1use crate::error::{Error, Result};
14use crate::http::client::Client;
15use crate::model::{
16 AssistantMessage, ContentBlock, StopReason, StreamEvent, TextContent, ToolCall, Usage,
17};
18use crate::models::CompatConfig;
19use crate::provider::{Context, Provider, StreamOptions};
20use crate::providers::gemini::{
21 self, GeminiCandidate, GeminiContent, GeminiFunctionCall, GeminiFunctionCallingConfig,
22 GeminiGenerationConfig, GeminiPart, GeminiRequest, GeminiStreamResponse, GeminiTool,
23 GeminiToolConfig,
24};
25use crate::sse::SseStream;
26use async_trait::async_trait;
27use futures::StreamExt;
28use futures::stream::{self, Stream};
29use std::collections::VecDeque;
30use std::pin::Pin;
31
32const VERTEX_DEFAULT_REGION: &str = "us-central1";
37
38const VERTEX_PROJECT_ENV: &str = "GOOGLE_CLOUD_PROJECT";
40const VERTEX_PROJECT_ENV_ALT: &str = "VERTEX_PROJECT";
42
43const VERTEX_LOCATION_ENV: &str = "GOOGLE_CLOUD_LOCATION";
45const VERTEX_LOCATION_ENV_ALT: &str = "VERTEX_LOCATION";
47
48pub struct VertexProvider {
55 client: Client,
56 model: String,
57 project: Option<String>,
59 location: String,
61 publisher: String,
63 endpoint_url_override: Option<String>,
65 compat: Option<CompatConfig>,
66}
67
68impl VertexProvider {
69 pub fn new(model: impl Into<String>) -> Self {
71 Self {
72 client: Client::new(),
73 model: model.into(),
74 project: None,
75 location: VERTEX_DEFAULT_REGION.to_string(),
76 publisher: "google".to_string(),
77 endpoint_url_override: None,
78 compat: None,
79 }
80 }
81
82 #[must_use]
84 pub fn with_project(mut self, project: impl Into<String>) -> Self {
85 self.project = Some(project.into());
86 self
87 }
88
89 #[must_use]
91 pub fn with_location(mut self, location: impl Into<String>) -> Self {
92 self.location = location.into();
93 self
94 }
95
96 #[must_use]
98 pub fn with_publisher(mut self, publisher: impl Into<String>) -> Self {
99 self.publisher = publisher.into();
100 self
101 }
102
103 #[must_use]
105 pub fn with_endpoint_url(mut self, url: impl Into<String>) -> Self {
106 self.endpoint_url_override = Some(url.into());
107 self
108 }
109
110 #[must_use]
112 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
113 self.compat = compat;
114 self
115 }
116
117 #[must_use]
119 pub fn with_client(mut self, client: Client) -> Self {
120 self.client = client;
121 self
122 }
123
124 fn resolve_project(&self) -> Result<String> {
126 if let Some(project) = &self.project {
127 return Ok(project.clone());
128 }
129 std::env::var(VERTEX_PROJECT_ENV)
130 .or_else(|_| std::env::var(VERTEX_PROJECT_ENV_ALT))
131 .map_err(|_| {
132 Error::provider(
133 "google-vertex",
134 format!(
135 "Missing GCP project. Set {VERTEX_PROJECT_ENV} or {VERTEX_PROJECT_ENV_ALT}, \
136 or configure `project` in provider settings."
137 ),
138 )
139 })
140 }
141
142 fn resolve_location(&self) -> String {
144 if self.location != VERTEX_DEFAULT_REGION {
145 return self.location.clone();
146 }
147 std::env::var(VERTEX_LOCATION_ENV)
148 .or_else(|_| std::env::var(VERTEX_LOCATION_ENV_ALT))
149 .unwrap_or_else(|_| VERTEX_DEFAULT_REGION.to_string())
150 }
151
152 fn streaming_url(&self, project: &str, location: &str) -> String {
157 if let Some(url) = &self.endpoint_url_override {
158 return url.clone();
159 }
160
161 let method = if self.publisher == "anthropic" {
162 "streamRawPredict"
163 } else {
164 "streamGenerateContent"
165 };
166
167 format!(
168 "https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{method}",
169 location = location,
170 project = project,
171 publisher = self.publisher,
172 model = self.model,
173 method = method,
174 )
175 }
176
177 #[allow(clippy::unused_self)]
179 pub fn build_gemini_request(
180 &self,
181 context: &Context<'_>,
182 options: &StreamOptions,
183 ) -> GeminiRequest {
184 let contents = Self::build_contents(context);
185 let system_instruction = context.system_prompt.as_deref().map(|s| GeminiContent {
186 role: None,
187 parts: vec![GeminiPart::Text {
188 text: s.to_string(),
189 }],
190 });
191
192 let tools: Option<Vec<GeminiTool>> = if context.tools.is_empty() {
193 None
194 } else {
195 Some(vec![GeminiTool {
196 function_declarations: context
197 .tools
198 .iter()
199 .map(gemini::convert_tool_to_gemini)
200 .collect(),
201 }])
202 };
203
204 let tool_config = if tools.is_some() {
205 Some(GeminiToolConfig {
206 function_calling_config: GeminiFunctionCallingConfig { mode: "AUTO" },
207 })
208 } else {
209 None
210 };
211
212 GeminiRequest {
213 contents,
214 system_instruction,
215 tools,
216 tool_config,
217 generation_config: Some(GeminiGenerationConfig {
218 max_output_tokens: options.max_tokens.or(Some(gemini::DEFAULT_MAX_TOKENS)),
219 temperature: options.temperature,
220 candidate_count: Some(1),
221 }),
222 }
223 }
224
225 fn build_contents(context: &Context<'_>) -> Vec<GeminiContent> {
227 let mut contents = Vec::new();
228 for message in context.messages.iter() {
229 contents.extend(gemini::convert_message_to_gemini(message));
230 }
231 contents
232 }
233}
234
235#[async_trait]
236impl Provider for VertexProvider {
237 fn name(&self) -> &'static str {
238 "google-vertex"
239 }
240
241 fn api(&self) -> &'static str {
242 "google-vertex"
243 }
244
245 fn model_id(&self) -> &str {
246 &self.model
247 }
248
249 #[allow(clippy::too_many_lines)]
250 async fn stream(
251 &self,
252 context: &Context<'_>,
253 options: &StreamOptions,
254 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
255 let auth_value = options
257 .api_key
258 .clone()
259 .or_else(|| std::env::var("GOOGLE_CLOUD_API_KEY").ok())
260 .or_else(|| std::env::var("VERTEX_API_KEY").ok())
261 .ok_or_else(|| {
262 Error::provider(
263 "google-vertex",
264 "Missing Vertex AI API key / access token. \
265 Set GOOGLE_CLOUD_API_KEY or VERTEX_API_KEY.",
266 )
267 })?;
268
269 let project = self.resolve_project()?;
270 let location = self.resolve_location();
271 let url = self.streaming_url(&project, &location);
272
273 let request_body = self.build_gemini_request(context, options);
275
276 let mut request = self
278 .client
279 .post(&url)
280 .header("Accept", "text/event-stream")
281 .header("Authorization", format!("Bearer {auth_value}"));
282
283 if let Some(compat) = &self.compat {
285 if let Some(custom_headers) = &compat.custom_headers {
286 for (key, value) in custom_headers {
287 request = request.header(key, value);
288 }
289 }
290 }
291
292 for (key, value) in &options.headers {
294 request = request.header(key, value);
295 }
296
297 let request = request.json(&request_body)?;
298
299 let response = Box::pin(request.send()).await?;
300 let status = response.status();
301 if !(200..300).contains(&status) {
302 let body = response
303 .text()
304 .await
305 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
306 return Err(Error::provider(
307 "google-vertex",
308 format!("Vertex AI API error (HTTP {status}): {body}"),
309 ));
310 }
311
312 let event_source = SseStream::new(response.bytes_stream());
314
315 let model = self.model.clone();
317 let api = self.api().to_string();
318 let provider = self.name().to_string();
319
320 let stream = stream::unfold(
321 StreamState::new(event_source, model, api, provider),
322 |mut state| async move {
323 if state.finished {
324 return None;
325 }
326 loop {
327 if let Some(event) = state.pending_events.pop_front() {
329 return Some((Ok(event), state));
330 }
331
332 match state.event_source.next().await {
333 Some(Ok(msg)) => {
334 state.transient_error_count = 0;
335 if msg.event == "ping" {
336 continue;
337 }
338
339 if let Err(e) = state.process_event(&msg.data) {
340 state.finished = true;
341 return Some((Err(e), state));
342 }
343 }
344 Some(Err(e)) => {
345 const MAX_CONSECUTIVE_TRANSIENT_ERRORS: usize = 5;
349 if e.kind() == std::io::ErrorKind::WriteZero
350 || e.kind() == std::io::ErrorKind::WouldBlock
351 || e.kind() == std::io::ErrorKind::TimedOut
352 {
353 state.transient_error_count += 1;
354 if state.transient_error_count <= MAX_CONSECUTIVE_TRANSIENT_ERRORS {
355 tracing::warn!(
356 kind = ?e.kind(),
357 count = state.transient_error_count,
358 "Transient error in SSE stream, continuing"
359 );
360 continue;
361 }
362 tracing::warn!(
363 kind = ?e.kind(),
364 "Error persisted after {MAX_CONSECUTIVE_TRANSIENT_ERRORS} \
365 consecutive attempts, treating as fatal"
366 );
367 }
368 state.finished = true;
369 let err = Error::api(format!("SSE error: {e}"));
370 return Some((Err(err), state));
371 }
372 None => {
373 state.finished = true;
375 let reason = state.partial.stop_reason;
376 let message = std::mem::take(&mut state.partial);
377 return Some((Ok(StreamEvent::Done { reason, message }), state));
378 }
379 }
380 }
381 },
382 );
383
384 Ok(Box::pin(stream))
385 }
386}
387
388struct StreamState<S>
393where
394 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
395{
396 event_source: SseStream<S>,
397 partial: AssistantMessage,
398 pending_events: VecDeque<StreamEvent>,
399 started: bool,
400 finished: bool,
401 transient_error_count: usize,
403}
404
405impl<S> StreamState<S>
406where
407 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
408{
409 fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
410 Self {
411 event_source,
412 partial: AssistantMessage {
413 content: Vec::new(),
414 api,
415 provider,
416 model,
417 usage: Usage::default(),
418 stop_reason: StopReason::Stop,
419 error_message: None,
420 timestamp: chrono::Utc::now().timestamp_millis(),
421 },
422 pending_events: VecDeque::new(),
423 started: false,
424 finished: false,
425 transient_error_count: 0,
426 }
427 }
428
429 fn process_event(&mut self, data: &str) -> Result<()> {
430 let response: GeminiStreamResponse = serde_json::from_str(data)
431 .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
432
433 if let Some(metadata) = response.usage_metadata {
435 self.partial.usage.input = metadata.prompt_token_count.unwrap_or(0);
436 self.partial.usage.output = metadata.candidates_token_count.unwrap_or(0);
437 self.partial.usage.total_tokens = metadata.total_token_count.unwrap_or(0);
438 }
439
440 if let Some(candidates) = response.candidates {
442 if let Some(candidate) = candidates.into_iter().next() {
443 self.process_candidate(candidate)?;
444 }
445 }
446
447 Ok(())
448 }
449
450 #[allow(clippy::unnecessary_wraps)]
451 fn process_candidate(&mut self, candidate: GeminiCandidate) -> Result<()> {
452 if let Some(ref reason) = candidate.finish_reason {
454 self.partial.stop_reason = match reason.as_str() {
455 "MAX_TOKENS" => StopReason::Length,
456 "SAFETY" | "RECITATION" | "OTHER" => StopReason::Error,
457 "FUNCTION_CALL" => StopReason::ToolUse,
458 _ => StopReason::Stop,
459 };
460 }
461
462 if let Some(content) = candidate.content {
464 for part in content.parts {
465 match part {
466 GeminiPart::Text { text } => {
467 let last_is_text =
468 matches!(self.partial.content.last(), Some(ContentBlock::Text(_)));
469 if !last_is_text {
470 let content_index = self.partial.content.len();
471 self.partial
472 .content
473 .push(ContentBlock::Text(TextContent::new("")));
474
475 self.ensure_started();
476
477 self.pending_events
478 .push_back(StreamEvent::TextStart { content_index });
479 }
480 let content_index = self.partial.content.len() - 1;
481
482 if let Some(ContentBlock::Text(t)) =
483 self.partial.content.get_mut(content_index)
484 {
485 t.text.push_str(&text);
486 }
487
488 self.ensure_started();
489
490 self.pending_events.push_back(StreamEvent::TextDelta {
491 content_index,
492 delta: text,
493 });
494 }
495 GeminiPart::FunctionCall { function_call } => {
496 let id = format!("call_{}", uuid::Uuid::new_v4().simple());
497
498 let args_str = serde_json::to_string(&function_call.args)
499 .unwrap_or_else(|_| "{}".to_string());
500 let GeminiFunctionCall { name, args } = function_call;
501
502 let tool_call = ToolCall {
503 id,
504 name,
505 arguments: args,
506 thought_signature: None,
507 };
508
509 self.partial
510 .content
511 .push(ContentBlock::ToolCall(tool_call.clone()));
512 let content_index = self.partial.content.len() - 1;
513
514 self.partial.stop_reason = StopReason::ToolUse;
515
516 self.ensure_started();
517
518 self.pending_events
519 .push_back(StreamEvent::ToolCallStart { content_index });
520 self.pending_events.push_back(StreamEvent::ToolCallDelta {
521 content_index,
522 delta: args_str,
523 });
524 self.pending_events.push_back(StreamEvent::ToolCallEnd {
525 content_index,
526 tool_call,
527 });
528 }
529 GeminiPart::InlineData { .. }
530 | GeminiPart::FunctionResponse { .. }
531 | GeminiPart::Unknown(_) => {
532 }
536 }
537 }
538 }
539
540 if candidate.finish_reason.is_some() {
543 for (content_index, block) in self.partial.content.iter().enumerate() {
544 if let ContentBlock::Text(t) = block {
545 self.pending_events.push_back(StreamEvent::TextEnd {
546 content_index,
547 content: t.text.clone(),
548 });
549 } else if let ContentBlock::Thinking(t) = block {
550 self.pending_events.push_back(StreamEvent::ThinkingEnd {
551 content_index,
552 content: t.thinking.clone(),
553 });
554 }
555 }
556 }
557
558 Ok(())
559 }
560
561 fn ensure_started(&mut self) {
562 if !self.started {
563 self.started = true;
564 self.pending_events.push_back(StreamEvent::Start {
565 partial: self.partial.clone(),
566 });
567 }
568 }
569}
570
571#[derive(Debug, Clone, PartialEq, Eq)]
577pub(crate) struct VertexProviderRuntime {
578 pub(crate) project: String,
579 pub(crate) location: String,
580 pub(crate) publisher: String,
581 pub(crate) model: String,
582}
583
584pub(crate) fn resolve_vertex_provider_runtime(
591 entry: &crate::models::ModelEntry,
592) -> Result<VertexProviderRuntime> {
593 let (url_project, url_location, url_publisher) = parse_vertex_base_url(&entry.model.base_url);
595
596 let project = url_project
597 .or_else(|| std::env::var(VERTEX_PROJECT_ENV).ok())
598 .or_else(|| std::env::var(VERTEX_PROJECT_ENV_ALT).ok())
599 .ok_or_else(|| {
600 Error::provider(
601 "google-vertex",
602 format!(
603 "Missing GCP project. Set {VERTEX_PROJECT_ENV} or provide a Vertex AI base URL \
604 like https://REGION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/REGION/..."
605 ),
606 )
607 })?;
608
609 let location = url_location
610 .or_else(|| std::env::var(VERTEX_LOCATION_ENV).ok())
611 .or_else(|| std::env::var(VERTEX_LOCATION_ENV_ALT).ok())
612 .unwrap_or_else(|| VERTEX_DEFAULT_REGION.to_string());
613
614 let publisher = url_publisher.unwrap_or_else(|| "google".to_string());
615
616 Ok(VertexProviderRuntime {
617 project,
618 location,
619 publisher,
620 model: entry.model.id.clone(),
621 })
622}
623
624fn parse_vertex_base_url(base_url: &str) -> (Option<String>, Option<String>, Option<String>) {
629 if base_url.is_empty() {
630 return (None, None, None);
631 }
632
633 let location_from_host = base_url
635 .strip_prefix("https://")
636 .or_else(|| base_url.strip_prefix("http://"))
637 .and_then(|rest| rest.split('-').next())
638 .and_then(|loc| {
639 if loc.chars().all(|c| c.is_ascii_lowercase() || c == '-') && !loc.is_empty() {
641 Some(loc.to_string())
642 } else {
643 None
644 }
645 });
646
647 let path_segments: Vec<&str> = base_url.split('/').collect();
649
650 let project = path_segments
651 .iter()
652 .zip(path_segments.iter().skip(1))
653 .find(|(key, _)| **key == "projects")
654 .map(|(_, val)| (*val).to_string());
655
656 let location = path_segments
657 .iter()
658 .zip(path_segments.iter().skip(1))
659 .find(|(key, _)| **key == "locations")
660 .map(|(_, val)| (*val).to_string())
661 .or(location_from_host);
662
663 let publisher = path_segments
664 .iter()
665 .zip(path_segments.iter().skip(1))
666 .find(|(key, _)| **key == "publishers")
667 .map(|(_, val)| (*val).to_string());
668
669 (project, location, publisher)
670}
671
672#[cfg(test)]
677mod tests {
678 use super::*;
679 use crate::model::{Message, UserContent};
680 use crate::provider::ToolDef;
681 use asupersync::runtime::RuntimeBuilder;
682 use futures::{StreamExt, stream};
683 use serde_json::Value;
684
685 #[test]
686 fn test_provider_info() {
687 let provider = VertexProvider::new("gemini-2.0-flash");
688 assert_eq!(provider.name(), "google-vertex");
689 assert_eq!(provider.api(), "google-vertex");
690 assert_eq!(provider.model_id(), "gemini-2.0-flash");
691 }
692
693 #[test]
694 fn test_streaming_url_google_publisher() {
695 let provider = VertexProvider::new("gemini-2.0-flash")
696 .with_project("my-project")
697 .with_location("us-central1");
698
699 let url = provider.streaming_url("my-project", "us-central1");
700 assert_eq!(
701 url,
702 "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:streamGenerateContent"
703 );
704 }
705
706 #[test]
707 fn test_streaming_url_anthropic_publisher() {
708 let provider = VertexProvider::new("claude-sonnet-4-20250514")
709 .with_project("my-project")
710 .with_location("europe-west1")
711 .with_publisher("anthropic");
712
713 let url = provider.streaming_url("my-project", "europe-west1");
714 assert_eq!(
715 url,
716 "https://europe-west1-aiplatform.googleapis.com/v1/projects/my-project/locations/europe-west1/publishers/anthropic/models/claude-sonnet-4-20250514:streamRawPredict"
717 );
718 }
719
720 #[test]
721 fn test_streaming_url_override() {
722 let provider =
723 VertexProvider::new("gemini-2.0-flash").with_endpoint_url("http://127.0.0.1:8080/mock");
724
725 let url = provider.streaming_url("ignored", "ignored");
726 assert_eq!(url, "http://127.0.0.1:8080/mock");
727 }
728
729 #[test]
730 fn test_build_gemini_request_basic() {
731 let provider = VertexProvider::new("gemini-2.0-flash");
732 let context = Context::owned(
733 Some("You are helpful.".to_string()),
734 vec![Message::User(crate::model::UserMessage {
735 content: UserContent::Text("What is Vertex AI?".to_string()),
736 timestamp: 0,
737 })],
738 vec![],
739 );
740 let options = StreamOptions {
741 max_tokens: Some(1024),
742 temperature: Some(0.7),
743 ..Default::default()
744 };
745
746 let req = provider.build_gemini_request(&context, &options);
747 let json = serde_json::to_value(&req).expect("serialize");
748
749 let contents = json["contents"].as_array().expect("contents");
750 assert_eq!(contents.len(), 1);
751 assert_eq!(contents[0]["role"], "user");
752 assert_eq!(contents[0]["parts"][0]["text"], "What is Vertex AI?");
753
754 assert_eq!(
755 json["systemInstruction"]["parts"][0]["text"],
756 "You are helpful."
757 );
758 assert_eq!(json["generationConfig"]["maxOutputTokens"], 1024);
759 }
760
761 #[test]
762 fn test_build_gemini_request_with_tools() {
763 let provider = VertexProvider::new("gemini-2.0-flash");
764 let context = Context::owned(
765 None,
766 vec![Message::User(crate::model::UserMessage {
767 content: UserContent::Text("Read a file".to_string()),
768 timestamp: 0,
769 })],
770 vec![ToolDef {
771 name: "read".to_string(),
772 description: "Read a file".to_string(),
773 parameters: serde_json::json!({
774 "type": "object",
775 "properties": { "path": {"type": "string"} },
776 "required": ["path"]
777 }),
778 }],
779 );
780 let options = StreamOptions::default();
781
782 let req = provider.build_gemini_request(&context, &options);
783 let json = serde_json::to_value(&req).expect("serialize");
784
785 let tools = json["tools"].as_array().expect("tools");
786 assert_eq!(tools.len(), 1);
787 let decls = tools[0]["functionDeclarations"]
788 .as_array()
789 .expect("declarations");
790 assert_eq!(decls[0]["name"], "read");
791 assert_eq!(json["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
792 }
793
794 #[test]
795 fn test_parse_vertex_base_url_full() {
796 let url = "https://us-central1-aiplatform.googleapis.com/v1/projects/my-proj/locations/us-central1/publishers/google/models/gemini-2.0-flash";
797 let (project, location, publisher) = parse_vertex_base_url(url);
798 assert_eq!(project.as_deref(), Some("my-proj"));
799 assert_eq!(location.as_deref(), Some("us-central1"));
800 assert_eq!(publisher.as_deref(), Some("google"));
801 }
802
803 #[test]
804 fn test_parse_vertex_base_url_anthropic() {
805 let url = "https://europe-west1-aiplatform.googleapis.com/v1/projects/corp-ai/locations/europe-west1/publishers/anthropic/models/claude-sonnet-4-20250514";
806 let (project, location, publisher) = parse_vertex_base_url(url);
807 assert_eq!(project.as_deref(), Some("corp-ai"));
808 assert_eq!(location.as_deref(), Some("europe-west1"));
809 assert_eq!(publisher.as_deref(), Some("anthropic"));
810 }
811
812 #[test]
813 fn test_parse_vertex_base_url_empty() {
814 let (project, location, publisher) = parse_vertex_base_url("");
815 assert!(project.is_none());
816 assert!(location.is_none());
817 assert!(publisher.is_none());
818 }
819
820 #[test]
821 fn test_parse_vertex_base_url_partial() {
822 let url = "https://us-central1-aiplatform.googleapis.com/v1/projects/my-proj/locations/us-central1";
823 let (project, location, publisher) = parse_vertex_base_url(url);
824 assert_eq!(project.as_deref(), Some("my-proj"));
825 assert_eq!(location.as_deref(), Some("us-central1"));
826 assert!(publisher.is_none());
827 }
828
829 #[test]
830 fn test_resolve_vertex_provider_runtime_from_url() {
831 let entry = crate::models::ModelEntry {
832 model: crate::provider::Model {
833 id: "gemini-2.0-flash".to_string(),
834 name: "Gemini 2.0 Flash".to_string(),
835 api: "google-vertex".to_string(),
836 provider: "google-vertex".to_string(),
837 base_url: "https://us-central1-aiplatform.googleapis.com/v1/projects/test-proj/locations/us-central1/publishers/google/models/gemini-2.0-flash".to_string(),
838 reasoning: false,
839 input: vec![],
840 cost: crate::provider::ModelCost {
841 input: 0.0,
842 output: 0.0,
843 cache_read: 0.0,
844 cache_write: 0.0,
845 },
846 context_window: 128_000,
847 max_tokens: 8192,
848 headers: std::collections::HashMap::new(),
849 },
850 api_key: None,
851 headers: std::collections::HashMap::new(),
852 auth_header: true,
853 compat: None,
854 oauth_config: None,
855 };
856
857 let runtime = resolve_vertex_provider_runtime(&entry).expect("resolve");
858 assert_eq!(runtime.project, "test-proj");
859 assert_eq!(runtime.location, "us-central1");
860 assert_eq!(runtime.publisher, "google");
861 assert_eq!(runtime.model, "gemini-2.0-flash");
862 }
863
864 #[test]
867 fn test_stream_text_response() {
868 let events = vec![
869 serde_json::json!({
870 "candidates": [{
871 "content": {
872 "role": "model",
873 "parts": [{"text": "Hello from "}]
874 }
875 }]
876 }),
877 serde_json::json!({
878 "candidates": [{
879 "content": {
880 "role": "model",
881 "parts": [{"text": "Vertex AI!"}]
882 },
883 "finishReason": "STOP"
884 }],
885 "usageMetadata": {
886 "promptTokenCount": 10,
887 "candidatesTokenCount": 5,
888 "totalTokenCount": 15
889 }
890 }),
891 ];
892
893 let stream_events = collect_events(&events);
894
895 assert!(
897 stream_events
898 .iter()
899 .any(|e| matches!(e, StreamEvent::Start { .. })),
900 "should emit Start"
901 );
902
903 let text_deltas: Vec<&str> = stream_events
904 .iter()
905 .filter_map(|e| match e {
906 StreamEvent::TextDelta { delta, .. } => Some(delta.as_str()),
907 _ => None,
908 })
909 .collect();
910 assert_eq!(text_deltas, vec!["Hello from ", "Vertex AI!"]);
911
912 let done = stream_events
913 .iter()
914 .find_map(|e| match e {
915 StreamEvent::Done { message, .. } => Some(message),
916 _ => None,
917 })
918 .expect("done event");
919 assert_eq!(done.usage.input, 10);
920 assert_eq!(done.usage.output, 5);
921 }
922
923 #[test]
924 fn test_stream_tool_call_response() {
925 let events = vec![serde_json::json!({
926 "candidates": [{
927 "content": {
928 "role": "model",
929 "parts": [{
930 "functionCall": {
931 "name": "read",
932 "args": {"path": "/tmp/test.txt"}
933 }
934 }]
935 },
936 "finishReason": "STOP"
937 }]
938 })];
939
940 let stream_events = collect_events(&events);
941
942 assert!(
943 stream_events
944 .iter()
945 .any(|e| matches!(e, StreamEvent::ToolCallStart { .. })),
946 "should emit ToolCallStart"
947 );
948 assert!(
949 stream_events
950 .iter()
951 .any(|e| matches!(e, StreamEvent::ToolCallEnd { .. })),
952 "should emit ToolCallEnd"
953 );
954
955 let done = stream_events
956 .iter()
957 .find_map(|e| match e {
958 StreamEvent::Done { message, .. } => Some(message),
959 _ => None,
960 })
961 .expect("done event");
962 assert_eq!(done.stop_reason, StopReason::ToolUse);
963 }
964
965 #[test]
966 fn test_stream_ignores_unknown_parts() {
967 let events = vec![serde_json::json!({
968 "candidates": [{
969 "content": {
970 "role": "model",
971 "parts": [
972 {
973 "executableCode": {
974 "language": "python",
975 "code": "print('x')"
976 }
977 },
978 {"text": "still works"}
979 ]
980 },
981 "finishReason": "STOP"
982 }]
983 })];
984
985 let stream_events = collect_events(&events);
986
987 let text_deltas: Vec<&str> = stream_events
988 .iter()
989 .filter_map(|e| match e {
990 StreamEvent::TextDelta { delta, .. } => Some(delta.as_str()),
991 _ => None,
992 })
993 .collect();
994 assert_eq!(text_deltas, vec!["still works"]);
995 assert!(
996 stream_events
997 .iter()
998 .any(|e| matches!(e, StreamEvent::Done { .. })),
999 "should emit Done even when unknown parts are present"
1000 );
1001 }
1002
1003 fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1006 let runtime = RuntimeBuilder::current_thread()
1007 .build()
1008 .expect("runtime build");
1009 runtime.block_on(async move {
1010 let byte_stream = stream::iter(
1011 events
1012 .iter()
1013 .map(|event| {
1014 let data = serde_json::to_string(event).expect("serialize event");
1015 format!("data: {data}\n\n").into_bytes()
1016 })
1017 .map(Ok),
1018 );
1019 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1020 let mut state = StreamState::new(
1021 event_source,
1022 "gemini-test".to_string(),
1023 "google-vertex".to_string(),
1024 "google-vertex".to_string(),
1025 );
1026 let mut out = Vec::new();
1027
1028 loop {
1029 let Some(item) = state.event_source.next().await else {
1030 if !state.finished {
1031 state.finished = true;
1032 out.push(StreamEvent::Done {
1033 reason: state.partial.stop_reason,
1034 message: std::mem::take(&mut state.partial),
1035 });
1036 }
1037 break;
1038 };
1039
1040 let msg = item.expect("SSE event");
1041 if msg.event == "ping" {
1042 continue;
1043 }
1044 state.process_event(&msg.data).expect("process_event");
1045 out.extend(state.pending_events.drain(..));
1046 }
1047
1048 out
1049 })
1050 }
1051}
1052
1053#[cfg(feature = "fuzzing")]
1058pub mod fuzz {
1059 use super::*;
1060 use futures::stream;
1061 use std::pin::Pin;
1062
1063 type FuzzStream =
1064 Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1065
1066 pub struct Processor(StreamState<FuzzStream>);
1068
1069 impl Default for Processor {
1070 fn default() -> Self {
1071 Self::new()
1072 }
1073 }
1074
1075 impl Processor {
1076 pub fn new() -> Self {
1078 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1079 Self(StreamState::new(
1080 crate::sse::SseStream::new(Box::pin(empty)),
1081 "vertex-fuzz".into(),
1082 "vertex-ai".into(),
1083 "vertex".into(),
1084 ))
1085 }
1086
1087 pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1089 self.0.process_event(data)?;
1090 Ok(self.0.pending_events.drain(..).collect())
1091 }
1092 }
1093}