Skip to main content

caliban_provider_vertex/
lib.rs

1//! Google Vertex AI provider for the caliban agent harness.
2//!
3//! Thin wrapper around `caliban_provider_anthropic::AnthropicProvider<VertexTransport>`
4//! that adds:
5//!
6//! - A [`VertexConfig`] with a `from_env` constructor.
7//! - An [`AuthRefresh`] background task that periodically refreshes the
8//!   GCP bearer token via `gcp_auth::TokenProvider`.
9//! - A `refresh_models` (the `Provider` trait's live-discovery hook) that calls
10//!   `https://{region}-aiplatform.googleapis.com/v1/publishers/anthropic/models`;
11//!   the sync `list_models` returns the vendored fallback.
12//! - `name() -> "vertex"` so the model router and telemetry attribute it
13//!   correctly.
14//!
15//! See `docs/adr/0034-bedrock-and-vertex-providers.md` and
16//! `docs/superpowers/specs/2026-05-24-bedrock-vertex-providers-design.md`.
17
18#![allow(clippy::missing_errors_doc)]
19#![allow(clippy::multiple_crate_versions)]
20
21pub mod auth;
22pub mod config;
23pub mod error;
24pub mod models;
25
26use std::sync::Arc;
27
28use async_trait::async_trait;
29use caliban_provider::{
30    Capabilities, CompletionRequest, CompletionResponse, Error, MessageStream, ModelInfo, Provider,
31    Result,
32};
33use caliban_provider_anthropic::AnthropicProvider;
34use caliban_provider_anthropic::config::VertexConfig as InnerVertexConfig;
35use caliban_provider_anthropic::transport::vertex::VertexTransport;
36use gcp_auth::TokenProvider;
37
38pub use auth::AuthRefresh;
39pub use config::VertexConfig;
40pub use error::VertexError;
41
42/// Provider that talks to Anthropic Claude on Google Vertex AI.
43pub struct VertexProvider {
44    inner: AnthropicProvider<VertexTransport>,
45    config: VertexConfig,
46    token_provider: Arc<dyn TokenProvider>,
47    auth: Arc<AuthRefresh>,
48    list_client: reqwest::Client,
49}
50
51impl std::fmt::Debug for VertexProvider {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("VertexProvider")
54            .field("project_id", &self.config.project_id)
55            .field("region", &self.config.region)
56            .field("auth_refresh", &self.config.auth_refresh)
57            .finish_non_exhaustive()
58    }
59}
60
61impl VertexProvider {
62    /// Build a `VertexProvider` from environment variables.
63    pub async fn from_env() -> std::result::Result<Self, VertexError> {
64        let cfg = VertexConfig::from_env()?;
65        Self::from_config(cfg).await
66    }
67
68    /// Build a `VertexProvider` from an explicit [`VertexConfig`].
69    ///
70    /// If `cfg.service_account_key_path` is set, the file is loaded via
71    /// `gcp_auth::CustomServiceAccount::from_file`. Otherwise the default
72    /// `gcp_auth::provider()` chain is used (ADC, gcloud user creds, GCE
73    /// metadata server).
74    pub async fn from_config(cfg: VertexConfig) -> std::result::Result<Self, VertexError> {
75        let token_provider: Arc<dyn TokenProvider> = if let Some(path) =
76            cfg.service_account_key_path.as_deref()
77        {
78            let sa = gcp_auth::CustomServiceAccount::from_file(path).map_err(VertexError::Auth)?;
79            Arc::new(sa)
80        } else {
81            gcp_auth::provider().await.map_err(VertexError::Auth)?
82        };
83        Self::from_parts(cfg, token_provider).await
84    }
85
86    /// Build a `VertexProvider` with an explicit token provider (mainly
87    /// for tests; production callers want `from_env` / `from_config`).
88    #[allow(clippy::unused_async)] // async for API symmetry with from_config
89    pub async fn from_parts(
90        cfg: VertexConfig,
91        token_provider: Arc<dyn TokenProvider>,
92    ) -> std::result::Result<Self, VertexError> {
93        let inner_cfg = InnerVertexConfig {
94            token_provider: token_provider.clone(),
95            project: cfg.project_id.clone(),
96            region: cfg.region.clone(),
97            timeout: std::time::Duration::from_mins(1),
98            anthropic_version: "vertex-2023-10-16".to_string(),
99        };
100        let inner = AnthropicProvider::vertex(inner_cfg)
101            .map_err(|e| VertexError::Transport(Box::new(e)))?;
102        let auth = AuthRefresh::spawn(token_provider.clone(), cfg.auth_refresh);
103        let list_client = caliban_common::http::default_client_builder()
104            .build()
105            .map_err(VertexError::Http)?;
106        Ok(Self {
107            inner,
108            config: cfg,
109            token_provider,
110            auth: Arc::new(auth),
111            list_client,
112        })
113    }
114
115    /// Access the `AuthRefresh` task (mainly for tests and graceful shutdown).
116    #[must_use]
117    pub fn auth_refresh(&self) -> &AuthRefresh {
118        &self.auth
119    }
120
121    /// Access the `VertexConfig` this provider was constructed with.
122    #[must_use]
123    pub fn config(&self) -> &VertexConfig {
124        &self.config
125    }
126}
127
128#[async_trait]
129impl Provider for VertexProvider {
130    async fn complete(&self, req: CompletionRequest) -> Result<CompletionResponse> {
131        self.inner.complete(req).await
132    }
133
134    async fn stream(&self, req: CompletionRequest) -> Result<MessageStream> {
135        self.inner.stream(req).await
136    }
137
138    fn capabilities(&self, model: &str) -> Capabilities {
139        models::capabilities_for_vertex(model)
140    }
141
142    fn list_models(&self) -> Vec<ModelInfo> {
143        // Offline view: the vendored catalog. Live callers use `refresh_models`.
144        models::vendored_vertex_models()
145    }
146
147    async fn refresh_models(&self) -> Result<Vec<ModelInfo>> {
148        // The real publisher catalog from Vertex, replacing the vendored list
149        // the sync `list_models` must return. Errors propagate so the caller
150        // (e.g. the #34 refresh path) can decide whether to fall back.
151        let base = models::default_base_url(&self.config.region);
152        models::list_models_remote(&self.list_client, &self.token_provider, &base)
153            .await
154            .map_err(Error::adapter)
155    }
156
157    fn name(&self) -> &'static str {
158        "vertex"
159    }
160}