Skip to main content

codetether_agent/provider/
init_vault.rs

1//! Build a [`ProviderRegistry`](super::ProviderRegistry) from HashiCorp Vault.
2//!
3//! Iterates all providers configured in Vault, delegates each to
4//! [`super::init_dispatch::dispatch`], then adds env-var / AWS auto-detection
5//! unless [`CODETETHER_DISABLE_ENV_FALLBACK`](super::fallback_policy::DISABLE_ENV_FALLBACK)
6//! is set.
7
8use super::bedrock;
9use super::fallback_policy;
10use super::init_dispatch;
11use super::init_env;
12use super::registry::ProviderRegistry;
13use anyhow::Result;
14use std::sync::Arc;
15
16impl ProviderRegistry {
17    /// Initialize providers from HashiCorp Vault with optional env/AWS fallback.
18    ///
19    /// See [module-level docs](super) for the security model and fallback order.
20    ///
21    /// # Examples
22    ///
23    /// ```rust,no_run
24    /// use codetether_agent::provider::ProviderRegistry;
25    /// # async fn demo() {
26    /// let registry = ProviderRegistry::from_vault().await.unwrap();
27    /// # }
28    /// ```
29    pub async fn from_vault() -> Result<Self> {
30        let mut registry = Self::new();
31        let disable_env = fallback_policy::env_fallback_disabled();
32
33        if let Some(mgr) = crate::secrets::secrets_manager() {
34            let providers = mgr.list_configured_providers().await?;
35            tracing::info!("Found {} providers configured in Vault", providers.len());
36
37            // Fetch every provider's secrets concurrently; each fetch is an
38            // independent Vault HTTP round-trip so there's no reason to
39            // serialize them. With ~10 providers this turns ~10 * RTT
40            // latency into ~1 * RTT.
41            let fetches = providers.into_iter().map(|pid| async move {
42                let secrets = mgr.get_provider_secrets(&pid).await;
43                (pid, secrets)
44            });
45            let results = futures::future::join_all(fetches).await;
46
47            for (pid, secrets) in results {
48                let secrets = match secrets {
49                    Ok(Some(s)) => s,
50                    Ok(None) => continue,
51                    Err(err) => {
52                        tracing::warn!(provider = %pid, %err, "vault fetch failed; skipping");
53                        continue;
54                    }
55                };
56                if let Some(provider) = init_dispatch::dispatch(&pid, &secrets) {
57                    registry.register(provider);
58                }
59            }
60        } else {
61            tracing::warn!("Vault not configured, no providers loaded from Vault");
62        }
63
64        // Bedrock auto-detect from local AWS creds if Vault didn't register it
65        if !registry.providers.contains_key("bedrock") && !disable_env {
66            if let Some(creds) = bedrock::AwsCredentials::from_environment() {
67                let region =
68                    bedrock::AwsCredentials::detect_region().unwrap_or_else(|| "us-east-1".into());
69                match bedrock::BedrockProvider::with_credentials(creds, region) {
70                    Ok(p) => {
71                        tracing::info!("Registered Bedrock from local AWS credentials");
72                        registry.register(Arc::new(p));
73                    }
74                    Err(e) => tracing::warn!("Failed to init bedrock: {e}"),
75                }
76            }
77        }
78
79        if !disable_env {
80            init_env::register_env_fallbacks(&mut registry);
81        } else {
82            tracing::info!(
83                env = fallback_policy::DISABLE_ENV_FALLBACK,
84                "Env/AWS fallback disabled"
85            );
86        }
87
88        tracing::info!(
89            mode = fallback_policy::registry_mode_label(disable_env),
90            "Registered {} providers",
91            registry.providers.len(),
92        );
93        Ok(registry)
94    }
95
96    /// Process-wide cached [`from_vault`](Self::from_vault) registry.
97    ///
98    /// `from_vault` performs vault fetches plus env-var / AWS probing on
99    /// every call. Compression paths (e.g. RLM model resolution inside
100    /// [`enforce_on_messages`](crate::session::helper::compression::enforce_on_messages))
101    /// invoke it once per keep-last attempt per turn, which can add up
102    /// to several Vault round-trips of unnecessary latency in the hot
103    /// loop. This accessor lazily builds the registry exactly once and
104    /// hands out `Arc` clones thereafter.
105    ///
106    /// The cache is process-global. Restart the binary to pick up
107    /// re-keyed providers.
108    ///
109    /// # Errors
110    ///
111    /// Propagates the underlying [`from_vault`](Self::from_vault) error
112    /// on the first call. Subsequent calls reuse the cached value and
113    /// cannot fail.
114    pub async fn shared_from_vault() -> Result<Arc<Self>> {
115        use tokio::sync::OnceCell;
116        static CACHED: OnceCell<Arc<ProviderRegistry>> = OnceCell::const_new();
117        CACHED
118            .get_or_try_init(|| async { Self::from_vault().await.map(Arc::new) })
119            .await
120            .cloned()
121    }
122}