nabla_cli/
middleware.rs

1use axum::{
2    Json,
3    extract::{Request, State},
4    http::{StatusCode, header},
5    middleware::Next,
6    response::{IntoResponse, Response},
7};
8use chrono::{DateTime, Utc};
9use dashmap::DashMap;
10use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
11use serde::{Deserialize, Serialize};
12
13use crate::{AppState, config::DeploymentType};
14use once_cell::sync::Lazy;
15use serde_json::json;
16
17#[derive(Debug, Serialize, Deserialize, Clone)]
18pub struct Claims {
19    pub sub: String,            // Company name (e.g., "acme-corp")
20    pub uid: String,            // User ID within the company
21    pub exp: i64,               // Expiration timestamp
22    pub iat: i64,               // Issued at timestamp
23    pub jti: String,            // JWT ID
24    pub rate_limit: i32,        // Requests per hour
25    pub deployment_id: String,  // UUID for deployment isolation
26    pub features: PlanFeatures, // Feature flags - required field
27}
28
29#[derive(Debug, Serialize, Deserialize, Clone)]
30pub struct PlanFeatures {
31    pub chat_enabled: bool,
32    pub api_access: bool,
33    pub file_upload_limit_mb: u32,
34    pub concurrent_requests: u32,
35    pub custom_models: bool,
36    pub sbom_generation: bool,
37    pub vulnerability_scanning: bool,
38    pub exploitability_analysis: bool,
39    pub signed_attestation: bool,
40    pub monthly_binaries: u32,
41}
42
43// In-memory rate limiting store
44static RATE_LIMITS: Lazy<DashMap<String, (u32, DateTime<Utc>)>> = Lazy::new(DashMap::new);
45
46impl PlanFeatures {
47    pub fn default_oss() -> Self {
48        Self {
49            chat_enabled: true,
50            api_access: true,
51            file_upload_limit_mb: 10,
52            concurrent_requests: 1,
53            custom_models: false,
54            sbom_generation: true,
55            vulnerability_scanning: true,
56            exploitability_analysis: false,
57            signed_attestation: false,
58            monthly_binaries: 100,
59        }
60    }
61}
62
63pub async fn validate_license_jwt(
64    State(state): State<AppState>,
65    mut request: Request,
66    next: Next,
67) -> Result<Response, impl IntoResponse> {
68    // If in OSS deployment, use default features and skip JWT validation
69    if state.config.deployment_type == DeploymentType::OSS {
70        tracing::info!("OSS deployment - using default features");
71        request.extensions_mut().insert(PlanFeatures::default_oss());
72        return Ok(next.run(request).await);
73    }
74
75    // In NablaSecure deployment, a valid JWT is required
76    let auth_header = request
77        .headers()
78        .get(header::AUTHORIZATION)
79        .and_then(|h| h.to_str().ok())
80        .and_then(|h| h.strip_prefix("Bearer "))
81        .ok_or_else(|| {
82            (
83                StatusCode::UNAUTHORIZED,
84                Json(json!({
85                    "error": "missing_authorization",
86                    "message": "Missing or invalid Authorization header (required for private deployment)"
87                })),
88            )
89        })?;
90
91    // Decode and validate JWT token using HMAC secret
92    let decoding_key = DecodingKey::from_secret(&state.license_jwt_secret[..]);
93    let validation = Validation::new(Algorithm::HS256);
94
95    let token_data = decode::<Claims>(auth_header, &decoding_key, &validation).map_err(|e| {
96        eprintln!("JWT decode error: {:?}", e);
97        (
98            StatusCode::UNAUTHORIZED,
99            Json(json!({
100                "error": "invalid_token",
101                "message": "Invalid or expired token"
102            })),
103        )
104    })?;
105
106    // Check rate limiting
107    let claims = token_data.claims;
108    let key = format!("{}:{}", claims.sub, claims.deployment_id);
109
110    let now = Utc::now();
111    let entry = RATE_LIMITS
112        .entry(key.clone())
113        .and_modify(|entry| {
114            let (count, start) = *entry;
115            if now.signed_duration_since(start).num_seconds() >= 3600 {
116                // Reset window
117                *entry = (1, now);
118            } else {
119                *entry = (count + 1, start);
120            }
121        })
122        .or_insert((1, now));
123
124    let (current_count, _window_start) = *entry;
125
126    if current_count > claims.rate_limit as u32 {
127        return Err((
128            StatusCode::TOO_MANY_REQUESTS,
129            Json(json!({
130                "error": "rate_limit_exceeded",
131                "message": format!("Rate limit exceeded. Limit: {}, Used: {}", claims.rate_limit, current_count)
132            })),
133        ));
134    }
135
136    // Add features from the token to request extensions
137    request.extensions_mut().insert(claims.features.clone());
138
139    // Continue with the request
140    Ok(next.run(request).await)
141}