Skip to main content

bitrouter_core/api/google/generate_content/
convert.rs

1//! Conversion between Google Generative AI format and core LanguageModel types.
2
3use std::collections::HashMap;
4
5use crate::models::{
6    language::{
7        call_options::LanguageModelCallOptions,
8        content::LanguageModelContent,
9        finish_reason::LanguageModelFinishReason,
10        generate_result::LanguageModelGenerateResult,
11        prompt::{
12            LanguageModelAssistantContent, LanguageModelMessage, LanguageModelToolResult,
13            LanguageModelToolResultOutput, LanguageModelUserContent,
14        },
15        stream_part::LanguageModelStreamPart,
16        tool::LanguageModelTool,
17        tool_choice::LanguageModelToolChoice,
18    },
19    shared::types::JsonSchema,
20};
21
22use super::types::{
23    GenerateContentCandidate, GenerateContentRequest, GenerateContentResponse,
24    GenerateContentUsageMetadata, GoogleContent, GoogleFunctionCall, GooglePart, GoogleToolConfig,
25};
26use crate::api::util::generate_id;
27
28/// Extracts the model name from a generate content request body.
29pub fn extract_model_name(request: &GenerateContentRequest) -> &str {
30    &request.model
31}
32
33/// Converts a [`GenerateContentRequest`] into [`LanguageModelCallOptions`].
34pub fn to_call_options(request: GenerateContentRequest) -> LanguageModelCallOptions {
35    let mut prompt: Vec<LanguageModelMessage> = Vec::new();
36
37    // Google system instruction is a top-level field.
38    if let Some(system) = request.system_instruction
39        && let Some(parts) = system.parts
40    {
41        let system_text: String = parts
42            .into_iter()
43            .filter_map(|p| p.text)
44            .collect::<Vec<_>>()
45            .join("");
46        if !system_text.is_empty() {
47            prompt.push(LanguageModelMessage::System {
48                content: system_text,
49                provider_options: None,
50            });
51        }
52    }
53
54    for content in request.contents {
55        match content.role.as_deref() {
56            Some("model") => {
57                let assistant_content = convert_model_parts(content.parts);
58                prompt.push(LanguageModelMessage::Assistant {
59                    content: assistant_content,
60                    provider_options: None,
61                });
62            }
63            _ => {
64                let (user_parts, tool_results) = split_google_parts(content.parts);
65                if !tool_results.is_empty() {
66                    prompt.push(LanguageModelMessage::Tool {
67                        content: tool_results,
68                        provider_options: None,
69                    });
70                }
71                if !user_parts.is_empty() {
72                    prompt.push(LanguageModelMessage::User {
73                        content: user_parts,
74                        provider_options: None,
75                    });
76                }
77            }
78        }
79    }
80
81    let tools = request.tools.map(|tool_groups| {
82        tool_groups
83            .into_iter()
84            .flat_map(|t| t.function_declarations.unwrap_or_default())
85            .map(|fd| {
86                let schema_value = fd.parameters.unwrap_or(serde_json::json!({}));
87                let input_schema: JsonSchema =
88                    serde_json::from_value(schema_value).unwrap_or_default();
89                LanguageModelTool::Function {
90                    name: fd.name,
91                    description: fd.description,
92                    input_schema,
93                    input_examples: vec![],
94                    strict: None,
95                    provider_options: None,
96                }
97            })
98            .collect()
99    });
100
101    let tool_choice = request.tool_config.and_then(convert_tool_config);
102
103    let (max_output_tokens, temperature, top_p, top_k, stop_sequences) =
104        if let Some(config) = request.generation_config {
105            (
106                config.max_output_tokens,
107                config.temperature,
108                config.top_p,
109                config.top_k,
110                config.stop_sequences,
111            )
112        } else {
113            (None, None, None, None, None)
114        };
115
116    LanguageModelCallOptions {
117        prompt,
118        stream: request.stream,
119        max_output_tokens,
120        temperature,
121        top_p,
122        top_k,
123        stop_sequences,
124        presence_penalty: None,
125        frequency_penalty: None,
126        response_format: None,
127        seed: None,
128        tools,
129        tool_choice,
130        include_raw_chunks: None,
131        abort_signal: None,
132        headers: None,
133        provider_options: None,
134    }
135}
136
137/// Converts a [`LanguageModelGenerateResult`] into a [`GenerateContentResponse`].
138pub fn from_generate_result(
139    model_id: &str,
140    result: LanguageModelGenerateResult,
141) -> GenerateContentResponse {
142    let parts = extract_response_parts(&result.content);
143    let finish_reason = map_finish_reason(&result.finish_reason);
144    let input_tokens = result.usage.input_tokens.total.unwrap_or(0);
145    let output_tokens = result.usage.output_tokens.total.unwrap_or(0);
146
147    GenerateContentResponse {
148        candidates: Some(vec![GenerateContentCandidate {
149            content: Some(GoogleContent {
150                role: Some("model".to_owned()),
151                parts: Some(parts),
152            }),
153            finish_reason: Some(finish_reason),
154            index: Some(0),
155        }]),
156        usage_metadata: Some(GenerateContentUsageMetadata {
157            prompt_token_count: Some(input_tokens),
158            candidates_token_count: Some(output_tokens),
159            total_token_count: Some(input_tokens + output_tokens),
160            cached_content_token_count: None,
161        }),
162        model_version: Some(model_id.to_owned()),
163    }
164}
165
166// ── Streaming ───────────────────────────────────────────────────────────────
167
168/// Stateful converter that accumulates incremental tool-call data.
169pub struct StreamConverter {
170    model_id: String,
171    pending_calls: HashMap<String, PendingFunctionCall>,
172}
173
174struct PendingFunctionCall {
175    name: String,
176    args_buffer: String,
177}
178
179impl StreamConverter {
180    pub fn new(model_id: String) -> Self {
181        Self {
182            model_id,
183            pending_calls: HashMap::new(),
184        }
185    }
186
187    /// Converts a [`LanguageModelStreamPart`] into a [`GenerateContentResponse`].
188    pub fn convert(&mut self, part: &LanguageModelStreamPart) -> Option<GenerateContentResponse> {
189        match part {
190            LanguageModelStreamPart::TextDelta { delta, .. } => Some(self.make_chunk(
191                vec![GooglePart {
192                    text: Some(delta.clone()),
193                    inline_data: None,
194                    function_call: None,
195                    function_response: None,
196                }],
197                None,
198                None,
199                None,
200            )),
201            LanguageModelStreamPart::ToolCall {
202                tool_name,
203                tool_input,
204                ..
205            } => {
206                let args: serde_json::Value = serde_json::from_str(tool_input).unwrap_or_default();
207                Some(self.make_chunk(
208                    vec![GooglePart {
209                        text: None,
210                        inline_data: None,
211                        function_call: Some(GoogleFunctionCall {
212                            name: tool_name.clone(),
213                            args: Some(args),
214                        }),
215                        function_response: None,
216                    }],
217                    None,
218                    None,
219                    None,
220                ))
221            }
222            LanguageModelStreamPart::ToolInputStart { id, tool_name, .. } => {
223                self.pending_calls.insert(
224                    id.clone(),
225                    PendingFunctionCall {
226                        name: tool_name.clone(),
227                        args_buffer: String::new(),
228                    },
229                );
230                None
231            }
232            LanguageModelStreamPart::ToolInputDelta { id, delta, .. } => {
233                if let Some(pending) = self.pending_calls.get_mut(id) {
234                    pending.args_buffer.push_str(delta);
235                }
236                None
237            }
238            LanguageModelStreamPart::ToolInputEnd { id, .. } => {
239                if let Some(pending) = self.pending_calls.remove(id) {
240                    let args: serde_json::Value =
241                        serde_json::from_str(&pending.args_buffer).unwrap_or_default();
242                    Some(self.make_chunk(
243                        vec![GooglePart {
244                            text: None,
245                            inline_data: None,
246                            function_call: Some(GoogleFunctionCall {
247                                name: pending.name,
248                                args: Some(args),
249                            }),
250                            function_response: None,
251                        }],
252                        None,
253                        None,
254                        None,
255                    ))
256                } else {
257                    None
258                }
259            }
260            LanguageModelStreamPart::Finish {
261                finish_reason,
262                usage,
263                ..
264            } => {
265                let input_tokens = usage.input_tokens.total.unwrap_or(0);
266                let output_tokens = usage.output_tokens.total.unwrap_or(0);
267                Some(self.make_chunk(
268                    vec![GooglePart {
269                        text: Some(String::new()),
270                        inline_data: None,
271                        function_call: None,
272                        function_response: None,
273                    }],
274                    Some(map_finish_reason(finish_reason)),
275                    Some(GenerateContentUsageMetadata {
276                        prompt_token_count: Some(input_tokens),
277                        candidates_token_count: Some(output_tokens),
278                        total_token_count: Some(input_tokens + output_tokens),
279                        cached_content_token_count: None,
280                    }),
281                    Some(self.model_id.clone()),
282                ))
283            }
284            _ => None,
285        }
286    }
287
288    fn make_chunk(
289        &self,
290        parts: Vec<GooglePart>,
291        finish_reason: Option<String>,
292        usage_metadata: Option<GenerateContentUsageMetadata>,
293        model_version: Option<String>,
294    ) -> GenerateContentResponse {
295        GenerateContentResponse {
296            candidates: Some(vec![GenerateContentCandidate {
297                content: Some(GoogleContent {
298                    role: Some("model".to_owned()),
299                    parts: Some(parts),
300                }),
301                finish_reason,
302                index: Some(0),
303            }]),
304            usage_metadata,
305            model_version,
306        }
307    }
308}
309
310// ── Helpers ─────────────────────────────────────────────────────────────────
311
312fn convert_model_parts(parts: Option<Vec<GooglePart>>) -> Vec<LanguageModelAssistantContent> {
313    parts
314        .unwrap_or_default()
315        .into_iter()
316        .filter_map(|p| {
317            if let Some(fc) = p.function_call {
318                Some(LanguageModelAssistantContent::ToolCall {
319                    tool_call_id: format!("call-{}", generate_id()),
320                    tool_name: fc.name,
321                    input: fc.args.unwrap_or_default(),
322                    provider_executed: None,
323                    provider_options: None,
324                })
325            } else {
326                p.text.map(|text| LanguageModelAssistantContent::Text {
327                    text,
328                    provider_options: None,
329                })
330            }
331        })
332        .collect()
333}
334
335fn split_google_parts(
336    parts: Option<Vec<GooglePart>>,
337) -> (Vec<LanguageModelUserContent>, Vec<LanguageModelToolResult>) {
338    let mut user_parts = Vec::new();
339    let mut tool_results = Vec::new();
340    for part in parts.unwrap_or_default() {
341        if let Some(fr) = part.function_response {
342            let output_text = match fr.response {
343                serde_json::Value::String(s) => s,
344                other => serde_json::to_string(&other).unwrap_or_default(),
345            };
346            tool_results.push(LanguageModelToolResult::ToolResult {
347                tool_call_id: String::new(),
348                tool_name: fr.name,
349                output: LanguageModelToolResultOutput::Text {
350                    value: output_text,
351                    provider_options: None,
352                },
353                provider_options: None,
354            });
355        } else if let Some(text) = part.text {
356            user_parts.push(LanguageModelUserContent::Text {
357                text,
358                provider_options: None,
359            });
360        }
361    }
362    (user_parts, tool_results)
363}
364
365fn convert_tool_config(config: GoogleToolConfig) -> Option<LanguageModelToolChoice> {
366    let fcc = config.function_calling_config?;
367    let mode = fcc.mode?;
368    match mode.as_str() {
369        "AUTO" => Some(LanguageModelToolChoice::Auto),
370        "NONE" => Some(LanguageModelToolChoice::None),
371        "ANY" => {
372            if let Some(names) = fcc.allowed_function_names
373                && names.len() == 1
374            {
375                Some(LanguageModelToolChoice::Tool {
376                    tool_name: names.into_iter().next().unwrap_or_default(),
377                })
378            } else {
379                Some(LanguageModelToolChoice::Required)
380            }
381        }
382        _ => None,
383    }
384}
385
386fn extract_response_parts(content: &LanguageModelContent) -> Vec<GooglePart> {
387    match content {
388        LanguageModelContent::Text { text, .. } => vec![GooglePart {
389            text: Some(text.clone()),
390            inline_data: None,
391            function_call: None,
392            function_response: None,
393        }],
394        LanguageModelContent::ToolCall {
395            tool_name,
396            tool_input,
397            ..
398        } => {
399            let args: serde_json::Value = serde_json::from_str(tool_input).unwrap_or_default();
400            vec![GooglePart {
401                text: None,
402                inline_data: None,
403                function_call: Some(GoogleFunctionCall {
404                    name: tool_name.clone(),
405                    args: Some(args),
406                }),
407                function_response: None,
408            }]
409        }
410        _ => vec![],
411    }
412}
413
414fn map_finish_reason(reason: &LanguageModelFinishReason) -> String {
415    match reason {
416        LanguageModelFinishReason::Stop => "STOP".to_owned(),
417        LanguageModelFinishReason::Length => "MAX_TOKENS".to_owned(),
418        LanguageModelFinishReason::FunctionCall => "STOP".to_owned(),
419        LanguageModelFinishReason::ContentFilter => "SAFETY".to_owned(),
420        LanguageModelFinishReason::Error => "OTHER".to_owned(),
421        LanguageModelFinishReason::Other(other) => other.clone(),
422    }
423}