gproxy_protocol/transform/gemini/stream_generate_content/claude/
response.rs1use std::collections::BTreeMap;
2
3use crate::claude::create_message::stream::{BetaRawContentBlockDelta, ClaudeStreamEvent};
4use crate::claude::create_message::types::{BetaContentBlock, BetaStopReason};
5use crate::claude::types::BetaError;
6use crate::gemini::count_tokens::types::{GeminiContentRole, GeminiFunctionCall, GeminiPart};
7use crate::gemini::generate_content::response::ResponseBody as GeminiGenerateContentResponseBody;
8use crate::gemini::generate_content::types::{
9 GeminiBlockReason, GeminiCandidate, GeminiContent, GeminiFinishReason, GeminiPromptFeedback,
10 GeminiUsageMetadata,
11};
12use crate::transform::claude::utils::claude_model_to_string;
13use crate::transform::gemini::stream_generate_content::utils::parse_json_object_or_empty;
14use crate::transform::utils::TransformError;
15
16#[derive(Debug, Clone)]
17enum ClaudeBlockState {
18 Thinking {
19 signature: String,
20 },
21 ToolUse {
22 id: String,
23 name: String,
24 partial_json: String,
25 },
26 Other,
27}
28
29#[derive(Debug, Default, Clone)]
30pub struct ClaudeToGeminiStream {
31 response_id: Option<String>,
32 model_version: Option<String>,
33 input_tokens: u64,
34 cache_creation_input_tokens: u64,
35 cached_input_tokens: u64,
36 output_tokens: u64,
37 usage_metadata: Option<GeminiUsageMetadata>,
38 blocks: BTreeMap<u64, ClaudeBlockState>,
39 finished: bool,
40}
41
42impl ClaudeToGeminiStream {
43 pub fn is_finished(&self) -> bool {
44 self.finished
45 }
46
47 fn usage_from_counts(
48 input_tokens: u64,
49 cache_creation_tokens: u64,
50 cached_tokens: u64,
51 output_tokens: u64,
52 ) -> GeminiUsageMetadata {
53 let prompt_tokens = input_tokens.saturating_add(cache_creation_tokens);
54 GeminiUsageMetadata {
55 prompt_token_count: Some(prompt_tokens),
56 cached_content_token_count: Some(cached_tokens),
57 candidates_token_count: Some(output_tokens),
58 total_token_count: Some(
59 prompt_tokens
60 .saturating_add(cached_tokens)
61 .saturating_add(output_tokens),
62 ),
63 ..GeminiUsageMetadata::default()
64 }
65 }
66
67 fn sync_usage_metadata(&mut self) {
68 self.usage_metadata = Some(Self::usage_from_counts(
69 self.input_tokens,
70 self.cache_creation_input_tokens,
71 self.cached_input_tokens,
72 self.output_tokens,
73 ));
74 }
75
76 fn finish_reason_from_stop_reason(stop_reason: Option<BetaStopReason>) -> GeminiFinishReason {
77 match stop_reason {
78 Some(BetaStopReason::MaxTokens) | Some(BetaStopReason::ModelContextWindowExceeded) => {
79 GeminiFinishReason::MaxTokens
80 }
81 Some(BetaStopReason::ToolUse) => GeminiFinishReason::UnexpectedToolCall,
82 Some(BetaStopReason::Refusal) => GeminiFinishReason::Safety,
83 Some(BetaStopReason::Compaction) | Some(BetaStopReason::PauseTurn) => {
84 GeminiFinishReason::Other
85 }
86 Some(BetaStopReason::EndTurn) | Some(BetaStopReason::StopSequence) | None => {
87 GeminiFinishReason::Stop
88 }
89 }
90 }
91
92 fn error_message(error: BetaError) -> String {
93 match error {
94 BetaError::InvalidRequest(error) => error.message,
95 BetaError::Authentication(error) => error.message,
96 BetaError::Billing(error) => error.message,
97 BetaError::Permission(error) => error.message,
98 BetaError::NotFound(error) => error.message,
99 BetaError::RateLimit(error) => error.message,
100 BetaError::GatewayTimeout(error) => error.message,
101 BetaError::Api(error) => error.message,
102 BetaError::Overloaded(error) => error.message,
103 }
104 }
105
106 fn chunk_from_parts(
107 &self,
108 parts: Vec<GeminiPart>,
109 finish_reason: Option<GeminiFinishReason>,
110 prompt_feedback: Option<GeminiPromptFeedback>,
111 ) -> GeminiGenerateContentResponseBody {
112 GeminiGenerateContentResponseBody {
113 candidates: Some(vec![GeminiCandidate {
114 content: Some(GeminiContent {
115 parts,
116 role: Some(GeminiContentRole::Model),
117 }),
118 finish_reason,
119 index: Some(0),
120 ..GeminiCandidate::default()
121 }]),
122 prompt_feedback,
123 usage_metadata: self.usage_metadata.clone(),
124 model_version: self.model_version.clone(),
125 response_id: self.response_id.clone(),
126 model_status: None,
127 }
128 }
129
130 fn text_chunk(&self, text: String) -> Option<GeminiGenerateContentResponseBody> {
131 if text.is_empty() {
132 None
133 } else {
134 Some(self.chunk_from_parts(
135 vec![GeminiPart {
136 text: Some(text),
137 ..GeminiPart::default()
138 }],
139 None,
140 None,
141 ))
142 }
143 }
144
145 fn thinking_chunk(
146 &self,
147 signature: String,
148 thinking: String,
149 ) -> Option<GeminiGenerateContentResponseBody> {
150 if thinking.is_empty() {
151 None
152 } else {
153 Some(self.chunk_from_parts(
154 vec![GeminiPart {
155 thought: Some(true),
156 thought_signature: Some(signature),
157 text: Some(thinking),
158 ..GeminiPart::default()
159 }],
160 None,
161 None,
162 ))
163 }
164 }
165
166 fn function_call_chunk(
167 &self,
168 id: String,
169 name: String,
170 arguments: String,
171 ) -> GeminiGenerateContentResponseBody {
172 self.chunk_from_parts(
173 vec![GeminiPart {
174 function_call: Some(GeminiFunctionCall {
175 id: Some(id),
176 name,
177 args: Some(parse_json_object_or_empty(&arguments)),
178 }),
179 ..GeminiPart::default()
180 }],
181 None,
182 None,
183 )
184 }
185
186 pub fn on_event(
187 &mut self,
188 event: ClaudeStreamEvent,
189 out: &mut Vec<GeminiGenerateContentResponseBody>,
190 ) -> Result<(), TransformError> {
191 if self.finished {
192 return Ok(());
193 }
194
195 match event {
196 ClaudeStreamEvent::MessageStart { message } => {
197 self.response_id = Some(message.id);
198 self.model_version = Some(claude_model_to_string(&message.model));
199 self.input_tokens = message.usage.input_tokens;
200 self.cache_creation_input_tokens = message.usage.cache_creation_input_tokens;
201 self.cached_input_tokens = message.usage.cache_read_input_tokens;
202 self.output_tokens = message.usage.output_tokens;
203 self.sync_usage_metadata();
204 }
205 ClaudeStreamEvent::ContentBlockStart {
206 content_block,
207 index,
208 } => {
209 let state = match content_block {
210 BetaContentBlock::Thinking(block) => ClaudeBlockState::Thinking {
211 signature: block.signature,
212 },
213 BetaContentBlock::ToolUse(block) => ClaudeBlockState::ToolUse {
214 id: block.id,
215 name: block.name,
216 partial_json: String::new(),
217 },
218 _ => ClaudeBlockState::Other,
219 };
220 self.blocks.insert(index, state);
221 }
222 ClaudeStreamEvent::ContentBlockDelta { delta, index } => match delta {
223 BetaRawContentBlockDelta::Text { text } => {
224 if let Some(chunk) = self.text_chunk(text) {
225 out.push(chunk);
226 }
227 }
228 BetaRawContentBlockDelta::Thinking { thinking } => {
229 let signature = match self.blocks.get(&index) {
230 Some(ClaudeBlockState::Thinking { signature }) => signature.clone(),
231 _ => format!("thought_{index}"),
232 };
233 if let Some(chunk) = self.thinking_chunk(signature, thinking) {
234 out.push(chunk);
235 }
236 }
237 BetaRawContentBlockDelta::InputJson { partial_json } => {
238 let mut tool_snapshot = None;
239 if let Some(ClaudeBlockState::ToolUse {
240 id,
241 name,
242 partial_json: accumulated,
243 }) = self.blocks.get_mut(&index)
244 {
245 accumulated.push_str(&partial_json);
246 tool_snapshot = Some((id.clone(), name.clone(), accumulated.clone()));
247 }
248 if let Some((id, name, arguments)) = tool_snapshot {
249 out.push(self.function_call_chunk(id, name, arguments));
250 }
251 }
252 BetaRawContentBlockDelta::Signature { signature } => {
253 if let Some(ClaudeBlockState::Thinking { signature: sig }) =
254 self.blocks.get_mut(&index)
255 {
256 *sig = signature;
257 }
258 }
259 BetaRawContentBlockDelta::Compaction { content } => {
260 if let Some(content) = content
261 && let Some(chunk) = self.text_chunk(content)
262 {
263 out.push(chunk);
264 }
265 }
266 BetaRawContentBlockDelta::Citations { .. } => {}
267 },
268 ClaudeStreamEvent::ContentBlockStop { index } => {
269 self.blocks.remove(&index);
270 }
271 ClaudeStreamEvent::MessageDelta {
272 delta,
273 usage,
274 context_management: _,
275 } => {
276 if let Some(input_tokens) = usage.input_tokens {
277 self.input_tokens = input_tokens;
278 }
279 if let Some(cache_creation_input_tokens) = usage.cache_creation_input_tokens {
280 self.cache_creation_input_tokens = cache_creation_input_tokens;
281 }
282 if let Some(cached_input_tokens) = usage.cache_read_input_tokens {
283 self.cached_input_tokens = cached_input_tokens;
284 }
285 self.output_tokens = usage.output_tokens;
286 self.sync_usage_metadata();
287
288 let finish_reason = Self::finish_reason_from_stop_reason(delta.stop_reason);
289 let prompt_feedback = if matches!(finish_reason, GeminiFinishReason::Safety) {
290 Some(GeminiPromptFeedback {
291 block_reason: Some(GeminiBlockReason::Safety),
292 safety_ratings: None,
293 })
294 } else {
295 None
296 };
297
298 out.push(self.chunk_from_parts(Vec::new(), Some(finish_reason), prompt_feedback));
299 }
300 ClaudeStreamEvent::MessageStop {} => {
301 self.finished = true;
302 }
303 ClaudeStreamEvent::Error { error } => {
304 let message = Self::error_message(error);
305 if let Some(chunk) = self.text_chunk(message) {
306 out.push(chunk);
307 }
308 self.finished = true;
309 }
310 ClaudeStreamEvent::Ping {} => {}
311 }
312
313 Ok(())
314 }
315}