Skip to main content

gproxy_protocol/transform/gemini/generate_content/openai_chat_completions/
response.rs

1use crate::gemini::count_tokens::types::{GeminiContentRole, GeminiFunctionCall, GeminiPart};
2use crate::gemini::generate_content::response::{GeminiGenerateContentResponse, ResponseBody};
3use crate::gemini::generate_content::types::{
4    GeminiBlockReason, GeminiCandidate, GeminiContent, GeminiFinishReason, GeminiPromptFeedback,
5    GeminiUsageMetadata,
6};
7use crate::gemini::types::GeminiResponseHeaders;
8use crate::openai::create_chat_completions::response::OpenAiChatCompletionsResponse;
9use crate::openai::create_chat_completions::types::{
10    ChatCompletionFinishReason, ChatCompletionMessageToolCall,
11};
12use crate::transform::gemini::generate_content::utils::{
13    gemini_error_response_from_openai, parse_json_object_or_empty,
14};
15use crate::transform::utils::TransformError;
16
17impl TryFrom<OpenAiChatCompletionsResponse> for GeminiGenerateContentResponse {
18    type Error = TransformError;
19
20    fn try_from(value: OpenAiChatCompletionsResponse) -> Result<Self, TransformError> {
21        Ok(match value {
22            OpenAiChatCompletionsResponse::Success {
23                stats_code,
24                headers,
25                body,
26            } => {
27                let mut parts = Vec::new();
28                let mut has_refusal = false;
29                let mut finish_reason = Some(GeminiFinishReason::Stop);
30                let choice = body.choices.into_iter().next();
31                if let Some(choice) = choice {
32                    finish_reason = Some(match choice.finish_reason {
33                        ChatCompletionFinishReason::Stop => GeminiFinishReason::Stop,
34                        ChatCompletionFinishReason::Length => GeminiFinishReason::MaxTokens,
35                        ChatCompletionFinishReason::ToolCalls
36                        | ChatCompletionFinishReason::FunctionCall => {
37                            GeminiFinishReason::UnexpectedToolCall
38                        }
39                        ChatCompletionFinishReason::ContentFilter => GeminiFinishReason::Safety,
40                    });
41
42                    if let Some(text) = choice.message.content
43                        && !text.is_empty()
44                    {
45                        parts.push(GeminiPart {
46                            text: Some(text),
47                            ..GeminiPart::default()
48                        });
49                    }
50
51                    if let Some(refusal) = choice.message.refusal {
52                        has_refusal = true;
53                        if !refusal.is_empty() {
54                            parts.push(GeminiPart {
55                                text: Some(refusal),
56                                ..GeminiPart::default()
57                            });
58                        }
59                    }
60
61                    if let Some(function_call) = choice.message.function_call {
62                        parts.push(GeminiPart {
63                            function_call: Some(GeminiFunctionCall {
64                                id: Some("function_call".to_string()),
65                                name: function_call.name,
66                                args: Some(parse_json_object_or_empty(&function_call.arguments)),
67                            }),
68                            ..GeminiPart::default()
69                        });
70                    }
71
72                    if let Some(tool_calls) = choice.message.tool_calls {
73                        for call in tool_calls {
74                            match call {
75                                ChatCompletionMessageToolCall::Function(call) => {
76                                    parts.push(GeminiPart {
77                                        function_call: Some(GeminiFunctionCall {
78                                            id: Some(call.id),
79                                            name: call.function.name,
80                                            args: Some(parse_json_object_or_empty(
81                                                &call.function.arguments,
82                                            )),
83                                        }),
84                                        ..GeminiPart::default()
85                                    });
86                                }
87                                ChatCompletionMessageToolCall::Custom(call) => {
88                                    parts.push(GeminiPart {
89                                        function_call: Some(GeminiFunctionCall {
90                                            id: Some(call.id),
91                                            name: call.custom.name,
92                                            args: Some(parse_json_object_or_empty(
93                                                &call.custom.input,
94                                            )),
95                                        }),
96                                        ..GeminiPart::default()
97                                    });
98                                }
99                            }
100                        }
101                    }
102                }
103
104                if parts.is_empty() {
105                    parts.push(GeminiPart {
106                        text: Some(String::new()),
107                        ..GeminiPart::default()
108                    });
109                }
110
111                let prompt_feedback =
112                    if has_refusal || matches!(finish_reason, Some(GeminiFinishReason::Safety)) {
113                        Some(GeminiPromptFeedback {
114                            block_reason: Some(GeminiBlockReason::Safety),
115                            safety_ratings: None,
116                        })
117                    } else {
118                        None
119                    };
120
121                let usage_metadata = body.usage.map(|usage| GeminiUsageMetadata {
122                    prompt_token_count: Some(usage.prompt_tokens),
123                    cached_content_token_count: usage
124                        .prompt_tokens_details
125                        .as_ref()
126                        .and_then(|details| details.cached_tokens),
127                    candidates_token_count: Some(usage.completion_tokens),
128                    thoughts_token_count: usage
129                        .completion_tokens_details
130                        .as_ref()
131                        .and_then(|details| details.reasoning_tokens),
132                    total_token_count: Some(usage.total_tokens),
133                    ..GeminiUsageMetadata::default()
134                });
135
136                GeminiGenerateContentResponse::Success {
137                    stats_code,
138                    headers: GeminiResponseHeaders {
139                        extra: headers.extra,
140                    },
141                    body: ResponseBody {
142                        candidates: Some(vec![GeminiCandidate {
143                            content: Some(GeminiContent {
144                                parts,
145                                role: Some(GeminiContentRole::Model),
146                            }),
147                            finish_reason,
148                            index: Some(0),
149                            ..GeminiCandidate::default()
150                        }]),
151                        prompt_feedback,
152                        usage_metadata,
153                        model_version: Some(body.model),
154                        response_id: Some(body.id),
155                        model_status: None,
156                    },
157                }
158            }
159            OpenAiChatCompletionsResponse::Error {
160                stats_code,
161                headers,
162                body,
163            } => GeminiGenerateContentResponse::Error {
164                stats_code,
165                headers: GeminiResponseHeaders {
166                    extra: headers.extra,
167                },
168                body: gemini_error_response_from_openai(stats_code, body),
169            },
170        })
171    }
172}