claude_agent/client/adapter/
vertex.rs

1//! Google Vertex AI adapter with ADC authentication.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use gcp_auth::TokenProvider;
9
10use super::base::RequestExecutor;
11use super::config::{BetaFeature, ProviderConfig};
12use super::request::{add_beta_features, build_messages_body};
13use super::token_cache::{CachedToken, TokenCache, new_token_cache};
14use super::traits::ProviderAdapter;
15use crate::client::messages::CreateMessageRequest;
16use crate::config::VertexConfig;
17use crate::types::ApiResponse;
18use crate::{Error, Result};
19
20const ANTHROPIC_VERSION: &str = "vertex-2023-10-16";
21
22pub struct VertexAdapter {
23    config: ProviderConfig,
24    project_id: String,
25    default_region: String,
26    model_regions: HashMap<String, String>,
27    enable_1m_context: bool,
28    token_provider: Arc<dyn TokenProvider>,
29    token_cache: TokenCache,
30}
31
32impl std::fmt::Debug for VertexAdapter {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("VertexAdapter")
35            .field("config", &self.config)
36            .field("project_id", &self.project_id)
37            .field("default_region", &self.default_region)
38            .field("model_regions", &self.model_regions)
39            .field("enable_1m_context", &self.enable_1m_context)
40            .finish_non_exhaustive()
41    }
42}
43
44impl VertexAdapter {
45    pub async fn from_env(config: ProviderConfig) -> Result<Self> {
46        let vertex_config = VertexConfig::from_env();
47        Self::from_config(config, vertex_config).await
48    }
49
50    pub async fn from_config(config: ProviderConfig, vertex: VertexConfig) -> Result<Self> {
51        let token_provider = gcp_auth::provider()
52            .await
53            .map_err(|e| Error::auth(e.to_string()))?;
54
55        let project_id = vertex
56            .project_id
57            .ok_or_else(|| Error::auth("No GCP project ID found"))?;
58
59        let default_region = vertex.region.unwrap_or_else(|| "us-central1".into());
60
61        Ok(Self {
62            config,
63            project_id,
64            default_region,
65            model_regions: vertex.model_regions,
66            enable_1m_context: vertex.enable_1m_context,
67            token_provider,
68            token_cache: new_token_cache(),
69        })
70    }
71
72    pub fn with_project(mut self, project_id: impl Into<String>) -> Self {
73        self.project_id = project_id.into();
74        self
75    }
76
77    pub fn with_region(mut self, region: impl Into<String>) -> Self {
78        self.default_region = region.into();
79        self
80    }
81
82    pub fn with_model_region(
83        mut self,
84        model_key: impl Into<String>,
85        region: impl Into<String>,
86    ) -> Self {
87        self.model_regions.insert(model_key.into(), region.into());
88        self
89    }
90
91    pub fn with_1m_context(mut self, enable: bool) -> Self {
92        self.enable_1m_context = enable;
93        self
94    }
95
96    fn region_for_model(&self, model: &str) -> &str {
97        for (key, region) in &self.model_regions {
98            if model.contains(key) {
99                return region;
100            }
101        }
102        &self.default_region
103    }
104
105    fn is_global(&self) -> bool {
106        self.default_region == "global"
107    }
108
109    fn build_url_for_model(&self, model: &str, stream: bool) -> String {
110        let region = self.region_for_model(model);
111        let endpoint = if stream {
112            "streamRawPredict"
113        } else {
114            "rawPredict"
115        };
116
117        if self.is_global() && region == "global" {
118            format!(
119                "https://aiplatform.googleapis.com/v1/projects/{}/locations/global/publishers/anthropic/models/{}:{}",
120                self.project_id, model, endpoint
121            )
122        } else {
123            format!(
124                "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/anthropic/models/{}:{}",
125                region, self.project_id, region, model, endpoint
126            )
127        }
128    }
129
130    fn build_request_body(&self, request: &CreateMessageRequest) -> serde_json::Value {
131        let mut body = build_messages_body(
132            request,
133            Some(ANTHROPIC_VERSION),
134            self.config.thinking_budget,
135        );
136
137        if let Some(obj) = body.as_object_mut() {
138            obj.remove("model");
139        }
140
141        if self.enable_1m_context {
142            add_beta_features(&mut body, &[BetaFeature::Context1M.header_value()]);
143        }
144
145        body
146    }
147
148    async fn get_token(&self) -> Result<String> {
149        {
150            let cache = self.token_cache.read().await;
151            if let Some(ref cached) = *cache
152                && !cached.is_expired()
153            {
154                return Ok(cached.token().to_string());
155            }
156        }
157
158        let scopes = &["https://www.googleapis.com/auth/cloud-platform"];
159        let token = self
160            .token_provider
161            .token(scopes)
162            .await
163            .map_err(|e| Error::auth(e.to_string()))?;
164
165        let token_str = token.as_str().to_string();
166        let cached = CachedToken::new(token_str.clone(), Duration::from_secs(3600));
167        *self.token_cache.write().await = Some(cached);
168
169        Ok(token_str)
170    }
171
172    async fn execute_request(
173        &self,
174        http: &reqwest::Client,
175        url: &str,
176        body: &serde_json::Value,
177    ) -> Result<reqwest::Response> {
178        let token = self.get_token().await?;
179        let headers = vec![("Authorization".into(), format!("Bearer {}", token))];
180        RequestExecutor::post(http, url, body, headers).await
181    }
182}
183
184#[async_trait]
185impl ProviderAdapter for VertexAdapter {
186    fn config(&self) -> &ProviderConfig {
187        &self.config
188    }
189
190    fn name(&self) -> &'static str {
191        "vertex"
192    }
193
194    async fn build_url(&self, model: &str, stream: bool) -> String {
195        self.build_url_for_model(model, stream)
196    }
197
198    async fn transform_request(&self, request: CreateMessageRequest) -> Result<serde_json::Value> {
199        Ok(self.build_request_body(&request))
200    }
201
202    fn transform_response(&self, response: serde_json::Value) -> Result<ApiResponse> {
203        serde_json::from_value(response).map_err(|e| Error::Parse(e.to_string()))
204    }
205
206    async fn send(
207        &self,
208        http: &reqwest::Client,
209        request: CreateMessageRequest,
210    ) -> Result<ApiResponse> {
211        let model = request.model.clone();
212        let url = self.build_url_for_model(&model, false);
213        let body = self.build_request_body(&request);
214
215        let response = self.execute_request(http, &url, &body).await?;
216        let json: serde_json::Value = response.json().await?;
217        self.transform_response(json)
218    }
219
220    async fn send_stream(
221        &self,
222        http: &reqwest::Client,
223        mut request: CreateMessageRequest,
224    ) -> Result<reqwest::Response> {
225        request.stream = Some(true);
226        let model = request.model.clone();
227        let url = self.build_url_for_model(&model, true);
228        let body = self.build_request_body(&request);
229
230        self.execute_request(http, &url, &body).await
231    }
232
233    async fn refresh_credentials(&self) -> Result<()> {
234        *self.token_cache.write().await = None;
235        self.get_token().await?;
236        Ok(())
237    }
238}
239
240#[cfg(test)]
241mod tests {
242
243    use crate::client::adapter::ModelConfig;
244
245    #[tokio::test]
246    async fn test_build_url() {
247        let url = format!(
248            "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/anthropic/models/{}:rawPredict",
249            "us-central1", "my-project", "us-central1", "claude-sonnet-4-5@20250929"
250        );
251        assert!(url.contains("aiplatform.googleapis.com"));
252        assert!(url.contains("rawPredict"));
253    }
254
255    #[test]
256    fn test_model_config() {
257        let config = ModelConfig::vertex();
258        assert!(config.primary.contains("@"));
259    }
260}