1use crate::gemini::count_tokens::types::GeminiContentRole;
2use crate::gemini::generate_content::request::GeminiGenerateContentRequest;
3use crate::gemini::generate_content::types::GeminiFunctionCallingMode;
4use crate::openai::count_tokens::types::{
5 HttpMethod, ResponseCodeInterpreterContainer, ResponseCodeInterpreterTool,
6 ResponseCodeInterpreterToolAuto, ResponseCodeInterpreterToolAutoType,
7 ResponseCodeInterpreterToolType, ResponseComputerEnvironment, ResponseComputerTool,
8 ResponseComputerToolType, ResponseFileSearchTool, ResponseFileSearchToolType,
9 ResponseFormatJsonObject, ResponseFormatJsonObjectType, ResponseFormatText,
10 ResponseFormatTextJsonSchemaConfig, ResponseFormatTextJsonSchemaConfigType,
11 ResponseFormatTextType, ResponseFunctionCallOutput, ResponseFunctionCallOutputContent,
12 ResponseFunctionCallOutputType, ResponseFunctionTool, ResponseFunctionToolCall,
13 ResponseFunctionToolCallType, ResponseInput, ResponseInputContent, ResponseInputFile,
14 ResponseInputFileType, ResponseInputImage, ResponseInputImageType, ResponseInputItem,
15 ResponseInputMessage, ResponseInputMessageContent, ResponseInputMessageRole,
16 ResponseInputMessageType, ResponseInputText, ResponseInputTextType, ResponseReasoning,
17 ResponseReasoningItem, ResponseReasoningItemType, ResponseSummaryTextContent,
18 ResponseSummaryTextContentType, ResponseTextConfig, ResponseTextFormatConfig, ResponseTool,
19 ResponseToolChoice, ResponseToolChoiceFunction, ResponseToolChoiceFunctionType,
20 ResponseToolChoiceOptions, ResponseWebSearchTool, ResponseWebSearchToolType,
21};
22use crate::openai::create_response::request::{
23 OpenAiCreateResponseRequest, PathParameters, QueryParameters, RequestBody, RequestHeaders,
24};
25use crate::transform::gemini::utils::{
26 gemini_content_to_text, openai_reasoning_effort_from_gemini_thinking, strip_models_prefix,
27};
28use crate::transform::utils::TransformError;
29
30impl TryFrom<GeminiGenerateContentRequest> for OpenAiCreateResponseRequest {
31 type Error = TransformError;
32
33 fn try_from(value: GeminiGenerateContentRequest) -> Result<Self, TransformError> {
34 let body = value.body;
35
36 let instructions = body
37 .system_instruction
38 .as_ref()
39 .map(gemini_content_to_text)
40 .filter(|text| !text.is_empty());
41
42 let mut input_items = Vec::new();
43 let mut reasoning_index = 0u64;
44 let mut tool_call_index = 0u64;
45 for content in body.contents {
46 let role = match content.role.unwrap_or(GeminiContentRole::User) {
47 GeminiContentRole::User => ResponseInputMessageRole::User,
48 GeminiContentRole::Model => ResponseInputMessageRole::Assistant,
49 };
50 let mut message_parts = Vec::new();
51
52 for part in content.parts {
53 if let Some(text) = part.text
54 && !text.is_empty()
55 {
56 if part.thought.unwrap_or(false) {
57 if !message_parts.is_empty() {
58 let content = if message_parts.len() == 1 {
59 match message_parts.into_iter().next() {
60 Some(ResponseInputContent::Text(text_part)) => {
61 ResponseInputMessageContent::Text(text_part.text)
62 }
63 Some(other) => ResponseInputMessageContent::List(vec![other]),
64 None => ResponseInputMessageContent::Text(String::new()),
65 }
66 } else {
67 ResponseInputMessageContent::List(message_parts)
68 };
69 input_items.push(ResponseInputItem::Message(ResponseInputMessage {
70 content,
71 role: role.clone(),
72 phase: None,
73 status: None,
74 type_: Some(ResponseInputMessageType::Message),
75 }));
76 message_parts = Vec::new();
77 }
78
79 let id = part.thought_signature.unwrap_or_else(|| {
80 let id = format!("reasoning_{reasoning_index}");
81 reasoning_index += 1;
82 id
83 });
84 input_items.push(ResponseInputItem::ReasoningItem(ResponseReasoningItem {
85 id: Some(id),
86 summary: vec![ResponseSummaryTextContent {
87 text,
88 type_: ResponseSummaryTextContentType::SummaryText,
89 }],
90 type_: ResponseReasoningItemType::Reasoning,
91 content: None,
92 encrypted_content: None,
93 status: None,
94 }));
95 } else {
96 message_parts.push(ResponseInputContent::Text(ResponseInputText {
97 text,
98 type_: ResponseInputTextType::InputText,
99 }));
100 }
101 }
102
103 if let Some(inline_data) = part.inline_data {
104 if inline_data.mime_type.starts_with("image/") {
105 message_parts.push(ResponseInputContent::Image(ResponseInputImage {
106 detail: None,
107 type_: ResponseInputImageType::InputImage,
108 file_id: None,
109 image_url: Some(format!(
110 "data:{};base64,{}",
111 inline_data.mime_type, inline_data.data
112 )),
113 }));
114 } else {
115 message_parts.push(ResponseInputContent::File(ResponseInputFile {
116 type_: ResponseInputFileType::InputFile,
117 detail: None,
118 file_data: Some(inline_data.data),
119 file_id: None,
120 file_url: None,
121 filename: Some(inline_data.mime_type),
122 }));
123 }
124 }
125
126 if let Some(file_data) = part.file_data {
127 if file_data.file_uri.is_empty() {
128 continue;
129 }
130 if file_data
131 .mime_type
132 .as_deref()
133 .unwrap_or_default()
134 .starts_with("image/")
135 {
136 message_parts.push(ResponseInputContent::Image(ResponseInputImage {
137 detail: None,
138 type_: ResponseInputImageType::InputImage,
139 file_id: None,
140 image_url: Some(file_data.file_uri),
141 }));
142 } else {
143 message_parts.push(ResponseInputContent::File(ResponseInputFile {
144 type_: ResponseInputFileType::InputFile,
145 detail: None,
146 file_data: None,
147 file_id: None,
148 file_url: Some(file_data.file_uri),
149 filename: None,
150 }));
151 }
152 }
153
154 if let Some(function_call) = part.function_call {
155 if !message_parts.is_empty() {
156 let content = if message_parts.len() == 1 {
157 match message_parts.into_iter().next() {
158 Some(ResponseInputContent::Text(text_part)) => {
159 ResponseInputMessageContent::Text(text_part.text)
160 }
161 Some(other) => ResponseInputMessageContent::List(vec![other]),
162 None => ResponseInputMessageContent::Text(String::new()),
163 }
164 } else {
165 ResponseInputMessageContent::List(message_parts)
166 };
167 input_items.push(ResponseInputItem::Message(ResponseInputMessage {
168 content,
169 role: role.clone(),
170 phase: None,
171 status: None,
172 type_: Some(ResponseInputMessageType::Message),
173 }));
174 message_parts = Vec::new();
175 }
176
177 let call_id = function_call.id.unwrap_or_else(|| {
178 let id = format!("call_{tool_call_index}");
179 tool_call_index += 1;
180 id
181 });
182 let arguments = function_call
183 .args
184 .and_then(|args| serde_json::to_string(&args).ok())
185 .unwrap_or_else(|| "{}".to_string());
186 input_items.push(ResponseInputItem::FunctionToolCall(
187 ResponseFunctionToolCall {
188 arguments,
189 call_id: call_id.clone(),
190 name: function_call.name,
191 type_: ResponseFunctionToolCallType::FunctionCall,
192 id: Some(call_id),
193 status: None,
194 },
195 ));
196 }
197
198 if let Some(function_response) = part.function_response {
199 if !message_parts.is_empty() {
200 let content = if message_parts.len() == 1 {
201 match message_parts.into_iter().next() {
202 Some(ResponseInputContent::Text(text_part)) => {
203 ResponseInputMessageContent::Text(text_part.text)
204 }
205 Some(other) => ResponseInputMessageContent::List(vec![other]),
206 None => ResponseInputMessageContent::Text(String::new()),
207 }
208 } else {
209 ResponseInputMessageContent::List(message_parts)
210 };
211 input_items.push(ResponseInputItem::Message(ResponseInputMessage {
212 content,
213 role: role.clone(),
214 phase: None,
215 status: None,
216 type_: Some(ResponseInputMessageType::Message),
217 }));
218 message_parts = Vec::new();
219 }
220
221 let call_id = function_response
222 .id
223 .unwrap_or_else(|| function_response.name.clone());
224 let output = match serde_json::to_string(&function_response.response) {
225 Ok(text) if !text.is_empty() => {
226 ResponseFunctionCallOutputContent::Text(text)
227 }
228 _ => ResponseFunctionCallOutputContent::Text("{}".to_string()),
229 };
230 input_items.push(ResponseInputItem::FunctionCallOutput(
231 ResponseFunctionCallOutput {
232 call_id,
233 output,
234 type_: ResponseFunctionCallOutputType::FunctionCallOutput,
235 id: None,
236 status: None,
237 },
238 ));
239 }
240 }
241
242 if !message_parts.is_empty() {
243 let content = if message_parts.len() == 1 {
244 match message_parts.into_iter().next() {
245 Some(ResponseInputContent::Text(text_part)) => {
246 ResponseInputMessageContent::Text(text_part.text)
247 }
248 Some(other) => ResponseInputMessageContent::List(vec![other]),
249 None => ResponseInputMessageContent::Text(String::new()),
250 }
251 } else {
252 ResponseInputMessageContent::List(message_parts)
253 };
254 input_items.push(ResponseInputItem::Message(ResponseInputMessage {
255 content,
256 role,
257 phase: None,
258 status: None,
259 type_: Some(ResponseInputMessageType::Message),
260 }));
261 }
262 }
263 let input = if input_items.is_empty() {
264 None
265 } else {
266 Some(ResponseInput::Items(input_items))
267 };
268
269 let tools = body.tools.and_then(|tools| {
270 let mut converted_tools = Vec::new();
271 for tool in tools {
272 if let Some(function_declarations) = tool.function_declarations {
273 for declaration in function_declarations {
274 let parameters = declaration
275 .parameters_json_schema
276 .and_then(|value| {
277 serde_json::from_value::<crate::openai::count_tokens::types::JsonObject>(value).ok()
278 })
279 .unwrap_or_default();
280 converted_tools.push(ResponseTool::Function(ResponseFunctionTool {
281 name: declaration.name,
282 parameters,
283 strict: None,
284 type_: crate::openai::count_tokens::types::ResponseFunctionToolType::Function,
285 defer_loading: None,
286 description: if declaration.description.is_empty() {
287 None
288 } else {
289 Some(declaration.description)
290 },
291 }));
292 }
293 }
294
295 if let Some(file_search) = tool.file_search {
296 converted_tools.push(ResponseTool::FileSearch(ResponseFileSearchTool {
297 type_: ResponseFileSearchToolType::FileSearch,
298 vector_store_ids: file_search.file_search_store_names,
299 filters: None,
300 max_num_results: file_search.top_k.and_then(|v| u32::try_from(v).ok()),
301 ranking_options: None,
302 }));
303 }
304
305 if tool.computer_use.is_some() {
306 converted_tools.push(ResponseTool::Computer(ResponseComputerTool {
307 display_height: Some(1024),
308 display_width: Some(1024),
309 environment: Some(ResponseComputerEnvironment::Browser),
310 type_: ResponseComputerToolType::ComputerUsePreview,
311 }));
312 }
313
314 if tool.google_search.is_some()
315 || tool.google_search_retrieval.is_some()
316 || tool.url_context.is_some()
317 || tool.google_maps.is_some()
318 {
319 converted_tools.push(ResponseTool::WebSearch(ResponseWebSearchTool {
320 type_: ResponseWebSearchToolType::WebSearch,
321 filters: None,
322 search_context_size: None,
323 user_location: None,
324 }));
325 }
326
327 if tool.code_execution.is_some() {
328 converted_tools.push(ResponseTool::CodeInterpreter(ResponseCodeInterpreterTool {
329 container: ResponseCodeInterpreterContainer::Auto(
330 ResponseCodeInterpreterToolAuto {
331 type_: ResponseCodeInterpreterToolAutoType::Auto,
332 file_ids: None,
333 memory_limit: None,
334 network_policy: None,
335 },
336 ),
337 type_: ResponseCodeInterpreterToolType::CodeInterpreter,
338 }));
339 }
340 }
341 if converted_tools.is_empty() {
342 None
343 } else {
344 Some(converted_tools)
345 }
346 });
347
348 let tool_choice = body
349 .tool_config
350 .and_then(|config| config.function_calling_config)
351 .map(|config| {
352 if let Some(name) = config
353 .allowed_function_names
354 .as_ref()
355 .and_then(|names| names.first())
356 .cloned()
357 {
358 return ResponseToolChoice::Function(ResponseToolChoiceFunction {
359 name,
360 type_: ResponseToolChoiceFunctionType::Function,
361 });
362 }
363 match config
364 .mode
365 .unwrap_or(GeminiFunctionCallingMode::ModeUnspecified)
366 {
367 GeminiFunctionCallingMode::Auto
368 | GeminiFunctionCallingMode::ModeUnspecified => {
369 ResponseToolChoice::Options(ResponseToolChoiceOptions::Auto)
370 }
371 GeminiFunctionCallingMode::Any | GeminiFunctionCallingMode::Validated => {
372 ResponseToolChoice::Options(ResponseToolChoiceOptions::Required)
373 }
374 GeminiFunctionCallingMode::None => {
375 ResponseToolChoice::Options(ResponseToolChoiceOptions::None)
376 }
377 }
378 });
379
380 let max_output_tokens = body
381 .generation_config
382 .as_ref()
383 .and_then(|config| config.max_output_tokens)
384 .map(u64::from);
385 let temperature = body
386 .generation_config
387 .as_ref()
388 .and_then(|config| config.temperature);
389 let top_p = body
390 .generation_config
391 .as_ref()
392 .and_then(|config| config.top_p);
393
394 let reasoning = body
395 .generation_config
396 .as_ref()
397 .and_then(|config| config.thinking_config.as_ref())
398 .and_then(openai_reasoning_effort_from_gemini_thinking)
399 .map(|effort| ResponseReasoning {
400 effort: Some(effort),
401 generate_summary: None,
402 summary: None,
403 });
404
405 let text = body.generation_config.as_ref().and_then(|config| {
406 let schema = config
407 .response_json_schema
408 .clone()
409 .or(config.response_json_schema_legacy.clone())
410 .or_else(|| {
411 config
412 .response_schema
413 .as_ref()
414 .and_then(|schema| serde_json::to_value(schema).ok())
415 })
416 .and_then(|value| {
417 serde_json::from_value::<crate::openai::count_tokens::types::JsonObject>(value)
418 .ok()
419 });
420
421 let format = match config.response_mime_type.as_deref() {
422 Some("application/json") => Some(if let Some(schema) = schema {
423 ResponseTextFormatConfig::JsonSchema(ResponseFormatTextJsonSchemaConfig {
424 name: "output".to_string(),
425 schema,
426 type_: ResponseFormatTextJsonSchemaConfigType::JsonSchema,
427 description: None,
428 strict: None,
429 })
430 } else {
431 ResponseTextFormatConfig::JsonObject(ResponseFormatJsonObject {
432 type_: ResponseFormatJsonObjectType::JsonObject,
433 })
434 }),
435 Some("text/plain") => Some(ResponseTextFormatConfig::Text(ResponseFormatText {
436 type_: ResponseFormatTextType::Text,
437 })),
438 _ => schema.map(|schema| {
439 ResponseTextFormatConfig::JsonSchema(ResponseFormatTextJsonSchemaConfig {
440 name: "output".to_string(),
441 schema,
442 type_: ResponseFormatTextJsonSchemaConfigType::JsonSchema,
443 description: None,
444 strict: None,
445 })
446 }),
447 };
448
449 format.map(|format| ResponseTextConfig {
450 format: Some(format),
451 verbosity: None,
452 })
453 });
454
455 Ok(Self {
456 method: HttpMethod::Post,
457 path: PathParameters::default(),
458 query: QueryParameters::default(),
459 headers: RequestHeaders::default(),
460 body: RequestBody {
461 input,
462 instructions,
463 max_output_tokens,
464 model: Some(strip_models_prefix(&value.path.model)),
465 reasoning,
466 stream: None,
467 temperature,
468 text,
469 tool_choice,
470 tools,
471 top_p,
472 ..RequestBody::default()
473 },
474 })
475 }
476}