Skip to main content

gproxy_protocol/transform/gemini/generate_content/claude/
response.rs

1use crate::claude::count_tokens::types::BetaServerToolUseName;
2use crate::claude::create_message::response::ClaudeCreateMessageResponse;
3use crate::claude::create_message::types::{BetaContentBlock, BetaStopReason};
4use crate::gemini::count_tokens::types::{GeminiContentRole, GeminiFunctionCall, GeminiPart};
5use crate::gemini::generate_content::response::{GeminiGenerateContentResponse, ResponseBody};
6use crate::gemini::generate_content::types::{
7    GeminiCandidate, GeminiContent, GeminiFinishReason, GeminiUsageMetadata,
8};
9use crate::gemini::types::GeminiResponseHeaders;
10use crate::transform::claude::utils::claude_model_to_string;
11use crate::transform::gemini::generate_content::utils::gemini_error_response_from_claude;
12use crate::transform::utils::TransformError;
13
14impl TryFrom<ClaudeCreateMessageResponse> for GeminiGenerateContentResponse {
15    type Error = TransformError;
16
17    fn try_from(value: ClaudeCreateMessageResponse) -> Result<Self, TransformError> {
18        Ok(match value {
19            ClaudeCreateMessageResponse::Success {
20                stats_code,
21                headers,
22                body,
23            } => {
24                let mut parts = Vec::new();
25                for block in body.content {
26                    match block {
27                        BetaContentBlock::Text(block) if !block.text.is_empty() => {
28                            parts.push(GeminiPart {
29                                text: Some(block.text),
30                                ..GeminiPart::default()
31                            });
32                        }
33                        BetaContentBlock::Thinking(block) if !block.thinking.is_empty() => {
34                            parts.push(GeminiPart {
35                                thought: Some(true),
36                                thought_signature: Some(block.signature),
37                                text: Some(block.thinking),
38                                ..GeminiPart::default()
39                            });
40                        }
41                        BetaContentBlock::ToolUse(block) => {
42                            parts.push(GeminiPart {
43                                function_call: Some(GeminiFunctionCall {
44                                    id: Some(block.id),
45                                    name: block.name,
46                                    args: Some(block.input),
47                                }),
48                                ..GeminiPart::default()
49                            });
50                        }
51                        BetaContentBlock::ServerToolUse(block) => {
52                            let name = match block.name {
53                                BetaServerToolUseName::WebSearch => "web_search",
54                                BetaServerToolUseName::WebFetch => "web_fetch",
55                                BetaServerToolUseName::CodeExecution => "code_execution",
56                                BetaServerToolUseName::BashCodeExecution => "bash_code_execution",
57                                BetaServerToolUseName::TextEditorCodeExecution => {
58                                    "text_editor_code_execution"
59                                }
60                                BetaServerToolUseName::ToolSearchToolRegex => "tool_search_regex",
61                                BetaServerToolUseName::ToolSearchToolBm25 => "tool_search_bm25",
62                            }
63                            .to_string();
64                            parts.push(GeminiPart {
65                                function_call: Some(GeminiFunctionCall {
66                                    id: Some(block.id),
67                                    name,
68                                    args: Some(block.input),
69                                }),
70                                ..GeminiPart::default()
71                            });
72                        }
73                        BetaContentBlock::McpToolUse(block) => {
74                            parts.push(GeminiPart {
75                                function_call: Some(GeminiFunctionCall {
76                                    id: Some(block.id),
77                                    name: format!("mcp:{}:{}", block.server_name, block.name),
78                                    args: Some(block.input),
79                                }),
80                                ..GeminiPart::default()
81                            });
82                        }
83                        _ => {}
84                    }
85                }
86
87                if parts.is_empty() {
88                    parts.push(GeminiPart {
89                        text: Some(String::new()),
90                        ..GeminiPart::default()
91                    });
92                }
93
94                let finish_reason = match body.stop_reason {
95                    Some(BetaStopReason::MaxTokens)
96                    | Some(BetaStopReason::ModelContextWindowExceeded) => {
97                        Some(GeminiFinishReason::MaxTokens)
98                    }
99                    Some(BetaStopReason::ToolUse) => Some(GeminiFinishReason::UnexpectedToolCall),
100                    Some(BetaStopReason::Refusal) => Some(GeminiFinishReason::Safety),
101                    Some(BetaStopReason::Compaction) => Some(GeminiFinishReason::Other),
102                    Some(BetaStopReason::PauseTurn) => Some(GeminiFinishReason::Other),
103                    Some(BetaStopReason::EndTurn) | Some(BetaStopReason::StopSequence) | None => {
104                        Some(GeminiFinishReason::Stop)
105                    }
106                };
107
108                let usage_metadata = GeminiUsageMetadata {
109                    prompt_token_count: Some(
110                        body.usage
111                            .input_tokens
112                            .saturating_add(body.usage.cache_creation_input_tokens),
113                    ),
114                    cached_content_token_count: Some(body.usage.cache_read_input_tokens),
115                    candidates_token_count: Some(body.usage.output_tokens),
116                    total_token_count: Some(
117                        body.usage
118                            .input_tokens
119                            .saturating_add(body.usage.cache_creation_input_tokens)
120                            .saturating_add(body.usage.cache_read_input_tokens)
121                            .saturating_add(body.usage.output_tokens),
122                    ),
123                    ..GeminiUsageMetadata::default()
124                };
125
126                GeminiGenerateContentResponse::Success {
127                    stats_code,
128                    headers: GeminiResponseHeaders {
129                        extra: headers.extra,
130                    },
131                    body: ResponseBody {
132                        candidates: Some(vec![GeminiCandidate {
133                            content: Some(GeminiContent {
134                                parts,
135                                role: Some(GeminiContentRole::Model),
136                            }),
137                            finish_reason,
138                            token_count: Some(body.usage.output_tokens),
139                            index: Some(0),
140                            ..GeminiCandidate::default()
141                        }]),
142                        prompt_feedback: None,
143                        usage_metadata: Some(usage_metadata),
144                        model_version: Some(claude_model_to_string(&body.model)),
145                        response_id: Some(body.id),
146                        model_status: None,
147                    },
148                }
149            }
150            ClaudeCreateMessageResponse::Error {
151                stats_code,
152                headers,
153                body,
154            } => GeminiGenerateContentResponse::Error {
155                stats_code,
156                headers: GeminiResponseHeaders {
157                    extra: headers.extra,
158                },
159                body: gemini_error_response_from_claude(stats_code, body),
160            },
161        })
162    }
163}