claude_agent/client/adapter/
vertex.rs1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use async_trait::async_trait;
8use gcp_auth::TokenProvider;
9
10use super::base::RequestExecutor;
11use super::config::{BetaFeature, ProviderConfig};
12use super::request::{add_beta_features, build_messages_body};
13use super::token_cache::{CachedToken, TokenCache, new_token_cache};
14use super::traits::ProviderAdapter;
15use crate::client::messages::CreateMessageRequest;
16use crate::config::VertexConfig;
17use crate::types::ApiResponse;
18use crate::{Error, Result};
19
20const ANTHROPIC_VERSION: &str = "vertex-2023-10-16";
21
22pub struct VertexAdapter {
23 config: ProviderConfig,
24 project_id: String,
25 default_region: String,
26 model_regions: HashMap<String, String>,
27 enable_1m_context: bool,
28 token_provider: Arc<dyn TokenProvider>,
29 token_cache: TokenCache,
30}
31
32impl std::fmt::Debug for VertexAdapter {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("VertexAdapter")
35 .field("config", &self.config)
36 .field("project_id", &self.project_id)
37 .field("default_region", &self.default_region)
38 .field("model_regions", &self.model_regions)
39 .field("enable_1m_context", &self.enable_1m_context)
40 .finish_non_exhaustive()
41 }
42}
43
44impl VertexAdapter {
45 pub async fn from_env(config: ProviderConfig) -> Result<Self> {
46 let vertex_config = VertexConfig::from_env();
47 Self::from_config(config, vertex_config).await
48 }
49
50 pub async fn from_config(config: ProviderConfig, vertex: VertexConfig) -> Result<Self> {
51 let token_provider = gcp_auth::provider()
52 .await
53 .map_err(|e| Error::auth(e.to_string()))?;
54
55 let project_id = vertex
56 .project_id
57 .ok_or_else(|| Error::auth("No GCP project ID found"))?;
58
59 let default_region = vertex.region.unwrap_or_else(|| "us-central1".into());
60
61 Ok(Self {
62 config,
63 project_id,
64 default_region,
65 model_regions: vertex.model_regions,
66 enable_1m_context: vertex.enable_1m_context,
67 token_provider,
68 token_cache: new_token_cache(),
69 })
70 }
71
72 pub fn with_project(mut self, project_id: impl Into<String>) -> Self {
73 self.project_id = project_id.into();
74 self
75 }
76
77 pub fn with_region(mut self, region: impl Into<String>) -> Self {
78 self.default_region = region.into();
79 self
80 }
81
82 pub fn with_model_region(
83 mut self,
84 model_key: impl Into<String>,
85 region: impl Into<String>,
86 ) -> Self {
87 self.model_regions.insert(model_key.into(), region.into());
88 self
89 }
90
91 pub fn with_1m_context(mut self, enable: bool) -> Self {
92 self.enable_1m_context = enable;
93 self
94 }
95
96 fn region_for_model(&self, model: &str) -> &str {
97 for (key, region) in &self.model_regions {
98 if model.contains(key) {
99 return region;
100 }
101 }
102 &self.default_region
103 }
104
105 fn is_global(&self) -> bool {
106 self.default_region == "global"
107 }
108
109 fn build_url_for_model(&self, model: &str, stream: bool) -> String {
110 let region = self.region_for_model(model);
111 let endpoint = if stream {
112 "streamRawPredict"
113 } else {
114 "rawPredict"
115 };
116
117 if self.is_global() && region == "global" {
118 format!(
119 "https://aiplatform.googleapis.com/v1/projects/{}/locations/global/publishers/anthropic/models/{}:{}",
120 self.project_id, model, endpoint
121 )
122 } else {
123 format!(
124 "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/anthropic/models/{}:{}",
125 region, self.project_id, region, model, endpoint
126 )
127 }
128 }
129
130 fn build_request_body(&self, request: &CreateMessageRequest) -> serde_json::Value {
131 let mut body = build_messages_body(
132 request,
133 Some(ANTHROPIC_VERSION),
134 self.config.thinking_budget,
135 );
136
137 if let Some(obj) = body.as_object_mut() {
138 obj.remove("model");
139 }
140
141 if self.enable_1m_context {
142 add_beta_features(&mut body, &[BetaFeature::Context1M.header_value()]);
143 }
144
145 body
146 }
147
148 async fn get_token(&self) -> Result<String> {
149 {
150 let cache = self.token_cache.read().await;
151 if let Some(ref cached) = *cache
152 && !cached.is_expired()
153 {
154 return Ok(cached.token().to_string());
155 }
156 }
157
158 let scopes = &["https://www.googleapis.com/auth/cloud-platform"];
159 let token = self
160 .token_provider
161 .token(scopes)
162 .await
163 .map_err(|e| Error::auth(e.to_string()))?;
164
165 let token_str = token.as_str().to_string();
166 let cached = CachedToken::new(token_str.clone(), Duration::from_secs(3600));
167 *self.token_cache.write().await = Some(cached);
168
169 Ok(token_str)
170 }
171
172 async fn execute_request(
173 &self,
174 http: &reqwest::Client,
175 url: &str,
176 body: &serde_json::Value,
177 ) -> Result<reqwest::Response> {
178 let token = self.get_token().await?;
179 let headers = vec![("Authorization".into(), format!("Bearer {}", token))];
180 RequestExecutor::post(http, url, body, headers).await
181 }
182}
183
184#[async_trait]
185impl ProviderAdapter for VertexAdapter {
186 fn config(&self) -> &ProviderConfig {
187 &self.config
188 }
189
190 fn name(&self) -> &'static str {
191 "vertex"
192 }
193
194 async fn build_url(&self, model: &str, stream: bool) -> String {
195 self.build_url_for_model(model, stream)
196 }
197
198 async fn transform_request(&self, request: CreateMessageRequest) -> Result<serde_json::Value> {
199 Ok(self.build_request_body(&request))
200 }
201
202 fn transform_response(&self, response: serde_json::Value) -> Result<ApiResponse> {
203 serde_json::from_value(response).map_err(|e| Error::Parse(e.to_string()))
204 }
205
206 async fn send(
207 &self,
208 http: &reqwest::Client,
209 request: CreateMessageRequest,
210 ) -> Result<ApiResponse> {
211 let model = request.model.clone();
212 let url = self.build_url_for_model(&model, false);
213 let body = self.build_request_body(&request);
214
215 let response = self.execute_request(http, &url, &body).await?;
216 let json: serde_json::Value = response.json().await?;
217 self.transform_response(json)
218 }
219
220 async fn send_stream(
221 &self,
222 http: &reqwest::Client,
223 mut request: CreateMessageRequest,
224 ) -> Result<reqwest::Response> {
225 request.stream = Some(true);
226 let model = request.model.clone();
227 let url = self.build_url_for_model(&model, true);
228 let body = self.build_request_body(&request);
229
230 self.execute_request(http, &url, &body).await
231 }
232
233 async fn refresh_credentials(&self) -> Result<()> {
234 *self.token_cache.write().await = None;
235 self.get_token().await?;
236 Ok(())
237 }
238}
239
240#[cfg(test)]
241mod tests {
242
243 use crate::client::adapter::ModelConfig;
244
245 #[tokio::test]
246 async fn test_build_url() {
247 let url = format!(
248 "https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/anthropic/models/{}:rawPredict",
249 "us-central1", "my-project", "us-central1", "claude-sonnet-4-5@20250929"
250 );
251 assert!(url.contains("aiplatform.googleapis.com"));
252 assert!(url.contains("rawPredict"));
253 }
254
255 #[test]
256 fn test_model_config() {
257 let config = ModelConfig::vertex();
258 assert!(config.primary.contains("@"));
259 }
260}