Skip to main content

ccs_proxy/provider/
claude.rs

1//! Anthropic Messages reassembler (server-sent events).
2//!
3//! Spec: <https://docs.anthropic.com/en/api/messages-streaming>
4
5use crate::capture::Usage;
6use serde_json::Value;
7
8#[derive(Debug, Default)]
9pub struct ClaudeMessage {
10    pub model: Option<String>,
11    pub stop_reason: Option<String>,
12    pub content_blocks: Vec<ContentBlock>,
13    pub usage: Option<Usage>,
14}
15
16#[derive(Debug)]
17pub enum ContentBlock {
18    Text(String),
19    ToolUse {
20        id: String,
21        name: String,
22        input: Value,
23    },
24}
25
26impl ClaudeMessage {
27    pub fn text_content(&self) -> String {
28        let mut out = String::new();
29        for b in &self.content_blocks {
30            if let ContentBlock::Text(t) = b {
31                out.push_str(t);
32            }
33        }
34        out
35    }
36
37    pub fn to_json(&self) -> Value {
38        let blocks: Vec<Value> = self
39            .content_blocks
40            .iter()
41            .map(|b| match b {
42                ContentBlock::Text(t) => serde_json::json!({"type":"text","text":t}),
43                ContentBlock::ToolUse { id, name, input } => serde_json::json!({
44                    "type":"tool_use","id":id,"name":name,"input":input
45                }),
46            })
47            .collect();
48        serde_json::json!({
49            "model": self.model,
50            "stop_reason": self.stop_reason,
51            "content": blocks,
52            "usage": self.usage,
53        })
54    }
55}
56
57pub struct ClaudeReassembler {
58    buffer: Vec<u8>,
59    msg: ClaudeMessage,
60    frames_count: u64,
61    saw_message_stop: bool,
62}
63
64impl Default for ClaudeReassembler {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl ClaudeReassembler {
71    pub fn new() -> Self {
72        Self {
73            buffer: Vec::new(),
74            msg: ClaudeMessage::default(),
75            frames_count: 0,
76            saw_message_stop: false,
77        }
78    }
79
80    pub fn feed(&mut self, chunk: &[u8]) {
81        self.buffer.extend_from_slice(chunk);
82        while let Some(end) = find_double_newline(&self.buffer) {
83            let frame_bytes = self.buffer.drain(..end + 2).collect::<Vec<u8>>();
84            // SSE frames are blank-line-terminated. drain takes through the \n\n.
85            self.process_frame(&frame_bytes);
86        }
87    }
88
89    pub fn frames_count(&self) -> u64 {
90        self.frames_count
91    }
92
93    pub fn saw_message_stop(&self) -> bool {
94        self.saw_message_stop
95    }
96
97    pub fn finish(mut self) -> Option<ClaudeMessage> {
98        // Process anything still in buffer (no trailing blank line case).
99        if !self.buffer.is_empty() {
100            let leftover = std::mem::take(&mut self.buffer);
101            self.process_frame(&leftover);
102        }
103        if self.frames_count == 0 {
104            return None;
105        }
106        Some(self.msg)
107    }
108
109    fn process_frame(&mut self, raw: &[u8]) {
110        self.frames_count += 1;
111        // Each frame has lines like "event: foo" and "data: {...}".
112        let mut data_lines: Vec<&[u8]> = Vec::new();
113        for line in raw.split(|b| *b == b'\n') {
114            let line = strip_cr(line);
115            if let Some(rest) = line.strip_prefix(b"data:") {
116                let trimmed = trim_ascii_start(rest);
117                data_lines.push(trimmed);
118            }
119        }
120        if data_lines.is_empty() {
121            return;
122        }
123        let mut joined = Vec::new();
124        for (i, l) in data_lines.iter().enumerate() {
125            if i > 0 {
126                joined.push(b'\n');
127            }
128            joined.extend_from_slice(l);
129        }
130        let Ok(text) = std::str::from_utf8(&joined) else {
131            return;
132        };
133        let Ok(value) = serde_json::from_str::<Value>(text) else {
134            return;
135        };
136        self.apply_event(&value);
137    }
138
139    fn apply_event(&mut self, v: &Value) {
140        let Some(ty) = v.get("type").and_then(|t| t.as_str()) else {
141            return;
142        };
143        match ty {
144            "message_start" => {
145                if let Some(m) = v.get("message") {
146                    if let Some(model) = m.get("model").and_then(|x| x.as_str()) {
147                        self.msg.model = Some(model.to_string());
148                    }
149                    if let Some(u) = m.get("usage") {
150                        self.msg.usage = parse_usage(u);
151                    }
152                }
153            }
154            "content_block_start" => {
155                if let Some(cb) = v.get("content_block") {
156                    let kind = cb.get("type").and_then(|x| x.as_str()).unwrap_or("");
157                    match kind {
158                        "text" => self
159                            .msg
160                            .content_blocks
161                            .push(ContentBlock::Text(String::new())),
162                        "tool_use" => self.msg.content_blocks.push(ContentBlock::ToolUse {
163                            id: cb
164                                .get("id")
165                                .and_then(|x| x.as_str())
166                                .unwrap_or("")
167                                .to_string(),
168                            name: cb
169                                .get("name")
170                                .and_then(|x| x.as_str())
171                                .unwrap_or("")
172                                .to_string(),
173                            input: cb.get("input").cloned().unwrap_or(Value::Null),
174                        }),
175                        _ => self
176                            .msg
177                            .content_blocks
178                            .push(ContentBlock::Text(String::new())),
179                    }
180                }
181            }
182            "content_block_delta" => {
183                if let Some(delta) = v.get("delta") {
184                    let delta_type = delta.get("type").and_then(|x| x.as_str()).unwrap_or("");
185                    let idx = v.get("index").and_then(|x| x.as_u64()).unwrap_or(0) as usize;
186                    if let Some(block) = self.msg.content_blocks.get_mut(idx) {
187                        match (block, delta_type) {
188                            (ContentBlock::Text(s), "text_delta") => {
189                                if let Some(t) = delta.get("text").and_then(|x| x.as_str()) {
190                                    s.push_str(t);
191                                }
192                            }
193                            (ContentBlock::ToolUse { input, .. }, "input_json_delta") => {
194                                if let Some(partial) =
195                                    delta.get("partial_json").and_then(|x| x.as_str())
196                                {
197                                    // Accumulate raw partial JSON in a string under input,
198                                    // serialized as string fragment list. v1 stores last seen.
199                                    let key = "_partial".to_string();
200                                    if let Value::Null = input {
201                                        *input = Value::Object(Default::default());
202                                    }
203                                    if let Value::Object(m) = input {
204                                        let cur = m
205                                            .entry(key)
206                                            .or_insert_with(|| Value::String(String::new()));
207                                        if let Value::String(s) = cur {
208                                            s.push_str(partial);
209                                        }
210                                    }
211                                }
212                            }
213                            _ => {}
214                        }
215                    }
216                }
217            }
218            "message_delta" => {
219                if let Some(d) = v.get("delta")
220                    && let Some(sr) = d.get("stop_reason").and_then(|x| x.as_str())
221                {
222                    self.msg.stop_reason = Some(sr.to_string());
223                }
224                if let Some(u) = v.get("usage") {
225                    if let Some(existing) = self.msg.usage.as_mut() {
226                        if let Some(ot) = u.get("output_tokens").and_then(|x| x.as_u64()) {
227                            existing.output_tokens = ot;
228                        }
229                    } else {
230                        self.msg.usage = parse_usage(u);
231                    }
232                }
233            }
234            "message_stop" => {
235                self.saw_message_stop = true;
236            }
237            _ => {}
238        }
239    }
240}
241
242fn parse_usage(v: &Value) -> Option<Usage> {
243    let mut u = Usage::default();
244    if let Some(x) = v.get("input_tokens").and_then(|x| x.as_u64()) {
245        u.input_tokens = x;
246    }
247    if let Some(x) = v.get("output_tokens").and_then(|x| x.as_u64()) {
248        u.output_tokens = x;
249    }
250    if let Some(x) = v
251        .get("cache_creation_input_tokens")
252        .and_then(|x| x.as_u64())
253    {
254        u.cache_creation_input_tokens = x;
255    }
256    if let Some(x) = v.get("cache_read_input_tokens").and_then(|x| x.as_u64()) {
257        u.cache_read_input_tokens = x;
258    }
259    Some(u)
260}
261
262fn find_double_newline(buf: &[u8]) -> Option<usize> {
263    // returns index such that caller `drain(..idx + 2)` consumes the full
264    // terminator. For "\n\n" at i, returns i. For "\r\n\r\n" at i, returns
265    // i + 2 (so drain(..i + 4) covers all four bytes).
266    let mut i = 0;
267    while i + 1 < buf.len() {
268        if buf[i] == b'\n' && buf[i + 1] == b'\n' {
269            return Some(i);
270        }
271        if i + 3 < buf.len() && &buf[i..i + 4] == b"\r\n\r\n" {
272            // BUG FIX: was Some(i + 1) in plan template — caller drains
273            // ..end+2, so we must return i+2 to consume all 4 terminator
274            // bytes; the previous value left a stray '\n' in the buffer.
275            return Some(i + 2);
276        }
277        i += 1;
278    }
279    None
280}
281
282fn strip_cr(line: &[u8]) -> &[u8] {
283    // BUG FIX: plan template had a dead first branch (bound `rest` and
284    // immediately discarded it). Removed.
285    if line.ends_with(b"\r") {
286        &line[..line.len() - 1]
287    } else {
288        line
289    }
290}
291
292fn trim_ascii_start(s: &[u8]) -> &[u8] {
293    let mut i = 0;
294    while i < s.len() && (s[i] == b' ' || s[i] == b'\t') {
295        i += 1;
296    }
297    &s[i..]
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn find_double_newline_handles_lf() {
306        // "abc\n\ndef" — \n\n at index 3, drain(..end+2) should consume bytes 0..5
307        let buf = b"abc\n\ndef".to_vec();
308        let end = find_double_newline(&buf).expect("expected LF separator");
309        assert_eq!(end, 3);
310        // Caller would drain(..end+2) = drain(..5) -> consumes "abc\n\n"
311    }
312
313    #[test]
314    fn find_double_newline_handles_crlf() {
315        // "abc\r\n\r\ndef" — \r\n\r\n at index 3, drain(..end+2) must consume bytes 0..7
316        let buf = b"abc\r\n\r\ndef".to_vec();
317        let end = find_double_newline(&buf).expect("expected CRLF separator");
318        // With the bug fix, end = 5 (so drain(..7) removes "abc\r\n\r\n")
319        assert_eq!(end, 5);
320    }
321
322    #[test]
323    fn crlf_terminated_frame_reassembles() {
324        let raw: &[u8] = b"event: message_start\r\n\
325            data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg\",\"type\":\"message\",\"role\":\"assistant\",\"model\":\"crlf-test\",\"content\":[],\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":1,\"output_tokens\":1}}}\r\n\r\n\
326            event: message_stop\r\n\
327            data: {\"type\":\"message_stop\"}\r\n\r\n";
328        let mut r = ClaudeReassembler::new();
329        r.feed(raw);
330        let out = r.finish().expect("message");
331        assert_eq!(out.model.as_deref(), Some("crlf-test"));
332    }
333}