nabla_cli/cli/
jwt_store.rs

1use anyhow::{Result, anyhow};
2use base64::{Engine as _, engine::general_purpose};
3use home::home_dir;
4use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
5use serde::{Deserialize, Serialize};
6use std::fs;
7use std::path::PathBuf;
8
9#[cfg(feature = "private")]
10use doppler_rs::{apis::client::Client, apis::Error as DopplerError};
11
12#[derive(Debug, Serialize, Deserialize, Clone)]
13pub struct JwtData {
14    pub token: String,
15    pub sub: String,
16    pub deployment_id: String,
17    pub expires_at: i64,
18    pub features: PlanFeatures,
19}
20
21#[derive(Debug, Serialize, Deserialize, Clone)]
22pub struct PlanFeatures {
23    pub chat_enabled: bool,
24    pub api_access: bool,
25    pub file_upload_limit_mb: u32,
26    pub concurrent_requests: u32,
27    pub custom_models: bool,
28    pub sbom_generation: bool,
29    pub vulnerability_scanning: bool,
30    pub signed_attestation: bool,
31    pub monthly_binaries: u32,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35struct JwtClaims {
36    pub sub: String,
37    pub deployment_id: String,
38    pub exp: i64,
39    pub features: PlanFeatures,
40}
41
42pub struct JwtStore {
43    store_path: PathBuf,
44}
45
46impl JwtStore {
47    pub fn new() -> Result<Self> {
48        let home = home_dir().ok_or_else(|| anyhow!("Could not find home directory"))?;
49        let nabla_dir = home.join(".nabla");
50
51        if !nabla_dir.exists() {
52            fs::create_dir_all(&nabla_dir)?;
53        }
54
55        Ok(Self {
56            store_path: nabla_dir.join("jwt.json"),
57        })
58    }
59
60    pub fn save_jwt(&self, jwt_data: &JwtData) -> Result<()> {
61        let json = serde_json::to_string_pretty(jwt_data)?;
62        fs::write(&self.store_path, json)?;
63        Ok(())
64    }
65
66    pub fn load_jwt(&self) -> Result<Option<JwtData>> {
67        if !self.store_path.exists() {
68            return Ok(None);
69        }
70
71        let content = fs::read_to_string(&self.store_path)?;
72        let jwt_data: JwtData = serde_json::from_str(&content)?;
73
74        // Check if token is expired
75        let now = chrono::Utc::now().timestamp();
76        if jwt_data.expires_at < now {
77            self.clear_jwt()?;
78            return Ok(None);
79        }
80
81        Ok(Some(jwt_data))
82    }
83
84    pub fn clear_jwt(&self) -> Result<()> {
85        if self.store_path.exists() {
86            fs::remove_file(&self.store_path)?;
87        }
88        Ok(())
89    }
90
91    pub fn verify_and_store_jwt(&self, jwt_token: &str) -> Result<JwtData> {
92        // Get signing key using the same logic as config.rs
93        let signing_key_b64 = self.get_license_signing_key()?;
94
95        // Decode the base64 key like the minting tool does
96        let key_bytes = general_purpose::URL_SAFE_NO_PAD.decode(signing_key_b64.trim())
97            .map_err(|e| anyhow!("Failed to decode LICENSE_SIGNING_KEY as base64: {}", e))?;
98
99        let key = DecodingKey::from_secret(&key_bytes);
100        let mut validation = Validation::new(Algorithm::HS256);
101        validation.validate_exp = true;
102
103        // Decode and verify the JWT
104        let token_data = decode::<JwtClaims>(jwt_token, &key, &validation)
105            .map_err(|e| anyhow!("JWT verification failed: {}", e))?;
106
107        let claims = token_data.claims;
108
109        let jwt_data = JwtData {
110            token: jwt_token.to_string(),
111            sub: claims.sub,
112            deployment_id: claims.deployment_id,
113            expires_at: claims.exp,
114            features: claims.features,
115        };
116
117        // Store the verified JWT
118        self.save_jwt(&jwt_data)?;
119
120        Ok(jwt_data)
121    }
122
123    fn get_license_signing_key(&self) -> Result<String> {
124        // Determine deployment type
125        let deployment_type = std::env::var("NABLA_DEPLOYMENT")
126            .unwrap_or_else(|_| "oss".to_string());
127
128        match deployment_type.to_lowercase().as_str() {
129            "oss" => {
130                // Use default OSS key
131                Ok("t6eLp6y0Ly8BZJIVv_wK71WyBtJ1zY2Pxz2M_0z5t8Q".to_string())
132            }
133            "private" => {
134                #[cfg(feature = "private")]
135                {
136                    // Try Doppler first for private deployments
137                    if let (Ok(project), Ok(config_name)) = (
138                        std::env::var("DOPPLER_PROJECT"),
139                        std::env::var("DOPPLER_CONFIG")
140                    ) {
141                        if let Ok(doppler_token) = std::env::var("DOPPLER_TOKEN") {
142                            let client = Client::new(&doppler_token);
143                            if let Ok(secret) = client.get_secret(&project, &config_name, "LICENSE_SIGNING_KEY") {
144                                return Ok(secret.value);
145                            }
146                        }
147                    }
148                }
149                
150                // Fallback to environment variable
151                std::env::var("LICENSE_SIGNING_KEY")
152                    .map_err(|_| anyhow!("LICENSE_SIGNING_KEY environment variable is required for JWT verification"))
153            }
154            _ => {
155                // Invalid deployment type, fallback to environment variable
156                std::env::var("LICENSE_SIGNING_KEY")
157                    .map_err(|_| anyhow!("LICENSE_SIGNING_KEY environment variable is required for JWT verification"))
158            }
159        }
160    }
161}