1use axum::{
2 Json,
3 extract::{Request, State},
4 http::{StatusCode, header},
5 middleware::Next,
6 response::{IntoResponse, Response},
7};
8use chrono::{DateTime, Utc};
9use dashmap::DashMap;
10use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
11use serde::{Deserialize, Serialize};
12
13use crate::{AppState, config::DeploymentType};
14use once_cell::sync::Lazy;
15use serde_json::json;
16
17#[derive(Debug, Serialize, Deserialize, Clone)]
18pub struct Claims {
19 pub sub: String, pub uid: String, pub exp: i64, pub iat: i64, pub jti: String, pub rate_limit: i32, pub deployment_id: String, pub features: PlanFeatures, }
28
29#[derive(Debug, Serialize, Deserialize, Clone)]
30pub struct PlanFeatures {
31 pub chat_enabled: bool,
32 pub api_access: bool,
33 pub file_upload_limit_mb: u32,
34 pub concurrent_requests: u32,
35 pub custom_models: bool,
36 pub sbom_generation: bool,
37 pub vulnerability_scanning: bool,
38 pub exploitability_analysis: bool,
39 pub signed_attestation: bool,
40 pub monthly_binaries: u32,
41}
42
43static RATE_LIMITS: Lazy<DashMap<String, (u32, DateTime<Utc>)>> = Lazy::new(DashMap::new);
45
46impl PlanFeatures {
47 pub fn default_oss() -> Self {
48 Self {
49 chat_enabled: true,
50 api_access: true,
51 file_upload_limit_mb: 10,
52 concurrent_requests: 1,
53 custom_models: false,
54 sbom_generation: true,
55 vulnerability_scanning: true,
56 exploitability_analysis: false,
57 signed_attestation: false,
58 monthly_binaries: 100,
59 }
60 }
61}
62
63pub async fn validate_license_jwt(
64 State(state): State<AppState>,
65 mut request: Request,
66 next: Next,
67) -> Result<Response, impl IntoResponse> {
68 if state.config.deployment_type == DeploymentType::OSS {
70 tracing::info!("OSS deployment - using default features");
71 request.extensions_mut().insert(PlanFeatures::default_oss());
72 return Ok(next.run(request).await);
73 }
74
75 let auth_header = request
77 .headers()
78 .get(header::AUTHORIZATION)
79 .and_then(|h| h.to_str().ok())
80 .and_then(|h| h.strip_prefix("Bearer "))
81 .ok_or_else(|| {
82 (
83 StatusCode::UNAUTHORIZED,
84 Json(json!({
85 "error": "missing_authorization",
86 "message": "Missing or invalid Authorization header (required for private deployment)"
87 })),
88 )
89 })?;
90
91 let decoding_key = DecodingKey::from_secret(&state.license_jwt_secret[..]);
93 let validation = Validation::new(Algorithm::HS256);
94
95 let token_data = decode::<Claims>(auth_header, &decoding_key, &validation).map_err(|e| {
96 eprintln!("JWT decode error: {:?}", e);
97 (
98 StatusCode::UNAUTHORIZED,
99 Json(json!({
100 "error": "invalid_token",
101 "message": "Invalid or expired token"
102 })),
103 )
104 })?;
105
106 let claims = token_data.claims;
108 let key = format!("{}:{}", claims.sub, claims.deployment_id);
109
110 let now = Utc::now();
111 let entry = RATE_LIMITS
112 .entry(key.clone())
113 .and_modify(|entry| {
114 let (count, start) = *entry;
115 if now.signed_duration_since(start).num_seconds() >= 3600 {
116 *entry = (1, now);
118 } else {
119 *entry = (count + 1, start);
120 }
121 })
122 .or_insert((1, now));
123
124 let (current_count, _window_start) = *entry;
125
126 if current_count > claims.rate_limit as u32 {
127 return Err((
128 StatusCode::TOO_MANY_REQUESTS,
129 Json(json!({
130 "error": "rate_limit_exceeded",
131 "message": format!("Rate limit exceeded. Limit: {}, Used: {}", claims.rate_limit, current_count)
132 })),
133 ));
134 }
135
136 request.extensions_mut().insert(claims.features.clone());
138
139 Ok(next.run(request).await)
141}