1use anyhow::{Context, Result};
8use parking_lot::RwLock as ClientRwLock;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use vaultrs::client::{VaultClient, VaultClientSettingsBuilder};
13use vaultrs::error::ClientError;
14use vaultrs::kv2;
15
16#[allow(dead_code)]
18const DEFAULT_SECRETS_PATH: &str = "secret/data/codetether/providers";
19
20#[derive(Clone)]
22pub struct SecretsManager {
23 client: Arc<ClientRwLock<Option<Arc<VaultClient>>>>,
24 pub cache: Arc<RwLock<HashMap<String, String>>>,
26 mount: String,
27 path: String,
28 k8s_auth: Option<Arc<KubernetesAuthConfig>>,
29}
30
31#[derive(Clone, Debug)]
32struct KubernetesAuthConfig {
33 address: String,
34 role: String,
35 auth_mount: String,
36 jwt_path: String,
37}
38
39impl Default for SecretsManager {
40 fn default() -> Self {
41 Self {
42 client: Arc::new(ClientRwLock::new(None)),
43 cache: Arc::new(RwLock::new(HashMap::new())),
44 mount: "secret".to_string(),
45 path: "codetether/providers".to_string(),
46 k8s_auth: None,
47 }
48 }
49}
50
51impl SecretsManager {
52 pub async fn new(config: &VaultConfig) -> Result<Self> {
54 let settings = VaultClientSettingsBuilder::default()
55 .address(&config.address)
56 .token(&config.token)
57 .build()
58 .context("Failed to build Vault client settings")?;
59
60 let client = VaultClient::new(settings).context("Failed to create Vault client")?;
61
62 Ok(Self {
63 client: Arc::new(ClientRwLock::new(Some(Arc::new(client)))),
64 cache: Arc::new(RwLock::new(HashMap::new())),
65 mount: config.mount.clone().unwrap_or_else(|| "secret".to_string()),
66 path: config
67 .path
68 .clone()
69 .unwrap_or_else(|| "codetether/providers".to_string()),
70 k8s_auth: None,
71 })
72 }
73
74 pub async fn from_k8s_auth(
81 address: &str,
82 role: &str,
83 mount: &str,
84 kv_mount: Option<&str>,
85 kv_path: Option<&str>,
86 ) -> Result<Self> {
87 let jwt_path = std::env::var("VAULT_K8S_SA_JWT_PATH")
88 .unwrap_or_else(|_| "/var/run/secrets/kubernetes.io/serviceaccount/token".to_string());
89
90 let auth = KubernetesAuthConfig {
91 address: address.to_string(),
92 role: role.to_string(),
93 auth_mount: mount.to_string(),
94 jwt_path,
95 };
96
97 let client = Self::login_with_kubernetes(&auth).await?;
98
99 Ok(Self {
100 client: Arc::new(ClientRwLock::new(Some(client))),
101 cache: Arc::new(RwLock::new(HashMap::new())),
102 mount: kv_mount.unwrap_or("secret").to_string(),
103 path: kv_path.unwrap_or("codetether/providers").to_string(),
104 k8s_auth: Some(Arc::new(auth)),
105 })
106 }
107
108 async fn login_with_kubernetes(auth: &KubernetesAuthConfig) -> Result<Arc<VaultClient>> {
109 let jwt = tokio::fs::read_to_string(&auth.jwt_path)
110 .await
111 .with_context(|| {
112 format!(
113 "Failed to read Kubernetes service account token from {}",
114 auth.jwt_path
115 )
116 })?;
117 let jwt = jwt.trim().to_string();
118
119 let bootstrap_settings = VaultClientSettingsBuilder::default()
122 .address(&auth.address)
123 .token("")
124 .build()
125 .context("Failed to build bootstrap Vault client settings")?;
126 let bootstrap_client = VaultClient::new(bootstrap_settings)
127 .context("Failed to create bootstrap Vault client")?;
128
129 let auth_info =
130 vaultrs::auth::kubernetes::login(&bootstrap_client, &auth.auth_mount, &auth.role, &jwt)
131 .await
132 .context("Vault Kubernetes auth login failed")?;
133
134 let settings = VaultClientSettingsBuilder::default()
135 .address(&auth.address)
136 .token(&auth_info.client_token)
137 .build()
138 .context("Failed to build authenticated Vault client settings")?;
139 let client =
140 VaultClient::new(settings).context("Failed to create authenticated Vault client")?;
141
142 Ok(Arc::new(client))
143 }
144
145 pub async fn from_env() -> Result<Self> {
152 let address = std::env::var("VAULT_ADDR").context("VAULT_ADDR not set")?;
153 let kv_mount = std::env::var("VAULT_MOUNT").ok();
154 let kv_path = std::env::var("VAULT_SECRETS_PATH").ok();
155
156 if let Ok(role) = std::env::var("VAULT_ROLE") {
161 let role = role.trim().to_string();
162 if !role.is_empty() {
163 let k8s_mount =
164 std::env::var("VAULT_AUTH_MOUNT").unwrap_or_else(|_| "kubernetes".to_string());
165
166 match Self::from_k8s_auth(
167 &address,
168 &role,
169 &k8s_mount,
170 kv_mount.as_deref(),
171 kv_path.as_deref(),
172 )
173 .await
174 {
175 Ok(manager) => {
176 tracing::info!(
177 role = %role,
178 mount = %k8s_mount,
179 "Authenticated to Vault via Kubernetes service account"
180 );
181 return Ok(manager);
182 }
183 Err(e) => {
184 tracing::warn!(
185 error = %e,
186 "Vault Kubernetes auth failed; falling back to VAULT_TOKEN"
187 );
188 }
189 }
190 }
191 }
192
193 let token = std::env::var("VAULT_TOKEN").context("VAULT_TOKEN not set")?;
194 let config = VaultConfig {
195 address,
196 token,
197 mount: kv_mount,
198 path: kv_path,
199 };
200
201 Self::new(&config).await
202 }
203
204 pub fn is_connected(&self) -> bool {
206 self.client.read().is_some()
207 }
208
209 pub async fn verify_reachable(&self) -> Result<()> {
215 let client = match self.client() {
216 Some(c) => c,
217 None => anyhow::bail!("Vault client not configured"),
218 };
219
220 let mut result = kv2::list(client.as_ref(), &self.mount, &self.path).await;
221 if matches!(result.as_ref().err(), Some(err) if Self::should_refresh_vault_token(err)) {
222 if let Some(client) = self.refresh_kubernetes_auth().await? {
223 result = kv2::list(client.as_ref(), &self.mount, &self.path).await;
224 }
225 }
226
227 match result {
228 Ok(_) | Err(vaultrs::error::ClientError::APIError { code: 404, .. }) => Ok(()),
229 Err(error) => Err(error).context("Failed to reach configured Vault provider path"),
230 }
231 }
232
233 fn client(&self) -> Option<Arc<VaultClient>> {
234 self.client.read().clone()
235 }
236
237 async fn refresh_kubernetes_auth(&self) -> Result<Option<Arc<VaultClient>>> {
238 let Some(auth) = self.k8s_auth.as_deref() else {
239 return Ok(self.client());
240 };
241
242 tracing::warn!("Vault token was rejected; refreshing Kubernetes auth token");
243 let client = Self::login_with_kubernetes(auth).await?;
244 {
245 let mut current = self.client.write();
246 *current = Some(client.clone());
247 }
248 self.clear_cache().await;
249 tracing::info!(
250 role = %auth.role,
251 mount = %auth.auth_mount,
252 "Refreshed Vault Kubernetes auth token"
253 );
254 Ok(Some(client))
255 }
256
257 fn should_refresh_vault_token(err: &ClientError) -> bool {
258 match err {
259 ClientError::APIError { code, errors } => {
260 *code == 403
261 || errors.iter().any(|msg| {
262 let msg = msg.to_ascii_lowercase();
263 msg.contains("invalid token") || msg.contains("permission denied")
264 })
265 }
266 _ => false,
267 }
268 }
269
270 pub async fn get_api_key(&self, provider_id: &str) -> Result<Option<String>> {
272 {
274 let cache = self.cache.read().await;
275 if let Some(key) = cache.get(provider_id) {
276 return Ok(Some(key.clone()));
277 }
278 }
279
280 let client = match self.client() {
282 Some(c) => c,
283 None => return Ok(None),
284 };
285
286 let secret_path = format!("{}/{}", self.path, provider_id);
287
288 let mut result =
289 kv2::read::<ProviderSecrets>(client.as_ref(), &self.mount, &secret_path).await;
290 if matches!(result.as_ref().err(), Some(err) if Self::should_refresh_vault_token(err)) {
291 if let Some(client) = self.refresh_kubernetes_auth().await? {
292 result =
293 kv2::read::<ProviderSecrets>(client.as_ref(), &self.mount, &secret_path).await;
294 }
295 }
296
297 match result {
298 Ok(secret) => {
299 if let Some(ref api_key) = secret.api_key {
301 let mut cache = self.cache.write().await;
302 cache.insert(provider_id.to_string(), api_key.clone());
303 }
304 Ok(secret.api_key)
305 }
306 Err(vaultrs::error::ClientError::APIError { code: 404, .. }) => Ok(None),
307 Err(e) => {
308 tracing::warn!("Failed to fetch secret for {}: {}", provider_id, e);
309 Ok(None)
310 }
311 }
312 }
313
314 pub async fn get_provider_secrets(&self, provider_id: &str) -> Result<Option<ProviderSecrets>> {
316 let client = match self.client() {
317 Some(c) => c,
318 None => return Ok(None),
319 };
320
321 let secret_path = format!("{}/{}", self.path, provider_id);
322
323 let mut result =
324 kv2::read::<ProviderSecrets>(client.as_ref(), &self.mount, &secret_path).await;
325 if matches!(result.as_ref().err(), Some(err) if Self::should_refresh_vault_token(err)) {
326 if let Some(client) = self.refresh_kubernetes_auth().await? {
327 result =
328 kv2::read::<ProviderSecrets>(client.as_ref(), &self.mount, &secret_path).await;
329 }
330 }
331
332 match result {
333 Ok(secret) => Ok(Some(secret)),
334 Err(vaultrs::error::ClientError::APIError { code: 404, .. }) => Ok(None),
335 Err(e) => {
336 tracing::warn!("Failed to fetch secrets for {}: {}", provider_id, e);
337 Ok(None)
338 }
339 }
340 }
341
342 pub async fn set_provider_secrets(
344 &self,
345 provider_id: &str,
346 secrets: &ProviderSecrets,
347 ) -> Result<()> {
348 let client = match self.client() {
349 Some(c) => c,
350 None => anyhow::bail!("Vault client not configured"),
351 };
352
353 let secret_path = format!("{}/{}", self.path, provider_id);
354 let mut result = kv2::set(client.as_ref(), &self.mount, &secret_path, secrets).await;
355 if matches!(result.as_ref().err(), Some(err) if Self::should_refresh_vault_token(err)) {
356 if let Some(client) = self.refresh_kubernetes_auth().await? {
357 result = kv2::set(client.as_ref(), &self.mount, &secret_path, secrets).await;
358 }
359 }
360 result.with_context(|| format!("Failed to write provider secrets for {}", provider_id))?;
361
362 let mut cache = self.cache.write().await;
364 if let Some(api_key) = secrets.api_key.clone() {
365 cache.insert(provider_id.to_string(), api_key);
366 } else {
367 cache.remove(provider_id);
368 }
369
370 Ok(())
371 }
372
373 pub async fn has_api_key(&self, provider_id: &str) -> bool {
375 matches!(self.get_api_key(provider_id).await, Ok(Some(_)))
376 }
377
378 pub async fn list_configured_providers(&self) -> Result<Vec<String>> {
380 let client = match self.client() {
381 Some(c) => c,
382 None => return Ok(Vec::new()),
383 };
384
385 let mut result = kv2::list(client.as_ref(), &self.mount, &self.path).await;
386 if matches!(result.as_ref().err(), Some(err) if Self::should_refresh_vault_token(err)) {
387 if let Some(client) = self.refresh_kubernetes_auth().await? {
388 result = kv2::list(client.as_ref(), &self.mount, &self.path).await;
389 }
390 }
391
392 match result {
393 Ok(keys) => Ok(keys),
394 Err(vaultrs::error::ClientError::APIError { code: 404, .. }) => Ok(Vec::new()),
395 Err(e) => {
396 tracing::warn!("Failed to list providers: {}", e);
397 Ok(Vec::new())
398 }
399 }
400 }
401
402 pub async fn clear_cache(&self) {
404 let mut cache = self.cache.write().await;
405 cache.clear();
406 }
407}
408
409#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
411pub struct VaultConfig {
412 pub address: String,
414
415 pub token: String,
417
418 #[serde(default)]
420 pub mount: Option<String>,
421
422 #[serde(default)]
424 pub path: Option<String>,
425}
426
427impl Default for VaultConfig {
428 fn default() -> Self {
429 Self {
430 address: String::new(),
431 token: String::new(),
432 mount: Some("secret".to_string()),
433 path: Some("codetether/providers".to_string()),
434 }
435 }
436}
437
438#[derive(Clone, serde::Serialize, serde::Deserialize)]
440pub struct ProviderSecrets {
441 #[serde(default)]
443 pub api_key: Option<String>,
444
445 #[serde(default)]
447 pub base_url: Option<String>,
448
449 #[serde(default)]
451 pub organization: Option<String>,
452
453 #[serde(default)]
455 pub headers: Option<HashMap<String, String>>,
456
457 #[serde(flatten)]
459 pub extra: HashMap<String, serde_json::Value>,
460}
461
462impl std::fmt::Debug for ProviderSecrets {
463 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
464 f.debug_struct("ProviderSecrets")
465 .field("api_key", &self.api_key.as_ref().map(|_| "<REDACTED>"))
466 .field("api_key_len", &self.api_key.as_ref().map(|k| k.len()))
467 .field("base_url", &self.base_url)
468 .field("organization", &self.organization)
469 .field("headers_present", &self.headers.is_some())
470 .field("extra_fields", &self.extra.len())
471 .finish()
472 }
473}
474
475impl ProviderSecrets {
476 pub fn has_valid_api_key(&self) -> bool {
478 self.api_key
479 .as_ref()
480 .map(|k| !k.is_empty())
481 .unwrap_or(false)
482 }
483
484 pub fn api_key_len(&self) -> Option<usize> {
486 self.api_key.as_ref().map(|k| k.len())
487 }
488}
489
490static SECRETS_MANAGER: tokio::sync::OnceCell<SecretsManager> = tokio::sync::OnceCell::const_new();
492
493pub async fn init_secrets_manager(config: &VaultConfig) -> Result<()> {
495 let manager = SecretsManager::new(config).await?;
496 SECRETS_MANAGER
497 .set(manager)
498 .map_err(|_| anyhow::anyhow!("Secrets manager already initialized"))?;
499 Ok(())
500}
501
502pub fn init_from_manager(manager: SecretsManager) -> Result<()> {
504 SECRETS_MANAGER
505 .set(manager)
506 .map_err(|_| anyhow::anyhow!("Secrets manager already initialized"))?;
507 Ok(())
508}
509
510pub fn secrets_manager() -> Option<&'static SecretsManager> {
512 SECRETS_MANAGER.get()
513}
514
515pub async fn get_api_key(provider_id: &str) -> Option<String> {
517 match SECRETS_MANAGER.get() {
518 Some(manager) => manager.get_api_key(provider_id).await.ok().flatten(),
519 None => None,
520 }
521}
522
523pub async fn has_api_key(provider_id: &str) -> bool {
525 match SECRETS_MANAGER.get() {
526 Some(manager) => manager.has_api_key(provider_id).await,
527 None => false,
528 }
529}
530
531pub async fn get_provider_secrets(provider_id: &str) -> Option<ProviderSecrets> {
533 match SECRETS_MANAGER.get() {
534 Some(manager) => manager
535 .get_provider_secrets(provider_id)
536 .await
537 .ok()
538 .flatten(),
539 None => None,
540 }
541}
542
543pub async fn set_provider_secrets(provider_id: &str, secrets: &ProviderSecrets) -> Result<()> {
545 match SECRETS_MANAGER.get() {
546 Some(manager) => manager.set_provider_secrets(provider_id, secrets).await,
547 None => anyhow::bail!("Secrets manager not initialized"),
548 }
549}
550
551pub async fn verify_reachable() -> Result<()> {
553 match SECRETS_MANAGER.get() {
554 Some(manager) => manager.verify_reachable().await,
555 None => anyhow::bail!("Secrets manager not initialized"),
556 }
557}