Skip to main content

codetether_agent/secrets/
mod.rs

1//! Secrets management via HashiCorp Vault
2//!
3//! This module only reads HashiCorp Vault. Provider initialization may add
4//! local-development env/AWS fallback credentials unless
5//! `CODETETHER_DISABLE_ENV_FALLBACK=1` is set.
6
7use anyhow::{Context, Result};
8use parking_lot::RwLock as ClientRwLock;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use vaultrs::client::{VaultClient, VaultClientSettingsBuilder};
13use vaultrs::error::ClientError;
14use vaultrs::kv2;
15
16/// Path in Vault where provider secrets are stored
17#[allow(dead_code)]
18const DEFAULT_SECRETS_PATH: &str = "secret/data/codetether/providers";
19
20/// Vault-based secrets manager
21#[derive(Clone)]
22pub struct SecretsManager {
23    client: Arc<ClientRwLock<Option<Arc<VaultClient>>>>,
24    /// Cache of loaded API keys (provider_id -> api_key)
25    pub cache: Arc<RwLock<HashMap<String, String>>>,
26    mount: String,
27    path: String,
28    k8s_auth: Option<Arc<KubernetesAuthConfig>>,
29}
30
31#[derive(Clone, Debug)]
32struct KubernetesAuthConfig {
33    address: String,
34    role: String,
35    auth_mount: String,
36    jwt_path: String,
37}
38
39impl Default for SecretsManager {
40    fn default() -> Self {
41        Self {
42            client: Arc::new(ClientRwLock::new(None)),
43            cache: Arc::new(RwLock::new(HashMap::new())),
44            mount: "secret".to_string(),
45            path: "codetether/providers".to_string(),
46            k8s_auth: None,
47        }
48    }
49}
50
51impl SecretsManager {
52    /// Create a new secrets manager with Vault configuration
53    pub async fn new(config: &VaultConfig) -> Result<Self> {
54        let settings = VaultClientSettingsBuilder::default()
55            .address(&config.address)
56            .token(&config.token)
57            .build()
58            .context("Failed to build Vault client settings")?;
59
60        let client = VaultClient::new(settings).context("Failed to create Vault client")?;
61
62        Ok(Self {
63            client: Arc::new(ClientRwLock::new(Some(Arc::new(client)))),
64            cache: Arc::new(RwLock::new(HashMap::new())),
65            mount: config.mount.clone().unwrap_or_else(|| "secret".to_string()),
66            path: config
67                .path
68                .clone()
69                .unwrap_or_else(|| "codetether/providers".to_string()),
70            k8s_auth: None,
71        })
72    }
73
74    /// Authenticate to Vault using the pod's Kubernetes service account JWT.
75    ///
76    /// Reads the SA JWT from the standard Kubernetes mount path (overridable via
77    /// `VAULT_K8S_SA_JWT_PATH`) then calls the Vault `auth/kubernetes/login`
78    /// endpoint.  The returned manager holds the short-lived token that Vault
79    /// issued — no `VAULT_TOKEN` environment variable is required.
80    pub async fn from_k8s_auth(
81        address: &str,
82        role: &str,
83        mount: &str,
84        kv_mount: Option<&str>,
85        kv_path: Option<&str>,
86    ) -> Result<Self> {
87        let jwt_path = std::env::var("VAULT_K8S_SA_JWT_PATH")
88            .unwrap_or_else(|_| "/var/run/secrets/kubernetes.io/serviceaccount/token".to_string());
89
90        let auth = KubernetesAuthConfig {
91            address: address.to_string(),
92            role: role.to_string(),
93            auth_mount: mount.to_string(),
94            jwt_path,
95        };
96
97        let client = Self::login_with_kubernetes(&auth).await?;
98
99        Ok(Self {
100            client: Arc::new(ClientRwLock::new(Some(client))),
101            cache: Arc::new(RwLock::new(HashMap::new())),
102            mount: kv_mount.unwrap_or("secret").to_string(),
103            path: kv_path.unwrap_or("codetether/providers").to_string(),
104            k8s_auth: Some(Arc::new(auth)),
105        })
106    }
107
108    async fn login_with_kubernetes(auth: &KubernetesAuthConfig) -> Result<Arc<VaultClient>> {
109        let jwt = tokio::fs::read_to_string(&auth.jwt_path)
110            .await
111            .with_context(|| {
112                format!(
113                    "Failed to read Kubernetes service account token from {}",
114                    auth.jwt_path
115                )
116            })?;
117        let jwt = jwt.trim().to_string();
118
119        // Bootstrap client with an empty token — only used for the one-shot
120        // auth call; the real authenticated client is built from the result.
121        let bootstrap_settings = VaultClientSettingsBuilder::default()
122            .address(&auth.address)
123            .token("")
124            .build()
125            .context("Failed to build bootstrap Vault client settings")?;
126        let bootstrap_client = VaultClient::new(bootstrap_settings)
127            .context("Failed to create bootstrap Vault client")?;
128
129        let auth_info =
130            vaultrs::auth::kubernetes::login(&bootstrap_client, &auth.auth_mount, &auth.role, &jwt)
131                .await
132                .context("Vault Kubernetes auth login failed")?;
133
134        let settings = VaultClientSettingsBuilder::default()
135            .address(&auth.address)
136            .token(&auth_info.client_token)
137            .build()
138            .context("Failed to build authenticated Vault client settings")?;
139        let client =
140            VaultClient::new(settings).context("Failed to create authenticated Vault client")?;
141
142        Ok(Arc::new(client))
143    }
144
145    /// Try to create from environment (for initial bootstrap only).
146    ///
147    /// When `VAULT_ROLE` is set the worker authenticates via Kubernetes service
148    /// account — no static token is needed and the resulting Vault token is
149    /// short-lived and automatically rotated by Vault itself.  Falls back to
150    /// `VAULT_TOKEN` when `VAULT_ROLE` is absent or K8s auth fails.
151    pub async fn from_env() -> Result<Self> {
152        let address = std::env::var("VAULT_ADDR").context("VAULT_ADDR not set")?;
153        let kv_mount = std::env::var("VAULT_MOUNT").ok();
154        let kv_path = std::env::var("VAULT_SECRETS_PATH").ok();
155
156        // Prefer Kubernetes service-account auth when VAULT_ROLE is set.
157        // This eliminates the dependency on a static VAULT_TOKEN; the pod's own
158        // SA JWT (mounted by k8s at the standard path) is the only credential
159        // the container needs to carry.
160        if let Ok(role) = std::env::var("VAULT_ROLE") {
161            let role = role.trim().to_string();
162            if !role.is_empty() {
163                let k8s_mount =
164                    std::env::var("VAULT_AUTH_MOUNT").unwrap_or_else(|_| "kubernetes".to_string());
165
166                match Self::from_k8s_auth(
167                    &address,
168                    &role,
169                    &k8s_mount,
170                    kv_mount.as_deref(),
171                    kv_path.as_deref(),
172                )
173                .await
174                {
175                    Ok(manager) => {
176                        tracing::info!(
177                            role = %role,
178                            mount = %k8s_mount,
179                            "Authenticated to Vault via Kubernetes service account"
180                        );
181                        return Ok(manager);
182                    }
183                    Err(e) => {
184                        tracing::warn!(
185                            error = %e,
186                            "Vault Kubernetes auth failed; falling back to VAULT_TOKEN"
187                        );
188                    }
189                }
190            }
191        }
192
193        let token = std::env::var("VAULT_TOKEN").context("VAULT_TOKEN not set")?;
194        let config = VaultConfig {
195            address,
196            token,
197            mount: kv_mount,
198            path: kv_path,
199        };
200
201        Self::new(&config).await
202    }
203
204    /// Check if Vault is configured and connected
205    pub fn is_connected(&self) -> bool {
206        self.client.read().is_some()
207    }
208
209    fn client(&self) -> Option<Arc<VaultClient>> {
210        self.client.read().clone()
211    }
212
213    async fn refresh_kubernetes_auth(&self) -> Result<Option<Arc<VaultClient>>> {
214        let Some(auth) = self.k8s_auth.as_deref() else {
215            return Ok(self.client());
216        };
217
218        tracing::warn!("Vault token was rejected; refreshing Kubernetes auth token");
219        let client = Self::login_with_kubernetes(auth).await?;
220        {
221            let mut current = self.client.write();
222            *current = Some(client.clone());
223        }
224        self.clear_cache().await;
225        tracing::info!(
226            role = %auth.role,
227            mount = %auth.auth_mount,
228            "Refreshed Vault Kubernetes auth token"
229        );
230        Ok(Some(client))
231    }
232
233    fn should_refresh_vault_token(err: &ClientError) -> bool {
234        match err {
235            ClientError::APIError { code, errors } => {
236                *code == 403
237                    || errors.iter().any(|msg| {
238                        let msg = msg.to_ascii_lowercase();
239                        msg.contains("invalid token") || msg.contains("permission denied")
240                    })
241            }
242            _ => false,
243        }
244    }
245
246    /// Get an API key for a provider from Vault
247    pub async fn get_api_key(&self, provider_id: &str) -> Result<Option<String>> {
248        // Check cache first
249        {
250            let cache = self.cache.read().await;
251            if let Some(key) = cache.get(provider_id) {
252                return Ok(Some(key.clone()));
253            }
254        }
255
256        // Fetch from Vault
257        let client = match self.client() {
258            Some(c) => c,
259            None => return Ok(None),
260        };
261
262        let secret_path = format!("{}/{}", self.path, provider_id);
263
264        let mut result =
265            kv2::read::<ProviderSecrets>(client.as_ref(), &self.mount, &secret_path).await;
266        if matches!(result.as_ref().err(), Some(err) if Self::should_refresh_vault_token(err)) {
267            if let Some(client) = self.refresh_kubernetes_auth().await? {
268                result =
269                    kv2::read::<ProviderSecrets>(client.as_ref(), &self.mount, &secret_path).await;
270            }
271        }
272
273        match result {
274            Ok(secret) => {
275                // Cache the result
276                if let Some(ref api_key) = secret.api_key {
277                    let mut cache = self.cache.write().await;
278                    cache.insert(provider_id.to_string(), api_key.clone());
279                }
280                Ok(secret.api_key)
281            }
282            Err(vaultrs::error::ClientError::APIError { code: 404, .. }) => Ok(None),
283            Err(e) => {
284                tracing::warn!("Failed to fetch secret for {}: {}", provider_id, e);
285                Ok(None)
286            }
287        }
288    }
289
290    /// Get all secrets for a provider
291    pub async fn get_provider_secrets(&self, provider_id: &str) -> Result<Option<ProviderSecrets>> {
292        let client = match self.client() {
293            Some(c) => c,
294            None => return Ok(None),
295        };
296
297        let secret_path = format!("{}/{}", self.path, provider_id);
298
299        let mut result =
300            kv2::read::<ProviderSecrets>(client.as_ref(), &self.mount, &secret_path).await;
301        if matches!(result.as_ref().err(), Some(err) if Self::should_refresh_vault_token(err)) {
302            if let Some(client) = self.refresh_kubernetes_auth().await? {
303                result =
304                    kv2::read::<ProviderSecrets>(client.as_ref(), &self.mount, &secret_path).await;
305            }
306        }
307
308        match result {
309            Ok(secret) => Ok(Some(secret)),
310            Err(vaultrs::error::ClientError::APIError { code: 404, .. }) => Ok(None),
311            Err(e) => {
312                tracing::warn!("Failed to fetch secrets for {}: {}", provider_id, e);
313                Ok(None)
314            }
315        }
316    }
317
318    /// Set/update secrets for a provider in Vault
319    pub async fn set_provider_secrets(
320        &self,
321        provider_id: &str,
322        secrets: &ProviderSecrets,
323    ) -> Result<()> {
324        let client = match self.client() {
325            Some(c) => c,
326            None => anyhow::bail!("Vault client not configured"),
327        };
328
329        let secret_path = format!("{}/{}", self.path, provider_id);
330        let mut result = kv2::set(client.as_ref(), &self.mount, &secret_path, secrets).await;
331        if matches!(result.as_ref().err(), Some(err) if Self::should_refresh_vault_token(err)) {
332            if let Some(client) = self.refresh_kubernetes_auth().await? {
333                result = kv2::set(client.as_ref(), &self.mount, &secret_path, secrets).await;
334            }
335        }
336        result.with_context(|| format!("Failed to write provider secrets for {}", provider_id))?;
337
338        // Update cache with latest API key value
339        let mut cache = self.cache.write().await;
340        if let Some(api_key) = secrets.api_key.clone() {
341            cache.insert(provider_id.to_string(), api_key);
342        } else {
343            cache.remove(provider_id);
344        }
345
346        Ok(())
347    }
348
349    /// Check if a provider has an API key in Vault
350    pub async fn has_api_key(&self, provider_id: &str) -> bool {
351        matches!(self.get_api_key(provider_id).await, Ok(Some(_)))
352    }
353
354    /// List all providers that have secrets configured
355    pub async fn list_configured_providers(&self) -> Result<Vec<String>> {
356        let client = match self.client() {
357            Some(c) => c,
358            None => return Ok(Vec::new()),
359        };
360
361        let mut result = kv2::list(client.as_ref(), &self.mount, &self.path).await;
362        if matches!(result.as_ref().err(), Some(err) if Self::should_refresh_vault_token(err)) {
363            if let Some(client) = self.refresh_kubernetes_auth().await? {
364                result = kv2::list(client.as_ref(), &self.mount, &self.path).await;
365            }
366        }
367
368        match result {
369            Ok(keys) => Ok(keys),
370            Err(vaultrs::error::ClientError::APIError { code: 404, .. }) => Ok(Vec::new()),
371            Err(e) => {
372                tracing::warn!("Failed to list providers: {}", e);
373                Ok(Vec::new())
374            }
375        }
376    }
377
378    /// Clear the cache (useful when secrets are rotated)
379    pub async fn clear_cache(&self) {
380        let mut cache = self.cache.write().await;
381        cache.clear();
382    }
383}
384
385/// Vault configuration
386#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
387pub struct VaultConfig {
388    /// Vault server address (e.g., "https://vault.example.com:8200")
389    pub address: String,
390
391    /// Vault token for authentication
392    pub token: String,
393
394    /// KV secrets engine mount path (default: "secret")
395    #[serde(default)]
396    pub mount: Option<String>,
397
398    /// Path prefix for provider secrets (default: "codetether/providers")
399    #[serde(default)]
400    pub path: Option<String>,
401}
402
403impl Default for VaultConfig {
404    fn default() -> Self {
405        Self {
406            address: String::new(),
407            token: String::new(),
408            mount: Some("secret".to_string()),
409            path: Some("codetether/providers".to_string()),
410        }
411    }
412}
413
414/// Provider secrets stored in Vault
415#[derive(Clone, serde::Serialize, serde::Deserialize)]
416pub struct ProviderSecrets {
417    /// API key for the provider
418    #[serde(default)]
419    pub api_key: Option<String>,
420
421    /// Base URL override
422    #[serde(default)]
423    pub base_url: Option<String>,
424
425    /// Organization ID (for OpenAI)
426    #[serde(default)]
427    pub organization: Option<String>,
428
429    /// Additional headers as JSON
430    #[serde(default)]
431    pub headers: Option<HashMap<String, String>>,
432
433    /// Any provider-specific extra fields
434    #[serde(flatten)]
435    pub extra: HashMap<String, serde_json::Value>,
436}
437
438impl std::fmt::Debug for ProviderSecrets {
439    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
440        f.debug_struct("ProviderSecrets")
441            .field("api_key", &self.api_key.as_ref().map(|_| "<REDACTED>"))
442            .field("api_key_len", &self.api_key.as_ref().map(|k| k.len()))
443            .field("base_url", &self.base_url)
444            .field("organization", &self.organization)
445            .field("headers_present", &self.headers.is_some())
446            .field("extra_fields", &self.extra.len())
447            .finish()
448    }
449}
450
451impl ProviderSecrets {
452    /// Check if API key is present and valid (non-empty)
453    pub fn has_valid_api_key(&self) -> bool {
454        self.api_key
455            .as_ref()
456            .map(|k| !k.is_empty())
457            .unwrap_or(false)
458    }
459
460    /// Get API key length without exposing the key
461    pub fn api_key_len(&self) -> Option<usize> {
462        self.api_key.as_ref().map(|k| k.len())
463    }
464}
465
466/// Global secrets manager instance
467static SECRETS_MANAGER: tokio::sync::OnceCell<SecretsManager> = tokio::sync::OnceCell::const_new();
468
469/// Initialize the global secrets manager
470pub async fn init_secrets_manager(config: &VaultConfig) -> Result<()> {
471    let manager = SecretsManager::new(config).await?;
472    SECRETS_MANAGER
473        .set(manager)
474        .map_err(|_| anyhow::anyhow!("Secrets manager already initialized"))?;
475    Ok(())
476}
477
478/// Initialize the global secrets manager from an existing manager instance
479pub fn init_from_manager(manager: SecretsManager) -> Result<()> {
480    SECRETS_MANAGER
481        .set(manager)
482        .map_err(|_| anyhow::anyhow!("Secrets manager already initialized"))?;
483    Ok(())
484}
485
486/// Get the global secrets manager
487pub fn secrets_manager() -> Option<&'static SecretsManager> {
488    SECRETS_MANAGER.get()
489}
490
491/// Get API key for a provider (convenience function)
492pub async fn get_api_key(provider_id: &str) -> Option<String> {
493    match SECRETS_MANAGER.get() {
494        Some(manager) => manager.get_api_key(provider_id).await.ok().flatten(),
495        None => None,
496    }
497}
498
499/// Check if a provider has an API key (convenience function)
500pub async fn has_api_key(provider_id: &str) -> bool {
501    match SECRETS_MANAGER.get() {
502        Some(manager) => manager.has_api_key(provider_id).await,
503        None => false,
504    }
505}
506
507/// Get full provider secrets (convenience function)
508pub async fn get_provider_secrets(provider_id: &str) -> Option<ProviderSecrets> {
509    match SECRETS_MANAGER.get() {
510        Some(manager) => manager
511            .get_provider_secrets(provider_id)
512            .await
513            .ok()
514            .flatten(),
515        None => None,
516    }
517}
518
519/// Set full provider secrets (convenience function)
520pub async fn set_provider_secrets(provider_id: &str, secrets: &ProviderSecrets) -> Result<()> {
521    match SECRETS_MANAGER.get() {
522        Some(manager) => manager.set_provider_secrets(provider_id, secrets).await,
523        None => anyhow::bail!("Secrets manager not initialized"),
524    }
525}