Skip to main content

agent_sdk_rs/llm/
google.rs

1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use serde_json::{Map, Value, json};
5use std::collections::HashSet;
6
7use crate::error::ProviderError;
8use crate::llm::{
9    ChatModel, ModelCompletion, ModelMessage, ModelToolCall, ModelToolChoice, ModelToolDefinition,
10    ModelUsage,
11};
12
13const DEFAULT_API_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
14const EMPTY_USER_CONTENT_FALLBACK: &str = " ";
15
16#[derive(Debug, Clone)]
17pub struct GoogleModelConfig {
18    pub api_key: String,
19    pub model: String,
20    pub api_base_url: Option<String>,
21    pub temperature: Option<f32>,
22    pub top_p: Option<f32>,
23    pub max_output_tokens: Option<u32>,
24    pub thinking_budget_tokens: Option<u32>,
25    pub include_thoughts: Option<bool>,
26}
27
28impl GoogleModelConfig {
29    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
30        Self {
31            api_key: api_key.into(),
32            model: model.into(),
33            api_base_url: None,
34            temperature: None,
35            top_p: None,
36            max_output_tokens: Some(4096),
37            thinking_budget_tokens: None,
38            include_thoughts: None,
39        }
40    }
41}
42
43#[derive(Debug, Clone)]
44pub struct GoogleModel {
45    client: Client,
46    config: GoogleModelConfig,
47}
48
49impl GoogleModel {
50    pub fn new(config: GoogleModelConfig) -> Result<Self, ProviderError> {
51        let client = Client::builder()
52            .build()
53            .map_err(|err| ProviderError::Request(err.to_string()))?;
54
55        Ok(Self { client, config })
56    }
57
58    pub fn from_env(model: impl Into<String>) -> Result<Self, ProviderError> {
59        let api_key = std::env::var("GOOGLE_API_KEY")
60            .or_else(|_| std::env::var("GEMINI_API_KEY"))
61            .map_err(|_| {
62                ProviderError::Request("GOOGLE_API_KEY (or GEMINI_API_KEY) is not set".to_string())
63            })?;
64
65        Self::new(GoogleModelConfig::new(api_key, model))
66    }
67
68    fn endpoint(&self) -> String {
69        let base = self
70            .config
71            .api_base_url
72            .as_deref()
73            .unwrap_or(DEFAULT_API_BASE_URL)
74            .trim_end_matches('/');
75        format!("{base}/models/{}:generateContent", self.config.model)
76    }
77}
78
79#[async_trait]
80impl ChatModel for GoogleModel {
81    async fn invoke(
82        &self,
83        messages: &[ModelMessage],
84        tools: &[ModelToolDefinition],
85        tool_choice: ModelToolChoice,
86    ) -> Result<ModelCompletion, ProviderError> {
87        let request = build_request(messages, tools, tool_choice, &self.config);
88
89        let response = self
90            .client
91            .post(self.endpoint())
92            .header("x-goog-api-key", &self.config.api_key)
93            .header("content-type", "application/json")
94            .json(&request)
95            .send()
96            .await
97            .map_err(|err| ProviderError::Request(err.to_string()))?;
98
99        if !response.status().is_success() {
100            return Err(ProviderError::Request(extract_api_error(response).await));
101        }
102
103        let payload = response
104            .json::<GenerateContentResponse>()
105            .await
106            .map_err(|err| ProviderError::Response(err.to_string()))?;
107
108        normalize_response(payload)
109    }
110}
111
112#[derive(Debug, Serialize)]
113#[serde(rename_all = "camelCase")]
114struct GenerateContentRequest {
115    contents: Vec<GoogleContent>,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    system_instruction: Option<GoogleSystemInstruction>,
118    #[serde(skip_serializing_if = "Option::is_none")]
119    tools: Option<Vec<GoogleTool>>,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    tool_config: Option<GoogleToolConfig>,
122    #[serde(skip_serializing_if = "Option::is_none")]
123    generation_config: Option<GoogleGenerationConfig>,
124}
125
126#[derive(Debug, Serialize, Deserialize, Clone)]
127#[serde(rename_all = "camelCase")]
128struct GoogleContent {
129    role: String,
130    parts: Vec<GooglePart>,
131}
132
133#[derive(Debug, Serialize)]
134#[serde(rename_all = "camelCase")]
135struct GoogleSystemInstruction {
136    parts: Vec<GooglePart>,
137}
138
139#[derive(Debug, Serialize)]
140#[serde(rename_all = "camelCase")]
141struct GoogleTool {
142    function_declarations: Vec<GoogleFunctionDeclaration>,
143}
144
145#[derive(Debug, Serialize)]
146#[serde(rename_all = "camelCase")]
147struct GoogleFunctionDeclaration {
148    name: String,
149    description: String,
150    parameters: Value,
151}
152
153#[derive(Debug, Serialize)]
154#[serde(rename_all = "camelCase")]
155struct GoogleToolConfig {
156    function_calling_config: GoogleFunctionCallingConfig,
157}
158
159#[derive(Debug, Serialize)]
160#[serde(rename_all = "camelCase")]
161struct GoogleFunctionCallingConfig {
162    mode: String,
163    #[serde(skip_serializing_if = "Option::is_none")]
164    allowed_function_names: Option<Vec<String>>,
165}
166
167#[derive(Debug, Serialize)]
168#[serde(rename_all = "camelCase")]
169struct GoogleGenerationConfig {
170    #[serde(skip_serializing_if = "Option::is_none")]
171    temperature: Option<f32>,
172    #[serde(skip_serializing_if = "Option::is_none")]
173    top_p: Option<f32>,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    max_output_tokens: Option<u32>,
176    #[serde(skip_serializing_if = "Option::is_none")]
177    thinking_config: Option<GoogleThinkingConfig>,
178}
179
180#[derive(Debug, Serialize)]
181#[serde(rename_all = "camelCase")]
182struct GoogleThinkingConfig {
183    thinking_budget: u32,
184    #[serde(skip_serializing_if = "Option::is_none")]
185    include_thoughts: Option<bool>,
186}
187
188#[derive(Debug, Serialize, Deserialize, Clone)]
189#[serde(rename_all = "camelCase")]
190struct GooglePart {
191    #[serde(skip_serializing_if = "Option::is_none")]
192    text: Option<String>,
193    #[serde(skip_serializing_if = "Option::is_none")]
194    thought: Option<bool>,
195    #[serde(skip_serializing_if = "Option::is_none")]
196    function_call: Option<GoogleFunctionCall>,
197    #[serde(skip_serializing_if = "Option::is_none")]
198    function_response: Option<GoogleFunctionResponse>,
199}
200
201#[derive(Debug, Serialize, Deserialize, Clone)]
202#[serde(rename_all = "camelCase")]
203struct GoogleFunctionCall {
204    id: Option<String>,
205    name: Option<String>,
206    args: Option<Value>,
207}
208
209#[derive(Debug, Serialize, Deserialize, Clone)]
210#[serde(rename_all = "camelCase")]
211struct GoogleFunctionResponse {
212    name: String,
213    response: Value,
214}
215
216#[derive(Debug, Deserialize)]
217#[serde(rename_all = "camelCase")]
218struct GenerateContentResponse {
219    #[serde(default)]
220    candidates: Vec<GoogleCandidate>,
221    usage_metadata: Option<GoogleUsageMetadata>,
222}
223
224#[derive(Debug, Deserialize)]
225#[serde(rename_all = "camelCase")]
226struct GoogleCandidate {
227    content: Option<GoogleContent>,
228}
229
230#[derive(Debug, Deserialize)]
231#[serde(rename_all = "camelCase")]
232struct GoogleUsageMetadata {
233    prompt_token_count: Option<u32>,
234    candidates_token_count: Option<u32>,
235    thoughts_token_count: Option<u32>,
236}
237
238#[derive(Debug, Deserialize)]
239#[serde(rename_all = "camelCase")]
240struct GoogleErrorEnvelope {
241    error: GoogleApiError,
242}
243
244#[derive(Debug, Deserialize)]
245#[serde(rename_all = "camelCase")]
246struct GoogleApiError {
247    code: Option<u16>,
248    status: Option<String>,
249    message: Option<String>,
250}
251
252fn build_request(
253    messages: &[ModelMessage],
254    tools: &[ModelToolDefinition],
255    tool_choice: ModelToolChoice,
256    config: &GoogleModelConfig,
257) -> GenerateContentRequest {
258    let (contents, system_instruction) = to_google_contents(messages);
259    let contents = ensure_non_empty_contents(contents);
260
261    let tools_payload = if tools.is_empty() {
262        None
263    } else {
264        let declarations = tools
265            .iter()
266            .map(|tool| GoogleFunctionDeclaration {
267                name: tool.name.clone(),
268                description: tool.description.clone(),
269                parameters: clean_gemini_schema(tool.parameters.clone()),
270            })
271            .collect::<Vec<_>>();
272        Some(vec![GoogleTool {
273            function_declarations: declarations,
274        }])
275    };
276
277    let tool_config = if tools.is_empty() {
278        None
279    } else {
280        Some(match tool_choice {
281            ModelToolChoice::Auto => GoogleToolConfig {
282                function_calling_config: GoogleFunctionCallingConfig {
283                    mode: "AUTO".to_string(),
284                    allowed_function_names: None,
285                },
286            },
287            ModelToolChoice::Required => GoogleToolConfig {
288                function_calling_config: GoogleFunctionCallingConfig {
289                    mode: "ANY".to_string(),
290                    allowed_function_names: None,
291                },
292            },
293            ModelToolChoice::None => GoogleToolConfig {
294                function_calling_config: GoogleFunctionCallingConfig {
295                    mode: "NONE".to_string(),
296                    allowed_function_names: None,
297                },
298            },
299            ModelToolChoice::Tool(name) => GoogleToolConfig {
300                function_calling_config: GoogleFunctionCallingConfig {
301                    mode: "ANY".to_string(),
302                    allowed_function_names: Some(vec![name]),
303                },
304            },
305        })
306    };
307
308    let thinking_config = config
309        .thinking_budget_tokens
310        .map(|budget| GoogleThinkingConfig {
311            thinking_budget: budget,
312            include_thoughts: config.include_thoughts,
313        });
314
315    let generation_config = GoogleGenerationConfig {
316        temperature: config.temperature,
317        top_p: config.top_p,
318        max_output_tokens: config.max_output_tokens,
319        thinking_config,
320    };
321
322    GenerateContentRequest {
323        contents,
324        system_instruction: system_instruction.map(|instruction| GoogleSystemInstruction {
325            parts: vec![GooglePart {
326                text: Some(instruction),
327                thought: None,
328                function_call: None,
329                function_response: None,
330            }],
331        }),
332        tools: tools_payload,
333        tool_config,
334        generation_config: Some(generation_config),
335    }
336}
337
338fn ensure_non_empty_contents(mut contents: Vec<GoogleContent>) -> Vec<GoogleContent> {
339    if contents.is_empty() {
340        contents.push(GoogleContent {
341            role: "user".to_string(),
342            parts: vec![GooglePart {
343                text: Some(EMPTY_USER_CONTENT_FALLBACK.to_string()),
344                thought: None,
345                function_call: None,
346                function_response: None,
347            }],
348        });
349    }
350    contents
351}
352
353fn to_google_contents(messages: &[ModelMessage]) -> (Vec<GoogleContent>, Option<String>) {
354    let mut system_lines = Vec::new();
355    let mut contents = Vec::new();
356
357    for message in messages {
358        match message {
359            ModelMessage::System(content) => {
360                if !content.is_empty() {
361                    system_lines.push(content.clone());
362                }
363            }
364            ModelMessage::User(content) => {
365                if content.is_empty() {
366                    continue;
367                }
368                contents.push(GoogleContent {
369                    role: "user".to_string(),
370                    parts: vec![GooglePart {
371                        text: Some(content.clone()),
372                        thought: None,
373                        function_call: None,
374                        function_response: None,
375                    }],
376                });
377            }
378            ModelMessage::Assistant {
379                content,
380                tool_calls,
381            } => {
382                let mut parts = Vec::new();
383
384                if let Some(text) = content
385                    && !text.is_empty()
386                {
387                    parts.push(GooglePart {
388                        text: Some(text.clone()),
389                        thought: None,
390                        function_call: None,
391                        function_response: None,
392                    });
393                }
394
395                for call in tool_calls {
396                    parts.push(GooglePart {
397                        text: None,
398                        thought: None,
399                        function_call: Some(GoogleFunctionCall {
400                            id: Some(call.id.clone()),
401                            name: Some(call.name.clone()),
402                            args: Some(call.arguments.clone()),
403                        }),
404                        function_response: None,
405                    });
406                }
407
408                if !parts.is_empty() {
409                    contents.push(GoogleContent {
410                        role: "model".to_string(),
411                        parts,
412                    });
413                }
414            }
415            ModelMessage::ToolResult {
416                tool_call_id: _,
417                tool_name,
418                content,
419                is_error,
420            } => contents.push(GoogleContent {
421                role: "user".to_string(),
422                parts: vec![GooglePart {
423                    text: None,
424                    thought: None,
425                    function_call: None,
426                    function_response: Some(GoogleFunctionResponse {
427                        name: tool_name.clone(),
428                        response: tool_result_payload(content, *is_error),
429                    }),
430                }],
431            }),
432        }
433    }
434
435    let system = if system_lines.is_empty() {
436        None
437    } else {
438        Some(system_lines.join("\n\n"))
439    };
440
441    (contents, system)
442}
443
444fn tool_result_payload(content: &str, is_error: bool) -> Value {
445    if is_error {
446        return json!({"error": content});
447    }
448
449    if let Ok(parsed) = serde_json::from_str::<Value>(content) {
450        parsed
451    } else {
452        json!({"result": content})
453    }
454}
455
456fn normalize_response(response: GenerateContentResponse) -> Result<ModelCompletion, ProviderError> {
457    let Some(candidate) = response.candidates.into_iter().next() else {
458        return Err(ProviderError::Response(
459            "google response missing candidates".to_string(),
460        ));
461    };
462
463    let mut text_parts = Vec::new();
464    let mut thinking_parts = Vec::new();
465    let mut tool_calls = Vec::new();
466
467    if let Some(content) = candidate.content {
468        for (index, part) in content.parts.into_iter().enumerate() {
469            if let Some(text) = part.text {
470                if part.thought.unwrap_or(false) {
471                    thinking_parts.push(text);
472                } else {
473                    text_parts.push(text);
474                }
475            }
476
477            if let Some(function_call) = part.function_call {
478                let Some(name) = function_call.name else {
479                    return Err(ProviderError::Response(
480                        "google functionCall missing name".to_string(),
481                    ));
482                };
483
484                tool_calls.push(ModelToolCall {
485                    id: function_call
486                        .id
487                        .unwrap_or_else(|| format!("call_{}", index + 1)),
488                    name,
489                    arguments: function_call.args.unwrap_or_else(|| json!({})),
490                });
491            }
492        }
493    }
494
495    let usage = response.usage_metadata.map(|usage| ModelUsage {
496        input_tokens: usage.prompt_token_count.unwrap_or(0),
497        output_tokens: usage
498            .candidates_token_count
499            .unwrap_or(0)
500            .saturating_add(usage.thoughts_token_count.unwrap_or(0)),
501    });
502
503    let text = if text_parts.is_empty() {
504        None
505    } else {
506        Some(text_parts.join("\n"))
507    };
508
509    let thinking = if thinking_parts.is_empty() {
510        None
511    } else {
512        Some(thinking_parts.join("\n"))
513    };
514
515    Ok(ModelCompletion {
516        text,
517        thinking,
518        tool_calls,
519        usage,
520    })
521}
522
523async fn extract_api_error(response: reqwest::Response) -> String {
524    let status = response.status();
525    let body = response.text().await.unwrap_or_default();
526
527    if let Ok(parsed) = serde_json::from_str::<GoogleErrorEnvelope>(&body) {
528        let code = parsed.error.code.unwrap_or(status.as_u16());
529        let status_name = parsed
530            .error
531            .status
532            .unwrap_or_else(|| status.to_string().to_uppercase());
533        let message = parsed
534            .error
535            .message
536            .unwrap_or_else(|| "unknown google api error".to_string());
537        return format!("google api error {code} {status_name}: {message}");
538    }
539
540    if body.is_empty() {
541        format!("google api request failed ({status})")
542    } else {
543        format!("google api request failed ({status}): {body}")
544    }
545}
546
547fn clean_gemini_schema(schema: Value) -> Value {
548    let mut root = schema;
549    let defs = match &mut root {
550        Value::Object(map) => {
551            let mut defs = Map::new();
552            for key in ["$defs", "definitions"] {
553                if let Some(Value::Object(found)) = map.remove(key) {
554                    defs.extend(found);
555                }
556            }
557            defs
558        }
559        _ => Map::new(),
560    };
561
562    let resolved = resolve_schema_refs(root, &defs);
563    clean_schema_node(resolved, None)
564}
565
566fn resolve_schema_refs(value: Value, defs: &Map<String, Value>) -> Value {
567    let mut active_refs = HashSet::new();
568    resolve_schema_refs_inner(value, defs, &mut active_refs)
569}
570
571fn resolve_schema_refs_inner(
572    value: Value,
573    defs: &Map<String, Value>,
574    active_refs: &mut HashSet<String>,
575) -> Value {
576    match value {
577        Value::Object(mut map) => {
578            if let Some(reference) = map
579                .get("$ref")
580                .and_then(Value::as_str)
581                .map(ToString::to_string)
582            {
583                let ref_name = reference.rsplit('/').next().unwrap_or("").to_string();
584                if let Some(definition) = defs.get(&ref_name) {
585                    if active_refs.contains(&ref_name) {
586                        map.remove("$ref");
587                        if map.is_empty() {
588                            return json!({"type": "string"});
589                        }
590                    } else {
591                        active_refs.insert(ref_name.clone());
592                        let mut resolved = definition.clone();
593                        if let Value::Object(ref mut resolved_map) = resolved {
594                            map.remove("$ref");
595                            for (key, value) in map {
596                                resolved_map.insert(key, value);
597                            }
598                        }
599                        let output = resolve_schema_refs_inner(resolved, defs, active_refs);
600                        active_refs.remove(&ref_name);
601                        return output;
602                    }
603                } else {
604                    map.remove("$ref");
605                    if map.is_empty() {
606                        return json!({"type": "string"});
607                    }
608                }
609            }
610
611            let mut out = Map::new();
612            for (key, value) in map {
613                out.insert(key, resolve_schema_refs_inner(value, defs, active_refs));
614            }
615            Value::Object(out)
616        }
617        Value::Array(values) => Value::Array(
618            values
619                .into_iter()
620                .map(|value| resolve_schema_refs_inner(value, defs, active_refs))
621                .collect(),
622        ),
623        other => other,
624    }
625}
626
627fn clean_schema_node(value: Value, parent_key: Option<&str>) -> Value {
628    match value {
629        Value::Object(map) => {
630            let mut cleaned = Map::new();
631
632            for (key, value) in map {
633                let is_metadata_title = key == "title" && parent_key != Some("properties");
634                if key == "additionalProperties" || key == "default" || is_metadata_title {
635                    continue;
636                }
637
638                cleaned.insert(key.clone(), clean_schema_node(value, Some(&key)));
639            }
640
641            let type_name = cleaned
642                .get("type")
643                .and_then(Value::as_str)
644                .map(|t| t.to_ascii_lowercase());
645            if type_name.as_deref() == Some("object") {
646                let needs_placeholder = cleaned
647                    .get("properties")
648                    .and_then(Value::as_object)
649                    .map(|properties| properties.is_empty())
650                    .unwrap_or(false);
651
652                if needs_placeholder {
653                    cleaned.insert(
654                        "properties".to_string(),
655                        json!({"_placeholder": {"type": "string"}}),
656                    );
657                }
658            }
659
660            Value::Object(cleaned)
661        }
662        Value::Array(values) => Value::Array(
663            values
664                .into_iter()
665                .map(|value| clean_schema_node(value, parent_key))
666                .collect(),
667        ),
668        other => other,
669    }
670}
671
672#[cfg(test)]
673mod tests {
674    use serde_json::json;
675
676    use super::*;
677
678    fn tool_definition() -> ModelToolDefinition {
679        ModelToolDefinition {
680            name: "lookup".to_string(),
681            description: "Look up something".to_string(),
682            parameters: json!({
683                "type": "object",
684                "properties": {
685                    "query": {"type": "string", "default": "x"}
686                },
687                "required": ["query"],
688                "additionalProperties": false,
689                "title": "LookupTool"
690            }),
691        }
692    }
693
694    #[test]
695    fn build_request_serializes_messages_tools_and_tool_choice() {
696        let messages = vec![
697            ModelMessage::System("You are helpful".to_string()),
698            ModelMessage::User("Find docs".to_string()),
699            ModelMessage::Assistant {
700                content: Some("Calling tool".to_string()),
701                tool_calls: vec![ModelToolCall {
702                    id: "call_1".to_string(),
703                    name: "lookup".to_string(),
704                    arguments: json!({"query": "rust"}),
705                }],
706            },
707            ModelMessage::ToolResult {
708                tool_call_id: "call_1".to_string(),
709                tool_name: "lookup".to_string(),
710                content: "{\"result\":\"ok\"}".to_string(),
711                is_error: false,
712            },
713        ];
714
715        let mut config = GoogleModelConfig::new("key", "gemini-2.5-flash");
716        config.temperature = Some(0.2);
717        config.thinking_budget_tokens = Some(256);
718
719        let request = build_request(
720            &messages,
721            &[tool_definition()],
722            ModelToolChoice::Tool("lookup".to_string()),
723            &config,
724        );
725        let value = serde_json::to_value(request).expect("serializes");
726
727        assert_eq!(
728            value["systemInstruction"]["parts"][0]["text"],
729            "You are helpful"
730        );
731        assert_eq!(value["contents"][0]["role"], "user");
732        assert_eq!(
733            value["contents"][1]["parts"][1]["functionCall"]["name"],
734            "lookup"
735        );
736        assert_eq!(
737            value["contents"][2]["parts"][0]["functionResponse"]["response"]["result"],
738            "ok"
739        );
740        assert_eq!(value["toolConfig"]["functionCallingConfig"]["mode"], "ANY");
741        assert_eq!(
742            value["toolConfig"]["functionCallingConfig"]["allowedFunctionNames"][0],
743            "lookup"
744        );
745        assert_eq!(
746            value["generationConfig"]["thinkingConfig"]["thinkingBudget"],
747            256
748        );
749        assert!(
750            value["tools"][0]["functionDeclarations"][0]["parameters"]
751                .get("additionalProperties")
752                .is_none()
753        );
754    }
755
756    #[test]
757    fn build_request_adds_fallback_content_for_empty_user_message() {
758        let messages = vec![ModelMessage::User(String::new())];
759        let config = GoogleModelConfig::new("key", "gemini-2.5-flash");
760
761        let request = build_request(&messages, &[], ModelToolChoice::Auto, &config);
762        let value = serde_json::to_value(request).expect("serializes");
763
764        assert_eq!(value["contents"].as_array().map(|v| v.len()), Some(1));
765        assert_eq!(value["contents"][0]["role"], "user");
766        assert_eq!(value["contents"][0]["parts"][0]["text"], " ");
767    }
768
769    #[test]
770    fn normalize_response_extracts_text_thinking_tool_calls_and_usage() {
771        let response = GenerateContentResponse {
772            candidates: vec![GoogleCandidate {
773                content: Some(GoogleContent {
774                    role: "model".to_string(),
775                    parts: vec![
776                        GooglePart {
777                            text: Some("answer".to_string()),
778                            thought: None,
779                            function_call: None,
780                            function_response: None,
781                        },
782                        GooglePart {
783                            text: Some("reasoning".to_string()),
784                            thought: Some(true),
785                            function_call: None,
786                            function_response: None,
787                        },
788                        GooglePart {
789                            text: None,
790                            thought: None,
791                            function_call: Some(GoogleFunctionCall {
792                                id: Some("call_x".to_string()),
793                                name: Some("lookup".to_string()),
794                                args: Some(json!({"q": "rust"})),
795                            }),
796                            function_response: None,
797                        },
798                    ],
799                }),
800            }],
801            usage_metadata: Some(GoogleUsageMetadata {
802                prompt_token_count: Some(11),
803                candidates_token_count: Some(7),
804                thoughts_token_count: Some(3),
805            }),
806        };
807
808        let completion = normalize_response(response).expect("response normalizes");
809
810        assert_eq!(completion.text.as_deref(), Some("answer"));
811        assert_eq!(completion.thinking.as_deref(), Some("reasoning"));
812        assert_eq!(completion.tool_calls.len(), 1);
813        assert_eq!(completion.tool_calls[0].name, "lookup");
814        assert_eq!(completion.tool_calls[0].id, "call_x");
815        assert_eq!(
816            completion.usage,
817            Some(ModelUsage {
818                input_tokens: 11,
819                output_tokens: 10,
820            })
821        );
822    }
823
824    #[test]
825    fn normalize_response_requires_candidates() {
826        let err = normalize_response(GenerateContentResponse {
827            candidates: Vec::new(),
828            usage_metadata: None,
829        })
830        .expect_err("should fail");
831
832        match err {
833            ProviderError::Response(message) => {
834                assert!(message.contains("missing candidates"));
835            }
836            other => panic!("unexpected error: {other}"),
837        }
838    }
839
840    #[test]
841    fn clean_gemini_schema_resolves_refs_and_handles_empty_objects() {
842        let schema = json!({
843            "$defs": {
844                "Inner": {
845                    "type": "object",
846                    "properties": {},
847                    "additionalProperties": false
848                }
849            },
850            "type": "object",
851            "properties": {
852                "inner": {
853                    "$ref": "#/$defs/Inner"
854                }
855            },
856            "additionalProperties": false
857        });
858
859        let cleaned = clean_gemini_schema(schema);
860        assert!(cleaned.get("$defs").is_none());
861        assert!(cleaned.get("additionalProperties").is_none());
862        assert_eq!(
863            cleaned["properties"]["inner"]["properties"]["_placeholder"]["type"],
864            "string"
865        );
866    }
867
868    #[test]
869    fn clean_gemini_schema_handles_unresolved_ref_and_legacy_definitions() {
870        let schema = json!({
871            "definitions": {
872                "Legacy": {
873                    "type": "object",
874                    "properties": {
875                        "name": {"type": "string"}
876                    }
877                }
878            },
879            "type": "object",
880            "properties": {
881                "legacy": {"$ref": "#/definitions/Legacy"},
882                "broken": {"$ref": "#/$defs/Unknown"}
883            }
884        });
885
886        let cleaned = clean_gemini_schema(schema);
887
888        assert_eq!(cleaned["properties"]["legacy"]["properties"]["name"]["type"], "string");
889        assert!(cleaned["properties"]["broken"].get("$ref").is_none());
890        assert_eq!(cleaned["properties"]["broken"]["type"], "string");
891    }
892
893    #[test]
894    fn clean_gemini_schema_handles_circular_refs_without_recursing_forever() {
895        let schema = json!({
896            "$defs": {
897                "Node": {
898                    "type": "object",
899                    "properties": {
900                        "next": { "$ref": "#/$defs/Node" }
901                    }
902                }
903            },
904            "type": "object",
905            "properties": {
906                "root": { "$ref": "#/$defs/Node" }
907            }
908        });
909
910        let cleaned = clean_gemini_schema(schema);
911
912        assert!(cleaned["properties"]["root"].get("$ref").is_none());
913        assert_eq!(cleaned["properties"]["root"]["properties"]["next"]["type"], "string");
914    }
915}