gproxy_protocol/transform/openai/generate_content/openai_response/gemini/
response.rs1use std::collections::BTreeMap;
2
3use super::utils::{
4 gemini_citation_annotations, gemini_grounding_to_web_search_item, gemini_logprobs,
5};
6use crate::gemini::generate_content::response::GeminiGenerateContentResponse;
7use crate::gemini::generate_content::types as gt;
8use crate::openai::count_tokens::types as ot;
9use crate::openai::create_response::response::{OpenAiCreateResponseResponse, ResponseBody};
10use crate::openai::create_response::types as rt;
11use crate::openai::types::OpenAiResponseHeaders;
12use crate::transform::openai::generate_content::openai_chat_completions::gemini::utils::{
13 gemini_function_response_to_text, json_object_to_string, prompt_feedback_refusal_text,
14};
15use crate::transform::openai::model_list::gemini::utils::{
16 openai_error_response_from_gemini, strip_models_prefix,
17};
18use crate::transform::utils::TransformError;
19
20impl TryFrom<GeminiGenerateContentResponse> for OpenAiCreateResponseResponse {
21 type Error = TransformError;
22
23 fn try_from(value: GeminiGenerateContentResponse) -> Result<Self, TransformError> {
24 Ok(match value {
25 GeminiGenerateContentResponse::Success {
26 stats_code,
27 headers,
28 body,
29 } => {
30 let response_id = body.response_id.unwrap_or_default();
31 let response_model = body
32 .model_version
33 .as_deref()
34 .map(strip_models_prefix)
35 .unwrap_or_default();
36 let usage = body.usage_metadata.map(|usage| {
37 let input_tokens = usage
38 .prompt_token_count
39 .unwrap_or(0)
40 .saturating_add(usage.tool_use_prompt_token_count.unwrap_or(0));
41 let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
42 let output_tokens = usage
43 .candidates_token_count
44 .unwrap_or(0)
45 .saturating_add(usage.thoughts_token_count.unwrap_or(0));
46 let total_tokens = usage
47 .total_token_count
48 .unwrap_or_else(|| input_tokens.saturating_add(output_tokens));
49
50 rt::ResponseUsage {
51 input_tokens,
52 input_tokens_details: rt::ResponseInputTokensDetails { cached_tokens },
53 output_tokens,
54 output_tokens_details: rt::ResponseOutputTokensDetails {
55 reasoning_tokens: usage.thoughts_token_count.unwrap_or(0),
56 },
57 total_tokens,
58 }
59 });
60 let prompt_feedback = body.prompt_feedback;
61
62 let mut output = Vec::new();
63 let mut output_text_parts = Vec::new();
64 let mut tool_call_count = 0usize;
65 let mut first_finish_reason = None;
66
67 for (candidate_pos, candidate) in
68 body.candidates.unwrap_or_default().into_iter().enumerate()
69 {
70 let candidate_index = candidate.index.unwrap_or(candidate_pos as u32);
71
72 if first_finish_reason.is_none() {
73 first_finish_reason = candidate.finish_reason.clone();
74 }
75
76 if let Some(web_search_item) = gemini_grounding_to_web_search_item(
77 candidate_index,
78 candidate.grounding_metadata.as_ref(),
79 ) {
80 tool_call_count += 1;
81 output.push(web_search_item);
82 }
83
84 let annotations =
85 gemini_citation_annotations(candidate.citation_metadata.as_ref());
86 let logprobs = gemini_logprobs(candidate.logprobs_result.as_ref());
87 let mut logprobs_attached = false;
88 let mut message_content = Vec::new();
89
90 if let Some(content) = candidate.content {
91 for (part_index, part) in content.parts.into_iter().enumerate() {
92 if part.thought.unwrap_or(false) {
93 if let Some(thinking) = part.text
94 && !thinking.is_empty()
95 {
96 let reasoning_id =
97 part.thought_signature.unwrap_or_else(|| {
98 format!(
99 "candidate_{candidate_index}_reasoning_{part_index}"
100 )
101 });
102 output.push(rt::ResponseOutputItem::ReasoningItem(
103 ot::ResponseReasoningItem {
104 id: Some(reasoning_id),
105 summary: vec![ot::ResponseSummaryTextContent {
106 text: thinking.clone(),
107 type_: ot::ResponseSummaryTextContentType::SummaryText,
108 }],
109 type_: ot::ResponseReasoningItemType::Reasoning,
110 content: Some(vec![ot::ResponseReasoningTextContent {
111 text: thinking,
112 type_: ot::ResponseReasoningTextContentType::ReasoningText,
113 }]),
114 encrypted_content: None,
115 status: Some(ot::ResponseItemStatus::Completed),
116 },
117 ));
118 }
119 continue;
120 }
121
122 if let Some(function_call) = part.function_call {
123 tool_call_count += 1;
124 let call_id = function_call.id.unwrap_or_else(|| {
125 format!("candidate_{candidate_index}_tool_{part_index}")
126 });
127 output.push(rt::ResponseOutputItem::FunctionToolCall(
128 ot::ResponseFunctionToolCall {
129 arguments: function_call
130 .args
131 .as_ref()
132 .map(json_object_to_string)
133 .unwrap_or_else(|| "{}".to_string()),
134 call_id: call_id.clone(),
135 name: function_call.name,
136 type_: ot::ResponseFunctionToolCallType::FunctionCall,
137 id: Some(call_id),
138 status: Some(ot::ResponseItemStatus::Completed),
139 },
140 ));
141 }
142
143 if let Some(function_response) = part.function_response {
144 let call_id = function_response
145 .id
146 .clone()
147 .unwrap_or_else(|| function_response.name.clone());
148 let output_text =
149 gemini_function_response_to_text(function_response);
150 output.push(rt::ResponseOutputItem::FunctionCallOutput(
151 ot::ResponseFunctionCallOutput {
152 call_id,
153 output: ot::ResponseFunctionCallOutputContent::Text(
154 output_text,
155 ),
156 type_:
157 ot::ResponseFunctionCallOutputType::FunctionCallOutput,
158 id: None,
159 status: Some(ot::ResponseItemStatus::Completed),
160 },
161 ));
162 }
163
164 if let Some(executable_code) = part.executable_code {
165 tool_call_count += 1;
166 output.push(rt::ResponseOutputItem::CodeInterpreterToolCall(
167 ot::ResponseCodeInterpreterToolCall {
168 id: format!("code_interpreter_{candidate_index}_{part_index}"),
169 code: executable_code.code,
170 container_id: "gemini".to_string(),
171 outputs: None,
172 status: ot::ResponseCodeInterpreterToolCallStatus::Completed,
173 type_: ot::ResponseCodeInterpreterToolCallType::CodeInterpreterCall,
174 },
175 ));
176 }
177
178 if let Some(code_execution_result) = part.code_execution_result
179 && let Some(result_text) = code_execution_result.output
180 && !result_text.is_empty()
181 {
182 output.push(rt::ResponseOutputItem::FunctionCallOutput(
183 ot::ResponseFunctionCallOutput {
184 call_id: format!(
185 "code_execution_{candidate_index}_{part_index}"
186 ),
187 output: ot::ResponseFunctionCallOutputContent::Text(
188 result_text,
189 ),
190 type_:
191 ot::ResponseFunctionCallOutputType::FunctionCallOutput,
192 id: None,
193 status: Some(ot::ResponseItemStatus::Completed),
194 },
195 ));
196 }
197
198 if let Some(text) = part.text
199 && !text.is_empty()
200 {
201 output_text_parts.push(text.clone());
202 message_content.push(ot::ResponseOutputContent::Text(
203 ot::ResponseOutputText {
204 annotations: annotations.clone(),
205 logprobs: if !logprobs_attached {
206 logprobs_attached = true;
207 logprobs.clone()
208 } else {
209 None
210 },
211 text,
212 type_: ot::ResponseOutputTextType::OutputText,
213 },
214 ));
215 continue;
216 }
217
218 if let Some(inline_data) = part.inline_data {
219 let text = format!(
220 "data:{};base64,{}",
221 inline_data.mime_type, inline_data.data
222 );
223 output_text_parts.push(text.clone());
224 message_content.push(ot::ResponseOutputContent::Text(
225 ot::ResponseOutputText {
226 annotations: Vec::new(),
227 logprobs: None,
228 text,
229 type_: ot::ResponseOutputTextType::OutputText,
230 },
231 ));
232 } else if let Some(file_data) = part.file_data {
233 output_text_parts.push(file_data.file_uri.clone());
234 message_content.push(ot::ResponseOutputContent::Text(
235 ot::ResponseOutputText {
236 annotations: Vec::new(),
237 logprobs: None,
238 text: file_data.file_uri,
239 type_: ot::ResponseOutputTextType::OutputText,
240 },
241 ));
242 }
243 }
244 }
245
246 if message_content.is_empty()
247 && let Some(finish_message) = candidate.finish_message
248 && !finish_message.is_empty()
249 {
250 output_text_parts.push(finish_message.clone());
251 message_content.push(ot::ResponseOutputContent::Text(
252 ot::ResponseOutputText {
253 annotations: Vec::new(),
254 logprobs: None,
255 text: finish_message,
256 type_: ot::ResponseOutputTextType::OutputText,
257 },
258 ));
259 }
260
261 if !message_content.is_empty() {
262 output.push(rt::ResponseOutputItem::Message(ot::ResponseOutputMessage {
263 id: format!("{}_message_{}", response_id, candidate_index),
264 content: message_content,
265 role: ot::ResponseOutputMessageRole::Assistant,
266 phase: Some(ot::ResponseMessagePhase::FinalAnswer),
267 status: Some(ot::ResponseItemStatus::Completed),
268 type_: Some(ot::ResponseOutputMessageType::Message),
269 }));
270 }
271 }
272
273 if output.is_empty()
274 && let Some(refusal) = prompt_feedback_refusal_text(prompt_feedback.as_ref())
275 {
276 output.push(rt::ResponseOutputItem::Message(ot::ResponseOutputMessage {
277 id: format!("{}_message_0", response_id),
278 content: vec![ot::ResponseOutputContent::Refusal(
279 ot::ResponseOutputRefusal {
280 refusal,
281 type_: ot::ResponseOutputRefusalType::Refusal,
282 },
283 )],
284 role: ot::ResponseOutputMessageRole::Assistant,
285 phase: Some(ot::ResponseMessagePhase::FinalAnswer),
286 status: Some(ot::ResponseItemStatus::Completed),
287 type_: Some(ot::ResponseOutputMessageType::Message),
288 }));
289 }
290
291 let incomplete_reason = match first_finish_reason.as_ref() {
292 Some(gt::GeminiFinishReason::MaxTokens) => {
293 Some(rt::ResponseIncompleteReason::MaxOutputTokens)
294 }
295 Some(
296 gt::GeminiFinishReason::Safety
297 | gt::GeminiFinishReason::Recitation
298 | gt::GeminiFinishReason::Blocklist
299 | gt::GeminiFinishReason::ProhibitedContent
300 | gt::GeminiFinishReason::Spii
301 | gt::GeminiFinishReason::ImageSafety
302 | gt::GeminiFinishReason::ImageProhibitedContent
303 | gt::GeminiFinishReason::ImageRecitation,
304 ) => Some(rt::ResponseIncompleteReason::ContentFilter),
305 _ => None,
306 }
307 .or_else(|| {
308 match prompt_feedback
309 .as_ref()
310 .and_then(|feedback| feedback.block_reason.as_ref())
311 {
312 Some(gt::GeminiBlockReason::Safety)
313 | Some(gt::GeminiBlockReason::Blocklist)
314 | Some(gt::GeminiBlockReason::ProhibitedContent)
315 | Some(gt::GeminiBlockReason::ImageSafety) => {
316 Some(rt::ResponseIncompleteReason::ContentFilter)
317 }
318 _ => None,
319 }
320 });
321 let is_incomplete = incomplete_reason.is_some();
322
323 OpenAiCreateResponseResponse::Success {
324 stats_code,
325 headers: OpenAiResponseHeaders {
326 extra: headers.extra,
327 },
328 body: ResponseBody {
329 id: response_id,
330 created_at: 0,
331 error: None,
332 incomplete_details: incomplete_reason.map(|reason| {
333 rt::ResponseIncompleteDetails {
334 reason: Some(reason),
335 }
336 }),
337 instructions: Some(ot::ResponseInput::Text(String::new())),
338 metadata: BTreeMap::new(),
339 model: response_model,
340 object: rt::ResponseObject::Response,
341 output,
342 parallel_tool_calls: tool_call_count > 1,
343 temperature: 1.0,
344 tool_choice: if tool_call_count > 0 {
345 ot::ResponseToolChoice::Options(ot::ResponseToolChoiceOptions::Required)
346 } else {
347 ot::ResponseToolChoice::Options(ot::ResponseToolChoiceOptions::Auto)
348 },
349 tools: Vec::new(),
350 top_p: 1.0,
351 background: None,
352 completed_at: None,
353 conversation: None,
354 max_output_tokens: None,
355 max_tool_calls: None,
356 output_text: if output_text_parts.is_empty() {
357 None
358 } else {
359 Some(output_text_parts.join("\n"))
360 },
361 previous_response_id: None,
362 prompt: None,
363 prompt_cache_key: None,
364 prompt_cache_retention: None,
365 reasoning: None,
366 safety_identifier: None,
367 service_tier: None,
368 status: Some(if is_incomplete {
369 rt::ResponseStatus::Incomplete
370 } else {
371 rt::ResponseStatus::Completed
372 }),
373 text: None,
374 top_logprobs: None,
375 truncation: None,
376 usage,
377 user: None,
378 },
379 }
380 }
381 GeminiGenerateContentResponse::Error {
382 stats_code,
383 headers,
384 body,
385 } => OpenAiCreateResponseResponse::Error {
386 stats_code,
387 headers: OpenAiResponseHeaders {
388 extra: headers.extra,
389 },
390 body: openai_error_response_from_gemini(stats_code, body),
391 },
392 })
393 }
394}