caliban_provider_vertex/
lib.rs1#![allow(clippy::missing_errors_doc)]
19#![allow(clippy::multiple_crate_versions)]
20
21pub mod auth;
22pub mod config;
23pub mod error;
24pub mod models;
25
26use std::sync::Arc;
27
28use async_trait::async_trait;
29use caliban_provider::{
30 Capabilities, CompletionRequest, CompletionResponse, Error, MessageStream, ModelInfo, Provider,
31 Result,
32};
33use caliban_provider_anthropic::AnthropicProvider;
34use caliban_provider_anthropic::config::VertexConfig as InnerVertexConfig;
35use caliban_provider_anthropic::transport::vertex::VertexTransport;
36use gcp_auth::TokenProvider;
37
38pub use auth::AuthRefresh;
39pub use config::VertexConfig;
40pub use error::VertexError;
41
42pub struct VertexProvider {
44 inner: AnthropicProvider<VertexTransport>,
45 config: VertexConfig,
46 token_provider: Arc<dyn TokenProvider>,
47 auth: Arc<AuthRefresh>,
48 list_client: reqwest::Client,
49}
50
51impl std::fmt::Debug for VertexProvider {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("VertexProvider")
54 .field("project_id", &self.config.project_id)
55 .field("region", &self.config.region)
56 .field("auth_refresh", &self.config.auth_refresh)
57 .finish_non_exhaustive()
58 }
59}
60
61impl VertexProvider {
62 pub async fn from_env() -> std::result::Result<Self, VertexError> {
64 let cfg = VertexConfig::from_env()?;
65 Self::from_config(cfg).await
66 }
67
68 pub async fn from_config(cfg: VertexConfig) -> std::result::Result<Self, VertexError> {
75 let token_provider: Arc<dyn TokenProvider> = if let Some(path) =
76 cfg.service_account_key_path.as_deref()
77 {
78 let sa = gcp_auth::CustomServiceAccount::from_file(path).map_err(VertexError::Auth)?;
79 Arc::new(sa)
80 } else {
81 gcp_auth::provider().await.map_err(VertexError::Auth)?
82 };
83 Self::from_parts(cfg, token_provider).await
84 }
85
86 #[allow(clippy::unused_async)] pub async fn from_parts(
90 cfg: VertexConfig,
91 token_provider: Arc<dyn TokenProvider>,
92 ) -> std::result::Result<Self, VertexError> {
93 let inner_cfg = InnerVertexConfig {
94 token_provider: token_provider.clone(),
95 project: cfg.project_id.clone(),
96 region: cfg.region.clone(),
97 timeout: std::time::Duration::from_mins(1),
98 anthropic_version: "vertex-2023-10-16".to_string(),
99 };
100 let inner = AnthropicProvider::vertex(inner_cfg)
101 .map_err(|e| VertexError::Transport(Box::new(e)))?;
102 let auth = AuthRefresh::spawn(token_provider.clone(), cfg.auth_refresh);
103 let list_client = caliban_common::http::default_client_builder()
104 .build()
105 .map_err(VertexError::Http)?;
106 Ok(Self {
107 inner,
108 config: cfg,
109 token_provider,
110 auth: Arc::new(auth),
111 list_client,
112 })
113 }
114
115 #[must_use]
117 pub fn auth_refresh(&self) -> &AuthRefresh {
118 &self.auth
119 }
120
121 #[must_use]
123 pub fn config(&self) -> &VertexConfig {
124 &self.config
125 }
126}
127
128#[async_trait]
129impl Provider for VertexProvider {
130 async fn complete(&self, req: CompletionRequest) -> Result<CompletionResponse> {
131 self.inner.complete(req).await
132 }
133
134 async fn stream(&self, req: CompletionRequest) -> Result<MessageStream> {
135 self.inner.stream(req).await
136 }
137
138 fn capabilities(&self, model: &str) -> Capabilities {
139 models::capabilities_for_vertex(model)
140 }
141
142 fn list_models(&self) -> Vec<ModelInfo> {
143 models::vendored_vertex_models()
145 }
146
147 async fn refresh_models(&self) -> Result<Vec<ModelInfo>> {
148 let base = models::default_base_url(&self.config.region);
152 models::list_models_remote(&self.list_client, &self.token_provider, &base)
153 .await
154 .map_err(Error::adapter)
155 }
156
157 fn name(&self) -> &'static str {
158 "vertex"
159 }
160}