1use crate::error::{Error, Result};
7use crate::http::client::Client;
8use crate::model::{
9 AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ToolCall, Usage,
10 UserContent,
11};
12use crate::models::CompatConfig;
13use crate::provider::{Context, Provider, StreamOptions, ToolDef};
14use crate::sse::SseStream;
15use async_trait::async_trait;
16use futures::StreamExt;
17use futures::stream::{self, Stream};
18use serde::{Deserialize, Serialize};
19use std::collections::VecDeque;
20use std::pin::Pin;
21
22const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
27const GOOGLE_GEMINI_CLI_BASE: &str = "https://cloudcode-pa.googleapis.com";
28const GOOGLE_ANTIGRAVITY_BASE: &str = "https://daily-cloudcode-pa.sandbox.googleapis.com";
29pub(crate) const DEFAULT_MAX_TOKENS: u32 = 8192;
30
31pub struct GeminiProvider {
37 client: Client,
38 model: String,
39 base_url: String,
40 provider: String,
41 api: String,
42 google_cli_mode: bool,
43 compat: Option<CompatConfig>,
44}
45
46impl GeminiProvider {
47 pub fn new(model: impl Into<String>) -> Self {
49 Self {
50 client: Client::new(),
51 model: model.into(),
52 base_url: GEMINI_API_BASE.to_string(),
53 provider: "google".to_string(),
54 api: "google-generative-ai".to_string(),
55 google_cli_mode: false,
56 compat: None,
57 }
58 }
59
60 #[must_use]
62 pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
63 self.provider = provider.into();
64 self
65 }
66
67 #[must_use]
69 pub fn with_api_name(mut self, api: impl Into<String>) -> Self {
70 self.api = api.into();
71 self
72 }
73
74 #[must_use]
76 pub const fn with_google_cli_mode(mut self, enabled: bool) -> Self {
77 self.google_cli_mode = enabled;
78 self
79 }
80
81 #[must_use]
83 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
84 self.base_url = base_url.into();
85 self
86 }
87
88 #[must_use]
90 pub fn with_client(mut self, client: Client) -> Self {
91 self.client = client;
92 self
93 }
94
95 #[must_use]
97 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
98 self.compat = compat;
99 self
100 }
101
102 pub fn streaming_url(&self) -> String {
104 let base = {
105 let trimmed = self.base_url.trim();
106 if trimmed.is_empty() {
107 if self.google_cli_mode {
108 if self.provider.eq_ignore_ascii_case("google-antigravity") {
109 GOOGLE_ANTIGRAVITY_BASE
110 } else {
111 GOOGLE_GEMINI_CLI_BASE
112 }
113 } else {
114 GEMINI_API_BASE
115 }
116 } else {
117 trimmed
118 }
119 };
120 if self.google_cli_mode {
121 format!("{base}/v1internal:streamGenerateContent?alt=sse")
122 } else {
123 format!("{base}/models/{}:streamGenerateContent?alt=sse", self.model)
124 }
125 }
126
127 #[allow(clippy::unused_self)]
129 pub fn build_request(&self, context: &Context<'_>, options: &StreamOptions) -> GeminiRequest {
130 let contents = Self::build_contents(context);
131 let system_instruction = context.system_prompt.as_deref().map(|s| GeminiContent {
132 role: None,
133 parts: vec![GeminiPart::Text {
134 text: s.to_string(),
135 }],
136 });
137
138 let tools: Option<Vec<GeminiTool>> = if context.tools.is_empty() {
139 None
140 } else {
141 Some(vec![GeminiTool {
142 function_declarations: context.tools.iter().map(convert_tool_to_gemini).collect(),
143 }])
144 };
145
146 let tool_config = if tools.is_some() {
147 Some(GeminiToolConfig {
148 function_calling_config: GeminiFunctionCallingConfig { mode: "AUTO" },
149 })
150 } else {
151 None
152 };
153
154 GeminiRequest {
155 contents,
156 system_instruction,
157 tools,
158 tool_config,
159 generation_config: Some(GeminiGenerationConfig {
160 max_output_tokens: options.max_tokens.or(Some(DEFAULT_MAX_TOKENS)),
161 temperature: options.temperature,
162 candidate_count: Some(1),
163 }),
164 }
165 }
166
167 fn build_contents(context: &Context<'_>) -> Vec<GeminiContent> {
169 let mut contents = Vec::with_capacity(context.messages.len());
170
171 for message in context.messages.iter() {
172 contents.extend(convert_message_to_gemini(message));
173 }
174
175 contents
176 }
177}
178
179#[derive(Debug, Serialize)]
180#[serde(rename_all = "camelCase")]
181struct CloudCodeAssistRequest {
182 project: String,
183 model: String,
184 request: GeminiRequest,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 request_type: Option<String>,
187 user_agent: String,
188 request_id: String,
189}
190
191fn build_google_cli_request(
192 model_id: &str,
193 project_id: &str,
194 request: GeminiRequest,
195 is_antigravity: bool,
196) -> std::result::Result<CloudCodeAssistRequest, &'static str> {
197 let safe_project = project_id.trim();
198 if safe_project.is_empty() {
199 return Err(
200 "Missing Google Cloud project ID for Gemini CLI. Set GOOGLE_CLOUD_PROJECT (or configure gcloud) and re-authenticate with /login google-gemini-cli.",
201 );
202 }
203 let project = if safe_project.starts_with("projects/") {
204 safe_project.to_string()
205 } else {
206 format!("projects/{safe_project}/locations/global")
207 };
208 Ok(CloudCodeAssistRequest {
209 project,
210 model: model_id.to_string(),
211 request,
212 request_type: is_antigravity.then(|| "agent".to_string()),
213 user_agent: if is_antigravity {
214 "antigravity".to_string()
215 } else {
216 "pi-coding-agent".to_string()
217 },
218 request_id: format!(
219 "{}-{}",
220 if is_antigravity { "agent" } else { "pi" },
221 uuid::Uuid::new_v4().simple()
222 ),
223 })
224}
225
226fn decode_project_scoped_access_payload(payload: &str) -> Option<(String, String)> {
227 let value: serde_json::Value = serde_json::from_str(payload).ok()?;
228 let token = value
229 .get("token")
230 .and_then(serde_json::Value::as_str)
231 .map(str::trim)
232 .filter(|value| !value.is_empty())?
233 .to_string();
234 let project_id = value
235 .get("projectId")
236 .or_else(|| value.get("project_id"))
237 .and_then(serde_json::Value::as_str)
238 .map(str::trim)
239 .filter(|value| !value.is_empty())?
240 .to_string();
241 Some((token, project_id))
242}
243
244#[async_trait]
245impl Provider for GeminiProvider {
246 fn name(&self) -> &str {
247 &self.provider
248 }
249
250 fn api(&self) -> &str {
251 &self.api
252 }
253
254 fn model_id(&self) -> &str {
255 &self.model
256 }
257
258 #[allow(clippy::too_many_lines)]
259 async fn stream(
260 &self,
261 context: &Context<'_>,
262 options: &StreamOptions,
263 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
264 let request_body = self.build_request(context, options);
265 let url = self.streaming_url();
266
267 let mut request = self.client.post(&url).header("Accept", "text/event-stream");
269
270 if self.google_cli_mode {
271 let api_payload = options.api_key.clone().ok_or_else(|| {
272 Error::provider(
273 self.name(),
274 "Google Gemini CLI requires OAuth credentials. Run /login google-gemini-cli.",
275 )
276 })?;
277 let (access_token, project_id) = decode_project_scoped_access_payload(&api_payload)
278 .ok_or_else(|| {
279 Error::provider(
280 self.name(),
281 "Invalid Google Gemini CLI OAuth payload (expected JSON {token, projectId}). Run /login google-gemini-cli again.",
282 )
283 })?;
284 let is_antigravity = self.provider.eq_ignore_ascii_case("google-antigravity");
285
286 request = request
287 .header("Authorization", format!("Bearer {access_token}"))
288 .header("Content-Type", "application/json")
289 .header("x-goog-api-client", "gl-node/22.17.0")
290 .header(
291 "client-metadata",
292 r#"{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}"#,
293 );
294
295 if is_antigravity {
296 request = request.header("User-Agent", "antigravity/1.15.8 darwin/arm64");
297 } else {
298 request =
299 request.header("User-Agent", "google-cloud-sdk vscode_cloudshelleditor/0.1");
300 }
301
302 if let Some(compat) = &self.compat {
304 if let Some(custom_headers) = &compat.custom_headers {
305 for (key, value) in custom_headers {
306 request = request.header(key, value);
307 }
308 }
309 }
310
311 for (key, value) in &options.headers {
313 request = request.header(key, value);
314 }
315
316 let cli_request =
317 build_google_cli_request(&self.model, &project_id, request_body, is_antigravity)
318 .map_err(|message| Error::provider(self.name(), message.to_string()))?;
319 let request = request.json(&cli_request)?;
320 let response = Box::pin(request.send()).await?;
321 let status = response.status();
322 if !(200..300).contains(&status) {
323 let body = response
324 .text()
325 .await
326 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
327 return Err(Error::provider(
328 self.name(),
329 format!("Gemini CLI API error (HTTP {status}): {body}"),
330 ));
331 }
332
333 let event_source = SseStream::new(response.bytes_stream());
335 let model = self.model.clone();
336 let api = self.api().to_string();
337 let provider = self.name().to_string();
338 let cloud_cli_mode = self.google_cli_mode;
339
340 let stream = stream::unfold(
341 StreamState::new(event_source, model, api, provider),
342 move |mut state| async move {
343 if state.finished {
344 return None;
345 }
346 loop {
347 if let Some(event) = state.pending_events.pop_front() {
349 return Some((Ok(event), state));
350 }
351
352 match state.event_source.next().await {
353 Some(Ok(msg)) => {
354 if msg.event == "ping" {
355 continue;
356 }
357
358 let processing = if cloud_cli_mode {
359 state.process_cloud_code_event(&msg.data)
360 } else {
361 state.process_event(&msg.data)
362 };
363 if let Err(e) = processing {
364 state.finished = true;
365 return Some((Err(e), state));
366 }
367 }
368 Some(Err(e)) => {
369 state.finished = true;
370 let err = Error::api(format!("SSE error: {e}"));
371 return Some((Err(err), state));
372 }
373 None => {
374 state.finished = true;
376 let reason = state.partial.stop_reason;
377 let message = std::mem::take(&mut state.partial);
378 return Some((Ok(StreamEvent::Done { reason, message }), state));
379 }
380 }
381 }
382 },
383 );
384
385 return Ok(Box::pin(stream));
386 }
387
388 let auth_value = options
389 .api_key
390 .clone()
391 .or_else(|| std::env::var("GOOGLE_API_KEY").ok())
392 .or_else(|| std::env::var("GEMINI_API_KEY").ok())
393 .ok_or_else(|| {
394 Error::provider(
395 self.name(),
396 "Missing API key for Google/Gemini. Set GOOGLE_API_KEY or GEMINI_API_KEY.",
397 )
398 })?;
399
400 request = request.header("x-goog-api-key", &auth_value);
401
402 if let Some(compat) = &self.compat {
404 if let Some(custom_headers) = &compat.custom_headers {
405 for (key, value) in custom_headers {
406 request = request.header(key, value);
407 }
408 }
409 }
410
411 for (key, value) in &options.headers {
413 request = request.header(key, value);
414 }
415
416 let request = request.json(&request_body)?;
417
418 let response = Box::pin(request.send()).await?;
419 let status = response.status();
420 if !(200..300).contains(&status) {
421 let body = response
422 .text()
423 .await
424 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
425 return Err(Error::provider(
426 self.name(),
427 format!("Gemini API error (HTTP {status}): {body}"),
428 ));
429 }
430
431 let event_source = SseStream::new(response.bytes_stream());
433
434 let model = self.model.clone();
436 let api = self.api().to_string();
437 let provider = self.name().to_string();
438 let cloud_cli_mode = self.google_cli_mode;
439
440 let stream = stream::unfold(
441 StreamState::new(event_source, model, api, provider),
442 move |mut state| async move {
443 if state.finished {
444 return None;
445 }
446 loop {
447 if let Some(event) = state.pending_events.pop_front() {
449 return Some((Ok(event), state));
450 }
451
452 match state.event_source.next().await {
453 Some(Ok(msg)) => {
454 if msg.event == "ping" {
455 continue;
456 }
457
458 let processing = if cloud_cli_mode {
459 state.process_cloud_code_event(&msg.data)
460 } else {
461 state.process_event(&msg.data)
462 };
463 if let Err(e) = processing {
464 state.finished = true;
465 return Some((Err(e), state));
466 }
467 }
468 Some(Err(e)) => {
469 state.finished = true;
470 let err = Error::api(format!("SSE error: {e}"));
471 return Some((Err(err), state));
472 }
473 None => {
474 state.finished = true;
476 let reason = state.partial.stop_reason;
477 let message = std::mem::take(&mut state.partial);
478 return Some((Ok(StreamEvent::Done { reason, message }), state));
479 }
480 }
481 }
482 },
483 );
484
485 Ok(Box::pin(stream))
486 }
487}
488
489struct StreamState<S>
494where
495 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
496{
497 event_source: SseStream<S>,
498 partial: AssistantMessage,
499 pending_events: VecDeque<StreamEvent>,
500 started: bool,
501 finished: bool,
502}
503
504impl<S> StreamState<S>
505where
506 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
507{
508 fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
509 Self {
510 event_source,
511 partial: AssistantMessage {
512 content: Vec::new(),
513 api,
514 provider,
515 model,
516 usage: Usage::default(),
517 stop_reason: StopReason::Stop,
518 error_message: None,
519 timestamp: chrono::Utc::now().timestamp_millis(),
520 },
521 pending_events: VecDeque::new(),
522 started: false,
523 finished: false,
524 }
525 }
526
527 fn process_event(&mut self, data: &str) -> Result<()> {
528 let response: GeminiStreamResponse = serde_json::from_str(data)
529 .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
530 self.process_response(response)
531 }
532
533 fn process_response(&mut self, response: GeminiStreamResponse) -> Result<()> {
534 if let Some(metadata) = response.usage_metadata {
536 self.partial.usage.input = metadata.prompt_token_count.unwrap_or(0);
537 self.partial.usage.output = metadata.candidates_token_count.unwrap_or(0);
538 self.partial.usage.total_tokens = metadata.total_token_count.unwrap_or(0);
539 }
540
541 if let Some(candidates) = response.candidates {
543 if let Some(candidate) = candidates.into_iter().next() {
544 self.process_candidate(candidate)?;
545 }
546 }
547
548 Ok(())
549 }
550
551 fn process_cloud_code_event(&mut self, data: &str) -> Result<()> {
552 let wrapped: CloudCodeAssistResponseChunk = serde_json::from_str(data)
553 .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
554 let Some(response) = wrapped.response else {
555 return Ok(());
556 };
557 self.process_response(GeminiStreamResponse {
558 candidates: response.candidates,
559 usage_metadata: response.usage_metadata,
560 })
561 }
562
563 #[allow(clippy::unnecessary_wraps)]
564 fn process_candidate(&mut self, candidate: GeminiCandidate) -> Result<()> {
565 let has_finish_reason = candidate.finish_reason.is_some();
566
567 if let Some(reason) = candidate.finish_reason.as_deref() {
569 self.partial.stop_reason = match reason {
570 "MAX_TOKENS" => StopReason::Length,
571 "SAFETY" | "RECITATION" | "OTHER" => StopReason::Error,
572 "FUNCTION_CALL" => StopReason::ToolUse,
573 _ => StopReason::Stop,
575 };
576 }
577
578 if let Some(content) = candidate.content {
580 for part in content.parts {
581 match part {
582 GeminiPart::Text { text } => {
583 let last_is_text =
585 matches!(self.partial.content.last(), Some(ContentBlock::Text(_)));
586
587 self.ensure_started();
591
592 let content_index = if last_is_text {
593 self.partial.content.len() - 1
594 } else {
595 let idx = self.partial.content.len();
596 self.partial
597 .content
598 .push(ContentBlock::Text(TextContent::new("")));
599 self.pending_events
600 .push_back(StreamEvent::TextStart { content_index: idx });
601 idx
602 };
603
604 if let Some(ContentBlock::Text(t)) =
605 self.partial.content.get_mut(content_index)
606 {
607 t.text.push_str(&text);
608 }
609
610 self.pending_events.push_back(StreamEvent::TextDelta {
611 content_index,
612 delta: text,
613 });
614 }
615 GeminiPart::FunctionCall { function_call } => {
616 let id = format!("call_{}", uuid::Uuid::new_v4().simple());
618
619 let args_str = serde_json::to_string(&function_call.args)
621 .unwrap_or_else(|_| "{}".to_string());
622 let GeminiFunctionCall { name, args } = function_call;
623
624 let tool_call = ToolCall {
625 id,
626 name,
627 arguments: args,
628 thought_signature: None,
629 };
630
631 self.partial
632 .content
633 .push(ContentBlock::ToolCall(tool_call.clone()));
634 let content_index = self.partial.content.len() - 1;
635
636 self.partial.stop_reason = StopReason::ToolUse;
638
639 self.ensure_started();
640
641 self.pending_events
643 .push_back(StreamEvent::ToolCallStart { content_index });
644 self.pending_events.push_back(StreamEvent::ToolCallDelta {
645 content_index,
646 delta: args_str,
647 });
648 self.pending_events.push_back(StreamEvent::ToolCallEnd {
649 content_index,
650 tool_call,
651 });
652 }
653 GeminiPart::InlineData { .. }
654 | GeminiPart::FunctionResponse { .. }
655 | GeminiPart::Unknown(_) => {
656 }
660 }
661 }
662 }
663
664 if has_finish_reason {
667 for (content_index, block) in self.partial.content.iter().enumerate() {
668 if let ContentBlock::Text(t) = block {
669 self.pending_events.push_back(StreamEvent::TextEnd {
670 content_index,
671 content: t.text.clone(),
672 });
673 } else if let ContentBlock::Thinking(t) = block {
674 self.pending_events.push_back(StreamEvent::ThinkingEnd {
675 content_index,
676 content: t.thinking.clone(),
677 });
678 }
679 }
680 }
681
682 Ok(())
683 }
684
685 fn ensure_started(&mut self) {
686 if !self.started {
687 self.started = true;
688 self.pending_events.push_back(StreamEvent::Start {
689 partial: self.partial.clone(),
690 });
691 }
692 }
693}
694
695#[derive(Debug, Serialize)]
700#[serde(rename_all = "camelCase")]
701pub struct GeminiRequest {
702 pub(crate) contents: Vec<GeminiContent>,
703 #[serde(skip_serializing_if = "Option::is_none")]
704 pub(crate) system_instruction: Option<GeminiContent>,
705 #[serde(skip_serializing_if = "Option::is_none")]
706 pub(crate) tools: Option<Vec<GeminiTool>>,
707 #[serde(skip_serializing_if = "Option::is_none")]
708 pub(crate) tool_config: Option<GeminiToolConfig>,
709 #[serde(skip_serializing_if = "Option::is_none")]
710 pub(crate) generation_config: Option<GeminiGenerationConfig>,
711}
712
713#[derive(Debug, Serialize, Deserialize)]
714#[serde(rename_all = "camelCase")]
715pub(crate) struct GeminiContent {
716 #[serde(skip_serializing_if = "Option::is_none")]
717 pub(crate) role: Option<String>,
718 pub(crate) parts: Vec<GeminiPart>,
719}
720
721#[derive(Debug, Serialize, Deserialize)]
722#[serde(untagged)]
723pub(crate) enum GeminiPart {
724 Text {
725 text: String,
726 },
727 InlineData {
728 inline_data: GeminiBlob,
729 },
730 FunctionCall {
731 #[serde(rename = "functionCall")]
732 function_call: GeminiFunctionCall,
733 },
734 FunctionResponse {
735 #[serde(rename = "functionResponse")]
736 function_response: GeminiFunctionResponse,
737 },
738 Unknown(serde_json::Value),
742}
743
744#[derive(Debug, Serialize, Deserialize)]
745#[serde(rename_all = "camelCase")]
746pub(crate) struct GeminiBlob {
747 pub(crate) mime_type: String,
748 pub(crate) data: String,
749}
750
751#[derive(Debug, Serialize, Deserialize)]
752pub(crate) struct GeminiFunctionCall {
753 pub(crate) name: String,
754 pub(crate) args: serde_json::Value,
755}
756
757#[derive(Debug, Serialize, Deserialize)]
758pub(crate) struct GeminiFunctionResponse {
759 pub(crate) name: String,
760 pub(crate) response: serde_json::Value,
761}
762
763#[derive(Debug, Serialize)]
764#[serde(rename_all = "camelCase")]
765pub(crate) struct GeminiTool {
766 pub(crate) function_declarations: Vec<GeminiFunctionDeclaration>,
767}
768
769#[derive(Debug, Serialize)]
770pub(crate) struct GeminiFunctionDeclaration {
771 pub(crate) name: String,
772 pub(crate) description: String,
773 pub(crate) parameters: serde_json::Value,
774}
775
776#[derive(Debug, Serialize)]
777#[serde(rename_all = "camelCase")]
778pub(crate) struct GeminiToolConfig {
779 pub(crate) function_calling_config: GeminiFunctionCallingConfig,
780}
781
782#[derive(Debug, Serialize)]
783pub(crate) struct GeminiFunctionCallingConfig {
784 pub(crate) mode: &'static str,
785}
786
787#[derive(Debug, Serialize)]
788#[serde(rename_all = "camelCase")]
789pub(crate) struct GeminiGenerationConfig {
790 #[serde(skip_serializing_if = "Option::is_none")]
791 pub(crate) max_output_tokens: Option<u32>,
792 #[serde(skip_serializing_if = "Option::is_none")]
793 pub(crate) temperature: Option<f32>,
794 #[serde(skip_serializing_if = "Option::is_none")]
795 pub(crate) candidate_count: Option<u32>,
796}
797
798#[derive(Debug, Deserialize)]
803#[serde(rename_all = "camelCase")]
804pub(crate) struct GeminiStreamResponse {
805 #[serde(default)]
806 pub(crate) candidates: Option<Vec<GeminiCandidate>>,
807 #[serde(default)]
808 pub(crate) usage_metadata: Option<GeminiUsageMetadata>,
809}
810
811#[derive(Debug, Deserialize)]
812#[serde(rename_all = "camelCase")]
813struct CloudCodeAssistResponseChunk {
814 #[serde(default)]
815 response: Option<CloudCodeAssistResponse>,
816}
817
818#[derive(Debug, Deserialize)]
819#[serde(rename_all = "camelCase")]
820struct CloudCodeAssistResponse {
821 #[serde(default)]
822 candidates: Option<Vec<GeminiCandidate>>,
823 #[serde(default)]
824 usage_metadata: Option<GeminiUsageMetadata>,
825}
826
827#[derive(Debug, Deserialize)]
828#[serde(rename_all = "camelCase")]
829pub(crate) struct GeminiCandidate {
830 #[serde(default)]
831 pub(crate) content: Option<GeminiContent>,
832 #[serde(default)]
833 pub(crate) finish_reason: Option<String>,
834}
835
836#[derive(Debug, Deserialize)]
837#[serde(rename_all = "camelCase")]
838#[allow(clippy::struct_field_names)]
839pub(crate) struct GeminiUsageMetadata {
840 #[serde(default)]
841 pub(crate) prompt_token_count: Option<u64>,
842 #[serde(default)]
843 pub(crate) candidates_token_count: Option<u64>,
844 #[serde(default)]
845 pub(crate) total_token_count: Option<u64>,
846}
847
848pub(crate) fn convert_message_to_gemini(message: &Message) -> Vec<GeminiContent> {
853 match message {
854 Message::User(user) => vec![GeminiContent {
855 role: Some("user".into()),
856 parts: convert_user_content_to_parts(&user.content),
857 }],
858 Message::Custom(custom) => vec![GeminiContent {
859 role: Some("user".into()),
860 parts: vec![GeminiPart::Text {
861 text: custom.content.clone(),
862 }],
863 }],
864 Message::Assistant(assistant) => {
865 let mut parts = Vec::new();
866
867 for block in &assistant.content {
868 match block {
869 ContentBlock::Text(t) => {
870 parts.push(GeminiPart::Text {
871 text: t.text.clone(),
872 });
873 }
874 ContentBlock::ToolCall(tc) => {
875 parts.push(GeminiPart::FunctionCall {
876 function_call: GeminiFunctionCall {
877 name: tc.name.clone(),
878 args: tc.arguments.clone(),
879 },
880 });
881 }
882 ContentBlock::Thinking(_) | ContentBlock::Image(_) => {
883 }
885 }
886 }
887
888 if parts.is_empty() {
889 return Vec::new();
890 }
891
892 vec![GeminiContent {
893 role: Some("model".into()),
894 parts,
895 }]
896 }
897 Message::ToolResult(result) => {
898 let content_text = result
900 .content
901 .iter()
902 .map(|b| match b {
903 ContentBlock::Text(t) => t.text.clone(),
904 ContentBlock::Image(img) => format!("[Image ({}) omitted]", img.mime_type),
905 _ => String::new(),
906 })
907 .filter(|s| !s.is_empty())
908 .collect::<Vec<_>>()
909 .join("\n");
910
911 let response_value = if result.is_error {
912 serde_json::json!({ "error": content_text })
913 } else {
914 serde_json::json!({ "result": content_text })
915 };
916
917 vec![GeminiContent {
918 role: Some("user".into()),
919 parts: vec![GeminiPart::FunctionResponse {
920 function_response: GeminiFunctionResponse {
921 name: result.tool_name.clone(),
922 response: response_value,
923 },
924 }],
925 }]
926 }
927 }
928}
929
930pub(crate) fn convert_user_content_to_parts(content: &UserContent) -> Vec<GeminiPart> {
931 match content {
932 UserContent::Text(text) => vec![GeminiPart::Text { text: text.clone() }],
933 UserContent::Blocks(blocks) => blocks
934 .iter()
935 .filter_map(|block| match block {
936 ContentBlock::Text(t) => Some(GeminiPart::Text {
937 text: t.text.clone(),
938 }),
939 ContentBlock::Image(img) => Some(GeminiPart::InlineData {
940 inline_data: GeminiBlob {
941 mime_type: img.mime_type.clone(),
942 data: img.data.clone(),
943 },
944 }),
945 _ => None,
946 })
947 .collect(),
948 }
949}
950
951pub(crate) fn convert_tool_to_gemini(tool: &ToolDef) -> GeminiFunctionDeclaration {
952 GeminiFunctionDeclaration {
953 name: tool.name.clone(),
954 description: tool.description.clone(),
955 parameters: tool.parameters.clone(),
956 }
957}
958
959#[cfg(test)]
964mod tests {
965 use super::*;
966 use asupersync::runtime::RuntimeBuilder;
967 use futures::{StreamExt, stream};
968 use serde::{Deserialize, Serialize};
969 use serde_json::Value;
970 use std::path::PathBuf;
971
972 #[test]
973 fn test_convert_user_text_message() {
974 let message = Message::User(crate::model::UserMessage {
975 content: UserContent::Text("Hello".to_string()),
976 timestamp: 0,
977 });
978
979 let converted = convert_message_to_gemini(&message);
980 assert_eq!(converted.len(), 1);
981 assert_eq!(converted[0].role, Some("user".to_string()));
982 }
983
984 #[test]
985 fn test_tool_conversion() {
986 let tool = ToolDef {
987 name: "test_tool".to_string(),
988 description: "A test tool".to_string(),
989 parameters: serde_json::json!({
990 "type": "object",
991 "properties": {
992 "arg": {"type": "string"}
993 }
994 }),
995 };
996
997 let converted = convert_tool_to_gemini(&tool);
998 assert_eq!(converted.name, "test_tool");
999 assert_eq!(converted.description, "A test tool");
1000 }
1001
1002 #[test]
1003 fn test_provider_info() {
1004 let provider = GeminiProvider::new("gemini-2.0-flash");
1005 assert_eq!(provider.name(), "google");
1006 assert_eq!(provider.api(), "google-generative-ai");
1007 }
1008
1009 #[test]
1010 fn test_streaming_url() {
1011 let provider = GeminiProvider::new("gemini-2.0-flash");
1012 let url = provider.streaming_url();
1013 assert!(url.contains("gemini-2.0-flash"));
1014 assert!(url.contains("streamGenerateContent"));
1015 assert!(!url.contains("key="));
1016 }
1017
1018 #[derive(Debug, Deserialize)]
1019 struct ProviderFixture {
1020 cases: Vec<ProviderCase>,
1021 }
1022
1023 #[derive(Debug, Deserialize)]
1024 struct ProviderCase {
1025 name: String,
1026 events: Vec<Value>,
1027 expected: Vec<EventSummary>,
1028 }
1029
1030 #[derive(Debug, Deserialize, Serialize, PartialEq)]
1031 struct EventSummary {
1032 kind: String,
1033 #[serde(default)]
1034 content_index: Option<usize>,
1035 #[serde(default)]
1036 delta: Option<String>,
1037 #[serde(default)]
1038 content: Option<String>,
1039 #[serde(default)]
1040 reason: Option<String>,
1041 }
1042
1043 #[test]
1044 fn test_stream_fixtures() {
1045 let fixture = load_fixture("gemini_stream.json");
1046 for case in fixture.cases {
1047 let events = collect_events(&case.events);
1048 let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1049 assert_eq!(summaries, case.expected, "case {}", case.name);
1050 }
1051 }
1052
1053 fn load_fixture(file_name: &str) -> ProviderFixture {
1054 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1055 .join("tests/fixtures/provider_responses")
1056 .join(file_name);
1057 let raw = std::fs::read_to_string(path).expect("fixture read");
1058 serde_json::from_str(&raw).expect("fixture parse")
1059 }
1060
1061 fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1062 let runtime = RuntimeBuilder::current_thread()
1063 .build()
1064 .expect("runtime build");
1065 runtime.block_on(async move {
1066 let byte_stream = stream::iter(
1067 events
1068 .iter()
1069 .map(|event| {
1070 let data = match event {
1071 Value::String(text) => text.clone(),
1072 _ => serde_json::to_string(event).expect("serialize event"),
1073 };
1074 format!("data: {data}\n\n").into_bytes()
1075 })
1076 .map(Ok),
1077 );
1078 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1079 let mut state = StreamState::new(
1080 event_source,
1081 "gemini-test".to_string(),
1082 "google-generative".to_string(),
1083 "google".to_string(),
1084 );
1085 let mut out = Vec::new();
1086
1087 loop {
1088 let Some(item) = state.event_source.next().await else {
1089 if !state.finished {
1090 state.finished = true;
1091 out.push(StreamEvent::Done {
1092 reason: state.partial.stop_reason,
1093 message: std::mem::take(&mut state.partial),
1094 });
1095 }
1096 break;
1097 };
1098
1099 let msg = item.expect("SSE event");
1100 if msg.event == "ping" {
1101 continue;
1102 }
1103 state.process_event(&msg.data).expect("process_event");
1104 out.extend(state.pending_events.drain(..));
1105 }
1106
1107 out
1108 })
1109 }
1110
1111 fn summarize_event(event: &StreamEvent) -> EventSummary {
1112 match event {
1113 StreamEvent::Start { .. } => EventSummary {
1114 kind: "start".to_string(),
1115 content_index: None,
1116 delta: None,
1117 content: None,
1118 reason: None,
1119 },
1120 StreamEvent::TextDelta {
1121 content_index,
1122 delta,
1123 ..
1124 } => EventSummary {
1125 kind: "text_delta".to_string(),
1126 content_index: Some(*content_index),
1127 delta: Some(delta.clone()),
1128 content: None,
1129 reason: None,
1130 },
1131 StreamEvent::Done { reason, .. } => EventSummary {
1132 kind: "done".to_string(),
1133 content_index: None,
1134 delta: None,
1135 content: None,
1136 reason: Some(reason_to_string(*reason)),
1137 },
1138 StreamEvent::Error { reason, .. } => EventSummary {
1139 kind: "error".to_string(),
1140 content_index: None,
1141 delta: None,
1142 content: None,
1143 reason: Some(reason_to_string(*reason)),
1144 },
1145 StreamEvent::TextStart { content_index, .. } => EventSummary {
1146 kind: "text_start".to_string(),
1147 content_index: Some(*content_index),
1148 delta: None,
1149 content: None,
1150 reason: None,
1151 },
1152 StreamEvent::TextEnd {
1153 content_index,
1154 content,
1155 ..
1156 } => EventSummary {
1157 kind: "text_end".to_string(),
1158 content_index: Some(*content_index),
1159 delta: None,
1160 content: Some(content.clone()),
1161 reason: None,
1162 },
1163 _ => EventSummary {
1164 kind: "other".to_string(),
1165 content_index: None,
1166 delta: None,
1167 content: None,
1168 reason: None,
1169 },
1170 }
1171 }
1172
1173 fn reason_to_string(reason: StopReason) -> String {
1174 match reason {
1175 StopReason::Stop => "stop",
1176 StopReason::Length => "length",
1177 StopReason::ToolUse => "tool_use",
1178 StopReason::Error => "error",
1179 StopReason::Aborted => "aborted",
1180 }
1181 .to_string()
1182 }
1183
1184 #[test]
1187 fn test_build_request_basic_text() {
1188 let provider = GeminiProvider::new("gemini-2.0-flash");
1189 let context = Context::owned(
1190 Some("You are helpful.".to_string()),
1191 vec![Message::User(crate::model::UserMessage {
1192 content: UserContent::Text("What is Rust?".to_string()),
1193 timestamp: 0,
1194 })],
1195 vec![],
1196 );
1197 let options = crate::provider::StreamOptions {
1198 max_tokens: Some(1024),
1199 temperature: Some(0.7),
1200 ..Default::default()
1201 };
1202
1203 let req = provider.build_request(&context, &options);
1204 let json = serde_json::to_value(&req).expect("serialize");
1205
1206 let contents = json["contents"].as_array().expect("contents array");
1208 assert_eq!(contents.len(), 1);
1209 assert_eq!(contents[0]["role"], "user");
1210 assert_eq!(contents[0]["parts"][0]["text"], "What is Rust?");
1211
1212 assert_eq!(
1214 json["systemInstruction"]["parts"][0]["text"],
1215 "You are helpful."
1216 );
1217
1218 assert!(json.get("tools").is_none() || json["tools"].is_null());
1220
1221 assert_eq!(json["generationConfig"]["maxOutputTokens"], 1024);
1223 assert!((json["generationConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.01);
1224 assert_eq!(json["generationConfig"]["candidateCount"], 1);
1225 }
1226
1227 #[test]
1228 fn test_build_request_with_tools() {
1229 let provider = GeminiProvider::new("gemini-2.0-flash");
1230 let context = Context::owned(
1231 None,
1232 vec![Message::User(crate::model::UserMessage {
1233 content: UserContent::Text("Read a file".to_string()),
1234 timestamp: 0,
1235 })],
1236 vec![
1237 ToolDef {
1238 name: "read".to_string(),
1239 description: "Read a file".to_string(),
1240 parameters: serde_json::json!({
1241 "type": "object",
1242 "properties": {
1243 "path": {"type": "string"}
1244 },
1245 "required": ["path"]
1246 }),
1247 },
1248 ToolDef {
1249 name: "write".to_string(),
1250 description: "Write a file".to_string(),
1251 parameters: serde_json::json!({
1252 "type": "object",
1253 "properties": {
1254 "path": {"type": "string"},
1255 "content": {"type": "string"}
1256 }
1257 }),
1258 },
1259 ],
1260 );
1261 let options = crate::provider::StreamOptions::default();
1262
1263 let req = provider.build_request(&context, &options);
1264 let json = serde_json::to_value(&req).expect("serialize");
1265
1266 assert!(json.get("systemInstruction").is_none() || json["systemInstruction"].is_null());
1268
1269 let tools = json["tools"].as_array().expect("tools array");
1271 assert_eq!(tools.len(), 1);
1272 let declarations = tools[0]["functionDeclarations"]
1273 .as_array()
1274 .expect("declarations");
1275 assert_eq!(declarations.len(), 2);
1276 assert_eq!(declarations[0]["name"], "read");
1277 assert_eq!(declarations[1]["name"], "write");
1278 assert_eq!(declarations[0]["description"], "Read a file");
1279
1280 assert_eq!(json["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
1282 }
1283
1284 #[test]
1285 fn test_build_request_default_max_tokens() {
1286 let provider = GeminiProvider::new("gemini-2.0-flash");
1287 let context = Context::owned(
1288 None,
1289 vec![Message::User(crate::model::UserMessage {
1290 content: UserContent::Text("hi".to_string()),
1291 timestamp: 0,
1292 })],
1293 vec![],
1294 );
1295 let options = crate::provider::StreamOptions::default();
1296
1297 let req = provider.build_request(&context, &options);
1298 let json = serde_json::to_value(&req).expect("serialize");
1299
1300 assert_eq!(
1302 json["generationConfig"]["maxOutputTokens"],
1303 DEFAULT_MAX_TOKENS
1304 );
1305 }
1306
1307 #[test]
1310 fn test_streaming_url_no_key_query_param() {
1311 let provider = GeminiProvider::new("gemini-2.0-flash");
1312 let url = provider.streaming_url();
1313
1314 assert!(
1316 !url.contains("key="),
1317 "API key should not be in query param"
1318 );
1319 assert!(url.contains("alt=sse"), "alt=sse should be present");
1320 assert!(
1321 url.contains("streamGenerateContent"),
1322 "should use streaming endpoint"
1323 );
1324 }
1325
1326 #[test]
1327 fn test_streaming_url_custom_base() {
1328 let provider =
1329 GeminiProvider::new("gemini-pro").with_base_url("https://custom.example.com/v1");
1330 let url = provider.streaming_url();
1331
1332 assert!(url.starts_with("https://custom.example.com/v1/models/gemini-pro"));
1333 assert!(!url.contains("key="));
1334 }
1335
1336 #[test]
1339 fn test_convert_user_text_to_gemini_parts() {
1340 let parts = convert_user_content_to_parts(&UserContent::Text("hello world".to_string()));
1341 assert_eq!(parts.len(), 1);
1342 match &parts[0] {
1343 GeminiPart::Text { text } => assert_eq!(text, "hello world"),
1344 _ => panic!("expected text part"),
1345 }
1346 }
1347
1348 #[test]
1349 fn test_convert_user_blocks_with_image_to_gemini_parts() {
1350 let content = UserContent::Blocks(vec![
1351 ContentBlock::Text(TextContent::new("describe this")),
1352 ContentBlock::Image(crate::model::ImageContent {
1353 data: "aGVsbG8=".to_string(),
1354 mime_type: "image/png".to_string(),
1355 }),
1356 ]);
1357
1358 let parts = convert_user_content_to_parts(&content);
1359 assert_eq!(parts.len(), 2);
1360 match &parts[0] {
1361 GeminiPart::Text { text } => assert_eq!(text, "describe this"),
1362 _ => panic!("expected text part"),
1363 }
1364 match &parts[1] {
1365 GeminiPart::InlineData { inline_data } => {
1366 assert_eq!(inline_data.mime_type, "image/png");
1367 assert_eq!(inline_data.data, "aGVsbG8=");
1368 }
1369 _ => panic!("expected inline_data part"),
1370 }
1371 }
1372
1373 #[test]
1374 fn test_convert_assistant_message_with_tool_call() {
1375 let message = Message::assistant(AssistantMessage {
1376 content: vec![
1377 ContentBlock::Text(TextContent::new("Let me read that file.")),
1378 ContentBlock::ToolCall(ToolCall {
1379 id: "call_123".to_string(),
1380 name: "read".to_string(),
1381 arguments: serde_json::json!({"path": "/tmp/test.txt"}),
1382 thought_signature: None,
1383 }),
1384 ],
1385 api: "google".to_string(),
1386 provider: "google".to_string(),
1387 model: "gemini-2.0-flash".to_string(),
1388 usage: Usage::default(),
1389 stop_reason: StopReason::ToolUse,
1390 error_message: None,
1391 timestamp: 0,
1392 });
1393
1394 let converted = convert_message_to_gemini(&message);
1395 assert_eq!(converted.len(), 1);
1396 assert_eq!(converted[0].role, Some("model".to_string()));
1397 assert_eq!(converted[0].parts.len(), 2);
1398
1399 match &converted[0].parts[0] {
1400 GeminiPart::Text { text } => assert_eq!(text, "Let me read that file."),
1401 _ => panic!("expected text part"),
1402 }
1403 match &converted[0].parts[1] {
1404 GeminiPart::FunctionCall { function_call } => {
1405 assert_eq!(function_call.name, "read");
1406 assert_eq!(function_call.args["path"], "/tmp/test.txt");
1407 }
1408 _ => panic!("expected function_call part"),
1409 }
1410 }
1411
1412 #[test]
1413 fn test_convert_assistant_empty_content_returns_empty() {
1414 let message = Message::assistant(AssistantMessage {
1415 content: vec![],
1416 api: "google".to_string(),
1417 provider: "google".to_string(),
1418 model: "gemini-2.0-flash".to_string(),
1419 usage: Usage::default(),
1420 stop_reason: StopReason::Stop,
1421 error_message: None,
1422 timestamp: 0,
1423 });
1424
1425 let converted = convert_message_to_gemini(&message);
1426 assert!(converted.is_empty());
1427 }
1428
1429 #[test]
1430 fn test_convert_tool_result_success() {
1431 let message = Message::tool_result(crate::model::ToolResultMessage {
1432 tool_call_id: "call_123".to_string(),
1433 tool_name: "read".to_string(),
1434 content: vec![ContentBlock::Text(TextContent::new("file contents here"))],
1435 details: None,
1436 is_error: false,
1437 timestamp: 0,
1438 });
1439
1440 let converted = convert_message_to_gemini(&message);
1441 assert_eq!(converted.len(), 1);
1442 assert_eq!(converted[0].role, Some("user".to_string()));
1443
1444 match &converted[0].parts[0] {
1445 GeminiPart::FunctionResponse { function_response } => {
1446 assert_eq!(function_response.name, "read");
1447 assert_eq!(function_response.response["result"], "file contents here");
1448 assert!(function_response.response.get("error").is_none());
1449 }
1450 _ => panic!("expected function_response part"),
1451 }
1452 }
1453
1454 #[test]
1455 fn test_convert_tool_result_error() {
1456 let message = Message::tool_result(crate::model::ToolResultMessage {
1457 tool_call_id: "call_456".to_string(),
1458 tool_name: "bash".to_string(),
1459 content: vec![ContentBlock::Text(TextContent::new("command not found"))],
1460 details: None,
1461 is_error: true,
1462 timestamp: 0,
1463 });
1464
1465 let converted = convert_message_to_gemini(&message);
1466 assert_eq!(converted.len(), 1);
1467
1468 match &converted[0].parts[0] {
1469 GeminiPart::FunctionResponse { function_response } => {
1470 assert_eq!(function_response.name, "bash");
1471 assert_eq!(function_response.response["error"], "command not found");
1472 assert!(function_response.response.get("result").is_none());
1473 }
1474 _ => panic!("expected function_response part"),
1475 }
1476 }
1477
1478 #[test]
1479 fn test_convert_custom_message() {
1480 let message = Message::Custom(crate::model::CustomMessage {
1481 custom_type: "system_note".to_string(),
1482 content: "Context window approaching limit.".to_string(),
1483 display: false,
1484 details: None,
1485 timestamp: 0,
1486 });
1487
1488 let converted = convert_message_to_gemini(&message);
1489 assert_eq!(converted.len(), 1);
1490 assert_eq!(converted[0].role, Some("user".to_string()));
1491 match &converted[0].parts[0] {
1492 GeminiPart::Text { text } => {
1493 assert_eq!(text, "Context window approaching limit.");
1494 }
1495 _ => panic!("expected text part"),
1496 }
1497 }
1498
1499 #[test]
1502 fn test_stop_reason_mapping() {
1503 let test_cases = vec![
1505 ("STOP", StopReason::Stop),
1506 ("MAX_TOKENS", StopReason::Length),
1507 ("SAFETY", StopReason::Error),
1508 ("RECITATION", StopReason::Error),
1509 ("OTHER", StopReason::Error),
1510 ("UNKNOWN_REASON", StopReason::Stop), ];
1512
1513 for (reason_str, expected) in test_cases {
1514 let candidate = GeminiCandidate {
1515 content: None,
1516 finish_reason: Some(reason_str.to_string()),
1517 };
1518
1519 let runtime = RuntimeBuilder::current_thread().build().unwrap();
1520 runtime.block_on(async {
1521 let byte_stream = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1522 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1523 let mut state = StreamState::new(
1524 event_source,
1525 "test".to_string(),
1526 "test".to_string(),
1527 "test".to_string(),
1528 );
1529 state.process_candidate(candidate).unwrap();
1530 assert_eq!(
1531 state.partial.stop_reason, expected,
1532 "finish_reason '{reason_str}' should map to {expected:?}"
1533 );
1534 });
1535 }
1536 }
1537
1538 #[test]
1539 fn test_usage_metadata_parsing() {
1540 let data = r#"{
1541 "usageMetadata": {
1542 "promptTokenCount": 42,
1543 "candidatesTokenCount": 100,
1544 "totalTokenCount": 142
1545 }
1546 }"#;
1547
1548 let runtime = RuntimeBuilder::current_thread().build().unwrap();
1549 runtime.block_on(async {
1550 let byte_stream = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1551 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1552 let mut state = StreamState::new(
1553 event_source,
1554 "test".to_string(),
1555 "test".to_string(),
1556 "test".to_string(),
1557 );
1558 state.process_event(data).unwrap();
1559 assert_eq!(state.partial.usage.input, 42);
1560 assert_eq!(state.partial.usage.output, 100);
1561 assert_eq!(state.partial.usage.total_tokens, 142);
1562 });
1563 }
1564
1565 #[test]
1568 fn test_build_request_full_conversation() {
1569 let provider = GeminiProvider::new("gemini-2.0-flash");
1570 let context = Context::owned(
1571 Some("Be concise.".to_string()),
1572 vec![
1573 Message::User(crate::model::UserMessage {
1574 content: UserContent::Text("Read /tmp/a.txt".to_string()),
1575 timestamp: 0,
1576 }),
1577 Message::assistant(AssistantMessage {
1578 content: vec![ContentBlock::ToolCall(ToolCall {
1579 id: "call_1".to_string(),
1580 name: "read".to_string(),
1581 arguments: serde_json::json!({"path": "/tmp/a.txt"}),
1582 thought_signature: None,
1583 })],
1584 api: "google".to_string(),
1585 provider: "google".to_string(),
1586 model: "gemini-2.0-flash".to_string(),
1587 usage: Usage::default(),
1588 stop_reason: StopReason::ToolUse,
1589 error_message: None,
1590 timestamp: 1,
1591 }),
1592 Message::tool_result(crate::model::ToolResultMessage {
1593 tool_call_id: "call_1".to_string(),
1594 tool_name: "read".to_string(),
1595 content: vec![ContentBlock::Text(TextContent::new("file contents"))],
1596 details: None,
1597 is_error: false,
1598 timestamp: 2,
1599 }),
1600 ],
1601 vec![ToolDef {
1602 name: "read".to_string(),
1603 description: "Read a file".to_string(),
1604 parameters: serde_json::json!({"type": "object"}),
1605 }],
1606 );
1607 let options = crate::provider::StreamOptions::default();
1608
1609 let req = provider.build_request(&context, &options);
1610 let json = serde_json::to_value(&req).expect("serialize");
1611
1612 let contents = json["contents"].as_array().expect("contents");
1613 assert_eq!(contents.len(), 3); assert_eq!(contents[0]["role"], "user");
1617 assert_eq!(contents[0]["parts"][0]["text"], "Read /tmp/a.txt");
1618
1619 assert_eq!(contents[1]["role"], "model");
1621 assert_eq!(contents[1]["parts"][0]["functionCall"]["name"], "read");
1622
1623 assert_eq!(contents[2]["role"], "user");
1625 assert_eq!(contents[2]["parts"][0]["functionResponse"]["name"], "read");
1626 assert_eq!(
1627 contents[2]["parts"][0]["functionResponse"]["response"]["result"],
1628 "file contents"
1629 );
1630 }
1631
1632 mod proptest_process_event {
1637 use super::*;
1638 use proptest::prelude::*;
1639
1640 fn make_state()
1641 -> StreamState<impl Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin>
1642 {
1643 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1644 let sse = crate::sse::SseStream::new(Box::pin(empty));
1645 StreamState::new(
1646 sse,
1647 "gemini-test".into(),
1648 "google-generative".into(),
1649 "google".into(),
1650 )
1651 }
1652
1653 fn small_string() -> impl Strategy<Value = String> {
1654 prop_oneof![Just(String::new()), "[a-zA-Z0-9_]{1,16}", "[ -~]{0,32}",]
1655 }
1656
1657 fn token_count() -> impl Strategy<Value = u64> {
1658 prop_oneof![
1659 5 => 0u64..10_000u64,
1660 2 => Just(0u64),
1661 1 => Just(u64::MAX),
1662 1 => (u64::MAX - 100)..=u64::MAX,
1663 ]
1664 }
1665
1666 fn finish_reason() -> impl Strategy<Value = Option<String>> {
1667 prop_oneof![
1668 3 => Just(None),
1669 1 => Just(Some("STOP".to_string())),
1670 1 => Just(Some("MAX_TOKENS".to_string())),
1671 1 => Just(Some("SAFETY".to_string())),
1672 1 => Just(Some("RECITATION".to_string())),
1673 1 => Just(Some("OTHER".to_string())),
1674 1 => small_string().prop_map(Some),
1675 ]
1676 }
1677
1678 fn json_args() -> impl Strategy<Value = serde_json::Value> {
1680 prop_oneof![
1681 Just(serde_json::json!({})),
1682 Just(serde_json::json!({"key": "value"})),
1683 Just(serde_json::json!({"a": 1, "b": true, "c": null})),
1684 small_string().prop_map(|s| serde_json::json!({"input": s})),
1685 ]
1686 }
1687
1688 fn text_part() -> impl Strategy<Value = serde_json::Value> {
1690 small_string().prop_map(|t| serde_json::json!({"text": t}))
1691 }
1692
1693 fn function_call_part() -> impl Strategy<Value = serde_json::Value> {
1695 (small_string(), json_args()).prop_map(
1696 |(name, args)| serde_json::json!({"functionCall": {"name": name, "args": args}}),
1697 )
1698 }
1699
1700 fn parts_strategy() -> impl Strategy<Value = Vec<serde_json::Value>> {
1702 prop::collection::vec(
1703 prop_oneof![3 => text_part(), 1 => function_call_part(),],
1704 0..5,
1705 )
1706 }
1707
1708 fn gemini_response_json() -> impl Strategy<Value = String> {
1710 prop_oneof![
1711 3 => (parts_strategy(), finish_reason()).prop_map(|(parts, fr)| {
1713 let mut candidate = serde_json::json!({
1714 "content": {"parts": parts}
1715 });
1716 if let Some(r) = fr {
1717 candidate["finishReason"] = serde_json::Value::String(r);
1718 }
1719 serde_json::json!({"candidates": [candidate]}).to_string()
1720 }),
1721 2 => (token_count(), token_count(), token_count()).prop_map(|(p, c, t)| {
1723 serde_json::json!({
1724 "usageMetadata": {
1725 "promptTokenCount": p,
1726 "candidatesTokenCount": c,
1727 "totalTokenCount": t
1728 }
1729 })
1730 .to_string()
1731 }),
1732 1 => Just(r#"{"candidates":[]}"#.to_string()),
1734 1 => Just(r"{}".to_string()),
1736 1 => finish_reason()
1738 .prop_filter("some reason", Option::is_some)
1739 .prop_map(|fr| {
1740 serde_json::json!({
1741 "candidates": [{"finishReason": fr.unwrap()}]
1742 })
1743 .to_string()
1744 }),
1745 2 => (parts_strategy(), finish_reason(), token_count(), token_count(), token_count())
1747 .prop_map(|(parts, fr, p, c, t)| {
1748 let mut candidate = serde_json::json!({
1749 "content": {"parts": parts}
1750 });
1751 if let Some(r) = fr {
1752 candidate["finishReason"] = serde_json::Value::String(r);
1753 }
1754 serde_json::json!({
1755 "candidates": [candidate],
1756 "usageMetadata": {
1757 "promptTokenCount": p,
1758 "candidatesTokenCount": c,
1759 "totalTokenCount": t
1760 }
1761 })
1762 .to_string()
1763 }),
1764 ]
1765 }
1766
1767 fn chaos_json() -> impl Strategy<Value = String> {
1769 prop_oneof![
1770 Just(String::new()),
1771 Just("{}".to_string()),
1772 Just("[]".to_string()),
1773 Just("null".to_string()),
1774 Just("{".to_string()),
1775 Just(r#"{"candidates":"not_array"}"#.to_string()),
1776 Just(r#"{"candidates":[{"content":null}]}"#.to_string()),
1777 Just(r#"{"candidates":[{"content":{"parts":"not_array"}}]}"#.to_string()),
1778 "[ -~]{0,64}",
1779 ]
1780 }
1781
1782 proptest! {
1783 #![proptest_config(ProptestConfig {
1784 cases: 256,
1785 max_shrink_iters: 100,
1786 .. ProptestConfig::default()
1787 })]
1788
1789 #[test]
1790 fn process_event_valid_never_panics(data in gemini_response_json()) {
1791 let mut state = make_state();
1792 let _ = state.process_event(&data);
1793 }
1794
1795 #[test]
1796 fn process_event_chaos_never_panics(data in chaos_json()) {
1797 let mut state = make_state();
1798 let _ = state.process_event(&data);
1799 }
1800
1801 #[test]
1802 fn process_event_sequence_never_panics(
1803 events in prop::collection::vec(gemini_response_json(), 1..8)
1804 ) {
1805 let mut state = make_state();
1806 for event in &events {
1807 let _ = state.process_event(event);
1808 }
1809 }
1810
1811 #[test]
1812 fn process_event_mixed_sequence_never_panics(
1813 events in prop::collection::vec(
1814 prop_oneof![gemini_response_json(), chaos_json()],
1815 1..12
1816 )
1817 ) {
1818 let mut state = make_state();
1819 for event in &events {
1820 let _ = state.process_event(event);
1821 }
1822 }
1823 }
1824 }
1825}
1826
1827#[cfg(feature = "fuzzing")]
1832pub mod fuzz {
1833 use super::*;
1834 use futures::stream;
1835 use std::pin::Pin;
1836
1837 type FuzzStream =
1838 Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1839
1840 pub struct Processor(StreamState<FuzzStream>);
1842
1843 impl Default for Processor {
1844 fn default() -> Self {
1845 Self::new()
1846 }
1847 }
1848
1849 impl Processor {
1850 pub fn new() -> Self {
1852 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1853 Self(StreamState::new(
1854 crate::sse::SseStream::new(Box::pin(empty)),
1855 "gemini-fuzz".into(),
1856 "google-generative".into(),
1857 "google".into(),
1858 ))
1859 }
1860
1861 pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1863 self.0.process_event(data)?;
1864 Ok(self.0.pending_events.drain(..).collect())
1865 }
1866 }
1867}