nabla_cli/cli/
jwt_store.rs1use anyhow::{Result, anyhow};
2use base64::{Engine as _, engine::general_purpose};
3use home::home_dir;
4use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
5use serde::{Deserialize, Serialize};
6use std::fs;
7use std::path::PathBuf;
8
9#[derive(Debug, Serialize, Deserialize, Clone)]
10pub struct JwtData {
11 pub token: String,
12 pub sub: String,
13 pub deployment_id: String,
14 pub expires_at: i64,
15 pub features: PlanFeatures,
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct PlanFeatures {
20 pub chat_enabled: bool,
21 pub api_access: bool,
22 pub file_upload_limit_mb: u32,
23 pub concurrent_requests: u32,
24 pub custom_models: bool,
25 pub sbom_generation: bool,
26 pub vulnerability_scanning: bool,
27 pub signed_attestation: bool,
28 pub monthly_binaries: u32,
29}
30
31#[derive(Debug, Serialize, Deserialize)]
32struct JwtClaims {
33 pub sub: String,
34 pub deployment_id: String,
35 pub exp: i64,
36 pub features: PlanFeatures,
37}
38
39pub struct JwtStore {
40 store_path: PathBuf,
41}
42
43impl JwtStore {
44 pub fn new() -> Result<Self> {
45 let home = home_dir().ok_or_else(|| anyhow!("Could not find home directory"))?;
46 let nabla_dir = home.join(".nabla");
47
48 if !nabla_dir.exists() {
49 fs::create_dir_all(&nabla_dir)?;
50 }
51
52 Ok(Self {
53 store_path: nabla_dir.join("jwt.json"),
54 })
55 }
56
57 pub fn save_jwt(&self, jwt_data: &JwtData) -> Result<()> {
58 let json = serde_json::to_string_pretty(jwt_data)?;
59 fs::write(&self.store_path, json)?;
60 Ok(())
61 }
62
63 pub fn load_jwt(&self) -> Result<Option<JwtData>> {
64 if !self.store_path.exists() {
65 return Ok(None);
66 }
67
68 let content = fs::read_to_string(&self.store_path)?;
69 let jwt_data: JwtData = serde_json::from_str(&content)?;
70
71 let now = chrono::Utc::now().timestamp();
73 if jwt_data.expires_at < now {
74 self.clear_jwt()?;
75 return Ok(None);
76 }
77
78 Ok(Some(jwt_data))
79 }
80
81 pub fn clear_jwt(&self) -> Result<()> {
82 if self.store_path.exists() {
83 fs::remove_file(&self.store_path)?;
84 }
85 Ok(())
86 }
87
88 pub fn verify_and_store_jwt(&self, jwt_token: &str) -> Result<JwtData> {
89 let signing_key_b64 = self.get_license_signing_key()?;
91
92 let key_bytes = general_purpose::URL_SAFE_NO_PAD
94 .decode(signing_key_b64.trim())
95 .map_err(|e| anyhow!("Failed to decode LICENSE_SIGNING_KEY as base64: {}", e))?;
96
97 let key = DecodingKey::from_secret(&key_bytes);
98 let mut validation = Validation::new(Algorithm::HS256);
99 validation.validate_exp = true;
100
101 let token_data = decode::<JwtClaims>(jwt_token, &key, &validation)
103 .map_err(|e| anyhow!("JWT verification failed: {}", e))?;
104
105 let claims = token_data.claims;
106
107 let jwt_data = JwtData {
108 token: jwt_token.to_string(),
109 sub: claims.sub,
110 deployment_id: claims.deployment_id,
111 expires_at: claims.exp,
112 features: claims.features,
113 };
114
115 self.save_jwt(&jwt_data)?;
117
118 Ok(jwt_data)
119 }
120
121 fn get_license_signing_key(&self) -> Result<String> {
122 if let Ok(key) = std::env::var("LICENSE_SIGNING_KEY") {
124 return Ok(key);
125 }
126
127 if let (Ok(project), Ok(config_name)) = (
129 std::env::var("DOPPLER_PROJECT"),
130 std::env::var("DOPPLER_CONFIG"),
131 ) {
132 let doppler_token = if config_name.contains("prd") {
134 std::env::var("DOPPLER_TOKEN_PRD").or_else(|_| std::env::var("DOPPLER_TOKEN"))
135 } else if config_name.contains("oss") {
136 std::env::var("DOPPLER_TOKEN_OSS").or_else(|_| std::env::var("DOPPLER_TOKEN"))
137 } else {
138 std::env::var("DOPPLER_TOKEN")
139 };
140
141 if let Ok(token) = doppler_token {
142 let url = format!(
144 "https://api.doppler.com/v3/configs/config/secret?project={}&config={}&name=LICENSE_SIGNING_KEY",
145 project, config_name
146 );
147
148 if let Ok(response) = ureq::get(&url)
149 .set("Authorization", &format!("Bearer {}", token))
150 .call()
151 {
152 if let Ok(json) = response.into_json::<serde_json::Value>() {
153 if let Some(value) = json
154 .get("value")
155 .and_then(|v| v.get("computed"))
156 .and_then(|c| c.as_str())
157 {
158 return Ok(value.to_string());
159 }
160 }
161 }
162 }
163 }
164
165 let deployment_type =
167 std::env::var("NABLA_DEPLOYMENT").unwrap_or_else(|_| "oss".to_string());
168
169 match deployment_type.to_lowercase().as_str() {
170 "oss" => {
171 Ok("t6eLp6y0Ly8BZJIVv_wK71WyBtJ1zY2Pxz2M_0z5t8Q".to_string())
173 }
174 "private" => Err(anyhow!(
175 "LICENSE_SIGNING_KEY required for private deployment (try Doppler CLI or env var)"
176 )),
177 _ => {
178 Ok("t6eLp6y0Ly8BZJIVv_wK71WyBtJ1zY2Pxz2M_0z5t8Q".to_string())
180 }
181 }
182 }
183}