codetether_agent/provider/init_vault.rs
1//! Build a [`ProviderRegistry`](super::ProviderRegistry) from HashiCorp Vault.
2//!
3//! Iterates all providers configured in Vault, delegates each to
4//! [`super::init_dispatch::dispatch`], then adds env-var / AWS auto-detection
5//! unless [`CODETETHER_DISABLE_ENV_FALLBACK`](super::fallback_policy::DISABLE_ENV_FALLBACK)
6//! is set.
7
8use super::bedrock;
9use super::fallback_policy;
10use super::init_dispatch;
11use super::init_env;
12use super::registry::ProviderRegistry;
13use anyhow::Result;
14use std::sync::Arc;
15
16impl ProviderRegistry {
17 /// Initialize providers from HashiCorp Vault with optional env/AWS fallback.
18 ///
19 /// See [module-level docs](super) for the security model and fallback order.
20 ///
21 /// # Examples
22 ///
23 /// ```rust,no_run
24 /// use codetether_agent::provider::ProviderRegistry;
25 /// # async fn demo() {
26 /// let registry = ProviderRegistry::from_vault().await.unwrap();
27 /// # }
28 /// ```
29 pub async fn from_vault() -> Result<Self> {
30 let mut registry = Self::new();
31 let disable_env = fallback_policy::env_fallback_disabled();
32
33 if let Some(mgr) = crate::secrets::secrets_manager() {
34 let providers = mgr.list_configured_providers().await?;
35 tracing::info!("Found {} providers configured in Vault", providers.len());
36
37 // Fetch every provider's secrets concurrently; each fetch is an
38 // independent Vault HTTP round-trip so there's no reason to
39 // serialize them. With ~10 providers this turns ~10 * RTT
40 // latency into ~1 * RTT.
41 let fetches = providers.into_iter().map(|pid| async move {
42 let secrets = mgr.get_provider_secrets(&pid).await;
43 (pid, secrets)
44 });
45 let results = futures::future::join_all(fetches).await;
46
47 for (pid, secrets) in results {
48 let secrets = match secrets {
49 Ok(Some(s)) => s,
50 Ok(None) => continue,
51 Err(err) => {
52 tracing::warn!(provider = %pid, %err, "vault fetch failed; skipping");
53 continue;
54 }
55 };
56 if let Some(provider) = init_dispatch::dispatch(&pid, &secrets) {
57 registry.register(provider);
58 }
59 }
60 } else {
61 tracing::warn!("Vault not configured, no providers loaded from Vault");
62 }
63
64 // Bedrock auto-detect from local AWS creds if Vault didn't register it
65 if !registry.providers.contains_key("bedrock") && !disable_env {
66 if let Some(creds) = bedrock::AwsCredentials::from_environment() {
67 let region =
68 bedrock::AwsCredentials::detect_region().unwrap_or_else(|| "us-east-1".into());
69 match bedrock::BedrockProvider::with_credentials(creds, region) {
70 Ok(p) => {
71 tracing::info!("Registered Bedrock from local AWS credentials");
72 registry.register(Arc::new(p));
73 }
74 Err(e) => tracing::warn!("Failed to init bedrock: {e}"),
75 }
76 }
77 }
78
79 if !disable_env {
80 init_env::register_env_fallbacks(&mut registry);
81 } else {
82 tracing::info!(
83 env = fallback_policy::DISABLE_ENV_FALLBACK,
84 "Env/AWS fallback disabled"
85 );
86 }
87
88 tracing::info!(
89 mode = fallback_policy::registry_mode_label(disable_env),
90 "Registered {} providers",
91 registry.providers.len(),
92 );
93 Ok(registry)
94 }
95
96 /// Process-wide cached [`from_vault`](Self::from_vault) registry.
97 ///
98 /// `from_vault` performs vault fetches plus env-var / AWS probing on
99 /// every call. Compression paths (e.g. RLM model resolution inside
100 /// [`enforce_on_messages`](crate::session::helper::compression::enforce_on_messages))
101 /// invoke it once per keep-last attempt per turn, which can add up
102 /// to several Vault round-trips of unnecessary latency in the hot
103 /// loop. This accessor lazily builds the registry exactly once and
104 /// hands out `Arc` clones thereafter.
105 ///
106 /// The cache is process-global. Restart the binary to pick up
107 /// re-keyed providers.
108 ///
109 /// # Errors
110 ///
111 /// Propagates the underlying [`from_vault`](Self::from_vault) error
112 /// on the first call. Subsequent calls reuse the cached value and
113 /// cannot fail.
114 pub async fn shared_from_vault() -> Result<Arc<Self>> {
115 use tokio::sync::OnceCell;
116 static CACHED: OnceCell<Arc<ProviderRegistry>> = OnceCell::const_new();
117 CACHED
118 .get_or_try_init(|| async { Self::from_vault().await.map(Arc::new) })
119 .await
120 .cloned()
121 }
122}