Skip to main content

gproxy_protocol/transform/gemini/websocket/to_http/
response.rs

1use http::StatusCode;
2
3use crate::gemini::count_tokens::types::{GeminiContentRole, GeminiFunctionCall, GeminiPart};
4use crate::gemini::generate_content::response::GeminiGenerateContentResponse;
5use crate::gemini::generate_content::response::ResponseBody as GeminiGenerateContentResponseBody;
6use crate::gemini::generate_content::types::{
7    GeminiCandidate, GeminiContent, GeminiFinishReason, GeminiUsageMetadata,
8};
9use crate::gemini::live::response::GeminiLiveMessageResponse;
10use crate::gemini::live::types::GeminiBidiGenerateContentServerMessageType;
11use crate::gemini::live::types::GeminiLiveUsageMetadata;
12use crate::gemini::stream_generate_content::response::GeminiStreamGenerateContentResponse;
13use crate::gemini::types::GeminiResponseHeaders;
14use crate::transform::gemini::websocket::context::GeminiWebsocketTransformContext;
15use crate::transform::utils::TransformError;
16
17pub fn usage_live_to_generate(
18    usage: Option<GeminiLiveUsageMetadata>,
19) -> Option<GeminiUsageMetadata> {
20    usage.map(|usage| GeminiUsageMetadata {
21        prompt_token_count: usage.prompt_token_count,
22        cached_content_token_count: usage.cached_content_token_count,
23        candidates_token_count: usage.response_token_count,
24        tool_use_prompt_token_count: usage.tool_use_prompt_token_count,
25        thoughts_token_count: usage.thoughts_token_count,
26        total_token_count: usage.total_token_count,
27        prompt_tokens_details: usage.prompt_tokens_details,
28        cache_tokens_details: usage.cache_tokens_details,
29        candidates_tokens_details: usage.response_tokens_details,
30        tool_use_prompt_tokens_details: usage.tool_use_prompt_tokens_details,
31    })
32}
33
34pub fn server_message_to_chunk(
35    message: crate::gemini::live::types::GeminiBidiGenerateContentServerMessage,
36    ctx: &mut GeminiWebsocketTransformContext,
37) -> Option<GeminiGenerateContentResponseBody> {
38    let usage_metadata = usage_live_to_generate(message.usage_metadata);
39
40    match message.message_type {
41        GeminiBidiGenerateContentServerMessageType::SetupComplete { .. } => {
42            ctx.push_warning(
43                "gemini websocket to_http response: dropped setupComplete event".to_string(),
44            );
45            usage_metadata.map(|usage| GeminiGenerateContentResponseBody {
46                usage_metadata: Some(usage),
47                ..GeminiGenerateContentResponseBody::default()
48            })
49        }
50        GeminiBidiGenerateContentServerMessageType::GoAway { go_away } => {
51            ctx.push_warning(format!(
52                "gemini websocket to_http response: dropped goAway event (timeLeft={})",
53                go_away.time_left
54            ));
55            usage_metadata.map(|usage| GeminiGenerateContentResponseBody {
56                usage_metadata: Some(usage),
57                ..GeminiGenerateContentResponseBody::default()
58            })
59        }
60        GeminiBidiGenerateContentServerMessageType::SessionResumptionUpdate { .. } => {
61            ctx.push_warning(
62                "gemini websocket to_http response: dropped sessionResumptionUpdate event"
63                    .to_string(),
64            );
65            usage_metadata.map(|usage| GeminiGenerateContentResponseBody {
66                usage_metadata: Some(usage),
67                ..GeminiGenerateContentResponseBody::default()
68            })
69        }
70        GeminiBidiGenerateContentServerMessageType::ToolCallCancellation { .. } => {
71            ctx.push_warning(
72                "gemini websocket to_http response: dropped toolCallCancellation event".to_string(),
73            );
74            usage_metadata.map(|usage| GeminiGenerateContentResponseBody {
75                usage_metadata: Some(usage),
76                ..GeminiGenerateContentResponseBody::default()
77            })
78        }
79        GeminiBidiGenerateContentServerMessageType::ServerContent { server_content } => {
80            if server_content.interrupted == Some(true) {
81                ctx.push_warning(
82                    "gemini websocket to_http response: dropped interrupted=true flag".to_string(),
83                );
84            }
85            if server_content.input_transcription.is_some() {
86                ctx.push_warning(
87                    "gemini websocket to_http response: dropped inputTranscription".to_string(),
88                );
89            }
90            if server_content.output_transcription.is_some() {
91                ctx.push_warning(
92                    "gemini websocket to_http response: dropped outputTranscription".to_string(),
93                );
94            }
95            let candidates = server_content.model_turn.map(|model_turn| {
96                vec![GeminiCandidate {
97                    content: Some(model_turn),
98                    finish_reason: if server_content.generation_complete == Some(true)
99                        || server_content.turn_complete == Some(true)
100                    {
101                        Some(GeminiFinishReason::Stop)
102                    } else {
103                        None
104                    },
105                    grounding_metadata: server_content.grounding_metadata,
106                    url_context_metadata: server_content.url_context_metadata,
107                    index: Some(0),
108                    ..GeminiCandidate::default()
109                }]
110            });
111
112            if candidates.is_none() && usage_metadata.is_none() {
113                return None;
114            }
115
116            Some(GeminiGenerateContentResponseBody {
117                candidates,
118                usage_metadata,
119                ..GeminiGenerateContentResponseBody::default()
120            })
121        }
122        GeminiBidiGenerateContentServerMessageType::ToolCall { tool_call } => {
123            let calls = tool_call.function_calls.unwrap_or_default();
124            if calls.is_empty() && usage_metadata.is_none() {
125                return None;
126            }
127
128            let model_turn = GeminiContent {
129                role: Some(GeminiContentRole::Model),
130                parts: calls
131                    .into_iter()
132                    .map(|call| GeminiPart {
133                        function_call: Some(GeminiFunctionCall {
134                            id: call.id,
135                            name: call.name,
136                            args: call.args,
137                        }),
138                        ..GeminiPart::default()
139                    })
140                    .collect(),
141            };
142
143            Some(GeminiGenerateContentResponseBody {
144                candidates: Some(vec![GeminiCandidate {
145                    content: Some(model_turn),
146                    index: Some(0),
147                    ..GeminiCandidate::default()
148                }]),
149                usage_metadata,
150                ..GeminiGenerateContentResponseBody::default()
151            })
152        }
153    }
154}
155
156impl TryFrom<Vec<GeminiLiveMessageResponse>> for GeminiStreamGenerateContentResponse {
157    type Error = TransformError;
158
159    fn try_from(value: Vec<GeminiLiveMessageResponse>) -> Result<Self, TransformError> {
160        Ok(gemini_live_messages_to_stream_response_with_context(value)?.0)
161    }
162}
163
164impl TryFrom<Vec<GeminiLiveMessageResponse>> for GeminiGenerateContentResponse {
165    type Error = TransformError;
166
167    fn try_from(value: Vec<GeminiLiveMessageResponse>) -> Result<Self, TransformError> {
168        Ok(gemini_live_messages_to_nonstream_response_with_context(value)?.0)
169    }
170}
171
172pub fn gemini_live_messages_to_nonstream_response_with_context(
173    value: Vec<GeminiLiveMessageResponse>,
174) -> Result<
175    (
176        GeminiGenerateContentResponse,
177        GeminiWebsocketTransformContext,
178    ),
179    TransformError,
180> {
181    let mut ctx = GeminiWebsocketTransformContext::default();
182    let mut chunks = Vec::new();
183
184    for message in value {
185        match message {
186            GeminiLiveMessageResponse::Error(body) => {
187                return Ok((
188                    GeminiGenerateContentResponse::Error {
189                        stats_code: StatusCode::BAD_REQUEST,
190                        headers: GeminiResponseHeaders::default(),
191                        body,
192                    },
193                    ctx,
194                ));
195            }
196            GeminiLiveMessageResponse::Message(server) => {
197                if let Some(chunk) = server_message_to_chunk(server, &mut ctx) {
198                    chunks.push(chunk);
199                }
200            }
201        }
202    }
203
204    // Merge all chunks into a single response body.
205    let mut merged = GeminiGenerateContentResponseBody::default();
206    for chunk in chunks {
207        if let Some(candidates) = chunk.candidates {
208            merged
209                .candidates
210                .get_or_insert_with(Vec::new)
211                .extend(candidates);
212        }
213        if chunk.usage_metadata.is_some() {
214            merged.usage_metadata = chunk.usage_metadata;
215        }
216    }
217
218    Ok((
219        GeminiGenerateContentResponse::Success {
220            stats_code: StatusCode::OK,
221            headers: GeminiResponseHeaders::default(),
222            body: merged,
223        },
224        ctx,
225    ))
226}
227
228pub fn gemini_live_messages_to_stream_response_with_context(
229    value: Vec<GeminiLiveMessageResponse>,
230) -> Result<
231    (
232        GeminiStreamGenerateContentResponse,
233        GeminiWebsocketTransformContext,
234    ),
235    TransformError,
236> {
237    let mut ctx = GeminiWebsocketTransformContext::default();
238
239    for message in &value {
240        if let GeminiLiveMessageResponse::Error(body) = message {
241            return Ok((
242                GeminiStreamGenerateContentResponse::Error {
243                    stats_code: StatusCode::BAD_REQUEST,
244                    headers: GeminiResponseHeaders::default(),
245                    body: body.clone(),
246                },
247                ctx,
248            ));
249        }
250    }
251
252    // Validate chunks can be produced (side-effects into ctx).
253    for message in value {
254        if let GeminiLiveMessageResponse::Message(server) = message {
255            let _ = server_message_to_chunk(server, &mut ctx);
256        }
257    }
258
259    Ok((
260        GeminiStreamGenerateContentResponse::Success {
261            stats_code: StatusCode::OK,
262            headers: GeminiResponseHeaders::default(),
263        },
264        ctx,
265    ))
266}