Skip to main content

gproxy_protocol/transform/gemini/generate_content/openai_response/
request.rs

1use crate::gemini::count_tokens::types::GeminiContentRole;
2use crate::gemini::generate_content::request::GeminiGenerateContentRequest;
3use crate::gemini::generate_content::types::GeminiFunctionCallingMode;
4use crate::openai::count_tokens::types::{
5    HttpMethod, ResponseCodeInterpreterContainer, ResponseCodeInterpreterTool,
6    ResponseCodeInterpreterToolAuto, ResponseCodeInterpreterToolAutoType,
7    ResponseCodeInterpreterToolType, ResponseComputerEnvironment, ResponseComputerTool,
8    ResponseComputerToolType, ResponseFileSearchTool, ResponseFileSearchToolType,
9    ResponseFormatJsonObject, ResponseFormatJsonObjectType, ResponseFormatText,
10    ResponseFormatTextJsonSchemaConfig, ResponseFormatTextJsonSchemaConfigType,
11    ResponseFormatTextType, ResponseFunctionCallOutput, ResponseFunctionCallOutputContent,
12    ResponseFunctionCallOutputType, ResponseFunctionTool, ResponseFunctionToolCall,
13    ResponseFunctionToolCallType, ResponseInput, ResponseInputContent, ResponseInputFile,
14    ResponseInputFileType, ResponseInputImage, ResponseInputImageType, ResponseInputItem,
15    ResponseInputMessage, ResponseInputMessageContent, ResponseInputMessageRole,
16    ResponseInputMessageType, ResponseInputText, ResponseInputTextType, ResponseReasoning,
17    ResponseReasoningItem, ResponseReasoningItemType, ResponseSummaryTextContent,
18    ResponseSummaryTextContentType, ResponseTextConfig, ResponseTextFormatConfig, ResponseTool,
19    ResponseToolChoice, ResponseToolChoiceFunction, ResponseToolChoiceFunctionType,
20    ResponseToolChoiceOptions, ResponseWebSearchTool, ResponseWebSearchToolType,
21};
22use crate::openai::create_response::request::{
23    OpenAiCreateResponseRequest, PathParameters, QueryParameters, RequestBody, RequestHeaders,
24};
25use crate::transform::gemini::utils::{
26    gemini_content_to_text, openai_reasoning_effort_from_gemini_thinking, strip_models_prefix,
27};
28use crate::transform::utils::TransformError;
29
30impl TryFrom<GeminiGenerateContentRequest> for OpenAiCreateResponseRequest {
31    type Error = TransformError;
32
33    fn try_from(value: GeminiGenerateContentRequest) -> Result<Self, TransformError> {
34        let body = value.body;
35
36        let instructions = body
37            .system_instruction
38            .as_ref()
39            .map(gemini_content_to_text)
40            .filter(|text| !text.is_empty());
41
42        let mut input_items = Vec::new();
43        let mut reasoning_index = 0u64;
44        let mut tool_call_index = 0u64;
45        for content in body.contents {
46            let role = match content.role.unwrap_or(GeminiContentRole::User) {
47                GeminiContentRole::User => ResponseInputMessageRole::User,
48                GeminiContentRole::Model => ResponseInputMessageRole::Assistant,
49            };
50            let mut message_parts = Vec::new();
51
52            for part in content.parts {
53                if let Some(text) = part.text
54                    && !text.is_empty()
55                {
56                    if part.thought.unwrap_or(false) {
57                        if !message_parts.is_empty() {
58                            let content = if message_parts.len() == 1 {
59                                match message_parts.into_iter().next() {
60                                    Some(ResponseInputContent::Text(text_part)) => {
61                                        ResponseInputMessageContent::Text(text_part.text)
62                                    }
63                                    Some(other) => ResponseInputMessageContent::List(vec![other]),
64                                    None => ResponseInputMessageContent::Text(String::new()),
65                                }
66                            } else {
67                                ResponseInputMessageContent::List(message_parts)
68                            };
69                            input_items.push(ResponseInputItem::Message(ResponseInputMessage {
70                                content,
71                                role: role.clone(),
72                                phase: None,
73                                status: None,
74                                type_: Some(ResponseInputMessageType::Message),
75                            }));
76                            message_parts = Vec::new();
77                        }
78
79                        let id = part.thought_signature.unwrap_or_else(|| {
80                            let id = format!("reasoning_{reasoning_index}");
81                            reasoning_index += 1;
82                            id
83                        });
84                        input_items.push(ResponseInputItem::ReasoningItem(ResponseReasoningItem {
85                            id: Some(id),
86                            summary: vec![ResponseSummaryTextContent {
87                                text,
88                                type_: ResponseSummaryTextContentType::SummaryText,
89                            }],
90                            type_: ResponseReasoningItemType::Reasoning,
91                            content: None,
92                            encrypted_content: None,
93                            status: None,
94                        }));
95                    } else {
96                        message_parts.push(ResponseInputContent::Text(ResponseInputText {
97                            text,
98                            type_: ResponseInputTextType::InputText,
99                        }));
100                    }
101                }
102
103                if let Some(inline_data) = part.inline_data {
104                    if inline_data.mime_type.starts_with("image/") {
105                        message_parts.push(ResponseInputContent::Image(ResponseInputImage {
106                            detail: None,
107                            type_: ResponseInputImageType::InputImage,
108                            file_id: None,
109                            image_url: Some(format!(
110                                "data:{};base64,{}",
111                                inline_data.mime_type, inline_data.data
112                            )),
113                        }));
114                    } else {
115                        message_parts.push(ResponseInputContent::File(ResponseInputFile {
116                            type_: ResponseInputFileType::InputFile,
117                            detail: None,
118                            file_data: Some(inline_data.data),
119                            file_id: None,
120                            file_url: None,
121                            filename: Some(inline_data.mime_type),
122                        }));
123                    }
124                }
125
126                if let Some(file_data) = part.file_data {
127                    if file_data.file_uri.is_empty() {
128                        continue;
129                    }
130                    if file_data
131                        .mime_type
132                        .as_deref()
133                        .unwrap_or_default()
134                        .starts_with("image/")
135                    {
136                        message_parts.push(ResponseInputContent::Image(ResponseInputImage {
137                            detail: None,
138                            type_: ResponseInputImageType::InputImage,
139                            file_id: None,
140                            image_url: Some(file_data.file_uri),
141                        }));
142                    } else {
143                        message_parts.push(ResponseInputContent::File(ResponseInputFile {
144                            type_: ResponseInputFileType::InputFile,
145                            detail: None,
146                            file_data: None,
147                            file_id: None,
148                            file_url: Some(file_data.file_uri),
149                            filename: None,
150                        }));
151                    }
152                }
153
154                if let Some(function_call) = part.function_call {
155                    if !message_parts.is_empty() {
156                        let content = if message_parts.len() == 1 {
157                            match message_parts.into_iter().next() {
158                                Some(ResponseInputContent::Text(text_part)) => {
159                                    ResponseInputMessageContent::Text(text_part.text)
160                                }
161                                Some(other) => ResponseInputMessageContent::List(vec![other]),
162                                None => ResponseInputMessageContent::Text(String::new()),
163                            }
164                        } else {
165                            ResponseInputMessageContent::List(message_parts)
166                        };
167                        input_items.push(ResponseInputItem::Message(ResponseInputMessage {
168                            content,
169                            role: role.clone(),
170                            phase: None,
171                            status: None,
172                            type_: Some(ResponseInputMessageType::Message),
173                        }));
174                        message_parts = Vec::new();
175                    }
176
177                    let call_id = function_call.id.unwrap_or_else(|| {
178                        let id = format!("call_{tool_call_index}");
179                        tool_call_index += 1;
180                        id
181                    });
182                    let arguments = function_call
183                        .args
184                        .and_then(|args| serde_json::to_string(&args).ok())
185                        .unwrap_or_else(|| "{}".to_string());
186                    input_items.push(ResponseInputItem::FunctionToolCall(
187                        ResponseFunctionToolCall {
188                            arguments,
189                            call_id: call_id.clone(),
190                            name: function_call.name,
191                            type_: ResponseFunctionToolCallType::FunctionCall,
192                            id: Some(call_id),
193                            status: None,
194                        },
195                    ));
196                }
197
198                if let Some(function_response) = part.function_response {
199                    if !message_parts.is_empty() {
200                        let content = if message_parts.len() == 1 {
201                            match message_parts.into_iter().next() {
202                                Some(ResponseInputContent::Text(text_part)) => {
203                                    ResponseInputMessageContent::Text(text_part.text)
204                                }
205                                Some(other) => ResponseInputMessageContent::List(vec![other]),
206                                None => ResponseInputMessageContent::Text(String::new()),
207                            }
208                        } else {
209                            ResponseInputMessageContent::List(message_parts)
210                        };
211                        input_items.push(ResponseInputItem::Message(ResponseInputMessage {
212                            content,
213                            role: role.clone(),
214                            phase: None,
215                            status: None,
216                            type_: Some(ResponseInputMessageType::Message),
217                        }));
218                        message_parts = Vec::new();
219                    }
220
221                    let call_id = function_response
222                        .id
223                        .unwrap_or_else(|| function_response.name.clone());
224                    let output = match serde_json::to_string(&function_response.response) {
225                        Ok(text) if !text.is_empty() => {
226                            ResponseFunctionCallOutputContent::Text(text)
227                        }
228                        _ => ResponseFunctionCallOutputContent::Text("{}".to_string()),
229                    };
230                    input_items.push(ResponseInputItem::FunctionCallOutput(
231                        ResponseFunctionCallOutput {
232                            call_id,
233                            output,
234                            type_: ResponseFunctionCallOutputType::FunctionCallOutput,
235                            id: None,
236                            status: None,
237                        },
238                    ));
239                }
240            }
241
242            if !message_parts.is_empty() {
243                let content = if message_parts.len() == 1 {
244                    match message_parts.into_iter().next() {
245                        Some(ResponseInputContent::Text(text_part)) => {
246                            ResponseInputMessageContent::Text(text_part.text)
247                        }
248                        Some(other) => ResponseInputMessageContent::List(vec![other]),
249                        None => ResponseInputMessageContent::Text(String::new()),
250                    }
251                } else {
252                    ResponseInputMessageContent::List(message_parts)
253                };
254                input_items.push(ResponseInputItem::Message(ResponseInputMessage {
255                    content,
256                    role,
257                    phase: None,
258                    status: None,
259                    type_: Some(ResponseInputMessageType::Message),
260                }));
261            }
262        }
263        let input = if input_items.is_empty() {
264            None
265        } else {
266            Some(ResponseInput::Items(input_items))
267        };
268
269        let tools = body.tools.and_then(|tools| {
270            let mut converted_tools = Vec::new();
271            for tool in tools {
272                if let Some(function_declarations) = tool.function_declarations {
273                    for declaration in function_declarations {
274                        let parameters = declaration
275                            .parameters_json_schema
276                            .and_then(|value| {
277                                serde_json::from_value::<crate::openai::count_tokens::types::JsonObject>(value).ok()
278                            })
279                            .unwrap_or_default();
280                        converted_tools.push(ResponseTool::Function(ResponseFunctionTool {
281                            name: declaration.name,
282                            parameters,
283                            strict: None,
284                            type_: crate::openai::count_tokens::types::ResponseFunctionToolType::Function,
285                            defer_loading: None,
286                            description: if declaration.description.is_empty() {
287                                None
288                            } else {
289                                Some(declaration.description)
290                            },
291                        }));
292                    }
293                }
294
295                if let Some(file_search) = tool.file_search {
296                    converted_tools.push(ResponseTool::FileSearch(ResponseFileSearchTool {
297                        type_: ResponseFileSearchToolType::FileSearch,
298                        vector_store_ids: file_search.file_search_store_names,
299                        filters: None,
300                        max_num_results: file_search.top_k.and_then(|v| u32::try_from(v).ok()),
301                        ranking_options: None,
302                    }));
303                }
304
305                if tool.computer_use.is_some() {
306                    converted_tools.push(ResponseTool::Computer(ResponseComputerTool {
307                        display_height: Some(1024),
308                        display_width: Some(1024),
309                        environment: Some(ResponseComputerEnvironment::Browser),
310                        type_: ResponseComputerToolType::ComputerUsePreview,
311                    }));
312                }
313
314                if tool.google_search.is_some()
315                    || tool.google_search_retrieval.is_some()
316                    || tool.url_context.is_some()
317                    || tool.google_maps.is_some()
318                {
319                    converted_tools.push(ResponseTool::WebSearch(ResponseWebSearchTool {
320                        type_: ResponseWebSearchToolType::WebSearch,
321                        filters: None,
322                        search_context_size: None,
323                        user_location: None,
324                    }));
325                }
326
327                if tool.code_execution.is_some() {
328                    converted_tools.push(ResponseTool::CodeInterpreter(ResponseCodeInterpreterTool {
329                        container: ResponseCodeInterpreterContainer::Auto(
330                            ResponseCodeInterpreterToolAuto {
331                                type_: ResponseCodeInterpreterToolAutoType::Auto,
332                                file_ids: None,
333                                memory_limit: None,
334                                network_policy: None,
335                            },
336                        ),
337                        type_: ResponseCodeInterpreterToolType::CodeInterpreter,
338                    }));
339                }
340            }
341            if converted_tools.is_empty() {
342                None
343            } else {
344                Some(converted_tools)
345            }
346        });
347
348        let tool_choice = body
349            .tool_config
350            .and_then(|config| config.function_calling_config)
351            .map(|config| {
352                if let Some(name) = config
353                    .allowed_function_names
354                    .as_ref()
355                    .and_then(|names| names.first())
356                    .cloned()
357                {
358                    return ResponseToolChoice::Function(ResponseToolChoiceFunction {
359                        name,
360                        type_: ResponseToolChoiceFunctionType::Function,
361                    });
362                }
363                match config
364                    .mode
365                    .unwrap_or(GeminiFunctionCallingMode::ModeUnspecified)
366                {
367                    GeminiFunctionCallingMode::Auto
368                    | GeminiFunctionCallingMode::ModeUnspecified => {
369                        ResponseToolChoice::Options(ResponseToolChoiceOptions::Auto)
370                    }
371                    GeminiFunctionCallingMode::Any | GeminiFunctionCallingMode::Validated => {
372                        ResponseToolChoice::Options(ResponseToolChoiceOptions::Required)
373                    }
374                    GeminiFunctionCallingMode::None => {
375                        ResponseToolChoice::Options(ResponseToolChoiceOptions::None)
376                    }
377                }
378            });
379
380        let max_output_tokens = body
381            .generation_config
382            .as_ref()
383            .and_then(|config| config.max_output_tokens)
384            .map(u64::from);
385        let temperature = body
386            .generation_config
387            .as_ref()
388            .and_then(|config| config.temperature);
389        let top_p = body
390            .generation_config
391            .as_ref()
392            .and_then(|config| config.top_p);
393
394        let reasoning = body
395            .generation_config
396            .as_ref()
397            .and_then(|config| config.thinking_config.as_ref())
398            .and_then(openai_reasoning_effort_from_gemini_thinking)
399            .map(|effort| ResponseReasoning {
400                effort: Some(effort),
401                generate_summary: None,
402                summary: None,
403            });
404
405        let text = body.generation_config.as_ref().and_then(|config| {
406            let schema = config
407                .response_json_schema
408                .clone()
409                .or(config.response_json_schema_legacy.clone())
410                .or_else(|| {
411                    config
412                        .response_schema
413                        .as_ref()
414                        .and_then(|schema| serde_json::to_value(schema).ok())
415                })
416                .and_then(|value| {
417                    serde_json::from_value::<crate::openai::count_tokens::types::JsonObject>(value)
418                        .ok()
419                });
420
421            let format = match config.response_mime_type.as_deref() {
422                Some("application/json") => Some(if let Some(schema) = schema {
423                    ResponseTextFormatConfig::JsonSchema(ResponseFormatTextJsonSchemaConfig {
424                        name: "output".to_string(),
425                        schema,
426                        type_: ResponseFormatTextJsonSchemaConfigType::JsonSchema,
427                        description: None,
428                        strict: None,
429                    })
430                } else {
431                    ResponseTextFormatConfig::JsonObject(ResponseFormatJsonObject {
432                        type_: ResponseFormatJsonObjectType::JsonObject,
433                    })
434                }),
435                Some("text/plain") => Some(ResponseTextFormatConfig::Text(ResponseFormatText {
436                    type_: ResponseFormatTextType::Text,
437                })),
438                _ => schema.map(|schema| {
439                    ResponseTextFormatConfig::JsonSchema(ResponseFormatTextJsonSchemaConfig {
440                        name: "output".to_string(),
441                        schema,
442                        type_: ResponseFormatTextJsonSchemaConfigType::JsonSchema,
443                        description: None,
444                        strict: None,
445                    })
446                }),
447            };
448
449            format.map(|format| ResponseTextConfig {
450                format: Some(format),
451                verbosity: None,
452            })
453        });
454
455        Ok(Self {
456            method: HttpMethod::Post,
457            path: PathParameters::default(),
458            query: QueryParameters::default(),
459            headers: RequestHeaders::default(),
460            body: RequestBody {
461                input,
462                instructions,
463                max_output_tokens,
464                model: Some(strip_models_prefix(&value.path.model)),
465                reasoning,
466                stream: None,
467                temperature,
468                text,
469                tool_choice,
470                tools,
471                top_p,
472                ..RequestBody::default()
473            },
474        })
475    }
476}