1use crate::error::{Error, Result};
14use crate::http::client::Client;
15use crate::model::{
16 AssistantMessage, ContentBlock, StopReason, StreamEvent, TextContent, ToolCall, Usage,
17};
18use crate::models::CompatConfig;
19use crate::provider::{Context, Provider, StreamOptions};
20use crate::providers::gemini::{
21 self, GeminiCandidate, GeminiContent, GeminiFunctionCall, GeminiFunctionCallingConfig,
22 GeminiGenerationConfig, GeminiPart, GeminiRequest, GeminiStreamResponse, GeminiTool,
23 GeminiToolConfig,
24};
25use crate::sse::SseStream;
26use async_trait::async_trait;
27use futures::StreamExt;
28use futures::stream::{self, Stream};
29use std::collections::VecDeque;
30use std::pin::Pin;
31
32const VERTEX_DEFAULT_REGION: &str = "us-central1";
37
38const VERTEX_PROJECT_ENV: &str = "GOOGLE_CLOUD_PROJECT";
40const VERTEX_PROJECT_ENV_ALT: &str = "VERTEX_PROJECT";
42
43const VERTEX_LOCATION_ENV: &str = "GOOGLE_CLOUD_LOCATION";
45const VERTEX_LOCATION_ENV_ALT: &str = "VERTEX_LOCATION";
47
48pub struct VertexProvider {
55 client: Client,
56 model: String,
57 project: Option<String>,
59 location: String,
61 publisher: String,
63 endpoint_url_override: Option<String>,
65 compat: Option<CompatConfig>,
66}
67
68impl VertexProvider {
69 pub fn new(model: impl Into<String>) -> Self {
71 Self {
72 client: Client::new(),
73 model: model.into(),
74 project: None,
75 location: VERTEX_DEFAULT_REGION.to_string(),
76 publisher: "google".to_string(),
77 endpoint_url_override: None,
78 compat: None,
79 }
80 }
81
82 #[must_use]
84 pub fn with_project(mut self, project: impl Into<String>) -> Self {
85 self.project = Some(project.into());
86 self
87 }
88
89 #[must_use]
91 pub fn with_location(mut self, location: impl Into<String>) -> Self {
92 self.location = location.into();
93 self
94 }
95
96 #[must_use]
98 pub fn with_publisher(mut self, publisher: impl Into<String>) -> Self {
99 self.publisher = publisher.into();
100 self
101 }
102
103 #[must_use]
105 pub fn with_endpoint_url(mut self, url: impl Into<String>) -> Self {
106 self.endpoint_url_override = Some(url.into());
107 self
108 }
109
110 #[must_use]
112 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
113 self.compat = compat;
114 self
115 }
116
117 #[must_use]
119 pub fn with_client(mut self, client: Client) -> Self {
120 self.client = client;
121 self
122 }
123
124 fn resolve_project(&self) -> Result<String> {
126 if let Some(project) = &self.project {
127 return Ok(project.clone());
128 }
129 std::env::var(VERTEX_PROJECT_ENV)
130 .or_else(|_| std::env::var(VERTEX_PROJECT_ENV_ALT))
131 .map_err(|_| {
132 Error::provider(
133 "google-vertex",
134 format!(
135 "Missing GCP project. Set {VERTEX_PROJECT_ENV} or {VERTEX_PROJECT_ENV_ALT}, \
136 or configure `project` in provider settings."
137 ),
138 )
139 })
140 }
141
142 fn resolve_location(&self) -> String {
144 if self.location != VERTEX_DEFAULT_REGION {
145 return self.location.clone();
146 }
147 std::env::var(VERTEX_LOCATION_ENV)
148 .or_else(|_| std::env::var(VERTEX_LOCATION_ENV_ALT))
149 .unwrap_or_else(|_| VERTEX_DEFAULT_REGION.to_string())
150 }
151
152 fn streaming_url(&self, project: &str, location: &str) -> String {
157 if let Some(url) = &self.endpoint_url_override {
158 return url.clone();
159 }
160
161 let method = if self.publisher == "anthropic" {
162 "streamRawPredict"
163 } else {
164 "streamGenerateContent"
165 };
166
167 format!(
168 "https://{location}-aiplatform.googleapis.com/v1/projects/{project}/locations/{location}/publishers/{publisher}/models/{model}:{method}",
169 location = location,
170 project = project,
171 publisher = self.publisher,
172 model = self.model,
173 method = method,
174 )
175 }
176
177 #[allow(clippy::unused_self)]
179 pub fn build_gemini_request(
180 &self,
181 context: &Context<'_>,
182 options: &StreamOptions,
183 ) -> GeminiRequest {
184 let contents = Self::build_contents(context);
185 let system_instruction = context.system_prompt.as_deref().map(|s| GeminiContent {
186 role: None,
187 parts: vec![GeminiPart::Text {
188 text: s.to_string(),
189 }],
190 });
191
192 let tools: Option<Vec<GeminiTool>> = if context.tools.is_empty() {
193 None
194 } else {
195 Some(vec![GeminiTool {
196 function_declarations: context
197 .tools
198 .iter()
199 .map(gemini::convert_tool_to_gemini)
200 .collect(),
201 }])
202 };
203
204 let tool_config = if tools.is_some() {
205 Some(GeminiToolConfig {
206 function_calling_config: GeminiFunctionCallingConfig { mode: "AUTO" },
207 })
208 } else {
209 None
210 };
211
212 GeminiRequest {
213 contents,
214 system_instruction,
215 tools,
216 tool_config,
217 generation_config: Some(GeminiGenerationConfig {
218 max_output_tokens: options.max_tokens.or(Some(gemini::DEFAULT_MAX_TOKENS)),
219 temperature: options.temperature,
220 candidate_count: Some(1),
221 }),
222 }
223 }
224
225 fn build_contents(context: &Context<'_>) -> Vec<GeminiContent> {
227 let mut contents = Vec::new();
228 for message in context.messages.iter() {
229 contents.extend(gemini::convert_message_to_gemini(message));
230 }
231 contents
232 }
233}
234
235#[async_trait]
236impl Provider for VertexProvider {
237 fn name(&self) -> &'static str {
238 "google-vertex"
239 }
240
241 fn api(&self) -> &'static str {
242 "google-vertex"
243 }
244
245 fn model_id(&self) -> &str {
246 &self.model
247 }
248
249 #[allow(clippy::too_many_lines)]
250 async fn stream(
251 &self,
252 context: &Context<'_>,
253 options: &StreamOptions,
254 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
255 let auth_value = options
257 .api_key
258 .clone()
259 .or_else(|| std::env::var("GOOGLE_CLOUD_API_KEY").ok())
260 .or_else(|| std::env::var("VERTEX_API_KEY").ok())
261 .ok_or_else(|| {
262 Error::provider(
263 "google-vertex",
264 "Missing Vertex AI API key / access token. \
265 Set GOOGLE_CLOUD_API_KEY or VERTEX_API_KEY.",
266 )
267 })?;
268
269 let project = self.resolve_project()?;
270 let location = self.resolve_location();
271 let url = self.streaming_url(&project, &location);
272
273 let request_body = self.build_gemini_request(context, options);
275
276 let mut request = self
278 .client
279 .post(&url)
280 .header("Accept", "text/event-stream")
281 .header("Authorization", format!("Bearer {auth_value}"));
282
283 if let Some(compat) = &self.compat {
285 if let Some(custom_headers) = &compat.custom_headers {
286 for (key, value) in custom_headers {
287 request = request.header(key, value);
288 }
289 }
290 }
291
292 for (key, value) in &options.headers {
294 request = request.header(key, value);
295 }
296
297 let request = request.json(&request_body)?;
298
299 let response = Box::pin(request.send()).await?;
300 let status = response.status();
301 if !(200..300).contains(&status) {
302 let body = response
303 .text()
304 .await
305 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
306 return Err(Error::provider(
307 "google-vertex",
308 format!("Vertex AI API error (HTTP {status}): {body}"),
309 ));
310 }
311
312 let event_source = SseStream::new(response.bytes_stream());
314
315 let model = self.model.clone();
317 let api = self.api().to_string();
318 let provider = self.name().to_string();
319
320 let stream = stream::unfold(
321 StreamState::new(event_source, model, api, provider),
322 |mut state| async move {
323 if state.finished {
324 return None;
325 }
326 loop {
327 if let Some(event) = state.pending_events.pop_front() {
329 return Some((Ok(event), state));
330 }
331
332 match state.event_source.next().await {
333 Some(Ok(msg)) => {
334 if msg.event == "ping" {
335 continue;
336 }
337
338 if let Err(e) = state.process_event(&msg.data) {
339 state.finished = true;
340 return Some((Err(e), state));
341 }
342 }
343 Some(Err(e)) => {
344 state.finished = true;
345 let err = Error::api(format!("SSE error: {e}"));
346 return Some((Err(err), state));
347 }
348 None => {
349 state.finished = true;
351 let reason = state.partial.stop_reason;
352 let message = std::mem::take(&mut state.partial);
353 return Some((Ok(StreamEvent::Done { reason, message }), state));
354 }
355 }
356 }
357 },
358 );
359
360 Ok(Box::pin(stream))
361 }
362}
363
364struct StreamState<S>
369where
370 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
371{
372 event_source: SseStream<S>,
373 partial: AssistantMessage,
374 pending_events: VecDeque<StreamEvent>,
375 started: bool,
376 finished: bool,
377}
378
379impl<S> StreamState<S>
380where
381 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
382{
383 fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
384 Self {
385 event_source,
386 partial: AssistantMessage {
387 content: Vec::new(),
388 api,
389 provider,
390 model,
391 usage: Usage::default(),
392 stop_reason: StopReason::Stop,
393 error_message: None,
394 timestamp: chrono::Utc::now().timestamp_millis(),
395 },
396 pending_events: VecDeque::new(),
397 started: false,
398 finished: false,
399 }
400 }
401
402 fn process_event(&mut self, data: &str) -> Result<()> {
403 let response: GeminiStreamResponse = serde_json::from_str(data)
404 .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
405
406 if let Some(metadata) = response.usage_metadata {
408 self.partial.usage.input = metadata.prompt_token_count.unwrap_or(0);
409 self.partial.usage.output = metadata.candidates_token_count.unwrap_or(0);
410 self.partial.usage.total_tokens = metadata.total_token_count.unwrap_or(0);
411 }
412
413 if let Some(candidates) = response.candidates {
415 if let Some(candidate) = candidates.into_iter().next() {
416 self.process_candidate(candidate)?;
417 }
418 }
419
420 Ok(())
421 }
422
423 #[allow(clippy::unnecessary_wraps)]
424 fn process_candidate(&mut self, candidate: GeminiCandidate) -> Result<()> {
425 if let Some(ref reason) = candidate.finish_reason {
427 self.partial.stop_reason = match reason.as_str() {
428 "MAX_TOKENS" => StopReason::Length,
429 "SAFETY" | "RECITATION" | "OTHER" => StopReason::Error,
430 _ => StopReason::Stop,
431 };
432 }
433
434 if let Some(content) = candidate.content {
436 for part in content.parts {
437 match part {
438 GeminiPart::Text { text } => {
439 let last_is_text =
440 matches!(self.partial.content.last(), Some(ContentBlock::Text(_)));
441 if !last_is_text {
442 let content_index = self.partial.content.len();
443 self.partial
444 .content
445 .push(ContentBlock::Text(TextContent::new("")));
446
447 self.ensure_started();
448
449 self.pending_events
450 .push_back(StreamEvent::TextStart { content_index });
451 }
452 let content_index = self.partial.content.len() - 1;
453
454 if let Some(ContentBlock::Text(t)) =
455 self.partial.content.get_mut(content_index)
456 {
457 t.text.push_str(&text);
458 }
459
460 self.ensure_started();
461
462 self.pending_events.push_back(StreamEvent::TextDelta {
463 content_index,
464 delta: text,
465 });
466 }
467 GeminiPart::FunctionCall { function_call } => {
468 let id = format!("call_{}", uuid::Uuid::new_v4().simple());
469
470 let args_str = serde_json::to_string(&function_call.args)
471 .unwrap_or_else(|_| "{}".to_string());
472 let GeminiFunctionCall { name, args } = function_call;
473
474 let tool_call = ToolCall {
475 id,
476 name,
477 arguments: args,
478 thought_signature: None,
479 };
480
481 self.partial
482 .content
483 .push(ContentBlock::ToolCall(tool_call.clone()));
484 let content_index = self.partial.content.len() - 1;
485
486 self.partial.stop_reason = StopReason::ToolUse;
487
488 self.ensure_started();
489
490 self.pending_events
491 .push_back(StreamEvent::ToolCallStart { content_index });
492 self.pending_events.push_back(StreamEvent::ToolCallDelta {
493 content_index,
494 delta: args_str,
495 });
496 self.pending_events.push_back(StreamEvent::ToolCallEnd {
497 content_index,
498 tool_call,
499 });
500 }
501 GeminiPart::InlineData { .. }
502 | GeminiPart::FunctionResponse { .. }
503 | GeminiPart::Unknown(_) => {
504 }
508 }
509 }
510 }
511
512 if candidate.finish_reason.is_some() {
515 for (content_index, block) in self.partial.content.iter().enumerate() {
516 if let ContentBlock::Text(t) = block {
517 self.pending_events.push_back(StreamEvent::TextEnd {
518 content_index,
519 content: t.text.clone(),
520 });
521 } else if let ContentBlock::Thinking(t) = block {
522 self.pending_events.push_back(StreamEvent::ThinkingEnd {
523 content_index,
524 content: t.thinking.clone(),
525 });
526 }
527 }
528 }
529
530 Ok(())
531 }
532
533 fn ensure_started(&mut self) {
534 if !self.started {
535 self.started = true;
536 self.pending_events.push_back(StreamEvent::Start {
537 partial: self.partial.clone(),
538 });
539 }
540 }
541}
542
543#[derive(Debug, Clone, PartialEq, Eq)]
549pub(crate) struct VertexProviderRuntime {
550 pub(crate) project: String,
551 pub(crate) location: String,
552 pub(crate) publisher: String,
553 pub(crate) model: String,
554}
555
556pub(crate) fn resolve_vertex_provider_runtime(
563 entry: &crate::models::ModelEntry,
564) -> Result<VertexProviderRuntime> {
565 let (url_project, url_location, url_publisher) = parse_vertex_base_url(&entry.model.base_url);
567
568 let project = url_project
569 .or_else(|| std::env::var(VERTEX_PROJECT_ENV).ok())
570 .or_else(|| std::env::var(VERTEX_PROJECT_ENV_ALT).ok())
571 .ok_or_else(|| {
572 Error::provider(
573 "google-vertex",
574 format!(
575 "Missing GCP project. Set {VERTEX_PROJECT_ENV} or provide a Vertex AI base URL \
576 like https://REGION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/REGION/..."
577 ),
578 )
579 })?;
580
581 let location = url_location
582 .or_else(|| std::env::var(VERTEX_LOCATION_ENV).ok())
583 .or_else(|| std::env::var(VERTEX_LOCATION_ENV_ALT).ok())
584 .unwrap_or_else(|| VERTEX_DEFAULT_REGION.to_string());
585
586 let publisher = url_publisher.unwrap_or_else(|| "google".to_string());
587
588 Ok(VertexProviderRuntime {
589 project,
590 location,
591 publisher,
592 model: entry.model.id.clone(),
593 })
594}
595
596fn parse_vertex_base_url(base_url: &str) -> (Option<String>, Option<String>, Option<String>) {
601 if base_url.is_empty() {
602 return (None, None, None);
603 }
604
605 let location_from_host = base_url
607 .strip_prefix("https://")
608 .or_else(|| base_url.strip_prefix("http://"))
609 .and_then(|rest| rest.split('-').next())
610 .and_then(|loc| {
611 if loc.chars().all(|c| c.is_ascii_lowercase() || c == '-') && !loc.is_empty() {
613 Some(loc.to_string())
614 } else {
615 None
616 }
617 });
618
619 let path_segments: Vec<&str> = base_url.split('/').collect();
621
622 let project = path_segments
623 .iter()
624 .zip(path_segments.iter().skip(1))
625 .find(|(key, _)| **key == "projects")
626 .map(|(_, val)| (*val).to_string());
627
628 let location = path_segments
629 .iter()
630 .zip(path_segments.iter().skip(1))
631 .find(|(key, _)| **key == "locations")
632 .map(|(_, val)| (*val).to_string())
633 .or(location_from_host);
634
635 let publisher = path_segments
636 .iter()
637 .zip(path_segments.iter().skip(1))
638 .find(|(key, _)| **key == "publishers")
639 .map(|(_, val)| (*val).to_string());
640
641 (project, location, publisher)
642}
643
644#[cfg(test)]
649mod tests {
650 use super::*;
651 use crate::model::{Message, UserContent};
652 use crate::provider::ToolDef;
653 use asupersync::runtime::RuntimeBuilder;
654 use futures::{StreamExt, stream};
655 use serde_json::Value;
656
657 #[test]
658 fn test_provider_info() {
659 let provider = VertexProvider::new("gemini-2.0-flash");
660 assert_eq!(provider.name(), "google-vertex");
661 assert_eq!(provider.api(), "google-vertex");
662 assert_eq!(provider.model_id(), "gemini-2.0-flash");
663 }
664
665 #[test]
666 fn test_streaming_url_google_publisher() {
667 let provider = VertexProvider::new("gemini-2.0-flash")
668 .with_project("my-project")
669 .with_location("us-central1");
670
671 let url = provider.streaming_url("my-project", "us-central1");
672 assert_eq!(
673 url,
674 "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:streamGenerateContent"
675 );
676 }
677
678 #[test]
679 fn test_streaming_url_anthropic_publisher() {
680 let provider = VertexProvider::new("claude-sonnet-4-20250514")
681 .with_project("my-project")
682 .with_location("europe-west1")
683 .with_publisher("anthropic");
684
685 let url = provider.streaming_url("my-project", "europe-west1");
686 assert_eq!(
687 url,
688 "https://europe-west1-aiplatform.googleapis.com/v1/projects/my-project/locations/europe-west1/publishers/anthropic/models/claude-sonnet-4-20250514:streamRawPredict"
689 );
690 }
691
692 #[test]
693 fn test_streaming_url_override() {
694 let provider =
695 VertexProvider::new("gemini-2.0-flash").with_endpoint_url("http://127.0.0.1:8080/mock");
696
697 let url = provider.streaming_url("ignored", "ignored");
698 assert_eq!(url, "http://127.0.0.1:8080/mock");
699 }
700
701 #[test]
702 fn test_build_gemini_request_basic() {
703 let provider = VertexProvider::new("gemini-2.0-flash");
704 let context = Context::owned(
705 Some("You are helpful.".to_string()),
706 vec![Message::User(crate::model::UserMessage {
707 content: UserContent::Text("What is Vertex AI?".to_string()),
708 timestamp: 0,
709 })],
710 vec![],
711 );
712 let options = StreamOptions {
713 max_tokens: Some(1024),
714 temperature: Some(0.7),
715 ..Default::default()
716 };
717
718 let req = provider.build_gemini_request(&context, &options);
719 let json = serde_json::to_value(&req).expect("serialize");
720
721 let contents = json["contents"].as_array().expect("contents");
722 assert_eq!(contents.len(), 1);
723 assert_eq!(contents[0]["role"], "user");
724 assert_eq!(contents[0]["parts"][0]["text"], "What is Vertex AI?");
725
726 assert_eq!(
727 json["systemInstruction"]["parts"][0]["text"],
728 "You are helpful."
729 );
730 assert_eq!(json["generationConfig"]["maxOutputTokens"], 1024);
731 }
732
733 #[test]
734 fn test_build_gemini_request_with_tools() {
735 let provider = VertexProvider::new("gemini-2.0-flash");
736 let context = Context::owned(
737 None,
738 vec![Message::User(crate::model::UserMessage {
739 content: UserContent::Text("Read a file".to_string()),
740 timestamp: 0,
741 })],
742 vec![ToolDef {
743 name: "read".to_string(),
744 description: "Read a file".to_string(),
745 parameters: serde_json::json!({
746 "type": "object",
747 "properties": { "path": {"type": "string"} },
748 "required": ["path"]
749 }),
750 }],
751 );
752 let options = StreamOptions::default();
753
754 let req = provider.build_gemini_request(&context, &options);
755 let json = serde_json::to_value(&req).expect("serialize");
756
757 let tools = json["tools"].as_array().expect("tools");
758 assert_eq!(tools.len(), 1);
759 let decls = tools[0]["functionDeclarations"]
760 .as_array()
761 .expect("declarations");
762 assert_eq!(decls[0]["name"], "read");
763 assert_eq!(json["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
764 }
765
766 #[test]
767 fn test_parse_vertex_base_url_full() {
768 let url = "https://us-central1-aiplatform.googleapis.com/v1/projects/my-proj/locations/us-central1/publishers/google/models/gemini-2.0-flash";
769 let (project, location, publisher) = parse_vertex_base_url(url);
770 assert_eq!(project.as_deref(), Some("my-proj"));
771 assert_eq!(location.as_deref(), Some("us-central1"));
772 assert_eq!(publisher.as_deref(), Some("google"));
773 }
774
775 #[test]
776 fn test_parse_vertex_base_url_anthropic() {
777 let url = "https://europe-west1-aiplatform.googleapis.com/v1/projects/corp-ai/locations/europe-west1/publishers/anthropic/models/claude-sonnet-4-20250514";
778 let (project, location, publisher) = parse_vertex_base_url(url);
779 assert_eq!(project.as_deref(), Some("corp-ai"));
780 assert_eq!(location.as_deref(), Some("europe-west1"));
781 assert_eq!(publisher.as_deref(), Some("anthropic"));
782 }
783
784 #[test]
785 fn test_parse_vertex_base_url_empty() {
786 let (project, location, publisher) = parse_vertex_base_url("");
787 assert!(project.is_none());
788 assert!(location.is_none());
789 assert!(publisher.is_none());
790 }
791
792 #[test]
793 fn test_parse_vertex_base_url_partial() {
794 let url = "https://us-central1-aiplatform.googleapis.com/v1/projects/my-proj/locations/us-central1";
795 let (project, location, publisher) = parse_vertex_base_url(url);
796 assert_eq!(project.as_deref(), Some("my-proj"));
797 assert_eq!(location.as_deref(), Some("us-central1"));
798 assert!(publisher.is_none());
799 }
800
801 #[test]
802 fn test_resolve_vertex_provider_runtime_from_url() {
803 let entry = crate::models::ModelEntry {
804 model: crate::provider::Model {
805 id: "gemini-2.0-flash".to_string(),
806 name: "Gemini 2.0 Flash".to_string(),
807 api: "google-vertex".to_string(),
808 provider: "google-vertex".to_string(),
809 base_url: "https://us-central1-aiplatform.googleapis.com/v1/projects/test-proj/locations/us-central1/publishers/google/models/gemini-2.0-flash".to_string(),
810 reasoning: false,
811 input: vec![],
812 cost: crate::provider::ModelCost {
813 input: 0.0,
814 output: 0.0,
815 cache_read: 0.0,
816 cache_write: 0.0,
817 },
818 context_window: 128_000,
819 max_tokens: 8192,
820 headers: std::collections::HashMap::new(),
821 },
822 api_key: None,
823 headers: std::collections::HashMap::new(),
824 auth_header: true,
825 compat: None,
826 oauth_config: None,
827 };
828
829 let runtime = resolve_vertex_provider_runtime(&entry).expect("resolve");
830 assert_eq!(runtime.project, "test-proj");
831 assert_eq!(runtime.location, "us-central1");
832 assert_eq!(runtime.publisher, "google");
833 assert_eq!(runtime.model, "gemini-2.0-flash");
834 }
835
836 #[test]
839 fn test_stream_text_response() {
840 let events = vec![
841 serde_json::json!({
842 "candidates": [{
843 "content": {
844 "role": "model",
845 "parts": [{"text": "Hello from "}]
846 }
847 }]
848 }),
849 serde_json::json!({
850 "candidates": [{
851 "content": {
852 "role": "model",
853 "parts": [{"text": "Vertex AI!"}]
854 },
855 "finishReason": "STOP"
856 }],
857 "usageMetadata": {
858 "promptTokenCount": 10,
859 "candidatesTokenCount": 5,
860 "totalTokenCount": 15
861 }
862 }),
863 ];
864
865 let stream_events = collect_events(&events);
866
867 assert!(
869 stream_events
870 .iter()
871 .any(|e| matches!(e, StreamEvent::Start { .. })),
872 "should emit Start"
873 );
874
875 let text_deltas: Vec<&str> = stream_events
876 .iter()
877 .filter_map(|e| match e {
878 StreamEvent::TextDelta { delta, .. } => Some(delta.as_str()),
879 _ => None,
880 })
881 .collect();
882 assert_eq!(text_deltas, vec!["Hello from ", "Vertex AI!"]);
883
884 let done = stream_events
885 .iter()
886 .find_map(|e| match e {
887 StreamEvent::Done { message, .. } => Some(message),
888 _ => None,
889 })
890 .expect("done event");
891 assert_eq!(done.usage.input, 10);
892 assert_eq!(done.usage.output, 5);
893 }
894
895 #[test]
896 fn test_stream_tool_call_response() {
897 let events = vec![serde_json::json!({
898 "candidates": [{
899 "content": {
900 "role": "model",
901 "parts": [{
902 "functionCall": {
903 "name": "read",
904 "args": {"path": "/tmp/test.txt"}
905 }
906 }]
907 },
908 "finishReason": "STOP"
909 }]
910 })];
911
912 let stream_events = collect_events(&events);
913
914 assert!(
915 stream_events
916 .iter()
917 .any(|e| matches!(e, StreamEvent::ToolCallStart { .. })),
918 "should emit ToolCallStart"
919 );
920 assert!(
921 stream_events
922 .iter()
923 .any(|e| matches!(e, StreamEvent::ToolCallEnd { .. })),
924 "should emit ToolCallEnd"
925 );
926
927 let done = stream_events
928 .iter()
929 .find_map(|e| match e {
930 StreamEvent::Done { message, .. } => Some(message),
931 _ => None,
932 })
933 .expect("done event");
934 assert_eq!(done.stop_reason, StopReason::ToolUse);
935 }
936
937 #[test]
938 fn test_stream_ignores_unknown_parts() {
939 let events = vec![serde_json::json!({
940 "candidates": [{
941 "content": {
942 "role": "model",
943 "parts": [
944 {
945 "executableCode": {
946 "language": "python",
947 "code": "print('x')"
948 }
949 },
950 {"text": "still works"}
951 ]
952 },
953 "finishReason": "STOP"
954 }]
955 })];
956
957 let stream_events = collect_events(&events);
958
959 let text_deltas: Vec<&str> = stream_events
960 .iter()
961 .filter_map(|e| match e {
962 StreamEvent::TextDelta { delta, .. } => Some(delta.as_str()),
963 _ => None,
964 })
965 .collect();
966 assert_eq!(text_deltas, vec!["still works"]);
967 assert!(
968 stream_events
969 .iter()
970 .any(|e| matches!(e, StreamEvent::Done { .. })),
971 "should emit Done even when unknown parts are present"
972 );
973 }
974
975 fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
978 let runtime = RuntimeBuilder::current_thread()
979 .build()
980 .expect("runtime build");
981 runtime.block_on(async move {
982 let byte_stream = stream::iter(
983 events
984 .iter()
985 .map(|event| {
986 let data = serde_json::to_string(event).expect("serialize event");
987 format!("data: {data}\n\n").into_bytes()
988 })
989 .map(Ok),
990 );
991 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
992 let mut state = StreamState::new(
993 event_source,
994 "gemini-test".to_string(),
995 "google-vertex".to_string(),
996 "google-vertex".to_string(),
997 );
998 let mut out = Vec::new();
999
1000 loop {
1001 let Some(item) = state.event_source.next().await else {
1002 if !state.finished {
1003 state.finished = true;
1004 out.push(StreamEvent::Done {
1005 reason: state.partial.stop_reason,
1006 message: std::mem::take(&mut state.partial),
1007 });
1008 }
1009 break;
1010 };
1011
1012 let msg = item.expect("SSE event");
1013 if msg.event == "ping" {
1014 continue;
1015 }
1016 state.process_event(&msg.data).expect("process_event");
1017 out.extend(state.pending_events.drain(..));
1018 }
1019
1020 out
1021 })
1022 }
1023}
1024
1025#[cfg(feature = "fuzzing")]
1030pub mod fuzz {
1031 use super::*;
1032 use futures::stream;
1033 use std::pin::Pin;
1034
1035 type FuzzStream =
1036 Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1037
1038 pub struct Processor(StreamState<FuzzStream>);
1040
1041 impl Default for Processor {
1042 fn default() -> Self {
1043 Self::new()
1044 }
1045 }
1046
1047 impl Processor {
1048 pub fn new() -> Self {
1050 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1051 Self(StreamState::new(
1052 crate::sse::SseStream::new(Box::pin(empty)),
1053 "vertex-fuzz".into(),
1054 "vertex-ai".into(),
1055 "vertex".into(),
1056 ))
1057 }
1058
1059 pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1061 self.0.process_event(data)?;
1062 Ok(self.0.pending_events.drain(..).collect())
1063 }
1064 }
1065}