Skip to main content

gproxy_protocol/transform/openai/generate_content/openai_response/gemini/
response.rs

1use std::collections::BTreeMap;
2
3use super::utils::{
4    gemini_citation_annotations, gemini_grounding_to_web_search_item, gemini_logprobs,
5};
6use crate::gemini::generate_content::response::GeminiGenerateContentResponse;
7use crate::gemini::generate_content::types as gt;
8use crate::openai::count_tokens::types as ot;
9use crate::openai::create_response::response::{OpenAiCreateResponseResponse, ResponseBody};
10use crate::openai::create_response::types as rt;
11use crate::openai::types::OpenAiResponseHeaders;
12use crate::transform::openai::generate_content::openai_chat_completions::gemini::utils::{
13    gemini_function_response_to_text, json_object_to_string, prompt_feedback_refusal_text,
14};
15use crate::transform::openai::model_list::gemini::utils::{
16    openai_error_response_from_gemini, strip_models_prefix,
17};
18use crate::transform::utils::TransformError;
19
20impl TryFrom<GeminiGenerateContentResponse> for OpenAiCreateResponseResponse {
21    type Error = TransformError;
22
23    fn try_from(value: GeminiGenerateContentResponse) -> Result<Self, TransformError> {
24        Ok(match value {
25            GeminiGenerateContentResponse::Success {
26                stats_code,
27                headers,
28                body,
29            } => {
30                let response_id = body.response_id.unwrap_or_default();
31                let response_model = body
32                    .model_version
33                    .as_deref()
34                    .map(strip_models_prefix)
35                    .unwrap_or_default();
36                let usage = body.usage_metadata.map(|usage| {
37                    let input_tokens = usage
38                        .prompt_token_count
39                        .unwrap_or(0)
40                        .saturating_add(usage.tool_use_prompt_token_count.unwrap_or(0));
41                    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
42                    let output_tokens = usage
43                        .candidates_token_count
44                        .unwrap_or(0)
45                        .saturating_add(usage.thoughts_token_count.unwrap_or(0));
46                    let total_tokens = usage
47                        .total_token_count
48                        .unwrap_or_else(|| input_tokens.saturating_add(output_tokens));
49
50                    rt::ResponseUsage {
51                        input_tokens,
52                        input_tokens_details: rt::ResponseInputTokensDetails { cached_tokens },
53                        output_tokens,
54                        output_tokens_details: rt::ResponseOutputTokensDetails {
55                            reasoning_tokens: usage.thoughts_token_count.unwrap_or(0),
56                        },
57                        total_tokens,
58                    }
59                });
60                let prompt_feedback = body.prompt_feedback;
61
62                let mut output = Vec::new();
63                let mut output_text_parts = Vec::new();
64                let mut tool_call_count = 0usize;
65                let mut first_finish_reason = None;
66
67                for (candidate_pos, candidate) in
68                    body.candidates.unwrap_or_default().into_iter().enumerate()
69                {
70                    let candidate_index = candidate.index.unwrap_or(candidate_pos as u32);
71
72                    if first_finish_reason.is_none() {
73                        first_finish_reason = candidate.finish_reason.clone();
74                    }
75
76                    if let Some(web_search_item) = gemini_grounding_to_web_search_item(
77                        candidate_index,
78                        candidate.grounding_metadata.as_ref(),
79                    ) {
80                        tool_call_count += 1;
81                        output.push(web_search_item);
82                    }
83
84                    let annotations =
85                        gemini_citation_annotations(candidate.citation_metadata.as_ref());
86                    let logprobs = gemini_logprobs(candidate.logprobs_result.as_ref());
87                    let mut logprobs_attached = false;
88                    let mut message_content = Vec::new();
89
90                    if let Some(content) = candidate.content {
91                        for (part_index, part) in content.parts.into_iter().enumerate() {
92                            if part.thought.unwrap_or(false) {
93                                if let Some(thinking) = part.text
94                                    && !thinking.is_empty()
95                                {
96                                    let reasoning_id =
97                                        part.thought_signature.unwrap_or_else(|| {
98                                            format!(
99                                                "candidate_{candidate_index}_reasoning_{part_index}"
100                                            )
101                                        });
102                                    output.push(rt::ResponseOutputItem::ReasoningItem(
103                                            ot::ResponseReasoningItem {
104                                                id: Some(reasoning_id),
105                                                summary: vec![ot::ResponseSummaryTextContent {
106                                                    text: thinking.clone(),
107                                                    type_: ot::ResponseSummaryTextContentType::SummaryText,
108                                                }],
109                                                type_: ot::ResponseReasoningItemType::Reasoning,
110                                                content: Some(vec![ot::ResponseReasoningTextContent {
111                                                    text: thinking,
112                                                    type_: ot::ResponseReasoningTextContentType::ReasoningText,
113                                                }]),
114                                                encrypted_content: None,
115                                                status: Some(ot::ResponseItemStatus::Completed),
116                                            },
117                                        ));
118                                }
119                                continue;
120                            }
121
122                            if let Some(function_call) = part.function_call {
123                                tool_call_count += 1;
124                                let call_id = function_call.id.unwrap_or_else(|| {
125                                    format!("candidate_{candidate_index}_tool_{part_index}")
126                                });
127                                output.push(rt::ResponseOutputItem::FunctionToolCall(
128                                    ot::ResponseFunctionToolCall {
129                                        arguments: function_call
130                                            .args
131                                            .as_ref()
132                                            .map(json_object_to_string)
133                                            .unwrap_or_else(|| "{}".to_string()),
134                                        call_id: call_id.clone(),
135                                        name: function_call.name,
136                                        type_: ot::ResponseFunctionToolCallType::FunctionCall,
137                                        id: Some(call_id),
138                                        status: Some(ot::ResponseItemStatus::Completed),
139                                    },
140                                ));
141                            }
142
143                            if let Some(function_response) = part.function_response {
144                                let call_id = function_response
145                                    .id
146                                    .clone()
147                                    .unwrap_or_else(|| function_response.name.clone());
148                                let output_text =
149                                    gemini_function_response_to_text(function_response);
150                                output.push(rt::ResponseOutputItem::FunctionCallOutput(
151                                    ot::ResponseFunctionCallOutput {
152                                        call_id,
153                                        output: ot::ResponseFunctionCallOutputContent::Text(
154                                            output_text,
155                                        ),
156                                        type_:
157                                            ot::ResponseFunctionCallOutputType::FunctionCallOutput,
158                                        id: None,
159                                        status: Some(ot::ResponseItemStatus::Completed),
160                                    },
161                                ));
162                            }
163
164                            if let Some(executable_code) = part.executable_code {
165                                tool_call_count += 1;
166                                output.push(rt::ResponseOutputItem::CodeInterpreterToolCall(
167                                    ot::ResponseCodeInterpreterToolCall {
168                                        id: format!("code_interpreter_{candidate_index}_{part_index}"),
169                                        code: executable_code.code,
170                                        container_id: "gemini".to_string(),
171                                        outputs: None,
172                                        status: ot::ResponseCodeInterpreterToolCallStatus::Completed,
173                                        type_: ot::ResponseCodeInterpreterToolCallType::CodeInterpreterCall,
174                                    },
175                                ));
176                            }
177
178                            if let Some(code_execution_result) = part.code_execution_result
179                                && let Some(result_text) = code_execution_result.output
180                                && !result_text.is_empty()
181                            {
182                                output.push(rt::ResponseOutputItem::FunctionCallOutput(
183                                    ot::ResponseFunctionCallOutput {
184                                        call_id: format!(
185                                            "code_execution_{candidate_index}_{part_index}"
186                                        ),
187                                        output: ot::ResponseFunctionCallOutputContent::Text(
188                                            result_text,
189                                        ),
190                                        type_:
191                                            ot::ResponseFunctionCallOutputType::FunctionCallOutput,
192                                        id: None,
193                                        status: Some(ot::ResponseItemStatus::Completed),
194                                    },
195                                ));
196                            }
197
198                            if let Some(text) = part.text
199                                && !text.is_empty()
200                            {
201                                output_text_parts.push(text.clone());
202                                message_content.push(ot::ResponseOutputContent::Text(
203                                    ot::ResponseOutputText {
204                                        annotations: annotations.clone(),
205                                        logprobs: if !logprobs_attached {
206                                            logprobs_attached = true;
207                                            logprobs.clone()
208                                        } else {
209                                            None
210                                        },
211                                        text,
212                                        type_: ot::ResponseOutputTextType::OutputText,
213                                    },
214                                ));
215                                continue;
216                            }
217
218                            if let Some(inline_data) = part.inline_data {
219                                let text = format!(
220                                    "data:{};base64,{}",
221                                    inline_data.mime_type, inline_data.data
222                                );
223                                output_text_parts.push(text.clone());
224                                message_content.push(ot::ResponseOutputContent::Text(
225                                    ot::ResponseOutputText {
226                                        annotations: Vec::new(),
227                                        logprobs: None,
228                                        text,
229                                        type_: ot::ResponseOutputTextType::OutputText,
230                                    },
231                                ));
232                            } else if let Some(file_data) = part.file_data {
233                                output_text_parts.push(file_data.file_uri.clone());
234                                message_content.push(ot::ResponseOutputContent::Text(
235                                    ot::ResponseOutputText {
236                                        annotations: Vec::new(),
237                                        logprobs: None,
238                                        text: file_data.file_uri,
239                                        type_: ot::ResponseOutputTextType::OutputText,
240                                    },
241                                ));
242                            }
243                        }
244                    }
245
246                    if message_content.is_empty()
247                        && let Some(finish_message) = candidate.finish_message
248                        && !finish_message.is_empty()
249                    {
250                        output_text_parts.push(finish_message.clone());
251                        message_content.push(ot::ResponseOutputContent::Text(
252                            ot::ResponseOutputText {
253                                annotations: Vec::new(),
254                                logprobs: None,
255                                text: finish_message,
256                                type_: ot::ResponseOutputTextType::OutputText,
257                            },
258                        ));
259                    }
260
261                    if !message_content.is_empty() {
262                        output.push(rt::ResponseOutputItem::Message(ot::ResponseOutputMessage {
263                            id: format!("{}_message_{}", response_id, candidate_index),
264                            content: message_content,
265                            role: ot::ResponseOutputMessageRole::Assistant,
266                            phase: Some(ot::ResponseMessagePhase::FinalAnswer),
267                            status: Some(ot::ResponseItemStatus::Completed),
268                            type_: Some(ot::ResponseOutputMessageType::Message),
269                        }));
270                    }
271                }
272
273                if output.is_empty()
274                    && let Some(refusal) = prompt_feedback_refusal_text(prompt_feedback.as_ref())
275                {
276                    output.push(rt::ResponseOutputItem::Message(ot::ResponseOutputMessage {
277                        id: format!("{}_message_0", response_id),
278                        content: vec![ot::ResponseOutputContent::Refusal(
279                            ot::ResponseOutputRefusal {
280                                refusal,
281                                type_: ot::ResponseOutputRefusalType::Refusal,
282                            },
283                        )],
284                        role: ot::ResponseOutputMessageRole::Assistant,
285                        phase: Some(ot::ResponseMessagePhase::FinalAnswer),
286                        status: Some(ot::ResponseItemStatus::Completed),
287                        type_: Some(ot::ResponseOutputMessageType::Message),
288                    }));
289                }
290
291                let incomplete_reason = match first_finish_reason.as_ref() {
292                    Some(gt::GeminiFinishReason::MaxTokens) => {
293                        Some(rt::ResponseIncompleteReason::MaxOutputTokens)
294                    }
295                    Some(
296                        gt::GeminiFinishReason::Safety
297                        | gt::GeminiFinishReason::Recitation
298                        | gt::GeminiFinishReason::Blocklist
299                        | gt::GeminiFinishReason::ProhibitedContent
300                        | gt::GeminiFinishReason::Spii
301                        | gt::GeminiFinishReason::ImageSafety
302                        | gt::GeminiFinishReason::ImageProhibitedContent
303                        | gt::GeminiFinishReason::ImageRecitation,
304                    ) => Some(rt::ResponseIncompleteReason::ContentFilter),
305                    _ => None,
306                }
307                .or_else(|| {
308                    match prompt_feedback
309                        .as_ref()
310                        .and_then(|feedback| feedback.block_reason.as_ref())
311                    {
312                        Some(gt::GeminiBlockReason::Safety)
313                        | Some(gt::GeminiBlockReason::Blocklist)
314                        | Some(gt::GeminiBlockReason::ProhibitedContent)
315                        | Some(gt::GeminiBlockReason::ImageSafety) => {
316                            Some(rt::ResponseIncompleteReason::ContentFilter)
317                        }
318                        _ => None,
319                    }
320                });
321                let is_incomplete = incomplete_reason.is_some();
322
323                OpenAiCreateResponseResponse::Success {
324                    stats_code,
325                    headers: OpenAiResponseHeaders {
326                        extra: headers.extra,
327                    },
328                    body: ResponseBody {
329                        id: response_id,
330                        created_at: 0,
331                        error: None,
332                        incomplete_details: incomplete_reason.map(|reason| {
333                            rt::ResponseIncompleteDetails {
334                                reason: Some(reason),
335                            }
336                        }),
337                        instructions: Some(ot::ResponseInput::Text(String::new())),
338                        metadata: BTreeMap::new(),
339                        model: response_model,
340                        object: rt::ResponseObject::Response,
341                        output,
342                        parallel_tool_calls: tool_call_count > 1,
343                        temperature: 1.0,
344                        tool_choice: if tool_call_count > 0 {
345                            ot::ResponseToolChoice::Options(ot::ResponseToolChoiceOptions::Required)
346                        } else {
347                            ot::ResponseToolChoice::Options(ot::ResponseToolChoiceOptions::Auto)
348                        },
349                        tools: Vec::new(),
350                        top_p: 1.0,
351                        background: None,
352                        completed_at: None,
353                        conversation: None,
354                        max_output_tokens: None,
355                        max_tool_calls: None,
356                        output_text: if output_text_parts.is_empty() {
357                            None
358                        } else {
359                            Some(output_text_parts.join("\n"))
360                        },
361                        previous_response_id: None,
362                        prompt: None,
363                        prompt_cache_key: None,
364                        prompt_cache_retention: None,
365                        reasoning: None,
366                        safety_identifier: None,
367                        service_tier: None,
368                        status: Some(if is_incomplete {
369                            rt::ResponseStatus::Incomplete
370                        } else {
371                            rt::ResponseStatus::Completed
372                        }),
373                        text: None,
374                        top_logprobs: None,
375                        truncation: None,
376                        usage,
377                        user: None,
378                    },
379                }
380            }
381            GeminiGenerateContentResponse::Error {
382                stats_code,
383                headers,
384                body,
385            } => OpenAiCreateResponseResponse::Error {
386                stats_code,
387                headers: OpenAiResponseHeaders {
388                    extra: headers.extra,
389                },
390                body: openai_error_response_from_gemini(stats_code, body),
391            },
392        })
393    }
394}