1use std::collections::HashMap;
4
5use crate::models::{
6 language::{
7 call_options::LanguageModelCallOptions,
8 content::LanguageModelContent,
9 finish_reason::LanguageModelFinishReason,
10 generate_result::LanguageModelGenerateResult,
11 prompt::{
12 LanguageModelAssistantContent, LanguageModelMessage, LanguageModelToolResult,
13 LanguageModelToolResultOutput, LanguageModelUserContent,
14 },
15 stream_part::LanguageModelStreamPart,
16 tool::LanguageModelTool,
17 tool_choice::LanguageModelToolChoice,
18 },
19 shared::types::JsonSchema,
20};
21
22use super::types::{
23 GenerateContentCandidate, GenerateContentRequest, GenerateContentResponse,
24 GenerateContentUsageMetadata, GoogleContent, GoogleFunctionCall, GooglePart, GoogleToolConfig,
25};
26use crate::api::util::generate_id;
27
28pub fn extract_model_name(request: &GenerateContentRequest) -> &str {
30 &request.model
31}
32
33pub fn to_call_options(request: GenerateContentRequest) -> LanguageModelCallOptions {
35 let mut prompt: Vec<LanguageModelMessage> = Vec::new();
36
37 if let Some(system) = request.system_instruction
39 && let Some(parts) = system.parts
40 {
41 let system_text: String = parts
42 .into_iter()
43 .filter_map(|p| p.text)
44 .collect::<Vec<_>>()
45 .join("");
46 if !system_text.is_empty() {
47 prompt.push(LanguageModelMessage::System {
48 content: system_text,
49 provider_options: None,
50 });
51 }
52 }
53
54 for content in request.contents {
55 match content.role.as_deref() {
56 Some("model") => {
57 let assistant_content = convert_model_parts(content.parts);
58 prompt.push(LanguageModelMessage::Assistant {
59 content: assistant_content,
60 provider_options: None,
61 });
62 }
63 _ => {
64 let (user_parts, tool_results) = split_google_parts(content.parts);
65 if !tool_results.is_empty() {
66 prompt.push(LanguageModelMessage::Tool {
67 content: tool_results,
68 provider_options: None,
69 });
70 }
71 if !user_parts.is_empty() {
72 prompt.push(LanguageModelMessage::User {
73 content: user_parts,
74 provider_options: None,
75 });
76 }
77 }
78 }
79 }
80
81 let tools = request.tools.map(|tool_groups| {
82 tool_groups
83 .into_iter()
84 .flat_map(|t| t.function_declarations.unwrap_or_default())
85 .map(|fd| {
86 let schema_value = fd.parameters.unwrap_or(serde_json::json!({}));
87 let input_schema: JsonSchema =
88 serde_json::from_value(schema_value).unwrap_or_default();
89 LanguageModelTool::Function {
90 name: fd.name,
91 description: fd.description,
92 input_schema,
93 input_examples: vec![],
94 strict: None,
95 provider_options: None,
96 }
97 })
98 .collect()
99 });
100
101 let tool_choice = request.tool_config.and_then(convert_tool_config);
102
103 let (max_output_tokens, temperature, top_p, top_k, stop_sequences) =
104 if let Some(config) = request.generation_config {
105 (
106 config.max_output_tokens,
107 config.temperature,
108 config.top_p,
109 config.top_k,
110 config.stop_sequences,
111 )
112 } else {
113 (None, None, None, None, None)
114 };
115
116 LanguageModelCallOptions {
117 prompt,
118 stream: request.stream,
119 max_output_tokens,
120 temperature,
121 top_p,
122 top_k,
123 stop_sequences,
124 presence_penalty: None,
125 frequency_penalty: None,
126 response_format: None,
127 seed: None,
128 tools,
129 tool_choice,
130 include_raw_chunks: None,
131 abort_signal: None,
132 headers: None,
133 provider_options: None,
134 }
135}
136
137pub fn from_generate_result(
139 model_id: &str,
140 result: LanguageModelGenerateResult,
141) -> GenerateContentResponse {
142 let parts = extract_response_parts(&result.content);
143 let finish_reason = map_finish_reason(&result.finish_reason);
144 let input_tokens = result.usage.input_tokens.total.unwrap_or(0);
145 let output_tokens = result.usage.output_tokens.total.unwrap_or(0);
146
147 GenerateContentResponse {
148 candidates: Some(vec![GenerateContentCandidate {
149 content: Some(GoogleContent {
150 role: Some("model".to_owned()),
151 parts: Some(parts),
152 }),
153 finish_reason: Some(finish_reason),
154 index: Some(0),
155 }]),
156 usage_metadata: Some(GenerateContentUsageMetadata {
157 prompt_token_count: Some(input_tokens),
158 candidates_token_count: Some(output_tokens),
159 total_token_count: Some(input_tokens + output_tokens),
160 cached_content_token_count: None,
161 }),
162 model_version: Some(model_id.to_owned()),
163 }
164}
165
166pub struct StreamConverter {
170 model_id: String,
171 pending_calls: HashMap<String, PendingFunctionCall>,
172}
173
174struct PendingFunctionCall {
175 name: String,
176 args_buffer: String,
177}
178
179impl StreamConverter {
180 pub fn new(model_id: String) -> Self {
181 Self {
182 model_id,
183 pending_calls: HashMap::new(),
184 }
185 }
186
187 pub fn convert(&mut self, part: &LanguageModelStreamPart) -> Option<GenerateContentResponse> {
189 match part {
190 LanguageModelStreamPart::TextDelta { delta, .. } => Some(self.make_chunk(
191 vec![GooglePart {
192 text: Some(delta.clone()),
193 inline_data: None,
194 function_call: None,
195 function_response: None,
196 }],
197 None,
198 None,
199 None,
200 )),
201 LanguageModelStreamPart::ToolCall {
202 tool_name,
203 tool_input,
204 ..
205 } => {
206 let args: serde_json::Value = serde_json::from_str(tool_input).unwrap_or_default();
207 Some(self.make_chunk(
208 vec![GooglePart {
209 text: None,
210 inline_data: None,
211 function_call: Some(GoogleFunctionCall {
212 name: tool_name.clone(),
213 args: Some(args),
214 }),
215 function_response: None,
216 }],
217 None,
218 None,
219 None,
220 ))
221 }
222 LanguageModelStreamPart::ToolInputStart { id, tool_name, .. } => {
223 self.pending_calls.insert(
224 id.clone(),
225 PendingFunctionCall {
226 name: tool_name.clone(),
227 args_buffer: String::new(),
228 },
229 );
230 None
231 }
232 LanguageModelStreamPart::ToolInputDelta { id, delta, .. } => {
233 if let Some(pending) = self.pending_calls.get_mut(id) {
234 pending.args_buffer.push_str(delta);
235 }
236 None
237 }
238 LanguageModelStreamPart::ToolInputEnd { id, .. } => {
239 if let Some(pending) = self.pending_calls.remove(id) {
240 let args: serde_json::Value =
241 serde_json::from_str(&pending.args_buffer).unwrap_or_default();
242 Some(self.make_chunk(
243 vec![GooglePart {
244 text: None,
245 inline_data: None,
246 function_call: Some(GoogleFunctionCall {
247 name: pending.name,
248 args: Some(args),
249 }),
250 function_response: None,
251 }],
252 None,
253 None,
254 None,
255 ))
256 } else {
257 None
258 }
259 }
260 LanguageModelStreamPart::Finish {
261 finish_reason,
262 usage,
263 ..
264 } => {
265 let input_tokens = usage.input_tokens.total.unwrap_or(0);
266 let output_tokens = usage.output_tokens.total.unwrap_or(0);
267 Some(self.make_chunk(
268 vec![GooglePart {
269 text: Some(String::new()),
270 inline_data: None,
271 function_call: None,
272 function_response: None,
273 }],
274 Some(map_finish_reason(finish_reason)),
275 Some(GenerateContentUsageMetadata {
276 prompt_token_count: Some(input_tokens),
277 candidates_token_count: Some(output_tokens),
278 total_token_count: Some(input_tokens + output_tokens),
279 cached_content_token_count: None,
280 }),
281 Some(self.model_id.clone()),
282 ))
283 }
284 _ => None,
285 }
286 }
287
288 fn make_chunk(
289 &self,
290 parts: Vec<GooglePart>,
291 finish_reason: Option<String>,
292 usage_metadata: Option<GenerateContentUsageMetadata>,
293 model_version: Option<String>,
294 ) -> GenerateContentResponse {
295 GenerateContentResponse {
296 candidates: Some(vec![GenerateContentCandidate {
297 content: Some(GoogleContent {
298 role: Some("model".to_owned()),
299 parts: Some(parts),
300 }),
301 finish_reason,
302 index: Some(0),
303 }]),
304 usage_metadata,
305 model_version,
306 }
307 }
308}
309
310fn convert_model_parts(parts: Option<Vec<GooglePart>>) -> Vec<LanguageModelAssistantContent> {
313 parts
314 .unwrap_or_default()
315 .into_iter()
316 .filter_map(|p| {
317 if let Some(fc) = p.function_call {
318 Some(LanguageModelAssistantContent::ToolCall {
319 tool_call_id: format!("call-{}", generate_id()),
320 tool_name: fc.name,
321 input: fc.args.unwrap_or_default(),
322 provider_executed: None,
323 provider_options: None,
324 })
325 } else {
326 p.text.map(|text| LanguageModelAssistantContent::Text {
327 text,
328 provider_options: None,
329 })
330 }
331 })
332 .collect()
333}
334
335fn split_google_parts(
336 parts: Option<Vec<GooglePart>>,
337) -> (Vec<LanguageModelUserContent>, Vec<LanguageModelToolResult>) {
338 let mut user_parts = Vec::new();
339 let mut tool_results = Vec::new();
340 for part in parts.unwrap_or_default() {
341 if let Some(fr) = part.function_response {
342 let output_text = match fr.response {
343 serde_json::Value::String(s) => s,
344 other => serde_json::to_string(&other).unwrap_or_default(),
345 };
346 tool_results.push(LanguageModelToolResult::ToolResult {
347 tool_call_id: String::new(),
348 tool_name: fr.name,
349 output: LanguageModelToolResultOutput::Text {
350 value: output_text,
351 provider_options: None,
352 },
353 provider_options: None,
354 });
355 } else if let Some(text) = part.text {
356 user_parts.push(LanguageModelUserContent::Text {
357 text,
358 provider_options: None,
359 });
360 }
361 }
362 (user_parts, tool_results)
363}
364
365fn convert_tool_config(config: GoogleToolConfig) -> Option<LanguageModelToolChoice> {
366 let fcc = config.function_calling_config?;
367 let mode = fcc.mode?;
368 match mode.as_str() {
369 "AUTO" => Some(LanguageModelToolChoice::Auto),
370 "NONE" => Some(LanguageModelToolChoice::None),
371 "ANY" => {
372 if let Some(names) = fcc.allowed_function_names
373 && names.len() == 1
374 {
375 Some(LanguageModelToolChoice::Tool {
376 tool_name: names.into_iter().next().unwrap_or_default(),
377 })
378 } else {
379 Some(LanguageModelToolChoice::Required)
380 }
381 }
382 _ => None,
383 }
384}
385
386fn extract_response_parts(content: &LanguageModelContent) -> Vec<GooglePart> {
387 match content {
388 LanguageModelContent::Text { text, .. } => vec![GooglePart {
389 text: Some(text.clone()),
390 inline_data: None,
391 function_call: None,
392 function_response: None,
393 }],
394 LanguageModelContent::ToolCall {
395 tool_name,
396 tool_input,
397 ..
398 } => {
399 let args: serde_json::Value = serde_json::from_str(tool_input).unwrap_or_default();
400 vec![GooglePart {
401 text: None,
402 inline_data: None,
403 function_call: Some(GoogleFunctionCall {
404 name: tool_name.clone(),
405 args: Some(args),
406 }),
407 function_response: None,
408 }]
409 }
410 _ => vec![],
411 }
412}
413
414fn map_finish_reason(reason: &LanguageModelFinishReason) -> String {
415 match reason {
416 LanguageModelFinishReason::Stop => "STOP".to_owned(),
417 LanguageModelFinishReason::Length => "MAX_TOKENS".to_owned(),
418 LanguageModelFinishReason::FunctionCall => "STOP".to_owned(),
419 LanguageModelFinishReason::ContentFilter => "SAFETY".to_owned(),
420 LanguageModelFinishReason::Error => "OTHER".to_owned(),
421 LanguageModelFinishReason::Other(other) => other.clone(),
422 }
423}