Skip to main content

pi/providers/
gemini.rs

1//! Google Gemini API provider implementation.
2//!
3//! This module implements the Provider trait for the Google Gemini API,
4//! supporting streaming responses and function calling (tool use).
5
6use crate::error::{Error, Result};
7use crate::http::client::Client;
8use crate::model::{
9    AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ToolCall, Usage,
10    UserContent,
11};
12use crate::models::CompatConfig;
13use crate::provider::{Context, Provider, StreamOptions, ToolDef};
14use crate::sse::SseStream;
15use async_trait::async_trait;
16use futures::StreamExt;
17use futures::stream::{self, Stream};
18use serde::{Deserialize, Serialize};
19use std::collections::VecDeque;
20use std::pin::Pin;
21
22// ============================================================================
23// Constants
24// ============================================================================
25
26const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";
27const GOOGLE_GEMINI_CLI_BASE: &str = "https://cloudcode-pa.googleapis.com";
28const GOOGLE_ANTIGRAVITY_BASE: &str = "https://daily-cloudcode-pa.sandbox.googleapis.com";
29pub(crate) const DEFAULT_MAX_TOKENS: u32 = 8192;
30
31// ============================================================================
32// Gemini Provider
33// ============================================================================
34
35/// Google Gemini API provider.
36pub struct GeminiProvider {
37    client: Client,
38    model: String,
39    base_url: String,
40    provider: String,
41    api: String,
42    google_cli_mode: bool,
43    compat: Option<CompatConfig>,
44}
45
46impl GeminiProvider {
47    /// Create a new Gemini provider.
48    pub fn new(model: impl Into<String>) -> Self {
49        Self {
50            client: Client::new(),
51            model: model.into(),
52            base_url: GEMINI_API_BASE.to_string(),
53            provider: "google".to_string(),
54            api: "google-generative-ai".to_string(),
55            google_cli_mode: false,
56            compat: None,
57        }
58    }
59
60    /// Override provider name reported in streamed events.
61    #[must_use]
62    pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
63        self.provider = provider.into();
64        self
65    }
66
67    /// Override API identifier reported in streamed events.
68    #[must_use]
69    pub fn with_api_name(mut self, api: impl Into<String>) -> Self {
70        self.api = api.into();
71        self
72    }
73
74    /// Enable Google Cloud Code Assist mode (`google-gemini-cli` / `google-antigravity`).
75    #[must_use]
76    pub const fn with_google_cli_mode(mut self, enabled: bool) -> Self {
77        self.google_cli_mode = enabled;
78        self
79    }
80
81    /// Create with a custom base URL.
82    #[must_use]
83    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
84        self.base_url = base_url.into();
85        self
86    }
87
88    /// Create with a custom HTTP client (VCR, test harness, etc.).
89    #[must_use]
90    pub fn with_client(mut self, client: Client) -> Self {
91        self.client = client;
92        self
93    }
94
95    /// Attach provider-specific compatibility overrides.
96    #[must_use]
97    pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
98        self.compat = compat;
99        self
100    }
101
102    /// Build the streaming URL.
103    pub fn streaming_url(&self) -> String {
104        let base = {
105            let trimmed = self.base_url.trim();
106            if trimmed.is_empty() {
107                if self.google_cli_mode {
108                    if self.provider.eq_ignore_ascii_case("google-antigravity") {
109                        GOOGLE_ANTIGRAVITY_BASE
110                    } else {
111                        GOOGLE_GEMINI_CLI_BASE
112                    }
113                } else {
114                    GEMINI_API_BASE
115                }
116            } else {
117                trimmed
118            }
119        };
120        if self.google_cli_mode {
121            format!("{base}/v1internal:streamGenerateContent?alt=sse")
122        } else {
123            format!("{base}/models/{}:streamGenerateContent?alt=sse", self.model)
124        }
125    }
126
127    /// Build the request body for the Gemini API.
128    #[allow(clippy::unused_self)]
129    pub fn build_request(&self, context: &Context<'_>, options: &StreamOptions) -> GeminiRequest {
130        let contents = Self::build_contents(context);
131        let system_instruction = context.system_prompt.as_deref().map(|s| GeminiContent {
132            role: None,
133            parts: vec![GeminiPart::Text {
134                text: s.to_string(),
135            }],
136        });
137
138        let tools: Option<Vec<GeminiTool>> = if context.tools.is_empty() {
139            None
140        } else {
141            Some(vec![GeminiTool {
142                function_declarations: context.tools.iter().map(convert_tool_to_gemini).collect(),
143            }])
144        };
145
146        let tool_config = if tools.is_some() {
147            Some(GeminiToolConfig {
148                function_calling_config: GeminiFunctionCallingConfig { mode: "AUTO" },
149            })
150        } else {
151            None
152        };
153
154        GeminiRequest {
155            contents,
156            system_instruction,
157            tools,
158            tool_config,
159            generation_config: Some(GeminiGenerationConfig {
160                max_output_tokens: options.max_tokens.or(Some(DEFAULT_MAX_TOKENS)),
161                temperature: options.temperature,
162                candidate_count: Some(1),
163            }),
164        }
165    }
166
167    /// Build the contents array from context messages.
168    fn build_contents(context: &Context<'_>) -> Vec<GeminiContent> {
169        let mut contents = Vec::with_capacity(context.messages.len());
170
171        for message in context.messages.iter() {
172            contents.extend(convert_message_to_gemini(message));
173        }
174
175        contents
176    }
177}
178
179#[derive(Debug, Serialize)]
180#[serde(rename_all = "camelCase")]
181struct CloudCodeAssistRequest {
182    project: String,
183    model: String,
184    request: GeminiRequest,
185    #[serde(skip_serializing_if = "Option::is_none")]
186    request_type: Option<String>,
187    user_agent: String,
188    request_id: String,
189}
190
191fn build_google_cli_request(
192    model_id: &str,
193    project_id: &str,
194    request: GeminiRequest,
195    is_antigravity: bool,
196) -> std::result::Result<CloudCodeAssistRequest, &'static str> {
197    let safe_project = project_id.trim();
198    if safe_project.is_empty() {
199        return Err(
200            "Missing Google Cloud project ID for Gemini CLI. Set GOOGLE_CLOUD_PROJECT (or configure gcloud) and re-authenticate with /login google-gemini-cli.",
201        );
202    }
203    let project = if safe_project.starts_with("projects/") {
204        safe_project.to_string()
205    } else {
206        format!("projects/{safe_project}/locations/global")
207    };
208    Ok(CloudCodeAssistRequest {
209        project,
210        model: model_id.to_string(),
211        request,
212        request_type: is_antigravity.then(|| "agent".to_string()),
213        user_agent: if is_antigravity {
214            "antigravity".to_string()
215        } else {
216            "pi-coding-agent".to_string()
217        },
218        request_id: format!(
219            "{}-{}",
220            if is_antigravity { "agent" } else { "pi" },
221            uuid::Uuid::new_v4().simple()
222        ),
223    })
224}
225
226fn decode_project_scoped_access_payload(payload: &str) -> Option<(String, String)> {
227    let value: serde_json::Value = serde_json::from_str(payload).ok()?;
228    let token = value
229        .get("token")
230        .and_then(serde_json::Value::as_str)
231        .map(str::trim)
232        .filter(|value| !value.is_empty())?
233        .to_string();
234    let project_id = value
235        .get("projectId")
236        .or_else(|| value.get("project_id"))
237        .and_then(serde_json::Value::as_str)
238        .map(str::trim)
239        .filter(|value| !value.is_empty())?
240        .to_string();
241    Some((token, project_id))
242}
243
244#[async_trait]
245impl Provider for GeminiProvider {
246    fn name(&self) -> &str {
247        &self.provider
248    }
249
250    fn api(&self) -> &str {
251        &self.api
252    }
253
254    fn model_id(&self) -> &str {
255        &self.model
256    }
257
258    #[allow(clippy::too_many_lines)]
259    async fn stream(
260        &self,
261        context: &Context<'_>,
262        options: &StreamOptions,
263    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
264        let request_body = self.build_request(context, options);
265        let url = self.streaming_url();
266
267        // Build request (Content-Type set by .json() below)
268        let mut request = self.client.post(&url).header("Accept", "text/event-stream");
269
270        if self.google_cli_mode {
271            let api_payload = options.api_key.clone().ok_or_else(|| {
272                Error::provider(
273                    self.name(),
274                    "Google Gemini CLI requires OAuth credentials. Run /login google-gemini-cli.",
275                )
276            })?;
277            let (access_token, project_id) = decode_project_scoped_access_payload(&api_payload)
278                .ok_or_else(|| {
279                    Error::provider(
280                        self.name(),
281                        "Invalid Google Gemini CLI OAuth payload (expected JSON {token, projectId}). Run /login google-gemini-cli again.",
282                    )
283                })?;
284            let is_antigravity = self.provider.eq_ignore_ascii_case("google-antigravity");
285
286            request = request
287                .header("Authorization", format!("Bearer {access_token}"))
288                .header("Content-Type", "application/json")
289                .header("x-goog-api-client", "gl-node/22.17.0")
290                .header(
291                    "client-metadata",
292                    r#"{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}"#,
293                );
294
295            if is_antigravity {
296                request = request.header("User-Agent", "antigravity/1.15.8 darwin/arm64");
297            } else {
298                request =
299                    request.header("User-Agent", "google-cloud-sdk vscode_cloudshelleditor/0.1");
300            }
301
302            // Apply provider-specific custom headers from compat config.
303            if let Some(compat) = &self.compat {
304                if let Some(custom_headers) = &compat.custom_headers {
305                    for (key, value) in custom_headers {
306                        request = request.header(key, value);
307                    }
308                }
309            }
310
311            // Per-request headers from StreamOptions (highest priority).
312            for (key, value) in &options.headers {
313                request = request.header(key, value);
314            }
315
316            let cli_request =
317                build_google_cli_request(&self.model, &project_id, request_body, is_antigravity)
318                    .map_err(|message| Error::provider(self.name(), message.to_string()))?;
319            let request = request.json(&cli_request)?;
320            let response = Box::pin(request.send()).await?;
321            let status = response.status();
322            if !(200..300).contains(&status) {
323                let body = response
324                    .text()
325                    .await
326                    .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
327                return Err(Error::provider(
328                    self.name(),
329                    format!("Gemini CLI API error (HTTP {status}): {body}"),
330                ));
331            }
332
333            // Create SSE stream for streaming responses.
334            let event_source = SseStream::new(response.bytes_stream());
335            let model = self.model.clone();
336            let api = self.api().to_string();
337            let provider = self.name().to_string();
338            let cloud_cli_mode = self.google_cli_mode;
339
340            let stream = stream::unfold(
341                StreamState::new(event_source, model, api, provider),
342                move |mut state| async move {
343                    if state.finished {
344                        return None;
345                    }
346                    loop {
347                        // Drain pending events before polling for more SSE data
348                        if let Some(event) = state.pending_events.pop_front() {
349                            return Some((Ok(event), state));
350                        }
351
352                        match state.event_source.next().await {
353                            Some(Ok(msg)) => {
354                                if msg.event == "ping" {
355                                    continue;
356                                }
357
358                                let processing = if cloud_cli_mode {
359                                    state.process_cloud_code_event(&msg.data)
360                                } else {
361                                    state.process_event(&msg.data)
362                                };
363                                if let Err(e) = processing {
364                                    state.finished = true;
365                                    return Some((Err(e), state));
366                                }
367                            }
368                            Some(Err(e)) => {
369                                state.finished = true;
370                                let err = Error::api(format!("SSE error: {e}"));
371                                return Some((Err(err), state));
372                            }
373                            None => {
374                                // Stream ended naturally
375                                state.finished = true;
376                                let reason = state.partial.stop_reason;
377                                let message = std::mem::take(&mut state.partial);
378                                return Some((Ok(StreamEvent::Done { reason, message }), state));
379                            }
380                        }
381                    }
382                },
383            );
384
385            return Ok(Box::pin(stream));
386        }
387
388        let auth_value = options
389            .api_key
390            .clone()
391            .or_else(|| std::env::var("GOOGLE_API_KEY").ok())
392            .or_else(|| std::env::var("GEMINI_API_KEY").ok())
393            .ok_or_else(|| {
394                Error::provider(
395                    self.name(),
396                    "Missing API key for Google/Gemini. Set GOOGLE_API_KEY or GEMINI_API_KEY.",
397                )
398            })?;
399
400        request = request.header("x-goog-api-key", &auth_value);
401
402        // Apply provider-specific custom headers from compat config.
403        if let Some(compat) = &self.compat {
404            if let Some(custom_headers) = &compat.custom_headers {
405                for (key, value) in custom_headers {
406                    request = request.header(key, value);
407                }
408            }
409        }
410
411        // Per-request headers from StreamOptions (highest priority).
412        for (key, value) in &options.headers {
413            request = request.header(key, value);
414        }
415
416        let request = request.json(&request_body)?;
417
418        let response = Box::pin(request.send()).await?;
419        let status = response.status();
420        if !(200..300).contains(&status) {
421            let body = response
422                .text()
423                .await
424                .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
425            return Err(Error::provider(
426                self.name(),
427                format!("Gemini API error (HTTP {status}): {body}"),
428            ));
429        }
430
431        // Create SSE stream for streaming responses.
432        let event_source = SseStream::new(response.bytes_stream());
433
434        // Create stream state
435        let model = self.model.clone();
436        let api = self.api().to_string();
437        let provider = self.name().to_string();
438        let cloud_cli_mode = self.google_cli_mode;
439
440        let stream = stream::unfold(
441            StreamState::new(event_source, model, api, provider),
442            move |mut state| async move {
443                if state.finished {
444                    return None;
445                }
446                loop {
447                    // Drain pending events before polling for more SSE data
448                    if let Some(event) = state.pending_events.pop_front() {
449                        return Some((Ok(event), state));
450                    }
451
452                    match state.event_source.next().await {
453                        Some(Ok(msg)) => {
454                            if msg.event == "ping" {
455                                continue;
456                            }
457
458                            let processing = if cloud_cli_mode {
459                                state.process_cloud_code_event(&msg.data)
460                            } else {
461                                state.process_event(&msg.data)
462                            };
463                            if let Err(e) = processing {
464                                state.finished = true;
465                                return Some((Err(e), state));
466                            }
467                        }
468                        Some(Err(e)) => {
469                            state.finished = true;
470                            let err = Error::api(format!("SSE error: {e}"));
471                            return Some((Err(err), state));
472                        }
473                        None => {
474                            // Stream ended naturally
475                            state.finished = true;
476                            let reason = state.partial.stop_reason;
477                            let message = std::mem::take(&mut state.partial);
478                            return Some((Ok(StreamEvent::Done { reason, message }), state));
479                        }
480                    }
481                }
482            },
483        );
484
485        Ok(Box::pin(stream))
486    }
487}
488
489// ============================================================================
490// Stream State
491// ============================================================================
492
493struct StreamState<S>
494where
495    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
496{
497    event_source: SseStream<S>,
498    partial: AssistantMessage,
499    pending_events: VecDeque<StreamEvent>,
500    started: bool,
501    finished: bool,
502}
503
504impl<S> StreamState<S>
505where
506    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
507{
508    fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
509        Self {
510            event_source,
511            partial: AssistantMessage {
512                content: Vec::new(),
513                api,
514                provider,
515                model,
516                usage: Usage::default(),
517                stop_reason: StopReason::Stop,
518                error_message: None,
519                timestamp: chrono::Utc::now().timestamp_millis(),
520            },
521            pending_events: VecDeque::new(),
522            started: false,
523            finished: false,
524        }
525    }
526
527    fn process_event(&mut self, data: &str) -> Result<()> {
528        let response: GeminiStreamResponse = serde_json::from_str(data)
529            .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
530        self.process_response(response)
531    }
532
533    fn process_response(&mut self, response: GeminiStreamResponse) -> Result<()> {
534        // Handle usage metadata
535        if let Some(metadata) = response.usage_metadata {
536            self.partial.usage.input = metadata.prompt_token_count.unwrap_or(0);
537            self.partial.usage.output = metadata.candidates_token_count.unwrap_or(0);
538            self.partial.usage.total_tokens = metadata.total_token_count.unwrap_or(0);
539        }
540
541        // Process candidates
542        if let Some(candidates) = response.candidates {
543            if let Some(candidate) = candidates.into_iter().next() {
544                self.process_candidate(candidate)?;
545            }
546        }
547
548        Ok(())
549    }
550
551    fn process_cloud_code_event(&mut self, data: &str) -> Result<()> {
552        let wrapped: CloudCodeAssistResponseChunk = serde_json::from_str(data)
553            .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
554        let Some(response) = wrapped.response else {
555            return Ok(());
556        };
557        self.process_response(GeminiStreamResponse {
558            candidates: response.candidates,
559            usage_metadata: response.usage_metadata,
560        })
561    }
562
563    #[allow(clippy::unnecessary_wraps)]
564    fn process_candidate(&mut self, candidate: GeminiCandidate) -> Result<()> {
565        let has_finish_reason = candidate.finish_reason.is_some();
566
567        // Handle finish reason
568        if let Some(reason) = candidate.finish_reason.as_deref() {
569            self.partial.stop_reason = match reason {
570                "MAX_TOKENS" => StopReason::Length,
571                "SAFETY" | "RECITATION" | "OTHER" => StopReason::Error,
572                "FUNCTION_CALL" => StopReason::ToolUse,
573                // STOP and any other reason treated as normal stop
574                _ => StopReason::Stop,
575            };
576        }
577
578        // Process content parts — queue all events into pending_events
579        if let Some(content) = candidate.content {
580            for part in content.parts {
581                match part {
582                    GeminiPart::Text { text } => {
583                        // Accumulate text into partial
584                        let last_is_text =
585                            matches!(self.partial.content.last(), Some(ContentBlock::Text(_)));
586
587                        // Ensure Start is emitted before any TextStart/TextDelta events
588                        // so downstream consumers see the correct event order:
589                        // Start → TextStart → TextDelta
590                        self.ensure_started();
591
592                        let content_index = if last_is_text {
593                            self.partial.content.len() - 1
594                        } else {
595                            let idx = self.partial.content.len();
596                            self.partial
597                                .content
598                                .push(ContentBlock::Text(TextContent::new("")));
599                            self.pending_events
600                                .push_back(StreamEvent::TextStart { content_index: idx });
601                            idx
602                        };
603
604                        if let Some(ContentBlock::Text(t)) =
605                            self.partial.content.get_mut(content_index)
606                        {
607                            t.text.push_str(&text);
608                        }
609
610                        self.pending_events.push_back(StreamEvent::TextDelta {
611                            content_index,
612                            delta: text,
613                        });
614                    }
615                    GeminiPart::FunctionCall { function_call } => {
616                        // Generate a unique ID for this tool call
617                        let id = format!("call_{}", uuid::Uuid::new_v4().simple());
618
619                        // Serialize args for the delta event
620                        let args_str = serde_json::to_string(&function_call.args)
621                            .unwrap_or_else(|_| "{}".to_string());
622                        let GeminiFunctionCall { name, args } = function_call;
623
624                        let tool_call = ToolCall {
625                            id,
626                            name,
627                            arguments: args,
628                            thought_signature: None,
629                        };
630
631                        self.partial
632                            .content
633                            .push(ContentBlock::ToolCall(tool_call.clone()));
634                        let content_index = self.partial.content.len() - 1;
635
636                        // Update stop reason for tool use
637                        self.partial.stop_reason = StopReason::ToolUse;
638
639                        self.ensure_started();
640
641                        // Emit full ToolCallStart → ToolCallDelta → ToolCallEnd sequence
642                        self.pending_events
643                            .push_back(StreamEvent::ToolCallStart { content_index });
644                        self.pending_events.push_back(StreamEvent::ToolCallDelta {
645                            content_index,
646                            delta: args_str,
647                        });
648                        self.pending_events.push_back(StreamEvent::ToolCallEnd {
649                            content_index,
650                            tool_call,
651                        });
652                    }
653                    GeminiPart::InlineData { .. }
654                    | GeminiPart::FunctionResponse { .. }
655                    | GeminiPart::Unknown(_) => {
656                        // InlineData/FunctionResponse are for input, not output.
657                        // Unknown parts are silently skipped so new Gemini API
658                        // features don't break existing streams.
659                    }
660                }
661            }
662        }
663
664        // Emit TextEnd/ThinkingEnd for all open text/thinking blocks (not just the last
665        // one, since text/thinking may precede tool calls).
666        if has_finish_reason {
667            for (content_index, block) in self.partial.content.iter().enumerate() {
668                if let ContentBlock::Text(t) = block {
669                    self.pending_events.push_back(StreamEvent::TextEnd {
670                        content_index,
671                        content: t.text.clone(),
672                    });
673                } else if let ContentBlock::Thinking(t) = block {
674                    self.pending_events.push_back(StreamEvent::ThinkingEnd {
675                        content_index,
676                        content: t.thinking.clone(),
677                    });
678                }
679            }
680        }
681
682        Ok(())
683    }
684
685    fn ensure_started(&mut self) {
686        if !self.started {
687            self.started = true;
688            self.pending_events.push_back(StreamEvent::Start {
689                partial: self.partial.clone(),
690            });
691        }
692    }
693}
694
695// ============================================================================
696// Gemini API Types
697// ============================================================================
698
699#[derive(Debug, Serialize)]
700#[serde(rename_all = "camelCase")]
701pub struct GeminiRequest {
702    pub(crate) contents: Vec<GeminiContent>,
703    #[serde(skip_serializing_if = "Option::is_none")]
704    pub(crate) system_instruction: Option<GeminiContent>,
705    #[serde(skip_serializing_if = "Option::is_none")]
706    pub(crate) tools: Option<Vec<GeminiTool>>,
707    #[serde(skip_serializing_if = "Option::is_none")]
708    pub(crate) tool_config: Option<GeminiToolConfig>,
709    #[serde(skip_serializing_if = "Option::is_none")]
710    pub(crate) generation_config: Option<GeminiGenerationConfig>,
711}
712
713#[derive(Debug, Serialize, Deserialize)]
714#[serde(rename_all = "camelCase")]
715pub(crate) struct GeminiContent {
716    #[serde(skip_serializing_if = "Option::is_none")]
717    pub(crate) role: Option<String>,
718    pub(crate) parts: Vec<GeminiPart>,
719}
720
721#[derive(Debug, Serialize, Deserialize)]
722#[serde(untagged)]
723pub(crate) enum GeminiPart {
724    Text {
725        text: String,
726    },
727    InlineData {
728        inline_data: GeminiBlob,
729    },
730    FunctionCall {
731        #[serde(rename = "functionCall")]
732        function_call: GeminiFunctionCall,
733    },
734    FunctionResponse {
735        #[serde(rename = "functionResponse")]
736        function_response: GeminiFunctionResponse,
737    },
738    /// Catch-all for unrecognized part types (e.g. `executableCode`,
739    /// `codeExecutionResult`) so that new Gemini API features don't
740    /// cause hard deserialization failures that terminate the stream.
741    Unknown(serde_json::Value),
742}
743
744#[derive(Debug, Serialize, Deserialize)]
745#[serde(rename_all = "camelCase")]
746pub(crate) struct GeminiBlob {
747    pub(crate) mime_type: String,
748    pub(crate) data: String,
749}
750
751#[derive(Debug, Serialize, Deserialize)]
752pub(crate) struct GeminiFunctionCall {
753    pub(crate) name: String,
754    pub(crate) args: serde_json::Value,
755}
756
757#[derive(Debug, Serialize, Deserialize)]
758pub(crate) struct GeminiFunctionResponse {
759    pub(crate) name: String,
760    pub(crate) response: serde_json::Value,
761}
762
763#[derive(Debug, Serialize)]
764#[serde(rename_all = "camelCase")]
765pub(crate) struct GeminiTool {
766    pub(crate) function_declarations: Vec<GeminiFunctionDeclaration>,
767}
768
769#[derive(Debug, Serialize)]
770pub(crate) struct GeminiFunctionDeclaration {
771    pub(crate) name: String,
772    pub(crate) description: String,
773    pub(crate) parameters: serde_json::Value,
774}
775
776#[derive(Debug, Serialize)]
777#[serde(rename_all = "camelCase")]
778pub(crate) struct GeminiToolConfig {
779    pub(crate) function_calling_config: GeminiFunctionCallingConfig,
780}
781
782#[derive(Debug, Serialize)]
783pub(crate) struct GeminiFunctionCallingConfig {
784    pub(crate) mode: &'static str,
785}
786
787#[derive(Debug, Serialize)]
788#[serde(rename_all = "camelCase")]
789pub(crate) struct GeminiGenerationConfig {
790    #[serde(skip_serializing_if = "Option::is_none")]
791    pub(crate) max_output_tokens: Option<u32>,
792    #[serde(skip_serializing_if = "Option::is_none")]
793    pub(crate) temperature: Option<f32>,
794    #[serde(skip_serializing_if = "Option::is_none")]
795    pub(crate) candidate_count: Option<u32>,
796}
797
798// ============================================================================
799// Streaming Response Types
800// ============================================================================
801
802#[derive(Debug, Deserialize)]
803#[serde(rename_all = "camelCase")]
804pub(crate) struct GeminiStreamResponse {
805    #[serde(default)]
806    pub(crate) candidates: Option<Vec<GeminiCandidate>>,
807    #[serde(default)]
808    pub(crate) usage_metadata: Option<GeminiUsageMetadata>,
809}
810
811#[derive(Debug, Deserialize)]
812#[serde(rename_all = "camelCase")]
813struct CloudCodeAssistResponseChunk {
814    #[serde(default)]
815    response: Option<CloudCodeAssistResponse>,
816}
817
818#[derive(Debug, Deserialize)]
819#[serde(rename_all = "camelCase")]
820struct CloudCodeAssistResponse {
821    #[serde(default)]
822    candidates: Option<Vec<GeminiCandidate>>,
823    #[serde(default)]
824    usage_metadata: Option<GeminiUsageMetadata>,
825}
826
827#[derive(Debug, Deserialize)]
828#[serde(rename_all = "camelCase")]
829pub(crate) struct GeminiCandidate {
830    #[serde(default)]
831    pub(crate) content: Option<GeminiContent>,
832    #[serde(default)]
833    pub(crate) finish_reason: Option<String>,
834}
835
836#[derive(Debug, Deserialize)]
837#[serde(rename_all = "camelCase")]
838#[allow(clippy::struct_field_names)]
839pub(crate) struct GeminiUsageMetadata {
840    #[serde(default)]
841    pub(crate) prompt_token_count: Option<u64>,
842    #[serde(default)]
843    pub(crate) candidates_token_count: Option<u64>,
844    #[serde(default)]
845    pub(crate) total_token_count: Option<u64>,
846}
847
848// ============================================================================
849// Conversion Functions
850// ============================================================================
851
852pub(crate) fn convert_message_to_gemini(message: &Message) -> Vec<GeminiContent> {
853    match message {
854        Message::User(user) => vec![GeminiContent {
855            role: Some("user".into()),
856            parts: convert_user_content_to_parts(&user.content),
857        }],
858        Message::Custom(custom) => vec![GeminiContent {
859            role: Some("user".into()),
860            parts: vec![GeminiPart::Text {
861                text: custom.content.clone(),
862            }],
863        }],
864        Message::Assistant(assistant) => {
865            let mut parts = Vec::new();
866
867            for block in &assistant.content {
868                match block {
869                    ContentBlock::Text(t) => {
870                        parts.push(GeminiPart::Text {
871                            text: t.text.clone(),
872                        });
873                    }
874                    ContentBlock::ToolCall(tc) => {
875                        parts.push(GeminiPart::FunctionCall {
876                            function_call: GeminiFunctionCall {
877                                name: tc.name.clone(),
878                                args: tc.arguments.clone(),
879                            },
880                        });
881                    }
882                    ContentBlock::Thinking(_) | ContentBlock::Image(_) => {
883                        // Skip thinking blocks and images in assistant output
884                    }
885                }
886            }
887
888            if parts.is_empty() {
889                return Vec::new();
890            }
891
892            vec![GeminiContent {
893                role: Some("model".into()),
894                parts,
895            }]
896        }
897        Message::ToolResult(result) => {
898            // Gemini expects function responses as user role with functionResponse part
899            let content_text = result
900                .content
901                .iter()
902                .map(|b| match b {
903                    ContentBlock::Text(t) => t.text.clone(),
904                    ContentBlock::Image(img) => format!("[Image ({}) omitted]", img.mime_type),
905                    _ => String::new(),
906                })
907                .filter(|s| !s.is_empty())
908                .collect::<Vec<_>>()
909                .join("\n");
910
911            let response_value = if result.is_error {
912                serde_json::json!({ "error": content_text })
913            } else {
914                serde_json::json!({ "result": content_text })
915            };
916
917            vec![GeminiContent {
918                role: Some("user".into()),
919                parts: vec![GeminiPart::FunctionResponse {
920                    function_response: GeminiFunctionResponse {
921                        name: result.tool_name.clone(),
922                        response: response_value,
923                    },
924                }],
925            }]
926        }
927    }
928}
929
930pub(crate) fn convert_user_content_to_parts(content: &UserContent) -> Vec<GeminiPart> {
931    match content {
932        UserContent::Text(text) => vec![GeminiPart::Text { text: text.clone() }],
933        UserContent::Blocks(blocks) => blocks
934            .iter()
935            .filter_map(|block| match block {
936                ContentBlock::Text(t) => Some(GeminiPart::Text {
937                    text: t.text.clone(),
938                }),
939                ContentBlock::Image(img) => Some(GeminiPart::InlineData {
940                    inline_data: GeminiBlob {
941                        mime_type: img.mime_type.clone(),
942                        data: img.data.clone(),
943                    },
944                }),
945                _ => None,
946            })
947            .collect(),
948    }
949}
950
951pub(crate) fn convert_tool_to_gemini(tool: &ToolDef) -> GeminiFunctionDeclaration {
952    GeminiFunctionDeclaration {
953        name: tool.name.clone(),
954        description: tool.description.clone(),
955        parameters: tool.parameters.clone(),
956    }
957}
958
959// ============================================================================
960// Tests
961// ============================================================================
962
963#[cfg(test)]
964mod tests {
965    use super::*;
966    use asupersync::runtime::RuntimeBuilder;
967    use futures::{StreamExt, stream};
968    use serde::{Deserialize, Serialize};
969    use serde_json::Value;
970    use std::path::PathBuf;
971
972    #[test]
973    fn test_convert_user_text_message() {
974        let message = Message::User(crate::model::UserMessage {
975            content: UserContent::Text("Hello".to_string()),
976            timestamp: 0,
977        });
978
979        let converted = convert_message_to_gemini(&message);
980        assert_eq!(converted.len(), 1);
981        assert_eq!(converted[0].role, Some("user".to_string()));
982    }
983
984    #[test]
985    fn test_tool_conversion() {
986        let tool = ToolDef {
987            name: "test_tool".to_string(),
988            description: "A test tool".to_string(),
989            parameters: serde_json::json!({
990                "type": "object",
991                "properties": {
992                    "arg": {"type": "string"}
993                }
994            }),
995        };
996
997        let converted = convert_tool_to_gemini(&tool);
998        assert_eq!(converted.name, "test_tool");
999        assert_eq!(converted.description, "A test tool");
1000    }
1001
1002    #[test]
1003    fn test_provider_info() {
1004        let provider = GeminiProvider::new("gemini-2.0-flash");
1005        assert_eq!(provider.name(), "google");
1006        assert_eq!(provider.api(), "google-generative-ai");
1007    }
1008
1009    #[test]
1010    fn test_streaming_url() {
1011        let provider = GeminiProvider::new("gemini-2.0-flash");
1012        let url = provider.streaming_url();
1013        assert!(url.contains("gemini-2.0-flash"));
1014        assert!(url.contains("streamGenerateContent"));
1015        assert!(!url.contains("key="));
1016    }
1017
1018    #[derive(Debug, Deserialize)]
1019    struct ProviderFixture {
1020        cases: Vec<ProviderCase>,
1021    }
1022
1023    #[derive(Debug, Deserialize)]
1024    struct ProviderCase {
1025        name: String,
1026        events: Vec<Value>,
1027        expected: Vec<EventSummary>,
1028    }
1029
1030    #[derive(Debug, Deserialize, Serialize, PartialEq)]
1031    struct EventSummary {
1032        kind: String,
1033        #[serde(default)]
1034        content_index: Option<usize>,
1035        #[serde(default)]
1036        delta: Option<String>,
1037        #[serde(default)]
1038        content: Option<String>,
1039        #[serde(default)]
1040        reason: Option<String>,
1041    }
1042
1043    #[test]
1044    fn test_stream_fixtures() {
1045        let fixture = load_fixture("gemini_stream.json");
1046        for case in fixture.cases {
1047            let events = collect_events(&case.events);
1048            let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1049            assert_eq!(summaries, case.expected, "case {}", case.name);
1050        }
1051    }
1052
1053    fn load_fixture(file_name: &str) -> ProviderFixture {
1054        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1055            .join("tests/fixtures/provider_responses")
1056            .join(file_name);
1057        let raw = std::fs::read_to_string(path).expect("fixture read");
1058        serde_json::from_str(&raw).expect("fixture parse")
1059    }
1060
1061    fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1062        let runtime = RuntimeBuilder::current_thread()
1063            .build()
1064            .expect("runtime build");
1065        runtime.block_on(async move {
1066            let byte_stream = stream::iter(
1067                events
1068                    .iter()
1069                    .map(|event| {
1070                        let data = match event {
1071                            Value::String(text) => text.clone(),
1072                            _ => serde_json::to_string(event).expect("serialize event"),
1073                        };
1074                        format!("data: {data}\n\n").into_bytes()
1075                    })
1076                    .map(Ok),
1077            );
1078            let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1079            let mut state = StreamState::new(
1080                event_source,
1081                "gemini-test".to_string(),
1082                "google-generative".to_string(),
1083                "google".to_string(),
1084            );
1085            let mut out = Vec::new();
1086
1087            loop {
1088                let Some(item) = state.event_source.next().await else {
1089                    if !state.finished {
1090                        state.finished = true;
1091                        out.push(StreamEvent::Done {
1092                            reason: state.partial.stop_reason,
1093                            message: std::mem::take(&mut state.partial),
1094                        });
1095                    }
1096                    break;
1097                };
1098
1099                let msg = item.expect("SSE event");
1100                if msg.event == "ping" {
1101                    continue;
1102                }
1103                state.process_event(&msg.data).expect("process_event");
1104                out.extend(state.pending_events.drain(..));
1105            }
1106
1107            out
1108        })
1109    }
1110
1111    fn summarize_event(event: &StreamEvent) -> EventSummary {
1112        match event {
1113            StreamEvent::Start { .. } => EventSummary {
1114                kind: "start".to_string(),
1115                content_index: None,
1116                delta: None,
1117                content: None,
1118                reason: None,
1119            },
1120            StreamEvent::TextDelta {
1121                content_index,
1122                delta,
1123                ..
1124            } => EventSummary {
1125                kind: "text_delta".to_string(),
1126                content_index: Some(*content_index),
1127                delta: Some(delta.clone()),
1128                content: None,
1129                reason: None,
1130            },
1131            StreamEvent::Done { reason, .. } => EventSummary {
1132                kind: "done".to_string(),
1133                content_index: None,
1134                delta: None,
1135                content: None,
1136                reason: Some(reason_to_string(*reason)),
1137            },
1138            StreamEvent::Error { reason, .. } => EventSummary {
1139                kind: "error".to_string(),
1140                content_index: None,
1141                delta: None,
1142                content: None,
1143                reason: Some(reason_to_string(*reason)),
1144            },
1145            StreamEvent::TextStart { content_index, .. } => EventSummary {
1146                kind: "text_start".to_string(),
1147                content_index: Some(*content_index),
1148                delta: None,
1149                content: None,
1150                reason: None,
1151            },
1152            StreamEvent::TextEnd {
1153                content_index,
1154                content,
1155                ..
1156            } => EventSummary {
1157                kind: "text_end".to_string(),
1158                content_index: Some(*content_index),
1159                delta: None,
1160                content: Some(content.clone()),
1161                reason: None,
1162            },
1163            _ => EventSummary {
1164                kind: "other".to_string(),
1165                content_index: None,
1166                delta: None,
1167                content: None,
1168                reason: None,
1169            },
1170        }
1171    }
1172
1173    fn reason_to_string(reason: StopReason) -> String {
1174        match reason {
1175            StopReason::Stop => "stop",
1176            StopReason::Length => "length",
1177            StopReason::ToolUse => "tool_use",
1178            StopReason::Error => "error",
1179            StopReason::Aborted => "aborted",
1180        }
1181        .to_string()
1182    }
1183
1184    // ─── Request body format tests ──────────────────────────────────────
1185
1186    #[test]
1187    fn test_build_request_basic_text() {
1188        let provider = GeminiProvider::new("gemini-2.0-flash");
1189        let context = Context::owned(
1190            Some("You are helpful.".to_string()),
1191            vec![Message::User(crate::model::UserMessage {
1192                content: UserContent::Text("What is Rust?".to_string()),
1193                timestamp: 0,
1194            })],
1195            vec![],
1196        );
1197        let options = crate::provider::StreamOptions {
1198            max_tokens: Some(1024),
1199            temperature: Some(0.7),
1200            ..Default::default()
1201        };
1202
1203        let req = provider.build_request(&context, &options);
1204        let json = serde_json::to_value(&req).expect("serialize");
1205
1206        // Contents should have exactly one user message.
1207        let contents = json["contents"].as_array().expect("contents array");
1208        assert_eq!(contents.len(), 1);
1209        assert_eq!(contents[0]["role"], "user");
1210        assert_eq!(contents[0]["parts"][0]["text"], "What is Rust?");
1211
1212        // System instruction should be present.
1213        assert_eq!(
1214            json["systemInstruction"]["parts"][0]["text"],
1215            "You are helpful."
1216        );
1217
1218        // No tools should be present.
1219        assert!(json.get("tools").is_none() || json["tools"].is_null());
1220
1221        // Generation config should match.
1222        assert_eq!(json["generationConfig"]["maxOutputTokens"], 1024);
1223        assert!((json["generationConfig"]["temperature"].as_f64().unwrap() - 0.7).abs() < 0.01);
1224        assert_eq!(json["generationConfig"]["candidateCount"], 1);
1225    }
1226
1227    #[test]
1228    fn test_build_request_with_tools() {
1229        let provider = GeminiProvider::new("gemini-2.0-flash");
1230        let context = Context::owned(
1231            None,
1232            vec![Message::User(crate::model::UserMessage {
1233                content: UserContent::Text("Read a file".to_string()),
1234                timestamp: 0,
1235            })],
1236            vec![
1237                ToolDef {
1238                    name: "read".to_string(),
1239                    description: "Read a file".to_string(),
1240                    parameters: serde_json::json!({
1241                        "type": "object",
1242                        "properties": {
1243                            "path": {"type": "string"}
1244                        },
1245                        "required": ["path"]
1246                    }),
1247                },
1248                ToolDef {
1249                    name: "write".to_string(),
1250                    description: "Write a file".to_string(),
1251                    parameters: serde_json::json!({
1252                        "type": "object",
1253                        "properties": {
1254                            "path": {"type": "string"},
1255                            "content": {"type": "string"}
1256                        }
1257                    }),
1258                },
1259            ],
1260        );
1261        let options = crate::provider::StreamOptions::default();
1262
1263        let req = provider.build_request(&context, &options);
1264        let json = serde_json::to_value(&req).expect("serialize");
1265
1266        // System instruction should be absent.
1267        assert!(json.get("systemInstruction").is_none() || json["systemInstruction"].is_null());
1268
1269        // Tools should be present as a single GeminiTool with function_declarations array.
1270        let tools = json["tools"].as_array().expect("tools array");
1271        assert_eq!(tools.len(), 1);
1272        let declarations = tools[0]["functionDeclarations"]
1273            .as_array()
1274            .expect("declarations");
1275        assert_eq!(declarations.len(), 2);
1276        assert_eq!(declarations[0]["name"], "read");
1277        assert_eq!(declarations[1]["name"], "write");
1278        assert_eq!(declarations[0]["description"], "Read a file");
1279
1280        // Tool config should be AUTO mode.
1281        assert_eq!(json["toolConfig"]["functionCallingConfig"]["mode"], "AUTO");
1282    }
1283
1284    #[test]
1285    fn test_build_request_default_max_tokens() {
1286        let provider = GeminiProvider::new("gemini-2.0-flash");
1287        let context = Context::owned(
1288            None,
1289            vec![Message::User(crate::model::UserMessage {
1290                content: UserContent::Text("hi".to_string()),
1291                timestamp: 0,
1292            })],
1293            vec![],
1294        );
1295        let options = crate::provider::StreamOptions::default();
1296
1297        let req = provider.build_request(&context, &options);
1298        let json = serde_json::to_value(&req).expect("serialize");
1299
1300        // Default max tokens should be DEFAULT_MAX_TOKENS (8192).
1301        assert_eq!(
1302            json["generationConfig"]["maxOutputTokens"],
1303            DEFAULT_MAX_TOKENS
1304        );
1305    }
1306
1307    // ─── API key as query parameter tests ───────────────────────────────
1308
1309    #[test]
1310    fn test_streaming_url_no_key_query_param() {
1311        let provider = GeminiProvider::new("gemini-2.0-flash");
1312        let url = provider.streaming_url();
1313
1314        // API key should NOT be in the query string.
1315        assert!(
1316            !url.contains("key="),
1317            "API key should not be in query param"
1318        );
1319        assert!(url.contains("alt=sse"), "alt=sse should be present");
1320        assert!(
1321            url.contains("streamGenerateContent"),
1322            "should use streaming endpoint"
1323        );
1324    }
1325
1326    #[test]
1327    fn test_streaming_url_custom_base() {
1328        let provider =
1329            GeminiProvider::new("gemini-pro").with_base_url("https://custom.example.com/v1");
1330        let url = provider.streaming_url();
1331
1332        assert!(url.starts_with("https://custom.example.com/v1/models/gemini-pro"));
1333        assert!(!url.contains("key="));
1334    }
1335
1336    // ─── Content part mapping tests ─────────────────────────────────────
1337
1338    #[test]
1339    fn test_convert_user_text_to_gemini_parts() {
1340        let parts = convert_user_content_to_parts(&UserContent::Text("hello world".to_string()));
1341        assert_eq!(parts.len(), 1);
1342        match &parts[0] {
1343            GeminiPart::Text { text } => assert_eq!(text, "hello world"),
1344            _ => panic!("expected text part"),
1345        }
1346    }
1347
1348    #[test]
1349    fn test_convert_user_blocks_with_image_to_gemini_parts() {
1350        let content = UserContent::Blocks(vec![
1351            ContentBlock::Text(TextContent::new("describe this")),
1352            ContentBlock::Image(crate::model::ImageContent {
1353                data: "aGVsbG8=".to_string(),
1354                mime_type: "image/png".to_string(),
1355            }),
1356        ]);
1357
1358        let parts = convert_user_content_to_parts(&content);
1359        assert_eq!(parts.len(), 2);
1360        match &parts[0] {
1361            GeminiPart::Text { text } => assert_eq!(text, "describe this"),
1362            _ => panic!("expected text part"),
1363        }
1364        match &parts[1] {
1365            GeminiPart::InlineData { inline_data } => {
1366                assert_eq!(inline_data.mime_type, "image/png");
1367                assert_eq!(inline_data.data, "aGVsbG8=");
1368            }
1369            _ => panic!("expected inline_data part"),
1370        }
1371    }
1372
1373    #[test]
1374    fn test_convert_assistant_message_with_tool_call() {
1375        let message = Message::assistant(AssistantMessage {
1376            content: vec![
1377                ContentBlock::Text(TextContent::new("Let me read that file.")),
1378                ContentBlock::ToolCall(ToolCall {
1379                    id: "call_123".to_string(),
1380                    name: "read".to_string(),
1381                    arguments: serde_json::json!({"path": "/tmp/test.txt"}),
1382                    thought_signature: None,
1383                }),
1384            ],
1385            api: "google".to_string(),
1386            provider: "google".to_string(),
1387            model: "gemini-2.0-flash".to_string(),
1388            usage: Usage::default(),
1389            stop_reason: StopReason::ToolUse,
1390            error_message: None,
1391            timestamp: 0,
1392        });
1393
1394        let converted = convert_message_to_gemini(&message);
1395        assert_eq!(converted.len(), 1);
1396        assert_eq!(converted[0].role, Some("model".to_string()));
1397        assert_eq!(converted[0].parts.len(), 2);
1398
1399        match &converted[0].parts[0] {
1400            GeminiPart::Text { text } => assert_eq!(text, "Let me read that file."),
1401            _ => panic!("expected text part"),
1402        }
1403        match &converted[0].parts[1] {
1404            GeminiPart::FunctionCall { function_call } => {
1405                assert_eq!(function_call.name, "read");
1406                assert_eq!(function_call.args["path"], "/tmp/test.txt");
1407            }
1408            _ => panic!("expected function_call part"),
1409        }
1410    }
1411
1412    #[test]
1413    fn test_convert_assistant_empty_content_returns_empty() {
1414        let message = Message::assistant(AssistantMessage {
1415            content: vec![],
1416            api: "google".to_string(),
1417            provider: "google".to_string(),
1418            model: "gemini-2.0-flash".to_string(),
1419            usage: Usage::default(),
1420            stop_reason: StopReason::Stop,
1421            error_message: None,
1422            timestamp: 0,
1423        });
1424
1425        let converted = convert_message_to_gemini(&message);
1426        assert!(converted.is_empty());
1427    }
1428
1429    #[test]
1430    fn test_convert_tool_result_success() {
1431        let message = Message::tool_result(crate::model::ToolResultMessage {
1432            tool_call_id: "call_123".to_string(),
1433            tool_name: "read".to_string(),
1434            content: vec![ContentBlock::Text(TextContent::new("file contents here"))],
1435            details: None,
1436            is_error: false,
1437            timestamp: 0,
1438        });
1439
1440        let converted = convert_message_to_gemini(&message);
1441        assert_eq!(converted.len(), 1);
1442        assert_eq!(converted[0].role, Some("user".to_string()));
1443
1444        match &converted[0].parts[0] {
1445            GeminiPart::FunctionResponse { function_response } => {
1446                assert_eq!(function_response.name, "read");
1447                assert_eq!(function_response.response["result"], "file contents here");
1448                assert!(function_response.response.get("error").is_none());
1449            }
1450            _ => panic!("expected function_response part"),
1451        }
1452    }
1453
1454    #[test]
1455    fn test_convert_tool_result_error() {
1456        let message = Message::tool_result(crate::model::ToolResultMessage {
1457            tool_call_id: "call_456".to_string(),
1458            tool_name: "bash".to_string(),
1459            content: vec![ContentBlock::Text(TextContent::new("command not found"))],
1460            details: None,
1461            is_error: true,
1462            timestamp: 0,
1463        });
1464
1465        let converted = convert_message_to_gemini(&message);
1466        assert_eq!(converted.len(), 1);
1467
1468        match &converted[0].parts[0] {
1469            GeminiPart::FunctionResponse { function_response } => {
1470                assert_eq!(function_response.name, "bash");
1471                assert_eq!(function_response.response["error"], "command not found");
1472                assert!(function_response.response.get("result").is_none());
1473            }
1474            _ => panic!("expected function_response part"),
1475        }
1476    }
1477
1478    #[test]
1479    fn test_convert_custom_message() {
1480        let message = Message::Custom(crate::model::CustomMessage {
1481            custom_type: "system_note".to_string(),
1482            content: "Context window approaching limit.".to_string(),
1483            display: false,
1484            details: None,
1485            timestamp: 0,
1486        });
1487
1488        let converted = convert_message_to_gemini(&message);
1489        assert_eq!(converted.len(), 1);
1490        assert_eq!(converted[0].role, Some("user".to_string()));
1491        match &converted[0].parts[0] {
1492            GeminiPart::Text { text } => {
1493                assert_eq!(text, "Context window approaching limit.");
1494            }
1495            _ => panic!("expected text part"),
1496        }
1497    }
1498
1499    // ─── Response parsing / stop reason tests ───────────────────────────
1500
1501    #[test]
1502    fn test_stop_reason_mapping() {
1503        // Test all finish reason strings map correctly.
1504        let test_cases = vec![
1505            ("STOP", StopReason::Stop),
1506            ("MAX_TOKENS", StopReason::Length),
1507            ("SAFETY", StopReason::Error),
1508            ("RECITATION", StopReason::Error),
1509            ("OTHER", StopReason::Error),
1510            ("UNKNOWN_REASON", StopReason::Stop), // unknown defaults to Stop
1511        ];
1512
1513        for (reason_str, expected) in test_cases {
1514            let candidate = GeminiCandidate {
1515                content: None,
1516                finish_reason: Some(reason_str.to_string()),
1517            };
1518
1519            let runtime = RuntimeBuilder::current_thread().build().unwrap();
1520            runtime.block_on(async {
1521                let byte_stream = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1522                let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1523                let mut state = StreamState::new(
1524                    event_source,
1525                    "test".to_string(),
1526                    "test".to_string(),
1527                    "test".to_string(),
1528                );
1529                state.process_candidate(candidate).unwrap();
1530                assert_eq!(
1531                    state.partial.stop_reason, expected,
1532                    "finish_reason '{reason_str}' should map to {expected:?}"
1533                );
1534            });
1535        }
1536    }
1537
1538    #[test]
1539    fn test_usage_metadata_parsing() {
1540        let data = r#"{
1541            "usageMetadata": {
1542                "promptTokenCount": 42,
1543                "candidatesTokenCount": 100,
1544                "totalTokenCount": 142
1545            }
1546        }"#;
1547
1548        let runtime = RuntimeBuilder::current_thread().build().unwrap();
1549        runtime.block_on(async {
1550            let byte_stream = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1551            let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1552            let mut state = StreamState::new(
1553                event_source,
1554                "test".to_string(),
1555                "test".to_string(),
1556                "test".to_string(),
1557            );
1558            state.process_event(data).unwrap();
1559            assert_eq!(state.partial.usage.input, 42);
1560            assert_eq!(state.partial.usage.output, 100);
1561            assert_eq!(state.partial.usage.total_tokens, 142);
1562        });
1563    }
1564
1565    // ─── Full conversation round-trip tests ─────────────────────────────
1566
1567    #[test]
1568    fn test_build_request_full_conversation() {
1569        let provider = GeminiProvider::new("gemini-2.0-flash");
1570        let context = Context::owned(
1571            Some("Be concise.".to_string()),
1572            vec![
1573                Message::User(crate::model::UserMessage {
1574                    content: UserContent::Text("Read /tmp/a.txt".to_string()),
1575                    timestamp: 0,
1576                }),
1577                Message::assistant(AssistantMessage {
1578                    content: vec![ContentBlock::ToolCall(ToolCall {
1579                        id: "call_1".to_string(),
1580                        name: "read".to_string(),
1581                        arguments: serde_json::json!({"path": "/tmp/a.txt"}),
1582                        thought_signature: None,
1583                    })],
1584                    api: "google".to_string(),
1585                    provider: "google".to_string(),
1586                    model: "gemini-2.0-flash".to_string(),
1587                    usage: Usage::default(),
1588                    stop_reason: StopReason::ToolUse,
1589                    error_message: None,
1590                    timestamp: 1,
1591                }),
1592                Message::tool_result(crate::model::ToolResultMessage {
1593                    tool_call_id: "call_1".to_string(),
1594                    tool_name: "read".to_string(),
1595                    content: vec![ContentBlock::Text(TextContent::new("file contents"))],
1596                    details: None,
1597                    is_error: false,
1598                    timestamp: 2,
1599                }),
1600            ],
1601            vec![ToolDef {
1602                name: "read".to_string(),
1603                description: "Read a file".to_string(),
1604                parameters: serde_json::json!({"type": "object"}),
1605            }],
1606        );
1607        let options = crate::provider::StreamOptions::default();
1608
1609        let req = provider.build_request(&context, &options);
1610        let json = serde_json::to_value(&req).expect("serialize");
1611
1612        let contents = json["contents"].as_array().expect("contents");
1613        assert_eq!(contents.len(), 3); // user, model, user (tool result)
1614
1615        // First: user message
1616        assert_eq!(contents[0]["role"], "user");
1617        assert_eq!(contents[0]["parts"][0]["text"], "Read /tmp/a.txt");
1618
1619        // Second: model with function call
1620        assert_eq!(contents[1]["role"], "model");
1621        assert_eq!(contents[1]["parts"][0]["functionCall"]["name"], "read");
1622
1623        // Third: function response (sent as user role)
1624        assert_eq!(contents[2]["role"], "user");
1625        assert_eq!(contents[2]["parts"][0]["functionResponse"]["name"], "read");
1626        assert_eq!(
1627            contents[2]["parts"][0]["functionResponse"]["response"]["result"],
1628            "file contents"
1629        );
1630    }
1631
1632    // ========================================================================
1633    // Proptest — process_event() fuzz coverage (FUZZ-P1.3)
1634    // ========================================================================
1635
1636    mod proptest_process_event {
1637        use super::*;
1638        use proptest::prelude::*;
1639
1640        fn make_state()
1641        -> StreamState<impl Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin>
1642        {
1643            let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1644            let sse = crate::sse::SseStream::new(Box::pin(empty));
1645            StreamState::new(
1646                sse,
1647                "gemini-test".into(),
1648                "google-generative".into(),
1649                "google".into(),
1650            )
1651        }
1652
1653        fn small_string() -> impl Strategy<Value = String> {
1654            prop_oneof![Just(String::new()), "[a-zA-Z0-9_]{1,16}", "[ -~]{0,32}",]
1655        }
1656
1657        fn token_count() -> impl Strategy<Value = u64> {
1658            prop_oneof![
1659                5 => 0u64..10_000u64,
1660                2 => Just(0u64),
1661                1 => Just(u64::MAX),
1662                1 => (u64::MAX - 100)..=u64::MAX,
1663            ]
1664        }
1665
1666        fn finish_reason() -> impl Strategy<Value = Option<String>> {
1667            prop_oneof![
1668                3 => Just(None),
1669                1 => Just(Some("STOP".to_string())),
1670                1 => Just(Some("MAX_TOKENS".to_string())),
1671                1 => Just(Some("SAFETY".to_string())),
1672                1 => Just(Some("RECITATION".to_string())),
1673                1 => Just(Some("OTHER".to_string())),
1674                1 => small_string().prop_map(Some),
1675            ]
1676        }
1677
1678        /// Generate a JSON `Value` representing a Gemini function call args object.
1679        fn json_args() -> impl Strategy<Value = serde_json::Value> {
1680            prop_oneof![
1681                Just(serde_json::json!({})),
1682                Just(serde_json::json!({"key": "value"})),
1683                Just(serde_json::json!({"a": 1, "b": true, "c": null})),
1684                small_string().prop_map(|s| serde_json::json!({"input": s})),
1685            ]
1686        }
1687
1688        /// Strategy for Gemini text parts.
1689        fn text_part() -> impl Strategy<Value = serde_json::Value> {
1690            small_string().prop_map(|t| serde_json::json!({"text": t}))
1691        }
1692
1693        /// Strategy for Gemini function call parts.
1694        fn function_call_part() -> impl Strategy<Value = serde_json::Value> {
1695            (small_string(), json_args()).prop_map(
1696                |(name, args)| serde_json::json!({"functionCall": {"name": name, "args": args}}),
1697            )
1698        }
1699
1700        /// Strategy for content parts (mix of text and function calls).
1701        fn parts_strategy() -> impl Strategy<Value = Vec<serde_json::Value>> {
1702            prop::collection::vec(
1703                prop_oneof![3 => text_part(), 1 => function_call_part(),],
1704                0..5,
1705            )
1706        }
1707
1708        /// Generate valid `GeminiStreamResponse` JSON strings.
1709        fn gemini_response_json() -> impl Strategy<Value = String> {
1710            prop_oneof![
1711                // Text response with candidate
1712                3 => (parts_strategy(), finish_reason()).prop_map(|(parts, fr)| {
1713                    let mut candidate = serde_json::json!({
1714                        "content": {"parts": parts}
1715                    });
1716                    if let Some(r) = fr {
1717                        candidate["finishReason"] = serde_json::Value::String(r);
1718                    }
1719                    serde_json::json!({"candidates": [candidate]}).to_string()
1720                }),
1721                // Usage-only response
1722                2 => (token_count(), token_count(), token_count()).prop_map(|(p, c, t)| {
1723                    serde_json::json!({
1724                        "usageMetadata": {
1725                            "promptTokenCount": p,
1726                            "candidatesTokenCount": c,
1727                            "totalTokenCount": t
1728                        }
1729                    })
1730                    .to_string()
1731                }),
1732                // Empty candidates
1733                1 => Just(r#"{"candidates":[]}"#.to_string()),
1734                // No candidates, no usage
1735                1 => Just(r"{}".to_string()),
1736                // Candidate with finish reason only (no content)
1737                1 => finish_reason()
1738                    .prop_filter("some reason", Option::is_some)
1739                    .prop_map(|fr| {
1740                        serde_json::json!({
1741                            "candidates": [{"finishReason": fr.unwrap()}]
1742                        })
1743                        .to_string()
1744                    }),
1745                // Both candidate and usage
1746                2 => (parts_strategy(), finish_reason(), token_count(), token_count(), token_count())
1747                    .prop_map(|(parts, fr, p, c, t)| {
1748                        let mut candidate = serde_json::json!({
1749                            "content": {"parts": parts}
1750                        });
1751                        if let Some(r) = fr {
1752                            candidate["finishReason"] = serde_json::Value::String(r);
1753                        }
1754                        serde_json::json!({
1755                            "candidates": [candidate],
1756                            "usageMetadata": {
1757                                "promptTokenCount": p,
1758                                "candidatesTokenCount": c,
1759                                "totalTokenCount": t
1760                            }
1761                        })
1762                        .to_string()
1763                    }),
1764            ]
1765        }
1766
1767        /// Chaos — arbitrary JSON strings.
1768        fn chaos_json() -> impl Strategy<Value = String> {
1769            prop_oneof![
1770                Just(String::new()),
1771                Just("{}".to_string()),
1772                Just("[]".to_string()),
1773                Just("null".to_string()),
1774                Just("{".to_string()),
1775                Just(r#"{"candidates":"not_array"}"#.to_string()),
1776                Just(r#"{"candidates":[{"content":null}]}"#.to_string()),
1777                Just(r#"{"candidates":[{"content":{"parts":"not_array"}}]}"#.to_string()),
1778                "[ -~]{0,64}",
1779            ]
1780        }
1781
1782        proptest! {
1783            #![proptest_config(ProptestConfig {
1784                cases: 256,
1785                max_shrink_iters: 100,
1786                .. ProptestConfig::default()
1787            })]
1788
1789            #[test]
1790            fn process_event_valid_never_panics(data in gemini_response_json()) {
1791                let mut state = make_state();
1792                let _ = state.process_event(&data);
1793            }
1794
1795            #[test]
1796            fn process_event_chaos_never_panics(data in chaos_json()) {
1797                let mut state = make_state();
1798                let _ = state.process_event(&data);
1799            }
1800
1801            #[test]
1802            fn process_event_sequence_never_panics(
1803                events in prop::collection::vec(gemini_response_json(), 1..8)
1804            ) {
1805                let mut state = make_state();
1806                for event in &events {
1807                    let _ = state.process_event(event);
1808                }
1809            }
1810
1811            #[test]
1812            fn process_event_mixed_sequence_never_panics(
1813                events in prop::collection::vec(
1814                    prop_oneof![gemini_response_json(), chaos_json()],
1815                    1..12
1816                )
1817            ) {
1818                let mut state = make_state();
1819                for event in &events {
1820                    let _ = state.process_event(event);
1821                }
1822            }
1823        }
1824    }
1825}
1826
1827// ============================================================================
1828// Fuzzing support
1829// ============================================================================
1830
1831#[cfg(feature = "fuzzing")]
1832pub mod fuzz {
1833    use super::*;
1834    use futures::stream;
1835    use std::pin::Pin;
1836
1837    type FuzzStream =
1838        Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1839
1840    /// Opaque wrapper around the Gemini stream processor state.
1841    pub struct Processor(StreamState<FuzzStream>);
1842
1843    impl Default for Processor {
1844        fn default() -> Self {
1845            Self::new()
1846        }
1847    }
1848
1849    impl Processor {
1850        /// Create a fresh processor with default state.
1851        pub fn new() -> Self {
1852            let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1853            Self(StreamState::new(
1854                crate::sse::SseStream::new(Box::pin(empty)),
1855                "gemini-fuzz".into(),
1856                "google-generative".into(),
1857                "google".into(),
1858            ))
1859        }
1860
1861        /// Feed one SSE data payload and return any emitted `StreamEvent`s.
1862        pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1863            self.0.process_event(data)?;
1864            Ok(self.0.pending_events.drain(..).collect())
1865        }
1866    }
1867}