nabla_cli/
middleware.rs

1use axum::{
2    Json,
3    extract::{Request, State},
4    http::{StatusCode, header},
5    middleware::Next,
6    response::{IntoResponse, Response},
7};
8use chrono::{DateTime, Utc};
9use dashmap::DashMap;
10use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
11use serde::{Deserialize, Serialize};
12
13use crate::AppState;
14use once_cell::sync::Lazy;
15use serde_json::json;
16
17#[derive(Debug, Serialize, Deserialize, Clone)]
18pub struct Claims {
19    pub sub: String,            // Company name (e.g., "acme-corp")
20    pub uid: String,            // User ID within the company
21    pub exp: i64,               // Expiration timestamp
22    pub iat: i64,               // Issued at timestamp
23    pub jti: String,            // JWT ID
24    pub rate_limit: i32,        // Requests per hour
25    pub deployment_id: String,  // UUID for deployment isolation
26    pub features: PlanFeatures, // Feature flags - required field
27}
28
29#[derive(Debug, Serialize, Deserialize, Clone)]
30pub struct PlanFeatures {
31    pub chat_enabled: bool,
32    pub api_access: bool,
33    pub file_upload_limit_mb: u32,
34    pub concurrent_requests: u32,
35    pub custom_models: bool,
36    pub sbom_generation: bool,
37    pub vulnerability_scanning: bool,
38    pub signed_attestation: bool,
39    pub monthly_binaries: u32,
40}
41
42// In-memory rate limiting store
43static RATE_LIMITS: Lazy<DashMap<String, (u32, DateTime<Utc>)>> = Lazy::new(DashMap::new);
44
45impl PlanFeatures {
46    pub fn default_oss() -> Self {
47        Self {
48            chat_enabled: false,
49            api_access: true,
50            file_upload_limit_mb: 10,
51            concurrent_requests: 1,
52            custom_models: false,
53            sbom_generation: true,
54            vulnerability_scanning: true,
55            signed_attestation: false,
56            monthly_binaries: 100,
57        }
58    }
59}
60
61pub async fn validate_license_jwt(
62    State(state): State<AppState>,
63    mut request: Request,
64    next: Next,
65) -> Result<Response, impl IntoResponse> {
66    // Check if FIPS mode is enabled
67    if !state.config.fips_mode {
68        tracing::info!("FIPS mode disabled - using default OSS features");
69        // Add default OSS features to request extensions for endpoints to use
70        request.extensions_mut().insert(PlanFeatures::default_oss());
71        return Ok(next.run(request).await);
72    }
73
74    // 1. Extract Authorization header
75    let auth_header = request
76        .headers()
77        .get(header::AUTHORIZATION)
78        .and_then(|h| h.to_str().ok())
79        .and_then(|h| h.strip_prefix("Bearer "))
80        .ok_or_else(|| {
81            (
82                StatusCode::UNAUTHORIZED,
83                Json(json!({
84                    "error": "missing_authorization",
85                    "message": "Missing or invalid Authorization header (required in FIPS mode)"
86                })),
87            )
88        })?;
89
90    // 2. Decode and validate JWT token using HMAC secret
91    let decoding_key = DecodingKey::from_secret(&state.license_jwt_secret[..]);
92
93    // Use FIPS-compliant algorithm when FIPS mode is enabled
94    let algorithm = if state.config.fips_mode {
95        Algorithm::HS256 // FIPS-approved HMAC-SHA256
96    } else {
97        Algorithm::HS256 // Default to HS256 for consistency
98    };
99
100    let validation = Validation::new(algorithm);
101
102    let token_data = decode::<Claims>(auth_header, &decoding_key, &validation).map_err(|e| {
103        eprintln!("JWT decode error: {:?}", e);
104        (
105            StatusCode::UNAUTHORIZED,
106            Json(json!({
107                "error": "invalid_token",
108                "message": "Invalid or expired token"
109            })),
110        )
111    })?;
112
113    // 3. Check rate limiting
114    let claims = token_data.claims;
115    let key = format!("{}:{}", claims.sub, claims.deployment_id);
116
117    let now = Utc::now();
118    let entry = RATE_LIMITS
119        .entry(key.clone())
120        .and_modify(|entry| {
121            let (count, start) = *entry;
122            if now.signed_duration_since(start).num_seconds() >= 3600 {
123                // Reset window
124                *entry = (1, now);
125            } else {
126                *entry = (count + 1, start);
127            }
128        })
129        .or_insert((1, now));
130
131    let (current_count, _window_start) = *entry;
132
133    if current_count > claims.rate_limit as u32 {
134        return Err((
135            StatusCode::TOO_MANY_REQUESTS,
136            Json(json!({
137                "error": "rate_limit_exceeded",
138                "message": format!("Rate limit exceeded. Limit: {}, Used: {}", claims.rate_limit, current_count)
139            })),
140        ));
141    }
142
143    // 4. Add features to request extensions for endpoints to use
144    request.extensions_mut().insert(claims.features.clone());
145
146    // 5. Continue with the request
147    Ok(next.run(request).await)
148}