Skip to main content

gproxy_protocol/transform/gemini/stream_generate_content/claude/
response.rs

1use std::collections::BTreeMap;
2
3use crate::claude::create_message::stream::{BetaRawContentBlockDelta, ClaudeStreamEvent};
4use crate::claude::create_message::types::{BetaContentBlock, BetaStopReason};
5use crate::claude::types::BetaError;
6use crate::gemini::count_tokens::types::{GeminiContentRole, GeminiFunctionCall, GeminiPart};
7use crate::gemini::generate_content::response::ResponseBody as GeminiGenerateContentResponseBody;
8use crate::gemini::generate_content::types::{
9    GeminiBlockReason, GeminiCandidate, GeminiContent, GeminiFinishReason, GeminiPromptFeedback,
10    GeminiUsageMetadata,
11};
12use crate::transform::claude::utils::claude_model_to_string;
13use crate::transform::gemini::stream_generate_content::utils::parse_json_object_or_empty;
14use crate::transform::utils::TransformError;
15
16#[derive(Debug, Clone)]
17enum ClaudeBlockState {
18    Thinking {
19        signature: String,
20    },
21    ToolUse {
22        id: String,
23        name: String,
24        partial_json: String,
25    },
26    Other,
27}
28
29#[derive(Debug, Default, Clone)]
30pub struct ClaudeToGeminiStream {
31    response_id: Option<String>,
32    model_version: Option<String>,
33    input_tokens: u64,
34    cache_creation_input_tokens: u64,
35    cached_input_tokens: u64,
36    output_tokens: u64,
37    usage_metadata: Option<GeminiUsageMetadata>,
38    blocks: BTreeMap<u64, ClaudeBlockState>,
39    finished: bool,
40}
41
42impl ClaudeToGeminiStream {
43    pub fn is_finished(&self) -> bool {
44        self.finished
45    }
46
47    fn usage_from_counts(
48        input_tokens: u64,
49        cache_creation_tokens: u64,
50        cached_tokens: u64,
51        output_tokens: u64,
52    ) -> GeminiUsageMetadata {
53        let prompt_tokens = input_tokens.saturating_add(cache_creation_tokens);
54        GeminiUsageMetadata {
55            prompt_token_count: Some(prompt_tokens),
56            cached_content_token_count: Some(cached_tokens),
57            candidates_token_count: Some(output_tokens),
58            total_token_count: Some(
59                prompt_tokens
60                    .saturating_add(cached_tokens)
61                    .saturating_add(output_tokens),
62            ),
63            ..GeminiUsageMetadata::default()
64        }
65    }
66
67    fn sync_usage_metadata(&mut self) {
68        self.usage_metadata = Some(Self::usage_from_counts(
69            self.input_tokens,
70            self.cache_creation_input_tokens,
71            self.cached_input_tokens,
72            self.output_tokens,
73        ));
74    }
75
76    fn finish_reason_from_stop_reason(stop_reason: Option<BetaStopReason>) -> GeminiFinishReason {
77        match stop_reason {
78            Some(BetaStopReason::MaxTokens) | Some(BetaStopReason::ModelContextWindowExceeded) => {
79                GeminiFinishReason::MaxTokens
80            }
81            Some(BetaStopReason::ToolUse) => GeminiFinishReason::UnexpectedToolCall,
82            Some(BetaStopReason::Refusal) => GeminiFinishReason::Safety,
83            Some(BetaStopReason::Compaction) | Some(BetaStopReason::PauseTurn) => {
84                GeminiFinishReason::Other
85            }
86            Some(BetaStopReason::EndTurn) | Some(BetaStopReason::StopSequence) | None => {
87                GeminiFinishReason::Stop
88            }
89        }
90    }
91
92    fn error_message(error: BetaError) -> String {
93        match error {
94            BetaError::InvalidRequest(error) => error.message,
95            BetaError::Authentication(error) => error.message,
96            BetaError::Billing(error) => error.message,
97            BetaError::Permission(error) => error.message,
98            BetaError::NotFound(error) => error.message,
99            BetaError::RateLimit(error) => error.message,
100            BetaError::GatewayTimeout(error) => error.message,
101            BetaError::Api(error) => error.message,
102            BetaError::Overloaded(error) => error.message,
103        }
104    }
105
106    fn chunk_from_parts(
107        &self,
108        parts: Vec<GeminiPart>,
109        finish_reason: Option<GeminiFinishReason>,
110        prompt_feedback: Option<GeminiPromptFeedback>,
111    ) -> GeminiGenerateContentResponseBody {
112        GeminiGenerateContentResponseBody {
113            candidates: Some(vec![GeminiCandidate {
114                content: Some(GeminiContent {
115                    parts,
116                    role: Some(GeminiContentRole::Model),
117                }),
118                finish_reason,
119                index: Some(0),
120                ..GeminiCandidate::default()
121            }]),
122            prompt_feedback,
123            usage_metadata: self.usage_metadata.clone(),
124            model_version: self.model_version.clone(),
125            response_id: self.response_id.clone(),
126            model_status: None,
127        }
128    }
129
130    fn text_chunk(&self, text: String) -> Option<GeminiGenerateContentResponseBody> {
131        if text.is_empty() {
132            None
133        } else {
134            Some(self.chunk_from_parts(
135                vec![GeminiPart {
136                    text: Some(text),
137                    ..GeminiPart::default()
138                }],
139                None,
140                None,
141            ))
142        }
143    }
144
145    fn thinking_chunk(
146        &self,
147        signature: String,
148        thinking: String,
149    ) -> Option<GeminiGenerateContentResponseBody> {
150        if thinking.is_empty() {
151            None
152        } else {
153            Some(self.chunk_from_parts(
154                vec![GeminiPart {
155                    thought: Some(true),
156                    thought_signature: Some(signature),
157                    text: Some(thinking),
158                    ..GeminiPart::default()
159                }],
160                None,
161                None,
162            ))
163        }
164    }
165
166    fn function_call_chunk(
167        &self,
168        id: String,
169        name: String,
170        arguments: String,
171    ) -> GeminiGenerateContentResponseBody {
172        self.chunk_from_parts(
173            vec![GeminiPart {
174                function_call: Some(GeminiFunctionCall {
175                    id: Some(id),
176                    name,
177                    args: Some(parse_json_object_or_empty(&arguments)),
178                }),
179                ..GeminiPart::default()
180            }],
181            None,
182            None,
183        )
184    }
185
186    pub fn on_event(
187        &mut self,
188        event: ClaudeStreamEvent,
189        out: &mut Vec<GeminiGenerateContentResponseBody>,
190    ) -> Result<(), TransformError> {
191        if self.finished {
192            return Ok(());
193        }
194
195        match event {
196            ClaudeStreamEvent::MessageStart { message } => {
197                self.response_id = Some(message.id);
198                self.model_version = Some(claude_model_to_string(&message.model));
199                self.input_tokens = message.usage.input_tokens;
200                self.cache_creation_input_tokens = message.usage.cache_creation_input_tokens;
201                self.cached_input_tokens = message.usage.cache_read_input_tokens;
202                self.output_tokens = message.usage.output_tokens;
203                self.sync_usage_metadata();
204            }
205            ClaudeStreamEvent::ContentBlockStart {
206                content_block,
207                index,
208            } => {
209                let state = match content_block {
210                    BetaContentBlock::Thinking(block) => ClaudeBlockState::Thinking {
211                        signature: block.signature,
212                    },
213                    BetaContentBlock::ToolUse(block) => ClaudeBlockState::ToolUse {
214                        id: block.id,
215                        name: block.name,
216                        partial_json: String::new(),
217                    },
218                    _ => ClaudeBlockState::Other,
219                };
220                self.blocks.insert(index, state);
221            }
222            ClaudeStreamEvent::ContentBlockDelta { delta, index } => match delta {
223                BetaRawContentBlockDelta::Text { text } => {
224                    if let Some(chunk) = self.text_chunk(text) {
225                        out.push(chunk);
226                    }
227                }
228                BetaRawContentBlockDelta::Thinking { thinking } => {
229                    let signature = match self.blocks.get(&index) {
230                        Some(ClaudeBlockState::Thinking { signature }) => signature.clone(),
231                        _ => format!("thought_{index}"),
232                    };
233                    if let Some(chunk) = self.thinking_chunk(signature, thinking) {
234                        out.push(chunk);
235                    }
236                }
237                BetaRawContentBlockDelta::InputJson { partial_json } => {
238                    let mut tool_snapshot = None;
239                    if let Some(ClaudeBlockState::ToolUse {
240                        id,
241                        name,
242                        partial_json: accumulated,
243                    }) = self.blocks.get_mut(&index)
244                    {
245                        accumulated.push_str(&partial_json);
246                        tool_snapshot = Some((id.clone(), name.clone(), accumulated.clone()));
247                    }
248                    if let Some((id, name, arguments)) = tool_snapshot {
249                        out.push(self.function_call_chunk(id, name, arguments));
250                    }
251                }
252                BetaRawContentBlockDelta::Signature { signature } => {
253                    if let Some(ClaudeBlockState::Thinking { signature: sig }) =
254                        self.blocks.get_mut(&index)
255                    {
256                        *sig = signature;
257                    }
258                }
259                BetaRawContentBlockDelta::Compaction { content } => {
260                    if let Some(content) = content
261                        && let Some(chunk) = self.text_chunk(content)
262                    {
263                        out.push(chunk);
264                    }
265                }
266                BetaRawContentBlockDelta::Citations { .. } => {}
267            },
268            ClaudeStreamEvent::ContentBlockStop { index } => {
269                self.blocks.remove(&index);
270            }
271            ClaudeStreamEvent::MessageDelta {
272                delta,
273                usage,
274                context_management: _,
275            } => {
276                if let Some(input_tokens) = usage.input_tokens {
277                    self.input_tokens = input_tokens;
278                }
279                if let Some(cache_creation_input_tokens) = usage.cache_creation_input_tokens {
280                    self.cache_creation_input_tokens = cache_creation_input_tokens;
281                }
282                if let Some(cached_input_tokens) = usage.cache_read_input_tokens {
283                    self.cached_input_tokens = cached_input_tokens;
284                }
285                self.output_tokens = usage.output_tokens;
286                self.sync_usage_metadata();
287
288                let finish_reason = Self::finish_reason_from_stop_reason(delta.stop_reason);
289                let prompt_feedback = if matches!(finish_reason, GeminiFinishReason::Safety) {
290                    Some(GeminiPromptFeedback {
291                        block_reason: Some(GeminiBlockReason::Safety),
292                        safety_ratings: None,
293                    })
294                } else {
295                    None
296                };
297
298                out.push(self.chunk_from_parts(Vec::new(), Some(finish_reason), prompt_feedback));
299            }
300            ClaudeStreamEvent::MessageStop {} => {
301                self.finished = true;
302            }
303            ClaudeStreamEvent::Error { error } => {
304                let message = Self::error_message(error);
305                if let Some(chunk) = self.text_chunk(message) {
306                    out.push(chunk);
307                }
308                self.finished = true;
309            }
310            ClaudeStreamEvent::Ping {} => {}
311        }
312
313        Ok(())
314    }
315}