Skip to main content

imp_llm/providers/
google.rs

1use std::pin::Pin;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use futures_core::Stream;
6use serde::{Deserialize, Serialize};
7
8use crate::auth::{ApiKey, AuthStore};
9use crate::error::{Error, Result};
10use crate::message::{AssistantMessage, ContentBlock, Message, StopReason, ToolResultMessage};
11use crate::model::{Capabilities, Model, ModelMeta, ModelPricing};
12use crate::provider::{Context, Provider, RequestOptions, ThinkingLevel, ToolDefinition};
13use crate::stream::StreamEvent;
14use crate::usage::Usage;
15
16const API_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
17const FLASH_MAX_THINKING_BUDGET: i32 = 24_576;
18const PRO_MAX_THINKING_BUDGET: i32 = 32_768;
19
20// ---------------------------------------------------------------------------
21// Gemini wire-format types (request)
22// ---------------------------------------------------------------------------
23
24#[derive(Debug, Serialize)]
25struct ApiRequest {
26    contents: Vec<ApiContent>,
27    #[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")]
28    system_instruction: Option<ApiInstruction>,
29    #[serde(skip_serializing_if = "Vec::is_empty")]
30    tools: Vec<ApiTool>,
31    #[serde(rename = "generationConfig")]
32    generation_config: ApiGenerationConfig,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36struct ApiInstruction {
37    parts: Vec<ApiPart>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41struct ApiContent {
42    role: String,
43    parts: Vec<ApiPart>,
44}
45
46#[derive(Debug, Clone, Default, Serialize, Deserialize)]
47struct ApiPart {
48    #[serde(skip_serializing_if = "Option::is_none")]
49    text: Option<String>,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    thought: Option<bool>,
52    #[serde(rename = "functionCall", skip_serializing_if = "Option::is_none")]
53    function_call: Option<ApiFunctionCall>,
54    #[serde(rename = "functionResponse", skip_serializing_if = "Option::is_none")]
55    function_response: Option<ApiFunctionResponse>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59struct ApiFunctionCall {
60    #[serde(skip_serializing_if = "Option::is_none")]
61    id: Option<String>,
62    name: String,
63    args: serde_json::Value,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67struct ApiFunctionResponse {
68    #[serde(skip_serializing_if = "Option::is_none")]
69    id: Option<String>,
70    name: String,
71    response: serde_json::Value,
72}
73
74#[derive(Debug, Serialize)]
75struct ApiTool {
76    #[serde(rename = "functionDeclarations")]
77    function_declarations: Vec<ApiFunctionDeclaration>,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81struct ApiFunctionDeclaration {
82    name: String,
83    description: String,
84    parameters: serde_json::Value,
85}
86
87#[derive(Debug, Serialize)]
88struct ApiGenerationConfig {
89    #[serde(rename = "maxOutputTokens", skip_serializing_if = "Option::is_none")]
90    max_output_tokens: Option<u32>,
91    #[serde(skip_serializing_if = "Option::is_none")]
92    temperature: Option<f32>,
93    #[serde(rename = "thinkingConfig", skip_serializing_if = "Option::is_none")]
94    thinking_config: Option<ApiThinkingConfig>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
98struct ApiThinkingConfig {
99    #[serde(rename = "includeThoughts")]
100    include_thoughts: bool,
101    #[serde(rename = "thinkingBudget")]
102    thinking_budget: i32,
103}
104
105// ---------------------------------------------------------------------------
106// Gemini wire-format types (SSE response)
107// ---------------------------------------------------------------------------
108
109#[derive(Debug, Clone, Deserialize)]
110struct GenerateContentResponse {
111    #[serde(default)]
112    candidates: Vec<ApiCandidate>,
113    #[serde(rename = "usageMetadata")]
114    usage_metadata: Option<ApiUsageMetadata>,
115}
116
117#[derive(Debug, Clone, Deserialize)]
118struct ApiCandidate {
119    content: Option<ApiContent>,
120    #[serde(rename = "finishReason")]
121    finish_reason: Option<String>,
122}
123
124#[derive(Debug, Clone, Deserialize)]
125struct ApiUsageMetadata {
126    #[serde(rename = "promptTokenCount", default)]
127    prompt_token_count: u32,
128    #[serde(rename = "candidatesTokenCount", default)]
129    candidates_token_count: u32,
130    #[serde(rename = "thoughtsTokenCount", default)]
131    thoughts_token_count: u32,
132    #[serde(rename = "cachedContentTokenCount", default)]
133    cached_content_token_count: u32,
134}
135
136// ---------------------------------------------------------------------------
137// SSE stream state
138// ---------------------------------------------------------------------------
139
140#[derive(Debug, Clone)]
141enum PartState {
142    Text(String),
143    Thinking(String),
144    ToolCall {
145        id: String,
146        name: String,
147        arguments: serde_json::Value,
148        emitted: bool,
149    },
150}
151
152#[derive(Debug)]
153struct StreamState {
154    model: String,
155    started: bool,
156    finished: bool,
157    parts: Vec<PartState>,
158    usage: Usage,
159    finish_reason: Option<String>,
160    saw_tool_call: bool,
161}
162
163impl StreamState {
164    fn new(model: String) -> Self {
165        Self {
166            model,
167            started: false,
168            finished: false,
169            parts: Vec::new(),
170            usage: Usage::default(),
171            finish_reason: None,
172            saw_tool_call: false,
173        }
174    }
175
176    fn ensure_index(&mut self, index: usize) {
177        while self.parts.len() <= index {
178            self.parts.push(PartState::Text(String::new()));
179        }
180    }
181
182    fn stop_reason(&self) -> StopReason {
183        if self.saw_tool_call {
184            return StopReason::ToolUse;
185        }
186
187        match self.finish_reason.as_deref() {
188            Some("STOP") | Some("FINISH_REASON_UNSPECIFIED") | None => StopReason::EndTurn,
189            Some("MAX_TOKENS") => StopReason::MaxTokens,
190            Some(other) => StopReason::Error(other.to_string()),
191        }
192    }
193
194    fn build_message(&self) -> AssistantMessage {
195        let content = self
196            .parts
197            .iter()
198            .filter_map(|part| match part {
199                PartState::Text(text) if !text.is_empty() => {
200                    Some(ContentBlock::Text { text: text.clone() })
201                }
202                PartState::Thinking(text) if !text.is_empty() => {
203                    Some(ContentBlock::Thinking { text: text.clone() })
204                }
205                PartState::ToolCall {
206                    id,
207                    name,
208                    arguments,
209                    ..
210                } => Some(ContentBlock::ToolCall {
211                    id: id.clone(),
212                    name: name.clone(),
213                    arguments: arguments.clone(),
214                }),
215                _ => None,
216            })
217            .collect();
218
219        AssistantMessage {
220            content,
221            usage: Some(self.usage.clone()),
222            stop_reason: self.stop_reason(),
223            timestamp: crate::now(),
224        }
225    }
226}
227
228/// Google Gemini API provider.
229pub struct GoogleProvider {
230    client: reqwest::Client,
231    models: Vec<ModelMeta>,
232}
233
234impl Default for GoogleProvider {
235    fn default() -> Self {
236        Self::new()
237    }
238}
239
240impl GoogleProvider {
241    pub fn new() -> Self {
242        Self {
243            client: super::streaming_http_client(),
244            models: builtin_models(),
245        }
246    }
247
248    pub fn into_arc(self) -> Arc<Self> {
249        Arc::new(self)
250    }
251}
252
253// ---------------------------------------------------------------------------
254// Request building
255// ---------------------------------------------------------------------------
256
257fn max_thinking_budget(model_id: &str) -> i32 {
258    if model_id.contains("flash") {
259        FLASH_MAX_THINKING_BUDGET
260    } else {
261        PRO_MAX_THINKING_BUDGET
262    }
263}
264
265fn thinking_budget(model: &Model, level: ThinkingLevel) -> Option<i32> {
266    let budget = match level {
267        ThinkingLevel::Off => return None,
268        ThinkingLevel::Minimal => 1024,
269        ThinkingLevel::Low => 4096,
270        ThinkingLevel::Medium => 10_000,
271        ThinkingLevel::High => 24_576,
272        ThinkingLevel::XHigh => max_thinking_budget(&model.meta.id),
273    };
274
275    Some(budget.min(max_thinking_budget(&model.meta.id)))
276}
277
278fn default_max_output_tokens(model: &Model, thinking_budget: Option<i32>) -> u32 {
279    let base = model.meta.max_output_tokens.min(8_192);
280    match thinking_budget {
281        Some(budget) => base.max((budget as u32).saturating_add(1024)),
282        None => base,
283    }
284}
285
286fn build_request(model: &Model, context: Context, options: RequestOptions) -> ApiRequest {
287    let thinking_config =
288        thinking_budget(model, options.thinking_level).map(|thinking_budget| ApiThinkingConfig {
289            include_thoughts: true,
290            thinking_budget,
291        });
292
293    ApiRequest {
294        contents: build_messages(&context.messages),
295        system_instruction: build_system_instruction(&options.system_prompt),
296        tools: build_tools(&options.tools),
297        generation_config: ApiGenerationConfig {
298            max_output_tokens: options.max_tokens.or(Some(default_max_output_tokens(
299                model,
300                thinking_budget(model, options.thinking_level),
301            ))),
302            temperature: options.temperature,
303            thinking_config,
304        },
305    }
306}
307
308fn build_system_instruction(prompt: &str) -> Option<ApiInstruction> {
309    if prompt.is_empty() {
310        return None;
311    }
312
313    Some(ApiInstruction {
314        parts: vec![ApiPart {
315            text: Some(prompt.to_string()),
316            ..Default::default()
317        }],
318    })
319}
320
321fn build_tools(tools: &[ToolDefinition]) -> Vec<ApiTool> {
322    if tools.is_empty() {
323        return Vec::new();
324    }
325
326    vec![ApiTool {
327        function_declarations: tools.iter().map(convert_tool_def).collect(),
328    }]
329}
330
331fn build_messages(messages: &[Message]) -> Vec<ApiContent> {
332    messages.iter().map(convert_message).collect()
333}
334
335fn convert_message(message: &Message) -> ApiContent {
336    match message {
337        Message::User(user) => ApiContent {
338            role: "user".into(),
339            parts: user
340                .content
341                .iter()
342                .filter_map(convert_content_block)
343                .collect(),
344        },
345        Message::Assistant(assistant) => ApiContent {
346            role: "model".into(),
347            parts: assistant
348                .content
349                .iter()
350                .filter_map(convert_content_block)
351                .collect(),
352        },
353        Message::ToolResult(tool_result) => ApiContent {
354            role: "user".into(),
355            parts: vec![ApiPart {
356                function_response: Some(ApiFunctionResponse {
357                    id: Some(tool_result.tool_call_id.clone()),
358                    name: tool_result.tool_name.clone(),
359                    response: convert_tool_result_response(tool_result),
360                }),
361                ..Default::default()
362            }],
363        },
364    }
365}
366
367fn convert_content_block(block: &ContentBlock) -> Option<ApiPart> {
368    match block {
369        ContentBlock::Text { text } => Some(ApiPart {
370            text: Some(text.clone()),
371            ..Default::default()
372        }),
373        ContentBlock::Thinking { text } => Some(ApiPart {
374            text: Some(text.clone()),
375            thought: Some(true),
376            ..Default::default()
377        }),
378        ContentBlock::ToolCall {
379            id,
380            name,
381            arguments,
382        } => Some(ApiPart {
383            function_call: Some(ApiFunctionCall {
384                id: Some(id.clone()),
385                name: name.clone(),
386                args: arguments.clone(),
387            }),
388            ..Default::default()
389        }),
390        ContentBlock::Image { .. } => None,
391    }
392}
393
394fn convert_tool_result_response(tool_result: &ToolResultMessage) -> serde_json::Value {
395    let output = tool_result
396        .content
397        .iter()
398        .filter_map(|block| match block {
399            ContentBlock::Text { text } => Some(text.as_str()),
400            _ => None,
401        })
402        .collect::<Vec<_>>()
403        .join("\n");
404
405    let mut response = serde_json::Map::new();
406    response.insert("result".into(), serde_json::Value::String(output));
407
408    if tool_result.is_error {
409        response.insert("isError".into(), serde_json::Value::Bool(true));
410    }
411
412    if !tool_result.details.is_null() {
413        response.insert("details".into(), tool_result.details.clone());
414    }
415
416    serde_json::Value::Object(response)
417}
418
419fn convert_tool_def(tool: &ToolDefinition) -> ApiFunctionDeclaration {
420    ApiFunctionDeclaration {
421        name: tool.name.clone(),
422        description: tool.description.clone(),
423        parameters: tool.parameters.clone(),
424    }
425}
426
427// ---------------------------------------------------------------------------
428// SSE parsing
429// ---------------------------------------------------------------------------
430
431fn parse_sse_event(data: &str) -> Result<Option<GenerateContentResponse>> {
432    let trimmed = data.trim();
433    if trimmed.is_empty() || trimmed == "[DONE]" {
434        return Ok(None);
435    }
436
437    serde_json::from_str(trimmed)
438        .map(Some)
439        .map_err(|e| Error::Stream(format!("Failed to parse Gemini SSE data: {e}: {trimmed}")))
440}
441
442fn text_delta(previous: &str, current: &str) -> String {
443    current
444        .strip_prefix(previous)
445        .unwrap_or(current)
446        .to_string()
447}
448
449fn update_usage(usage: &ApiUsageMetadata, state: &mut StreamState) {
450    state.usage.input_tokens = usage.prompt_token_count;
451    state.usage.output_tokens = usage.candidates_token_count + usage.thoughts_token_count;
452    state.usage.cache_read_tokens = usage.cached_content_token_count;
453    state.usage.cache_write_tokens = 0;
454}
455
456fn process_response(
457    response: GenerateContentResponse,
458    state: &mut StreamState,
459) -> Vec<StreamEvent> {
460    let mut out = Vec::new();
461
462    if !state.started {
463        state.started = true;
464        out.push(StreamEvent::MessageStart {
465            model: state.model.clone(),
466        });
467    }
468
469    if let Some(usage) = &response.usage_metadata {
470        update_usage(usage, state);
471    }
472
473    if let Some(candidate) = response.candidates.first() {
474        if let Some(content) = &candidate.content {
475            for (index, part) in content.parts.iter().enumerate() {
476                if let Some(function_call) = &part.function_call {
477                    state.ensure_index(index);
478                    let id = function_call
479                        .id
480                        .clone()
481                        .unwrap_or_else(|| format!("call_{index}"));
482                    let name = function_call.name.clone();
483                    let arguments = function_call.args.clone();
484
485                    let emit = match state.parts.get_mut(index) {
486                        Some(PartState::ToolCall {
487                            id: existing_id,
488                            name: existing_name,
489                            arguments: existing_arguments,
490                            emitted,
491                        }) if *existing_id == id && *existing_name == name => {
492                            *existing_arguments = arguments.clone();
493                            if *emitted {
494                                false
495                            } else {
496                                *emitted = true;
497                                true
498                            }
499                        }
500                        Some(slot) => {
501                            *slot = PartState::ToolCall {
502                                id: id.clone(),
503                                name: name.clone(),
504                                arguments: arguments.clone(),
505                                emitted: true,
506                            };
507                            true
508                        }
509                        None => false,
510                    };
511
512                    state.saw_tool_call = true;
513                    if emit {
514                        out.push(StreamEvent::ToolCall {
515                            id,
516                            name,
517                            arguments,
518                        });
519                    }
520                    continue;
521                }
522
523                let Some(text) = part.text.as_deref() else {
524                    continue;
525                };
526
527                state.ensure_index(index);
528                if part.thought.unwrap_or(false) {
529                    let previous = match &state.parts[index] {
530                        PartState::Thinking(existing) => existing.clone(),
531                        _ => String::new(),
532                    };
533                    state.parts[index] = PartState::Thinking(text.to_string());
534                    let delta = text_delta(&previous, text);
535                    if !delta.is_empty() {
536                        out.push(StreamEvent::ThinkingDelta { text: delta });
537                    }
538                } else {
539                    let previous = match &state.parts[index] {
540                        PartState::Text(existing) => existing.clone(),
541                        _ => String::new(),
542                    };
543                    state.parts[index] = PartState::Text(text.to_string());
544                    let delta = text_delta(&previous, text);
545                    if !delta.is_empty() {
546                        out.push(StreamEvent::TextDelta { text: delta });
547                    }
548                }
549            }
550        }
551
552        if let Some(reason) = &candidate.finish_reason {
553            state.finish_reason = Some(reason.clone());
554        }
555
556        if candidate.finish_reason.is_some() && !state.finished {
557            state.finished = true;
558            out.push(StreamEvent::MessageEnd {
559                message: state.build_message(),
560            });
561        }
562    }
563
564    out
565}
566
567#[cfg(test)]
568fn parse_sse_stream(raw: &str, state: &mut StreamState) -> Vec<Result<StreamEvent>> {
569    let mut events = Vec::new();
570
571    for line in raw.lines() {
572        let trimmed = line.trim();
573        if let Some(data) = trimmed.strip_prefix("data: ") {
574            match parse_sse_event(data) {
575                Ok(Some(response)) => {
576                    for event in process_response(response, state) {
577                        events.push(Ok(event));
578                    }
579                }
580                Ok(None) => {}
581                Err(error) => events.push(Err(error)),
582            }
583        }
584    }
585
586    events
587}
588
589// ---------------------------------------------------------------------------
590// Streaming implementation
591// ---------------------------------------------------------------------------
592
593fn stream_response(
594    client: reqwest::Client,
595    model_id: String,
596    api_key: String,
597    request: ApiRequest,
598) -> Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>> {
599    let (tx, rx) = futures::channel::mpsc::unbounded();
600
601    tokio::spawn(async move {
602        let result = client
603            .post(format!("{API_BASE_URL}/{model_id}:streamGenerateContent"))
604            .query(&[("alt", "sse"), ("key", api_key.as_str())])
605            .header("content-type", "application/json")
606            .json(&request)
607            .send()
608            .await;
609
610        let response = match result {
611            Ok(response) => response,
612            Err(error) => {
613                let _ = tx.unbounded_send(Err(Error::Http(error)));
614                return;
615            }
616        };
617
618        let status = response.status();
619        if !status.is_success() {
620            let body = response.text().await.unwrap_or_default();
621            let _ = tx.unbounded_send(Err(Error::Provider(format!("HTTP {status}: {body}"))));
622            return;
623        }
624
625        let mut state = StreamState::new(model_id);
626        let mut buffer = String::new();
627        let mut byte_stream = response.bytes_stream();
628
629        use futures::StreamExt;
630        while let Some(chunk) = byte_stream.next().await {
631            match chunk {
632                Ok(bytes) => {
633                    buffer.push_str(&String::from_utf8_lossy(&bytes));
634
635                    while let Some(pos) = buffer.find('\n') {
636                        let line = buffer[..pos].to_string();
637                        buffer = buffer[pos + 1..].to_string();
638
639                        let trimmed = line.trim();
640                        if let Some(data) = trimmed.strip_prefix("data: ") {
641                            match parse_sse_event(data) {
642                                Ok(Some(response)) => {
643                                    for event in process_response(response, &mut state) {
644                                        if tx.unbounded_send(Ok(event)).is_err() {
645                                            return;
646                                        }
647                                    }
648                                }
649                                Ok(None) => {}
650                                Err(error) => {
651                                    if tx.unbounded_send(Err(error)).is_err() {
652                                        return;
653                                    }
654                                }
655                            }
656                        }
657                    }
658                }
659                Err(error) => {
660                    let _ = tx.unbounded_send(Err(Error::Http(error)));
661                    return;
662                }
663            }
664        }
665
666        let trimmed = buffer.trim();
667        if let Some(data) = trimmed.strip_prefix("data: ") {
668            match parse_sse_event(data) {
669                Ok(Some(response)) => {
670                    for event in process_response(response, &mut state) {
671                        if tx.unbounded_send(Ok(event)).is_err() {
672                            return;
673                        }
674                    }
675                }
676                Ok(None) => {}
677                Err(error) => {
678                    let _ = tx.unbounded_send(Err(error));
679                    return;
680                }
681            }
682        }
683
684        if !state.finished {
685            let _ = tx.unbounded_send(Err(Error::Stream(
686                "Google stream ended before terminal finishReason".into(),
687            )));
688        }
689    });
690
691    Box::pin(rx)
692}
693
694#[async_trait]
695impl Provider for GoogleProvider {
696    fn stream(
697        &self,
698        model: &Model,
699        context: Context,
700        options: RequestOptions,
701        api_key: &str,
702    ) -> Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>> {
703        let request = build_request(model, context, options);
704        stream_response(
705            self.client.clone(),
706            model.meta.id.clone(),
707            api_key.to_string(),
708            request,
709        )
710    }
711
712    async fn resolve_auth(&self, auth: &AuthStore) -> Result<ApiKey> {
713        auth.resolve("google")
714    }
715
716    fn id(&self) -> &str {
717        "google"
718    }
719
720    fn models(&self) -> &[ModelMeta] {
721        &self.models
722    }
723}
724
725fn builtin_models() -> Vec<ModelMeta> {
726    vec![
727        ModelMeta {
728            id: "gemini-2.5-pro".into(),
729            provider: "google".into(),
730            name: "Gemini 2.5 Pro".into(),
731            context_window: 1_048_576,
732            max_output_tokens: 65_536,
733            pricing: ModelPricing {
734                input_per_mtok: 1.25,
735                output_per_mtok: 10.0,
736                cache_read_per_mtok: 0.315,
737                cache_write_per_mtok: 1.25,
738            },
739            capabilities: Capabilities {
740                reasoning: true,
741                images: true,
742                tool_use: true,
743            },
744        },
745        ModelMeta {
746            id: "gemini-2.5-flash".into(),
747            provider: "google".into(),
748            name: "Gemini 2.5 Flash".into(),
749            context_window: 1_048_576,
750            max_output_tokens: 65_536,
751            pricing: ModelPricing {
752                input_per_mtok: 0.15,
753                output_per_mtok: 3.5,
754                cache_read_per_mtok: 0.0375,
755                cache_write_per_mtok: 0.15,
756            },
757            capabilities: Capabilities {
758                reasoning: true,
759                images: true,
760                tool_use: true,
761            },
762        },
763    ]
764}
765
766#[cfg(test)]
767mod tests {
768    use super::*;
769    use crate::message::UserMessage;
770
771    fn test_model(id: &str) -> Model {
772        let provider = GoogleProvider::new();
773        Model {
774            meta: builtin_models()
775                .into_iter()
776                .find(|meta| meta.id == id)
777                .expect("test model should exist"),
778            provider: provider.into_arc(),
779        }
780    }
781
782    #[test]
783    fn serialize_text_user_message() {
784        let message = Message::User(UserMessage {
785            content: vec![ContentBlock::Text {
786                text: "Hello Gemini".into(),
787            }],
788            timestamp: 0,
789        });
790
791        let api = convert_message(&message);
792        let json = serde_json::to_value(&api).unwrap();
793
794        assert_eq!(json["role"], "user");
795        assert_eq!(json["parts"][0]["text"], "Hello Gemini");
796    }
797
798    #[test]
799    fn serialize_assistant_tool_call_block() {
800        let message = Message::Assistant(AssistantMessage {
801            content: vec![ContentBlock::ToolCall {
802                id: "call_1".into(),
803                name: "bash".into(),
804                arguments: serde_json::json!({"command": "ls"}),
805            }],
806            usage: None,
807            stop_reason: StopReason::ToolUse,
808            timestamp: 0,
809        });
810
811        let api = convert_message(&message);
812        let json = serde_json::to_value(&api).unwrap();
813
814        assert_eq!(json["role"], "model");
815        assert_eq!(json["parts"][0]["functionCall"]["id"], "call_1");
816        assert_eq!(json["parts"][0]["functionCall"]["name"], "bash");
817        assert_eq!(json["parts"][0]["functionCall"]["args"]["command"], "ls");
818    }
819
820    #[test]
821    fn serialize_tool_result_message() {
822        let message = Message::ToolResult(ToolResultMessage {
823            tool_call_id: "call_1".into(),
824            tool_name: "bash".into(),
825            content: vec![ContentBlock::Text {
826                text: "README.md\nsrc/".into(),
827            }],
828            is_error: false,
829            details: serde_json::json!({"cwd": "/tmp"}),
830            timestamp: 0,
831        });
832
833        let api = convert_message(&message);
834        let json = serde_json::to_value(&api).unwrap();
835
836        assert_eq!(json["role"], "user");
837        assert_eq!(json["parts"][0]["functionResponse"]["id"], "call_1");
838        assert_eq!(json["parts"][0]["functionResponse"]["name"], "bash");
839        assert_eq!(
840            json["parts"][0]["functionResponse"]["response"]["result"],
841            "README.md\nsrc/"
842        );
843        assert_eq!(
844            json["parts"][0]["functionResponse"]["response"]["details"]["cwd"],
845            "/tmp"
846        );
847    }
848
849    #[test]
850    fn thinking_budget_mapping_matches_model_limits() {
851        let pro = test_model("gemini-2.5-pro");
852        let flash = test_model("gemini-2.5-flash");
853
854        assert_eq!(thinking_budget(&pro, ThinkingLevel::Off), None);
855        assert_eq!(thinking_budget(&pro, ThinkingLevel::Minimal), Some(1024));
856        assert_eq!(thinking_budget(&pro, ThinkingLevel::Low), Some(4096));
857        assert_eq!(thinking_budget(&pro, ThinkingLevel::Medium), Some(10_000));
858        assert_eq!(thinking_budget(&pro, ThinkingLevel::High), Some(24_576));
859        assert_eq!(thinking_budget(&pro, ThinkingLevel::XHigh), Some(32_768));
860        assert_eq!(thinking_budget(&flash, ThinkingLevel::XHigh), Some(24_576));
861    }
862
863    #[test]
864    fn default_max_output_tokens_caps_google_models_without_thinking() {
865        let pro = test_model("gemini-2.5-pro");
866        assert_eq!(default_max_output_tokens(&pro, None), 8_192);
867    }
868
869    #[test]
870    fn default_max_output_tokens_grows_for_google_thinking_budget() {
871        let pro = test_model("gemini-2.5-pro");
872        assert_eq!(default_max_output_tokens(&pro, Some(24_576)), 25_600);
873    }
874
875    #[test]
876    fn build_request_serializes_system_tools_and_thinking() {
877        let model = test_model("gemini-2.5-pro");
878        let context = Context {
879            messages: vec![
880                Message::user("List the files in this directory."),
881                Message::Assistant(AssistantMessage {
882                    content: vec![ContentBlock::ToolCall {
883                        id: "call_1".into(),
884                        name: "bash".into(),
885                        arguments: serde_json::json!({"command": "ls"}),
886                    }],
887                    usage: None,
888                    stop_reason: StopReason::ToolUse,
889                    timestamp: 0,
890                }),
891                Message::ToolResult(ToolResultMessage {
892                    tool_call_id: "call_1".into(),
893                    tool_name: "bash".into(),
894                    content: vec![ContentBlock::Text {
895                        text: "Cargo.toml\nsrc/".into(),
896                    }],
897                    is_error: false,
898                    details: serde_json::Value::Null,
899                    timestamp: 0,
900                }),
901            ],
902        };
903        let options = RequestOptions {
904            system_prompt: "You are a helpful coding assistant.".into(),
905            max_tokens: Some(2048),
906            temperature: Some(0.2),
907            thinking_level: ThinkingLevel::High,
908            tools: vec![ToolDefinition {
909                name: "bash".into(),
910                description: "Run a shell command".into(),
911                parameters: serde_json::json!({
912                    "type": "object",
913                    "properties": {
914                        "command": { "type": "string" }
915                    },
916                    "required": ["command"]
917                }),
918            }],
919            ..Default::default()
920        };
921
922        let request = build_request(&model, context, options);
923        let json = serde_json::to_value(&request).unwrap();
924
925        assert_eq!(
926            json["systemInstruction"]["parts"][0]["text"],
927            "You are a helpful coding assistant."
928        );
929        assert_eq!(json["contents"].as_array().unwrap().len(), 3);
930        assert_eq!(json["contents"][0]["role"], "user");
931        assert_eq!(json["contents"][1]["role"], "model");
932        assert_eq!(
933            json["contents"][1]["parts"][0]["functionCall"]["name"],
934            "bash"
935        );
936        assert_eq!(
937            json["contents"][2]["parts"][0]["functionResponse"]["name"],
938            "bash"
939        );
940        assert_eq!(json["tools"][0]["functionDeclarations"][0]["name"], "bash");
941        assert_eq!(json["generationConfig"]["maxOutputTokens"], 2048);
942        assert!(
943            (json["generationConfig"]["temperature"]
944                .as_f64()
945                .expect("temperature should be numeric")
946                - 0.2)
947                .abs()
948                < 1e-6
949        );
950        assert_eq!(
951            json["generationConfig"]["thinkingConfig"]["includeThoughts"],
952            true
953        );
954        assert_eq!(
955            json["generationConfig"]["thinkingConfig"]["thinkingBudget"],
956            24_576
957        );
958    }
959
960    #[test]
961    fn parse_text_and_thinking_deltas() {
962        let raw = "\
963 data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"thought\":true,\"text\":\"Plan\"}]}}]}\n\
964 \n\
965 data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"thought\":true,\"text\":\"Planning\"},{\"text\":\"Answer\"}]}}]}\n\
966 \n\
967 data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"thought\":true,\"text\":\"Planning\"},{\"text\":\"Answer done\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":5,\"thoughtsTokenCount\":3}}\n";
968
969        let mut state = StreamState::new("gemini-2.5-pro".into());
970        let events = parse_sse_stream(raw, &mut state);
971        let events: Vec<_> = events
972            .into_iter()
973            .collect::<std::result::Result<Vec<_>, _>>()
974            .unwrap();
975
976        assert!(
977            matches!(&events[0], StreamEvent::MessageStart { model } if model == "gemini-2.5-pro")
978        );
979        assert!(matches!(&events[1], StreamEvent::ThinkingDelta { text } if text == "Plan"));
980        assert!(matches!(&events[2], StreamEvent::ThinkingDelta { text } if text == "ning"));
981        assert!(matches!(&events[3], StreamEvent::TextDelta { text } if text == "Answer"));
982        assert!(matches!(&events[4], StreamEvent::TextDelta { text } if text == " done"));
983        assert!(
984            matches!(&events[5], StreamEvent::MessageEnd { message } if message.stop_reason == StopReason::EndTurn)
985        );
986
987        if let StreamEvent::MessageEnd { message } = &events[5] {
988            assert_eq!(message.usage.as_ref().unwrap().input_tokens, 10);
989            assert_eq!(message.usage.as_ref().unwrap().output_tokens, 8);
990            assert_eq!(message.content.len(), 2);
991        } else {
992            panic!("expected MessageEnd");
993        }
994    }
995
996    #[test]
997    fn parse_tool_call_response() {
998        let raw = "\
999 data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"functionCall\":{\"id\":\"call_1\",\"name\":\"read\",\"args\":{\"path\":\"src/lib.rs\"}}}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":12,\"candidatesTokenCount\":4}}\n";
1000
1001        let mut state = StreamState::new("gemini-2.5-pro".into());
1002        let events = parse_sse_stream(raw, &mut state);
1003        let events: Vec<_> = events
1004            .into_iter()
1005            .collect::<std::result::Result<Vec<_>, _>>()
1006            .unwrap();
1007
1008        assert_eq!(events.len(), 3);
1009        assert!(matches!(&events[0], StreamEvent::MessageStart { .. }));
1010        assert!(
1011            matches!(&events[1], StreamEvent::ToolCall { id, name, arguments } if id == "call_1" && name == "read" && arguments["path"] == "src/lib.rs")
1012        );
1013        assert!(
1014            matches!(&events[2], StreamEvent::MessageEnd { message } if message.stop_reason == StopReason::ToolUse)
1015        );
1016    }
1017
1018    #[test]
1019    fn parse_invalid_sse_event_returns_error() {
1020        let error = parse_sse_event("not json").unwrap_err();
1021        assert!(matches!(error, Error::Stream(_)));
1022    }
1023
1024    #[test]
1025    fn builtin_models_include_flash_and_pro() {
1026        let models = builtin_models();
1027        assert_eq!(models.len(), 2);
1028        assert!(models.iter().any(|model| model.id == "gemini-2.5-pro"));
1029        assert!(models.iter().any(|model| model.id == "gemini-2.5-flash"));
1030    }
1031
1032    #[test]
1033    fn parse_multi_part_response_text_and_tool_call() {
1034        // A single candidate with both text and a function_call in the same response
1035        let raw = "\
1036 data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"Let me check\"},{\"functionCall\":{\"id\":\"call_1\",\"name\":\"read\",\"args\":{\"path\":\"a.rs\"}}}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":6}}\n";
1037
1038        let mut state = StreamState::new("gemini-2.5-pro".into());
1039        let events = parse_sse_stream(raw, &mut state);
1040        let events: Vec<_> = events
1041            .into_iter()
1042            .collect::<std::result::Result<Vec<_>, _>>()
1043            .unwrap();
1044
1045        // MessageStart, TextDelta, ToolCall, MessageEnd
1046        assert_eq!(events.len(), 4);
1047        assert!(matches!(&events[0], StreamEvent::MessageStart { .. }));
1048        assert!(matches!(&events[1], StreamEvent::TextDelta { text } if text == "Let me check"));
1049        assert!(matches!(&events[2], StreamEvent::ToolCall { name, .. } if name == "read"));
1050        if let StreamEvent::MessageEnd { message } = &events[3] {
1051            assert_eq!(message.stop_reason, StopReason::ToolUse);
1052        } else {
1053            panic!("expected MessageEnd");
1054        }
1055    }
1056
1057    #[test]
1058    fn parse_usage_metadata_extraction() {
1059        let raw = "\
1060 data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"Hi\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":42,\"candidatesTokenCount\":10,\"thoughtsTokenCount\":5,\"cachedContentTokenCount\":3}}\n";
1061
1062        let mut state = StreamState::new("gemini-2.5-pro".into());
1063        let events = parse_sse_stream(raw, &mut state);
1064        let events: Vec<_> = events
1065            .into_iter()
1066            .collect::<std::result::Result<Vec<_>, _>>()
1067            .unwrap();
1068
1069        if let StreamEvent::MessageEnd { message } = events.last().unwrap() {
1070            let usage = message.usage.as_ref().unwrap();
1071            assert_eq!(usage.input_tokens, 42);
1072            assert_eq!(usage.output_tokens, 15); // candidates + thoughts
1073            assert_eq!(usage.cache_read_tokens, 3);
1074        } else {
1075            panic!("expected MessageEnd");
1076        }
1077    }
1078
1079    #[test]
1080    fn stop_reason_mapping() {
1081        let mut state = StreamState::new("test".into());
1082        state.finish_reason = Some("STOP".into());
1083        assert_eq!(state.stop_reason(), StopReason::EndTurn);
1084
1085        state.finish_reason = Some("MAX_TOKENS".into());
1086        assert_eq!(state.stop_reason(), StopReason::MaxTokens);
1087
1088        state.finish_reason = Some("SAFETY".into());
1089        assert_eq!(state.stop_reason(), StopReason::Error("SAFETY".into()));
1090
1091        state.finish_reason = None;
1092        assert_eq!(state.stop_reason(), StopReason::EndTurn);
1093
1094        state.saw_tool_call = true;
1095        assert_eq!(state.stop_reason(), StopReason::ToolUse);
1096    }
1097
1098    #[test]
1099    fn empty_candidates_produces_no_content_events() {
1100        let raw = "\
1101 data: {\"candidates\":[],\"usageMetadata\":{\"promptTokenCount\":5,\"candidatesTokenCount\":0}}\n";
1102
1103        let mut state = StreamState::new("gemini-2.5-pro".into());
1104        let events = parse_sse_stream(raw, &mut state);
1105        let events: Vec<_> = events
1106            .into_iter()
1107            .collect::<std::result::Result<Vec<_>, _>>()
1108            .unwrap();
1109
1110        // Only MessageStart (no content deltas, no MessageEnd since no finishReason)
1111        assert_eq!(events.len(), 1);
1112        assert!(matches!(&events[0], StreamEvent::MessageStart { .. }));
1113    }
1114
1115    #[test]
1116    fn parse_sse_event_done_marker_returns_none() {
1117        let result = parse_sse_event("[DONE]").unwrap();
1118        assert!(result.is_none());
1119    }
1120
1121    #[test]
1122    fn empty_system_prompt_produces_no_instruction() {
1123        let instruction = build_system_instruction("");
1124        assert!(instruction.is_none());
1125    }
1126
1127    #[test]
1128    fn empty_tools_produces_empty_vec() {
1129        let tools = build_tools(&[]);
1130        assert!(tools.is_empty());
1131    }
1132}