1use axum::{
2 Json,
3 extract::{Request, State},
4 http::{StatusCode, header},
5 middleware::Next,
6 response::{IntoResponse, Response},
7};
8use chrono::{DateTime, Utc};
9use dashmap::DashMap;
10use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
11use serde::{Deserialize, Serialize};
12
13use crate::AppState;
14use once_cell::sync::Lazy;
15use serde_json::json;
16
17#[derive(Debug, Serialize, Deserialize, Clone)]
18pub struct Claims {
19 pub sub: String, pub uid: String, pub exp: i64, pub iat: i64, pub jti: String, pub rate_limit: i32, pub deployment_id: String, pub features: PlanFeatures, }
28
29#[derive(Debug, Serialize, Deserialize, Clone)]
30pub struct PlanFeatures {
31 pub chat_enabled: bool,
32 pub api_access: bool,
33 pub file_upload_limit_mb: u32,
34 pub concurrent_requests: u32,
35 pub custom_models: bool,
36 pub sbom_generation: bool,
37 pub vulnerability_scanning: bool,
38 pub signed_attestation: bool,
39 pub monthly_binaries: u32,
40}
41
42static RATE_LIMITS: Lazy<DashMap<String, (u32, DateTime<Utc>)>> = Lazy::new(DashMap::new);
44
45impl PlanFeatures {
46 pub fn default_oss() -> Self {
47 Self {
48 chat_enabled: false,
49 api_access: true,
50 file_upload_limit_mb: 10,
51 concurrent_requests: 1,
52 custom_models: false,
53 sbom_generation: true,
54 vulnerability_scanning: true,
55 signed_attestation: false,
56 monthly_binaries: 100,
57 }
58 }
59}
60
61pub async fn validate_license_jwt(
62 State(state): State<AppState>,
63 mut request: Request,
64 next: Next,
65) -> Result<Response, impl IntoResponse> {
66 if !state.config.fips_mode {
68 tracing::info!("FIPS mode disabled - using default OSS features");
69 request.extensions_mut().insert(PlanFeatures::default_oss());
71 return Ok(next.run(request).await);
72 }
73
74 let auth_header = request
76 .headers()
77 .get(header::AUTHORIZATION)
78 .and_then(|h| h.to_str().ok())
79 .and_then(|h| h.strip_prefix("Bearer "))
80 .ok_or_else(|| {
81 (
82 StatusCode::UNAUTHORIZED,
83 Json(json!({
84 "error": "missing_authorization",
85 "message": "Missing or invalid Authorization header (required in FIPS mode)"
86 })),
87 )
88 })?;
89
90 let decoding_key = DecodingKey::from_secret(&state.license_jwt_secret[..]);
92
93 let algorithm = if state.config.fips_mode {
95 Algorithm::HS256 } else {
97 Algorithm::HS256 };
99
100 let validation = Validation::new(algorithm);
101
102 let token_data = decode::<Claims>(auth_header, &decoding_key, &validation).map_err(|e| {
103 eprintln!("JWT decode error: {:?}", e);
104 (
105 StatusCode::UNAUTHORIZED,
106 Json(json!({
107 "error": "invalid_token",
108 "message": "Invalid or expired token"
109 })),
110 )
111 })?;
112
113 let claims = token_data.claims;
115 let key = format!("{}:{}", claims.sub, claims.deployment_id);
116
117 let now = Utc::now();
118 let entry = RATE_LIMITS
119 .entry(key.clone())
120 .and_modify(|entry| {
121 let (count, start) = *entry;
122 if now.signed_duration_since(start).num_seconds() >= 3600 {
123 *entry = (1, now);
125 } else {
126 *entry = (count + 1, start);
127 }
128 })
129 .or_insert((1, now));
130
131 let (current_count, _window_start) = *entry;
132
133 if current_count > claims.rate_limit as u32 {
134 return Err((
135 StatusCode::TOO_MANY_REQUESTS,
136 Json(json!({
137 "error": "rate_limit_exceeded",
138 "message": format!("Rate limit exceeded. Limit: {}, Used: {}", claims.rate_limit, current_count)
139 })),
140 ));
141 }
142
143 request.extensions_mut().insert(claims.features.clone());
145
146 Ok(next.run(request).await)
148}