1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use serde_json::{Map, Value, json};
5use std::collections::HashSet;
6
7use crate::error::ProviderError;
8use crate::llm::{
9 ChatModel, ModelCompletion, ModelMessage, ModelToolCall, ModelToolChoice, ModelToolDefinition,
10 ModelUsage,
11};
12
13const DEFAULT_API_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
14const EMPTY_USER_CONTENT_FALLBACK: &str = " ";
15
16#[derive(Debug, Clone)]
17pub struct GoogleModelConfig {
18 pub api_key: String,
19 pub model: String,
20 pub api_base_url: Option<String>,
21 pub temperature: Option<f32>,
22 pub top_p: Option<f32>,
23 pub max_output_tokens: Option<u32>,
24 pub thinking_budget_tokens: Option<u32>,
25 pub include_thoughts: Option<bool>,
26}
27
28impl GoogleModelConfig {
29 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
30 Self {
31 api_key: api_key.into(),
32 model: model.into(),
33 api_base_url: None,
34 temperature: None,
35 top_p: None,
36 max_output_tokens: Some(4096),
37 thinking_budget_tokens: None,
38 include_thoughts: None,
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
44pub struct GoogleModel {
45 client: Client,
46 config: GoogleModelConfig,
47}
48
49impl GoogleModel {
50 pub fn new(config: GoogleModelConfig) -> Result<Self, ProviderError> {
51 let client = Client::builder()
52 .build()
53 .map_err(|err| ProviderError::Request(err.to_string()))?;
54
55 Ok(Self { client, config })
56 }
57
58 pub fn from_env(model: impl Into<String>) -> Result<Self, ProviderError> {
59 let api_key = std::env::var("GOOGLE_API_KEY")
60 .or_else(|_| std::env::var("GEMINI_API_KEY"))
61 .map_err(|_| {
62 ProviderError::Request("GOOGLE_API_KEY (or GEMINI_API_KEY) is not set".to_string())
63 })?;
64
65 Self::new(GoogleModelConfig::new(api_key, model))
66 }
67
68 fn endpoint(&self) -> String {
69 let base = self
70 .config
71 .api_base_url
72 .as_deref()
73 .unwrap_or(DEFAULT_API_BASE_URL)
74 .trim_end_matches('/');
75 format!("{base}/models/{}:generateContent", self.config.model)
76 }
77}
78
79#[async_trait]
80impl ChatModel for GoogleModel {
81 async fn invoke(
82 &self,
83 messages: &[ModelMessage],
84 tools: &[ModelToolDefinition],
85 tool_choice: ModelToolChoice,
86 ) -> Result<ModelCompletion, ProviderError> {
87 let request = build_request(messages, tools, tool_choice, &self.config);
88
89 let response = self
90 .client
91 .post(self.endpoint())
92 .header("x-goog-api-key", &self.config.api_key)
93 .header("content-type", "application/json")
94 .json(&request)
95 .send()
96 .await
97 .map_err(|err| ProviderError::Request(err.to_string()))?;
98
99 if !response.status().is_success() {
100 return Err(ProviderError::Request(extract_api_error(response).await));
101 }
102
103 let payload = response
104 .json::<GenerateContentResponse>()
105 .await
106 .map_err(|err| ProviderError::Response(err.to_string()))?;
107
108 normalize_response(payload)
109 }
110}
111
112#[derive(Debug, Serialize)]
113#[serde(rename_all = "camelCase")]
114struct GenerateContentRequest {
115 contents: Vec<GoogleContent>,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 system_instruction: Option<GoogleSystemInstruction>,
118 #[serde(skip_serializing_if = "Option::is_none")]
119 tools: Option<Vec<GoogleTool>>,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 tool_config: Option<GoogleToolConfig>,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 generation_config: Option<GoogleGenerationConfig>,
124}
125
126#[derive(Debug, Serialize, Deserialize, Clone)]
127#[serde(rename_all = "camelCase")]
128struct GoogleContent {
129 role: String,
130 parts: Vec<GooglePart>,
131}
132
133#[derive(Debug, Serialize)]
134#[serde(rename_all = "camelCase")]
135struct GoogleSystemInstruction {
136 parts: Vec<GooglePart>,
137}
138
139#[derive(Debug, Serialize)]
140#[serde(rename_all = "camelCase")]
141struct GoogleTool {
142 function_declarations: Vec<GoogleFunctionDeclaration>,
143}
144
145#[derive(Debug, Serialize)]
146#[serde(rename_all = "camelCase")]
147struct GoogleFunctionDeclaration {
148 name: String,
149 description: String,
150 parameters: Value,
151}
152
153#[derive(Debug, Serialize)]
154#[serde(rename_all = "camelCase")]
155struct GoogleToolConfig {
156 function_calling_config: GoogleFunctionCallingConfig,
157}
158
159#[derive(Debug, Serialize)]
160#[serde(rename_all = "camelCase")]
161struct GoogleFunctionCallingConfig {
162 mode: String,
163 #[serde(skip_serializing_if = "Option::is_none")]
164 allowed_function_names: Option<Vec<String>>,
165}
166
167#[derive(Debug, Serialize)]
168#[serde(rename_all = "camelCase")]
169struct GoogleGenerationConfig {
170 #[serde(skip_serializing_if = "Option::is_none")]
171 temperature: Option<f32>,
172 #[serde(skip_serializing_if = "Option::is_none")]
173 top_p: Option<f32>,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 max_output_tokens: Option<u32>,
176 #[serde(skip_serializing_if = "Option::is_none")]
177 thinking_config: Option<GoogleThinkingConfig>,
178}
179
180#[derive(Debug, Serialize)]
181#[serde(rename_all = "camelCase")]
182struct GoogleThinkingConfig {
183 thinking_budget: u32,
184 #[serde(skip_serializing_if = "Option::is_none")]
185 include_thoughts: Option<bool>,
186}
187
188#[derive(Debug, Serialize, Deserialize, Clone)]
189#[serde(rename_all = "camelCase")]
190struct GooglePart {
191 #[serde(skip_serializing_if = "Option::is_none")]
192 text: Option<String>,
193 #[serde(skip_serializing_if = "Option::is_none")]
194 thought: Option<bool>,
195 #[serde(skip_serializing_if = "Option::is_none")]
196 function_call: Option<GoogleFunctionCall>,
197 #[serde(skip_serializing_if = "Option::is_none")]
198 function_response: Option<GoogleFunctionResponse>,
199}
200
201#[derive(Debug, Serialize, Deserialize, Clone)]
202#[serde(rename_all = "camelCase")]
203struct GoogleFunctionCall {
204 id: Option<String>,
205 name: Option<String>,
206 args: Option<Value>,
207}
208
209#[derive(Debug, Serialize, Deserialize, Clone)]
210#[serde(rename_all = "camelCase")]
211struct GoogleFunctionResponse {
212 name: String,
213 response: Value,
214}
215
216#[derive(Debug, Deserialize)]
217#[serde(rename_all = "camelCase")]
218struct GenerateContentResponse {
219 #[serde(default)]
220 candidates: Vec<GoogleCandidate>,
221 usage_metadata: Option<GoogleUsageMetadata>,
222}
223
224#[derive(Debug, Deserialize)]
225#[serde(rename_all = "camelCase")]
226struct GoogleCandidate {
227 content: Option<GoogleContent>,
228}
229
230#[derive(Debug, Deserialize)]
231#[serde(rename_all = "camelCase")]
232struct GoogleUsageMetadata {
233 prompt_token_count: Option<u32>,
234 candidates_token_count: Option<u32>,
235 thoughts_token_count: Option<u32>,
236}
237
238#[derive(Debug, Deserialize)]
239#[serde(rename_all = "camelCase")]
240struct GoogleErrorEnvelope {
241 error: GoogleApiError,
242}
243
244#[derive(Debug, Deserialize)]
245#[serde(rename_all = "camelCase")]
246struct GoogleApiError {
247 code: Option<u16>,
248 status: Option<String>,
249 message: Option<String>,
250}
251
252fn build_request(
253 messages: &[ModelMessage],
254 tools: &[ModelToolDefinition],
255 tool_choice: ModelToolChoice,
256 config: &GoogleModelConfig,
257) -> GenerateContentRequest {
258 let (contents, system_instruction) = to_google_contents(messages);
259 let contents = ensure_non_empty_contents(contents);
260
261 let tools_payload = if tools.is_empty() {
262 None
263 } else {
264 let declarations = tools
265 .iter()
266 .map(|tool| GoogleFunctionDeclaration {
267 name: tool.name.clone(),
268 description: tool.description.clone(),
269 parameters: clean_gemini_schema(tool.parameters.clone()),
270 })
271 .collect::<Vec<_>>();
272 Some(vec![GoogleTool {
273 function_declarations: declarations,
274 }])
275 };
276
277 let tool_config = if tools.is_empty() {
278 None
279 } else {
280 Some(match tool_choice {
281 ModelToolChoice::Auto => GoogleToolConfig {
282 function_calling_config: GoogleFunctionCallingConfig {
283 mode: "AUTO".to_string(),
284 allowed_function_names: None,
285 },
286 },
287 ModelToolChoice::Required => GoogleToolConfig {
288 function_calling_config: GoogleFunctionCallingConfig {
289 mode: "ANY".to_string(),
290 allowed_function_names: None,
291 },
292 },
293 ModelToolChoice::None => GoogleToolConfig {
294 function_calling_config: GoogleFunctionCallingConfig {
295 mode: "NONE".to_string(),
296 allowed_function_names: None,
297 },
298 },
299 ModelToolChoice::Tool(name) => GoogleToolConfig {
300 function_calling_config: GoogleFunctionCallingConfig {
301 mode: "ANY".to_string(),
302 allowed_function_names: Some(vec![name]),
303 },
304 },
305 })
306 };
307
308 let thinking_config = config
309 .thinking_budget_tokens
310 .map(|budget| GoogleThinkingConfig {
311 thinking_budget: budget,
312 include_thoughts: config.include_thoughts,
313 });
314
315 let generation_config = GoogleGenerationConfig {
316 temperature: config.temperature,
317 top_p: config.top_p,
318 max_output_tokens: config.max_output_tokens,
319 thinking_config,
320 };
321
322 GenerateContentRequest {
323 contents,
324 system_instruction: system_instruction.map(|instruction| GoogleSystemInstruction {
325 parts: vec![GooglePart {
326 text: Some(instruction),
327 thought: None,
328 function_call: None,
329 function_response: None,
330 }],
331 }),
332 tools: tools_payload,
333 tool_config,
334 generation_config: Some(generation_config),
335 }
336}
337
338fn ensure_non_empty_contents(mut contents: Vec<GoogleContent>) -> Vec<GoogleContent> {
339 if contents.is_empty() {
340 contents.push(GoogleContent {
341 role: "user".to_string(),
342 parts: vec![GooglePart {
343 text: Some(EMPTY_USER_CONTENT_FALLBACK.to_string()),
344 thought: None,
345 function_call: None,
346 function_response: None,
347 }],
348 });
349 }
350 contents
351}
352
353fn to_google_contents(messages: &[ModelMessage]) -> (Vec<GoogleContent>, Option<String>) {
354 let mut system_lines = Vec::new();
355 let mut contents = Vec::new();
356
357 for message in messages {
358 match message {
359 ModelMessage::System(content) => {
360 if !content.is_empty() {
361 system_lines.push(content.clone());
362 }
363 }
364 ModelMessage::User(content) => {
365 if content.is_empty() {
366 continue;
367 }
368 contents.push(GoogleContent {
369 role: "user".to_string(),
370 parts: vec![GooglePart {
371 text: Some(content.clone()),
372 thought: None,
373 function_call: None,
374 function_response: None,
375 }],
376 });
377 }
378 ModelMessage::Assistant {
379 content,
380 tool_calls,
381 } => {
382 let mut parts = Vec::new();
383
384 if let Some(text) = content
385 && !text.is_empty()
386 {
387 parts.push(GooglePart {
388 text: Some(text.clone()),
389 thought: None,
390 function_call: None,
391 function_response: None,
392 });
393 }
394
395 for call in tool_calls {
396 parts.push(GooglePart {
397 text: None,
398 thought: None,
399 function_call: Some(GoogleFunctionCall {
400 id: Some(call.id.clone()),
401 name: Some(call.name.clone()),
402 args: Some(call.arguments.clone()),
403 }),
404 function_response: None,
405 });
406 }
407
408 if !parts.is_empty() {
409 contents.push(GoogleContent {
410 role: "model".to_string(),
411 parts,
412 });
413 }
414 }
415 ModelMessage::ToolResult {
416 tool_call_id: _,
417 tool_name,
418 content,
419 is_error,
420 } => contents.push(GoogleContent {
421 role: "user".to_string(),
422 parts: vec![GooglePart {
423 text: None,
424 thought: None,
425 function_call: None,
426 function_response: Some(GoogleFunctionResponse {
427 name: tool_name.clone(),
428 response: tool_result_payload(content, *is_error),
429 }),
430 }],
431 }),
432 }
433 }
434
435 let system = if system_lines.is_empty() {
436 None
437 } else {
438 Some(system_lines.join("\n\n"))
439 };
440
441 (contents, system)
442}
443
444fn tool_result_payload(content: &str, is_error: bool) -> Value {
445 if is_error {
446 return json!({"error": content});
447 }
448
449 if let Ok(parsed) = serde_json::from_str::<Value>(content) {
450 parsed
451 } else {
452 json!({"result": content})
453 }
454}
455
456fn normalize_response(response: GenerateContentResponse) -> Result<ModelCompletion, ProviderError> {
457 let Some(candidate) = response.candidates.into_iter().next() else {
458 return Err(ProviderError::Response(
459 "google response missing candidates".to_string(),
460 ));
461 };
462
463 let mut text_parts = Vec::new();
464 let mut thinking_parts = Vec::new();
465 let mut tool_calls = Vec::new();
466
467 if let Some(content) = candidate.content {
468 for (index, part) in content.parts.into_iter().enumerate() {
469 if let Some(text) = part.text {
470 if part.thought.unwrap_or(false) {
471 thinking_parts.push(text);
472 } else {
473 text_parts.push(text);
474 }
475 }
476
477 if let Some(function_call) = part.function_call {
478 let Some(name) = function_call.name else {
479 return Err(ProviderError::Response(
480 "google functionCall missing name".to_string(),
481 ));
482 };
483
484 tool_calls.push(ModelToolCall {
485 id: function_call
486 .id
487 .unwrap_or_else(|| format!("call_{}", index + 1)),
488 name,
489 arguments: function_call.args.unwrap_or_else(|| json!({})),
490 });
491 }
492 }
493 }
494
495 let usage = response.usage_metadata.map(|usage| ModelUsage {
496 input_tokens: usage.prompt_token_count.unwrap_or(0),
497 output_tokens: usage
498 .candidates_token_count
499 .unwrap_or(0)
500 .saturating_add(usage.thoughts_token_count.unwrap_or(0)),
501 });
502
503 let text = if text_parts.is_empty() {
504 None
505 } else {
506 Some(text_parts.join("\n"))
507 };
508
509 let thinking = if thinking_parts.is_empty() {
510 None
511 } else {
512 Some(thinking_parts.join("\n"))
513 };
514
515 Ok(ModelCompletion {
516 text,
517 thinking,
518 tool_calls,
519 usage,
520 })
521}
522
523async fn extract_api_error(response: reqwest::Response) -> String {
524 let status = response.status();
525 let body = response.text().await.unwrap_or_default();
526
527 if let Ok(parsed) = serde_json::from_str::<GoogleErrorEnvelope>(&body) {
528 let code = parsed.error.code.unwrap_or(status.as_u16());
529 let status_name = parsed
530 .error
531 .status
532 .unwrap_or_else(|| status.to_string().to_uppercase());
533 let message = parsed
534 .error
535 .message
536 .unwrap_or_else(|| "unknown google api error".to_string());
537 return format!("google api error {code} {status_name}: {message}");
538 }
539
540 if body.is_empty() {
541 format!("google api request failed ({status})")
542 } else {
543 format!("google api request failed ({status}): {body}")
544 }
545}
546
547fn clean_gemini_schema(schema: Value) -> Value {
548 let mut root = schema;
549 let defs = match &mut root {
550 Value::Object(map) => {
551 let mut defs = Map::new();
552 for key in ["$defs", "definitions"] {
553 if let Some(Value::Object(found)) = map.remove(key) {
554 defs.extend(found);
555 }
556 }
557 defs
558 }
559 _ => Map::new(),
560 };
561
562 let resolved = resolve_schema_refs(root, &defs);
563 clean_schema_node(resolved, None)
564}
565
566fn resolve_schema_refs(value: Value, defs: &Map<String, Value>) -> Value {
567 let mut active_refs = HashSet::new();
568 resolve_schema_refs_inner(value, defs, &mut active_refs)
569}
570
571fn resolve_schema_refs_inner(
572 value: Value,
573 defs: &Map<String, Value>,
574 active_refs: &mut HashSet<String>,
575) -> Value {
576 match value {
577 Value::Object(mut map) => {
578 if let Some(reference) = map
579 .get("$ref")
580 .and_then(Value::as_str)
581 .map(ToString::to_string)
582 {
583 let ref_name = reference.rsplit('/').next().unwrap_or("").to_string();
584 if let Some(definition) = defs.get(&ref_name) {
585 if active_refs.contains(&ref_name) {
586 map.remove("$ref");
587 if map.is_empty() {
588 return json!({"type": "string"});
589 }
590 } else {
591 active_refs.insert(ref_name.clone());
592 let mut resolved = definition.clone();
593 if let Value::Object(ref mut resolved_map) = resolved {
594 map.remove("$ref");
595 for (key, value) in map {
596 resolved_map.insert(key, value);
597 }
598 }
599 let output = resolve_schema_refs_inner(resolved, defs, active_refs);
600 active_refs.remove(&ref_name);
601 return output;
602 }
603 } else {
604 map.remove("$ref");
605 if map.is_empty() {
606 return json!({"type": "string"});
607 }
608 }
609 }
610
611 let mut out = Map::new();
612 for (key, value) in map {
613 out.insert(key, resolve_schema_refs_inner(value, defs, active_refs));
614 }
615 Value::Object(out)
616 }
617 Value::Array(values) => Value::Array(
618 values
619 .into_iter()
620 .map(|value| resolve_schema_refs_inner(value, defs, active_refs))
621 .collect(),
622 ),
623 other => other,
624 }
625}
626
627fn clean_schema_node(value: Value, parent_key: Option<&str>) -> Value {
628 match value {
629 Value::Object(map) => {
630 let mut cleaned = Map::new();
631
632 for (key, value) in map {
633 let is_metadata_title = key == "title" && parent_key != Some("properties");
634 if key == "additionalProperties" || key == "default" || is_metadata_title {
635 continue;
636 }
637
638 cleaned.insert(key.clone(), clean_schema_node(value, Some(&key)));
639 }
640
641 let type_name = cleaned
642 .get("type")
643 .and_then(Value::as_str)
644 .map(|t| t.to_ascii_lowercase());
645 if type_name.as_deref() == Some("object") {
646 let needs_placeholder = cleaned
647 .get("properties")
648 .and_then(Value::as_object)
649 .map(|properties| properties.is_empty())
650 .unwrap_or(false);
651
652 if needs_placeholder {
653 cleaned.insert(
654 "properties".to_string(),
655 json!({"_placeholder": {"type": "string"}}),
656 );
657 }
658 }
659
660 Value::Object(cleaned)
661 }
662 Value::Array(values) => Value::Array(
663 values
664 .into_iter()
665 .map(|value| clean_schema_node(value, parent_key))
666 .collect(),
667 ),
668 other => other,
669 }
670}
671
672#[cfg(test)]
673mod tests {
674 use serde_json::json;
675
676 use super::*;
677
678 fn tool_definition() -> ModelToolDefinition {
679 ModelToolDefinition {
680 name: "lookup".to_string(),
681 description: "Look up something".to_string(),
682 parameters: json!({
683 "type": "object",
684 "properties": {
685 "query": {"type": "string", "default": "x"}
686 },
687 "required": ["query"],
688 "additionalProperties": false,
689 "title": "LookupTool"
690 }),
691 }
692 }
693
694 #[test]
695 fn build_request_serializes_messages_tools_and_tool_choice() {
696 let messages = vec![
697 ModelMessage::System("You are helpful".to_string()),
698 ModelMessage::User("Find docs".to_string()),
699 ModelMessage::Assistant {
700 content: Some("Calling tool".to_string()),
701 tool_calls: vec![ModelToolCall {
702 id: "call_1".to_string(),
703 name: "lookup".to_string(),
704 arguments: json!({"query": "rust"}),
705 }],
706 },
707 ModelMessage::ToolResult {
708 tool_call_id: "call_1".to_string(),
709 tool_name: "lookup".to_string(),
710 content: "{\"result\":\"ok\"}".to_string(),
711 is_error: false,
712 },
713 ];
714
715 let mut config = GoogleModelConfig::new("key", "gemini-2.5-flash");
716 config.temperature = Some(0.2);
717 config.thinking_budget_tokens = Some(256);
718
719 let request = build_request(
720 &messages,
721 &[tool_definition()],
722 ModelToolChoice::Tool("lookup".to_string()),
723 &config,
724 );
725 let value = serde_json::to_value(request).expect("serializes");
726
727 assert_eq!(
728 value["systemInstruction"]["parts"][0]["text"],
729 "You are helpful"
730 );
731 assert_eq!(value["contents"][0]["role"], "user");
732 assert_eq!(
733 value["contents"][1]["parts"][1]["functionCall"]["name"],
734 "lookup"
735 );
736 assert_eq!(
737 value["contents"][2]["parts"][0]["functionResponse"]["response"]["result"],
738 "ok"
739 );
740 assert_eq!(value["toolConfig"]["functionCallingConfig"]["mode"], "ANY");
741 assert_eq!(
742 value["toolConfig"]["functionCallingConfig"]["allowedFunctionNames"][0],
743 "lookup"
744 );
745 assert_eq!(
746 value["generationConfig"]["thinkingConfig"]["thinkingBudget"],
747 256
748 );
749 assert!(
750 value["tools"][0]["functionDeclarations"][0]["parameters"]
751 .get("additionalProperties")
752 .is_none()
753 );
754 }
755
756 #[test]
757 fn build_request_adds_fallback_content_for_empty_user_message() {
758 let messages = vec![ModelMessage::User(String::new())];
759 let config = GoogleModelConfig::new("key", "gemini-2.5-flash");
760
761 let request = build_request(&messages, &[], ModelToolChoice::Auto, &config);
762 let value = serde_json::to_value(request).expect("serializes");
763
764 assert_eq!(value["contents"].as_array().map(|v| v.len()), Some(1));
765 assert_eq!(value["contents"][0]["role"], "user");
766 assert_eq!(value["contents"][0]["parts"][0]["text"], " ");
767 }
768
769 #[test]
770 fn normalize_response_extracts_text_thinking_tool_calls_and_usage() {
771 let response = GenerateContentResponse {
772 candidates: vec![GoogleCandidate {
773 content: Some(GoogleContent {
774 role: "model".to_string(),
775 parts: vec![
776 GooglePart {
777 text: Some("answer".to_string()),
778 thought: None,
779 function_call: None,
780 function_response: None,
781 },
782 GooglePart {
783 text: Some("reasoning".to_string()),
784 thought: Some(true),
785 function_call: None,
786 function_response: None,
787 },
788 GooglePart {
789 text: None,
790 thought: None,
791 function_call: Some(GoogleFunctionCall {
792 id: Some("call_x".to_string()),
793 name: Some("lookup".to_string()),
794 args: Some(json!({"q": "rust"})),
795 }),
796 function_response: None,
797 },
798 ],
799 }),
800 }],
801 usage_metadata: Some(GoogleUsageMetadata {
802 prompt_token_count: Some(11),
803 candidates_token_count: Some(7),
804 thoughts_token_count: Some(3),
805 }),
806 };
807
808 let completion = normalize_response(response).expect("response normalizes");
809
810 assert_eq!(completion.text.as_deref(), Some("answer"));
811 assert_eq!(completion.thinking.as_deref(), Some("reasoning"));
812 assert_eq!(completion.tool_calls.len(), 1);
813 assert_eq!(completion.tool_calls[0].name, "lookup");
814 assert_eq!(completion.tool_calls[0].id, "call_x");
815 assert_eq!(
816 completion.usage,
817 Some(ModelUsage {
818 input_tokens: 11,
819 output_tokens: 10,
820 })
821 );
822 }
823
824 #[test]
825 fn normalize_response_requires_candidates() {
826 let err = normalize_response(GenerateContentResponse {
827 candidates: Vec::new(),
828 usage_metadata: None,
829 })
830 .expect_err("should fail");
831
832 match err {
833 ProviderError::Response(message) => {
834 assert!(message.contains("missing candidates"));
835 }
836 other => panic!("unexpected error: {other}"),
837 }
838 }
839
840 #[test]
841 fn clean_gemini_schema_resolves_refs_and_handles_empty_objects() {
842 let schema = json!({
843 "$defs": {
844 "Inner": {
845 "type": "object",
846 "properties": {},
847 "additionalProperties": false
848 }
849 },
850 "type": "object",
851 "properties": {
852 "inner": {
853 "$ref": "#/$defs/Inner"
854 }
855 },
856 "additionalProperties": false
857 });
858
859 let cleaned = clean_gemini_schema(schema);
860 assert!(cleaned.get("$defs").is_none());
861 assert!(cleaned.get("additionalProperties").is_none());
862 assert_eq!(
863 cleaned["properties"]["inner"]["properties"]["_placeholder"]["type"],
864 "string"
865 );
866 }
867
868 #[test]
869 fn clean_gemini_schema_handles_unresolved_ref_and_legacy_definitions() {
870 let schema = json!({
871 "definitions": {
872 "Legacy": {
873 "type": "object",
874 "properties": {
875 "name": {"type": "string"}
876 }
877 }
878 },
879 "type": "object",
880 "properties": {
881 "legacy": {"$ref": "#/definitions/Legacy"},
882 "broken": {"$ref": "#/$defs/Unknown"}
883 }
884 });
885
886 let cleaned = clean_gemini_schema(schema);
887
888 assert_eq!(cleaned["properties"]["legacy"]["properties"]["name"]["type"], "string");
889 assert!(cleaned["properties"]["broken"].get("$ref").is_none());
890 assert_eq!(cleaned["properties"]["broken"]["type"], "string");
891 }
892
893 #[test]
894 fn clean_gemini_schema_handles_circular_refs_without_recursing_forever() {
895 let schema = json!({
896 "$defs": {
897 "Node": {
898 "type": "object",
899 "properties": {
900 "next": { "$ref": "#/$defs/Node" }
901 }
902 }
903 },
904 "type": "object",
905 "properties": {
906 "root": { "$ref": "#/$defs/Node" }
907 }
908 });
909
910 let cleaned = clean_gemini_schema(schema);
911
912 assert!(cleaned["properties"]["root"].get("$ref").is_none());
913 assert_eq!(cleaned["properties"]["root"]["properties"]["next"]["type"], "string");
914 }
915}