nabla_cli/cli/
jwt_store.rs

1use anyhow::{Result, anyhow};
2use home::home_dir;
3use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
4use serde::{Deserialize, Serialize};
5use std::fs;
6use std::path::PathBuf;
7
8#[derive(Debug, Serialize, Deserialize, Clone)]
9pub struct JwtData {
10    pub token: String,
11    pub sub: String,
12    pub deployment_id: String,
13    pub expires_at: i64,
14    pub features: PlanFeatures,
15}
16
17#[derive(Debug, Serialize, Deserialize, Clone)]
18pub struct PlanFeatures {
19    pub chat_enabled: bool,
20    pub api_access: bool,
21    pub file_upload_limit_mb: u32,
22    pub concurrent_requests: u32,
23    pub custom_models: bool,
24    pub sbom_generation: bool,
25    pub vulnerability_scanning: bool,
26    pub signed_attestation: bool,
27    pub monthly_binaries: u32,
28}
29
30#[derive(Debug, Serialize, Deserialize)]
31struct JwtClaims {
32    pub sub: String,
33    pub deployment_id: String,
34    pub exp: i64,
35    pub features: PlanFeatures,
36}
37
38pub struct JwtStore {
39    store_path: PathBuf,
40}
41
42impl JwtStore {
43    pub fn new() -> Result<Self> {
44        let home = home_dir().ok_or_else(|| anyhow!("Could not find home directory"))?;
45        let nabla_dir = home.join(".nabla");
46
47        if !nabla_dir.exists() {
48            fs::create_dir_all(&nabla_dir)?;
49        }
50
51        Ok(Self {
52            store_path: nabla_dir.join("jwt.json"),
53        })
54    }
55
56    pub fn save_jwt(&self, jwt_data: &JwtData) -> Result<()> {
57        let json = serde_json::to_string_pretty(jwt_data)?;
58        fs::write(&self.store_path, json)?;
59        Ok(())
60    }
61
62    pub fn load_jwt(&self) -> Result<Option<JwtData>> {
63        if !self.store_path.exists() {
64            return Ok(None);
65        }
66
67        let content = fs::read_to_string(&self.store_path)?;
68        let jwt_data: JwtData = serde_json::from_str(&content)?;
69
70        // Check if token is expired
71        let now = chrono::Utc::now().timestamp();
72        if jwt_data.expires_at < now {
73            self.clear_jwt()?;
74            return Ok(None);
75        }
76
77        Ok(Some(jwt_data))
78    }
79
80    pub fn clear_jwt(&self) -> Result<()> {
81        if self.store_path.exists() {
82            fs::remove_file(&self.store_path)?;
83        }
84        Ok(())
85    }
86
87    pub fn verify_and_store_jwt(&self, jwt_token: &str) -> Result<JwtData> {
88        // TODO: Replace with your actual signing key - this should be the same key used in your backend
89        // For now using a placeholder - you'll need to set this to your actual signing key
90        let signing_key = std::env::var("NABLA_JWT_SECRET").map_err(|_| {
91            anyhow!("NABLA_JWT_SECRET environment variable is required for JWT verification")
92        })?;
93
94        let key = DecodingKey::from_secret(signing_key.as_ref());
95        let mut validation = Validation::new(Algorithm::HS256);
96        validation.validate_exp = true;
97
98        // Decode and verify the JWT
99        let token_data = decode::<JwtClaims>(jwt_token, &key, &validation)
100            .map_err(|e| anyhow!("JWT verification failed: {}", e))?;
101
102        let claims = token_data.claims;
103
104        let jwt_data = JwtData {
105            token: jwt_token.to_string(),
106            sub: claims.sub,
107            deployment_id: claims.deployment_id,
108            expires_at: claims.exp,
109            features: claims.features,
110        };
111
112        // Store the verified JWT
113        self.save_jwt(&jwt_data)?;
114
115        Ok(jwt_data)
116    }
117}