Skip to main content

codetether_agent/provider/bedrock/
eventstream.rs

1//! Minimal parser for the AWS `vnd.amazon.eventstream` binary frame format.
2//!
3//! Each frame layout:
4//!
5//! ```text
6//! +----------------------------------------------------------+
7//! | total_len:       u32 BE                                  |
8//! | headers_len:     u32 BE                                  |
9//! | prelude_crc:     u32 BE  (ignored — TLS provides integrity)
10//! | headers:         headers_len bytes                        |
11//! | payload:         total_len - headers_len - 16 bytes       |
12//! | message_crc:     u32 BE  (ignored)                        |
13//! +----------------------------------------------------------+
14//! ```
15//!
16//! Each header: `name_len u8`, `name utf8`, `value_type u8`, then a
17//! type-specific value. We only need type `7` (UTF-8 string, `u16 BE` length).
18//!
19//! This implementation is intentionally minimal: CRC validation is skipped
20//! (TLS already protects integrity), and only string-typed headers are
21//! decoded since Bedrock Converse only emits those.
22//!
23//! # Examples
24//!
25//! ```rust
26//! use codetether_agent::provider::bedrock::eventstream::FrameBuffer;
27//!
28//! // Build a tiny frame: no headers, payload = "hi".
29//! let payload = b"hi";
30//! let total_len: u32 = 16 + payload.len() as u32;
31//! let mut frame = Vec::new();
32//! frame.extend_from_slice(&total_len.to_be_bytes());
33//! frame.extend_from_slice(&0u32.to_be_bytes()); // headers_len
34//! frame.extend_from_slice(&0u32.to_be_bytes()); // prelude_crc (ignored)
35//! frame.extend_from_slice(payload);
36//! frame.extend_from_slice(&0u32.to_be_bytes()); // message_crc (ignored)
37//!
38//! let mut buf = FrameBuffer::new();
39//! buf.extend(&frame);
40//! let msg = buf.next_frame().unwrap().expect("one frame");
41//! assert_eq!(msg.payload, b"hi");
42//! assert!(msg.headers.is_empty());
43//! assert!(buf.next_frame().unwrap().is_none());
44//! ```
45
46use anyhow::{Result, anyhow};
47use std::collections::HashMap;
48
49/// One decoded AWS eventstream message.
50#[derive(Debug, Clone, Default)]
51pub struct EventMessage {
52    /// String-typed headers only; binary types are skipped.
53    pub headers: HashMap<String, String>,
54    /// Raw message payload (application-defined; typically JSON).
55    pub payload: Vec<u8>,
56}
57
58impl EventMessage {
59    /// Return the `:event-type` header (event variant name).
60    pub fn event_type(&self) -> Option<&str> {
61        self.headers.get(":event-type").map(String::as_str)
62    }
63
64    /// Return the `:message-type` header (`event`, `exception`, `error`).
65    pub fn message_type(&self) -> Option<&str> {
66        self.headers.get(":message-type").map(String::as_str)
67    }
68}
69
70/// Streaming frame buffer: feed arbitrary byte chunks via [`FrameBuffer::extend`]
71/// and pull out complete messages via [`FrameBuffer::next_frame`].
72#[derive(Debug, Default)]
73pub struct FrameBuffer {
74    buf: Vec<u8>,
75}
76
77impl FrameBuffer {
78    /// Create an empty buffer.
79    pub fn new() -> Self {
80        Self::default()
81    }
82
83    /// Append more bytes from the transport.
84    pub fn extend(&mut self, chunk: &[u8]) {
85        self.buf.extend_from_slice(chunk);
86    }
87
88    /// Try to parse the next full frame from the buffer.
89    ///
90    /// Returns:
91    /// - `Ok(Some(msg))` — a complete message was decoded and removed.
92    /// - `Ok(None)` — need more bytes.
93    /// - `Err(_)` — malformed frame.
94    pub fn next_frame(&mut self) -> Result<Option<EventMessage>> {
95        if self.buf.len() < 12 {
96            return Ok(None);
97        }
98        let total_len = u32::from_be_bytes(self.buf[0..4].try_into().unwrap()) as usize;
99        let headers_len = u32::from_be_bytes(self.buf[4..8].try_into().unwrap()) as usize;
100
101        if total_len < 16 || headers_len + 16 > total_len {
102            return Err(anyhow!(
103                "invalid eventstream frame: total_len={total_len}, headers_len={headers_len}"
104            ));
105        }
106        if self.buf.len() < total_len {
107            return Ok(None);
108        }
109
110        // prelude ends at 12; headers [12, 12+headers_len); payload
111        // [12+headers_len, total_len-4); trailing CRC 4 bytes.
112        let headers_start = 12usize;
113        let headers_end = headers_start + headers_len;
114        let payload_end = total_len - 4;
115
116        let headers = parse_headers(&self.buf[headers_start..headers_end])?;
117        let payload = self.buf[headers_end..payload_end].to_vec();
118
119        self.buf.drain(..total_len);
120        Ok(Some(EventMessage { headers, payload }))
121    }
122}
123
124fn parse_headers(mut bytes: &[u8]) -> Result<HashMap<String, String>> {
125    let mut out = HashMap::new();
126    while !bytes.is_empty() {
127        if bytes.is_empty() {
128            break;
129        }
130        let name_len = bytes[0] as usize;
131        bytes = &bytes[1..];
132        if bytes.len() < name_len + 1 {
133            return Err(anyhow!("truncated header name"));
134        }
135        let name = std::str::from_utf8(&bytes[..name_len])
136            .map_err(|e| anyhow!("bad header name utf8: {e}"))?
137            .to_string();
138        bytes = &bytes[name_len..];
139        let value_type = bytes[0];
140        bytes = &bytes[1..];
141
142        match value_type {
143            // UTF-8 string, u16 BE length
144            7 => {
145                if bytes.len() < 2 {
146                    return Err(anyhow!("truncated header value length"));
147                }
148                let vlen = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
149                bytes = &bytes[2..];
150                if bytes.len() < vlen {
151                    return Err(anyhow!("truncated header value"));
152                }
153                let value = std::str::from_utf8(&bytes[..vlen])
154                    .map_err(|e| anyhow!("bad header value utf8: {e}"))?
155                    .to_string();
156                bytes = &bytes[vlen..];
157                out.insert(name, value);
158            }
159            // Skip other types (bool, byte, int16/32/64, bytes, timestamp, uuid)
160            0 | 1 => {} // true / false — no value bytes
161            2 => bytes = &bytes[1..],
162            3 => bytes = &bytes[2..],
163            4 => bytes = &bytes[4..],
164            5 => bytes = &bytes[8..],
165            6 | 8 => {
166                // byte array / timestamp - u16-prefixed or 8 bytes
167                if value_type == 6 {
168                    let vlen = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
169                    bytes = &bytes[2 + vlen..];
170                } else {
171                    bytes = &bytes[8..];
172                }
173            }
174            9 => bytes = &bytes[16..], // uuid
175            _ => return Err(anyhow!("unknown header type {value_type}")),
176        }
177    }
178    Ok(out)
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    fn build_frame(headers: &[(&str, &str)], payload: &[u8]) -> Vec<u8> {
186        let mut header_bytes = Vec::new();
187        for (k, v) in headers {
188            header_bytes.push(k.len() as u8);
189            header_bytes.extend_from_slice(k.as_bytes());
190            header_bytes.push(7u8); // string type
191            header_bytes.extend_from_slice(&(v.len() as u16).to_be_bytes());
192            header_bytes.extend_from_slice(v.as_bytes());
193        }
194        let total_len = 16 + header_bytes.len() + payload.len();
195        let mut frame = Vec::new();
196        frame.extend_from_slice(&(total_len as u32).to_be_bytes());
197        frame.extend_from_slice(&(header_bytes.len() as u32).to_be_bytes());
198        frame.extend_from_slice(&0u32.to_be_bytes());
199        frame.extend_from_slice(&header_bytes);
200        frame.extend_from_slice(payload);
201        frame.extend_from_slice(&0u32.to_be_bytes());
202        frame
203    }
204
205    #[test]
206    fn parses_single_frame_with_headers() {
207        let frame = build_frame(
208            &[(":event-type", "messageStart"), (":message-type", "event")],
209            br#"{"role":"assistant"}"#,
210        );
211        let mut buf = FrameBuffer::new();
212        buf.extend(&frame);
213        let msg = buf.next_frame().unwrap().unwrap();
214        assert_eq!(msg.event_type(), Some("messageStart"));
215        assert_eq!(msg.message_type(), Some("event"));
216        assert_eq!(msg.payload, br#"{"role":"assistant"}"#);
217    }
218
219    #[test]
220    fn handles_chunked_delivery() {
221        let frame = build_frame(&[(":event-type", "x")], b"hello");
222        let mut buf = FrameBuffer::new();
223        buf.extend(&frame[..5]);
224        assert!(buf.next_frame().unwrap().is_none());
225        buf.extend(&frame[5..]);
226        assert!(buf.next_frame().unwrap().is_some());
227    }
228
229    #[test]
230    fn parses_multiple_frames() {
231        let mut all = Vec::new();
232        all.extend(build_frame(&[(":event-type", "a")], b"1"));
233        all.extend(build_frame(&[(":event-type", "b")], b"22"));
234        let mut buf = FrameBuffer::new();
235        buf.extend(&all);
236        assert_eq!(buf.next_frame().unwrap().unwrap().event_type(), Some("a"));
237        assert_eq!(buf.next_frame().unwrap().unwrap().event_type(), Some("b"));
238        assert!(buf.next_frame().unwrap().is_none());
239    }
240}