Skip to main content

pi/providers/
vertex.rs

1//! Google Vertex AI provider implementation.
2//!
3//! This module implements the Provider trait for Google Cloud Vertex AI,
4//! supporting both Google-native models (Gemini via Vertex) and Anthropic
5//! models hosted on Vertex AI.
6//!
7//! Vertex AI URL format (Google models):
8//! `https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/google/models/{model}:streamGenerateContent`
9//!
10//! Vertex AI URL format (Anthropic models):
11//! `https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/anthropic/models/{model}:streamRawPredict`
12
13use crate::error::{Error, Result};
14use crate::http::client::Client;
15use crate::model::{
16    AssistantMessage, ContentBlock, StopReason, StreamEvent, TextContent, ToolCall, Usage,
17};
18use crate::models::CompatConfig;
19use crate::provider::{Context, Provider, StreamOptions};
20use crate::providers::gemini::{
21    self, GeminiCandidate, GeminiContent, GeminiFunctionCall, GeminiFunctionCallingConfig,
22    GeminiGenerationConfig, GeminiPart, GeminiRequest, GeminiStreamResponse, GeminiTool,
23    GeminiToolConfig,
24};
25use crate::sse::SseStream;
26use async_trait::async_trait;
27use futures::StreamExt;
28use futures::stream::{self, Stream};
29use std::collections::VecDeque;
30use std::pin::Pin;
31
32// ============================================================================
33// Constants
34// ============================================================================
35
36const VERTEX_DEFAULT_REGION: &str = "us-central1";
37
38/// Environment variable for the Google Cloud project ID.
39const VERTEX_PROJECT_ENV: &str = "GOOGLE_CLOUD_PROJECT";
40/// Fallback: `VERTEX_PROJECT` is a common alternative.
41const VERTEX_PROJECT_ENV_ALT: &str = "VERTEX_PROJECT";
42
43/// Environment variable for the Vertex AI region/location.
44const VERTEX_LOCATION_ENV: &str = "GOOGLE_CLOUD_LOCATION";
45/// Fallback: `VERTEX_LOCATION` is a common alternative.
46const VERTEX_LOCATION_ENV_ALT: &str = "VERTEX_LOCATION";
47
48// ============================================================================
49// Vertex AI Provider
50// ============================================================================
51
52/// Google Vertex AI provider supporting both Google-native (Gemini) and
53/// Anthropic models via Vertex endpoints.
54pub struct VertexProvider {
55    client: Client,
56    model: String,
57    /// GCP project ID (required).
58    project: Option<String>,
59    /// GCP region / location (default: `us-central1`).
60    location: String,
61    /// Publisher: `"google"` for Gemini models, `"anthropic"` for Claude models.
62    publisher: String,
63    /// Optional override for the full endpoint URL (for tests).
64    endpoint_url_override: Option<String>,
65    compat: Option<CompatConfig>,
66}
67
68impl VertexProvider {
69    /// Create a new Vertex AI provider for Google-native (Gemini) models.
70    pub fn new(model: impl Into<String>) -> Self {
71        Self {
72            client: Client::new(),
73            model: model.into(),
74            project: None,
75            location: VERTEX_DEFAULT_REGION.to_string(),
76            publisher: "google".to_string(),
77            endpoint_url_override: None,
78            compat: None,
79        }
80    }
81
82    /// Set the GCP project ID.
83    #[must_use]
84    pub fn with_project(mut self, project: impl Into<String>) -> Self {
85        self.project = Some(project.into());
86        self
87    }
88
89    /// Set the GCP region/location.
90    #[must_use]
91    pub fn with_location(mut self, location: impl Into<String>) -> Self {
92        self.location = location.into();
93        self
94    }
95
96    /// Set the publisher (`"google"` or `"anthropic"`).
97    #[must_use]
98    pub fn with_publisher(mut self, publisher: impl Into<String>) -> Self {
99        self.publisher = publisher.into();
100        self
101    }
102
103    /// Override the full endpoint URL (for deterministic tests).
104    #[must_use]
105    pub fn with_endpoint_url(mut self, url: impl Into<String>) -> Self {
106        self.endpoint_url_override = Some(url.into());
107        self
108    }
109
110    /// Attach provider-specific compatibility overrides.
111    #[must_use]
112    pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
113        self.compat = compat;
114        self
115    }
116
117    /// Create with a custom HTTP client (VCR, test harness, etc.).
118    #[must_use]
119    pub fn with_client(mut self, client: Client) -> Self {
120        self.client = client;
121        self
122    }
123
124    /// Resolve the GCP project from explicit config or environment.
125    fn resolve_project(&self) -> Result<String> {
126        if let Some(project) = &self.project {
127            return Ok(project.clone());
128        }
129        std::env::var(VERTEX_PROJECT_ENV)
130            .or_else(|_| std::env::var(VERTEX_PROJECT_ENV_ALT))
131            .map_err(|_| {
132                Error::provider(
133                    "google-vertex",
134                    format!(
135                        "Missing GCP project. Set {VERTEX_PROJECT_ENV} or {VERTEX_PROJECT_ENV_ALT}, \
136                         or configure `project` in provider settings."
137                    ),
138                )
139            })
140    }
141
142    /// Resolve the GCP location from explicit config or environment.
143    fn resolve_location(&self) -> String {
144        if self.location != VERTEX_DEFAULT_REGION {
145            return self.location.clone();
146        }
147        std::env::var(VERTEX_LOCATION_ENV)
148            .or_else(|_| std::env::var(VERTEX_LOCATION_ENV_ALT))
149            .unwrap_or_else(|_| VERTEX_DEFAULT_REGION.to_string())
150    }
151
152    /// Build the streaming endpoint URL.
153    ///
154    /// Google models: `.../publishers/google/models/{model}:streamGenerateContent`
155    /// Anthropic models: `.../publishers/anthropic/models/{model}:streamRawPredict`
156    fn streaming_url(&self, project: &str, location: &str) -> String {
157        if let Some(url) = &self.endpoint_url_override {
158            return url.clone();
159        }
160
161        let method = if self.publisher == "anthropic" {
162            "streamRawPredict"
163        } else {
164            "streamGenerateContent"
165        };
166
167        format!(
168            "https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{method}",
169            location = location,
170            project = project,
171            publisher = self.publisher,
172            model = self.model,
173            method = method,
174        )
175    }
176
177    /// Build the Gemini-format request body (for Google-native models).
178    #[allow(clippy::unused_self)]
179    pub fn build_gemini_request(
180        &self,
181        context: &Context<'_>,
182        options: &StreamOptions,
183    ) -> GeminiRequest {
184        let contents = Self::build_contents(context);
185        let system_instruction = context.system_prompt.as_deref().map(|s| GeminiContent {
186            role: None,
187            parts: vec![GeminiPart::Text {
188                text: s.to_string(),
189            }],
190        });
191
192        let tools: Option<Vec<GeminiTool>> = if context.tools.is_empty() {
193            None
194        } else {
195            Some(vec![GeminiTool {
196                function_declarations: context
197                    .tools
198                    .iter()
199                    .map(gemini::convert_tool_to_gemini)
200                    .collect(),
201            }])
202        };
203
204        let tool_config = if tools.is_some() {
205            Some(GeminiToolConfig {
206                function_calling_config: GeminiFunctionCallingConfig { mode: "AUTO" },
207            })
208        } else {
209            None
210        };
211
212        GeminiRequest {
213            contents,
214            system_instruction,
215            tools,
216            tool_config,
217            generation_config: Some(GeminiGenerationConfig {
218                max_output_tokens: options.max_tokens.or(Some(gemini::DEFAULT_MAX_TOKENS)),
219                temperature: options.temperature,
220                candidate_count: Some(1),
221            }),
222        }
223    }
224
225    /// Build the contents array from context messages.
226    fn build_contents(context: &Context<'_>) -> Vec<GeminiContent> {
227        let mut contents = Vec::new();
228        for message in context.messages.iter() {
229            contents.extend(gemini::convert_message_to_gemini(message));
230        }
231        contents
232    }
233}
234
235#[async_trait]
236impl Provider for VertexProvider {
237    fn name(&self) -> &'static str {
238        "google-vertex"
239    }
240
241    fn api(&self) -> &'static str {
242        "google-vertex"
243    }
244
245    fn model_id(&self) -> &str {
246        &self.model
247    }
248
249    #[allow(clippy::too_many_lines)]
250    async fn stream(
251        &self,
252        context: &Context<'_>,
253        options: &StreamOptions,
254    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
255        // Resolve auth: Bearer token for Vertex AI.
256        let auth_value = options
257            .api_key
258            .clone()
259            .or_else(|| std::env::var("GOOGLE_CLOUD_API_KEY").ok())
260            .or_else(|| std::env::var("VERTEX_API_KEY").ok())
261            .ok_or_else(|| {
262                Error::provider(
263                    "google-vertex",
264                    "Missing Vertex AI API key / access token. \
265                     Set GOOGLE_CLOUD_API_KEY or VERTEX_API_KEY.",
266                )
267            })?;
268
269        let project = self.resolve_project()?;
270        let location = self.resolve_location();
271        let url = self.streaming_url(&project, &location);
272
273        // Build request body in Gemini format (Google-native models).
274        let request_body = self.build_gemini_request(context, options);
275
276        // Build HTTP request with Bearer auth.
277        let mut request = self
278            .client
279            .post(&url)
280            .header("Accept", "text/event-stream")
281            .header("Authorization", format!("Bearer {auth_value}"));
282
283        // Apply provider-specific custom headers from compat config.
284        if let Some(compat) = &self.compat {
285            if let Some(custom_headers) = &compat.custom_headers {
286                for (key, value) in custom_headers {
287                    request = request.header(key, value);
288                }
289            }
290        }
291
292        // Per-request headers from `StreamOptions` (highest priority).
293        for (key, value) in &options.headers {
294            request = request.header(key, value);
295        }
296
297        let request = request.json(&request_body)?;
298
299        let response = Box::pin(request.send()).await?;
300        let status = response.status();
301        if !(200..300).contains(&status) {
302            let body = response
303                .text()
304                .await
305                .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
306            return Err(Error::provider(
307                "google-vertex",
308                format!("Vertex AI API error (HTTP {status}): {body}"),
309            ));
310        }
311
312        // Create SSE stream for streaming responses.
313        let event_source = SseStream::new(response.bytes_stream());
314
315        // Create stream state — same response format as Gemini.
316        let model = self.model.clone();
317        let api = self.api().to_string();
318        let provider = self.name().to_string();
319
320        let stream = stream::unfold(
321            StreamState::new(event_source, model, api, provider),
322            |mut state| async move {
323                if state.finished {
324                    return None;
325                }
326                loop {
327                    // Drain pending events before polling for more SSE data.
328                    if let Some(event) = state.pending_events.pop_front() {
329                        return Some((Ok(event), state));
330                    }
331
332                    match state.event_source.next().await {
333                        Some(Ok(msg)) => {
334                            if msg.event == "ping" {
335                                continue;
336                            }
337
338                            if let Err(e) = state.process_event(&msg.data) {
339                                state.finished = true;
340                                return Some((Err(e), state));
341                            }
342                        }
343                        Some(Err(e)) => {
344                            state.finished = true;
345                            let err = Error::api(format!("SSE error: {e}"));
346                            return Some((Err(err), state));
347                        }
348                        None => {
349                            // Stream ended naturally.
350                            state.finished = true;
351                            let reason = state.partial.stop_reason;
352                            let message = std::mem::take(&mut state.partial);
353                            return Some((Ok(StreamEvent::Done { reason, message }), state));
354                        }
355                    }
356                }
357            },
358        );
359
360        Ok(Box::pin(stream))
361    }
362}
363
364// ============================================================================
365// Stream State (reuses Gemini response format)
366// ============================================================================
367
368struct StreamState<S>
369where
370    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
371{
372    event_source: SseStream<S>,
373    partial: AssistantMessage,
374    pending_events: VecDeque<StreamEvent>,
375    started: bool,
376    finished: bool,
377}
378
379impl<S> StreamState<S>
380where
381    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
382{
383    fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
384        Self {
385            event_source,
386            partial: AssistantMessage {
387                content: Vec::new(),
388                api,
389                provider,
390                model,
391                usage: Usage::default(),
392                stop_reason: StopReason::Stop,
393                error_message: None,
394                timestamp: chrono::Utc::now().timestamp_millis(),
395            },
396            pending_events: VecDeque::new(),
397            started: false,
398            finished: false,
399        }
400    }
401
402    fn process_event(&mut self, data: &str) -> Result<()> {
403        let response: GeminiStreamResponse = serde_json::from_str(data)
404            .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
405
406        // Handle usage metadata.
407        if let Some(metadata) = response.usage_metadata {
408            self.partial.usage.input = metadata.prompt_token_count.unwrap_or(0);
409            self.partial.usage.output = metadata.candidates_token_count.unwrap_or(0);
410            self.partial.usage.total_tokens = metadata.total_token_count.unwrap_or(0);
411        }
412
413        // Process candidates.
414        if let Some(candidates) = response.candidates {
415            if let Some(candidate) = candidates.into_iter().next() {
416                self.process_candidate(candidate)?;
417            }
418        }
419
420        Ok(())
421    }
422
423    #[allow(clippy::unnecessary_wraps)]
424    fn process_candidate(&mut self, candidate: GeminiCandidate) -> Result<()> {
425        // Handle finish reason.
426        if let Some(ref reason) = candidate.finish_reason {
427            self.partial.stop_reason = match reason.as_str() {
428                "MAX_TOKENS" => StopReason::Length,
429                "SAFETY" | "RECITATION" | "OTHER" => StopReason::Error,
430                _ => StopReason::Stop,
431            };
432        }
433
434        // Process content parts — queue all events into pending_events.
435        if let Some(content) = candidate.content {
436            for part in content.parts {
437                match part {
438                    GeminiPart::Text { text } => {
439                        let last_is_text =
440                            matches!(self.partial.content.last(), Some(ContentBlock::Text(_)));
441                        if !last_is_text {
442                            let content_index = self.partial.content.len();
443                            self.partial
444                                .content
445                                .push(ContentBlock::Text(TextContent::new("")));
446
447                            self.ensure_started();
448
449                            self.pending_events
450                                .push_back(StreamEvent::TextStart { content_index });
451                        }
452                        let content_index = self.partial.content.len() - 1;
453
454                        if let Some(ContentBlock::Text(t)) =
455                            self.partial.content.get_mut(content_index)
456                        {
457                            t.text.push_str(&text);
458                        }
459
460                        self.ensure_started();
461
462                        self.pending_events.push_back(StreamEvent::TextDelta {
463                            content_index,
464                            delta: text,
465                        });
466                    }
467                    GeminiPart::FunctionCall { function_call } => {
468                        let id = format!("call_{}", uuid::Uuid::new_v4().simple());
469
470                        let args_str = serde_json::to_string(&function_call.args)
471                            .unwrap_or_else(|_| "{}".to_string());
472                        let GeminiFunctionCall { name, args } = function_call;
473
474                        let tool_call = ToolCall {
475                            id,
476                            name,
477                            arguments: args,
478                            thought_signature: None,
479                        };
480
481                        self.partial
482                            .content
483                            .push(ContentBlock::ToolCall(tool_call.clone()));
484                        let content_index = self.partial.content.len() - 1;
485
486                        self.partial.stop_reason = StopReason::ToolUse;
487
488                        self.ensure_started();
489
490                        self.pending_events
491                            .push_back(StreamEvent::ToolCallStart { content_index });
492                        self.pending_events.push_back(StreamEvent::ToolCallDelta {
493                            content_index,
494                            delta: args_str,
495                        });
496                        self.pending_events.push_back(StreamEvent::ToolCallEnd {
497                            content_index,
498                            tool_call,
499                        });
500                    }
501                    GeminiPart::InlineData { .. }
502                    | GeminiPart::FunctionResponse { .. }
503                    | GeminiPart::Unknown(_) => {
504                        // Input-only parts are skipped.
505                        // Unknown parts are also skipped so new Gemini API part
506                        // variants don't break streaming.
507                    }
508                }
509            }
510        }
511
512        // Emit TextEnd/ThinkingEnd for all open text/thinking blocks when a finish reason
513        // is present.
514        if candidate.finish_reason.is_some() {
515            for (content_index, block) in self.partial.content.iter().enumerate() {
516                if let ContentBlock::Text(t) = block {
517                    self.pending_events.push_back(StreamEvent::TextEnd {
518                        content_index,
519                        content: t.text.clone(),
520                    });
521                } else if let ContentBlock::Thinking(t) = block {
522                    self.pending_events.push_back(StreamEvent::ThinkingEnd {
523                        content_index,
524                        content: t.thinking.clone(),
525                    });
526                }
527            }
528        }
529
530        Ok(())
531    }
532
533    fn ensure_started(&mut self) {
534        if !self.started {
535            self.started = true;
536            self.pending_events.push_back(StreamEvent::Start {
537                partial: self.partial.clone(),
538            });
539        }
540    }
541}
542
543// ============================================================================
544// Vertex Runtime Resolution (similar to Azure runtime resolution)
545// ============================================================================
546
547/// Resolved Vertex AI runtime configuration.
548#[derive(Debug, Clone, PartialEq, Eq)]
549pub(crate) struct VertexProviderRuntime {
550    pub(crate) project: String,
551    pub(crate) location: String,
552    pub(crate) publisher: String,
553    pub(crate) model: String,
554}
555
556/// Resolve Vertex AI provider runtime from a `ModelEntry`.
557///
558/// Configuration sources (highest priority first):
559/// 1. Explicit fields parsed from `base_url`
560/// 2. Environment variables (`GOOGLE_CLOUD_PROJECT`, `GOOGLE_CLOUD_LOCATION`)
561/// 3. Defaults (location: `us-central1`, publisher: `google`)
562pub(crate) fn resolve_vertex_provider_runtime(
563    entry: &crate::models::ModelEntry,
564) -> Result<VertexProviderRuntime> {
565    // Try to parse project/location/publisher from base_url.
566    let (url_project, url_location, url_publisher) = parse_vertex_base_url(&entry.model.base_url);
567
568    let project = url_project
569        .or_else(|| std::env::var(VERTEX_PROJECT_ENV).ok())
570        .or_else(|| std::env::var(VERTEX_PROJECT_ENV_ALT).ok())
571        .ok_or_else(|| {
572            Error::provider(
573                "google-vertex",
574                format!(
575                    "Missing GCP project. Set {VERTEX_PROJECT_ENV} or provide a Vertex AI base URL \
576                     like https://REGION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/REGION/..."
577                ),
578            )
579        })?;
580
581    let location = url_location
582        .or_else(|| std::env::var(VERTEX_LOCATION_ENV).ok())
583        .or_else(|| std::env::var(VERTEX_LOCATION_ENV_ALT).ok())
584        .unwrap_or_else(|| VERTEX_DEFAULT_REGION.to_string());
585
586    let publisher = url_publisher.unwrap_or_else(|| "google".to_string());
587
588    Ok(VertexProviderRuntime {
589        project,
590        location,
591        publisher,
592        model: entry.model.id.clone(),
593    })
594}
595
596/// Parse project, location, and publisher from a Vertex AI base URL.
597///
598/// Expected format:
599/// `https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/publishers/{publisher}/...`
600fn parse_vertex_base_url(base_url: &str) -> (Option<String>, Option<String>, Option<String>) {
601    if base_url.is_empty() {
602        return (None, None, None);
603    }
604
605    // Extract location from hostname: "{location}-aiplatform.googleapis.com"
606    let location_from_host = base_url
607        .strip_prefix("https://")
608        .or_else(|| base_url.strip_prefix("http://"))
609        .and_then(|rest| rest.split('-').next())
610        .and_then(|loc| {
611            // Validate it looks like a region (e.g. "us", "europe", "asia").
612            if loc.chars().all(|c| c.is_ascii_lowercase() || c == '-') && !loc.is_empty() {
613                Some(loc.to_string())
614            } else {
615                None
616            }
617        });
618
619    // Extract project, location, publisher from path segments.
620    let path_segments: Vec<&str> = base_url.split('/').collect();
621
622    let project = path_segments
623        .iter()
624        .zip(path_segments.iter().skip(1))
625        .find(|(key, _)| **key == "projects")
626        .map(|(_, val)| (*val).to_string());
627
628    let location = path_segments
629        .iter()
630        .zip(path_segments.iter().skip(1))
631        .find(|(key, _)| **key == "locations")
632        .map(|(_, val)| (*val).to_string())
633        .or(location_from_host);
634
635    let publisher = path_segments
636        .iter()
637        .zip(path_segments.iter().skip(1))
638        .find(|(key, _)| **key == "publishers")
639        .map(|(_, val)| (*val).to_string());
640
641    (project, location, publisher)
642}
643
644// ============================================================================
645// Tests
646// ============================================================================
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651    use crate::model::{Message, UserContent};
652    use crate::provider::ToolDef;
653    use asupersync::runtime::RuntimeBuilder;
654    use futures::{StreamExt, stream};
655    use serde_json::Value;
656
657    #[test]
658    fn test_provider_info() {
659        let provider = VertexProvider::new("gemini-2.0-flash");
660        assert_eq!(provider.name(), "google-vertex");
661        assert_eq!(provider.api(), "google-vertex");
662        assert_eq!(provider.model_id(), "gemini-2.0-flash");
663    }
664
665    #[test]
666    fn test_streaming_url_google_publisher() {
667        let provider = VertexProvider::new("gemini-2.0-flash")
668            .with_project("my-project")
669            .with_location("us-central1");
670
671        let url = provider.streaming_url("my-project", "us-central1");
672        assert_eq!(
673            url,
674            "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:streamGenerateContent"
675        );
676    }
677
678    #[test]
679    fn test_streaming_url_anthropic_publisher() {
680        let provider = VertexProvider::new("claude-sonnet-4-20250514")
681            .with_project("my-project")
682            .with_location("europe-west1")
683            .with_publisher("anthropic");
684
685        let url = provider.streaming_url("my-project", "europe-west1");
686        assert_eq!(
687            url,
688            "https://europe-west1-aiplatform.googleapis.com/v1/projects/my-project/locations/europe-west1/publishers/anthropic/models/claude-sonnet-4-20250514:streamRawPredict"
689        );
690    }
691
692    #[test]
693    fn test_streaming_url_override() {
694        let provider =
695            VertexProvider::new("gemini-2.0-flash").with_endpoint_url("http://127.0.0.1:8080/mock");
696
697        let url = provider.streaming_url("ignored", "ignored");
698        assert_eq!(url, "http://127.0.0.1:8080/mock");
699    }
700
701    #[test]
702    fn test_build_gemini_request_basic() {
703        let provider = VertexProvider::new("gemini-2.0-flash");
704        let context = Context::owned(
705            Some("You are helpful.".to_string()),
706            vec![Message::User(crate::model::UserMessage {
707                content: UserContent::Text("What is Vertex AI?".to_string()),
708                timestamp: 0,
709            })],
710            vec![],
711        );
712        let options = StreamOptions {
713            max_tokens: Some(1024),
714            temperature: Some(0.7),
715            ..Default::default()
716        };
717
718        let req = provider.build_gemini_request(&context, &options);
719        let json = serde_json::to_value(&req).expect("serialize");
720
721        let contents = json["contents"].as_array().expect("contents");
722        assert_eq!(contents.len(), 1);
723        assert_eq!(contents[0]["role"], "user");
724        assert_eq!(contents[0]["parts"][0]["text"], "What is Vertex AI?");
725
726        assert_eq!(
727            json["systemInstruction"]["parts"][0]["text"],
728            "You are helpful."
729        );
730        assert_eq!(json["generationConfig"]["maxOutputTokens"], 1024);
731    }
732
733    #[test]
734    fn test_build_gemini_request_with_tools() {
735        let provider = VertexProvider::new("gemini-2.0-flash");
736        let context = Context::owned(
737            None,
738            vec![Message::User(crate::model::UserMessage {
739                content: UserContent::Text("Read a file".to_string()),
740                timestamp: 0,
741            })],
742            vec![ToolDef {
743                name: "read".to_string(),
744                description: "Read a file".to_string(),
745                parameters: serde_json::json!({
746                    "type": "object",
747                    "properties": { "path": {"type": "string"} },
748                    "required": ["path"]
749                }),
750            }],
751        );
752        let options = StreamOptions::default();
753
754        let req = provider.build_gemini_request(&context, &options);
755        let json = serde_json::to_value(&req).expect("serialize");
756
757        let tools = json["tools"].as_array().expect("tools");
758        assert_eq!(tools.len(), 1);
759        let decls = tools[0]["functionDeclarations"]
760            .as_array()
761            .expect("declarations");
762        assert_eq!(decls[0]["name"], "read");
763        assert_eq!(json["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
764    }
765
766    #[test]
767    fn test_parse_vertex_base_url_full() {
768        let url = "https://us-central1-aiplatform.googleapis.com/v1/projects/my-proj/locations/us-central1/publishers/google/models/gemini-2.0-flash";
769        let (project, location, publisher) = parse_vertex_base_url(url);
770        assert_eq!(project.as_deref(), Some("my-proj"));
771        assert_eq!(location.as_deref(), Some("us-central1"));
772        assert_eq!(publisher.as_deref(), Some("google"));
773    }
774
775    #[test]
776    fn test_parse_vertex_base_url_anthropic() {
777        let url = "https://europe-west1-aiplatform.googleapis.com/v1/projects/corp-ai/locations/europe-west1/publishers/anthropic/models/claude-sonnet-4-20250514";
778        let (project, location, publisher) = parse_vertex_base_url(url);
779        assert_eq!(project.as_deref(), Some("corp-ai"));
780        assert_eq!(location.as_deref(), Some("europe-west1"));
781        assert_eq!(publisher.as_deref(), Some("anthropic"));
782    }
783
784    #[test]
785    fn test_parse_vertex_base_url_empty() {
786        let (project, location, publisher) = parse_vertex_base_url("");
787        assert!(project.is_none());
788        assert!(location.is_none());
789        assert!(publisher.is_none());
790    }
791
792    #[test]
793    fn test_parse_vertex_base_url_partial() {
794        let url = "https://us-central1-aiplatform.googleapis.com/v1/projects/my-proj/locations/us-central1";
795        let (project, location, publisher) = parse_vertex_base_url(url);
796        assert_eq!(project.as_deref(), Some("my-proj"));
797        assert_eq!(location.as_deref(), Some("us-central1"));
798        assert!(publisher.is_none());
799    }
800
801    #[test]
802    fn test_resolve_vertex_provider_runtime_from_url() {
803        let entry = crate::models::ModelEntry {
804            model: crate::provider::Model {
805                id: "gemini-2.0-flash".to_string(),
806                name: "Gemini 2.0 Flash".to_string(),
807                api: "google-vertex".to_string(),
808                provider: "google-vertex".to_string(),
809                base_url: "https://us-central1-aiplatform.googleapis.com/v1/projects/test-proj/locations/us-central1/publishers/google/models/gemini-2.0-flash".to_string(),
810                reasoning: false,
811                input: vec![],
812                cost: crate::provider::ModelCost {
813                    input: 0.0,
814                    output: 0.0,
815                    cache_read: 0.0,
816                    cache_write: 0.0,
817                },
818                context_window: 128_000,
819                max_tokens: 8192,
820                headers: std::collections::HashMap::new(),
821            },
822            api_key: None,
823            headers: std::collections::HashMap::new(),
824            auth_header: true,
825            compat: None,
826            oauth_config: None,
827        };
828
829        let runtime = resolve_vertex_provider_runtime(&entry).expect("resolve");
830        assert_eq!(runtime.project, "test-proj");
831        assert_eq!(runtime.location, "us-central1");
832        assert_eq!(runtime.publisher, "google");
833        assert_eq!(runtime.model, "gemini-2.0-flash");
834    }
835
836    // ─── Streaming response parsing ──────────────────────────────────────
837
838    #[test]
839    fn test_stream_text_response() {
840        let events = vec![
841            serde_json::json!({
842                "candidates": [{
843                    "content": {
844                        "role": "model",
845                        "parts": [{"text": "Hello from "}]
846                    }
847                }]
848            }),
849            serde_json::json!({
850                "candidates": [{
851                    "content": {
852                        "role": "model",
853                        "parts": [{"text": "Vertex AI!"}]
854                    },
855                    "finishReason": "STOP"
856                }],
857                "usageMetadata": {
858                    "promptTokenCount": 10,
859                    "candidatesTokenCount": 5,
860                    "totalTokenCount": 15
861                }
862            }),
863        ];
864
865        let stream_events = collect_events(&events);
866
867        // Should have: Start, TextDelta("Hello from "), TextDelta("Vertex AI!"), Done
868        assert!(
869            stream_events
870                .iter()
871                .any(|e| matches!(e, StreamEvent::Start { .. })),
872            "should emit Start"
873        );
874
875        let text_deltas: Vec<&str> = stream_events
876            .iter()
877            .filter_map(|e| match e {
878                StreamEvent::TextDelta { delta, .. } => Some(delta.as_str()),
879                _ => None,
880            })
881            .collect();
882        assert_eq!(text_deltas, vec!["Hello from ", "Vertex AI!"]);
883
884        let done = stream_events
885            .iter()
886            .find_map(|e| match e {
887                StreamEvent::Done { message, .. } => Some(message),
888                _ => None,
889            })
890            .expect("done event");
891        assert_eq!(done.usage.input, 10);
892        assert_eq!(done.usage.output, 5);
893    }
894
895    #[test]
896    fn test_stream_tool_call_response() {
897        let events = vec![serde_json::json!({
898            "candidates": [{
899                "content": {
900                    "role": "model",
901                    "parts": [{
902                        "functionCall": {
903                            "name": "read",
904                            "args": {"path": "/tmp/test.txt"}
905                        }
906                    }]
907                },
908                "finishReason": "STOP"
909            }]
910        })];
911
912        let stream_events = collect_events(&events);
913
914        assert!(
915            stream_events
916                .iter()
917                .any(|e| matches!(e, StreamEvent::ToolCallStart { .. })),
918            "should emit ToolCallStart"
919        );
920        assert!(
921            stream_events
922                .iter()
923                .any(|e| matches!(e, StreamEvent::ToolCallEnd { .. })),
924            "should emit ToolCallEnd"
925        );
926
927        let done = stream_events
928            .iter()
929            .find_map(|e| match e {
930                StreamEvent::Done { message, .. } => Some(message),
931                _ => None,
932            })
933            .expect("done event");
934        assert_eq!(done.stop_reason, StopReason::ToolUse);
935    }
936
937    #[test]
938    fn test_stream_ignores_unknown_parts() {
939        let events = vec![serde_json::json!({
940            "candidates": [{
941                "content": {
942                    "role": "model",
943                    "parts": [
944                        {
945                            "executableCode": {
946                                "language": "python",
947                                "code": "print('x')"
948                            }
949                        },
950                        {"text": "still works"}
951                    ]
952                },
953                "finishReason": "STOP"
954            }]
955        })];
956
957        let stream_events = collect_events(&events);
958
959        let text_deltas: Vec<&str> = stream_events
960            .iter()
961            .filter_map(|e| match e {
962                StreamEvent::TextDelta { delta, .. } => Some(delta.as_str()),
963                _ => None,
964            })
965            .collect();
966        assert_eq!(text_deltas, vec!["still works"]);
967        assert!(
968            stream_events
969                .iter()
970                .any(|e| matches!(e, StreamEvent::Done { .. })),
971            "should emit Done even when unknown parts are present"
972        );
973    }
974
975    // ─── Test helpers ────────────────────────────────────────────────────
976
977    fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
978        let runtime = RuntimeBuilder::current_thread()
979            .build()
980            .expect("runtime build");
981        runtime.block_on(async move {
982            let byte_stream = stream::iter(
983                events
984                    .iter()
985                    .map(|event| {
986                        let data = serde_json::to_string(event).expect("serialize event");
987                        format!("data: {data}\n\n").into_bytes()
988                    })
989                    .map(Ok),
990            );
991            let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
992            let mut state = StreamState::new(
993                event_source,
994                "gemini-test".to_string(),
995                "google-vertex".to_string(),
996                "google-vertex".to_string(),
997            );
998            let mut out = Vec::new();
999
1000            loop {
1001                let Some(item) = state.event_source.next().await else {
1002                    if !state.finished {
1003                        state.finished = true;
1004                        out.push(StreamEvent::Done {
1005                            reason: state.partial.stop_reason,
1006                            message: std::mem::take(&mut state.partial),
1007                        });
1008                    }
1009                    break;
1010                };
1011
1012                let msg = item.expect("SSE event");
1013                if msg.event == "ping" {
1014                    continue;
1015                }
1016                state.process_event(&msg.data).expect("process_event");
1017                out.extend(state.pending_events.drain(..));
1018            }
1019
1020            out
1021        })
1022    }
1023}
1024
1025// ============================================================================
1026// Fuzzing support
1027// ============================================================================
1028
1029#[cfg(feature = "fuzzing")]
1030pub mod fuzz {
1031    use super::*;
1032    use futures::stream;
1033    use std::pin::Pin;
1034
1035    type FuzzStream =
1036        Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1037
1038    /// Opaque wrapper around the Vertex AI stream processor state.
1039    pub struct Processor(StreamState<FuzzStream>);
1040
1041    impl Default for Processor {
1042        fn default() -> Self {
1043            Self::new()
1044        }
1045    }
1046
1047    impl Processor {
1048        /// Create a fresh processor with default state.
1049        pub fn new() -> Self {
1050            let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1051            Self(StreamState::new(
1052                crate::sse::SseStream::new(Box::pin(empty)),
1053                "vertex-fuzz".into(),
1054                "vertex-ai".into(),
1055                "vertex".into(),
1056            ))
1057        }
1058
1059        /// Feed one SSE data payload and return any emitted `StreamEvent`s.
1060        pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1061            self.0.process_event(data)?;
1062            Ok(self.0.pending_events.drain(..).collect())
1063        }
1064    }
1065}