nabla_cli/
lib.rs

1// src/lib.rs
2use std::sync::Arc;
3pub mod binary;
4pub mod config;
5pub mod middleware;
6pub mod routes;
7pub mod ssrf_protection; // Add SSRF protection module
8// pub mod providers; // Using enterprise providers instead
9pub mod cli;
10pub mod enterprise;
11
12// Re-export AppState so integration tests can build routers easily.
13pub use config::Config;
14use reqwest::Client;
15
16#[derive(Clone)]
17pub struct AppState {
18    pub config: Config,
19    pub client: Client,
20    pub base_url: String,
21    pub license_jwt_secret: Arc<[u8; 32]>,
22    pub crypto_provider: enterprise::crypto::CryptoProvider,
23    pub inference_manager: Arc<enterprise::providers::InferenceManager>, // add this
24}
25
26// For binary crate main.rs we still have its own AppState; To avoid duplication, we
27// `cfg`-gate one of them, but duplicate struct definition is okay across crates
28// as they live in different crates (bin vs lib).
29
30// Re-export the server function from the binary crate
31pub mod server {
32    pub async fn run_server(port: u16) -> anyhow::Result<()> {
33        use crate::enterprise::providers::InferenceManager;
34        use axum::extract::DefaultBodyLimit;
35        use axum::{
36            Router,
37            http::{Method, header},
38            routing::post,
39        };
40        use base64::Engine;
41        use dotenvy::dotenv;
42        use std::sync::Arc;
43        use tower_http::cors::{Any, CorsLayer};
44        use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
45
46        dotenv().ok();
47        let config = crate::Config::from_env()?;
48
49        // Use consistent key loading from config
50        let key_b64 = config.license_signing_key.clone();
51
52        let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(key_b64.trim())?;
53        let secret_array: [u8; 32] = decoded
54            .try_into()
55            .map_err(|_| anyhow::anyhow!("LICENSE_SIGNING_KEY must be exactly 32 bytes"))?;
56        let license_jwt_secret = Arc::new(secret_array);
57
58        tracing_subscriber::registry()
59            .with(
60                tracing_subscriber::EnvFilter::try_from_default_env()
61                    .unwrap_or_else(|_| "nabla=debug,tower_http=debug".into()),
62            )
63            .with(tracing_subscriber::fmt::layer())
64            .init();
65
66        let client = reqwest::Client::builder()
67            .redirect(reqwest::redirect::Policy::none()) // disable redirects for SSRF protection
68            .build()?;
69
70        let inference_manager = Arc::new(InferenceManager::new());
71
72
73        let crypto_provider = crate::enterprise::crypto::CryptoProvider::new(
74            config.fips_mode,
75            config.fips_validation,
76        )?;
77
78        let state = crate::AppState {
79            config: config.clone(),
80            client,
81            base_url: config.base_url.clone(),
82            license_jwt_secret,
83            crypto_provider,
84            inference_manager,
85        };
86
87        let cors = CorsLayer::new()
88            .allow_origin(Any)
89            .allow_methods([Method::GET, Method::POST])
90            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION]);
91
92        // Create middleware layer that validates API keys & enforces quotas
93        let auth_layer = axum::middleware::from_fn_with_state(
94            state.clone(),
95            crate::middleware::validate_license_jwt,
96        );
97
98        // Public routes (no auth)
99        let public_routes = Router::new().route(
100            "/health",
101            axum::routing::get(crate::routes::binary::health_check),
102        );
103
104        // Protected routes (with auth)
105        let protected_routes = Router::new()
106            .route(
107                "/binary/analyze",
108                post(crate::routes::binary::upload_and_analyze_binary),
109            )
110            .route("/binary/diff", post(crate::routes::binary::diff_binaries))
111            .route(
112                "/binary/attest",
113                post(crate::enterprise::attestation::attest_binary),
114            )
115            .route("/binary/check-cves", post(crate::routes::binary::check_cve))
116            .route(
117                "/binary/chat",
118                post(crate::routes::binary::chat_with_binary),
119            )
120            .route_layer(auth_layer);
121
122        let app = Router::new()
123            .merge(public_routes)
124            .merge(protected_routes)
125            .layer(cors)
126            .layer(DefaultBodyLimit::max(64 * 1024 * 1024))
127            .with_state(state);
128
129        let listener = tokio::net::TcpListener::bind(&format!("0.0.0.0:{}", port)).await?;
130        tracing::info!("Server starting on port {}", port);
131
132        axum::serve(listener, app).await?;
133        Ok(())
134    }
135}