1use std::collections::HashMap;
2
3use bitrouter_core::{
4 errors::{BitrouterError, Result},
5 models::{
6 language::{
7 finish_reason::LanguageModelFinishReason,
8 tool::LanguageModelTool,
9 tool_choice::LanguageModelToolChoice,
10 usage::{LanguageModelInputTokens, LanguageModelOutputTokens, LanguageModelUsage},
11 },
12 shared::{provider::ProviderMetadata, types::JsonValue},
13 },
14};
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17
18pub(super) const GOOGLE_PROVIDER_NAME: &str = "google";
19pub(super) const STREAM_TEXT_ID: &str = "text";
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(rename_all = "camelCase")]
25pub struct GoogleGenerateContentResponse {
26 #[serde(default)]
27 pub candidates: Option<Vec<GoogleCandidate>>,
28 #[serde(default)]
29 pub usage_metadata: Option<GoogleUsageMetadata>,
30 #[serde(default)]
31 pub model_version: Option<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(rename_all = "camelCase")]
36pub struct GoogleCandidate {
37 #[serde(default)]
38 pub content: Option<GoogleContent>,
39 #[serde(default)]
40 pub finish_reason: Option<String>,
41 #[serde(default)]
42 pub index: Option<u32>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct GoogleContent {
48 #[serde(default)]
49 pub role: Option<String>,
50 #[serde(default)]
51 pub parts: Option<Vec<GooglePart>>,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56pub struct GooglePart {
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub text: Option<String>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub inline_data: Option<GoogleInlineData>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub function_call: Option<GoogleFunctionCall>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub function_response: Option<GoogleFunctionResponse>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68#[serde(rename_all = "camelCase")]
69pub struct GoogleInlineData {
70 pub mime_type: String,
71 pub data: String,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75#[serde(rename_all = "camelCase")]
76pub struct GoogleFunctionCall {
77 pub name: String,
78 #[serde(default)]
79 pub args: Option<JsonValue>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83#[serde(rename_all = "camelCase")]
84pub struct GoogleFunctionResponse {
85 pub name: String,
86 pub response: JsonValue,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
92#[serde(rename_all = "camelCase")]
93pub struct GoogleUsageMetadata {
94 #[serde(default)]
95 pub prompt_token_count: Option<u32>,
96 #[serde(default)]
97 pub candidates_token_count: Option<u32>,
98 #[serde(default)]
99 pub total_token_count: Option<u32>,
100 #[serde(default)]
101 pub cached_content_token_count: Option<u32>,
102}
103
104impl From<GoogleUsageMetadata> for LanguageModelUsage {
105 fn from(usage: GoogleUsageMetadata) -> Self {
106 let raw = serde_json::to_value(&usage).ok();
107 LanguageModelUsage {
108 input_tokens: LanguageModelInputTokens {
109 total: usage.prompt_token_count,
110 no_cache: usage.prompt_token_count.map(|total| {
111 total.saturating_sub(usage.cached_content_token_count.unwrap_or(0))
112 }),
113 cache_read: usage.cached_content_token_count,
114 cache_write: None,
115 },
116 output_tokens: LanguageModelOutputTokens {
117 total: usage.candidates_token_count,
118 text: usage.candidates_token_count,
119 reasoning: None,
120 },
121 raw,
122 }
123 }
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct GoogleErrorEnvelope {
130 pub error: GoogleApiError,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct GoogleApiError {
135 #[serde(default)]
136 pub code: Option<u16>,
137 #[serde(default)]
138 pub message: Option<String>,
139 #[serde(default)]
140 pub status: Option<String>,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
146#[serde(rename_all = "camelCase")]
147pub struct GoogleGenerateContentRequest {
148 pub contents: Vec<GoogleContent>,
149 #[serde(skip_serializing_if = "Option::is_none")]
150 pub system_instruction: Option<GoogleContent>,
151 #[serde(skip_serializing_if = "Option::is_none")]
152 pub tools: Option<Vec<GoogleTool>>,
153 #[serde(skip_serializing_if = "Option::is_none")]
154 pub tool_config: Option<GoogleToolConfig>,
155 #[serde(skip_serializing_if = "Option::is_none")]
156 pub generation_config: Option<GoogleGenerationConfig>,
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
160#[serde(rename_all = "camelCase")]
161pub struct GoogleGenerationConfig {
162 #[serde(skip_serializing_if = "Option::is_none")]
163 pub temperature: Option<f32>,
164 #[serde(skip_serializing_if = "Option::is_none")]
165 pub top_p: Option<f32>,
166 #[serde(skip_serializing_if = "Option::is_none")]
167 pub top_k: Option<u32>,
168 #[serde(skip_serializing_if = "Option::is_none")]
169 pub max_output_tokens: Option<u32>,
170 #[serde(skip_serializing_if = "Option::is_none")]
171 pub stop_sequences: Option<Vec<String>>,
172 #[serde(skip_serializing_if = "Option::is_none")]
173 pub presence_penalty: Option<f32>,
174 #[serde(skip_serializing_if = "Option::is_none")]
175 pub frequency_penalty: Option<f32>,
176 #[serde(skip_serializing_if = "Option::is_none")]
177 pub seed: Option<i64>,
178 #[serde(skip_serializing_if = "Option::is_none")]
179 pub response_mime_type: Option<String>,
180 #[serde(skip_serializing_if = "Option::is_none")]
181 pub response_schema: Option<schemars::Schema>,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185#[serde(rename_all = "camelCase")]
186pub struct GoogleTool {
187 #[serde(skip_serializing_if = "Option::is_none")]
188 pub function_declarations: Option<Vec<GoogleFunctionDeclaration>>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192#[serde(rename_all = "camelCase")]
193pub struct GoogleFunctionDeclaration {
194 pub name: String,
195 #[serde(skip_serializing_if = "Option::is_none")]
196 pub description: Option<String>,
197 #[serde(skip_serializing_if = "Option::is_none")]
198 pub parameters: Option<schemars::Schema>,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
202#[serde(rename_all = "camelCase")]
203pub struct GoogleToolConfig {
204 #[serde(skip_serializing_if = "Option::is_none")]
205 pub function_calling_config: Option<GoogleFunctionCallingConfig>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209#[serde(rename_all = "camelCase")]
210pub struct GoogleFunctionCallingConfig {
211 pub mode: String,
212 #[serde(skip_serializing_if = "Option::is_none")]
213 pub allowed_function_names: Option<Vec<String>>,
214}
215
216impl From<&LanguageModelToolChoice> for GoogleFunctionCallingConfig {
219 fn from(choice: &LanguageModelToolChoice) -> Self {
220 match choice {
221 LanguageModelToolChoice::Auto => GoogleFunctionCallingConfig {
222 mode: "AUTO".to_owned(),
223 allowed_function_names: None,
224 },
225 LanguageModelToolChoice::None => GoogleFunctionCallingConfig {
226 mode: "NONE".to_owned(),
227 allowed_function_names: None,
228 },
229 LanguageModelToolChoice::Required => GoogleFunctionCallingConfig {
230 mode: "ANY".to_owned(),
231 allowed_function_names: None,
232 },
233 LanguageModelToolChoice::Tool { tool_name } => GoogleFunctionCallingConfig {
234 mode: "ANY".to_owned(),
235 allowed_function_names: Some(vec![tool_name.clone()]),
236 },
237 }
238 }
239}
240
241impl TryFrom<&LanguageModelTool> for GoogleFunctionDeclaration {
242 type Error = BitrouterError;
243
244 fn try_from(tool: &LanguageModelTool) -> Result<Self> {
245 match tool {
246 LanguageModelTool::Function {
247 name,
248 description,
249 input_schema,
250 ..
251 } => Ok(GoogleFunctionDeclaration {
252 name: name.clone(),
253 description: description.clone(),
254 parameters: Some(input_schema.clone()),
255 }),
256 LanguageModelTool::Provider { id, .. } => Err(BitrouterError::unsupported(
257 GOOGLE_PROVIDER_NAME,
258 format!("provider tool {}:{}", id.provider_name, id.tool_id),
259 Some(
260 "Google Generative AI API supports function declarations, \
261 but bitrouter-core provider tools do not map cleanly here"
262 .to_owned(),
263 ),
264 )),
265 }
266 }
267}
268
269pub(super) fn map_finish_reason(finish_reason: Option<&str>) -> LanguageModelFinishReason {
272 match finish_reason {
273 Some("STOP") | None => LanguageModelFinishReason::Stop,
274 Some("MAX_TOKENS") => LanguageModelFinishReason::Length,
275 Some("SAFETY")
276 | Some("RECITATION")
277 | Some("BLOCKLIST")
278 | Some("PROHIBITED_CONTENT")
279 | Some("SPII") => LanguageModelFinishReason::ContentFilter,
280 Some("MALFORMED_FUNCTION_CALL") => LanguageModelFinishReason::Error,
281 Some("LANGUAGE") => LanguageModelFinishReason::Other("LANGUAGE".to_owned()),
282 Some(other) => LanguageModelFinishReason::Other(other.to_owned()),
283 }
284}
285
286pub(super) fn google_metadata(model_version: Option<String>) -> Option<ProviderMetadata> {
287 let mut inner = HashMap::new();
288 if let Some(version) = model_version {
289 inner.insert("model_version".to_owned(), JsonValue::String(version));
290 }
291
292 if inner.is_empty() {
293 None
294 } else {
295 Some(HashMap::from([(
296 GOOGLE_PROVIDER_NAME.to_owned(),
297 json!(inner),
298 )]))
299 }
300}
301
302pub(super) fn empty_usage() -> LanguageModelUsage {
303 LanguageModelUsage {
304 input_tokens: LanguageModelInputTokens {
305 total: None,
306 no_cache: None,
307 cache_read: None,
308 cache_write: None,
309 },
310 output_tokens: LanguageModelOutputTokens {
311 total: None,
312 text: None,
313 reasoning: None,
314 },
315 raw: None,
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use bitrouter_core::models::language::usage::LanguageModelUsage;
323
324 #[test]
325 fn maps_stop_finish_reason() {
326 assert_eq!(
327 map_finish_reason(Some("STOP")),
328 LanguageModelFinishReason::Stop
329 );
330 }
331
332 #[test]
333 fn maps_all_finish_reasons() {
334 assert_eq!(
335 map_finish_reason(Some("STOP")),
336 LanguageModelFinishReason::Stop
337 );
338 assert_eq!(map_finish_reason(None), LanguageModelFinishReason::Stop);
339 assert_eq!(
340 map_finish_reason(Some("MAX_TOKENS")),
341 LanguageModelFinishReason::Length
342 );
343 assert_eq!(
344 map_finish_reason(Some("SAFETY")),
345 LanguageModelFinishReason::ContentFilter
346 );
347 assert_eq!(
348 map_finish_reason(Some("RECITATION")),
349 LanguageModelFinishReason::ContentFilter
350 );
351 assert_eq!(
352 map_finish_reason(Some("BLOCKLIST")),
353 LanguageModelFinishReason::ContentFilter
354 );
355 assert_eq!(
356 map_finish_reason(Some("PROHIBITED_CONTENT")),
357 LanguageModelFinishReason::ContentFilter
358 );
359 assert_eq!(
360 map_finish_reason(Some("SPII")),
361 LanguageModelFinishReason::ContentFilter
362 );
363 assert_eq!(
364 map_finish_reason(Some("MALFORMED_FUNCTION_CALL")),
365 LanguageModelFinishReason::Error
366 );
367 assert_eq!(
368 map_finish_reason(Some("LANGUAGE")),
369 LanguageModelFinishReason::Other("LANGUAGE".to_owned())
370 );
371 assert_eq!(
372 map_finish_reason(Some("unknown_reason")),
373 LanguageModelFinishReason::Other("unknown_reason".to_owned())
374 );
375 }
376
377 #[test]
378 fn google_usage_to_language_model_usage() {
379 let usage = GoogleUsageMetadata {
380 prompt_token_count: Some(100),
381 candidates_token_count: Some(50),
382 total_token_count: Some(150),
383 cached_content_token_count: Some(20),
384 };
385 let lm_usage: LanguageModelUsage = usage.into();
386 assert_eq!(lm_usage.input_tokens.total, Some(100));
387 assert_eq!(lm_usage.input_tokens.no_cache, Some(80));
388 assert_eq!(lm_usage.input_tokens.cache_read, Some(20));
389 assert_eq!(lm_usage.input_tokens.cache_write, None);
390 assert_eq!(lm_usage.output_tokens.total, Some(50));
391 assert_eq!(lm_usage.output_tokens.text, Some(50));
392 assert_eq!(lm_usage.output_tokens.reasoning, None);
393 }
394
395 #[test]
396 fn google_usage_without_cache() {
397 let usage = GoogleUsageMetadata {
398 prompt_token_count: Some(100),
399 candidates_token_count: Some(50),
400 total_token_count: Some(150),
401 cached_content_token_count: None,
402 };
403 let lm_usage: LanguageModelUsage = usage.into();
404 assert_eq!(lm_usage.input_tokens.total, Some(100));
405 assert_eq!(lm_usage.input_tokens.no_cache, Some(100));
406 assert_eq!(lm_usage.input_tokens.cache_read, None);
407 }
408
409 #[test]
410 fn deserialize_text_response() {
411 let json = r#"{
412 "candidates": [{
413 "content": {
414 "role": "model",
415 "parts": [{"text": "Hello!"}]
416 },
417 "finishReason": "STOP",
418 "index": 0
419 }],
420 "usageMetadata": {
421 "promptTokenCount": 10,
422 "candidatesTokenCount": 5,
423 "totalTokenCount": 15
424 },
425 "modelVersion": "gemini-2.0-flash"
426 }"#;
427 let response: GoogleGenerateContentResponse = serde_json::from_str(json).unwrap();
428 let candidates = response.candidates.unwrap();
429 assert_eq!(candidates.len(), 1);
430 let parts = candidates[0]
431 .content
432 .as_ref()
433 .unwrap()
434 .parts
435 .as_ref()
436 .unwrap();
437 assert_eq!(parts[0].text.as_deref(), Some("Hello!"));
438 assert_eq!(candidates[0].finish_reason.as_deref(), Some("STOP"));
439 assert_eq!(response.model_version.as_deref(), Some("gemini-2.0-flash"));
440 }
441
442 #[test]
443 fn deserialize_function_call_response() {
444 let json = r#"{
445 "candidates": [{
446 "content": {
447 "role": "model",
448 "parts": [{
449 "functionCall": {
450 "name": "get_weather",
451 "args": {"location": "Paris"}
452 }
453 }]
454 },
455 "finishReason": "STOP",
456 "index": 0
457 }],
458 "usageMetadata": {
459 "promptTokenCount": 20,
460 "candidatesTokenCount": 15,
461 "totalTokenCount": 35
462 }
463 }"#;
464 let response: GoogleGenerateContentResponse = serde_json::from_str(json).unwrap();
465 let candidates = response.candidates.unwrap();
466 let parts = candidates[0]
467 .content
468 .as_ref()
469 .unwrap()
470 .parts
471 .as_ref()
472 .unwrap();
473 assert!(parts[0].function_call.is_some());
474 assert_eq!(parts[0].function_call.as_ref().unwrap().name, "get_weather");
475 }
476
477 #[test]
478 fn serialize_request() {
479 let request = GoogleGenerateContentRequest {
480 contents: vec![GoogleContent {
481 role: Some("user".to_owned()),
482 parts: Some(vec![GooglePart {
483 text: Some("Hello".to_owned()),
484 inline_data: None,
485 function_call: None,
486 function_response: None,
487 }]),
488 }],
489 system_instruction: Some(GoogleContent {
490 role: None,
491 parts: Some(vec![GooglePart {
492 text: Some("You are a helpful assistant.".to_owned()),
493 inline_data: None,
494 function_call: None,
495 function_response: None,
496 }]),
497 }),
498 tools: None,
499 tool_config: None,
500 generation_config: Some(GoogleGenerationConfig {
501 temperature: Some(0.7),
502 top_p: None,
503 top_k: None,
504 max_output_tokens: Some(1024),
505 stop_sequences: None,
506 presence_penalty: None,
507 frequency_penalty: None,
508 seed: None,
509 response_mime_type: None,
510 response_schema: None,
511 }),
512 };
513 let json = serde_json::to_value(&request).unwrap();
514 assert_eq!(json["contents"][0]["role"], "user");
515 assert_eq!(json["contents"][0]["parts"][0]["text"], "Hello");
516 assert_eq!(
517 json["systemInstruction"]["parts"][0]["text"],
518 "You are a helpful assistant."
519 );
520 assert!(json["generationConfig"]["temperature"].as_f64().unwrap() - 0.7 < 0.01);
521 assert_eq!(json["generationConfig"]["maxOutputTokens"], 1024);
522 assert!(json.get("tools").is_none());
523 }
524
525 #[test]
526 fn tool_choice_auto() {
527 let config = GoogleFunctionCallingConfig::from(&LanguageModelToolChoice::Auto);
528 assert_eq!(config.mode, "AUTO");
529 assert!(config.allowed_function_names.is_none());
530 }
531
532 #[test]
533 fn tool_choice_none() {
534 let config = GoogleFunctionCallingConfig::from(&LanguageModelToolChoice::None);
535 assert_eq!(config.mode, "NONE");
536 }
537
538 #[test]
539 fn tool_choice_required_maps_to_any() {
540 let config = GoogleFunctionCallingConfig::from(&LanguageModelToolChoice::Required);
541 assert_eq!(config.mode, "ANY");
542 assert!(config.allowed_function_names.is_none());
543 }
544
545 #[test]
546 fn tool_choice_named() {
547 let config = GoogleFunctionCallingConfig::from(&LanguageModelToolChoice::Tool {
548 tool_name: "get_weather".to_owned(),
549 });
550 assert_eq!(config.mode, "ANY");
551 assert_eq!(
552 config.allowed_function_names.as_ref().unwrap(),
553 &["get_weather"]
554 );
555 }
556
557 #[test]
558 fn tool_conversion_function() {
559 let tool = LanguageModelTool::Function {
560 name: "test_tool".to_owned(),
561 description: Some("A test tool".to_owned()),
562 input_schema: schemars::Schema::default(),
563 input_examples: vec![],
564 strict: None,
565 provider_options: None,
566 };
567 let result = GoogleFunctionDeclaration::try_from(&tool);
568 assert!(result.is_ok());
569 let decl = result.unwrap();
570 assert_eq!(decl.name, "test_tool");
571 assert_eq!(decl.description.as_deref(), Some("A test tool"));
572 }
573
574 #[test]
575 fn tool_conversion_provider_fails() {
576 let tool = LanguageModelTool::Provider {
577 id: bitrouter_core::models::language::tool::ProviderToolId {
578 provider_name: "test".to_owned(),
579 tool_id: "123".to_owned(),
580 },
581 name: "test_tool".to_owned(),
582 args: HashMap::new(),
583 provider_options: None,
584 };
585 let result = GoogleFunctionDeclaration::try_from(&tool);
586 assert!(result.is_err());
587 }
588
589 #[test]
590 fn deserialize_error_envelope() {
591 let json = r#"{
592 "error": {
593 "code": 400,
594 "message": "Invalid value at 'contents'",
595 "status": "INVALID_ARGUMENT"
596 }
597 }"#;
598 let envelope: GoogleErrorEnvelope = serde_json::from_str(json).unwrap();
599 assert_eq!(envelope.error.code, Some(400));
600 assert_eq!(
601 envelope.error.message.as_deref(),
602 Some("Invalid value at 'contents'")
603 );
604 assert_eq!(envelope.error.status.as_deref(), Some("INVALID_ARGUMENT"));
605 }
606
607 #[test]
608 fn serialize_inline_data_part() {
609 let part = GooglePart {
610 text: None,
611 inline_data: Some(GoogleInlineData {
612 mime_type: "image/png".to_owned(),
613 data: "abc123".to_owned(),
614 }),
615 function_call: None,
616 function_response: None,
617 };
618 let json = serde_json::to_value(&part).unwrap();
619 assert_eq!(json["inlineData"]["mimeType"], "image/png");
620 assert_eq!(json["inlineData"]["data"], "abc123");
621 assert!(json.get("text").is_none());
622 }
623
624 #[test]
625 fn google_metadata_with_model_version() {
626 let meta = google_metadata(Some("gemini-2.0-flash".to_owned()));
627 assert!(meta.is_some());
628 let meta = meta.unwrap();
629 let inner = meta.get(GOOGLE_PROVIDER_NAME).unwrap();
630 assert_eq!(inner["model_version"], "gemini-2.0-flash");
631 }
632
633 #[test]
634 fn google_metadata_empty() {
635 let meta = google_metadata(None);
636 assert!(meta.is_none());
637 }
638
639 #[test]
640 fn request_roundtrip_with_tools() {
641 let request = GoogleGenerateContentRequest {
642 contents: vec![GoogleContent {
643 role: Some("user".to_owned()),
644 parts: Some(vec![GooglePart {
645 text: Some("Hello".to_owned()),
646 inline_data: None,
647 function_call: None,
648 function_response: None,
649 }]),
650 }],
651 system_instruction: None,
652 tools: Some(vec![GoogleTool {
653 function_declarations: Some(vec![GoogleFunctionDeclaration {
654 name: "get_weather".to_owned(),
655 description: Some("Get the weather".to_owned()),
656 parameters: Some(schemars::Schema::default()),
657 }]),
658 }]),
659 tool_config: Some(GoogleToolConfig {
660 function_calling_config: Some(GoogleFunctionCallingConfig {
661 mode: "AUTO".to_owned(),
662 allowed_function_names: None,
663 }),
664 }),
665 generation_config: None,
666 };
667 let json = serde_json::to_string(&request).unwrap();
668 let parsed: GoogleGenerateContentRequest = serde_json::from_str(&json).unwrap();
669 assert_eq!(parsed.contents.len(), 1);
670 assert_eq!(
671 parsed.tools.as_ref().unwrap()[0]
672 .function_declarations
673 .as_ref()
674 .unwrap()
675 .len(),
676 1
677 );
678 assert_eq!(
679 parsed
680 .tool_config
681 .as_ref()
682 .unwrap()
683 .function_calling_config
684 .as_ref()
685 .unwrap()
686 .mode,
687 "AUTO"
688 );
689 }
690}