Skip to main content

codetether_agent/provider/bedrock/
eventstream.rs

1//! Minimal parser for the AWS `vnd.amazon.eventstream` binary frame format.
2//!
3//! Each frame layout:
4//!
5//! ```text
6//! +----------------------------------------------------------+
7//! | total_len:       u32 BE                                  |
8//! | headers_len:     u32 BE                                  |
9//! | prelude_crc:     u32 BE  (ignored — TLS provides integrity)
10//! | headers:         headers_len bytes                        |
11//! | payload:         total_len - headers_len - 16 bytes       |
12//! | message_crc:     u32 BE  (ignored)                        |
13//! +----------------------------------------------------------+
14//! ```
15//!
16//! Each header: `name_len u8`, `name utf8`, `value_type u8`, then a
17//! type-specific value. We only need type `7` (UTF-8 string, `u16 BE` length).
18//!
19//! This implementation is intentionally minimal: CRC validation is skipped
20//! (TLS already protects integrity), and only string-typed headers are
21//! decoded since Bedrock Converse only emits those.
22//!
23//! # Examples
24//!
25//! ```rust
26//! use codetether_agent::provider::bedrock::eventstream::FrameBuffer;
27//!
28//! // Build a tiny frame: no headers, payload = "hi".
29//! let payload = b"hi";
30//! let total_len: u32 = 16 + payload.len() as u32;
31//! let mut frame = Vec::new();
32//! frame.extend_from_slice(&total_len.to_be_bytes());
33//! frame.extend_from_slice(&0u32.to_be_bytes()); // headers_len
34//! frame.extend_from_slice(&0u32.to_be_bytes()); // prelude_crc (ignored)
35//! frame.extend_from_slice(payload);
36//! frame.extend_from_slice(&0u32.to_be_bytes()); // message_crc (ignored)
37//!
38//! let mut buf = FrameBuffer::new();
39//! buf.extend(&frame);
40//! let msg = buf.next_frame().unwrap().expect("one frame");
41//! assert_eq!(msg.payload, b"hi");
42//! assert!(msg.headers.is_empty());
43//! assert!(buf.next_frame().unwrap().is_none());
44//! ```
45
46use anyhow::{Result, anyhow};
47use std::collections::HashMap;
48
49/// One decoded AWS eventstream message.
50#[derive(Debug, Clone, Default)]
51pub struct EventMessage {
52    /// String-typed headers only; binary types are skipped.
53    pub headers: HashMap<String, String>,
54    /// Raw message payload (application-defined; typically JSON).
55    pub payload: Vec<u8>,
56}
57
58impl EventMessage {
59    /// Return the `:event-type` header (event variant name).
60    pub fn event_type(&self) -> Option<&str> {
61        self.headers.get(":event-type").map(String::as_str)
62    }
63
64    /// Return the `:message-type` header (`event`, `exception`, `error`).
65    pub fn message_type(&self) -> Option<&str> {
66        self.headers.get(":message-type").map(String::as_str)
67    }
68}
69
70/// Streaming frame buffer: feed arbitrary byte chunks via [`FrameBuffer::extend`]
71/// and pull out complete messages via [`FrameBuffer::next_frame`].
72#[derive(Debug, Default)]
73pub struct FrameBuffer {
74    buf: Vec<u8>,
75}
76
77impl FrameBuffer {
78    /// Create an empty buffer.
79    pub fn new() -> Self {
80        Self::default()
81    }
82
83    /// Append more bytes from the transport.
84    pub fn extend(&mut self, chunk: &[u8]) {
85        self.buf.extend_from_slice(chunk);
86    }
87
88    /// Try to parse the next full frame from the buffer.
89    ///
90    /// Returns:
91    /// - `Ok(Some(msg))` — a complete message was decoded and removed.
92    /// - `Ok(None)` — need more bytes.
93    /// - `Err(_)` — malformed frame.
94    pub fn next_frame(&mut self) -> Result<Option<EventMessage>> {
95        if self.buf.len() < 12 {
96            return Ok(None);
97        }
98        let total_len = u32::from_be_bytes(
99            self.buf[0..4]
100                .try_into()
101                .map_err(|_| anyhow!("short read"))?,
102        ) as usize;
103        let headers_len = u32::from_be_bytes(
104            self.buf[4..8]
105                .try_into()
106                .map_err(|_| anyhow!("short read"))?,
107        ) as usize;
108
109        if total_len < 16 || headers_len + 16 > total_len {
110            return Err(anyhow!(
111                "invalid eventstream frame: total_len={total_len}, headers_len={headers_len}"
112            ));
113        }
114        if self.buf.len() < total_len {
115            return Ok(None);
116        }
117
118        // prelude ends at 12; headers [12, 12+headers_len); payload
119        // [12+headers_len, total_len-4); trailing CRC 4 bytes.
120        let headers_start = 12usize;
121        let headers_end = headers_start + headers_len;
122        let payload_end = total_len - 4;
123
124        let headers = parse_headers(&self.buf[headers_start..headers_end])?;
125        let payload = self.buf[headers_end..payload_end].to_vec();
126
127        self.buf.drain(..total_len);
128        Ok(Some(EventMessage { headers, payload }))
129    }
130}
131
132fn parse_headers(mut bytes: &[u8]) -> Result<HashMap<String, String>> {
133    let mut out = HashMap::new();
134    while !bytes.is_empty() {
135        if bytes.is_empty() {
136            break;
137        }
138        let name_len = bytes[0] as usize;
139        bytes = &bytes[1..];
140        if bytes.len() < name_len + 1 {
141            return Err(anyhow!("truncated header name"));
142        }
143        let name = std::str::from_utf8(&bytes[..name_len])
144            .map_err(|e| anyhow!("bad header name utf8: {e}"))?
145            .to_string();
146        bytes = &bytes[name_len..];
147        let value_type = bytes[0];
148        bytes = &bytes[1..];
149
150        match value_type {
151            // UTF-8 string, u16 BE length
152            7 => {
153                if bytes.len() < 2 {
154                    return Err(anyhow!("truncated header value length"));
155                }
156                let vlen = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
157                bytes = &bytes[2..];
158                if bytes.len() < vlen {
159                    return Err(anyhow!("truncated header value"));
160                }
161                let value = std::str::from_utf8(&bytes[..vlen])
162                    .map_err(|e| anyhow!("bad header value utf8: {e}"))?
163                    .to_string();
164                bytes = &bytes[vlen..];
165                out.insert(name, value);
166            }
167            // Skip other types (bool, byte, int16/32/64, bytes, timestamp, uuid)
168            0 | 1 => {} // true / false — no value bytes
169            2 => bytes = &bytes[1..],
170            3 => bytes = &bytes[2..],
171            4 => bytes = &bytes[4..],
172            5 => bytes = &bytes[8..],
173            6 | 8 => {
174                // byte array / timestamp - u16-prefixed or 8 bytes
175                if value_type == 6 {
176                    let vlen = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
177                    bytes = &bytes[2 + vlen..];
178                } else {
179                    bytes = &bytes[8..];
180                }
181            }
182            9 => bytes = &bytes[16..], // uuid
183            _ => return Err(anyhow!("unknown header type {value_type}")),
184        }
185    }
186    Ok(out)
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    fn build_frame(headers: &[(&str, &str)], payload: &[u8]) -> Vec<u8> {
194        let mut header_bytes = Vec::new();
195        for (k, v) in headers {
196            header_bytes.push(k.len() as u8);
197            header_bytes.extend_from_slice(k.as_bytes());
198            header_bytes.push(7u8); // string type
199            header_bytes.extend_from_slice(&(v.len() as u16).to_be_bytes());
200            header_bytes.extend_from_slice(v.as_bytes());
201        }
202        let total_len = 16 + header_bytes.len() + payload.len();
203        let mut frame = Vec::new();
204        frame.extend_from_slice(&(total_len as u32).to_be_bytes());
205        frame.extend_from_slice(&(header_bytes.len() as u32).to_be_bytes());
206        frame.extend_from_slice(&0u32.to_be_bytes());
207        frame.extend_from_slice(&header_bytes);
208        frame.extend_from_slice(payload);
209        frame.extend_from_slice(&0u32.to_be_bytes());
210        frame
211    }
212
213    #[test]
214    fn parses_single_frame_with_headers() {
215        let frame = build_frame(
216            &[(":event-type", "messageStart"), (":message-type", "event")],
217            br#"{"role":"assistant"}"#,
218        );
219        let mut buf = FrameBuffer::new();
220        buf.extend(&frame);
221        let msg = buf.next_frame().unwrap().unwrap();
222        assert_eq!(msg.event_type(), Some("messageStart"));
223        assert_eq!(msg.message_type(), Some("event"));
224        assert_eq!(msg.payload, br#"{"role":"assistant"}"#);
225    }
226
227    #[test]
228    fn handles_chunked_delivery() {
229        let frame = build_frame(&[(":event-type", "x")], b"hello");
230        let mut buf = FrameBuffer::new();
231        buf.extend(&frame[..5]);
232        assert!(buf.next_frame().unwrap().is_none());
233        buf.extend(&frame[5..]);
234        assert!(buf.next_frame().unwrap().is_some());
235    }
236
237    #[test]
238    fn parses_multiple_frames() {
239        let mut all = Vec::new();
240        all.extend(build_frame(&[(":event-type", "a")], b"1"));
241        all.extend(build_frame(&[(":event-type", "b")], b"22"));
242        let mut buf = FrameBuffer::new();
243        buf.extend(&all);
244        assert_eq!(buf.next_frame().unwrap().unwrap().event_type(), Some("a"));
245        assert_eq!(buf.next_frame().unwrap().unwrap().event_type(), Some("b"));
246        assert!(buf.next_frame().unwrap().is_none());
247    }
248}