nabla_cli/cli/
jwt_store.rs1use anyhow::{Result, anyhow};
2use base64::{Engine as _, engine::general_purpose};
3use home::home_dir;
4use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
5use serde::{Deserialize, Serialize};
6use std::fs;
7use std::path::PathBuf;
8
9#[cfg(feature = "private")]
10use doppler_rs::{apis::client::Client, apis::Error as DopplerError};
11
12#[derive(Debug, Serialize, Deserialize, Clone)]
13pub struct JwtData {
14 pub token: String,
15 pub sub: String,
16 pub deployment_id: String,
17 pub expires_at: i64,
18 pub features: PlanFeatures,
19}
20
21#[derive(Debug, Serialize, Deserialize, Clone)]
22pub struct PlanFeatures {
23 pub chat_enabled: bool,
24 pub api_access: bool,
25 pub file_upload_limit_mb: u32,
26 pub concurrent_requests: u32,
27 pub custom_models: bool,
28 pub sbom_generation: bool,
29 pub vulnerability_scanning: bool,
30 pub signed_attestation: bool,
31 pub monthly_binaries: u32,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35struct JwtClaims {
36 pub sub: String,
37 pub deployment_id: String,
38 pub exp: i64,
39 pub features: PlanFeatures,
40}
41
42pub struct JwtStore {
43 store_path: PathBuf,
44}
45
46impl JwtStore {
47 pub fn new() -> Result<Self> {
48 let home = home_dir().ok_or_else(|| anyhow!("Could not find home directory"))?;
49 let nabla_dir = home.join(".nabla");
50
51 if !nabla_dir.exists() {
52 fs::create_dir_all(&nabla_dir)?;
53 }
54
55 Ok(Self {
56 store_path: nabla_dir.join("jwt.json"),
57 })
58 }
59
60 pub fn save_jwt(&self, jwt_data: &JwtData) -> Result<()> {
61 let json = serde_json::to_string_pretty(jwt_data)?;
62 fs::write(&self.store_path, json)?;
63 Ok(())
64 }
65
66 pub fn load_jwt(&self) -> Result<Option<JwtData>> {
67 if !self.store_path.exists() {
68 return Ok(None);
69 }
70
71 let content = fs::read_to_string(&self.store_path)?;
72 let jwt_data: JwtData = serde_json::from_str(&content)?;
73
74 let now = chrono::Utc::now().timestamp();
76 if jwt_data.expires_at < now {
77 self.clear_jwt()?;
78 return Ok(None);
79 }
80
81 Ok(Some(jwt_data))
82 }
83
84 pub fn clear_jwt(&self) -> Result<()> {
85 if self.store_path.exists() {
86 fs::remove_file(&self.store_path)?;
87 }
88 Ok(())
89 }
90
91 pub fn verify_and_store_jwt(&self, jwt_token: &str) -> Result<JwtData> {
92 let signing_key_b64 = self.get_license_signing_key()?;
94
95 let key_bytes = general_purpose::URL_SAFE_NO_PAD.decode(signing_key_b64.trim())
97 .map_err(|e| anyhow!("Failed to decode LICENSE_SIGNING_KEY as base64: {}", e))?;
98
99 let key = DecodingKey::from_secret(&key_bytes);
100 let mut validation = Validation::new(Algorithm::HS256);
101 validation.validate_exp = true;
102
103 let token_data = decode::<JwtClaims>(jwt_token, &key, &validation)
105 .map_err(|e| anyhow!("JWT verification failed: {}", e))?;
106
107 let claims = token_data.claims;
108
109 let jwt_data = JwtData {
110 token: jwt_token.to_string(),
111 sub: claims.sub,
112 deployment_id: claims.deployment_id,
113 expires_at: claims.exp,
114 features: claims.features,
115 };
116
117 self.save_jwt(&jwt_data)?;
119
120 Ok(jwt_data)
121 }
122
123 fn get_license_signing_key(&self) -> Result<String> {
124 let deployment_type = std::env::var("NABLA_DEPLOYMENT")
126 .unwrap_or_else(|_| "oss".to_string());
127
128 match deployment_type.to_lowercase().as_str() {
129 "oss" => {
130 Ok("t6eLp6y0Ly8BZJIVv_wK71WyBtJ1zY2Pxz2M_0z5t8Q".to_string())
132 }
133 "private" => {
134 #[cfg(feature = "private")]
135 {
136 if let (Ok(project), Ok(config_name)) = (
138 std::env::var("DOPPLER_PROJECT"),
139 std::env::var("DOPPLER_CONFIG")
140 ) {
141 if let Ok(doppler_token) = std::env::var("DOPPLER_TOKEN") {
142 let client = Client::new(&doppler_token);
143 if let Ok(secret) = client.get_secret(&project, &config_name, "LICENSE_SIGNING_KEY") {
144 return Ok(secret.value);
145 }
146 }
147 }
148 }
149
150 std::env::var("LICENSE_SIGNING_KEY")
152 .map_err(|_| anyhow!("LICENSE_SIGNING_KEY environment variable is required for JWT verification"))
153 }
154 _ => {
155 std::env::var("LICENSE_SIGNING_KEY")
157 .map_err(|_| anyhow!("LICENSE_SIGNING_KEY environment variable is required for JWT verification"))
158 }
159 }
160 }
161}