gproxy_protocol/transform/gemini/websocket/to_http/
response.rs1use http::StatusCode;
2
3use crate::gemini::count_tokens::types::{GeminiContentRole, GeminiFunctionCall, GeminiPart};
4use crate::gemini::generate_content::response::GeminiGenerateContentResponse;
5use crate::gemini::generate_content::response::ResponseBody as GeminiGenerateContentResponseBody;
6use crate::gemini::generate_content::types::{
7 GeminiCandidate, GeminiContent, GeminiFinishReason, GeminiUsageMetadata,
8};
9use crate::gemini::live::response::GeminiLiveMessageResponse;
10use crate::gemini::live::types::GeminiBidiGenerateContentServerMessageType;
11use crate::gemini::live::types::GeminiLiveUsageMetadata;
12use crate::gemini::stream_generate_content::response::GeminiStreamGenerateContentResponse;
13use crate::gemini::types::GeminiResponseHeaders;
14use crate::transform::gemini::websocket::context::GeminiWebsocketTransformContext;
15use crate::transform::utils::TransformError;
16
17pub fn usage_live_to_generate(
18 usage: Option<GeminiLiveUsageMetadata>,
19) -> Option<GeminiUsageMetadata> {
20 usage.map(|usage| GeminiUsageMetadata {
21 prompt_token_count: usage.prompt_token_count,
22 cached_content_token_count: usage.cached_content_token_count,
23 candidates_token_count: usage.response_token_count,
24 tool_use_prompt_token_count: usage.tool_use_prompt_token_count,
25 thoughts_token_count: usage.thoughts_token_count,
26 total_token_count: usage.total_token_count,
27 prompt_tokens_details: usage.prompt_tokens_details,
28 cache_tokens_details: usage.cache_tokens_details,
29 candidates_tokens_details: usage.response_tokens_details,
30 tool_use_prompt_tokens_details: usage.tool_use_prompt_tokens_details,
31 })
32}
33
34pub fn server_message_to_chunk(
35 message: crate::gemini::live::types::GeminiBidiGenerateContentServerMessage,
36 ctx: &mut GeminiWebsocketTransformContext,
37) -> Option<GeminiGenerateContentResponseBody> {
38 let usage_metadata = usage_live_to_generate(message.usage_metadata);
39
40 match message.message_type {
41 GeminiBidiGenerateContentServerMessageType::SetupComplete { .. } => {
42 ctx.push_warning(
43 "gemini websocket to_http response: dropped setupComplete event".to_string(),
44 );
45 usage_metadata.map(|usage| GeminiGenerateContentResponseBody {
46 usage_metadata: Some(usage),
47 ..GeminiGenerateContentResponseBody::default()
48 })
49 }
50 GeminiBidiGenerateContentServerMessageType::GoAway { go_away } => {
51 ctx.push_warning(format!(
52 "gemini websocket to_http response: dropped goAway event (timeLeft={})",
53 go_away.time_left
54 ));
55 usage_metadata.map(|usage| GeminiGenerateContentResponseBody {
56 usage_metadata: Some(usage),
57 ..GeminiGenerateContentResponseBody::default()
58 })
59 }
60 GeminiBidiGenerateContentServerMessageType::SessionResumptionUpdate { .. } => {
61 ctx.push_warning(
62 "gemini websocket to_http response: dropped sessionResumptionUpdate event"
63 .to_string(),
64 );
65 usage_metadata.map(|usage| GeminiGenerateContentResponseBody {
66 usage_metadata: Some(usage),
67 ..GeminiGenerateContentResponseBody::default()
68 })
69 }
70 GeminiBidiGenerateContentServerMessageType::ToolCallCancellation { .. } => {
71 ctx.push_warning(
72 "gemini websocket to_http response: dropped toolCallCancellation event".to_string(),
73 );
74 usage_metadata.map(|usage| GeminiGenerateContentResponseBody {
75 usage_metadata: Some(usage),
76 ..GeminiGenerateContentResponseBody::default()
77 })
78 }
79 GeminiBidiGenerateContentServerMessageType::ServerContent { server_content } => {
80 if server_content.interrupted == Some(true) {
81 ctx.push_warning(
82 "gemini websocket to_http response: dropped interrupted=true flag".to_string(),
83 );
84 }
85 if server_content.input_transcription.is_some() {
86 ctx.push_warning(
87 "gemini websocket to_http response: dropped inputTranscription".to_string(),
88 );
89 }
90 if server_content.output_transcription.is_some() {
91 ctx.push_warning(
92 "gemini websocket to_http response: dropped outputTranscription".to_string(),
93 );
94 }
95 let candidates = server_content.model_turn.map(|model_turn| {
96 vec![GeminiCandidate {
97 content: Some(model_turn),
98 finish_reason: if server_content.generation_complete == Some(true)
99 || server_content.turn_complete == Some(true)
100 {
101 Some(GeminiFinishReason::Stop)
102 } else {
103 None
104 },
105 grounding_metadata: server_content.grounding_metadata,
106 url_context_metadata: server_content.url_context_metadata,
107 index: Some(0),
108 ..GeminiCandidate::default()
109 }]
110 });
111
112 if candidates.is_none() && usage_metadata.is_none() {
113 return None;
114 }
115
116 Some(GeminiGenerateContentResponseBody {
117 candidates,
118 usage_metadata,
119 ..GeminiGenerateContentResponseBody::default()
120 })
121 }
122 GeminiBidiGenerateContentServerMessageType::ToolCall { tool_call } => {
123 let calls = tool_call.function_calls.unwrap_or_default();
124 if calls.is_empty() && usage_metadata.is_none() {
125 return None;
126 }
127
128 let model_turn = GeminiContent {
129 role: Some(GeminiContentRole::Model),
130 parts: calls
131 .into_iter()
132 .map(|call| GeminiPart {
133 function_call: Some(GeminiFunctionCall {
134 id: call.id,
135 name: call.name,
136 args: call.args,
137 }),
138 ..GeminiPart::default()
139 })
140 .collect(),
141 };
142
143 Some(GeminiGenerateContentResponseBody {
144 candidates: Some(vec![GeminiCandidate {
145 content: Some(model_turn),
146 index: Some(0),
147 ..GeminiCandidate::default()
148 }]),
149 usage_metadata,
150 ..GeminiGenerateContentResponseBody::default()
151 })
152 }
153 }
154}
155
156impl TryFrom<Vec<GeminiLiveMessageResponse>> for GeminiStreamGenerateContentResponse {
157 type Error = TransformError;
158
159 fn try_from(value: Vec<GeminiLiveMessageResponse>) -> Result<Self, TransformError> {
160 Ok(gemini_live_messages_to_stream_response_with_context(value)?.0)
161 }
162}
163
164impl TryFrom<Vec<GeminiLiveMessageResponse>> for GeminiGenerateContentResponse {
165 type Error = TransformError;
166
167 fn try_from(value: Vec<GeminiLiveMessageResponse>) -> Result<Self, TransformError> {
168 Ok(gemini_live_messages_to_nonstream_response_with_context(value)?.0)
169 }
170}
171
172pub fn gemini_live_messages_to_nonstream_response_with_context(
173 value: Vec<GeminiLiveMessageResponse>,
174) -> Result<
175 (
176 GeminiGenerateContentResponse,
177 GeminiWebsocketTransformContext,
178 ),
179 TransformError,
180> {
181 let mut ctx = GeminiWebsocketTransformContext::default();
182 let mut chunks = Vec::new();
183
184 for message in value {
185 match message {
186 GeminiLiveMessageResponse::Error(body) => {
187 return Ok((
188 GeminiGenerateContentResponse::Error {
189 stats_code: StatusCode::BAD_REQUEST,
190 headers: GeminiResponseHeaders::default(),
191 body,
192 },
193 ctx,
194 ));
195 }
196 GeminiLiveMessageResponse::Message(server) => {
197 if let Some(chunk) = server_message_to_chunk(server, &mut ctx) {
198 chunks.push(chunk);
199 }
200 }
201 }
202 }
203
204 let mut merged = GeminiGenerateContentResponseBody::default();
206 for chunk in chunks {
207 if let Some(candidates) = chunk.candidates {
208 merged
209 .candidates
210 .get_or_insert_with(Vec::new)
211 .extend(candidates);
212 }
213 if chunk.usage_metadata.is_some() {
214 merged.usage_metadata = chunk.usage_metadata;
215 }
216 }
217
218 Ok((
219 GeminiGenerateContentResponse::Success {
220 stats_code: StatusCode::OK,
221 headers: GeminiResponseHeaders::default(),
222 body: merged,
223 },
224 ctx,
225 ))
226}
227
228pub fn gemini_live_messages_to_stream_response_with_context(
229 value: Vec<GeminiLiveMessageResponse>,
230) -> Result<
231 (
232 GeminiStreamGenerateContentResponse,
233 GeminiWebsocketTransformContext,
234 ),
235 TransformError,
236> {
237 let mut ctx = GeminiWebsocketTransformContext::default();
238
239 for message in &value {
240 if let GeminiLiveMessageResponse::Error(body) = message {
241 return Ok((
242 GeminiStreamGenerateContentResponse::Error {
243 stats_code: StatusCode::BAD_REQUEST,
244 headers: GeminiResponseHeaders::default(),
245 body: body.clone(),
246 },
247 ctx,
248 ));
249 }
250 }
251
252 for message in value {
254 if let GeminiLiveMessageResponse::Message(server) = message {
255 let _ = server_message_to_chunk(server, &mut ctx);
256 }
257 }
258
259 Ok((
260 GeminiStreamGenerateContentResponse::Success {
261 stats_code: StatusCode::OK,
262 headers: GeminiResponseHeaders::default(),
263 },
264 ctx,
265 ))
266}