nabla_cli/
lib.rs

1// src/lib.rs
2use std::sync::Arc;
3pub mod binary;
4pub mod config;
5pub mod middleware;
6pub mod routes;
7pub mod ssrf_protection; // Add SSRF protection module
8// pub mod providers; // Using enterprise providers instead
9pub mod cli;
10pub mod enterprise;
11
12// Re-export AppState so integration tests can build routers easily.
13pub use config::Config;
14use reqwest::Client;
15
16#[derive(Clone)]
17pub struct AppState {
18    pub config: Config,
19    pub client: Client,
20    pub base_url: String,
21    pub enterprise_features: bool,
22    pub license_jwt_secret: Arc<[u8; 32]>,
23}
24
25// For binary crate main.rs we still have its own AppState; To avoid duplication, we
26// `cfg`-gate one of them, but duplicate struct definition is okay across crates
27// as they live in different crates (bin vs lib).
28
29// Re-export the server function from the binary crate
30pub mod server {
31    pub async fn run_server(port: u16) -> anyhow::Result<()> {
32        use axum::extract::DefaultBodyLimit;
33        use axum::{
34            Router,
35            http::{Method, header},
36            routing::post,
37        };
38        use base64::Engine;
39        use dotenvy::dotenv;
40        use std::sync::Arc;
41        use tower_http::cors::{Any, CorsLayer};
42        use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
43
44        dotenv().ok();
45        let config = crate::Config::from_env()?;
46
47        // Use consistent key loading from config
48        let key_b64 = config.license_signing_key.clone();
49
50        let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(key_b64.trim())?;
51        let secret_array: [u8; 32] = decoded
52            .try_into()
53            .map_err(|_| anyhow::anyhow!("LICENSE_SIGNING_KEY must be exactly 32 bytes"))?;
54        let license_jwt_secret = Arc::new(secret_array);
55
56        tracing_subscriber::registry()
57            .with(
58                tracing_subscriber::EnvFilter::try_from_default_env()
59                    .unwrap_or_else(|_| "nabla=debug,tower_http=debug".into()),
60            )
61            .with(tracing_subscriber::fmt::layer())
62            .init();
63
64        let client = reqwest::Client::builder()
65            .redirect(reqwest::redirect::Policy::none()) // disable redirects for SSRF protection
66            .build()?;
67
68        let state = crate::AppState {
69            config: config.clone(),
70            client,
71            base_url: config.base_url.clone(),
72            enterprise_features: config.enterprise_features,
73            license_jwt_secret,
74        };
75
76        let cors = CorsLayer::new()
77            .allow_origin(Any)
78            .allow_methods([Method::GET, Method::POST])
79            .allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION]);
80
81        // Create middleware layer that validates API keys & enforces quotas
82        let auth_layer = axum::middleware::from_fn_with_state(
83            state.clone(),
84            crate::middleware::validate_license_jwt,
85        );
86
87        // Public routes (no auth)
88        let public_routes = Router::new().route(
89            "/health",
90            axum::routing::get(crate::routes::binary::health_check),
91        );
92
93        // Protected routes (with auth)
94        let protected_routes = Router::new()
95            .route(
96                "/binary/analyze",
97                post(crate::routes::binary::upload_and_analyze_binary),
98            )
99            .route("/binary/diff", post(crate::routes::binary::diff_binaries))
100            .route(
101                "/binary/attest",
102                post(crate::enterprise::attestation::attest_binary),
103            )
104            .route("/binary/check-cves", post(crate::routes::binary::check_cve))
105            .route_layer(auth_layer);
106
107        let app = Router::new()
108            .merge(public_routes)
109            .merge(protected_routes)
110            .layer(cors)
111            .layer(DefaultBodyLimit::max(64 * 1024 * 1024))
112            .with_state(state);
113
114        let listener = tokio::net::TcpListener::bind(&format!("0.0.0.0:{}", port)).await?;
115        tracing::info!("Server starting on port {}", port);
116
117        axum::serve(listener, app).await?;
118        Ok(())
119    }
120}