Skip to main content

bitrouter_google/generate_content/
types.rs

1use std::collections::HashMap;
2
3use bitrouter_core::{
4    errors::{BitrouterError, Result},
5    models::{
6        language::{
7            finish_reason::LanguageModelFinishReason,
8            tool::LanguageModelTool,
9            tool_choice::LanguageModelToolChoice,
10            usage::{LanguageModelInputTokens, LanguageModelOutputTokens, LanguageModelUsage},
11        },
12        shared::{provider::ProviderMetadata, types::JsonValue},
13    },
14};
15use serde::{Deserialize, Serialize};
16use serde_json::json;
17
18pub(super) const GOOGLE_PROVIDER_NAME: &str = "google";
19pub(super) const STREAM_TEXT_ID: &str = "text";
20
21// ── Response types ──────────────────────────────────────────────────────────
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(rename_all = "camelCase")]
25pub struct GoogleGenerateContentResponse {
26    #[serde(default)]
27    pub candidates: Option<Vec<GoogleCandidate>>,
28    #[serde(default)]
29    pub usage_metadata: Option<GoogleUsageMetadata>,
30    #[serde(default)]
31    pub model_version: Option<String>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(rename_all = "camelCase")]
36pub struct GoogleCandidate {
37    #[serde(default)]
38    pub content: Option<GoogleContent>,
39    #[serde(default)]
40    pub finish_reason: Option<String>,
41    #[serde(default)]
42    pub index: Option<u32>,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct GoogleContent {
48    #[serde(default)]
49    pub role: Option<String>,
50    #[serde(default)]
51    pub parts: Option<Vec<GooglePart>>,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56pub struct GooglePart {
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub text: Option<String>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub inline_data: Option<GoogleInlineData>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub function_call: Option<GoogleFunctionCall>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub function_response: Option<GoogleFunctionResponse>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68#[serde(rename_all = "camelCase")]
69pub struct GoogleInlineData {
70    pub mime_type: String,
71    pub data: String,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75#[serde(rename_all = "camelCase")]
76pub struct GoogleFunctionCall {
77    pub name: String,
78    #[serde(default)]
79    pub args: Option<JsonValue>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83#[serde(rename_all = "camelCase")]
84pub struct GoogleFunctionResponse {
85    pub name: String,
86    pub response: JsonValue,
87}
88
89// ── Usage types ─────────────────────────────────────────────────────────────
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92#[serde(rename_all = "camelCase")]
93pub struct GoogleUsageMetadata {
94    #[serde(default)]
95    pub prompt_token_count: Option<u32>,
96    #[serde(default)]
97    pub candidates_token_count: Option<u32>,
98    #[serde(default)]
99    pub total_token_count: Option<u32>,
100    #[serde(default)]
101    pub cached_content_token_count: Option<u32>,
102}
103
104impl From<GoogleUsageMetadata> for LanguageModelUsage {
105    fn from(usage: GoogleUsageMetadata) -> Self {
106        let raw = serde_json::to_value(&usage).ok();
107        LanguageModelUsage {
108            input_tokens: LanguageModelInputTokens {
109                total: usage.prompt_token_count,
110                no_cache: usage.prompt_token_count.map(|total| {
111                    total.saturating_sub(usage.cached_content_token_count.unwrap_or(0))
112                }),
113                cache_read: usage.cached_content_token_count,
114                cache_write: None,
115            },
116            output_tokens: LanguageModelOutputTokens {
117                total: usage.candidates_token_count,
118                text: usage.candidates_token_count,
119                reasoning: None,
120            },
121            raw,
122        }
123    }
124}
125
126// ── Error types ─────────────────────────────────────────────────────────────
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct GoogleErrorEnvelope {
130    pub error: GoogleApiError,
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
134pub struct GoogleApiError {
135    #[serde(default)]
136    pub code: Option<u16>,
137    #[serde(default)]
138    pub message: Option<String>,
139    #[serde(default)]
140    pub status: Option<String>,
141}
142
143// ── Request types ───────────────────────────────────────────────────────────
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
146#[serde(rename_all = "camelCase")]
147pub struct GoogleGenerateContentRequest {
148    pub contents: Vec<GoogleContent>,
149    #[serde(skip_serializing_if = "Option::is_none")]
150    pub system_instruction: Option<GoogleContent>,
151    #[serde(skip_serializing_if = "Option::is_none")]
152    pub tools: Option<Vec<GoogleTool>>,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub tool_config: Option<GoogleToolConfig>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub generation_config: Option<GoogleGenerationConfig>,
157}
158
159#[derive(Debug, Clone, Serialize, Deserialize)]
160#[serde(rename_all = "camelCase")]
161pub struct GoogleGenerationConfig {
162    #[serde(skip_serializing_if = "Option::is_none")]
163    pub temperature: Option<f32>,
164    #[serde(skip_serializing_if = "Option::is_none")]
165    pub top_p: Option<f32>,
166    #[serde(skip_serializing_if = "Option::is_none")]
167    pub top_k: Option<u32>,
168    #[serde(skip_serializing_if = "Option::is_none")]
169    pub max_output_tokens: Option<u32>,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub stop_sequences: Option<Vec<String>>,
172    #[serde(skip_serializing_if = "Option::is_none")]
173    pub presence_penalty: Option<f32>,
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub frequency_penalty: Option<f32>,
176    #[serde(skip_serializing_if = "Option::is_none")]
177    pub seed: Option<i64>,
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub response_mime_type: Option<String>,
180    #[serde(skip_serializing_if = "Option::is_none")]
181    pub response_schema: Option<schemars::Schema>,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185#[serde(rename_all = "camelCase")]
186pub struct GoogleTool {
187    #[serde(skip_serializing_if = "Option::is_none")]
188    pub function_declarations: Option<Vec<GoogleFunctionDeclaration>>,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192#[serde(rename_all = "camelCase")]
193pub struct GoogleFunctionDeclaration {
194    pub name: String,
195    #[serde(skip_serializing_if = "Option::is_none")]
196    pub description: Option<String>,
197    #[serde(skip_serializing_if = "Option::is_none")]
198    pub parameters: Option<schemars::Schema>,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
202#[serde(rename_all = "camelCase")]
203pub struct GoogleToolConfig {
204    #[serde(skip_serializing_if = "Option::is_none")]
205    pub function_calling_config: Option<GoogleFunctionCallingConfig>,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209#[serde(rename_all = "camelCase")]
210pub struct GoogleFunctionCallingConfig {
211    pub mode: String,
212    #[serde(skip_serializing_if = "Option::is_none")]
213    pub allowed_function_names: Option<Vec<String>>,
214}
215
216// ── From / TryFrom conversions ──────────────────────────────────────────────
217
218impl From<&LanguageModelToolChoice> for GoogleFunctionCallingConfig {
219    fn from(choice: &LanguageModelToolChoice) -> Self {
220        match choice {
221            LanguageModelToolChoice::Auto => GoogleFunctionCallingConfig {
222                mode: "AUTO".to_owned(),
223                allowed_function_names: None,
224            },
225            LanguageModelToolChoice::None => GoogleFunctionCallingConfig {
226                mode: "NONE".to_owned(),
227                allowed_function_names: None,
228            },
229            LanguageModelToolChoice::Required => GoogleFunctionCallingConfig {
230                mode: "ANY".to_owned(),
231                allowed_function_names: None,
232            },
233            LanguageModelToolChoice::Tool { tool_name } => GoogleFunctionCallingConfig {
234                mode: "ANY".to_owned(),
235                allowed_function_names: Some(vec![tool_name.clone()]),
236            },
237        }
238    }
239}
240
241impl TryFrom<&LanguageModelTool> for GoogleFunctionDeclaration {
242    type Error = BitrouterError;
243
244    fn try_from(tool: &LanguageModelTool) -> Result<Self> {
245        match tool {
246            LanguageModelTool::Function {
247                name,
248                description,
249                input_schema,
250                ..
251            } => Ok(GoogleFunctionDeclaration {
252                name: name.clone(),
253                description: description.clone(),
254                parameters: Some(input_schema.clone()),
255            }),
256            LanguageModelTool::Provider { id, .. } => Err(BitrouterError::unsupported(
257                GOOGLE_PROVIDER_NAME,
258                format!("provider tool {}:{}", id.provider_name, id.tool_id),
259                Some(
260                    "Google Generative AI API supports function declarations, \
261                     but bitrouter-core provider tools do not map cleanly here"
262                        .to_owned(),
263                ),
264            )),
265        }
266    }
267}
268
269// ── Helper functions ────────────────────────────────────────────────────────
270
271pub(super) fn map_finish_reason(finish_reason: Option<&str>) -> LanguageModelFinishReason {
272    match finish_reason {
273        Some("STOP") | None => LanguageModelFinishReason::Stop,
274        Some("MAX_TOKENS") => LanguageModelFinishReason::Length,
275        Some("SAFETY")
276        | Some("RECITATION")
277        | Some("BLOCKLIST")
278        | Some("PROHIBITED_CONTENT")
279        | Some("SPII") => LanguageModelFinishReason::ContentFilter,
280        Some("MALFORMED_FUNCTION_CALL") => LanguageModelFinishReason::Error,
281        Some("LANGUAGE") => LanguageModelFinishReason::Other("LANGUAGE".to_owned()),
282        Some(other) => LanguageModelFinishReason::Other(other.to_owned()),
283    }
284}
285
286pub(super) fn google_metadata(model_version: Option<String>) -> Option<ProviderMetadata> {
287    let mut inner = HashMap::new();
288    if let Some(version) = model_version {
289        inner.insert("model_version".to_owned(), JsonValue::String(version));
290    }
291
292    if inner.is_empty() {
293        None
294    } else {
295        Some(HashMap::from([(
296            GOOGLE_PROVIDER_NAME.to_owned(),
297            json!(inner),
298        )]))
299    }
300}
301
302pub(super) fn empty_usage() -> LanguageModelUsage {
303    LanguageModelUsage {
304        input_tokens: LanguageModelInputTokens {
305            total: None,
306            no_cache: None,
307            cache_read: None,
308            cache_write: None,
309        },
310        output_tokens: LanguageModelOutputTokens {
311            total: None,
312            text: None,
313            reasoning: None,
314        },
315        raw: None,
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use bitrouter_core::models::language::usage::LanguageModelUsage;
323
324    #[test]
325    fn maps_stop_finish_reason() {
326        assert_eq!(
327            map_finish_reason(Some("STOP")),
328            LanguageModelFinishReason::Stop
329        );
330    }
331
332    #[test]
333    fn maps_all_finish_reasons() {
334        assert_eq!(
335            map_finish_reason(Some("STOP")),
336            LanguageModelFinishReason::Stop
337        );
338        assert_eq!(map_finish_reason(None), LanguageModelFinishReason::Stop);
339        assert_eq!(
340            map_finish_reason(Some("MAX_TOKENS")),
341            LanguageModelFinishReason::Length
342        );
343        assert_eq!(
344            map_finish_reason(Some("SAFETY")),
345            LanguageModelFinishReason::ContentFilter
346        );
347        assert_eq!(
348            map_finish_reason(Some("RECITATION")),
349            LanguageModelFinishReason::ContentFilter
350        );
351        assert_eq!(
352            map_finish_reason(Some("BLOCKLIST")),
353            LanguageModelFinishReason::ContentFilter
354        );
355        assert_eq!(
356            map_finish_reason(Some("PROHIBITED_CONTENT")),
357            LanguageModelFinishReason::ContentFilter
358        );
359        assert_eq!(
360            map_finish_reason(Some("SPII")),
361            LanguageModelFinishReason::ContentFilter
362        );
363        assert_eq!(
364            map_finish_reason(Some("MALFORMED_FUNCTION_CALL")),
365            LanguageModelFinishReason::Error
366        );
367        assert_eq!(
368            map_finish_reason(Some("LANGUAGE")),
369            LanguageModelFinishReason::Other("LANGUAGE".to_owned())
370        );
371        assert_eq!(
372            map_finish_reason(Some("unknown_reason")),
373            LanguageModelFinishReason::Other("unknown_reason".to_owned())
374        );
375    }
376
377    #[test]
378    fn google_usage_to_language_model_usage() {
379        let usage = GoogleUsageMetadata {
380            prompt_token_count: Some(100),
381            candidates_token_count: Some(50),
382            total_token_count: Some(150),
383            cached_content_token_count: Some(20),
384        };
385        let lm_usage: LanguageModelUsage = usage.into();
386        assert_eq!(lm_usage.input_tokens.total, Some(100));
387        assert_eq!(lm_usage.input_tokens.no_cache, Some(80));
388        assert_eq!(lm_usage.input_tokens.cache_read, Some(20));
389        assert_eq!(lm_usage.input_tokens.cache_write, None);
390        assert_eq!(lm_usage.output_tokens.total, Some(50));
391        assert_eq!(lm_usage.output_tokens.text, Some(50));
392        assert_eq!(lm_usage.output_tokens.reasoning, None);
393    }
394
395    #[test]
396    fn google_usage_without_cache() {
397        let usage = GoogleUsageMetadata {
398            prompt_token_count: Some(100),
399            candidates_token_count: Some(50),
400            total_token_count: Some(150),
401            cached_content_token_count: None,
402        };
403        let lm_usage: LanguageModelUsage = usage.into();
404        assert_eq!(lm_usage.input_tokens.total, Some(100));
405        assert_eq!(lm_usage.input_tokens.no_cache, Some(100));
406        assert_eq!(lm_usage.input_tokens.cache_read, None);
407    }
408
409    #[test]
410    fn deserialize_text_response() {
411        let json = r#"{
412            "candidates": [{
413                "content": {
414                    "role": "model",
415                    "parts": [{"text": "Hello!"}]
416                },
417                "finishReason": "STOP",
418                "index": 0
419            }],
420            "usageMetadata": {
421                "promptTokenCount": 10,
422                "candidatesTokenCount": 5,
423                "totalTokenCount": 15
424            },
425            "modelVersion": "gemini-2.0-flash"
426        }"#;
427        let response: GoogleGenerateContentResponse = serde_json::from_str(json).unwrap();
428        let candidates = response.candidates.unwrap();
429        assert_eq!(candidates.len(), 1);
430        let parts = candidates[0]
431            .content
432            .as_ref()
433            .unwrap()
434            .parts
435            .as_ref()
436            .unwrap();
437        assert_eq!(parts[0].text.as_deref(), Some("Hello!"));
438        assert_eq!(candidates[0].finish_reason.as_deref(), Some("STOP"));
439        assert_eq!(response.model_version.as_deref(), Some("gemini-2.0-flash"));
440    }
441
442    #[test]
443    fn deserialize_function_call_response() {
444        let json = r#"{
445            "candidates": [{
446                "content": {
447                    "role": "model",
448                    "parts": [{
449                        "functionCall": {
450                            "name": "get_weather",
451                            "args": {"location": "Paris"}
452                        }
453                    }]
454                },
455                "finishReason": "STOP",
456                "index": 0
457            }],
458            "usageMetadata": {
459                "promptTokenCount": 20,
460                "candidatesTokenCount": 15,
461                "totalTokenCount": 35
462            }
463        }"#;
464        let response: GoogleGenerateContentResponse = serde_json::from_str(json).unwrap();
465        let candidates = response.candidates.unwrap();
466        let parts = candidates[0]
467            .content
468            .as_ref()
469            .unwrap()
470            .parts
471            .as_ref()
472            .unwrap();
473        assert!(parts[0].function_call.is_some());
474        assert_eq!(parts[0].function_call.as_ref().unwrap().name, "get_weather");
475    }
476
477    #[test]
478    fn serialize_request() {
479        let request = GoogleGenerateContentRequest {
480            contents: vec![GoogleContent {
481                role: Some("user".to_owned()),
482                parts: Some(vec![GooglePart {
483                    text: Some("Hello".to_owned()),
484                    inline_data: None,
485                    function_call: None,
486                    function_response: None,
487                }]),
488            }],
489            system_instruction: Some(GoogleContent {
490                role: None,
491                parts: Some(vec![GooglePart {
492                    text: Some("You are a helpful assistant.".to_owned()),
493                    inline_data: None,
494                    function_call: None,
495                    function_response: None,
496                }]),
497            }),
498            tools: None,
499            tool_config: None,
500            generation_config: Some(GoogleGenerationConfig {
501                temperature: Some(0.7),
502                top_p: None,
503                top_k: None,
504                max_output_tokens: Some(1024),
505                stop_sequences: None,
506                presence_penalty: None,
507                frequency_penalty: None,
508                seed: None,
509                response_mime_type: None,
510                response_schema: None,
511            }),
512        };
513        let json = serde_json::to_value(&request).unwrap();
514        assert_eq!(json["contents"][0]["role"], "user");
515        assert_eq!(json["contents"][0]["parts"][0]["text"], "Hello");
516        assert_eq!(
517            json["systemInstruction"]["parts"][0]["text"],
518            "You are a helpful assistant."
519        );
520        assert!(json["generationConfig"]["temperature"].as_f64().unwrap() - 0.7 < 0.01);
521        assert_eq!(json["generationConfig"]["maxOutputTokens"], 1024);
522        assert!(json.get("tools").is_none());
523    }
524
525    #[test]
526    fn tool_choice_auto() {
527        let config = GoogleFunctionCallingConfig::from(&LanguageModelToolChoice::Auto);
528        assert_eq!(config.mode, "AUTO");
529        assert!(config.allowed_function_names.is_none());
530    }
531
532    #[test]
533    fn tool_choice_none() {
534        let config = GoogleFunctionCallingConfig::from(&LanguageModelToolChoice::None);
535        assert_eq!(config.mode, "NONE");
536    }
537
538    #[test]
539    fn tool_choice_required_maps_to_any() {
540        let config = GoogleFunctionCallingConfig::from(&LanguageModelToolChoice::Required);
541        assert_eq!(config.mode, "ANY");
542        assert!(config.allowed_function_names.is_none());
543    }
544
545    #[test]
546    fn tool_choice_named() {
547        let config = GoogleFunctionCallingConfig::from(&LanguageModelToolChoice::Tool {
548            tool_name: "get_weather".to_owned(),
549        });
550        assert_eq!(config.mode, "ANY");
551        assert_eq!(
552            config.allowed_function_names.as_ref().unwrap(),
553            &["get_weather"]
554        );
555    }
556
557    #[test]
558    fn tool_conversion_function() {
559        let tool = LanguageModelTool::Function {
560            name: "test_tool".to_owned(),
561            description: Some("A test tool".to_owned()),
562            input_schema: schemars::Schema::default(),
563            input_examples: vec![],
564            strict: None,
565            provider_options: None,
566        };
567        let result = GoogleFunctionDeclaration::try_from(&tool);
568        assert!(result.is_ok());
569        let decl = result.unwrap();
570        assert_eq!(decl.name, "test_tool");
571        assert_eq!(decl.description.as_deref(), Some("A test tool"));
572    }
573
574    #[test]
575    fn tool_conversion_provider_fails() {
576        let tool = LanguageModelTool::Provider {
577            id: bitrouter_core::models::language::tool::ProviderToolId {
578                provider_name: "test".to_owned(),
579                tool_id: "123".to_owned(),
580            },
581            name: "test_tool".to_owned(),
582            args: HashMap::new(),
583            provider_options: None,
584        };
585        let result = GoogleFunctionDeclaration::try_from(&tool);
586        assert!(result.is_err());
587    }
588
589    #[test]
590    fn deserialize_error_envelope() {
591        let json = r#"{
592            "error": {
593                "code": 400,
594                "message": "Invalid value at 'contents'",
595                "status": "INVALID_ARGUMENT"
596            }
597        }"#;
598        let envelope: GoogleErrorEnvelope = serde_json::from_str(json).unwrap();
599        assert_eq!(envelope.error.code, Some(400));
600        assert_eq!(
601            envelope.error.message.as_deref(),
602            Some("Invalid value at 'contents'")
603        );
604        assert_eq!(envelope.error.status.as_deref(), Some("INVALID_ARGUMENT"));
605    }
606
607    #[test]
608    fn serialize_inline_data_part() {
609        let part = GooglePart {
610            text: None,
611            inline_data: Some(GoogleInlineData {
612                mime_type: "image/png".to_owned(),
613                data: "abc123".to_owned(),
614            }),
615            function_call: None,
616            function_response: None,
617        };
618        let json = serde_json::to_value(&part).unwrap();
619        assert_eq!(json["inlineData"]["mimeType"], "image/png");
620        assert_eq!(json["inlineData"]["data"], "abc123");
621        assert!(json.get("text").is_none());
622    }
623
624    #[test]
625    fn google_metadata_with_model_version() {
626        let meta = google_metadata(Some("gemini-2.0-flash".to_owned()));
627        assert!(meta.is_some());
628        let meta = meta.unwrap();
629        let inner = meta.get(GOOGLE_PROVIDER_NAME).unwrap();
630        assert_eq!(inner["model_version"], "gemini-2.0-flash");
631    }
632
633    #[test]
634    fn google_metadata_empty() {
635        let meta = google_metadata(None);
636        assert!(meta.is_none());
637    }
638
639    #[test]
640    fn request_roundtrip_with_tools() {
641        let request = GoogleGenerateContentRequest {
642            contents: vec![GoogleContent {
643                role: Some("user".to_owned()),
644                parts: Some(vec![GooglePart {
645                    text: Some("Hello".to_owned()),
646                    inline_data: None,
647                    function_call: None,
648                    function_response: None,
649                }]),
650            }],
651            system_instruction: None,
652            tools: Some(vec![GoogleTool {
653                function_declarations: Some(vec![GoogleFunctionDeclaration {
654                    name: "get_weather".to_owned(),
655                    description: Some("Get the weather".to_owned()),
656                    parameters: Some(schemars::Schema::default()),
657                }]),
658            }]),
659            tool_config: Some(GoogleToolConfig {
660                function_calling_config: Some(GoogleFunctionCallingConfig {
661                    mode: "AUTO".to_owned(),
662                    allowed_function_names: None,
663                }),
664            }),
665            generation_config: None,
666        };
667        let json = serde_json::to_string(&request).unwrap();
668        let parsed: GoogleGenerateContentRequest = serde_json::from_str(&json).unwrap();
669        assert_eq!(parsed.contents.len(), 1);
670        assert_eq!(
671            parsed.tools.as_ref().unwrap()[0]
672                .function_declarations
673                .as_ref()
674                .unwrap()
675                .len(),
676            1
677        );
678        assert_eq!(
679            parsed
680                .tool_config
681                .as_ref()
682                .unwrap()
683                .function_calling_config
684                .as_ref()
685                .unwrap()
686                .mode,
687            "AUTO"
688        );
689    }
690}