Skip to main content

gproxy_protocol/transform/claude/stream_generate_content/gemini/
response.rs

1use crate::claude::create_message::stream::ClaudeStreamEvent;
2use crate::claude::create_message::types::{BetaServiceTier, BetaStopReason};
3use crate::gemini::count_tokens::types::{GeminiLanguage, GeminiOutcome};
4use crate::gemini::generate_content::response::ResponseBody as GeminiGenerateContentResponseBody;
5use crate::gemini::generate_content::types::{GeminiBlockReason, GeminiFinishReason};
6use crate::gemini::stream_generate_content::response::GeminiStreamGenerateContentResponse;
7use crate::transform::claude::stream_generate_content::utils::{
8    message_delta_event, message_start_event, message_stop_event, push_text_block,
9    push_thinking_block, push_tool_use_block, stream_error_event,
10};
11use crate::transform::utils::TransformError;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14enum StreamState {
15    Init,
16    Running,
17    Finished,
18}
19
20#[derive(Debug, Clone)]
21pub struct GeminiToClaudeStream {
22    state: StreamState,
23    next_block_index: u64,
24    chunk_seq: u64,
25    message_id: String,
26    model: String,
27    input_tokens: u64,
28    cached_input_tokens: u64,
29    output_tokens: u64,
30    stop_reason: Option<BetaStopReason>,
31    has_tool_use: bool,
32    has_refusal: bool,
33}
34
35impl Default for GeminiToClaudeStream {
36    fn default() -> Self {
37        Self {
38            state: StreamState::Init,
39            next_block_index: 0,
40            chunk_seq: 0,
41            message_id: String::new(),
42            model: String::new(),
43            input_tokens: 0,
44            cached_input_tokens: 0,
45            output_tokens: 0,
46            stop_reason: None,
47            has_tool_use: false,
48            has_refusal: false,
49        }
50    }
51}
52
53impl GeminiToClaudeStream {
54    pub fn is_finished(&self) -> bool {
55        matches!(self.state, StreamState::Finished)
56    }
57
58    fn update_envelope_from_chunk(&mut self, chunk: &GeminiGenerateContentResponseBody) {
59        if let Some(response_id) = chunk.response_id.as_ref() {
60            self.message_id = response_id.clone();
61        }
62        if let Some(model_version) = chunk.model_version.as_ref() {
63            self.model = model_version.clone();
64        }
65        if let Some(usage_metadata) = chunk.usage_metadata.as_ref() {
66            let prompt_input_tokens = usage_metadata
67                .prompt_token_count
68                .unwrap_or(0)
69                .saturating_add(usage_metadata.tool_use_prompt_token_count.unwrap_or(0));
70            let cached_tokens = usage_metadata.cached_content_token_count.unwrap_or(0);
71            let output_tokens = usage_metadata
72                .candidates_token_count
73                .unwrap_or(0)
74                .saturating_add(usage_metadata.thoughts_token_count.unwrap_or(0));
75            let total_input_tokens = usage_metadata
76                .total_token_count
77                .map(|total| total.saturating_sub(output_tokens))
78                .unwrap_or_else(|| prompt_input_tokens.saturating_add(cached_tokens));
79
80            self.input_tokens = total_input_tokens.saturating_sub(cached_tokens);
81            self.cached_input_tokens = cached_tokens;
82            self.output_tokens = output_tokens;
83        }
84    }
85
86    fn ensure_running(&mut self, out: &mut Vec<ClaudeStreamEvent>) {
87        if matches!(self.state, StreamState::Init) {
88            out.push(message_start_event(
89                self.message_id.clone(),
90                self.model.clone(),
91                BetaServiceTier::Standard,
92                self.input_tokens,
93                self.cached_input_tokens,
94            ));
95            self.state = StreamState::Running;
96        }
97    }
98
99    fn emit_text_block(&mut self, out: &mut Vec<ClaudeStreamEvent>, text: String) {
100        self.ensure_running(out);
101        let _ = push_text_block(out, &mut self.next_block_index, text);
102    }
103
104    fn emit_thinking_block(
105        &mut self,
106        out: &mut Vec<ClaudeStreamEvent>,
107        signature: String,
108        thinking: String,
109    ) {
110        self.ensure_running(out);
111        let _ = push_thinking_block(out, &mut self.next_block_index, signature, thinking);
112    }
113
114    fn emit_tool_use_block(
115        &mut self,
116        out: &mut Vec<ClaudeStreamEvent>,
117        id: String,
118        name: String,
119        input_json: Option<String>,
120    ) {
121        self.ensure_running(out);
122        self.has_tool_use = true;
123        let _ = push_tool_use_block(out, &mut self.next_block_index, id, name, input_json);
124    }
125
126    pub fn on_chunk(
127        &mut self,
128        chunk: GeminiGenerateContentResponseBody,
129        out: &mut Vec<ClaudeStreamEvent>,
130    ) {
131        if self.is_finished() {
132            return;
133        }
134
135        self.update_envelope_from_chunk(&chunk);
136        let chunk_index = self.chunk_seq;
137        self.chunk_seq = self.chunk_seq.saturating_add(1);
138
139        let mut chunk_has_content = false;
140
141        if let Some(status_message) = chunk
142            .model_status
143            .as_ref()
144            .and_then(|status| status.message.as_ref())
145            && !status_message.is_empty()
146        {
147            chunk_has_content = true;
148            self.emit_text_block(out, format!("model_status: {status_message}"));
149        }
150
151        if let Some(candidates) = chunk.candidates {
152            for (candidate_index, candidate) in candidates.into_iter().enumerate() {
153                let mut candidate_has_content = false;
154                if let Some(content) = candidate.content {
155                    for (part_index, part) in content.parts.into_iter().enumerate() {
156                        if part.thought.unwrap_or(false) {
157                            if let Some(thinking) = part.text {
158                                candidate_has_content = true;
159                                chunk_has_content = true;
160                                self.emit_thinking_block(
161                                    out,
162                                    part.thought_signature.unwrap_or_else(|| {
163                                        format!(
164                                            "thought_{chunk_index}_{candidate_index}_{part_index}"
165                                        )
166                                    }),
167                                    thinking,
168                                );
169                            }
170                        } else if let Some(text) = part.text {
171                            candidate_has_content = true;
172                            chunk_has_content = true;
173                            self.emit_text_block(out, text);
174                        }
175
176                        if let Some(inline_data) = part.inline_data {
177                            candidate_has_content = true;
178                            chunk_has_content = true;
179                            self.emit_text_block(
180                                out,
181                                format!(
182                                    "inline_data({}): {}",
183                                    inline_data.mime_type, inline_data.data
184                                ),
185                            );
186                        }
187
188                        if let Some(function_call) = part.function_call {
189                            candidate_has_content = true;
190                            chunk_has_content = true;
191                            self.emit_tool_use_block(
192                                out,
193                                function_call.id.unwrap_or_else(|| {
194                                    format!(
195                                        "tool_call_{chunk_index}_{candidate_index}_{part_index}"
196                                    )
197                                }),
198                                function_call.name,
199                                function_call
200                                    .args
201                                    .and_then(|args| serde_json::to_string(&args).ok()),
202                            );
203                        }
204
205                        if let Some(function_response) = part.function_response {
206                            if let Ok(response_json) =
207                                serde_json::to_string(&function_response.response)
208                                && !response_json.is_empty()
209                            {
210                                candidate_has_content = true;
211                                chunk_has_content = true;
212                                self.emit_text_block(
213                                    out,
214                                    format!(
215                                        "function_response({}): {response_json}",
216                                        function_response.name
217                                    ),
218                                );
219                            }
220                            if let Some(parts) = function_response.parts {
221                                for (response_part_index, response_part) in
222                                    parts.into_iter().enumerate()
223                                {
224                                    if let Some(inline_data) = response_part.inline_data {
225                                        candidate_has_content = true;
226                                        chunk_has_content = true;
227                                        self.emit_text_block(
228                                            out,
229                                            format!(
230                                                "function_response.inline_data({candidate_index}:{part_index}:{response_part_index})({}): {}",
231                                                inline_data.mime_type, inline_data.data
232                                            ),
233                                        );
234                                    }
235                                }
236                            }
237                        }
238
239                        if let Some(executable_code) = part.executable_code {
240                            let language = match executable_code.language {
241                                GeminiLanguage::LanguageUnspecified => "unspecified",
242                                GeminiLanguage::Python => "python",
243                            };
244                            candidate_has_content = true;
245                            chunk_has_content = true;
246                            self.emit_text_block(
247                                out,
248                                format!("executable_code({language}): {}", executable_code.code),
249                            );
250                        }
251
252                        if let Some(code_execution_result) = part.code_execution_result {
253                            let outcome = match code_execution_result.outcome {
254                                GeminiOutcome::OutcomeUnspecified => "unspecified",
255                                GeminiOutcome::OutcomeOk => "ok",
256                                GeminiOutcome::OutcomeFailed => "failed",
257                                GeminiOutcome::OutcomeDeadlineExceeded => "deadline_exceeded",
258                            };
259                            let output_text = code_execution_result.output.unwrap_or_default();
260                            candidate_has_content = true;
261                            chunk_has_content = true;
262                            if output_text.is_empty() {
263                                self.emit_text_block(
264                                    out,
265                                    format!("code_execution_result({outcome})"),
266                                );
267                            } else {
268                                self.emit_text_block(
269                                    out,
270                                    format!("code_execution_result({outcome}): {output_text}"),
271                                );
272                            }
273                        }
274
275                        if let Some(file_data) = part.file_data {
276                            candidate_has_content = true;
277                            chunk_has_content = true;
278                            if let Some(mime_type) = file_data.mime_type {
279                                self.emit_text_block(
280                                    out,
281                                    format!("file_data({mime_type}): {}", file_data.file_uri),
282                                );
283                            } else {
284                                self.emit_text_block(out, file_data.file_uri);
285                            }
286                        }
287                    }
288                }
289
290                if !candidate_has_content
291                    && let Some(finish_message) = candidate.finish_message
292                    && !finish_message.is_empty()
293                {
294                    chunk_has_content = true;
295                    self.emit_text_block(out, finish_message);
296                }
297
298                if let Some(reason) = candidate.finish_reason {
299                    self.stop_reason = Some(match reason {
300                        GeminiFinishReason::MaxTokens => BetaStopReason::MaxTokens,
301                        GeminiFinishReason::MalformedFunctionCall
302                        | GeminiFinishReason::UnexpectedToolCall
303                        | GeminiFinishReason::TooManyToolCalls
304                        | GeminiFinishReason::MissingThoughtSignature => BetaStopReason::ToolUse,
305                        GeminiFinishReason::Safety
306                        | GeminiFinishReason::Recitation
307                        | GeminiFinishReason::Blocklist
308                        | GeminiFinishReason::ProhibitedContent
309                        | GeminiFinishReason::Spii
310                        | GeminiFinishReason::ImageSafety
311                        | GeminiFinishReason::ImageProhibitedContent
312                        | GeminiFinishReason::ImageRecitation => {
313                            self.has_refusal = true;
314                            BetaStopReason::Refusal
315                        }
316                        GeminiFinishReason::Stop
317                        | GeminiFinishReason::FinishReasonUnspecified
318                        | GeminiFinishReason::Language
319                        | GeminiFinishReason::Other
320                        | GeminiFinishReason::ImageOther
321                        | GeminiFinishReason::NoImage => BetaStopReason::EndTurn,
322                    });
323                }
324            }
325        } else {
326            self.stop_reason = Some(
327                match chunk
328                    .prompt_feedback
329                    .as_ref()
330                    .and_then(|feedback| feedback.block_reason.as_ref())
331                {
332                    Some(GeminiBlockReason::Safety)
333                    | Some(GeminiBlockReason::Blocklist)
334                    | Some(GeminiBlockReason::ProhibitedContent)
335                    | Some(GeminiBlockReason::ImageSafety) => {
336                        self.has_refusal = true;
337                        BetaStopReason::Refusal
338                    }
339                    Some(GeminiBlockReason::BlockReasonUnspecified)
340                    | Some(GeminiBlockReason::Other)
341                    | None => BetaStopReason::EndTurn,
342                },
343            );
344        }
345
346        if !chunk_has_content {
347            self.ensure_running(out);
348        }
349    }
350
351    pub fn finish(&mut self, out: &mut Vec<ClaudeStreamEvent>) {
352        if self.is_finished() {
353            return;
354        }
355
356        self.ensure_running(out);
357
358        let final_stop_reason = self.stop_reason.clone().or({
359            if self.has_tool_use {
360                Some(BetaStopReason::ToolUse)
361            } else if self.has_refusal {
362                Some(BetaStopReason::Refusal)
363            } else {
364                Some(BetaStopReason::EndTurn)
365            }
366        });
367        out.push(message_delta_event(
368            final_stop_reason,
369            self.input_tokens,
370            self.cached_input_tokens,
371            self.output_tokens,
372        ));
373        out.push(message_stop_event());
374        self.state = StreamState::Finished;
375    }
376}
377
378impl TryFrom<GeminiStreamGenerateContentResponse> for Vec<ClaudeStreamEvent> {
379    type Error = TransformError;
380
381    fn try_from(value: GeminiStreamGenerateContentResponse) -> Result<Self, TransformError> {
382        match value {
383            GeminiStreamGenerateContentResponse::Success { .. } => {
384                // The new response type no longer contains chunks inline;
385                // chunks are processed individually via on_chunk().
386                Ok(Vec::new())
387            }
388            GeminiStreamGenerateContentResponse::Error { body, .. } => {
389                Ok(vec![stream_error_event(body.error.message)])
390            }
391        }
392    }
393}