1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use serde_json::{Map, Value, json};
5use std::collections::HashSet;
6
7use crate::error::ProviderError;
8use crate::llm::{
9 ChatModel, ModelCompletion, ModelMessage, ModelToolCall, ModelToolChoice, ModelToolDefinition,
10 ModelUsage,
11};
12
13const DEFAULT_API_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
14const EMPTY_USER_CONTENT_FALLBACK: &str = " ";
15
16#[derive(Debug, Clone)]
17pub struct GoogleModelConfig {
19 pub api_key: String,
21 pub model: String,
23 pub api_base_url: Option<String>,
25 pub temperature: Option<f32>,
27 pub top_p: Option<f32>,
29 pub max_output_tokens: Option<u32>,
31 pub thinking_budget_tokens: Option<u32>,
33 pub include_thoughts: Option<bool>,
35}
36
37impl GoogleModelConfig {
38 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
40 Self {
41 api_key: api_key.into(),
42 model: model.into(),
43 api_base_url: None,
44 temperature: None,
45 top_p: None,
46 max_output_tokens: Some(4096),
47 thinking_budget_tokens: None,
48 include_thoughts: None,
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
54pub struct GoogleModel {
56 client: Client,
57 config: GoogleModelConfig,
58}
59
60impl GoogleModel {
61 pub fn new(config: GoogleModelConfig) -> Result<Self, ProviderError> {
63 let client = Client::builder()
64 .build()
65 .map_err(|err| ProviderError::Request(err.to_string()))?;
66
67 Ok(Self { client, config })
68 }
69
70 pub fn from_env(model: impl Into<String>) -> Result<Self, ProviderError> {
72 let api_key = std::env::var("GOOGLE_API_KEY")
73 .or_else(|_| std::env::var("GEMINI_API_KEY"))
74 .map_err(|_| {
75 ProviderError::Request("GOOGLE_API_KEY (or GEMINI_API_KEY) is not set".to_string())
76 })?;
77
78 Self::new(GoogleModelConfig::new(api_key, model))
79 }
80
81 fn endpoint(&self) -> String {
82 let base = self
83 .config
84 .api_base_url
85 .as_deref()
86 .unwrap_or(DEFAULT_API_BASE_URL)
87 .trim_end_matches('/');
88 format!("{base}/models/{}:generateContent", self.config.model)
89 }
90}
91
92#[async_trait]
93impl ChatModel for GoogleModel {
94 async fn invoke(
95 &self,
96 messages: &[ModelMessage],
97 tools: &[ModelToolDefinition],
98 tool_choice: ModelToolChoice,
99 ) -> Result<ModelCompletion, ProviderError> {
100 let request = build_request(messages, tools, tool_choice, &self.config);
101
102 let response = self
103 .client
104 .post(self.endpoint())
105 .header("x-goog-api-key", &self.config.api_key)
106 .header("content-type", "application/json")
107 .json(&request)
108 .send()
109 .await
110 .map_err(|err| ProviderError::Request(err.to_string()))?;
111
112 if !response.status().is_success() {
113 return Err(ProviderError::Request(extract_api_error(response).await));
114 }
115
116 let payload = response
117 .json::<GenerateContentResponse>()
118 .await
119 .map_err(|err| ProviderError::Response(err.to_string()))?;
120
121 normalize_response(payload)
122 }
123}
124
125#[derive(Debug, Serialize)]
126#[serde(rename_all = "camelCase")]
127struct GenerateContentRequest {
128 contents: Vec<GoogleContent>,
129 #[serde(skip_serializing_if = "Option::is_none")]
130 system_instruction: Option<GoogleSystemInstruction>,
131 #[serde(skip_serializing_if = "Option::is_none")]
132 tools: Option<Vec<GoogleTool>>,
133 #[serde(skip_serializing_if = "Option::is_none")]
134 tool_config: Option<GoogleToolConfig>,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 generation_config: Option<GoogleGenerationConfig>,
137}
138
139#[derive(Debug, Serialize, Deserialize, Clone)]
140#[serde(rename_all = "camelCase")]
141struct GoogleContent {
142 role: String,
143 parts: Vec<GooglePart>,
144}
145
146#[derive(Debug, Serialize)]
147#[serde(rename_all = "camelCase")]
148struct GoogleSystemInstruction {
149 parts: Vec<GooglePart>,
150}
151
152#[derive(Debug, Serialize)]
153#[serde(rename_all = "camelCase")]
154struct GoogleTool {
155 function_declarations: Vec<GoogleFunctionDeclaration>,
156}
157
158#[derive(Debug, Serialize)]
159#[serde(rename_all = "camelCase")]
160struct GoogleFunctionDeclaration {
161 name: String,
162 description: String,
163 parameters: Value,
164}
165
166#[derive(Debug, Serialize)]
167#[serde(rename_all = "camelCase")]
168struct GoogleToolConfig {
169 function_calling_config: GoogleFunctionCallingConfig,
170}
171
172#[derive(Debug, Serialize)]
173#[serde(rename_all = "camelCase")]
174struct GoogleFunctionCallingConfig {
175 mode: String,
176 #[serde(skip_serializing_if = "Option::is_none")]
177 allowed_function_names: Option<Vec<String>>,
178}
179
180#[derive(Debug, Serialize)]
181#[serde(rename_all = "camelCase")]
182struct GoogleGenerationConfig {
183 #[serde(skip_serializing_if = "Option::is_none")]
184 temperature: Option<f32>,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 top_p: Option<f32>,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 max_output_tokens: Option<u32>,
189 #[serde(skip_serializing_if = "Option::is_none")]
190 thinking_config: Option<GoogleThinkingConfig>,
191}
192
193#[derive(Debug, Serialize)]
194#[serde(rename_all = "camelCase")]
195struct GoogleThinkingConfig {
196 thinking_budget: u32,
197 #[serde(skip_serializing_if = "Option::is_none")]
198 include_thoughts: Option<bool>,
199}
200
201#[derive(Debug, Serialize, Deserialize, Clone)]
202#[serde(rename_all = "camelCase")]
203struct GooglePart {
204 #[serde(skip_serializing_if = "Option::is_none")]
205 text: Option<String>,
206 #[serde(skip_serializing_if = "Option::is_none")]
207 thought: Option<bool>,
208 #[serde(skip_serializing_if = "Option::is_none")]
209 function_call: Option<GoogleFunctionCall>,
210 #[serde(skip_serializing_if = "Option::is_none")]
211 function_response: Option<GoogleFunctionResponse>,
212}
213
214#[derive(Debug, Serialize, Deserialize, Clone)]
215#[serde(rename_all = "camelCase")]
216struct GoogleFunctionCall {
217 id: Option<String>,
218 name: Option<String>,
219 args: Option<Value>,
220}
221
222#[derive(Debug, Serialize, Deserialize, Clone)]
223#[serde(rename_all = "camelCase")]
224struct GoogleFunctionResponse {
225 name: String,
226 response: Value,
227}
228
229#[derive(Debug, Deserialize)]
230#[serde(rename_all = "camelCase")]
231struct GenerateContentResponse {
232 #[serde(default)]
233 candidates: Vec<GoogleCandidate>,
234 usage_metadata: Option<GoogleUsageMetadata>,
235}
236
237#[derive(Debug, Deserialize)]
238#[serde(rename_all = "camelCase")]
239struct GoogleCandidate {
240 content: Option<GoogleContent>,
241}
242
243#[derive(Debug, Deserialize)]
244#[serde(rename_all = "camelCase")]
245struct GoogleUsageMetadata {
246 prompt_token_count: Option<u32>,
247 candidates_token_count: Option<u32>,
248 thoughts_token_count: Option<u32>,
249}
250
251#[derive(Debug, Deserialize)]
252#[serde(rename_all = "camelCase")]
253struct GoogleErrorEnvelope {
254 error: GoogleApiError,
255}
256
257#[derive(Debug, Deserialize)]
258#[serde(rename_all = "camelCase")]
259struct GoogleApiError {
260 code: Option<u16>,
261 status: Option<String>,
262 message: Option<String>,
263}
264
265fn build_request(
266 messages: &[ModelMessage],
267 tools: &[ModelToolDefinition],
268 tool_choice: ModelToolChoice,
269 config: &GoogleModelConfig,
270) -> GenerateContentRequest {
271 let (contents, system_instruction) = to_google_contents(messages);
272 let contents = ensure_non_empty_contents(contents);
273
274 let tools_payload = if tools.is_empty() {
275 None
276 } else {
277 let declarations = tools
278 .iter()
279 .map(|tool| GoogleFunctionDeclaration {
280 name: tool.name.clone(),
281 description: tool.description.clone(),
282 parameters: clean_gemini_schema(tool.parameters.clone()),
283 })
284 .collect::<Vec<_>>();
285 Some(vec![GoogleTool {
286 function_declarations: declarations,
287 }])
288 };
289
290 let tool_config = if tools.is_empty() {
291 None
292 } else {
293 Some(match tool_choice {
294 ModelToolChoice::Auto => GoogleToolConfig {
295 function_calling_config: GoogleFunctionCallingConfig {
296 mode: "AUTO".to_string(),
297 allowed_function_names: None,
298 },
299 },
300 ModelToolChoice::Required => GoogleToolConfig {
301 function_calling_config: GoogleFunctionCallingConfig {
302 mode: "ANY".to_string(),
303 allowed_function_names: None,
304 },
305 },
306 ModelToolChoice::None => GoogleToolConfig {
307 function_calling_config: GoogleFunctionCallingConfig {
308 mode: "NONE".to_string(),
309 allowed_function_names: None,
310 },
311 },
312 ModelToolChoice::Tool(name) => GoogleToolConfig {
313 function_calling_config: GoogleFunctionCallingConfig {
314 mode: "ANY".to_string(),
315 allowed_function_names: Some(vec![name]),
316 },
317 },
318 })
319 };
320
321 let thinking_config = config
322 .thinking_budget_tokens
323 .map(|budget| GoogleThinkingConfig {
324 thinking_budget: budget,
325 include_thoughts: config.include_thoughts,
326 });
327
328 let generation_config = GoogleGenerationConfig {
329 temperature: config.temperature,
330 top_p: config.top_p,
331 max_output_tokens: config.max_output_tokens,
332 thinking_config,
333 };
334
335 GenerateContentRequest {
336 contents,
337 system_instruction: system_instruction.map(|instruction| GoogleSystemInstruction {
338 parts: vec![GooglePart {
339 text: Some(instruction),
340 thought: None,
341 function_call: None,
342 function_response: None,
343 }],
344 }),
345 tools: tools_payload,
346 tool_config,
347 generation_config: Some(generation_config),
348 }
349}
350
351fn ensure_non_empty_contents(mut contents: Vec<GoogleContent>) -> Vec<GoogleContent> {
352 if contents.is_empty() {
353 contents.push(GoogleContent {
354 role: "user".to_string(),
355 parts: vec![GooglePart {
356 text: Some(EMPTY_USER_CONTENT_FALLBACK.to_string()),
357 thought: None,
358 function_call: None,
359 function_response: None,
360 }],
361 });
362 }
363 contents
364}
365
366fn to_google_contents(messages: &[ModelMessage]) -> (Vec<GoogleContent>, Option<String>) {
367 let mut system_lines = Vec::new();
368 let mut contents = Vec::new();
369
370 for message in messages {
371 match message {
372 ModelMessage::System(content) => {
373 if !content.is_empty() {
374 system_lines.push(content.clone());
375 }
376 }
377 ModelMessage::User(content) => {
378 if content.is_empty() {
379 continue;
380 }
381 contents.push(GoogleContent {
382 role: "user".to_string(),
383 parts: vec![GooglePart {
384 text: Some(content.clone()),
385 thought: None,
386 function_call: None,
387 function_response: None,
388 }],
389 });
390 }
391 ModelMessage::Assistant {
392 content,
393 tool_calls,
394 } => {
395 let mut parts = Vec::new();
396
397 if let Some(text) = content
398 && !text.is_empty()
399 {
400 parts.push(GooglePart {
401 text: Some(text.clone()),
402 thought: None,
403 function_call: None,
404 function_response: None,
405 });
406 }
407
408 for call in tool_calls {
409 parts.push(GooglePart {
410 text: None,
411 thought: None,
412 function_call: Some(GoogleFunctionCall {
413 id: Some(call.id.clone()),
414 name: Some(call.name.clone()),
415 args: Some(call.arguments.clone()),
416 }),
417 function_response: None,
418 });
419 }
420
421 if !parts.is_empty() {
422 contents.push(GoogleContent {
423 role: "model".to_string(),
424 parts,
425 });
426 }
427 }
428 ModelMessage::ToolResult {
429 tool_call_id: _,
430 tool_name,
431 content,
432 is_error,
433 } => contents.push(GoogleContent {
434 role: "user".to_string(),
435 parts: vec![GooglePart {
436 text: None,
437 thought: None,
438 function_call: None,
439 function_response: Some(GoogleFunctionResponse {
440 name: tool_name.clone(),
441 response: tool_result_payload(content, *is_error),
442 }),
443 }],
444 }),
445 }
446 }
447
448 let system = if system_lines.is_empty() {
449 None
450 } else {
451 Some(system_lines.join("\n\n"))
452 };
453
454 (contents, system)
455}
456
457fn tool_result_payload(content: &str, is_error: bool) -> Value {
458 if is_error {
459 return json!({"error": content});
460 }
461
462 if let Ok(parsed) = serde_json::from_str::<Value>(content) {
463 parsed
464 } else {
465 json!({"result": content})
466 }
467}
468
469fn normalize_response(response: GenerateContentResponse) -> Result<ModelCompletion, ProviderError> {
470 let Some(candidate) = response.candidates.into_iter().next() else {
471 return Err(ProviderError::Response(
472 "google response missing candidates".to_string(),
473 ));
474 };
475
476 let mut text_parts = Vec::new();
477 let mut thinking_parts = Vec::new();
478 let mut tool_calls = Vec::new();
479
480 if let Some(content) = candidate.content {
481 for (index, part) in content.parts.into_iter().enumerate() {
482 if let Some(text) = part.text {
483 if part.thought.unwrap_or(false) {
484 thinking_parts.push(text);
485 } else {
486 text_parts.push(text);
487 }
488 }
489
490 if let Some(function_call) = part.function_call {
491 let Some(name) = function_call.name else {
492 return Err(ProviderError::Response(
493 "google functionCall missing name".to_string(),
494 ));
495 };
496
497 tool_calls.push(ModelToolCall {
498 id: function_call
499 .id
500 .unwrap_or_else(|| format!("call_{}", index + 1)),
501 name,
502 arguments: function_call.args.unwrap_or_else(|| json!({})),
503 });
504 }
505 }
506 }
507
508 let usage = response.usage_metadata.map(|usage| ModelUsage {
509 input_tokens: usage.prompt_token_count.unwrap_or(0),
510 output_tokens: usage
511 .candidates_token_count
512 .unwrap_or(0)
513 .saturating_add(usage.thoughts_token_count.unwrap_or(0)),
514 });
515
516 let text = if text_parts.is_empty() {
517 None
518 } else {
519 Some(text_parts.join("\n"))
520 };
521
522 let thinking = if thinking_parts.is_empty() {
523 None
524 } else {
525 Some(thinking_parts.join("\n"))
526 };
527
528 Ok(ModelCompletion {
529 text,
530 thinking,
531 tool_calls,
532 usage,
533 })
534}
535
536async fn extract_api_error(response: reqwest::Response) -> String {
537 let status = response.status();
538 let body = response.text().await.unwrap_or_default();
539
540 if let Ok(parsed) = serde_json::from_str::<GoogleErrorEnvelope>(&body) {
541 let code = parsed.error.code.unwrap_or(status.as_u16());
542 let status_name = parsed
543 .error
544 .status
545 .unwrap_or_else(|| status.to_string().to_uppercase());
546 let message = parsed
547 .error
548 .message
549 .unwrap_or_else(|| "unknown google api error".to_string());
550 return format!("google api error {code} {status_name}: {message}");
551 }
552
553 if body.is_empty() {
554 format!("google api request failed ({status})")
555 } else {
556 format!("google api request failed ({status}): {body}")
557 }
558}
559
560fn clean_gemini_schema(schema: Value) -> Value {
561 let mut root = schema;
562 let defs = match &mut root {
563 Value::Object(map) => {
564 let mut defs = Map::new();
565 for key in ["$defs", "definitions"] {
566 if let Some(Value::Object(found)) = map.remove(key) {
567 defs.extend(found);
568 }
569 }
570 defs
571 }
572 _ => Map::new(),
573 };
574
575 let resolved = resolve_schema_refs(root, &defs);
576 clean_schema_node(resolved, None)
577}
578
579fn resolve_schema_refs(value: Value, defs: &Map<String, Value>) -> Value {
580 let mut active_refs = HashSet::new();
581 resolve_schema_refs_inner(value, defs, &mut active_refs)
582}
583
584fn resolve_schema_refs_inner(
585 value: Value,
586 defs: &Map<String, Value>,
587 active_refs: &mut HashSet<String>,
588) -> Value {
589 match value {
590 Value::Object(mut map) => {
591 if let Some(reference) = map
592 .get("$ref")
593 .and_then(Value::as_str)
594 .map(ToString::to_string)
595 {
596 let ref_name = reference.rsplit('/').next().unwrap_or("").to_string();
597 if let Some(definition) = defs.get(&ref_name) {
598 if active_refs.contains(&ref_name) {
599 map.remove("$ref");
600 if map.is_empty() {
601 return json!({"type": "string"});
602 }
603 } else {
604 active_refs.insert(ref_name.clone());
605 let mut resolved = definition.clone();
606 if let Value::Object(ref mut resolved_map) = resolved {
607 map.remove("$ref");
608 for (key, value) in map {
609 resolved_map.insert(key, value);
610 }
611 }
612 let output = resolve_schema_refs_inner(resolved, defs, active_refs);
613 active_refs.remove(&ref_name);
614 return output;
615 }
616 } else {
617 map.remove("$ref");
618 if map.is_empty() {
619 return json!({"type": "string"});
620 }
621 }
622 }
623
624 let mut out = Map::new();
625 for (key, value) in map {
626 out.insert(key, resolve_schema_refs_inner(value, defs, active_refs));
627 }
628 Value::Object(out)
629 }
630 Value::Array(values) => Value::Array(
631 values
632 .into_iter()
633 .map(|value| resolve_schema_refs_inner(value, defs, active_refs))
634 .collect(),
635 ),
636 other => other,
637 }
638}
639
640fn clean_schema_node(value: Value, parent_key: Option<&str>) -> Value {
641 match value {
642 Value::Object(map) => {
643 let mut cleaned = Map::new();
644
645 for (key, value) in map {
646 let is_metadata_title = key == "title" && parent_key != Some("properties");
647 if key == "additionalProperties" || key == "default" || is_metadata_title {
648 continue;
649 }
650
651 cleaned.insert(key.clone(), clean_schema_node(value, Some(&key)));
652 }
653
654 let type_name = cleaned
655 .get("type")
656 .and_then(Value::as_str)
657 .map(|t| t.to_ascii_lowercase());
658 if type_name.as_deref() == Some("object") {
659 let needs_placeholder = cleaned
660 .get("properties")
661 .and_then(Value::as_object)
662 .map(|properties| properties.is_empty())
663 .unwrap_or(false);
664
665 if needs_placeholder {
666 cleaned.insert(
667 "properties".to_string(),
668 json!({"_placeholder": {"type": "string"}}),
669 );
670 }
671 }
672
673 Value::Object(cleaned)
674 }
675 Value::Array(values) => Value::Array(
676 values
677 .into_iter()
678 .map(|value| clean_schema_node(value, parent_key))
679 .collect(),
680 ),
681 other => other,
682 }
683}
684
685#[cfg(test)]
686mod tests {
687 use serde_json::json;
688
689 use super::*;
690
691 fn tool_definition() -> ModelToolDefinition {
692 ModelToolDefinition {
693 name: "lookup".to_string(),
694 description: "Look up something".to_string(),
695 parameters: json!({
696 "type": "object",
697 "properties": {
698 "query": {"type": "string", "default": "x"}
699 },
700 "required": ["query"],
701 "additionalProperties": false,
702 "title": "LookupTool"
703 }),
704 }
705 }
706
707 #[test]
708 fn build_request_serializes_messages_tools_and_tool_choice() {
709 let messages = vec![
710 ModelMessage::System("You are helpful".to_string()),
711 ModelMessage::User("Find docs".to_string()),
712 ModelMessage::Assistant {
713 content: Some("Calling tool".to_string()),
714 tool_calls: vec![ModelToolCall {
715 id: "call_1".to_string(),
716 name: "lookup".to_string(),
717 arguments: json!({"query": "rust"}),
718 }],
719 },
720 ModelMessage::ToolResult {
721 tool_call_id: "call_1".to_string(),
722 tool_name: "lookup".to_string(),
723 content: "{\"result\":\"ok\"}".to_string(),
724 is_error: false,
725 },
726 ];
727
728 let mut config = GoogleModelConfig::new("key", "gemini-2.5-flash");
729 config.temperature = Some(0.2);
730 config.thinking_budget_tokens = Some(256);
731
732 let request = build_request(
733 &messages,
734 &[tool_definition()],
735 ModelToolChoice::Tool("lookup".to_string()),
736 &config,
737 );
738 let value = serde_json::to_value(request).expect("serializes");
739
740 assert_eq!(
741 value["systemInstruction"]["parts"][0]["text"],
742 "You are helpful"
743 );
744 assert_eq!(value["contents"][0]["role"], "user");
745 assert_eq!(
746 value["contents"][1]["parts"][1]["functionCall"]["name"],
747 "lookup"
748 );
749 assert_eq!(
750 value["contents"][2]["parts"][0]["functionResponse"]["response"]["result"],
751 "ok"
752 );
753 assert_eq!(value["toolConfig"]["functionCallingConfig"]["mode"], "ANY");
754 assert_eq!(
755 value["toolConfig"]["functionCallingConfig"]["allowedFunctionNames"][0],
756 "lookup"
757 );
758 assert_eq!(
759 value["generationConfig"]["thinkingConfig"]["thinkingBudget"],
760 256
761 );
762 assert!(
763 value["tools"][0]["functionDeclarations"][0]["parameters"]
764 .get("additionalProperties")
765 .is_none()
766 );
767 }
768
769 #[test]
770 fn build_request_adds_fallback_content_for_empty_user_message() {
771 let messages = vec![ModelMessage::User(String::new())];
772 let config = GoogleModelConfig::new("key", "gemini-2.5-flash");
773
774 let request = build_request(&messages, &[], ModelToolChoice::Auto, &config);
775 let value = serde_json::to_value(request).expect("serializes");
776
777 assert_eq!(value["contents"].as_array().map(|v| v.len()), Some(1));
778 assert_eq!(value["contents"][0]["role"], "user");
779 assert_eq!(value["contents"][0]["parts"][0]["text"], " ");
780 }
781
782 #[test]
783 fn normalize_response_extracts_text_thinking_tool_calls_and_usage() {
784 let response = GenerateContentResponse {
785 candidates: vec![GoogleCandidate {
786 content: Some(GoogleContent {
787 role: "model".to_string(),
788 parts: vec![
789 GooglePart {
790 text: Some("answer".to_string()),
791 thought: None,
792 function_call: None,
793 function_response: None,
794 },
795 GooglePart {
796 text: Some("reasoning".to_string()),
797 thought: Some(true),
798 function_call: None,
799 function_response: None,
800 },
801 GooglePart {
802 text: None,
803 thought: None,
804 function_call: Some(GoogleFunctionCall {
805 id: Some("call_x".to_string()),
806 name: Some("lookup".to_string()),
807 args: Some(json!({"q": "rust"})),
808 }),
809 function_response: None,
810 },
811 ],
812 }),
813 }],
814 usage_metadata: Some(GoogleUsageMetadata {
815 prompt_token_count: Some(11),
816 candidates_token_count: Some(7),
817 thoughts_token_count: Some(3),
818 }),
819 };
820
821 let completion = normalize_response(response).expect("response normalizes");
822
823 assert_eq!(completion.text.as_deref(), Some("answer"));
824 assert_eq!(completion.thinking.as_deref(), Some("reasoning"));
825 assert_eq!(completion.tool_calls.len(), 1);
826 assert_eq!(completion.tool_calls[0].name, "lookup");
827 assert_eq!(completion.tool_calls[0].id, "call_x");
828 assert_eq!(
829 completion.usage,
830 Some(ModelUsage {
831 input_tokens: 11,
832 output_tokens: 10,
833 })
834 );
835 }
836
837 #[test]
838 fn normalize_response_requires_candidates() {
839 let err = normalize_response(GenerateContentResponse {
840 candidates: Vec::new(),
841 usage_metadata: None,
842 })
843 .expect_err("should fail");
844
845 match err {
846 ProviderError::Response(message) => {
847 assert!(message.contains("missing candidates"));
848 }
849 other => panic!("unexpected error: {other}"),
850 }
851 }
852
853 #[test]
854 fn clean_gemini_schema_resolves_refs_and_handles_empty_objects() {
855 let schema = json!({
856 "$defs": {
857 "Inner": {
858 "type": "object",
859 "properties": {},
860 "additionalProperties": false
861 }
862 },
863 "type": "object",
864 "properties": {
865 "inner": {
866 "$ref": "#/$defs/Inner"
867 }
868 },
869 "additionalProperties": false
870 });
871
872 let cleaned = clean_gemini_schema(schema);
873 assert!(cleaned.get("$defs").is_none());
874 assert!(cleaned.get("additionalProperties").is_none());
875 assert_eq!(
876 cleaned["properties"]["inner"]["properties"]["_placeholder"]["type"],
877 "string"
878 );
879 }
880
881 #[test]
882 fn clean_gemini_schema_handles_unresolved_ref_and_legacy_definitions() {
883 let schema = json!({
884 "definitions": {
885 "Legacy": {
886 "type": "object",
887 "properties": {
888 "name": {"type": "string"}
889 }
890 }
891 },
892 "type": "object",
893 "properties": {
894 "legacy": {"$ref": "#/definitions/Legacy"},
895 "broken": {"$ref": "#/$defs/Unknown"}
896 }
897 });
898
899 let cleaned = clean_gemini_schema(schema);
900
901 assert_eq!(cleaned["properties"]["legacy"]["properties"]["name"]["type"], "string");
902 assert!(cleaned["properties"]["broken"].get("$ref").is_none());
903 assert_eq!(cleaned["properties"]["broken"]["type"], "string");
904 }
905
906 #[test]
907 fn clean_gemini_schema_handles_circular_refs_without_recursing_forever() {
908 let schema = json!({
909 "$defs": {
910 "Node": {
911 "type": "object",
912 "properties": {
913 "next": { "$ref": "#/$defs/Node" }
914 }
915 }
916 },
917 "type": "object",
918 "properties": {
919 "root": { "$ref": "#/$defs/Node" }
920 }
921 });
922
923 let cleaned = clean_gemini_schema(schema);
924
925 assert!(cleaned["properties"]["root"].get("$ref").is_none());
926 assert_eq!(cleaned["properties"]["root"]["properties"]["next"]["type"], "string");
927 }
928}