1use crate::error::{Error, Result};
7use crate::http::client::Client;
8use crate::model::{
9 AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ToolCall, Usage,
10 UserContent,
11};
12use crate::models::CompatConfig;
13use crate::provider::{Context, Provider, StreamOptions, ToolDef};
14use crate::sse::SseStream;
15use async_trait::async_trait;
16use futures::StreamExt;
17use futures::stream::{self, Stream};
18use serde::{Deserialize, Serialize};
19use std::collections::VecDeque;
20use std::pin::Pin;
21
22const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
27const GOOGLE_GEMINI_CLI_BASE: &str = "https://cloudcode-pa.googleapis.com";
28const GOOGLE_ANTIGRAVITY_BASE: &str = "https://daily-cloudcode-pa.sandbox.googleapis.com";
29pub(crate) const DEFAULT_MAX_TOKENS: u32 = 8192;
30
31fn authorization_override(
32 options: &StreamOptions,
33 compat: Option<&CompatConfig>,
34) -> Option<String> {
35 super::first_non_empty_header_value_case_insensitive(&options.headers, &["authorization"])
36 .or_else(|| {
37 compat
38 .and_then(|compat| compat.custom_headers.as_ref())
39 .and_then(|headers| {
40 super::first_non_empty_header_value_case_insensitive(
41 headers,
42 &["authorization"],
43 )
44 })
45 })
46}
47
48fn google_api_key_override(
49 options: &StreamOptions,
50 compat: Option<&CompatConfig>,
51) -> Option<String> {
52 super::first_non_empty_header_value_case_insensitive(&options.headers, &["x-goog-api-key"])
53 .or_else(|| {
54 compat
55 .and_then(|compat| compat.custom_headers.as_ref())
56 .and_then(|headers| {
57 super::first_non_empty_header_value_case_insensitive(
58 headers,
59 &["x-goog-api-key"],
60 )
61 })
62 })
63}
64
65pub struct GeminiProvider {
71 client: Client,
72 model: String,
73 base_url: String,
74 provider: String,
75 api: String,
76 google_cli_mode: bool,
77 compat: Option<CompatConfig>,
78}
79
80impl GeminiProvider {
81 pub fn new(model: impl Into<String>) -> Self {
83 Self {
84 client: Client::new(),
85 model: model.into(),
86 base_url: GEMINI_API_BASE.to_string(),
87 provider: "google".to_string(),
88 api: "google-generative-ai".to_string(),
89 google_cli_mode: false,
90 compat: None,
91 }
92 }
93
94 #[must_use]
96 pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
97 self.provider = provider.into();
98 self
99 }
100
101 #[must_use]
103 pub fn with_api_name(mut self, api: impl Into<String>) -> Self {
104 self.api = api.into();
105 self
106 }
107
108 #[must_use]
110 pub const fn with_google_cli_mode(mut self, enabled: bool) -> Self {
111 self.google_cli_mode = enabled;
112 self
113 }
114
115 #[must_use]
117 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
118 self.base_url = base_url.into();
119 self
120 }
121
122 #[must_use]
124 pub fn with_client(mut self, client: Client) -> Self {
125 self.client = client;
126 self
127 }
128
129 #[must_use]
131 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
132 self.compat = compat;
133 self
134 }
135
136 pub fn streaming_url(&self) -> String {
138 let base = {
139 let trimmed = self.base_url.trim();
140 if trimmed.is_empty() {
141 if self.google_cli_mode {
142 if self.provider.eq_ignore_ascii_case("google-antigravity") {
143 GOOGLE_ANTIGRAVITY_BASE
144 } else {
145 GOOGLE_GEMINI_CLI_BASE
146 }
147 } else {
148 GEMINI_API_BASE
149 }
150 } else {
151 trimmed
152 }
153 };
154 if self.google_cli_mode {
155 format!("{base}/v1internal:streamGenerateContent?alt=sse")
156 } else {
157 format!("{base}/models/{}:streamGenerateContent?alt=sse", self.model)
158 }
159 }
160
161 #[allow(clippy::unused_self)]
163 pub fn build_request(&self, context: &Context<'_>, options: &StreamOptions) -> GeminiRequest {
164 let contents = Self::build_contents(context);
165 let system_instruction = context.system_prompt.as_deref().map(|s| GeminiContent {
166 role: None,
167 parts: vec![GeminiPart::Text {
168 text: s.to_string(),
169 }],
170 });
171
172 let tools: Option<Vec<GeminiTool>> = if context.tools.is_empty() {
173 None
174 } else {
175 Some(vec![GeminiTool {
176 function_declarations: context.tools.iter().map(convert_tool_to_gemini).collect(),
177 }])
178 };
179
180 let tool_config = if tools.is_some() {
181 Some(GeminiToolConfig {
182 function_calling_config: GeminiFunctionCallingConfig { mode: "AUTO" },
183 })
184 } else {
185 None
186 };
187
188 GeminiRequest {
189 contents,
190 system_instruction,
191 tools,
192 tool_config,
193 generation_config: Some(GeminiGenerationConfig {
194 max_output_tokens: options.max_tokens.or(Some(DEFAULT_MAX_TOKENS)),
195 temperature: options.temperature,
196 candidate_count: Some(1),
197 }),
198 }
199 }
200
201 fn build_contents(context: &Context<'_>) -> Vec<GeminiContent> {
203 let mut contents = Vec::with_capacity(context.messages.len());
204
205 for message in context.messages.iter() {
206 contents.extend(convert_message_to_gemini(message));
207 }
208
209 contents
210 }
211}
212
213#[derive(Debug, Serialize)]
214#[serde(rename_all = "camelCase")]
215struct CloudCodeAssistRequest {
216 project: String,
217 model: String,
218 request: GeminiRequest,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 request_type: Option<String>,
221 user_agent: String,
222 request_id: String,
223}
224
225fn build_google_cli_request(
226 model_id: &str,
227 project_id: &str,
228 request: GeminiRequest,
229 is_antigravity: bool,
230) -> std::result::Result<CloudCodeAssistRequest, &'static str> {
231 let safe_project = project_id.trim();
232 if safe_project.is_empty() {
233 return Err(
234 "Missing Google Cloud project ID for Gemini CLI. Set GOOGLE_CLOUD_PROJECT (or configure gcloud) and re-authenticate with /login google-gemini-cli.",
235 );
236 }
237 let project = if safe_project.starts_with("projects/") {
238 safe_project.to_string()
239 } else {
240 format!("projects/{safe_project}/locations/global")
241 };
242 Ok(CloudCodeAssistRequest {
243 project,
244 model: model_id.to_string(),
245 request,
246 request_type: is_antigravity.then(|| "agent".to_string()),
247 user_agent: if is_antigravity {
248 "antigravity".to_string()
249 } else {
250 "pi-coding-agent".to_string()
251 },
252 request_id: format!(
253 "{}-{}",
254 if is_antigravity { "agent" } else { "pi" },
255 uuid::Uuid::new_v4().simple()
256 ),
257 })
258}
259
260fn decode_project_scoped_access_payload(payload: &str) -> Option<(String, String)> {
261 let value: serde_json::Value = serde_json::from_str(payload).ok()?;
262 let token = value
263 .get("token")
264 .and_then(serde_json::Value::as_str)
265 .map(str::trim)
266 .filter(|value| !value.is_empty())?
267 .to_string();
268 let project_id = value
269 .get("projectId")
270 .or_else(|| value.get("project_id"))
271 .and_then(serde_json::Value::as_str)
272 .map(str::trim)
273 .filter(|value| !value.is_empty())?
274 .to_string();
275 Some((token, project_id))
276}
277
278#[async_trait]
279impl Provider for GeminiProvider {
280 fn name(&self) -> &str {
281 &self.provider
282 }
283
284 fn api(&self) -> &str {
285 &self.api
286 }
287
288 fn model_id(&self) -> &str {
289 &self.model
290 }
291
292 #[allow(clippy::too_many_lines)]
293 async fn stream(
294 &self,
295 context: &Context<'_>,
296 options: &StreamOptions,
297 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
298 let request_body = self.build_request(context, options);
299 let url = self.streaming_url();
300
301 let mut request = self.client.post(&url).header("Accept", "text/event-stream");
303
304 if self.google_cli_mode {
305 let api_payload = options.api_key.clone().ok_or_else(|| {
306 Error::provider(
307 self.name(),
308 "Google Gemini CLI requires OAuth credentials. Run /login google-gemini-cli.",
309 )
310 })?;
311 let (access_token, project_id) = decode_project_scoped_access_payload(&api_payload)
312 .ok_or_else(|| {
313 Error::provider(
314 self.name(),
315 "Invalid Google Gemini CLI OAuth payload (expected JSON {token, projectId}). Run /login google-gemini-cli again.",
316 )
317 })?;
318 let is_antigravity = self.provider.eq_ignore_ascii_case("google-antigravity");
319
320 let platform = crate::platform::platform_tag();
322 let pi_version = crate::platform::VERSION;
323 let client_metadata = format!(
324 r#"{{"ideType":"CLI","platform":"{}","pluginType":"GEMINI"}}"#,
325 crate::platform::os_name().to_ascii_uppercase(),
326 );
327 let api_client_tag = format!("pi-agent-rust/{pi_version}");
328
329 request = request
330 .header("Authorization", format!("Bearer {access_token}"))
331 .header("Content-Type", "application/json")
332 .header("x-goog-api-client", api_client_tag)
333 .header("client-metadata", client_metadata);
334
335 if is_antigravity {
336 let antigravity_version = std::env::var("PI_AI_ANTIGRAVITY_VERSION")
337 .unwrap_or_else(|_| pi_version.to_string());
338 request = request.header(
339 "User-Agent",
340 format!("antigravity/{antigravity_version} {platform}"),
341 );
342 } else {
343 request = request.header("User-Agent", crate::platform::pi_user_agent());
344 }
345
346 if let Some(compat) = &self.compat {
348 if let Some(custom_headers) = &compat.custom_headers {
349 request = super::apply_headers_ignoring_blank_auth_overrides(
350 request,
351 custom_headers,
352 &["authorization", "x-goog-api-key"],
353 );
354 }
355 }
356
357 request = super::apply_headers_ignoring_blank_auth_overrides(
359 request,
360 &options.headers,
361 &["authorization", "x-goog-api-key"],
362 );
363
364 let cli_request =
365 build_google_cli_request(&self.model, &project_id, request_body, is_antigravity)
366 .map_err(|message| Error::provider(self.name(), message.to_string()))?;
367 let request = request.json(&cli_request)?;
368 let response = Box::pin(request.send()).await?;
369 let status = response.status();
370 if !(200..300).contains(&status) {
371 let body = response
372 .text()
373 .await
374 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
375 return Err(Error::provider(
376 self.name(),
377 format!("Gemini CLI API error (HTTP {status}): {body}"),
378 ));
379 }
380
381 let event_source = SseStream::new(response.bytes_stream());
383 let model = self.model.clone();
384 let api = self.api().to_string();
385 let provider = self.name().to_string();
386 let cloud_cli_mode = self.google_cli_mode;
387
388 let stream = stream::unfold(
389 StreamState::new(event_source, model, api, provider),
390 move |mut state| async move {
391 if state.finished {
392 return None;
393 }
394 loop {
395 if let Some(event) = state.pending_events.pop_front() {
397 return Some((Ok(event), state));
398 }
399
400 match state.event_source.next().await {
401 Some(Ok(msg)) => {
402 state.transient_error_count = 0;
403 if msg.event == "ping" {
404 continue;
405 }
406
407 let processing = if cloud_cli_mode {
408 state.process_cloud_code_event(&msg.data)
409 } else {
410 state.process_event(&msg.data)
411 };
412 if let Err(e) = processing {
413 state.finished = true;
414 return Some((Err(e), state));
415 }
416 }
417 Some(Err(e)) => {
418 const MAX_CONSECUTIVE_TRANSIENT_ERRORS: usize = 5;
422 if e.kind() == std::io::ErrorKind::WriteZero
423 || e.kind() == std::io::ErrorKind::WouldBlock
424 || e.kind() == std::io::ErrorKind::TimedOut
425 {
426 state.transient_error_count += 1;
427 if state.transient_error_count
428 <= MAX_CONSECUTIVE_TRANSIENT_ERRORS
429 {
430 tracing::warn!(
431 kind = ?e.kind(),
432 count = state.transient_error_count,
433 "Transient error in SSE stream, continuing"
434 );
435 continue;
436 }
437 tracing::warn!(
438 kind = ?e.kind(),
439 "Error persisted after {MAX_CONSECUTIVE_TRANSIENT_ERRORS} \
440 consecutive attempts, treating as fatal"
441 );
442 }
443 state.finished = true;
444 let err = Error::api(format!("SSE error: {e}"));
445 return Some((Err(err), state));
446 }
447 None => {
448 state.finished = true;
450 let reason = state.partial.stop_reason;
451 let message = std::mem::take(&mut state.partial);
452 return Some((Ok(StreamEvent::Done { reason, message }), state));
453 }
454 }
455 }
456 },
457 );
458
459 return Ok(Box::pin(stream));
460 }
461
462 let has_auth_override = google_api_key_override(options, self.compat.as_ref()).is_some()
463 || authorization_override(options, self.compat.as_ref()).is_some();
464 let auth_value = if has_auth_override {
465 None
466 } else {
467 Some(
468 options
469 .api_key
470 .clone()
471 .or_else(|| std::env::var("GOOGLE_API_KEY").ok())
472 .or_else(|| std::env::var("GEMINI_API_KEY").ok())
473 .ok_or_else(|| {
474 Error::provider(
475 self.name(),
476 "Missing API key for provider. Configure credentials with /login <provider> or set the provider's API key env var.",
477 )
478 })?,
479 )
480 };
481
482 if let Some(auth_value) = auth_value {
483 request = request.header("x-goog-api-key", &auth_value);
484 }
485
486 if let Some(compat) = &self.compat {
488 if let Some(custom_headers) = &compat.custom_headers {
489 request = super::apply_headers_ignoring_blank_auth_overrides(
490 request,
491 custom_headers,
492 &["authorization", "x-goog-api-key"],
493 );
494 }
495 }
496
497 request = super::apply_headers_ignoring_blank_auth_overrides(
499 request,
500 &options.headers,
501 &["authorization", "x-goog-api-key"],
502 );
503
504 let request = request.json(&request_body)?;
505
506 let response = Box::pin(request.send()).await?;
507 let status = response.status();
508 if !(200..300).contains(&status) {
509 let body = response
510 .text()
511 .await
512 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
513 return Err(Error::provider(
514 self.name(),
515 format!("Gemini API error (HTTP {status}): {body}"),
516 ));
517 }
518
519 let event_source = SseStream::new(response.bytes_stream());
521
522 let model = self.model.clone();
524 let api = self.api().to_string();
525 let provider = self.name().to_string();
526 let cloud_cli_mode = self.google_cli_mode;
527
528 let stream = stream::unfold(
529 StreamState::new(event_source, model, api, provider),
530 move |mut state| async move {
531 if state.finished {
532 return None;
533 }
534 loop {
535 if let Some(event) = state.pending_events.pop_front() {
537 return Some((Ok(event), state));
538 }
539
540 match state.event_source.next().await {
541 Some(Ok(msg)) => {
542 state.transient_error_count = 0;
543 if msg.event == "ping" {
544 continue;
545 }
546
547 let processing = if cloud_cli_mode {
548 state.process_cloud_code_event(&msg.data)
549 } else {
550 state.process_event(&msg.data)
551 };
552 if let Err(e) = processing {
553 state.finished = true;
554 return Some((Err(e), state));
555 }
556 }
557 Some(Err(e)) => {
558 const MAX_CONSECUTIVE_WRITE_ZERO: usize = 5;
559 if e.kind() == std::io::ErrorKind::WriteZero {
560 state.transient_error_count += 1;
561 if state.transient_error_count <= MAX_CONSECUTIVE_WRITE_ZERO {
562 tracing::warn!(
563 count = state.transient_error_count,
564 "Transient WriteZero error in SSE stream, continuing"
565 );
566 continue;
567 }
568 tracing::warn!(
569 "WriteZero error persisted after {MAX_CONSECUTIVE_WRITE_ZERO} \
570 consecutive attempts, treating as fatal"
571 );
572 }
573 state.finished = true;
574 let err = Error::api(format!("SSE error: {e}"));
575 return Some((Err(err), state));
576 }
577 None => {
578 state.finished = true;
580 let reason = state.partial.stop_reason;
581 let message = std::mem::take(&mut state.partial);
582 return Some((Ok(StreamEvent::Done { reason, message }), state));
583 }
584 }
585 }
586 },
587 );
588
589 Ok(Box::pin(stream))
590 }
591}
592
593struct StreamState<S>
598where
599 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
600{
601 event_source: SseStream<S>,
602 partial: AssistantMessage,
603 pending_events: VecDeque<StreamEvent>,
604 started: bool,
605 finished: bool,
606 transient_error_count: usize,
608}
609
610impl<S> StreamState<S>
611where
612 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
613{
614 fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
615 Self {
616 event_source,
617 partial: AssistantMessage {
618 content: Vec::new(),
619 api,
620 provider,
621 model,
622 usage: Usage::default(),
623 stop_reason: StopReason::Stop,
624 error_message: None,
625 timestamp: chrono::Utc::now().timestamp_millis(),
626 },
627 pending_events: VecDeque::new(),
628 started: false,
629 finished: false,
630 transient_error_count: 0,
631 }
632 }
633
634 fn process_event(&mut self, data: &str) -> Result<()> {
635 let response: GeminiStreamResponse = serde_json::from_str(data)
636 .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
637 self.process_response(response)
638 }
639
640 fn process_response(&mut self, response: GeminiStreamResponse) -> Result<()> {
641 if let Some(metadata) = response.usage_metadata {
643 self.partial.usage.input = metadata.prompt_token_count.unwrap_or(0);
644 self.partial.usage.output = metadata.candidates_token_count.unwrap_or(0);
645 self.partial.usage.total_tokens = metadata.total_token_count.unwrap_or(0);
646 }
647
648 if let Some(candidates) = response.candidates {
650 if let Some(candidate) = candidates.into_iter().next() {
651 self.process_candidate(candidate)?;
652 }
653 }
654
655 Ok(())
656 }
657
658 fn process_cloud_code_event(&mut self, data: &str) -> Result<()> {
659 let wrapped: CloudCodeAssistResponseChunk = serde_json::from_str(data)
660 .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
661 let Some(response) = wrapped.response else {
662 return Ok(());
663 };
664 self.process_response(GeminiStreamResponse {
665 candidates: response.candidates,
666 usage_metadata: response.usage_metadata,
667 })
668 }
669
670 #[allow(clippy::unnecessary_wraps)]
671 fn process_candidate(&mut self, candidate: GeminiCandidate) -> Result<()> {
672 let has_finish_reason = candidate.finish_reason.is_some();
673
674 if let Some(reason) = candidate.finish_reason.as_deref() {
676 self.partial.stop_reason = match reason {
677 "MAX_TOKENS" => StopReason::Length,
678 "SAFETY" | "RECITATION" | "OTHER" => StopReason::Error,
679 "FUNCTION_CALL" => StopReason::ToolUse,
680 _ => StopReason::Stop,
682 };
683 }
684
685 if let Some(content) = candidate.content {
687 for part in content.parts {
688 match part {
689 GeminiPart::Text { text } => {
690 let last_is_text =
692 matches!(self.partial.content.last(), Some(ContentBlock::Text(_)));
693
694 self.ensure_started();
698
699 let content_index = if last_is_text {
700 self.partial.content.len() - 1
701 } else {
702 let idx = self.partial.content.len();
703 self.partial
704 .content
705 .push(ContentBlock::Text(TextContent::new("")));
706 self.pending_events
707 .push_back(StreamEvent::TextStart { content_index: idx });
708 idx
709 };
710
711 if let Some(ContentBlock::Text(t)) =
712 self.partial.content.get_mut(content_index)
713 {
714 t.text.push_str(&text);
715 }
716
717 self.pending_events.push_back(StreamEvent::TextDelta {
718 content_index,
719 delta: text,
720 });
721 }
722 GeminiPart::FunctionCall { function_call } => {
723 let id = format!("call_{}", uuid::Uuid::new_v4().simple());
725
726 let args_str = serde_json::to_string(&function_call.args)
728 .unwrap_or_else(|_| "{}".to_string());
729 let GeminiFunctionCall { name, args } = function_call;
730
731 let tool_call = ToolCall {
732 id,
733 name,
734 arguments: args,
735 thought_signature: None,
736 };
737
738 self.partial
739 .content
740 .push(ContentBlock::ToolCall(tool_call.clone()));
741 let content_index = self.partial.content.len() - 1;
742
743 self.partial.stop_reason = StopReason::ToolUse;
745
746 self.ensure_started();
747
748 self.pending_events
750 .push_back(StreamEvent::ToolCallStart { content_index });
751 self.pending_events.push_back(StreamEvent::ToolCallDelta {
752 content_index,
753 delta: args_str,
754 });
755 self.pending_events.push_back(StreamEvent::ToolCallEnd {
756 content_index,
757 tool_call,
758 });
759 }
760 GeminiPart::InlineData { .. }
761 | GeminiPart::FunctionResponse { .. }
762 | GeminiPart::Unknown(_) => {
763 }
767 }
768 }
769 }
770
771 if has_finish_reason {
774 for (content_index, block) in self.partial.content.iter().enumerate() {
775 if let ContentBlock::Text(t) = block {
776 self.pending_events.push_back(StreamEvent::TextEnd {
777 content_index,
778 content: t.text.clone(),
779 });
780 } else if let ContentBlock::Thinking(t) = block {
781 self.pending_events.push_back(StreamEvent::ThinkingEnd {
782 content_index,
783 content: t.thinking.clone(),
784 });
785 }
786 }
787 }
788
789 Ok(())
790 }
791
792 fn ensure_started(&mut self) {
793 if !self.started {
794 self.started = true;
795 self.pending_events.push_back(StreamEvent::Start {
796 partial: self.partial.clone(),
797 });
798 }
799 }
800}
801
802#[derive(Debug, Serialize)]
807#[serde(rename_all = "camelCase")]
808pub struct GeminiRequest {
809 pub(crate) contents: Vec<GeminiContent>,
810 #[serde(skip_serializing_if = "Option::is_none")]
811 pub(crate) system_instruction: Option<GeminiContent>,
812 #[serde(skip_serializing_if = "Option::is_none")]
813 pub(crate) tools: Option<Vec<GeminiTool>>,
814 #[serde(skip_serializing_if = "Option::is_none")]
815 pub(crate) tool_config: Option<GeminiToolConfig>,
816 #[serde(skip_serializing_if = "Option::is_none")]
817 pub(crate) generation_config: Option<GeminiGenerationConfig>,
818}
819
820#[derive(Debug, Serialize, Deserialize)]
821#[serde(rename_all = "camelCase")]
822pub(crate) struct GeminiContent {
823 #[serde(skip_serializing_if = "Option::is_none")]
824 pub(crate) role: Option<String>,
825 pub(crate) parts: Vec<GeminiPart>,
826}
827
828#[derive(Debug, Serialize, Deserialize)]
829#[serde(untagged)]
830pub(crate) enum GeminiPart {
831 Text {
832 text: String,
833 },
834 InlineData {
835 inline_data: GeminiBlob,
836 },
837 FunctionCall {
838 #[serde(rename = "functionCall")]
839 function_call: GeminiFunctionCall,
840 },
841 FunctionResponse {
842 #[serde(rename = "functionResponse")]
843 function_response: GeminiFunctionResponse,
844 },
845 Unknown(serde_json::Value),
849}
850
851#[derive(Debug, Serialize, Deserialize)]
852#[serde(rename_all = "camelCase")]
853pub(crate) struct GeminiBlob {
854 pub(crate) mime_type: String,
855 pub(crate) data: String,
856}
857
858#[derive(Debug, Serialize, Deserialize)]
859pub(crate) struct GeminiFunctionCall {
860 pub(crate) name: String,
861 pub(crate) args: serde_json::Value,
862}
863
864#[derive(Debug, Serialize, Deserialize)]
865pub(crate) struct GeminiFunctionResponse {
866 pub(crate) name: String,
867 pub(crate) response: serde_json::Value,
868}
869
870#[derive(Debug, Serialize)]
871#[serde(rename_all = "camelCase")]
872pub(crate) struct GeminiTool {
873 pub(crate) function_declarations: Vec<GeminiFunctionDeclaration>,
874}
875
876#[derive(Debug, Serialize)]
877pub(crate) struct GeminiFunctionDeclaration {
878 pub(crate) name: String,
879 pub(crate) description: String,
880 pub(crate) parameters: serde_json::Value,
881}
882
883#[derive(Debug, Serialize)]
884#[serde(rename_all = "camelCase")]
885pub(crate) struct GeminiToolConfig {
886 pub(crate) function_calling_config: GeminiFunctionCallingConfig,
887}
888
889#[derive(Debug, Serialize)]
890pub(crate) struct GeminiFunctionCallingConfig {
891 pub(crate) mode: &'static str,
892}
893
894#[derive(Debug, Serialize)]
895#[serde(rename_all = "camelCase")]
896pub(crate) struct GeminiGenerationConfig {
897 #[serde(skip_serializing_if = "Option::is_none")]
898 pub(crate) max_output_tokens: Option<u32>,
899 #[serde(skip_serializing_if = "Option::is_none")]
900 pub(crate) temperature: Option<f32>,
901 #[serde(skip_serializing_if = "Option::is_none")]
902 pub(crate) candidate_count: Option<u32>,
903}
904
905#[derive(Debug, Deserialize)]
910#[serde(rename_all = "camelCase")]
911pub(crate) struct GeminiStreamResponse {
912 #[serde(default)]
913 pub(crate) candidates: Option<Vec<GeminiCandidate>>,
914 #[serde(default)]
915 pub(crate) usage_metadata: Option<GeminiUsageMetadata>,
916}
917
918#[derive(Debug, Deserialize)]
919#[serde(rename_all = "camelCase")]
920struct CloudCodeAssistResponseChunk {
921 #[serde(default)]
922 response: Option<CloudCodeAssistResponse>,
923}
924
925#[derive(Debug, Deserialize)]
926#[serde(rename_all = "camelCase")]
927struct CloudCodeAssistResponse {
928 #[serde(default)]
929 candidates: Option<Vec<GeminiCandidate>>,
930 #[serde(default)]
931 usage_metadata: Option<GeminiUsageMetadata>,
932}
933
934#[derive(Debug, Deserialize)]
935#[serde(rename_all = "camelCase")]
936pub(crate) struct GeminiCandidate {
937 #[serde(default)]
938 pub(crate) content: Option<GeminiContent>,
939 #[serde(default)]
940 pub(crate) finish_reason: Option<String>,
941}
942
943#[derive(Debug, Deserialize)]
944#[serde(rename_all = "camelCase")]
945#[allow(clippy::struct_field_names)]
946pub(crate) struct GeminiUsageMetadata {
947 #[serde(default)]
948 pub(crate) prompt_token_count: Option<u64>,
949 #[serde(default)]
950 pub(crate) candidates_token_count: Option<u64>,
951 #[serde(default)]
952 pub(crate) total_token_count: Option<u64>,
953}
954
955pub(crate) fn convert_message_to_gemini(message: &Message) -> Vec<GeminiContent> {
960 match message {
961 Message::User(user) => vec![GeminiContent {
962 role: Some("user".into()),
963 parts: convert_user_content_to_parts(&user.content),
964 }],
965 Message::Custom(custom) => vec![GeminiContent {
966 role: Some("user".into()),
967 parts: vec![GeminiPart::Text {
968 text: custom.content.clone(),
969 }],
970 }],
971 Message::Assistant(assistant) => {
972 let mut parts = Vec::new();
973
974 for block in &assistant.content {
975 match block {
976 ContentBlock::Text(t) => {
977 parts.push(GeminiPart::Text {
978 text: t.text.clone(),
979 });
980 }
981 ContentBlock::ToolCall(tc) => {
982 parts.push(GeminiPart::FunctionCall {
983 function_call: GeminiFunctionCall {
984 name: tc.name.clone(),
985 args: tc.arguments.clone(),
986 },
987 });
988 }
989 ContentBlock::Thinking(_) | ContentBlock::Image(_) => {
990 }
992 }
993 }
994
995 if parts.is_empty() {
996 return Vec::new();
997 }
998
999 vec![GeminiContent {
1000 role: Some("model".into()),
1001 parts,
1002 }]
1003 }
1004 Message::ToolResult(result) => {
1005 let content_text = result
1007 .content
1008 .iter()
1009 .map(|b| match b {
1010 ContentBlock::Text(t) => t.text.clone(),
1011 ContentBlock::Image(img) => format!("[Image ({}) omitted]", img.mime_type),
1012 _ => String::new(),
1013 })
1014 .filter(|s| !s.is_empty())
1015 .collect::<Vec<_>>()
1016 .join("\n");
1017
1018 let response_value = if result.is_error {
1019 serde_json::json!({ "error": content_text })
1020 } else {
1021 serde_json::json!({ "result": content_text })
1022 };
1023
1024 vec![GeminiContent {
1025 role: Some("user".into()),
1026 parts: vec![GeminiPart::FunctionResponse {
1027 function_response: GeminiFunctionResponse {
1028 name: result.tool_name.clone(),
1029 response: response_value,
1030 },
1031 }],
1032 }]
1033 }
1034 }
1035}
1036
1037pub(crate) fn convert_user_content_to_parts(content: &UserContent) -> Vec<GeminiPart> {
1038 match content {
1039 UserContent::Text(text) => vec![GeminiPart::Text { text: text.clone() }],
1040 UserContent::Blocks(blocks) => blocks
1041 .iter()
1042 .filter_map(|block| match block {
1043 ContentBlock::Text(t) => Some(GeminiPart::Text {
1044 text: t.text.clone(),
1045 }),
1046 ContentBlock::Image(img) => Some(GeminiPart::InlineData {
1047 inline_data: GeminiBlob {
1048 mime_type: img.mime_type.clone(),
1049 data: img.data.clone(),
1050 },
1051 }),
1052 _ => None,
1053 })
1054 .collect(),
1055 }
1056}
1057
1058pub(crate) fn convert_tool_to_gemini(tool: &ToolDef) -> GeminiFunctionDeclaration {
1059 GeminiFunctionDeclaration {
1060 name: tool.name.clone(),
1061 description: tool.description.clone(),
1062 parameters: tool.parameters.clone(),
1063 }
1064}
1065
1066#[cfg(test)]
1071mod tests {
1072 use super::*;
1073 use asupersync::runtime::RuntimeBuilder;
1074 use futures::{StreamExt, stream};
1075 use serde::{Deserialize, Serialize};
1076 use serde_json::Value;
1077 use std::collections::HashMap;
1078 use std::io::{Read, Write};
1079 use std::net::TcpListener;
1080 use std::path::PathBuf;
1081 use std::sync::mpsc;
1082 use std::time::Duration;
1083
1084 #[test]
1085 fn test_convert_user_text_message() {
1086 let message = Message::User(crate::model::UserMessage {
1087 content: UserContent::Text("Hello".to_string()),
1088 timestamp: 0,
1089 });
1090
1091 let converted = convert_message_to_gemini(&message);
1092 assert_eq!(converted.len(), 1);
1093 assert_eq!(converted[0].role, Some("user".to_string()));
1094 }
1095
1096 #[test]
1097 fn test_tool_conversion() {
1098 let tool = ToolDef {
1099 name: "test_tool".to_string(),
1100 description: "A test tool".to_string(),
1101 parameters: serde_json::json!({
1102 "type": "object",
1103 "properties": {
1104 "arg": {"type": "string"}
1105 }
1106 }),
1107 };
1108
1109 let converted = convert_tool_to_gemini(&tool);
1110 assert_eq!(converted.name, "test_tool");
1111 assert_eq!(converted.description, "A test tool");
1112 }
1113
1114 #[test]
1115 fn test_provider_info() {
1116 let provider = GeminiProvider::new("gemini-2.0-flash");
1117 assert_eq!(provider.name(), "google");
1118 assert_eq!(provider.api(), "google-generative-ai");
1119 }
1120
1121 #[test]
1122 fn test_streaming_url() {
1123 let provider = GeminiProvider::new("gemini-2.0-flash");
1124 let url = provider.streaming_url();
1125 assert!(url.contains("gemini-2.0-flash"));
1126 assert!(url.contains("streamGenerateContent"));
1127 assert!(!url.contains("key="));
1128 }
1129
1130 #[derive(Debug, Deserialize)]
1131 struct ProviderFixture {
1132 cases: Vec<ProviderCase>,
1133 }
1134
1135 #[derive(Debug, Deserialize)]
1136 struct ProviderCase {
1137 name: String,
1138 events: Vec<Value>,
1139 expected: Vec<EventSummary>,
1140 }
1141
1142 #[derive(Debug, Deserialize, Serialize, PartialEq)]
1143 struct EventSummary {
1144 kind: String,
1145 #[serde(default)]
1146 content_index: Option<usize>,
1147 #[serde(default)]
1148 delta: Option<String>,
1149 #[serde(default)]
1150 content: Option<String>,
1151 #[serde(default)]
1152 reason: Option<String>,
1153 }
1154
1155 #[test]
1156 fn test_stream_fixtures() {
1157 let fixture = load_fixture("gemini_stream.json");
1158 for case in fixture.cases {
1159 let events = collect_events(&case.events);
1160 let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1161 assert_eq!(summaries, case.expected, "case {}", case.name);
1162 }
1163 }
1164
1165 #[derive(Debug)]
1166 struct CapturedRequest {
1167 headers: HashMap<String, String>,
1168 body: String,
1169 }
1170
1171 #[test]
1172 fn test_stream_compat_google_api_key_header_works_without_api_key() {
1173 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1174 let mut custom_headers = HashMap::new();
1175 custom_headers.insert(
1176 "x-goog-api-key".to_string(),
1177 "compat-google-key".to_string(),
1178 );
1179 let provider = GeminiProvider::new("gemini-2.0-flash")
1180 .with_base_url(base_url)
1181 .with_compat(Some(CompatConfig {
1182 custom_headers: Some(custom_headers),
1183 ..CompatConfig::default()
1184 }));
1185 let context = Context::owned(
1186 None,
1187 vec![Message::User(crate::model::UserMessage {
1188 content: UserContent::Text("ping".to_string()),
1189 timestamp: 0,
1190 })],
1191 Vec::new(),
1192 );
1193
1194 let runtime = RuntimeBuilder::current_thread()
1195 .build()
1196 .expect("runtime build");
1197 runtime.block_on(async {
1198 let mut stream = provider
1199 .stream(&context, &StreamOptions::default())
1200 .await
1201 .expect("stream");
1202 while let Some(event) = stream.next().await {
1203 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1204 break;
1205 }
1206 }
1207 });
1208
1209 let captured = rx.recv_timeout(Duration::from_secs(2)).expect("captured");
1210 assert_eq!(
1211 captured.headers.get("x-goog-api-key").map(String::as_str),
1212 Some("compat-google-key")
1213 );
1214 let body: Value = serde_json::from_str(&captured.body).expect("body json");
1215 assert_eq!(body["contents"][0]["role"], "user");
1216 }
1217
1218 #[test]
1219 fn test_stream_option_authorization_header_works_without_api_key() {
1220 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1221 let provider = GeminiProvider::new("gemini-2.0-flash").with_base_url(base_url);
1222 let context = Context::owned(
1223 None,
1224 vec![Message::User(crate::model::UserMessage {
1225 content: UserContent::Text("ping".to_string()),
1226 timestamp: 0,
1227 })],
1228 Vec::new(),
1229 );
1230 let mut headers = HashMap::new();
1231 headers.insert(
1232 "Authorization".to_string(),
1233 "Bearer compat-gemini-token".to_string(),
1234 );
1235
1236 let runtime = RuntimeBuilder::current_thread()
1237 .build()
1238 .expect("runtime build");
1239 runtime.block_on(async {
1240 let mut stream = provider
1241 .stream(
1242 &context,
1243 &StreamOptions {
1244 headers,
1245 ..Default::default()
1246 },
1247 )
1248 .await
1249 .expect("stream");
1250 while let Some(event) = stream.next().await {
1251 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1252 break;
1253 }
1254 }
1255 });
1256
1257 let captured = rx.recv_timeout(Duration::from_secs(2)).expect("captured");
1258 assert_eq!(
1259 captured.headers.get("authorization").map(String::as_str),
1260 Some("Bearer compat-gemini-token")
1261 );
1262 let body: Value = serde_json::from_str(&captured.body).expect("body json");
1263 assert_eq!(body["contents"][0]["role"], "user");
1264 }
1265
1266 #[test]
1267 fn test_blank_request_google_api_key_header_does_not_override_builtin_api_key() {
1268 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1269 let provider = GeminiProvider::new("gemini-2.0-flash").with_base_url(base_url);
1270 let context = Context::owned(
1271 None,
1272 vec![Message::User(crate::model::UserMessage {
1273 content: UserContent::Text("ping".to_string()),
1274 timestamp: 0,
1275 })],
1276 Vec::new(),
1277 );
1278 let mut headers = HashMap::new();
1279 headers.insert("X-Goog-Api-Key".to_string(), " ".to_string());
1280 let options = StreamOptions {
1281 api_key: Some("fallback-google-key".to_string()),
1282 headers,
1283 ..Default::default()
1284 };
1285
1286 let runtime = RuntimeBuilder::current_thread()
1287 .build()
1288 .expect("runtime build");
1289 runtime.block_on(async {
1290 let mut stream = provider.stream(&context, &options).await.expect("stream");
1291 while let Some(event) = stream.next().await {
1292 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1293 break;
1294 }
1295 }
1296 });
1297
1298 let captured = rx.recv_timeout(Duration::from_secs(2)).expect("captured");
1299 assert_eq!(
1300 captured.headers.get("x-goog-api-key").map(String::as_str),
1301 Some("fallback-google-key")
1302 );
1303 let body: Value = serde_json::from_str(&captured.body).expect("body json");
1304 assert_eq!(body["contents"][0]["role"], "user");
1305 }
1306
1307 fn load_fixture(file_name: &str) -> ProviderFixture {
1308 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1309 .join("tests/fixtures/provider_responses")
1310 .join(file_name);
1311 let raw = std::fs::read_to_string(path).expect("fixture read");
1312 serde_json::from_str(&raw).expect("fixture parse")
1313 }
1314
1315 fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1316 let runtime = RuntimeBuilder::current_thread()
1317 .build()
1318 .expect("runtime build");
1319 runtime.block_on(async move {
1320 let byte_stream = stream::iter(
1321 events
1322 .iter()
1323 .map(|event| {
1324 let data = match event {
1325 Value::String(text) => text.clone(),
1326 _ => serde_json::to_string(event).expect("serialize event"),
1327 };
1328 format!("data: {data}\n\n").into_bytes()
1329 })
1330 .map(Ok),
1331 );
1332 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1333 let mut state = StreamState::new(
1334 event_source,
1335 "gemini-test".to_string(),
1336 "google-generative".to_string(),
1337 "google".to_string(),
1338 );
1339 let mut out = Vec::new();
1340
1341 loop {
1342 let Some(item) = state.event_source.next().await else {
1343 if !state.finished {
1344 state.finished = true;
1345 out.push(StreamEvent::Done {
1346 reason: state.partial.stop_reason,
1347 message: std::mem::take(&mut state.partial),
1348 });
1349 }
1350 break;
1351 };
1352
1353 let msg = item.expect("SSE event");
1354 if msg.event == "ping" {
1355 continue;
1356 }
1357 state.process_event(&msg.data).expect("process_event");
1358 out.extend(state.pending_events.drain(..));
1359 }
1360
1361 out
1362 })
1363 }
1364
1365 fn success_sse_body() -> String {
1366 [
1367 r#"data: {"candidates":[{"content":{"parts":[{"text":"ok"}]}}]}"#,
1368 "",
1369 r#"data: {"candidates":[{"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}"#,
1370 "",
1371 ]
1372 .join("\n")
1373 }
1374
1375 fn spawn_test_server(
1376 status_code: u16,
1377 content_type: &str,
1378 body: &str,
1379 ) -> (String, mpsc::Receiver<CapturedRequest>) {
1380 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1381 let addr = listener.local_addr().expect("local addr");
1382 let (tx, rx) = mpsc::channel();
1383 let body = body.to_string();
1384 let content_type = content_type.to_string();
1385
1386 std::thread::spawn(move || {
1387 let (mut socket, _) = listener.accept().expect("accept");
1388 socket
1389 .set_read_timeout(Some(Duration::from_secs(2)))
1390 .expect("set read timeout");
1391
1392 let mut bytes = Vec::new();
1393 let mut chunk = [0_u8; 4096];
1394 loop {
1395 match socket.read(&mut chunk) {
1396 Ok(0) => break,
1397 Ok(n) => {
1398 bytes.extend_from_slice(&chunk[..n]);
1399 if bytes.windows(4).any(|window| window == b"\r\n\r\n") {
1400 break;
1401 }
1402 }
1403 Err(err)
1404 if err.kind() == std::io::ErrorKind::WouldBlock
1405 || err.kind() == std::io::ErrorKind::TimedOut =>
1406 {
1407 break;
1408 }
1409 Err(err) => panic!("{err}"),
1410 }
1411 }
1412
1413 let header_end = bytes
1414 .windows(4)
1415 .position(|window| window == b"\r\n\r\n")
1416 .expect("request header boundary");
1417 let header_text = String::from_utf8_lossy(&bytes[..header_end]).to_string();
1418 let headers = parse_headers(&header_text);
1419 let mut request_body = bytes[header_end + 4..].to_vec();
1420
1421 let content_length = headers
1422 .get("content-length")
1423 .and_then(|value| value.parse::<usize>().ok())
1424 .unwrap_or(0);
1425 while request_body.len() < content_length {
1426 match socket.read(&mut chunk) {
1427 Ok(0) => break,
1428 Ok(n) => request_body.extend_from_slice(&chunk[..n]),
1429 Err(err)
1430 if err.kind() == std::io::ErrorKind::WouldBlock
1431 || err.kind() == std::io::ErrorKind::TimedOut =>
1432 {
1433 break;
1434 }
1435 Err(err) => panic!("{err}"),
1436 }
1437 }
1438
1439 tx.send(CapturedRequest {
1440 headers,
1441 body: String::from_utf8_lossy(&request_body).to_string(),
1442 })
1443 .expect("send captured request");
1444
1445 let reason = match status_code {
1446 401 => "Unauthorized",
1447 500 => "Internal Server Error",
1448 _ => "OK",
1449 };
1450 let response = format!(
1451 "HTTP/1.1 {status_code} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
1452 body.len()
1453 );
1454 socket
1455 .write_all(response.as_bytes())
1456 .expect("write response");
1457 socket.flush().expect("flush response");
1458 });
1459
1460 (format!("http://{addr}"), rx)
1461 }
1462
1463 fn parse_headers(header_text: &str) -> HashMap<String, String> {
1464 let mut headers = HashMap::new();
1465 for line in header_text.lines().skip(1) {
1466 if let Some((name, value)) = line.split_once(':') {
1467 headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
1468 }
1469 }
1470 headers
1471 }
1472
1473 fn summarize_event(event: &StreamEvent) -> EventSummary {
1474 match event {
1475 StreamEvent::Start { .. } => EventSummary {
1476 kind: "start".to_string(),
1477 content_index: None,
1478 delta: None,
1479 content: None,
1480 reason: None,
1481 },
1482 StreamEvent::TextDelta {
1483 content_index,
1484 delta,
1485 ..
1486 } => EventSummary {
1487 kind: "text_delta".to_string(),
1488 content_index: Some(*content_index),
1489 delta: Some(delta.clone()),
1490 content: None,
1491 reason: None,
1492 },
1493 StreamEvent::Done { reason, .. } => EventSummary {
1494 kind: "done".to_string(),
1495 content_index: None,
1496 delta: None,
1497 content: None,
1498 reason: Some(reason_to_string(*reason)),
1499 },
1500 StreamEvent::Error { reason, .. } => EventSummary {
1501 kind: "error".to_string(),
1502 content_index: None,
1503 delta: None,
1504 content: None,
1505 reason: Some(reason_to_string(*reason)),
1506 },
1507 StreamEvent::TextStart { content_index, .. } => EventSummary {
1508 kind: "text_start".to_string(),
1509 content_index: Some(*content_index),
1510 delta: None,
1511 content: None,
1512 reason: None,
1513 },
1514 StreamEvent::TextEnd {
1515 content_index,
1516 content,
1517 ..
1518 } => EventSummary {
1519 kind: "text_end".to_string(),
1520 content_index: Some(*content_index),
1521 delta: None,
1522 content: Some(content.clone()),
1523 reason: None,
1524 },
1525 _ => EventSummary {
1526 kind: "other".to_string(),
1527 content_index: None,
1528 delta: None,
1529 content: None,
1530 reason: None,
1531 },
1532 }
1533 }
1534
1535 fn reason_to_string(reason: StopReason) -> String {
1536 match reason {
1537 StopReason::Stop => "stop",
1538 StopReason::Length => "length",
1539 StopReason::ToolUse => "tool_use",
1540 StopReason::Error => "error",
1541 StopReason::Aborted => "aborted",
1542 }
1543 .to_string()
1544 }
1545
1546 #[test]
1549 fn test_build_request_basic_text() {
1550 let provider = GeminiProvider::new("gemini-2.0-flash");
1551 let context = Context::owned(
1552 Some("You are helpful.".to_string()),
1553 vec![Message::User(crate::model::UserMessage {
1554 content: UserContent::Text("What is Rust?".to_string()),
1555 timestamp: 0,
1556 })],
1557 vec![],
1558 );
1559 let options = crate::provider::StreamOptions {
1560 max_tokens: Some(1024),
1561 temperature: Some(0.7),
1562 ..Default::default()
1563 };
1564
1565 let req = provider.build_request(&context, &options);
1566 let json = serde_json::to_value(&req).expect("serialize");
1567
1568 let contents = json["contents"].as_array().expect("contents array");
1570 assert_eq!(contents.len(), 1);
1571 assert_eq!(contents[0]["role"], "user");
1572 assert_eq!(contents[0]["parts"][0]["text"], "What is Rust?");
1573
1574 assert_eq!(
1576 json["systemInstruction"]["parts"][0]["text"],
1577 "You are helpful."
1578 );
1579
1580 assert!(json.get("tools").is_none() || json["tools"].is_null());
1582
1583 assert_eq!(json["generationConfig"]["maxOutputTokens"], 1024);
1585 assert!((json["generationConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.01);
1586 assert_eq!(json["generationConfig"]["candidateCount"], 1);
1587 }
1588
1589 #[test]
1590 fn test_build_request_with_tools() {
1591 let provider = GeminiProvider::new("gemini-2.0-flash");
1592 let context = Context::owned(
1593 None,
1594 vec![Message::User(crate::model::UserMessage {
1595 content: UserContent::Text("Read a file".to_string()),
1596 timestamp: 0,
1597 })],
1598 vec![
1599 ToolDef {
1600 name: "read".to_string(),
1601 description: "Read a file".to_string(),
1602 parameters: serde_json::json!({
1603 "type": "object",
1604 "properties": {
1605 "path": {"type": "string"}
1606 },
1607 "required": ["path"]
1608 }),
1609 },
1610 ToolDef {
1611 name: "write".to_string(),
1612 description: "Write a file".to_string(),
1613 parameters: serde_json::json!({
1614 "type": "object",
1615 "properties": {
1616 "path": {"type": "string"},
1617 "content": {"type": "string"}
1618 }
1619 }),
1620 },
1621 ],
1622 );
1623 let options = crate::provider::StreamOptions::default();
1624
1625 let req = provider.build_request(&context, &options);
1626 let json = serde_json::to_value(&req).expect("serialize");
1627
1628 assert!(json.get("systemInstruction").is_none() || json["systemInstruction"].is_null());
1630
1631 let tools = json["tools"].as_array().expect("tools array");
1633 assert_eq!(tools.len(), 1);
1634 let declarations = tools[0]["functionDeclarations"]
1635 .as_array()
1636 .expect("declarations");
1637 assert_eq!(declarations.len(), 2);
1638 assert_eq!(declarations[0]["name"], "read");
1639 assert_eq!(declarations[1]["name"], "write");
1640 assert_eq!(declarations[0]["description"], "Read a file");
1641
1642 assert_eq!(json["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
1644 }
1645
1646 #[test]
1647 fn test_build_request_default_max_tokens() {
1648 let provider = GeminiProvider::new("gemini-2.0-flash");
1649 let context = Context::owned(
1650 None,
1651 vec![Message::User(crate::model::UserMessage {
1652 content: UserContent::Text("hi".to_string()),
1653 timestamp: 0,
1654 })],
1655 vec![],
1656 );
1657 let options = crate::provider::StreamOptions::default();
1658
1659 let req = provider.build_request(&context, &options);
1660 let json = serde_json::to_value(&req).expect("serialize");
1661
1662 assert_eq!(
1664 json["generationConfig"]["maxOutputTokens"],
1665 DEFAULT_MAX_TOKENS
1666 );
1667 }
1668
1669 #[test]
1672 fn test_streaming_url_no_key_query_param() {
1673 let provider = GeminiProvider::new("gemini-2.0-flash");
1674 let url = provider.streaming_url();
1675
1676 assert!(
1678 !url.contains("key="),
1679 "API key should not be in query param"
1680 );
1681 assert!(url.contains("alt=sse"), "alt=sse should be present");
1682 assert!(
1683 url.contains("streamGenerateContent"),
1684 "should use streaming endpoint"
1685 );
1686 }
1687
1688 #[test]
1689 fn test_streaming_url_custom_base() {
1690 let provider =
1691 GeminiProvider::new("gemini-pro").with_base_url("https://custom.example.com/v1");
1692 let url = provider.streaming_url();
1693
1694 assert!(url.starts_with("https://custom.example.com/v1/models/gemini-pro"));
1695 assert!(!url.contains("key="));
1696 }
1697
1698 #[test]
1701 fn test_convert_user_text_to_gemini_parts() {
1702 let parts = convert_user_content_to_parts(&UserContent::Text("hello world".to_string()));
1703 assert_eq!(parts.len(), 1);
1704 match &parts[0] {
1705 GeminiPart::Text { text } => assert_eq!(text, "hello world"),
1706 _ => panic!(),
1707 }
1708 }
1709
1710 #[test]
1711 fn test_convert_user_blocks_with_image_to_gemini_parts() {
1712 let content = UserContent::Blocks(vec![
1713 ContentBlock::Text(TextContent::new("describe this")),
1714 ContentBlock::Image(crate::model::ImageContent {
1715 data: "aGVsbG8=".to_string(),
1716 mime_type: "image/png".to_string(),
1717 }),
1718 ]);
1719
1720 let parts = convert_user_content_to_parts(&content);
1721 assert_eq!(parts.len(), 2);
1722 match &parts[0] {
1723 GeminiPart::Text { text } => assert_eq!(text, "describe this"),
1724 _ => panic!(),
1725 }
1726 match &parts[1] {
1727 GeminiPart::InlineData { inline_data } => {
1728 assert_eq!(inline_data.mime_type, "image/png");
1729 assert_eq!(inline_data.data, "aGVsbG8=");
1730 }
1731 _ => panic!(),
1732 }
1733 }
1734
1735 #[test]
1736 fn test_convert_assistant_message_with_tool_call() {
1737 let message = Message::assistant(AssistantMessage {
1738 content: vec![
1739 ContentBlock::Text(TextContent::new("Let me read that file.")),
1740 ContentBlock::ToolCall(ToolCall {
1741 id: "call_123".to_string(),
1742 name: "read".to_string(),
1743 arguments: serde_json::json!({"path": "/tmp/test.txt"}),
1744 thought_signature: None,
1745 }),
1746 ],
1747 api: "google".to_string(),
1748 provider: "google".to_string(),
1749 model: "gemini-2.0-flash".to_string(),
1750 usage: Usage::default(),
1751 stop_reason: StopReason::ToolUse,
1752 error_message: None,
1753 timestamp: 0,
1754 });
1755
1756 let converted = convert_message_to_gemini(&message);
1757 assert_eq!(converted.len(), 1);
1758 assert_eq!(converted[0].role, Some("model".to_string()));
1759 assert_eq!(converted[0].parts.len(), 2);
1760
1761 match &converted[0].parts[0] {
1762 GeminiPart::Text { text } => assert_eq!(text, "Let me read that file."),
1763 _ => panic!(),
1764 }
1765 match &converted[0].parts[1] {
1766 GeminiPart::FunctionCall { function_call } => {
1767 assert_eq!(function_call.name, "read");
1768 assert_eq!(function_call.args["path"], "/tmp/test.txt");
1769 }
1770 _ => panic!(),
1771 }
1772 }
1773
1774 #[test]
1775 fn test_convert_assistant_empty_content_returns_empty() {
1776 let message = Message::assistant(AssistantMessage {
1777 content: vec![],
1778 api: "google".to_string(),
1779 provider: "google".to_string(),
1780 model: "gemini-2.0-flash".to_string(),
1781 usage: Usage::default(),
1782 stop_reason: StopReason::Stop,
1783 error_message: None,
1784 timestamp: 0,
1785 });
1786
1787 let converted = convert_message_to_gemini(&message);
1788 assert!(converted.is_empty());
1789 }
1790
1791 #[test]
1792 fn test_convert_tool_result_success() {
1793 let message = Message::tool_result(crate::model::ToolResultMessage {
1794 tool_call_id: "call_123".to_string(),
1795 tool_name: "read".to_string(),
1796 content: vec![ContentBlock::Text(TextContent::new("file contents here"))],
1797 details: None,
1798 is_error: false,
1799 timestamp: 0,
1800 });
1801
1802 let converted = convert_message_to_gemini(&message);
1803 assert_eq!(converted.len(), 1);
1804 assert_eq!(converted[0].role, Some("user".to_string()));
1805
1806 match &converted[0].parts[0] {
1807 GeminiPart::FunctionResponse { function_response } => {
1808 assert_eq!(function_response.name, "read");
1809 assert_eq!(function_response.response["result"], "file contents here");
1810 assert!(function_response.response.get("error").is_none());
1811 }
1812 _ => panic!(),
1813 }
1814 }
1815
1816 #[test]
1817 fn test_convert_tool_result_error() {
1818 let message = Message::tool_result(crate::model::ToolResultMessage {
1819 tool_call_id: "call_456".to_string(),
1820 tool_name: "bash".to_string(),
1821 content: vec![ContentBlock::Text(TextContent::new("command not found"))],
1822 details: None,
1823 is_error: true,
1824 timestamp: 0,
1825 });
1826
1827 let converted = convert_message_to_gemini(&message);
1828 assert_eq!(converted.len(), 1);
1829
1830 match &converted[0].parts[0] {
1831 GeminiPart::FunctionResponse { function_response } => {
1832 assert_eq!(function_response.name, "bash");
1833 assert_eq!(function_response.response["error"], "command not found");
1834 assert!(function_response.response.get("result").is_none());
1835 }
1836 _ => panic!(),
1837 }
1838 }
1839
1840 #[test]
1841 fn test_convert_custom_message() {
1842 let message = Message::Custom(crate::model::CustomMessage {
1843 custom_type: "system_note".to_string(),
1844 content: "Context window approaching limit.".to_string(),
1845 display: false,
1846 details: None,
1847 timestamp: 0,
1848 });
1849
1850 let converted = convert_message_to_gemini(&message);
1851 assert_eq!(converted.len(), 1);
1852 assert_eq!(converted[0].role, Some("user".to_string()));
1853 match &converted[0].parts[0] {
1854 GeminiPart::Text { text } => {
1855 assert_eq!(text, "Context window approaching limit.");
1856 }
1857 _ => panic!(),
1858 }
1859 }
1860
1861 #[test]
1864 fn test_stop_reason_mapping() {
1865 let test_cases = vec![
1867 ("STOP", StopReason::Stop),
1868 ("MAX_TOKENS", StopReason::Length),
1869 ("SAFETY", StopReason::Error),
1870 ("RECITATION", StopReason::Error),
1871 ("OTHER", StopReason::Error),
1872 ("UNKNOWN_REASON", StopReason::Stop), ];
1874
1875 for (reason_str, expected) in test_cases {
1876 let candidate = GeminiCandidate {
1877 content: None,
1878 finish_reason: Some(reason_str.to_string()),
1879 };
1880
1881 let runtime = RuntimeBuilder::current_thread().build().unwrap();
1882 runtime.block_on(async {
1883 let byte_stream = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1884 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1885 let mut state = StreamState::new(
1886 event_source,
1887 "test".to_string(),
1888 "test".to_string(),
1889 "test".to_string(),
1890 );
1891 state.process_candidate(candidate).unwrap();
1892 assert_eq!(
1893 state.partial.stop_reason, expected,
1894 "finish_reason '{reason_str}' should map to {expected:?}"
1895 );
1896 });
1897 }
1898 }
1899
1900 #[test]
1901 fn test_usage_metadata_parsing() {
1902 let data = r#"{
1903 "usageMetadata": {
1904 "promptTokenCount": 42,
1905 "candidatesTokenCount": 100,
1906 "totalTokenCount": 142
1907 }
1908 }"#;
1909
1910 let runtime = RuntimeBuilder::current_thread().build().unwrap();
1911 runtime.block_on(async {
1912 let byte_stream = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1913 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1914 let mut state = StreamState::new(
1915 event_source,
1916 "test".to_string(),
1917 "test".to_string(),
1918 "test".to_string(),
1919 );
1920 state.process_event(data).unwrap();
1921 assert_eq!(state.partial.usage.input, 42);
1922 assert_eq!(state.partial.usage.output, 100);
1923 assert_eq!(state.partial.usage.total_tokens, 142);
1924 });
1925 }
1926
1927 #[test]
1930 fn test_build_request_full_conversation() {
1931 let provider = GeminiProvider::new("gemini-2.0-flash");
1932 let context = Context::owned(
1933 Some("Be concise.".to_string()),
1934 vec![
1935 Message::User(crate::model::UserMessage {
1936 content: UserContent::Text("Read /tmp/a.txt".to_string()),
1937 timestamp: 0,
1938 }),
1939 Message::assistant(AssistantMessage {
1940 content: vec![ContentBlock::ToolCall(ToolCall {
1941 id: "call_1".to_string(),
1942 name: "read".to_string(),
1943 arguments: serde_json::json!({"path": "/tmp/a.txt"}),
1944 thought_signature: None,
1945 })],
1946 api: "google".to_string(),
1947 provider: "google".to_string(),
1948 model: "gemini-2.0-flash".to_string(),
1949 usage: Usage::default(),
1950 stop_reason: StopReason::ToolUse,
1951 error_message: None,
1952 timestamp: 1,
1953 }),
1954 Message::tool_result(crate::model::ToolResultMessage {
1955 tool_call_id: "call_1".to_string(),
1956 tool_name: "read".to_string(),
1957 content: vec![ContentBlock::Text(TextContent::new("file contents"))],
1958 details: None,
1959 is_error: false,
1960 timestamp: 2,
1961 }),
1962 ],
1963 vec![ToolDef {
1964 name: "read".to_string(),
1965 description: "Read a file".to_string(),
1966 parameters: serde_json::json!({"type": "object"}),
1967 }],
1968 );
1969 let options = crate::provider::StreamOptions::default();
1970
1971 let req = provider.build_request(&context, &options);
1972 let json = serde_json::to_value(&req).expect("serialize");
1973
1974 let contents = json["contents"].as_array().expect("contents");
1975 assert_eq!(contents.len(), 3); assert_eq!(contents[0]["role"], "user");
1979 assert_eq!(contents[0]["parts"][0]["text"], "Read /tmp/a.txt");
1980
1981 assert_eq!(contents[1]["role"], "model");
1983 assert_eq!(contents[1]["parts"][0]["functionCall"]["name"], "read");
1984
1985 assert_eq!(contents[2]["role"], "user");
1987 assert_eq!(contents[2]["parts"][0]["functionResponse"]["name"], "read");
1988 assert_eq!(
1989 contents[2]["parts"][0]["functionResponse"]["response"]["result"],
1990 "file contents"
1991 );
1992 }
1993
1994 mod proptest_process_event {
1999 use super::*;
2000 use proptest::prelude::*;
2001
2002 fn make_state()
2003 -> StreamState<impl Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin>
2004 {
2005 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
2006 let sse = crate::sse::SseStream::new(Box::pin(empty));
2007 StreamState::new(
2008 sse,
2009 "gemini-test".into(),
2010 "google-generative".into(),
2011 "google".into(),
2012 )
2013 }
2014
2015 fn small_string() -> impl Strategy<Value = String> {
2016 prop_oneof![Just(String::new()), "[a-zA-Z0-9_]{1,16}", "[ -~]{0,32}",]
2017 }
2018
2019 fn token_count() -> impl Strategy<Value = u64> {
2020 prop_oneof![
2021 5 => 0u64..10_000u64,
2022 2 => Just(0u64),
2023 1 => Just(u64::MAX),
2024 1 => (u64::MAX - 100)..=u64::MAX,
2025 ]
2026 }
2027
2028 fn finish_reason() -> impl Strategy<Value = Option<String>> {
2029 prop_oneof![
2030 3 => Just(None),
2031 1 => Just(Some("STOP".to_string())),
2032 1 => Just(Some("MAX_TOKENS".to_string())),
2033 1 => Just(Some("SAFETY".to_string())),
2034 1 => Just(Some("RECITATION".to_string())),
2035 1 => Just(Some("OTHER".to_string())),
2036 1 => small_string().prop_map(Some),
2037 ]
2038 }
2039
2040 fn json_args() -> impl Strategy<Value = serde_json::Value> {
2042 prop_oneof![
2043 Just(serde_json::json!({})),
2044 Just(serde_json::json!({"key": "value"})),
2045 Just(serde_json::json!({"a": 1, "b": true, "c": null})),
2046 small_string().prop_map(|s| serde_json::json!({"input": s})),
2047 ]
2048 }
2049
2050 fn text_part() -> impl Strategy<Value = serde_json::Value> {
2052 small_string().prop_map(|t| serde_json::json!({"text": t}))
2053 }
2054
2055 fn function_call_part() -> impl Strategy<Value = serde_json::Value> {
2057 (small_string(), json_args()).prop_map(
2058 |(name, args)| serde_json::json!({"functionCall": {"name": name, "args": args}}),
2059 )
2060 }
2061
2062 fn parts_strategy() -> impl Strategy<Value = Vec<serde_json::Value>> {
2064 prop::collection::vec(
2065 prop_oneof![3 => text_part(), 1 => function_call_part(),],
2066 0..5,
2067 )
2068 }
2069
2070 fn gemini_response_json() -> impl Strategy<Value = String> {
2072 prop_oneof![
2073 3 => (parts_strategy(), finish_reason()).prop_map(|(parts, fr)| {
2075 let mut candidate = serde_json::json!({
2076 "content": {"parts": parts}
2077 });
2078 if let Some(r) = fr {
2079 candidate["finishReason"] = serde_json::Value::String(r);
2080 }
2081 serde_json::json!({"candidates": [candidate]}).to_string()
2082 }),
2083 2 => (token_count(), token_count(), token_count()).prop_map(|(p, c, t)| {
2085 serde_json::json!({
2086 "usageMetadata": {
2087 "promptTokenCount": p,
2088 "candidatesTokenCount": c,
2089 "totalTokenCount": t
2090 }
2091 })
2092 .to_string()
2093 }),
2094 1 => Just(r#"{"candidates":[]}"#.to_string()),
2096 1 => Just(r"{}".to_string()),
2098 1 => finish_reason()
2100 .prop_filter_map("some reason", |fr| fr)
2101 .prop_map(|reason| {
2102 serde_json::json!({
2103 "candidates": [{"finishReason": reason}]
2104 })
2105 .to_string()
2106 }),
2107 2 => (parts_strategy(), finish_reason(), token_count(), token_count(), token_count())
2109 .prop_map(|(parts, fr, p, c, t)| {
2110 let mut candidate = serde_json::json!({
2111 "content": {"parts": parts}
2112 });
2113 if let Some(r) = fr {
2114 candidate["finishReason"] = serde_json::Value::String(r);
2115 }
2116 serde_json::json!({
2117 "candidates": [candidate],
2118 "usageMetadata": {
2119 "promptTokenCount": p,
2120 "candidatesTokenCount": c,
2121 "totalTokenCount": t
2122 }
2123 })
2124 .to_string()
2125 }),
2126 ]
2127 }
2128
2129 fn chaos_json() -> impl Strategy<Value = String> {
2131 prop_oneof![
2132 Just(String::new()),
2133 Just("{}".to_string()),
2134 Just("[]".to_string()),
2135 Just("null".to_string()),
2136 Just("{".to_string()),
2137 Just(r#"{"candidates":"not_array"}"#.to_string()),
2138 Just(r#"{"candidates":[{"content":null}]}"#.to_string()),
2139 Just(r#"{"candidates":[{"content":{"parts":"not_array"}}]}"#.to_string()),
2140 "[ -~]{0,64}",
2141 ]
2142 }
2143
2144 proptest! {
2145 #![proptest_config(ProptestConfig {
2146 cases: 256,
2147 max_shrink_iters: 100,
2148 .. ProptestConfig::default()
2149 })]
2150
2151 #[test]
2152 fn process_event_valid_never_panics(data in gemini_response_json()) {
2153 let mut state = make_state();
2154 let _ = state.process_event(&data);
2155 }
2156
2157 #[test]
2158 fn process_event_chaos_never_panics(data in chaos_json()) {
2159 let mut state = make_state();
2160 let _ = state.process_event(&data);
2161 }
2162
2163 #[test]
2164 fn process_event_sequence_never_panics(
2165 events in prop::collection::vec(gemini_response_json(), 1..8)
2166 ) {
2167 let mut state = make_state();
2168 for event in &events {
2169 let _ = state.process_event(event);
2170 }
2171 }
2172
2173 #[test]
2174 fn process_event_mixed_sequence_never_panics(
2175 events in prop::collection::vec(
2176 prop_oneof![gemini_response_json(), chaos_json()],
2177 1..12
2178 )
2179 ) {
2180 let mut state = make_state();
2181 for event in &events {
2182 let _ = state.process_event(event);
2183 }
2184 }
2185 }
2186 }
2187}
2188
2189#[cfg(feature = "fuzzing")]
2194pub mod fuzz {
2195 use super::*;
2196 use futures::stream;
2197 use std::pin::Pin;
2198
2199 type FuzzStream =
2200 Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
2201
2202 pub struct Processor(StreamState<FuzzStream>);
2204
2205 impl Default for Processor {
2206 fn default() -> Self {
2207 Self::new()
2208 }
2209 }
2210
2211 impl Processor {
2212 pub fn new() -> Self {
2214 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
2215 Self(StreamState::new(
2216 crate::sse::SseStream::new(Box::pin(empty)),
2217 "gemini-fuzz".into(),
2218 "google-generative".into(),
2219 "google".into(),
2220 ))
2221 }
2222
2223 pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
2225 self.0.process_event(data)?;
2226 Ok(self.0.pending_events.drain(..).collect())
2227 }
2228 }
2229}