Skip to main content

gproxy_protocol/transform/gemini/count_tokens/openai/
request.rs

1use crate::gemini::count_tokens::request::GeminiCountTokensRequest;
2use crate::gemini::count_tokens::types::{
3    GeminiContentRole, GeminiFunctionCallingMode, GeminiLanguage, GeminiOutcome,
4};
5use crate::openai::count_tokens::request::{
6    OpenAiCountTokensRequest, PathParameters, QueryParameters, RequestBody, RequestHeaders,
7};
8use crate::openai::count_tokens::types::{
9    HttpMethod, ResponseCodeInterpreterContainer, ResponseCodeInterpreterTool,
10    ResponseCodeInterpreterToolAuto, ResponseCodeInterpreterToolAutoType,
11    ResponseCodeInterpreterToolType, ResponseComputerEnvironment, ResponseComputerTool,
12    ResponseComputerToolType, ResponseFileSearchTool, ResponseFileSearchToolType,
13    ResponseFormatJsonObject, ResponseFormatJsonObjectType, ResponseFormatText,
14    ResponseFormatTextJsonSchemaConfig, ResponseFormatTextJsonSchemaConfigType,
15    ResponseFormatTextType, ResponseFunctionCallOutput, ResponseFunctionCallOutputContent,
16    ResponseFunctionCallOutputType, ResponseFunctionTool, ResponseFunctionToolCall,
17    ResponseFunctionToolCallType, ResponseInput, ResponseInputContent, ResponseInputFile,
18    ResponseInputFileType, ResponseInputImage, ResponseInputImageType, ResponseInputItem,
19    ResponseInputMessage, ResponseInputMessageContent, ResponseInputMessageRole,
20    ResponseInputMessageType, ResponseInputText, ResponseInputTextType, ResponseReasoning,
21    ResponseReasoningItem, ResponseReasoningItemType, ResponseSummaryTextContent,
22    ResponseSummaryTextContentType, ResponseTextConfig, ResponseTextFormatConfig, ResponseTool,
23    ResponseToolChoice, ResponseToolChoiceFunction, ResponseToolChoiceFunctionType,
24    ResponseToolChoiceOptions, ResponseWebSearchTool, ResponseWebSearchToolType,
25};
26use crate::transform::gemini::utils::{
27    openai_reasoning_effort_from_gemini_thinking, strip_models_prefix,
28};
29use crate::transform::utils::TransformError;
30
31impl TryFrom<GeminiCountTokensRequest> for OpenAiCountTokensRequest {
32    type Error = TransformError;
33
34    fn try_from(value: GeminiCountTokensRequest) -> Result<Self, TransformError> {
35        let (raw_model, contents, tools, tool_config, system_instruction, generation_config) =
36            if let Some(generate_content_request) = value.body.generate_content_request {
37                (
38                    generate_content_request.model,
39                    generate_content_request.contents,
40                    generate_content_request.tools,
41                    generate_content_request.tool_config,
42                    generate_content_request.system_instruction,
43                    generate_content_request.generation_config,
44                )
45            } else {
46                (
47                    value.path.model,
48                    value.body.contents.unwrap_or_default(),
49                    None,
50                    None,
51                    None,
52                    None,
53                )
54            };
55
56        let model = {
57            let stripped = strip_models_prefix(&raw_model);
58            if stripped.is_empty() {
59                None
60            } else {
61                Some(stripped)
62            }
63        };
64
65        let instructions = system_instruction.and_then(|content| {
66            let lines = content
67                .parts
68                .into_iter()
69                .filter_map(|part| {
70                    if let Some(text) = part.text {
71                        let text = text.trim().to_string();
72                        if !text.is_empty() {
73                            return Some(text);
74                        }
75                    }
76                    if let Some(file_data) = part.file_data
77                        && !file_data.file_uri.is_empty()
78                    {
79                        return Some(file_data.file_uri);
80                    }
81                    None
82                })
83                .collect::<Vec<_>>();
84            if lines.is_empty() {
85                None
86            } else {
87                Some(lines.join("\n"))
88            }
89        });
90
91        let mut input_items = Vec::new();
92        let mut reasoning_index: u64 = 0;
93        let mut tool_call_index: u64 = 0;
94
95        for content in contents {
96            let role = match content.role.unwrap_or(GeminiContentRole::User) {
97                GeminiContentRole::User => ResponseInputMessageRole::User,
98                GeminiContentRole::Model => ResponseInputMessageRole::Assistant,
99            };
100            let mut message_content = Vec::new();
101
102            for part in content.parts {
103                if let Some(text) = part.text
104                    && !text.is_empty()
105                {
106                    if part.thought.unwrap_or(false) {
107                        if !message_content.is_empty() {
108                            let content = if message_content.len() == 1 {
109                                match message_content.into_iter().next() {
110                                    Some(ResponseInputContent::Text(text_part)) => {
111                                        ResponseInputMessageContent::Text(text_part.text)
112                                    }
113                                    Some(other) => ResponseInputMessageContent::List(vec![other]),
114                                    None => ResponseInputMessageContent::Text(String::new()),
115                                }
116                            } else {
117                                ResponseInputMessageContent::List(message_content)
118                            };
119                            input_items.push(ResponseInputItem::Message(ResponseInputMessage {
120                                content,
121                                role: role.clone(),
122                                phase: None,
123                                status: None,
124                                type_: Some(ResponseInputMessageType::Message),
125                            }));
126                            message_content = Vec::new();
127                        }
128
129                        let reasoning_id = part.thought_signature.unwrap_or_else(|| {
130                            let id = format!("reasoning_{reasoning_index}");
131                            reasoning_index += 1;
132                            id
133                        });
134                        input_items.push(ResponseInputItem::ReasoningItem(ResponseReasoningItem {
135                            id: Some(reasoning_id),
136                            summary: vec![ResponseSummaryTextContent {
137                                text,
138                                type_: ResponseSummaryTextContentType::SummaryText,
139                            }],
140                            type_: ResponseReasoningItemType::Reasoning,
141                            content: None,
142                            encrypted_content: None,
143                            status: None,
144                        }));
145                    } else {
146                        message_content.push(ResponseInputContent::Text(ResponseInputText {
147                            text,
148                            type_: ResponseInputTextType::InputText,
149                        }));
150                    }
151                }
152
153                if let Some(inline_data) = part.inline_data {
154                    if inline_data.mime_type.starts_with("image/") {
155                        message_content.push(ResponseInputContent::Image(ResponseInputImage {
156                            detail: None,
157                            type_: ResponseInputImageType::InputImage,
158                            file_id: None,
159                            image_url: Some(format!(
160                                "data:{};base64,{}",
161                                inline_data.mime_type, inline_data.data
162                            )),
163                        }));
164                    } else {
165                        message_content.push(ResponseInputContent::File(ResponseInputFile {
166                            type_: ResponseInputFileType::InputFile,
167                            detail: None,
168                            file_data: Some(inline_data.data),
169                            file_id: None,
170                            file_url: None,
171                            filename: Some(inline_data.mime_type),
172                        }));
173                    }
174                }
175
176                if let Some(file_data) = part.file_data {
177                    if file_data.file_uri.is_empty() {
178                        continue;
179                    }
180                    if file_data
181                        .mime_type
182                        .as_deref()
183                        .unwrap_or_default()
184                        .starts_with("image/")
185                    {
186                        message_content.push(ResponseInputContent::Image(ResponseInputImage {
187                            detail: None,
188                            type_: ResponseInputImageType::InputImage,
189                            file_id: None,
190                            image_url: Some(file_data.file_uri),
191                        }));
192                    } else {
193                        message_content.push(ResponseInputContent::File(ResponseInputFile {
194                            type_: ResponseInputFileType::InputFile,
195                            detail: None,
196                            file_data: None,
197                            file_id: None,
198                            file_url: Some(file_data.file_uri),
199                            filename: None,
200                        }));
201                    }
202                }
203
204                if let Some(code) = part.executable_code {
205                    let language = match code.language {
206                        GeminiLanguage::Python => "python",
207                        GeminiLanguage::LanguageUnspecified => "text",
208                    };
209                    message_content.push(ResponseInputContent::Text(ResponseInputText {
210                        text: format!("```{language}\n{}\n```", code.code),
211                        type_: ResponseInputTextType::InputText,
212                    }));
213                }
214
215                if let Some(result) = part.code_execution_result {
216                    let outcome = match result.outcome {
217                        GeminiOutcome::OutcomeUnspecified => "outcome_unspecified",
218                        GeminiOutcome::OutcomeOk => "outcome_ok",
219                        GeminiOutcome::OutcomeFailed => "outcome_failed",
220                        GeminiOutcome::OutcomeDeadlineExceeded => "outcome_deadline_exceeded",
221                    };
222                    let text = match result.output {
223                        Some(output) if !output.is_empty() => {
224                            format!("code_execution_result:{outcome}\n{output}")
225                        }
226                        _ => format!("code_execution_result:{outcome}"),
227                    };
228                    message_content.push(ResponseInputContent::Text(ResponseInputText {
229                        text,
230                        type_: ResponseInputTextType::InputText,
231                    }));
232                }
233
234                if let Some(function_call) = part.function_call {
235                    if !message_content.is_empty() {
236                        let content = if message_content.len() == 1 {
237                            match message_content.into_iter().next() {
238                                Some(ResponseInputContent::Text(text_part)) => {
239                                    ResponseInputMessageContent::Text(text_part.text)
240                                }
241                                Some(other) => ResponseInputMessageContent::List(vec![other]),
242                                None => ResponseInputMessageContent::Text(String::new()),
243                            }
244                        } else {
245                            ResponseInputMessageContent::List(message_content)
246                        };
247                        input_items.push(ResponseInputItem::Message(ResponseInputMessage {
248                            content,
249                            role: role.clone(),
250                            phase: None,
251                            status: None,
252                            type_: Some(ResponseInputMessageType::Message),
253                        }));
254                        message_content = Vec::new();
255                    }
256
257                    let call_id = function_call.id.unwrap_or_else(|| {
258                        let id = format!("call_{tool_call_index}");
259                        tool_call_index += 1;
260                        id
261                    });
262                    let arguments = function_call
263                        .args
264                        .and_then(|args| serde_json::to_string(&args).ok())
265                        .unwrap_or_else(|| "{}".to_string());
266                    input_items.push(ResponseInputItem::FunctionToolCall(
267                        ResponseFunctionToolCall {
268                            arguments,
269                            call_id: call_id.clone(),
270                            name: function_call.name,
271                            type_: ResponseFunctionToolCallType::FunctionCall,
272                            id: Some(call_id),
273                            status: None,
274                        },
275                    ));
276                }
277
278                if let Some(function_response) = part.function_response {
279                    if !message_content.is_empty() {
280                        let content = if message_content.len() == 1 {
281                            match message_content.into_iter().next() {
282                                Some(ResponseInputContent::Text(text_part)) => {
283                                    ResponseInputMessageContent::Text(text_part.text)
284                                }
285                                Some(other) => ResponseInputMessageContent::List(vec![other]),
286                                None => ResponseInputMessageContent::Text(String::new()),
287                            }
288                        } else {
289                            ResponseInputMessageContent::List(message_content)
290                        };
291                        input_items.push(ResponseInputItem::Message(ResponseInputMessage {
292                            content,
293                            role: role.clone(),
294                            phase: None,
295                            status: None,
296                            type_: Some(ResponseInputMessageType::Message),
297                        }));
298                        message_content = Vec::new();
299                    }
300
301                    let call_id = function_response
302                        .id
303                        .unwrap_or_else(|| function_response.name.clone());
304                    let output = match serde_json::to_string(&function_response.response) {
305                        Ok(text) if !text.is_empty() => {
306                            ResponseFunctionCallOutputContent::Text(text)
307                        }
308                        _ => ResponseFunctionCallOutputContent::Text("{}".to_string()),
309                    };
310
311                    input_items.push(ResponseInputItem::FunctionCallOutput(
312                        ResponseFunctionCallOutput {
313                            call_id,
314                            output,
315                            type_: ResponseFunctionCallOutputType::FunctionCallOutput,
316                            id: None,
317                            status: None,
318                        },
319                    ));
320                }
321            }
322
323            if !message_content.is_empty() {
324                let content = if message_content.len() == 1 {
325                    match message_content.into_iter().next() {
326                        Some(ResponseInputContent::Text(text_part)) => {
327                            ResponseInputMessageContent::Text(text_part.text)
328                        }
329                        Some(other) => ResponseInputMessageContent::List(vec![other]),
330                        None => ResponseInputMessageContent::Text(String::new()),
331                    }
332                } else {
333                    ResponseInputMessageContent::List(message_content)
334                };
335
336                input_items.push(ResponseInputItem::Message(ResponseInputMessage {
337                    content,
338                    role,
339                    phase: None,
340                    status: None,
341                    type_: Some(ResponseInputMessageType::Message),
342                }));
343            }
344        }
345
346        let input = if input_items.is_empty() {
347            None
348        } else {
349            Some(ResponseInput::Items(input_items))
350        };
351
352        let tools = tools.and_then(|tool_defs| {
353            let mut mapped = Vec::new();
354            for tool in tool_defs {
355                if let Some(function_declarations) = tool.function_declarations {
356                    for declaration in function_declarations {
357                        let parameters = declaration
358                            .parameters_json_schema
359                            .and_then(|value| match value {
360                                serde_json::Value::Object(map) => Some(map.into_iter().collect()),
361                                _ => None,
362                            })
363                            .unwrap_or_default();
364                        mapped.push(ResponseTool::Function(ResponseFunctionTool {
365                            name: declaration.name,
366                            parameters,
367                            strict: None,
368                            type_: crate::openai::count_tokens::types::ResponseFunctionToolType::Function,
369                            defer_loading: None,
370                            description: if declaration.description.is_empty() {
371                                None
372                            } else {
373                                Some(declaration.description)
374                            },
375                        }));
376                    }
377                }
378
379                if let Some(file_search) = tool.file_search {
380                    mapped.push(ResponseTool::FileSearch(ResponseFileSearchTool {
381                        type_: ResponseFileSearchToolType::FileSearch,
382                        vector_store_ids: file_search.file_search_store_names,
383                        filters: None,
384                        max_num_results: file_search.top_k.and_then(|v| u32::try_from(v).ok()),
385                        ranking_options: None,
386                    }));
387                }
388
389                if tool.computer_use.is_some() {
390                    mapped.push(ResponseTool::Computer(ResponseComputerTool {
391                        display_height: Some(1024),
392                        display_width: Some(1024),
393                        environment: Some(ResponseComputerEnvironment::Browser),
394                        type_: ResponseComputerToolType::ComputerUsePreview,
395                    }));
396                }
397
398                if tool.google_search.is_some()
399                    || tool.google_search_retrieval.is_some()
400                    || tool.url_context.is_some()
401                    || tool.google_maps.is_some()
402                {
403                    mapped.push(ResponseTool::WebSearch(ResponseWebSearchTool {
404                        type_: ResponseWebSearchToolType::WebSearch,
405                        filters: None,
406                        search_context_size: None,
407                        user_location: None,
408                    }));
409                }
410
411                if tool.code_execution.is_some() {
412                    mapped.push(ResponseTool::CodeInterpreter(ResponseCodeInterpreterTool {
413                        container: ResponseCodeInterpreterContainer::Auto(
414                            ResponseCodeInterpreterToolAuto {
415                                type_: ResponseCodeInterpreterToolAutoType::Auto,
416                                file_ids: None,
417                                memory_limit: None,
418                                network_policy: None,
419                            },
420                        ),
421                        type_: ResponseCodeInterpreterToolType::CodeInterpreter,
422                    }));
423                }
424            }
425
426            if mapped.is_empty() {
427                None
428            } else {
429                Some(mapped)
430            }
431        });
432
433        let tool_choice = tool_config
434            .and_then(|config| config.function_calling_config)
435            .map(|config| {
436                if let Some(name) = config
437                    .allowed_function_names
438                    .as_ref()
439                    .and_then(|names| names.first())
440                    .cloned()
441                {
442                    return ResponseToolChoice::Function(ResponseToolChoiceFunction {
443                        name,
444                        type_: ResponseToolChoiceFunctionType::Function,
445                    });
446                }
447
448                match config
449                    .mode
450                    .unwrap_or(GeminiFunctionCallingMode::ModeUnspecified)
451                {
452                    GeminiFunctionCallingMode::Auto
453                    | GeminiFunctionCallingMode::ModeUnspecified => {
454                        ResponseToolChoice::Options(ResponseToolChoiceOptions::Auto)
455                    }
456                    GeminiFunctionCallingMode::Any | GeminiFunctionCallingMode::Validated => {
457                        ResponseToolChoice::Options(ResponseToolChoiceOptions::Required)
458                    }
459                    GeminiFunctionCallingMode::None => {
460                        ResponseToolChoice::Options(ResponseToolChoiceOptions::None)
461                    }
462                }
463            });
464
465        let (reasoning, text) = if let Some(config) = generation_config {
466            let reasoning = config
467                .thinking_config
468                .as_ref()
469                .and_then(openai_reasoning_effort_from_gemini_thinking)
470                .map(|effort| ResponseReasoning {
471                    effort: Some(effort),
472                    generate_summary: None,
473                    summary: None,
474                });
475
476            let schema_value = config
477                .response_json_schema
478                .or(config.response_json_schema_legacy)
479                .or_else(|| {
480                    config
481                        .response_schema
482                        .and_then(|schema| serde_json::to_value(schema).ok())
483                });
484            let schema = schema_value.and_then(|value| match value {
485                serde_json::Value::Object(map) => Some(map.into_iter().collect()),
486                _ => None,
487            });
488            let mime = config.response_mime_type.as_deref().unwrap_or_default();
489            let format = match mime {
490                "application/json" => Some(if let Some(schema) = schema {
491                    ResponseTextFormatConfig::JsonSchema(ResponseFormatTextJsonSchemaConfig {
492                        name: "output".to_string(),
493                        schema,
494                        type_: ResponseFormatTextJsonSchemaConfigType::JsonSchema,
495                        description: None,
496                        strict: None,
497                    })
498                } else {
499                    ResponseTextFormatConfig::JsonObject(ResponseFormatJsonObject {
500                        type_: ResponseFormatJsonObjectType::JsonObject,
501                    })
502                }),
503                "text/plain" => Some(ResponseTextFormatConfig::Text(ResponseFormatText {
504                    type_: ResponseFormatTextType::Text,
505                })),
506                _ => schema.map(|schema| {
507                    ResponseTextFormatConfig::JsonSchema(ResponseFormatTextJsonSchemaConfig {
508                        name: "output".to_string(),
509                        schema,
510                        type_: ResponseFormatTextJsonSchemaConfigType::JsonSchema,
511                        description: None,
512                        strict: None,
513                    })
514                }),
515            };
516            let text = format.map(|format| ResponseTextConfig {
517                format: Some(format),
518                verbosity: None,
519            });
520            (reasoning, text)
521        } else {
522            (None, None)
523        };
524
525        Ok(Self {
526            method: HttpMethod::Post,
527            path: PathParameters::default(),
528            query: QueryParameters::default(),
529            headers: RequestHeaders::default(),
530            body: RequestBody {
531                conversation: None,
532                input,
533                instructions,
534                model,
535                parallel_tool_calls: None,
536                previous_response_id: None,
537                reasoning,
538                text,
539                tool_choice,
540                tools,
541                truncation: None,
542            },
543        })
544    }
545}