nabla_cli/cli/
jwt_store.rs

1use anyhow::{Result, anyhow};
2use base64::{Engine as _, engine::general_purpose};
3use home::home_dir;
4use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
5use serde::{Deserialize, Serialize};
6use std::fs;
7use std::path::PathBuf;
8
9#[derive(Debug, Serialize, Deserialize, Clone)]
10pub struct JwtData {
11    pub token: String,
12    pub sub: String,
13    pub deployment_id: String,
14    pub expires_at: i64,
15    pub features: PlanFeatures,
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct PlanFeatures {
20    pub chat_enabled: bool,
21    pub api_access: bool,
22    pub file_upload_limit_mb: u32,
23    pub concurrent_requests: u32,
24    pub custom_models: bool,
25    pub sbom_generation: bool,
26    pub vulnerability_scanning: bool,
27    pub signed_attestation: bool,
28    pub monthly_binaries: u32,
29}
30
31#[derive(Debug, Serialize, Deserialize)]
32struct JwtClaims {
33    pub sub: String,
34    pub deployment_id: String,
35    pub exp: i64,
36    pub features: PlanFeatures,
37}
38
39pub struct JwtStore {
40    store_path: PathBuf,
41}
42
43impl JwtStore {
44    pub fn new() -> Result<Self> {
45        let home = home_dir().ok_or_else(|| anyhow!("Could not find home directory"))?;
46        let nabla_dir = home.join(".nabla");
47
48        if !nabla_dir.exists() {
49            fs::create_dir_all(&nabla_dir)?;
50        }
51
52        Ok(Self {
53            store_path: nabla_dir.join("jwt.json"),
54        })
55    }
56
57    pub fn save_jwt(&self, jwt_data: &JwtData) -> Result<()> {
58        let json = serde_json::to_string_pretty(jwt_data)?;
59        fs::write(&self.store_path, json)?;
60        Ok(())
61    }
62
63    pub fn load_jwt(&self) -> Result<Option<JwtData>> {
64        if !self.store_path.exists() {
65            return Ok(None);
66        }
67
68        let content = fs::read_to_string(&self.store_path)?;
69        let jwt_data: JwtData = serde_json::from_str(&content)?;
70
71        // Check if token is expired
72        let now = chrono::Utc::now().timestamp();
73        if jwt_data.expires_at < now {
74            self.clear_jwt()?;
75            return Ok(None);
76        }
77
78        Ok(Some(jwt_data))
79    }
80
81    pub fn clear_jwt(&self) -> Result<()> {
82        if self.store_path.exists() {
83            fs::remove_file(&self.store_path)?;
84        }
85        Ok(())
86    }
87
88    pub fn verify_and_store_jwt(&self, jwt_token: &str) -> Result<JwtData> {
89        // Get signing key using the same logic as config.rs
90        let signing_key_b64 = self.get_license_signing_key()?;
91
92        // Decode the base64 key like the minting tool does
93        let key_bytes = general_purpose::URL_SAFE_NO_PAD
94            .decode(signing_key_b64.trim())
95            .map_err(|e| anyhow!("Failed to decode LICENSE_SIGNING_KEY as base64: {}", e))?;
96
97        let key = DecodingKey::from_secret(&key_bytes);
98        let mut validation = Validation::new(Algorithm::HS256);
99        validation.validate_exp = true;
100
101        // Decode and verify the JWT
102        let token_data = decode::<JwtClaims>(jwt_token, &key, &validation)
103            .map_err(|e| anyhow!("JWT verification failed: {}", e))?;
104
105        let claims = token_data.claims;
106
107        let jwt_data = JwtData {
108            token: jwt_token.to_string(),
109            sub: claims.sub,
110            deployment_id: claims.deployment_id,
111            expires_at: claims.exp,
112            features: claims.features,
113        };
114
115        // Store the verified JWT
116        self.save_jwt(&jwt_data)?;
117
118        Ok(jwt_data)
119    }
120
121    fn get_license_signing_key(&self) -> Result<String> {
122        // Try environment variable first (fastest)
123        if let Ok(key) = std::env::var("LICENSE_SIGNING_KEY") {
124            return Ok(key);
125        }
126
127        // Try Doppler API via HTTP for both OSS and NablaSecure deployments
128        if let (Ok(project), Ok(config_name)) = (
129            std::env::var("DOPPLER_PROJECT"),
130            std::env::var("DOPPLER_CONFIG"),
131        ) {
132            // Try deployment-specific token first, then fall back to general token
133            let doppler_token = if config_name.contains("prd") {
134                std::env::var("DOPPLER_TOKEN_PRD").or_else(|_| std::env::var("DOPPLER_TOKEN"))
135            } else if config_name.contains("oss") {
136                std::env::var("DOPPLER_TOKEN_OSS").or_else(|_| std::env::var("DOPPLER_TOKEN"))
137            } else {
138                std::env::var("DOPPLER_TOKEN")
139            };
140
141            if let Ok(token) = doppler_token {
142                // Use ureq for sync HTTP requests (no runtime conflicts)
143                let url = format!(
144                    "https://api.doppler.com/v3/configs/config/secret?project={}&config={}&name=LICENSE_SIGNING_KEY",
145                    project, config_name
146                );
147
148                if let Ok(response) = ureq::get(&url)
149                    .set("Authorization", &format!("Bearer {}", token))
150                    .call()
151                {
152                    if let Ok(json) = response.into_json::<serde_json::Value>() {
153                        if let Some(value) = json
154                            .get("value")
155                            .and_then(|v| v.get("computed"))
156                            .and_then(|c| c.as_str())
157                        {
158                            return Ok(value.to_string());
159                        }
160                    }
161                }
162            }
163        }
164
165        // Final fallback based on deployment type
166        let deployment_type =
167            std::env::var("NABLA_DEPLOYMENT").unwrap_or_else(|_| "oss".to_string());
168
169        match deployment_type.to_lowercase().as_str() {
170            "oss" => {
171                // Hardcoded public key for OSS deployments as last resort
172                Ok("t6eLp6y0Ly8BZJIVv_wK71WyBtJ1zY2Pxz2M_0z5t8Q".to_string())
173            }
174            "private" => Err(anyhow!(
175                "LICENSE_SIGNING_KEY required for private deployment (try Doppler CLI or env var)"
176            )),
177            _ => {
178                // Invalid deployment type, try OSS fallback
179                Ok("t6eLp6y0Ly8BZJIVv_wK71WyBtJ1zY2Pxz2M_0z5t8Q".to_string())
180            }
181        }
182    }
183}