1use std::time::Duration;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use once_cell::sync::Lazy;
6use reqwest::{Client, StatusCode};
7use serde_json::Value;
8use tokio::time::sleep;
9use url::Url;
10
11use crate::conversation::message::Message;
12use crate::model::ModelConfig;
13use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
14
15use crate::providers::errors::ProviderError;
16use crate::providers::formats::gcpvertexai::{
17 create_request, get_usage, response_to_message, ClaudeVersion, GcpVertexAIModel, GeminiVersion,
18 ModelProvider, RequestContext,
19};
20
21use crate::providers::formats::gcpvertexai::GcpLocation::Iowa;
22use crate::providers::gcpauth::GcpAuth;
23use crate::providers::retry::RetryConfig;
24use crate::providers::utils::RequestLog;
25use rmcp::model::Tool;
26
27const GCP_VERTEX_AI_DOC_URL: &str = "https://cloud.google.com/vertex-ai";
29const DEFAULT_TIMEOUT_SECS: u64 = 600;
31const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 5000;
33const DEFAULT_MAX_RETRIES: usize = 6;
35const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
37const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 320_000;
39static STATUS_API_OVERLOADED: Lazy<StatusCode> =
41 Lazy::new(|| StatusCode::from_u16(529).expect("Valid status code 529 for API_OVERLOADED"));
42
43#[derive(Debug, thiserror::Error)]
45enum GcpVertexAIError {
46 #[error("Invalid URL configuration: {0}")]
48 InvalidUrl(String),
49
50 #[error("Authentication error: {0}")]
52 AuthError(String),
53}
54
55#[derive(Debug, serde::Serialize)]
61pub struct GcpVertexAIProvider {
62 #[serde(skip)]
64 client: Client,
65 #[serde(skip)]
67 auth: GcpAuth,
68 host: String,
70 project_id: String,
72 location: String,
74 model: ModelConfig,
76 #[serde(skip)]
78 retry_config: RetryConfig,
79 #[serde(skip)]
80 name: String,
81}
82
83impl GcpVertexAIProvider {
84 pub async fn from_env(model: ModelConfig) -> Result<Self> {
92 let config = crate::config::Config::global();
93 let project_id = config.get_param("GCP_PROJECT_ID")?;
94 let location = Self::determine_location(config)?;
95 let host = format!("https://{}-aiplatform.googleapis.com", location);
96
97 let client = Client::builder()
98 .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
99 .build()?;
100
101 let auth = GcpAuth::new().await?;
102
103 let retry_config = Self::load_retry_config(config);
105
106 Ok(Self {
107 client,
108 auth,
109 host,
110 project_id,
111 location,
112 model,
113 retry_config,
114 name: Self::metadata().name,
115 })
116 }
117
118 fn load_retry_config(config: &crate::config::Config) -> RetryConfig {
120 let max_retries = config
122 .get_param("GCP_MAX_RETRIES")
123 .ok()
124 .and_then(|v: String| v.parse::<usize>().ok())
125 .unwrap_or(DEFAULT_MAX_RETRIES);
126
127 let initial_interval_ms = config
128 .get_param("GCP_INITIAL_RETRY_INTERVAL_MS")
129 .ok()
130 .and_then(|v: String| v.parse::<u64>().ok())
131 .unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS);
132
133 let backoff_multiplier = config
134 .get_param("GCP_BACKOFF_MULTIPLIER")
135 .ok()
136 .and_then(|v: String| v.parse::<f64>().ok())
137 .unwrap_or(DEFAULT_BACKOFF_MULTIPLIER);
138
139 let max_interval_ms = config
140 .get_param("GCP_MAX_RETRY_INTERVAL_MS")
141 .ok()
142 .and_then(|v: String| v.parse::<u64>().ok())
143 .unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS);
144
145 RetryConfig::new(
146 max_retries,
147 initial_interval_ms,
148 backoff_multiplier,
149 max_interval_ms,
150 )
151 }
152
153 fn determine_location(config: &crate::config::Config) -> Result<String> {
159 Ok(config
160 .get_param("GCP_LOCATION")
161 .ok()
162 .filter(|location: &String| !location.trim().is_empty())
163 .unwrap_or_else(|| Iowa.to_string()))
164 }
165
166 async fn get_auth_header(&self) -> Result<String, GcpVertexAIError> {
168 self.auth
169 .get_token()
170 .await
171 .map(|token| format!("Bearer {}", token.token_value))
172 .map_err(|e| GcpVertexAIError::AuthError(e.to_string()))
173 }
174
175 fn build_request_url(
181 &self,
182 provider: ModelProvider,
183 location: &str,
184 ) -> Result<Url, GcpVertexAIError> {
185 let host_url = if self.location == location {
187 &self.host
188 } else {
189 &self.host.replace(&self.location, location)
191 };
192
193 let base_url =
194 Url::parse(host_url).map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))?;
195
196 let endpoint = match provider {
198 ModelProvider::Anthropic => "streamRawPredict",
199 ModelProvider::Google => "generateContent",
200 ModelProvider::MaaS(_) => "generateContent",
201 };
202
203 let path = format!(
205 "v1/projects/{}/locations/{}/publishers/{}/models/{}:{}",
206 self.project_id,
207 location,
208 provider.as_str(),
209 self.model.model_name,
210 endpoint
211 );
212
213 base_url
214 .join(&path)
215 .map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))
216 }
217
218 async fn post_with_location(
226 &self,
227 payload: &Value,
228 context: &RequestContext,
229 location: &str,
230 ) -> Result<Value, ProviderError> {
231 let url = self
232 .build_request_url(context.provider(), location)
233 .map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
234
235 let mut rate_limit_attempts = 0;
237 let mut overloaded_attempts = 0;
238 let mut last_error = None;
239
240 loop {
241 if rate_limit_attempts > self.retry_config.max_retries
243 && overloaded_attempts > self.retry_config.max_retries
244 {
245 let error_msg = format!(
246 "Exceeded maximum retry attempts ({}) for rate limiting errors",
247 self.retry_config.max_retries
248 );
249 tracing::error!("{}", error_msg);
250 return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded {
251 details: error_msg,
252 retry_delay: None,
253 }));
254 }
255
256 let auth_header = self
258 .get_auth_header()
259 .await
260 .map_err(|e| ProviderError::Authentication(e.to_string()))?;
261
262 let response = self
264 .client
265 .post(url.clone())
266 .json(payload)
267 .header("Authorization", auth_header)
268 .send()
269 .await
270 .map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
271
272 let status = response.status();
273
274 match status {
276 status if status == StatusCode::TOO_MANY_REQUESTS => {
277 rate_limit_attempts += 1;
278
279 if rate_limit_attempts > self.retry_config.max_retries {
280 let error_msg = format!(
281 "Exceeded maximum retry attempts ({}) for rate limiting (429) errors",
282 self.retry_config.max_retries
283 );
284 tracing::error!("{}", error_msg);
285 return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded {
286 details: error_msg,
287 retry_delay: None,
288 }));
289 }
290
291 let cite_gcp_vertex_429 =
293 "See https://cloud.google.com/vertex-ai/generative-ai/docs/error-code-429";
294 let response_text = response.text().await.unwrap_or_default();
295
296 let error_message =
297 if response_text.contains("Exceeded the Provisioned Throughput") {
298 format!("Exceeded the Provisioned Throughput: {cite_gcp_vertex_429}")
300 } else {
301 format!("Pay-as-you-go resource exhausted: {cite_gcp_vertex_429}")
303 };
304
305 tracing::warn!(
306 "Rate limit exceeded error (429) (attempt {}/{}): {}. Retrying after backoff...",
307 rate_limit_attempts,
308 self.retry_config.max_retries,
309 error_message
310 );
311
312 last_error = Some(ProviderError::RateLimitExceeded {
314 details: error_message,
315 retry_delay: None,
316 });
317
318 let delay = self.retry_config.delay_for_attempt(rate_limit_attempts);
320 tracing::info!("Backing off for {:?} before retry (rate limit 429)", delay);
321 sleep(delay).await;
322 }
323 status if status == *STATUS_API_OVERLOADED => {
324 overloaded_attempts += 1;
325
326 if overloaded_attempts > self.retry_config.max_retries {
327 let error_msg = format!(
328 "Exceeded maximum retry attempts ({}) for API overloaded (529) errors",
329 self.retry_config.max_retries
330 );
331 tracing::error!("{}", error_msg);
332 return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded {
333 details: error_msg,
334 retry_delay: None,
335 }));
336 }
337
338 let error_message =
340 "Vertex AI Provider API is temporarily overloaded. This is similar to a rate limit \
341 error but indicates backend processing capacity issues."
342 .to_string();
343
344 tracing::warn!(
345 "API overloaded error (529) (attempt {}/{}): {}. Retrying after backoff...",
346 overloaded_attempts,
347 self.retry_config.max_retries,
348 error_message
349 );
350
351 last_error = Some(ProviderError::RateLimitExceeded {
353 details: error_message,
354 retry_delay: None,
355 });
356
357 let delay = self.retry_config.delay_for_attempt(overloaded_attempts);
359 tracing::info!(
360 "Backing off for {:?} before retry (API overloaded 529)",
361 delay
362 );
363 sleep(delay).await;
364 }
365 _ => {
367 let response_json = response.json::<Value>().await.map_err(|e| {
368 ProviderError::RequestFailed(format!("Failed to parse response: {e}"))
369 })?;
370
371 return match status {
372 StatusCode::OK => Ok(response_json),
373 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
374 tracing::debug!(
375 "Authentication failed. Status: {status}, Payload: {payload:?}"
376 );
377 Err(ProviderError::Authentication(format!(
378 "Authentication failed: {response_json:?}"
379 )))
380 }
381 _ => {
382 tracing::debug!(
383 "Request failed. Status: {status}, Response: {response_json:?}"
384 );
385 Err(ProviderError::RequestFailed(format!(
386 "Request failed with status {status}: {response_json:?}"
387 )))
388 }
389 };
390 }
391 }
392 }
393 }
394
395 async fn post(
401 &self,
402 payload: &Value,
403 context: &RequestContext,
404 ) -> Result<Value, ProviderError> {
405 let result = self
407 .post_with_location(payload, context, &self.location)
408 .await;
409
410 if self.location == context.model.known_location().to_string() || result.is_ok() {
412 return result;
413 }
414
415 match &result {
417 Err(ProviderError::RequestFailed(msg)) => {
418 let model_name = context.model.to_string();
419 let configured_location = &self.location;
420 let known_location = context.model.known_location().to_string();
421
422 tracing::error!(
423 "Trying known location {known_location} for {model_name} instead of {configured_location}: {msg}"
424 );
425
426 self.post_with_location(payload, context, &known_location)
427 .await
428 }
429 _ => result,
431 }
432 }
433}
434
435#[async_trait]
436impl Provider for GcpVertexAIProvider {
437 fn metadata() -> ProviderMetadata
439 where
440 Self: Sized,
441 {
442 let model_strings: Vec<String> = [
443 GcpVertexAIModel::Claude(ClaudeVersion::Sonnet37),
444 GcpVertexAIModel::Claude(ClaudeVersion::Sonnet4),
445 GcpVertexAIModel::Claude(ClaudeVersion::Opus4),
446 GcpVertexAIModel::Gemini(GeminiVersion::Pro15),
447 GcpVertexAIModel::Gemini(GeminiVersion::Flash20),
448 GcpVertexAIModel::Gemini(GeminiVersion::Pro20Exp),
449 GcpVertexAIModel::Gemini(GeminiVersion::Pro25Exp),
450 GcpVertexAIModel::Gemini(GeminiVersion::Flash25Preview),
451 GcpVertexAIModel::Gemini(GeminiVersion::Pro25Preview),
452 GcpVertexAIModel::Gemini(GeminiVersion::Flash25),
453 GcpVertexAIModel::Gemini(GeminiVersion::Pro25),
454 ]
455 .iter()
456 .map(|model| model.to_string())
457 .collect();
458
459 let known_models: Vec<&str> = model_strings.iter().map(|s| s.as_str()).collect();
460
461 ProviderMetadata::new(
462 "gcp_vertex_ai",
463 "GCP Vertex AI",
464 "Access variety of AI models such as Claude, Gemini through Vertex AI",
465 "gemini-2.5-flash",
466 known_models,
467 GCP_VERTEX_AI_DOC_URL,
468 vec![
469 ConfigKey::new("GCP_PROJECT_ID", true, false, None),
470 ConfigKey::new("GCP_LOCATION", true, false, Some(Iowa.to_string().as_str())),
471 ConfigKey::new(
472 "GCP_MAX_RETRIES",
473 false,
474 false,
475 Some(&DEFAULT_MAX_RETRIES.to_string()),
476 ),
477 ConfigKey::new(
478 "GCP_INITIAL_RETRY_INTERVAL_MS",
479 false,
480 false,
481 Some(&DEFAULT_INITIAL_RETRY_INTERVAL_MS.to_string()),
482 ),
483 ConfigKey::new(
484 "GCP_BACKOFF_MULTIPLIER",
485 false,
486 false,
487 Some(&DEFAULT_BACKOFF_MULTIPLIER.to_string()),
488 ),
489 ConfigKey::new(
490 "GCP_MAX_RETRY_INTERVAL_MS",
491 false,
492 false,
493 Some(&DEFAULT_MAX_RETRY_INTERVAL_MS.to_string()),
494 ),
495 ],
496 )
497 }
498
499 fn get_name(&self) -> &str {
500 &self.name
501 }
502
503 #[tracing::instrument(
510 skip(self, model_config, system, messages, tools),
511 fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
512 )]
513 async fn complete_with_model(
514 &self,
515 model_config: &ModelConfig,
516 system: &str,
517 messages: &[Message],
518 tools: &[Tool],
519 ) -> Result<(Message, ProviderUsage), ProviderError> {
520 let (request, context) = create_request(model_config, system, messages, tools)?;
522
523 let response = self.post(&request, &context).await?;
525 let usage = get_usage(&response, &context)?;
526
527 let mut log = RequestLog::start(model_config, &request)?;
528 log.write(&response, Some(&usage))?;
529
530 let message = response_to_message(response, context)?;
532 let provider_usage = ProviderUsage::new(self.model.model_name.clone(), usage);
533
534 Ok((message, provider_usage))
535 }
536
537 fn get_model_config(&self) -> ModelConfig {
539 self.model.clone()
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546 use reqwest::StatusCode;
547
548 #[test]
549 fn test_retry_config_delay_calculation() {
550 let config = RetryConfig::new(5, 1000, 2.0, 32000);
551
552 let delay0 = config.delay_for_attempt(0);
554 assert_eq!(delay0.as_millis(), 0);
555
556 let delay1 = config.delay_for_attempt(1);
558 assert!(delay1.as_millis() >= 800 && delay1.as_millis() <= 1200);
559
560 let delay2 = config.delay_for_attempt(2);
562 assert!(delay2.as_millis() >= 1600 && delay2.as_millis() <= 2400);
563
564 let delay10 = config.delay_for_attempt(10);
566 assert!(delay10.as_millis() <= 38400); }
568
569 #[test]
570 fn test_status_overloaded_code() {
571 assert_eq!(STATUS_API_OVERLOADED.as_u16(), 529);
575
576 assert!(STATUS_API_OVERLOADED.is_server_error());
578
579 assert_ne!(*STATUS_API_OVERLOADED, StatusCode::TOO_MANY_REQUESTS);
581
582 assert_ne!(*STATUS_API_OVERLOADED, StatusCode::SERVICE_UNAVAILABLE);
584 }
585
586 #[test]
587 fn test_model_provider_conversion() {
588 assert_eq!(ModelProvider::Anthropic.as_str(), "anthropic".to_string());
589 assert_eq!(ModelProvider::Google.as_str(), "google".to_string());
590 assert_eq!(
591 ModelProvider::MaaS("qwen".to_string()).as_str(),
592 "qwen".to_string()
593 );
594 }
595
596 #[test]
597 fn test_url_construction() {
598 use url::Url;
599
600 let model_config = ModelConfig::new_or_fail("claude-sonnet-4-20250514");
601 let context = RequestContext::new(&model_config.model_name).unwrap();
602 let api_model_id = context.model.to_string();
603
604 let host = "https://us-east5-aiplatform.googleapis.com";
605 let project_id = "test-project";
606 let location = "us-east5";
607
608 let path = format!(
609 "v1/projects/{}/locations/{}/publishers/{}/models/{}:{}",
610 project_id,
611 location,
612 ModelProvider::Anthropic.as_str(),
613 api_model_id,
614 "streamRawPredict"
615 );
616
617 let url = Url::parse(host).unwrap().join(&path).unwrap();
618
619 assert!(url.as_str().contains("publishers/anthropic"));
620 assert!(url.as_str().contains("projects/test-project"));
621 assert!(url.as_str().contains("locations/us-east5"));
622 }
623
624 #[test]
625 fn test_provider_metadata() {
626 let metadata = GcpVertexAIProvider::metadata();
627 let model_names: Vec<String> = metadata
628 .known_models
629 .iter()
630 .map(|m| m.name.clone())
631 .collect();
632 assert!(model_names.contains(&"claude-3-7-sonnet@20250219".to_string()));
633 assert!(model_names.contains(&"claude-sonnet-4@20250514".to_string()));
634 assert!(model_names.contains(&"gemini-1.5-pro-002".to_string()));
635 assert!(model_names.contains(&"gemini-2.5-pro".to_string()));
636 assert_eq!(metadata.config_keys.len(), 6);
638 }
639}