Skip to main content

aster/providers/
gcpvertexai.rs

1use std::time::Duration;
2
3use anyhow::Result;
4use async_trait::async_trait;
5use once_cell::sync::Lazy;
6use reqwest::{Client, StatusCode};
7use serde_json::Value;
8use tokio::time::sleep;
9use url::Url;
10
11use crate::conversation::message::Message;
12use crate::model::ModelConfig;
13use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
14
15use crate::providers::errors::ProviderError;
16use crate::providers::formats::gcpvertexai::{
17    create_request, get_usage, response_to_message, ClaudeVersion, GcpVertexAIModel, GeminiVersion,
18    ModelProvider, RequestContext,
19};
20
21use crate::providers::formats::gcpvertexai::GcpLocation::Iowa;
22use crate::providers::gcpauth::GcpAuth;
23use crate::providers::retry::RetryConfig;
24use crate::providers::utils::RequestLog;
25use rmcp::model::Tool;
26
27/// Base URL for GCP Vertex AI documentation
28const GCP_VERTEX_AI_DOC_URL: &str = "https://cloud.google.com/vertex-ai";
29/// Default timeout for API requests in seconds
30const DEFAULT_TIMEOUT_SECS: u64 = 600;
31/// Default initial interval for retry (in milliseconds)
32const DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 5000;
33/// Default maximum number of retries
34const DEFAULT_MAX_RETRIES: usize = 6;
35/// Default retry backoff multiplier
36const DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
37/// Default maximum interval for retry (in milliseconds)
38const DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 320_000;
39/// Status code for Anthropic's API overloaded error (529)
40static STATUS_API_OVERLOADED: Lazy<StatusCode> =
41    Lazy::new(|| StatusCode::from_u16(529).expect("Valid status code 529 for API_OVERLOADED"));
42
43/// Represents errors specific to GCP Vertex AI operations.
44#[derive(Debug, thiserror::Error)]
45enum GcpVertexAIError {
46    /// Error when URL construction fails
47    #[error("Invalid URL configuration: {0}")]
48    InvalidUrl(String),
49
50    /// Error during GCP authentication
51    #[error("Authentication error: {0}")]
52    AuthError(String),
53}
54
55/// Provider implementation for Google Cloud Platform's Vertex AI service.
56///
57/// This provider enables interaction with various AI models hosted on GCP Vertex AI,
58/// including Claude and Gemini model families. It handles authentication, request routing,
59/// and response processing for the Vertex AI API endpoints.
60#[derive(Debug, serde::Serialize)]
61pub struct GcpVertexAIProvider {
62    /// HTTP client for making API requests
63    #[serde(skip)]
64    client: Client,
65    /// GCP authentication handler
66    #[serde(skip)]
67    auth: GcpAuth,
68    /// Base URL for the Vertex AI API
69    host: String,
70    /// GCP project identifier
71    project_id: String,
72    /// GCP region for model deployment
73    location: String,
74    /// Configuration for the specific model being used
75    model: ModelConfig,
76    /// Retry configuration for handling rate limit errors
77    #[serde(skip)]
78    retry_config: RetryConfig,
79    #[serde(skip)]
80    name: String,
81}
82
83impl GcpVertexAIProvider {
84    /// Creates a new provider instance from environment configuration.
85    ///
86    /// This is a convenience method that initializes the provider using
87    /// environment variables and default settings.
88    ///
89    /// # Arguments
90    /// * `model` - Configuration for the model to be used
91    pub async fn from_env(model: ModelConfig) -> Result<Self> {
92        let config = crate::config::Config::global();
93        let project_id = config.get_param("GCP_PROJECT_ID")?;
94        let location = Self::determine_location(config)?;
95        let host = format!("https://{}-aiplatform.googleapis.com", location);
96
97        let client = Client::builder()
98            .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
99            .build()?;
100
101        let auth = GcpAuth::new().await?;
102
103        // Load optional retry configuration from environment
104        let retry_config = Self::load_retry_config(config);
105
106        Ok(Self {
107            client,
108            auth,
109            host,
110            project_id,
111            location,
112            model,
113            retry_config,
114            name: Self::metadata().name,
115        })
116    }
117
118    /// Loads retry configuration from environment variables or uses defaults.
119    fn load_retry_config(config: &crate::config::Config) -> RetryConfig {
120        // Load max retries for 429 rate limit errors
121        let max_retries = config
122            .get_param("GCP_MAX_RETRIES")
123            .ok()
124            .and_then(|v: String| v.parse::<usize>().ok())
125            .unwrap_or(DEFAULT_MAX_RETRIES);
126
127        let initial_interval_ms = config
128            .get_param("GCP_INITIAL_RETRY_INTERVAL_MS")
129            .ok()
130            .and_then(|v: String| v.parse::<u64>().ok())
131            .unwrap_or(DEFAULT_INITIAL_RETRY_INTERVAL_MS);
132
133        let backoff_multiplier = config
134            .get_param("GCP_BACKOFF_MULTIPLIER")
135            .ok()
136            .and_then(|v: String| v.parse::<f64>().ok())
137            .unwrap_or(DEFAULT_BACKOFF_MULTIPLIER);
138
139        let max_interval_ms = config
140            .get_param("GCP_MAX_RETRY_INTERVAL_MS")
141            .ok()
142            .and_then(|v: String| v.parse::<u64>().ok())
143            .unwrap_or(DEFAULT_MAX_RETRY_INTERVAL_MS);
144
145        RetryConfig::new(
146            max_retries,
147            initial_interval_ms,
148            backoff_multiplier,
149            max_interval_ms,
150        )
151    }
152
153    /// Determines the appropriate GCP location for model deployment.
154    ///
155    /// Location is determined in the following order:
156    /// 1. Custom location from GCP_LOCATION environment variable
157    /// 2. Global default location (Iowa)
158    fn determine_location(config: &crate::config::Config) -> Result<String> {
159        Ok(config
160            .get_param("GCP_LOCATION")
161            .ok()
162            .filter(|location: &String| !location.trim().is_empty())
163            .unwrap_or_else(|| Iowa.to_string()))
164    }
165
166    /// Retrieves an authentication token for API requests.
167    async fn get_auth_header(&self) -> Result<String, GcpVertexAIError> {
168        self.auth
169            .get_token()
170            .await
171            .map(|token| format!("Bearer {}", token.token_value))
172            .map_err(|e| GcpVertexAIError::AuthError(e.to_string()))
173    }
174
175    /// Constructs the appropriate API endpoint URL for a given provider.
176    ///
177    /// # Arguments
178    /// * `provider` - The model provider (Anthropic or Google)
179    /// * `location` - The GCP location for model deployment
180    fn build_request_url(
181        &self,
182        provider: ModelProvider,
183        location: &str,
184    ) -> Result<Url, GcpVertexAIError> {
185        // Create host URL for the specified location
186        let host_url = if self.location == location {
187            &self.host
188        } else {
189            // Only allocate a new string if location differs
190            &self.host.replace(&self.location, location)
191        };
192
193        let base_url =
194            Url::parse(host_url).map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))?;
195
196        // Determine endpoint based on provider type
197        let endpoint = match provider {
198            ModelProvider::Anthropic => "streamRawPredict",
199            ModelProvider::Google => "generateContent",
200            ModelProvider::MaaS(_) => "generateContent",
201        };
202
203        // Construct path for URL
204        let path = format!(
205            "v1/projects/{}/locations/{}/publishers/{}/models/{}:{}",
206            self.project_id,
207            location,
208            provider.as_str(),
209            self.model.model_name,
210            endpoint
211        );
212
213        base_url
214            .join(&path)
215            .map_err(|e| GcpVertexAIError::InvalidUrl(e.to_string()))
216    }
217
218    /// Makes an authenticated POST request to the Vertex AI API at a specific location.
219    /// Includes retry logic for 429 (Too Many Requests) and 529 (API Overloaded) errors.
220    ///
221    /// # Arguments
222    /// * `payload` - The request payload to send
223    /// * `context` - Request context containing model information
224    /// * `location` - The GCP location for the request
225    async fn post_with_location(
226        &self,
227        payload: &Value,
228        context: &RequestContext,
229        location: &str,
230    ) -> Result<Value, ProviderError> {
231        let url = self
232            .build_request_url(context.provider(), location)
233            .map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
234
235        // Initialize separate counters for different error types
236        let mut rate_limit_attempts = 0;
237        let mut overloaded_attempts = 0;
238        let mut last_error = None;
239
240        loop {
241            // Check if we've exceeded max retries
242            if rate_limit_attempts > self.retry_config.max_retries
243                && overloaded_attempts > self.retry_config.max_retries
244            {
245                let error_msg = format!(
246                    "Exceeded maximum retry attempts ({}) for rate limiting errors",
247                    self.retry_config.max_retries
248                );
249                tracing::error!("{}", error_msg);
250                return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded {
251                    details: error_msg,
252                    retry_delay: None,
253                }));
254            }
255
256            // Get a fresh auth token for each attempt
257            let auth_header = self
258                .get_auth_header()
259                .await
260                .map_err(|e| ProviderError::Authentication(e.to_string()))?;
261
262            // Make the request
263            let response = self
264                .client
265                .post(url.clone())
266                .json(payload)
267                .header("Authorization", auth_header)
268                .send()
269                .await
270                .map_err(|e| ProviderError::RequestFailed(e.to_string()))?;
271
272            let status = response.status();
273
274            // Handle 429 Too Many Requests and 529 API Overloaded errors
275            match status {
276                status if status == StatusCode::TOO_MANY_REQUESTS => {
277                    rate_limit_attempts += 1;
278
279                    if rate_limit_attempts > self.retry_config.max_retries {
280                        let error_msg = format!(
281                            "Exceeded maximum retry attempts ({}) for rate limiting (429) errors",
282                            self.retry_config.max_retries
283                        );
284                        tracing::error!("{}", error_msg);
285                        return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded {
286                            details: error_msg,
287                            retry_delay: None,
288                        }));
289                    }
290
291                    // Try to parse response for more detailed error info
292                    let cite_gcp_vertex_429 =
293                        "See https://cloud.google.com/vertex-ai/generative-ai/docs/error-code-429";
294                    let response_text = response.text().await.unwrap_or_default();
295
296                    let error_message =
297                        if response_text.contains("Exceeded the Provisioned Throughput") {
298                            // Handle 429 rate limit due to throughput limits
299                            format!("Exceeded the Provisioned Throughput: {cite_gcp_vertex_429}")
300                        } else {
301                            // Handle generic 429 rate limit
302                            format!("Pay-as-you-go resource exhausted: {cite_gcp_vertex_429}")
303                        };
304
305                    tracing::warn!(
306                        "Rate limit exceeded error (429) (attempt {}/{}): {}. Retrying after backoff...",
307                        rate_limit_attempts,
308                        self.retry_config.max_retries,
309                        error_message
310                    );
311
312                    // Store the error in case we need to return it after max retries
313                    last_error = Some(ProviderError::RateLimitExceeded {
314                        details: error_message,
315                        retry_delay: None,
316                    });
317
318                    // Calculate and apply the backoff delay
319                    let delay = self.retry_config.delay_for_attempt(rate_limit_attempts);
320                    tracing::info!("Backing off for {:?} before retry (rate limit 429)", delay);
321                    sleep(delay).await;
322                }
323                status if status == *STATUS_API_OVERLOADED => {
324                    overloaded_attempts += 1;
325
326                    if overloaded_attempts > self.retry_config.max_retries {
327                        let error_msg = format!(
328                            "Exceeded maximum retry attempts ({}) for API overloaded (529) errors",
329                            self.retry_config.max_retries
330                        );
331                        tracing::error!("{}", error_msg);
332                        return Err(last_error.unwrap_or(ProviderError::RateLimitExceeded {
333                            details: error_msg,
334                            retry_delay: None,
335                        }));
336                    }
337
338                    // Handle 529 Overloaded error (https://docs.anthropic.com/en/api/errors)
339                    let error_message =
340                        "Vertex AI Provider API is temporarily overloaded. This is similar to a rate limit \
341                        error but indicates backend processing capacity issues."
342                            .to_string();
343
344                    tracing::warn!(
345                        "API overloaded error (529) (attempt {}/{}): {}. Retrying after backoff...",
346                        overloaded_attempts,
347                        self.retry_config.max_retries,
348                        error_message
349                    );
350
351                    // Store the error in case we need to return it after max retries
352                    last_error = Some(ProviderError::RateLimitExceeded {
353                        details: error_message,
354                        retry_delay: None,
355                    });
356
357                    // Calculate and apply the backoff delay
358                    let delay = self.retry_config.delay_for_attempt(overloaded_attempts);
359                    tracing::info!(
360                        "Backing off for {:?} before retry (API overloaded 529)",
361                        delay
362                    );
363                    sleep(delay).await;
364                }
365                // For any other status codes, process normally
366                _ => {
367                    let response_json = response.json::<Value>().await.map_err(|e| {
368                        ProviderError::RequestFailed(format!("Failed to parse response: {e}"))
369                    })?;
370
371                    return match status {
372                        StatusCode::OK => Ok(response_json),
373                        StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
374                            tracing::debug!(
375                                "Authentication failed. Status: {status}, Payload: {payload:?}"
376                            );
377                            Err(ProviderError::Authentication(format!(
378                                "Authentication failed: {response_json:?}"
379                            )))
380                        }
381                        _ => {
382                            tracing::debug!(
383                                "Request failed. Status: {status}, Response: {response_json:?}"
384                            );
385                            Err(ProviderError::RequestFailed(format!(
386                                "Request failed with status {status}: {response_json:?}"
387                            )))
388                        }
389                    };
390                }
391            }
392        }
393    }
394
395    /// Makes an authenticated POST request to the Vertex AI API with fallback for invalid locations.
396    ///
397    /// # Arguments
398    /// * `payload` - The request payload to send
399    /// * `context` - Request context containing model information
400    async fn post(
401        &self,
402        payload: &Value,
403        context: &RequestContext,
404    ) -> Result<Value, ProviderError> {
405        // Try with user-specified location first
406        let result = self
407            .post_with_location(payload, context, &self.location)
408            .await;
409
410        // If location is already the known location for the model or request succeeded, return result
411        if self.location == context.model.known_location().to_string() || result.is_ok() {
412            return result;
413        }
414
415        // Check if we should retry with the model's known location
416        match &result {
417            Err(ProviderError::RequestFailed(msg)) => {
418                let model_name = context.model.to_string();
419                let configured_location = &self.location;
420                let known_location = context.model.known_location().to_string();
421
422                tracing::error!(
423                    "Trying known location {known_location} for {model_name} instead of {configured_location}: {msg}"
424                );
425
426                self.post_with_location(payload, context, &known_location)
427                    .await
428            }
429            // For any other error, return the original result
430            _ => result,
431        }
432    }
433}
434
435#[async_trait]
436impl Provider for GcpVertexAIProvider {
437    /// Returns metadata about the GCP Vertex AI provider.
438    fn metadata() -> ProviderMetadata
439    where
440        Self: Sized,
441    {
442        let model_strings: Vec<String> = [
443            GcpVertexAIModel::Claude(ClaudeVersion::Sonnet37),
444            GcpVertexAIModel::Claude(ClaudeVersion::Sonnet4),
445            GcpVertexAIModel::Claude(ClaudeVersion::Opus4),
446            GcpVertexAIModel::Gemini(GeminiVersion::Pro15),
447            GcpVertexAIModel::Gemini(GeminiVersion::Flash20),
448            GcpVertexAIModel::Gemini(GeminiVersion::Pro20Exp),
449            GcpVertexAIModel::Gemini(GeminiVersion::Pro25Exp),
450            GcpVertexAIModel::Gemini(GeminiVersion::Flash25Preview),
451            GcpVertexAIModel::Gemini(GeminiVersion::Pro25Preview),
452            GcpVertexAIModel::Gemini(GeminiVersion::Flash25),
453            GcpVertexAIModel::Gemini(GeminiVersion::Pro25),
454        ]
455        .iter()
456        .map(|model| model.to_string())
457        .collect();
458
459        let known_models: Vec<&str> = model_strings.iter().map(|s| s.as_str()).collect();
460
461        ProviderMetadata::new(
462            "gcp_vertex_ai",
463            "GCP Vertex AI",
464            "Access variety of AI models such as Claude, Gemini through Vertex AI",
465            "gemini-2.5-flash",
466            known_models,
467            GCP_VERTEX_AI_DOC_URL,
468            vec![
469                ConfigKey::new("GCP_PROJECT_ID", true, false, None),
470                ConfigKey::new("GCP_LOCATION", true, false, Some(Iowa.to_string().as_str())),
471                ConfigKey::new(
472                    "GCP_MAX_RETRIES",
473                    false,
474                    false,
475                    Some(&DEFAULT_MAX_RETRIES.to_string()),
476                ),
477                ConfigKey::new(
478                    "GCP_INITIAL_RETRY_INTERVAL_MS",
479                    false,
480                    false,
481                    Some(&DEFAULT_INITIAL_RETRY_INTERVAL_MS.to_string()),
482                ),
483                ConfigKey::new(
484                    "GCP_BACKOFF_MULTIPLIER",
485                    false,
486                    false,
487                    Some(&DEFAULT_BACKOFF_MULTIPLIER.to_string()),
488                ),
489                ConfigKey::new(
490                    "GCP_MAX_RETRY_INTERVAL_MS",
491                    false,
492                    false,
493                    Some(&DEFAULT_MAX_RETRY_INTERVAL_MS.to_string()),
494                ),
495            ],
496        )
497    }
498
499    fn get_name(&self) -> &str {
500        &self.name
501    }
502
503    /// Completes a model interaction by sending a request and processing the response.
504    ///
505    /// # Arguments
506    /// * `system` - System prompt or context
507    /// * `messages` - Array of previous messages in the conversation
508    /// * `tools` - Array of available tools for the model
509    #[tracing::instrument(
510        skip(self, model_config, system, messages, tools),
511        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
512    )]
513    async fn complete_with_model(
514        &self,
515        model_config: &ModelConfig,
516        system: &str,
517        messages: &[Message],
518        tools: &[Tool],
519    ) -> Result<(Message, ProviderUsage), ProviderError> {
520        // Create request and context
521        let (request, context) = create_request(model_config, system, messages, tools)?;
522
523        // Send request and process response
524        let response = self.post(&request, &context).await?;
525        let usage = get_usage(&response, &context)?;
526
527        let mut log = RequestLog::start(model_config, &request)?;
528        log.write(&response, Some(&usage))?;
529
530        // Convert response to message
531        let message = response_to_message(response, context)?;
532        let provider_usage = ProviderUsage::new(self.model.model_name.clone(), usage);
533
534        Ok((message, provider_usage))
535    }
536
537    /// Returns the current model configuration.
538    fn get_model_config(&self) -> ModelConfig {
539        self.model.clone()
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546    use reqwest::StatusCode;
547
548    #[test]
549    fn test_retry_config_delay_calculation() {
550        let config = RetryConfig::new(5, 1000, 2.0, 32000);
551
552        // First attempt has no delay
553        let delay0 = config.delay_for_attempt(0);
554        assert_eq!(delay0.as_millis(), 0);
555
556        // First retry should be around initial_interval with jitter
557        let delay1 = config.delay_for_attempt(1);
558        assert!(delay1.as_millis() >= 800 && delay1.as_millis() <= 1200);
559
560        // Second retry should be around initial_interval * multiplier^1 with jitter
561        let delay2 = config.delay_for_attempt(2);
562        assert!(delay2.as_millis() >= 1600 && delay2.as_millis() <= 2400);
563
564        // Check that max interval is respected
565        let delay10 = config.delay_for_attempt(10);
566        assert!(delay10.as_millis() <= 38400); // max_interval_ms * 1.2 (max jitter)
567    }
568
569    #[test]
570    fn test_status_overloaded_code() {
571        // Test that we correctly handle the 529 status code
572
573        // Verify the custom status code is created correctly
574        assert_eq!(STATUS_API_OVERLOADED.as_u16(), 529);
575
576        // This is not a standard HTTP status code, so it's classified as server error
577        assert!(STATUS_API_OVERLOADED.is_server_error());
578
579        // Should be different from TOO_MANY_REQUESTS (429)
580        assert_ne!(*STATUS_API_OVERLOADED, StatusCode::TOO_MANY_REQUESTS);
581
582        // Should be different from SERVICE_UNAVAILABLE (503)
583        assert_ne!(*STATUS_API_OVERLOADED, StatusCode::SERVICE_UNAVAILABLE);
584    }
585
586    #[test]
587    fn test_model_provider_conversion() {
588        assert_eq!(ModelProvider::Anthropic.as_str(), "anthropic".to_string());
589        assert_eq!(ModelProvider::Google.as_str(), "google".to_string());
590        assert_eq!(
591            ModelProvider::MaaS("qwen".to_string()).as_str(),
592            "qwen".to_string()
593        );
594    }
595
596    #[test]
597    fn test_url_construction() {
598        use url::Url;
599
600        let model_config = ModelConfig::new_or_fail("claude-sonnet-4-20250514");
601        let context = RequestContext::new(&model_config.model_name).unwrap();
602        let api_model_id = context.model.to_string();
603
604        let host = "https://us-east5-aiplatform.googleapis.com";
605        let project_id = "test-project";
606        let location = "us-east5";
607
608        let path = format!(
609            "v1/projects/{}/locations/{}/publishers/{}/models/{}:{}",
610            project_id,
611            location,
612            ModelProvider::Anthropic.as_str(),
613            api_model_id,
614            "streamRawPredict"
615        );
616
617        let url = Url::parse(host).unwrap().join(&path).unwrap();
618
619        assert!(url.as_str().contains("publishers/anthropic"));
620        assert!(url.as_str().contains("projects/test-project"));
621        assert!(url.as_str().contains("locations/us-east5"));
622    }
623
624    #[test]
625    fn test_provider_metadata() {
626        let metadata = GcpVertexAIProvider::metadata();
627        let model_names: Vec<String> = metadata
628            .known_models
629            .iter()
630            .map(|m| m.name.clone())
631            .collect();
632        assert!(model_names.contains(&"claude-3-7-sonnet@20250219".to_string()));
633        assert!(model_names.contains(&"claude-sonnet-4@20250514".to_string()));
634        assert!(model_names.contains(&"gemini-1.5-pro-002".to_string()));
635        assert!(model_names.contains(&"gemini-2.5-pro".to_string()));
636        // Should contain the original 2 config keys plus 4 new retry-related ones
637        assert_eq!(metadata.config_keys.len(), 6);
638    }
639}