Skip to main content

gproxy_protocol/transform/claude/
stream_to_nonstream.rs

1use std::collections::BTreeMap;
2
3use http::StatusCode;
4
5use crate::claude::create_message::response::ClaudeCreateMessageResponse;
6use crate::claude::create_message::stream::{BetaRawContentBlockDelta, ClaudeStreamEvent};
7use crate::claude::create_message::types::{
8    BetaContentBlock, BetaErrorResponse, BetaErrorResponseType, BetaMessage, BetaTextBlock,
9    BetaThinkingBlock, BetaToolUseBlock, JsonObject,
10};
11use crate::claude::types::{BetaError, ClaudeResponseHeaders};
12use crate::transform::utils::TransformError;
13
14#[derive(Debug, Clone)]
15enum PendingBlock {
16    Text(BetaTextBlock),
17    Thinking(BetaThinkingBlock),
18    ToolUse {
19        block: BetaToolUseBlock,
20        input_json_buf: String,
21    },
22    Other(BetaContentBlock),
23}
24
25impl PendingBlock {
26    fn apply_delta(&mut self, delta: BetaRawContentBlockDelta) {
27        match (self, delta) {
28            (Self::Text(block), BetaRawContentBlockDelta::Text { text }) => {
29                block.text.push_str(&text);
30            }
31            (Self::Text(block), BetaRawContentBlockDelta::Citations { citation }) => {
32                if let Some(citations) = block.citations.as_mut() {
33                    citations.push(citation);
34                } else {
35                    block.citations = Some(vec![citation]);
36                }
37            }
38            (Self::Thinking(block), BetaRawContentBlockDelta::Thinking { thinking }) => {
39                block.thinking.push_str(&thinking);
40            }
41            (Self::Thinking(block), BetaRawContentBlockDelta::Signature { signature }) => {
42                block.signature = signature;
43            }
44            (
45                Self::ToolUse { input_json_buf, .. },
46                BetaRawContentBlockDelta::InputJson { partial_json },
47            ) => {
48                input_json_buf.push_str(&partial_json);
49            }
50            (
51                Self::Other(BetaContentBlock::Compaction(block)),
52                BetaRawContentBlockDelta::Compaction { content },
53            ) => {
54                block.content = content;
55            }
56            _ => {}
57        }
58    }
59
60    fn into_content_block(self) -> BetaContentBlock {
61        match self {
62            Self::Text(block) => BetaContentBlock::Text(block),
63            Self::Thinking(block) => BetaContentBlock::Thinking(block),
64            Self::ToolUse {
65                mut block,
66                input_json_buf,
67            } => {
68                if !input_json_buf.is_empty() {
69                    block.input =
70                        serde_json::from_str::<JsonObject>(&input_json_buf).unwrap_or_default();
71                }
72                BetaContentBlock::ToolUse(block)
73            }
74            Self::Other(block) => block,
75        }
76    }
77}
78
79fn pending_from_content_block(content_block: BetaContentBlock) -> PendingBlock {
80    match content_block {
81        BetaContentBlock::Text(block) => PendingBlock::Text(block),
82        BetaContentBlock::Thinking(block) => PendingBlock::Thinking(block),
83        BetaContentBlock::ToolUse(block) => PendingBlock::ToolUse {
84            block,
85            input_json_buf: String::new(),
86        },
87        other => PendingBlock::Other(other),
88    }
89}
90
91fn status_code_from_stream_error(error: &BetaError) -> StatusCode {
92    match error {
93        BetaError::InvalidRequest(_) => StatusCode::BAD_REQUEST,
94        BetaError::Authentication(_) => StatusCode::UNAUTHORIZED,
95        BetaError::Billing(_) => StatusCode::PAYMENT_REQUIRED,
96        BetaError::Permission(_) => StatusCode::FORBIDDEN,
97        BetaError::NotFound(_) => StatusCode::NOT_FOUND,
98        BetaError::RateLimit(_) => StatusCode::TOO_MANY_REQUESTS,
99        BetaError::GatewayTimeout(_) => StatusCode::GATEWAY_TIMEOUT,
100        BetaError::Api(_) => StatusCode::INTERNAL_SERVER_ERROR,
101        BetaError::Overloaded(_) => {
102            StatusCode::from_u16(529).unwrap_or(StatusCode::SERVICE_UNAVAILABLE)
103        }
104    }
105}
106
107impl TryFrom<Vec<ClaudeStreamEvent>> for ClaudeCreateMessageResponse {
108    type Error = TransformError;
109
110    fn try_from(value: Vec<ClaudeStreamEvent>) -> Result<Self, TransformError> {
111        let mut message: Option<BetaMessage> = None;
112        let mut open_blocks: BTreeMap<u64, PendingBlock> = BTreeMap::new();
113        let mut closed_blocks: BTreeMap<u64, BetaContentBlock> = BTreeMap::new();
114
115        for event in value {
116            match event {
117                ClaudeStreamEvent::MessageStart { message: msg } => {
118                    if message.is_some() {
119                        return Err(TransformError::not_implemented(
120                            "multiple message_start events are not supported",
121                        ));
122                    }
123                    message = Some(msg);
124                }
125                ClaudeStreamEvent::ContentBlockStart {
126                    content_block,
127                    index,
128                } => {
129                    open_blocks.insert(index, pending_from_content_block(content_block));
130                }
131                ClaudeStreamEvent::ContentBlockDelta { delta, index } => {
132                    let Some(block) = open_blocks.get_mut(&index) else {
133                        return Err(TransformError::not_implemented(
134                            "content_block_delta received before content_block_start",
135                        ));
136                    };
137                    block.apply_delta(delta);
138                }
139                ClaudeStreamEvent::ContentBlockStop { index } => {
140                    let Some(block) = open_blocks.remove(&index) else {
141                        return Err(TransformError::not_implemented(
142                            "content_block_stop received before content_block_start",
143                        ));
144                    };
145                    closed_blocks.insert(index, block.into_content_block());
146                }
147                ClaudeStreamEvent::MessageDelta {
148                    context_management,
149                    delta,
150                    usage,
151                } => {
152                    let Some(message) = message.as_mut() else {
153                        return Err(TransformError::not_implemented(
154                            "message_delta received before message_start",
155                        ));
156                    };
157
158                    if let Some(context_management) = context_management {
159                        message.context_management = Some(context_management);
160                    }
161
162                    message.stop_reason = delta.stop_reason;
163                    message.stop_sequence = delta.stop_sequence;
164                    if let Some(container) = delta.container {
165                        message.container = Some(container);
166                    }
167
168                    if let Some(input_tokens) = usage.input_tokens {
169                        message.usage.input_tokens = input_tokens;
170                    }
171                    if let Some(cache_read_input_tokens) = usage.cache_read_input_tokens {
172                        message.usage.cache_read_input_tokens = cache_read_input_tokens;
173                    }
174                    if let Some(cache_creation_input_tokens) = usage.cache_creation_input_tokens {
175                        message.usage.cache_creation_input_tokens = cache_creation_input_tokens;
176                    }
177                    if let Some(iterations) = usage.iterations {
178                        message.usage.iterations = iterations;
179                    }
180                    if let Some(server_tool_use) = usage.server_tool_use {
181                        message.usage.server_tool_use = server_tool_use;
182                    }
183                    message.usage.output_tokens = usage.output_tokens;
184                }
185                ClaudeStreamEvent::MessageStop {} => {}
186                ClaudeStreamEvent::Ping {} => {}
187                ClaudeStreamEvent::Error { error } => {
188                    return Ok(ClaudeCreateMessageResponse::Error {
189                        stats_code: status_code_from_stream_error(&error),
190                        headers: ClaudeResponseHeaders {
191                            extra: BTreeMap::new(),
192                        },
193                        body: BetaErrorResponse {
194                            error,
195                            request_id: String::new(),
196                            type_: BetaErrorResponseType::Error,
197                        },
198                    });
199                }
200            }
201        }
202
203        let Some(mut message) = message else {
204            return Err(TransformError::not_implemented(
205                "message_start event is required for stream_to_nonstream conversion",
206            ));
207        };
208
209        for (index, block) in open_blocks {
210            closed_blocks.insert(index, block.into_content_block());
211        }
212
213        if !closed_blocks.is_empty() {
214            message.content = closed_blocks.into_values().collect();
215        }
216
217        Ok(ClaudeCreateMessageResponse::Success {
218            stats_code: StatusCode::OK,
219            headers: ClaudeResponseHeaders {
220                extra: BTreeMap::new(),
221            },
222            body: message,
223        })
224    }
225}