nabla_cli/cli/
jwt_store.rs1use anyhow::{Result, anyhow};
2use base64::{Engine as _, engine::general_purpose};
3use home::home_dir;
4use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
5use serde::{Deserialize, Serialize};
6use std::fs;
7use std::path::PathBuf;
8
9#[derive(Debug, Serialize, Deserialize, Clone)]
10pub struct JwtData {
11 pub token: String,
12 pub sub: String,
13 pub deployment_id: String,
14 pub expires_at: i64,
15 pub features: PlanFeatures,
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct PlanFeatures {
20 pub chat_enabled: bool,
21 pub api_access: bool,
22 pub file_upload_limit_mb: u32,
23 pub concurrent_requests: u32,
24 pub custom_models: bool,
25 pub sbom_generation: bool,
26 pub vulnerability_scanning: bool,
27 pub signed_attestation: bool,
28 pub monthly_binaries: u32,
29}
30
31#[derive(Debug, Serialize, Deserialize)]
32struct JwtClaims {
33 pub sub: String,
34 pub deployment_id: String,
35 pub exp: i64,
36 pub features: PlanFeatures,
37}
38
39pub struct JwtStore {
40 store_path: PathBuf,
41}
42
43impl JwtStore {
44 pub fn new() -> Result<Self> {
45 let home = home_dir().ok_or_else(|| anyhow!("Could not find home directory"))?;
46 let nabla_dir = home.join(".nabla");
47
48 if !nabla_dir.exists() {
49 fs::create_dir_all(&nabla_dir)?;
50 }
51
52 Ok(Self {
53 store_path: nabla_dir.join("jwt.json"),
54 })
55 }
56
57 pub fn save_jwt(&self, jwt_data: &JwtData) -> Result<()> {
58 let json = serde_json::to_string_pretty(jwt_data)?;
59 fs::write(&self.store_path, json)?;
60 Ok(())
61 }
62
63 pub fn load_jwt(&self) -> Result<Option<JwtData>> {
64 if !self.store_path.exists() {
65 return Ok(None);
66 }
67
68 let content = fs::read_to_string(&self.store_path)?;
69 let jwt_data: JwtData = serde_json::from_str(&content)?;
70
71 let now = chrono::Utc::now().timestamp();
73 if jwt_data.expires_at < now {
74 self.clear_jwt()?;
75 return Ok(None);
76 }
77
78 Ok(Some(jwt_data))
79 }
80
81 pub fn clear_jwt(&self) -> Result<()> {
82 if self.store_path.exists() {
83 fs::remove_file(&self.store_path)?;
84 }
85 Ok(())
86 }
87
88 pub fn verify_and_store_jwt(&self, jwt_token: &str) -> Result<JwtData> {
89 let signing_key_b64 = self.get_license_signing_key()?;
91
92 let key_bytes = general_purpose::URL_SAFE_NO_PAD.decode(signing_key_b64.trim())
94 .map_err(|e| anyhow!("Failed to decode LICENSE_SIGNING_KEY as base64: {}", e))?;
95
96 let key = DecodingKey::from_secret(&key_bytes);
97 let mut validation = Validation::new(Algorithm::HS256);
98 validation.validate_exp = true;
99
100 let token_data = decode::<JwtClaims>(jwt_token, &key, &validation)
102 .map_err(|e| anyhow!("JWT verification failed: {}", e))?;
103
104 let claims = token_data.claims;
105
106 let jwt_data = JwtData {
107 token: jwt_token.to_string(),
108 sub: claims.sub,
109 deployment_id: claims.deployment_id,
110 expires_at: claims.exp,
111 features: claims.features,
112 };
113
114 self.save_jwt(&jwt_data)?;
116
117 Ok(jwt_data)
118 }
119
120 fn get_license_signing_key(&self) -> Result<String> {
121 if let Ok(key) = std::env::var("LICENSE_SIGNING_KEY") {
123 return Ok(key);
124 }
125
126 if let (Ok(project), Ok(config_name)) = (
128 std::env::var("DOPPLER_PROJECT"),
129 std::env::var("DOPPLER_CONFIG")
130 ) {
131 let doppler_token = if config_name.contains("prd") {
133 std::env::var("DOPPLER_TOKEN_PRD")
134 .or_else(|_| std::env::var("DOPPLER_TOKEN"))
135 } else if config_name.contains("oss") {
136 std::env::var("DOPPLER_TOKEN_OSS")
137 .or_else(|_| std::env::var("DOPPLER_TOKEN"))
138 } else {
139 std::env::var("DOPPLER_TOKEN")
140 };
141
142 if let Ok(token) = doppler_token {
143 let url = format!("https://api.doppler.com/v3/configs/config/secret?project={}&config={}&name=LICENSE_SIGNING_KEY",
145 project, config_name);
146
147 if let Ok(response) = ureq::get(&url)
148 .set("Authorization", &format!("Bearer {}", token))
149 .call()
150 {
151 if let Ok(json) = response.into_json::<serde_json::Value>() {
152 if let Some(value) = json.get("value")
153 .and_then(|v| v.get("computed"))
154 .and_then(|c| c.as_str())
155 {
156 return Ok(value.to_string());
157 }
158 }
159 }
160 }
161 }
162
163 let deployment_type = std::env::var("NABLA_DEPLOYMENT")
165 .unwrap_or_else(|_| "oss".to_string());
166
167 match deployment_type.to_lowercase().as_str() {
168 "oss" => {
169 Ok("t6eLp6y0Ly8BZJIVv_wK71WyBtJ1zY2Pxz2M_0z5t8Q".to_string())
171 }
172 "private" => {
173 Err(anyhow!("LICENSE_SIGNING_KEY required for private deployment (try Doppler CLI or env var)"))
174 }
175 _ => {
176 Ok("t6eLp6y0Ly8BZJIVv_wK71WyBtJ1zY2Pxz2M_0z5t8Q".to_string())
178 }
179 }
180 }
181}