Skip to main content

ccs_proxy/provider/
codex.rs

1//! OpenAI Responses API streaming reassembler.
2//!
3//! Spec: <https://platform.openai.com/docs/api-reference/responses-streaming>
4
5use crate::capture::Usage;
6use serde_json::Value;
7
8#[derive(Debug, Default)]
9pub struct CodexResponse {
10    pub id: Option<String>,
11    pub model: Option<String>,
12    pub status: Option<String>,
13    pub text_parts: Vec<String>,
14    pub usage: Option<Usage>,
15}
16
17impl CodexResponse {
18    pub fn text_content(&self) -> String {
19        self.text_parts.concat()
20    }
21
22    pub fn to_json(&self) -> Value {
23        serde_json::json!({
24            "id": self.id,
25            "model": self.model,
26            "status": self.status,
27            "output_text": self.text_content(),
28            "usage": self.usage,
29        })
30    }
31}
32
33pub struct CodexReassembler {
34    buffer: Vec<u8>,
35    resp: CodexResponse,
36    frames_count: u64,
37}
38
39impl Default for CodexReassembler {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl CodexReassembler {
46    pub fn new() -> Self {
47        Self {
48            buffer: Vec::new(),
49            resp: CodexResponse::default(),
50            frames_count: 0,
51        }
52    }
53
54    pub fn feed(&mut self, chunk: &[u8]) {
55        self.buffer.extend_from_slice(chunk);
56        while let Some(end) = find_double_newline(&self.buffer) {
57            let frame_bytes = self.buffer.drain(..end + 2).collect::<Vec<u8>>();
58            self.process_frame(&frame_bytes);
59        }
60    }
61
62    pub fn frames_count(&self) -> u64 {
63        self.frames_count
64    }
65
66    pub fn finish(mut self) -> Option<CodexResponse> {
67        if !self.buffer.is_empty() {
68            let leftover = std::mem::take(&mut self.buffer);
69            self.process_frame(&leftover);
70        }
71        if self.frames_count == 0 {
72            return None;
73        }
74        Some(self.resp)
75    }
76
77    fn process_frame(&mut self, raw: &[u8]) {
78        self.frames_count += 1;
79        let mut data_lines: Vec<&[u8]> = Vec::new();
80        for line in raw.split(|b| *b == b'\n') {
81            let line = strip_cr(line);
82            if let Some(rest) = line.strip_prefix(b"data:") {
83                let trimmed = trim_ascii_start(rest);
84                data_lines.push(trimmed);
85            }
86        }
87        if data_lines.is_empty() {
88            return;
89        }
90        let mut joined = Vec::new();
91        for (i, l) in data_lines.iter().enumerate() {
92            if i > 0 {
93                joined.push(b'\n');
94            }
95            joined.extend_from_slice(l);
96        }
97        let Ok(text) = std::str::from_utf8(&joined) else {
98            return;
99        };
100        let Ok(value) = serde_json::from_str::<Value>(text) else {
101            return;
102        };
103        self.apply_event(&value);
104    }
105
106    fn apply_event(&mut self, v: &Value) {
107        let ty = v.get("type").and_then(|x| x.as_str()).unwrap_or("");
108        match ty {
109            "response.created" | "response.in_progress" => {
110                if let Some(r) = v.get("response") {
111                    if let Some(id) = r.get("id").and_then(|x| x.as_str()) {
112                        self.resp.id = Some(id.to_string());
113                    }
114                    if let Some(model) = r.get("model").and_then(|x| x.as_str()) {
115                        self.resp.model = Some(model.to_string());
116                    }
117                    if let Some(status) = r.get("status").and_then(|x| x.as_str()) {
118                        self.resp.status = Some(status.to_string());
119                    }
120                }
121            }
122            "response.output_text.delta" => {
123                if let Some(delta) = v.get("delta").and_then(|x| x.as_str()) {
124                    self.resp.text_parts.push(delta.to_string());
125                }
126            }
127            "response.completed" => {
128                if let Some(r) = v.get("response") {
129                    if let Some(status) = r.get("status").and_then(|x| x.as_str()) {
130                        self.resp.status = Some(status.to_string());
131                    }
132                    if let Some(u) = r.get("usage") {
133                        self.resp.usage = Some(parse_usage(u));
134                    }
135                }
136            }
137            _ => {}
138        }
139    }
140}
141
142fn parse_usage(v: &Value) -> Usage {
143    Usage {
144        input_tokens: v.get("input_tokens").and_then(|x| x.as_u64()).unwrap_or(0),
145        output_tokens: v.get("output_tokens").and_then(|x| x.as_u64()).unwrap_or(0),
146        cache_creation_input_tokens: 0,
147        cache_read_input_tokens: v
148            .get("cache_read_input_tokens")
149            .and_then(|x| x.as_u64())
150            .unwrap_or(0),
151    }
152}
153
154fn find_double_newline(buf: &[u8]) -> Option<usize> {
155    // Returns index such that caller `drain(..idx + 2)` consumes the full SSE
156    // terminator. For "\n\n" at i, returns i. For "\r\n\r\n" at i, returns
157    // i + 2 (so drain(..i + 4) covers all four bytes). Inlined here rather
158    // than cross-imported from `super::claude` to avoid polluting that
159    // module's public surface with a `#[doc(hidden)]` helper.
160    let mut i = 0;
161    while i + 1 < buf.len() {
162        if buf[i] == b'\n' && buf[i + 1] == b'\n' {
163            return Some(i);
164        }
165        if i + 3 < buf.len() && &buf[i..i + 4] == b"\r\n\r\n" {
166            return Some(i + 2);
167        }
168        i += 1;
169    }
170    None
171}
172
173fn strip_cr(line: &[u8]) -> &[u8] {
174    if line.ends_with(b"\r") {
175        &line[..line.len() - 1]
176    } else {
177        line
178    }
179}
180
181fn trim_ascii_start(s: &[u8]) -> &[u8] {
182    let mut i = 0;
183    while i < s.len() && (s[i] == b' ' || s[i] == b'\t') {
184        i += 1;
185    }
186    &s[i..]
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn find_double_newline_handles_lf_and_crlf() {
195        assert_eq!(find_double_newline(b"abc\n\ndef"), Some(3));
196        assert_eq!(find_double_newline(b"abc\r\n\r\ndef"), Some(5));
197    }
198}