Skip to main content

pi/providers/
vertex.rs

1//! Google Vertex AI provider implementation.
2//!
3//! This module implements the Provider trait for Google Cloud Vertex AI,
4//! supporting both Google-native models (Gemini via Vertex) and Anthropic
5//! models hosted on Vertex AI.
6//!
7//! Vertex AI URL format (Google models):
8//! `https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/google/models/{model}:streamGenerateContent`
9//!
10//! Vertex AI URL format (Anthropic models):
11//! `https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/anthropic/models/{model}:streamRawPredict`
12
13use crate::error::{Error, Result};
14use crate::http::client::Client;
15use crate::model::{
16    AssistantMessage, ContentBlock, StopReason, StreamEvent, TextContent, ToolCall, Usage,
17};
18use crate::models::CompatConfig;
19use crate::provider::{Context, Provider, StreamOptions};
20use crate::providers::gemini::{
21    self, GeminiCandidate, GeminiContent, GeminiFunctionCall, GeminiFunctionCallingConfig,
22    GeminiGenerationConfig, GeminiPart, GeminiRequest, GeminiStreamResponse, GeminiTool,
23    GeminiToolConfig,
24};
25use crate::sse::SseStream;
26use async_trait::async_trait;
27use futures::StreamExt;
28use futures::stream::{self, Stream};
29use std::collections::VecDeque;
30use std::pin::Pin;
31
32// ============================================================================
33// Constants
34// ============================================================================
35
36const VERTEX_DEFAULT_REGION: &str = "us-central1";
37
38/// Environment variable for the Google Cloud project ID.
39const VERTEX_PROJECT_ENV: &str = "GOOGLE_CLOUD_PROJECT";
40/// Fallback: `VERTEX_PROJECT` is a common alternative.
41const VERTEX_PROJECT_ENV_ALT: &str = "VERTEX_PROJECT";
42
43/// Environment variable for the Vertex AI region/location.
44const VERTEX_LOCATION_ENV: &str = "GOOGLE_CLOUD_LOCATION";
45/// Fallback: `VERTEX_LOCATION` is a common alternative.
46const VERTEX_LOCATION_ENV_ALT: &str = "VERTEX_LOCATION";
47
48// ============================================================================
49// Vertex AI Provider
50// ============================================================================
51
52/// Google Vertex AI provider supporting both Google-native (Gemini) and
53/// Anthropic models via Vertex endpoints.
54pub struct VertexProvider {
55    client: Client,
56    model: String,
57    /// GCP project ID (required).
58    project: Option<String>,
59    /// GCP region / location (default: `us-central1`).
60    location: String,
61    /// Publisher: `"google"` for Gemini models, `"anthropic"` for Claude models.
62    publisher: String,
63    /// Optional override for the full endpoint URL (for tests).
64    endpoint_url_override: Option<String>,
65    compat: Option<CompatConfig>,
66}
67
68impl VertexProvider {
69    /// Create a new Vertex AI provider for Google-native (Gemini) models.
70    pub fn new(model: impl Into<String>) -> Self {
71        Self {
72            client: Client::new(),
73            model: model.into(),
74            project: None,
75            location: VERTEX_DEFAULT_REGION.to_string(),
76            publisher: "google".to_string(),
77            endpoint_url_override: None,
78            compat: None,
79        }
80    }
81
82    /// Set the GCP project ID.
83    #[must_use]
84    pub fn with_project(mut self, project: impl Into<String>) -> Self {
85        self.project = Some(project.into());
86        self
87    }
88
89    /// Set the GCP region/location.
90    #[must_use]
91    pub fn with_location(mut self, location: impl Into<String>) -> Self {
92        self.location = location.into();
93        self
94    }
95
96    /// Set the publisher (`"google"` or `"anthropic"`).
97    #[must_use]
98    pub fn with_publisher(mut self, publisher: impl Into<String>) -> Self {
99        self.publisher = publisher.into();
100        self
101    }
102
103    /// Override the full endpoint URL (for deterministic tests).
104    #[must_use]
105    pub fn with_endpoint_url(mut self, url: impl Into<String>) -> Self {
106        self.endpoint_url_override = Some(url.into());
107        self
108    }
109
110    /// Attach provider-specific compatibility overrides.
111    #[must_use]
112    pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
113        self.compat = compat;
114        self
115    }
116
117    /// Create with a custom HTTP client (VCR, test harness, etc.).
118    #[must_use]
119    pub fn with_client(mut self, client: Client) -> Self {
120        self.client = client;
121        self
122    }
123
124    /// Resolve the GCP project from explicit config or environment.
125    fn resolve_project(&self) -> Result<String> {
126        if let Some(project) = &self.project {
127            return Ok(project.clone());
128        }
129        std::env::var(VERTEX_PROJECT_ENV)
130            .or_else(|_| std::env::var(VERTEX_PROJECT_ENV_ALT))
131            .map_err(|_| {
132                Error::provider(
133                    "google-vertex",
134                    format!(
135                        "Missing GCP project. Set {VERTEX_PROJECT_ENV} or {VERTEX_PROJECT_ENV_ALT}, \
136                         or configure `project` in provider settings."
137                    ),
138                )
139            })
140    }
141
142    /// Resolve the GCP location from explicit config or environment.
143    fn resolve_location(&self) -> String {
144        if self.location != VERTEX_DEFAULT_REGION {
145            return self.location.clone();
146        }
147        std::env::var(VERTEX_LOCATION_ENV)
148            .or_else(|_| std::env::var(VERTEX_LOCATION_ENV_ALT))
149            .unwrap_or_else(|_| VERTEX_DEFAULT_REGION.to_string())
150    }
151
152    /// Build the streaming endpoint URL.
153    ///
154    /// Google models: `.../publishers/google/models/{model}:streamGenerateContent`
155    /// Anthropic models: `.../publishers/anthropic/models/{model}:streamRawPredict`
156    fn streaming_url(&self, project: &str, location: &str) -> String {
157        if let Some(url) = &self.endpoint_url_override {
158            return url.clone();
159        }
160
161        let method = if self.publisher == "anthropic" {
162            "streamRawPredict"
163        } else {
164            "streamGenerateContent"
165        };
166
167        format!(
168            "https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{method}",
169            location = location,
170            project = project,
171            publisher = self.publisher,
172            model = self.model,
173            method = method,
174        )
175    }
176
177    /// Build the Gemini-format request body (for Google-native models).
178    #[allow(clippy::unused_self)]
179    pub fn build_gemini_request(
180        &self,
181        context: &Context<'_>,
182        options: &StreamOptions,
183    ) -> GeminiRequest {
184        let contents = Self::build_contents(context);
185        let system_instruction = context.system_prompt.as_deref().map(|s| GeminiContent {
186            role: None,
187            parts: vec![GeminiPart::Text {
188                text: s.to_string(),
189            }],
190        });
191
192        let tools: Option<Vec<GeminiTool>> = if context.tools.is_empty() {
193            None
194        } else {
195            Some(vec![GeminiTool {
196                function_declarations: context
197                    .tools
198                    .iter()
199                    .map(gemini::convert_tool_to_gemini)
200                    .collect(),
201            }])
202        };
203
204        let tool_config = if tools.is_some() {
205            Some(GeminiToolConfig {
206                function_calling_config: GeminiFunctionCallingConfig { mode: "AUTO" },
207            })
208        } else {
209            None
210        };
211
212        GeminiRequest {
213            contents,
214            system_instruction,
215            tools,
216            tool_config,
217            generation_config: Some(GeminiGenerationConfig {
218                max_output_tokens: options.max_tokens.or(Some(gemini::DEFAULT_MAX_TOKENS)),
219                temperature: options.temperature,
220                candidate_count: Some(1),
221            }),
222        }
223    }
224
225    /// Build the contents array from context messages.
226    fn build_contents(context: &Context<'_>) -> Vec<GeminiContent> {
227        let mut contents = Vec::new();
228        for message in context.messages.iter() {
229            contents.extend(gemini::convert_message_to_gemini(message));
230        }
231        contents
232    }
233}
234
235#[async_trait]
236impl Provider for VertexProvider {
237    fn name(&self) -> &'static str {
238        "google-vertex"
239    }
240
241    fn api(&self) -> &'static str {
242        "google-vertex"
243    }
244
245    fn model_id(&self) -> &str {
246        &self.model
247    }
248
249    #[allow(clippy::too_many_lines)]
250    async fn stream(
251        &self,
252        context: &Context<'_>,
253        options: &StreamOptions,
254    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
255        // Resolve auth: Bearer token for Vertex AI.
256        let auth_value = options
257            .api_key
258            .clone()
259            .or_else(|| std::env::var("GOOGLE_CLOUD_API_KEY").ok())
260            .or_else(|| std::env::var("VERTEX_API_KEY").ok())
261            .ok_or_else(|| {
262                Error::provider(
263                    "google-vertex",
264                    "Missing Vertex AI API key / access token. \
265                     Set GOOGLE_CLOUD_API_KEY or VERTEX_API_KEY.",
266                )
267            })?;
268
269        let project = self.resolve_project()?;
270        let location = self.resolve_location();
271        let url = self.streaming_url(&project, &location);
272
273        // Build request body in Gemini format (Google-native models).
274        let request_body = self.build_gemini_request(context, options);
275
276        // Build HTTP request with Bearer auth.
277        let mut request = self
278            .client
279            .post(&url)
280            .header("Accept", "text/event-stream")
281            .header("Authorization", format!("Bearer {auth_value}"));
282
283        // Apply provider-specific custom headers from compat config.
284        if let Some(compat) = &self.compat {
285            if let Some(custom_headers) = &compat.custom_headers {
286                for (key, value) in custom_headers {
287                    request = request.header(key, value);
288                }
289            }
290        }
291
292        // Per-request headers from `StreamOptions` (highest priority).
293        for (key, value) in &options.headers {
294            request = request.header(key, value);
295        }
296
297        let request = request.json(&request_body)?;
298
299        let response = Box::pin(request.send()).await?;
300        let status = response.status();
301        if !(200..300).contains(&status) {
302            let body = response
303                .text()
304                .await
305                .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
306            return Err(Error::provider(
307                "google-vertex",
308                format!("Vertex AI API error (HTTP {status}): {body}"),
309            ));
310        }
311
312        // Create SSE stream for streaming responses.
313        let event_source = SseStream::new(response.bytes_stream());
314
315        // Create stream state — same response format as Gemini.
316        let model = self.model.clone();
317        let api = self.api().to_string();
318        let provider = self.name().to_string();
319
320        let stream = stream::unfold(
321            StreamState::new(event_source, model, api, provider),
322            |mut state| async move {
323                if state.finished {
324                    return None;
325                }
326                loop {
327                    // Drain pending events before polling for more SSE data.
328                    if let Some(event) = state.pending_events.pop_front() {
329                        return Some((Ok(event), state));
330                    }
331
332                    match state.event_source.next().await {
333                        Some(Ok(msg)) => {
334                            state.transient_error_count = 0;
335                            if msg.event == "ping" {
336                                continue;
337                            }
338
339                            if let Err(e) = state.process_event(&msg.data) {
340                                state.finished = true;
341                                return Some((Err(e), state));
342                            }
343                        }
344                        Some(Err(e)) => {
345                            // WriteZero, WouldBlock, and TimedOut errors are treated as transient.
346                            // Skip them and keep reading the stream, but cap
347                            // consecutive occurrences to avoid infinite loops.
348                            const MAX_CONSECUTIVE_TRANSIENT_ERRORS: usize = 5;
349                            if e.kind() == std::io::ErrorKind::WriteZero
350                                || e.kind() == std::io::ErrorKind::WouldBlock
351                                || e.kind() == std::io::ErrorKind::TimedOut
352                            {
353                                state.transient_error_count += 1;
354                                if state.transient_error_count <= MAX_CONSECUTIVE_TRANSIENT_ERRORS {
355                                    tracing::warn!(
356                                        kind = ?e.kind(),
357                                        count = state.transient_error_count,
358                                        "Transient error in SSE stream, continuing"
359                                    );
360                                    continue;
361                                }
362                                tracing::warn!(
363                                    kind = ?e.kind(),
364                                    "Error persisted after {MAX_CONSECUTIVE_TRANSIENT_ERRORS} \
365                                     consecutive attempts, treating as fatal"
366                                );
367                            }
368                            state.finished = true;
369                            let err = Error::api(format!("SSE error: {e}"));
370                            return Some((Err(err), state));
371                        }
372                        None => {
373                            // Stream ended naturally.
374                            state.finished = true;
375                            let reason = state.partial.stop_reason;
376                            let message = std::mem::take(&mut state.partial);
377                            return Some((Ok(StreamEvent::Done { reason, message }), state));
378                        }
379                    }
380                }
381            },
382        );
383
384        Ok(Box::pin(stream))
385    }
386}
387
388// ============================================================================
389// Stream State (reuses Gemini response format)
390// ============================================================================
391
392struct StreamState<S>
393where
394    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
395{
396    event_source: SseStream<S>,
397    partial: AssistantMessage,
398    pending_events: VecDeque<StreamEvent>,
399    started: bool,
400    finished: bool,
401    /// Consecutive WriteZero errors seen without a successful event in between.
402    transient_error_count: usize,
403}
404
405impl<S> StreamState<S>
406where
407    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
408{
409    fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
410        Self {
411            event_source,
412            partial: AssistantMessage {
413                content: Vec::new(),
414                api,
415                provider,
416                model,
417                usage: Usage::default(),
418                stop_reason: StopReason::Stop,
419                error_message: None,
420                timestamp: chrono::Utc::now().timestamp_millis(),
421            },
422            pending_events: VecDeque::new(),
423            started: false,
424            finished: false,
425            transient_error_count: 0,
426        }
427    }
428
429    fn process_event(&mut self, data: &str) -> Result<()> {
430        let response: GeminiStreamResponse = serde_json::from_str(data)
431            .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
432
433        // Handle usage metadata.
434        if let Some(metadata) = response.usage_metadata {
435            self.partial.usage.input = metadata.prompt_token_count.unwrap_or(0);
436            self.partial.usage.output = metadata.candidates_token_count.unwrap_or(0);
437            self.partial.usage.total_tokens = metadata.total_token_count.unwrap_or(0);
438        }
439
440        // Process candidates.
441        if let Some(candidates) = response.candidates {
442            if let Some(candidate) = candidates.into_iter().next() {
443                self.process_candidate(candidate)?;
444            }
445        }
446
447        Ok(())
448    }
449
450    #[allow(clippy::unnecessary_wraps)]
451    fn process_candidate(&mut self, candidate: GeminiCandidate) -> Result<()> {
452        // Handle finish reason.
453        if let Some(ref reason) = candidate.finish_reason {
454            self.partial.stop_reason = match reason.as_str() {
455                "MAX_TOKENS" => StopReason::Length,
456                "SAFETY" | "RECITATION" | "OTHER" => StopReason::Error,
457                "FUNCTION_CALL" => StopReason::ToolUse,
458                _ => StopReason::Stop,
459            };
460        }
461
462        // Process content parts — queue all events into pending_events.
463        if let Some(content) = candidate.content {
464            for part in content.parts {
465                match part {
466                    GeminiPart::Text { text } => {
467                        let last_is_text =
468                            matches!(self.partial.content.last(), Some(ContentBlock::Text(_)));
469                        if !last_is_text {
470                            let content_index = self.partial.content.len();
471                            self.partial
472                                .content
473                                .push(ContentBlock::Text(TextContent::new("")));
474
475                            self.ensure_started();
476
477                            self.pending_events
478                                .push_back(StreamEvent::TextStart { content_index });
479                        }
480                        let content_index = self.partial.content.len() - 1;
481
482                        if let Some(ContentBlock::Text(t)) =
483                            self.partial.content.get_mut(content_index)
484                        {
485                            t.text.push_str(&text);
486                        }
487
488                        self.ensure_started();
489
490                        self.pending_events.push_back(StreamEvent::TextDelta {
491                            content_index,
492                            delta: text,
493                        });
494                    }
495                    GeminiPart::FunctionCall { function_call } => {
496                        let id = format!("call_{}", uuid::Uuid::new_v4().simple());
497
498                        let args_str = serde_json::to_string(&function_call.args)
499                            .unwrap_or_else(|_| "{}".to_string());
500                        let GeminiFunctionCall { name, args } = function_call;
501
502                        let tool_call = ToolCall {
503                            id,
504                            name,
505                            arguments: args,
506                            thought_signature: None,
507                        };
508
509                        self.partial
510                            .content
511                            .push(ContentBlock::ToolCall(tool_call.clone()));
512                        let content_index = self.partial.content.len() - 1;
513
514                        self.partial.stop_reason = StopReason::ToolUse;
515
516                        self.ensure_started();
517
518                        self.pending_events
519                            .push_back(StreamEvent::ToolCallStart { content_index });
520                        self.pending_events.push_back(StreamEvent::ToolCallDelta {
521                            content_index,
522                            delta: args_str,
523                        });
524                        self.pending_events.push_back(StreamEvent::ToolCallEnd {
525                            content_index,
526                            tool_call,
527                        });
528                    }
529                    GeminiPart::InlineData { .. }
530                    | GeminiPart::FunctionResponse { .. }
531                    | GeminiPart::Unknown(_) => {
532                        // Input-only parts are skipped.
533                        // Unknown parts are also skipped so new Gemini API part
534                        // variants don't break streaming.
535                    }
536                }
537            }
538        }
539
540        // Emit TextEnd/ThinkingEnd for all open text/thinking blocks when a finish reason
541        // is present.
542        if candidate.finish_reason.is_some() {
543            for (content_index, block) in self.partial.content.iter().enumerate() {
544                if let ContentBlock::Text(t) = block {
545                    self.pending_events.push_back(StreamEvent::TextEnd {
546                        content_index,
547                        content: t.text.clone(),
548                    });
549                } else if let ContentBlock::Thinking(t) = block {
550                    self.pending_events.push_back(StreamEvent::ThinkingEnd {
551                        content_index,
552                        content: t.thinking.clone(),
553                    });
554                }
555            }
556        }
557
558        Ok(())
559    }
560
561    fn ensure_started(&mut self) {
562        if !self.started {
563            self.started = true;
564            self.pending_events.push_back(StreamEvent::Start {
565                partial: self.partial.clone(),
566            });
567        }
568    }
569}
570
571// ============================================================================
572// Vertex Runtime Resolution (similar to Azure runtime resolution)
573// ============================================================================
574
575/// Resolved Vertex AI runtime configuration.
576#[derive(Debug, Clone, PartialEq, Eq)]
577pub(crate) struct VertexProviderRuntime {
578    pub(crate) project: String,
579    pub(crate) location: String,
580    pub(crate) publisher: String,
581    pub(crate) model: String,
582}
583
584/// Resolve Vertex AI provider runtime from a `ModelEntry`.
585///
586/// Configuration sources (highest priority first):
587/// 1. Explicit fields parsed from `base_url`
588/// 2. Environment variables (`GOOGLE_CLOUD_PROJECT`, `GOOGLE_CLOUD_LOCATION`)
589/// 3. Defaults (location: `us-central1`, publisher: `google`)
590pub(crate) fn resolve_vertex_provider_runtime(
591    entry: &crate::models::ModelEntry,
592) -> Result<VertexProviderRuntime> {
593    // Try to parse project/location/publisher from base_url.
594    let (url_project, url_location, url_publisher) = parse_vertex_base_url(&entry.model.base_url);
595
596    let project = url_project
597        .or_else(|| std::env::var(VERTEX_PROJECT_ENV).ok())
598        .or_else(|| std::env::var(VERTEX_PROJECT_ENV_ALT).ok())
599        .ok_or_else(|| {
600            Error::provider(
601                "google-vertex",
602                format!(
603                    "Missing GCP project. Set {VERTEX_PROJECT_ENV} or provide a Vertex AI base URL \
604                     like https://REGION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/REGION/..."
605                ),
606            )
607        })?;
608
609    let location = url_location
610        .or_else(|| std::env::var(VERTEX_LOCATION_ENV).ok())
611        .or_else(|| std::env::var(VERTEX_LOCATION_ENV_ALT).ok())
612        .unwrap_or_else(|| VERTEX_DEFAULT_REGION.to_string());
613
614    let publisher = url_publisher.unwrap_or_else(|| "google".to_string());
615
616    Ok(VertexProviderRuntime {
617        project,
618        location,
619        publisher,
620        model: entry.model.id.clone(),
621    })
622}
623
624/// Parse project, location, and publisher from a Vertex AI base URL.
625///
626/// Expected format:
627/// `https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/publishers/{publisher}/...`
628fn parse_vertex_base_url(base_url: &str) -> (Option<String>, Option<String>, Option<String>) {
629    if base_url.is_empty() {
630        return (None, None, None);
631    }
632
633    // Extract location from hostname: "{location}-aiplatform.googleapis.com"
634    let location_from_host = base_url
635        .strip_prefix("https://")
636        .or_else(|| base_url.strip_prefix("http://"))
637        .and_then(|rest| rest.split('-').next())
638        .and_then(|loc| {
639            // Validate it looks like a region (e.g. "us", "europe", "asia").
640            if loc.chars().all(|c| c.is_ascii_lowercase() || c == '-') && !loc.is_empty() {
641                Some(loc.to_string())
642            } else {
643                None
644            }
645        });
646
647    // Extract project, location, publisher from path segments.
648    let path_segments: Vec<&str> = base_url.split('/').collect();
649
650    let project = path_segments
651        .iter()
652        .zip(path_segments.iter().skip(1))
653        .find(|(key, _)| **key == "projects")
654        .map(|(_, val)| (*val).to_string());
655
656    let location = path_segments
657        .iter()
658        .zip(path_segments.iter().skip(1))
659        .find(|(key, _)| **key == "locations")
660        .map(|(_, val)| (*val).to_string())
661        .or(location_from_host);
662
663    let publisher = path_segments
664        .iter()
665        .zip(path_segments.iter().skip(1))
666        .find(|(key, _)| **key == "publishers")
667        .map(|(_, val)| (*val).to_string());
668
669    (project, location, publisher)
670}
671
672// ============================================================================
673// Tests
674// ============================================================================
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679    use crate::model::{Message, UserContent};
680    use crate::provider::ToolDef;
681    use asupersync::runtime::RuntimeBuilder;
682    use futures::{StreamExt, stream};
683    use serde_json::Value;
684
685    #[test]
686    fn test_provider_info() {
687        let provider = VertexProvider::new("gemini-2.0-flash");
688        assert_eq!(provider.name(), "google-vertex");
689        assert_eq!(provider.api(), "google-vertex");
690        assert_eq!(provider.model_id(), "gemini-2.0-flash");
691    }
692
693    #[test]
694    fn test_streaming_url_google_publisher() {
695        let provider = VertexProvider::new("gemini-2.0-flash")
696            .with_project("my-project")
697            .with_location("us-central1");
698
699        let url = provider.streaming_url("my-project", "us-central1");
700        assert_eq!(
701            url,
702            "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:streamGenerateContent"
703        );
704    }
705
706    #[test]
707    fn test_streaming_url_anthropic_publisher() {
708        let provider = VertexProvider::new("claude-sonnet-4-20250514")
709            .with_project("my-project")
710            .with_location("europe-west1")
711            .with_publisher("anthropic");
712
713        let url = provider.streaming_url("my-project", "europe-west1");
714        assert_eq!(
715            url,
716            "https://europe-west1-aiplatform.googleapis.com/v1/projects/my-project/locations/europe-west1/publishers/anthropic/models/claude-sonnet-4-20250514:streamRawPredict"
717        );
718    }
719
720    #[test]
721    fn test_streaming_url_override() {
722        let provider =
723            VertexProvider::new("gemini-2.0-flash").with_endpoint_url("http://127.0.0.1:8080/mock");
724
725        let url = provider.streaming_url("ignored", "ignored");
726        assert_eq!(url, "http://127.0.0.1:8080/mock");
727    }
728
729    #[test]
730    fn test_build_gemini_request_basic() {
731        let provider = VertexProvider::new("gemini-2.0-flash");
732        let context = Context::owned(
733            Some("You are helpful.".to_string()),
734            vec![Message::User(crate::model::UserMessage {
735                content: UserContent::Text("What is Vertex AI?".to_string()),
736                timestamp: 0,
737            })],
738            vec![],
739        );
740        let options = StreamOptions {
741            max_tokens: Some(1024),
742            temperature: Some(0.7),
743            ..Default::default()
744        };
745
746        let req = provider.build_gemini_request(&context, &options);
747        let json = serde_json::to_value(&req).expect("serialize");
748
749        let contents = json["contents"].as_array().expect("contents");
750        assert_eq!(contents.len(), 1);
751        assert_eq!(contents[0]["role"], "user");
752        assert_eq!(contents[0]["parts"][0]["text"], "What is Vertex AI?");
753
754        assert_eq!(
755            json["systemInstruction"]["parts"][0]["text"],
756            "You are helpful."
757        );
758        assert_eq!(json["generationConfig"]["maxOutputTokens"], 1024);
759    }
760
761    #[test]
762    fn test_build_gemini_request_with_tools() {
763        let provider = VertexProvider::new("gemini-2.0-flash");
764        let context = Context::owned(
765            None,
766            vec![Message::User(crate::model::UserMessage {
767                content: UserContent::Text("Read a file".to_string()),
768                timestamp: 0,
769            })],
770            vec![ToolDef {
771                name: "read".to_string(),
772                description: "Read a file".to_string(),
773                parameters: serde_json::json!({
774                    "type": "object",
775                    "properties": { "path": {"type": "string"} },
776                    "required": ["path"]
777                }),
778            }],
779        );
780        let options = StreamOptions::default();
781
782        let req = provider.build_gemini_request(&context, &options);
783        let json = serde_json::to_value(&req).expect("serialize");
784
785        let tools = json["tools"].as_array().expect("tools");
786        assert_eq!(tools.len(), 1);
787        let decls = tools[0]["functionDeclarations"]
788            .as_array()
789            .expect("declarations");
790        assert_eq!(decls[0]["name"], "read");
791        assert_eq!(json["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
792    }
793
794    #[test]
795    fn test_parse_vertex_base_url_full() {
796        let url = "https://us-central1-aiplatform.googleapis.com/v1/projects/my-proj/locations/us-central1/publishers/google/models/gemini-2.0-flash";
797        let (project, location, publisher) = parse_vertex_base_url(url);
798        assert_eq!(project.as_deref(), Some("my-proj"));
799        assert_eq!(location.as_deref(), Some("us-central1"));
800        assert_eq!(publisher.as_deref(), Some("google"));
801    }
802
803    #[test]
804    fn test_parse_vertex_base_url_anthropic() {
805        let url = "https://europe-west1-aiplatform.googleapis.com/v1/projects/corp-ai/locations/europe-west1/publishers/anthropic/models/claude-sonnet-4-20250514";
806        let (project, location, publisher) = parse_vertex_base_url(url);
807        assert_eq!(project.as_deref(), Some("corp-ai"));
808        assert_eq!(location.as_deref(), Some("europe-west1"));
809        assert_eq!(publisher.as_deref(), Some("anthropic"));
810    }
811
812    #[test]
813    fn test_parse_vertex_base_url_empty() {
814        let (project, location, publisher) = parse_vertex_base_url("");
815        assert!(project.is_none());
816        assert!(location.is_none());
817        assert!(publisher.is_none());
818    }
819
820    #[test]
821    fn test_parse_vertex_base_url_partial() {
822        let url = "https://us-central1-aiplatform.googleapis.com/v1/projects/my-proj/locations/us-central1";
823        let (project, location, publisher) = parse_vertex_base_url(url);
824        assert_eq!(project.as_deref(), Some("my-proj"));
825        assert_eq!(location.as_deref(), Some("us-central1"));
826        assert!(publisher.is_none());
827    }
828
829    #[test]
830    fn test_resolve_vertex_provider_runtime_from_url() {
831        let entry = crate::models::ModelEntry {
832            model: crate::provider::Model {
833                id: "gemini-2.0-flash".to_string(),
834                name: "Gemini 2.0 Flash".to_string(),
835                api: "google-vertex".to_string(),
836                provider: "google-vertex".to_string(),
837                base_url: "https://us-central1-aiplatform.googleapis.com/v1/projects/test-proj/locations/us-central1/publishers/google/models/gemini-2.0-flash".to_string(),
838                reasoning: false,
839                input: vec![],
840                cost: crate::provider::ModelCost {
841                    input: 0.0,
842                    output: 0.0,
843                    cache_read: 0.0,
844                    cache_write: 0.0,
845                },
846                context_window: 128_000,
847                max_tokens: 8192,
848                headers: std::collections::HashMap::new(),
849            },
850            api_key: None,
851            headers: std::collections::HashMap::new(),
852            auth_header: true,
853            compat: None,
854            oauth_config: None,
855        };
856
857        let runtime = resolve_vertex_provider_runtime(&entry).expect("resolve");
858        assert_eq!(runtime.project, "test-proj");
859        assert_eq!(runtime.location, "us-central1");
860        assert_eq!(runtime.publisher, "google");
861        assert_eq!(runtime.model, "gemini-2.0-flash");
862    }
863
864    // ─── Streaming response parsing ──────────────────────────────────────
865
866    #[test]
867    fn test_stream_text_response() {
868        let events = vec![
869            serde_json::json!({
870                "candidates": [{
871                    "content": {
872                        "role": "model",
873                        "parts": [{"text": "Hello from "}]
874                    }
875                }]
876            }),
877            serde_json::json!({
878                "candidates": [{
879                    "content": {
880                        "role": "model",
881                        "parts": [{"text": "Vertex AI!"}]
882                    },
883                    "finishReason": "STOP"
884                }],
885                "usageMetadata": {
886                    "promptTokenCount": 10,
887                    "candidatesTokenCount": 5,
888                    "totalTokenCount": 15
889                }
890            }),
891        ];
892
893        let stream_events = collect_events(&events);
894
895        // Should have: Start, TextDelta("Hello from "), TextDelta("Vertex AI!"), Done
896        assert!(
897            stream_events
898                .iter()
899                .any(|e| matches!(e, StreamEvent::Start { .. })),
900            "should emit Start"
901        );
902
903        let text_deltas: Vec<&str> = stream_events
904            .iter()
905            .filter_map(|e| match e {
906                StreamEvent::TextDelta { delta, .. } => Some(delta.as_str()),
907                _ => None,
908            })
909            .collect();
910        assert_eq!(text_deltas, vec!["Hello from ", "Vertex AI!"]);
911
912        let done = stream_events
913            .iter()
914            .find_map(|e| match e {
915                StreamEvent::Done { message, .. } => Some(message),
916                _ => None,
917            })
918            .expect("done event");
919        assert_eq!(done.usage.input, 10);
920        assert_eq!(done.usage.output, 5);
921    }
922
923    #[test]
924    fn test_stream_tool_call_response() {
925        let events = vec![serde_json::json!({
926            "candidates": [{
927                "content": {
928                    "role": "model",
929                    "parts": [{
930                        "functionCall": {
931                            "name": "read",
932                            "args": {"path": "/tmp/test.txt"}
933                        }
934                    }]
935                },
936                "finishReason": "STOP"
937            }]
938        })];
939
940        let stream_events = collect_events(&events);
941
942        assert!(
943            stream_events
944                .iter()
945                .any(|e| matches!(e, StreamEvent::ToolCallStart { .. })),
946            "should emit ToolCallStart"
947        );
948        assert!(
949            stream_events
950                .iter()
951                .any(|e| matches!(e, StreamEvent::ToolCallEnd { .. })),
952            "should emit ToolCallEnd"
953        );
954
955        let done = stream_events
956            .iter()
957            .find_map(|e| match e {
958                StreamEvent::Done { message, .. } => Some(message),
959                _ => None,
960            })
961            .expect("done event");
962        assert_eq!(done.stop_reason, StopReason::ToolUse);
963    }
964
965    #[test]
966    fn test_stream_ignores_unknown_parts() {
967        let events = vec![serde_json::json!({
968            "candidates": [{
969                "content": {
970                    "role": "model",
971                    "parts": [
972                        {
973                            "executableCode": {
974                                "language": "python",
975                                "code": "print('x')"
976                            }
977                        },
978                        {"text": "still works"}
979                    ]
980                },
981                "finishReason": "STOP"
982            }]
983        })];
984
985        let stream_events = collect_events(&events);
986
987        let text_deltas: Vec<&str> = stream_events
988            .iter()
989            .filter_map(|e| match e {
990                StreamEvent::TextDelta { delta, .. } => Some(delta.as_str()),
991                _ => None,
992            })
993            .collect();
994        assert_eq!(text_deltas, vec!["still works"]);
995        assert!(
996            stream_events
997                .iter()
998                .any(|e| matches!(e, StreamEvent::Done { .. })),
999            "should emit Done even when unknown parts are present"
1000        );
1001    }
1002
1003    // ─── Test helpers ────────────────────────────────────────────────────
1004
1005    fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1006        let runtime = RuntimeBuilder::current_thread()
1007            .build()
1008            .expect("runtime build");
1009        runtime.block_on(async move {
1010            let byte_stream = stream::iter(
1011                events
1012                    .iter()
1013                    .map(|event| {
1014                        let data = serde_json::to_string(event).expect("serialize event");
1015                        format!("data: {data}\n\n").into_bytes()
1016                    })
1017                    .map(Ok),
1018            );
1019            let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1020            let mut state = StreamState::new(
1021                event_source,
1022                "gemini-test".to_string(),
1023                "google-vertex".to_string(),
1024                "google-vertex".to_string(),
1025            );
1026            let mut out = Vec::new();
1027
1028            loop {
1029                let Some(item) = state.event_source.next().await else {
1030                    if !state.finished {
1031                        state.finished = true;
1032                        out.push(StreamEvent::Done {
1033                            reason: state.partial.stop_reason,
1034                            message: std::mem::take(&mut state.partial),
1035                        });
1036                    }
1037                    break;
1038                };
1039
1040                let msg = item.expect("SSE event");
1041                if msg.event == "ping" {
1042                    continue;
1043                }
1044                state.process_event(&msg.data).expect("process_event");
1045                out.extend(state.pending_events.drain(..));
1046            }
1047
1048            out
1049        })
1050    }
1051}
1052
1053// ============================================================================
1054// Fuzzing support
1055// ============================================================================
1056
1057#[cfg(feature = "fuzzing")]
1058pub mod fuzz {
1059    use super::*;
1060    use futures::stream;
1061    use std::pin::Pin;
1062
1063    type FuzzStream =
1064        Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1065
1066    /// Opaque wrapper around the Vertex AI stream processor state.
1067    pub struct Processor(StreamState<FuzzStream>);
1068
1069    impl Default for Processor {
1070        fn default() -> Self {
1071            Self::new()
1072        }
1073    }
1074
1075    impl Processor {
1076        /// Create a fresh processor with default state.
1077        pub fn new() -> Self {
1078            let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1079            Self(StreamState::new(
1080                crate::sse::SseStream::new(Box::pin(empty)),
1081                "vertex-fuzz".into(),
1082                "vertex-ai".into(),
1083                "vertex".into(),
1084            ))
1085        }
1086
1087        /// Feed one SSE data payload and return any emitted `StreamEvent`s.
1088        pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1089            self.0.process_event(data)?;
1090            Ok(self.0.pending_events.drain(..).collect())
1091        }
1092    }
1093}