codetether_agent/provider/bedrock/
eventstream.rs1use anyhow::{Result, anyhow};
47use std::collections::HashMap;
48
49#[derive(Debug, Clone, Default)]
51pub struct EventMessage {
52 pub headers: HashMap<String, String>,
54 pub payload: Vec<u8>,
56}
57
58impl EventMessage {
59 pub fn event_type(&self) -> Option<&str> {
61 self.headers.get(":event-type").map(String::as_str)
62 }
63
64 pub fn message_type(&self) -> Option<&str> {
66 self.headers.get(":message-type").map(String::as_str)
67 }
68}
69
70#[derive(Debug, Default)]
73pub struct FrameBuffer {
74 buf: Vec<u8>,
75}
76
77impl FrameBuffer {
78 pub fn new() -> Self {
80 Self::default()
81 }
82
83 pub fn extend(&mut self, chunk: &[u8]) {
85 self.buf.extend_from_slice(chunk);
86 }
87
88 pub fn next_frame(&mut self) -> Result<Option<EventMessage>> {
95 if self.buf.len() < 12 {
96 return Ok(None);
97 }
98 let total_len = u32::from_be_bytes(
99 self.buf[0..4]
100 .try_into()
101 .map_err(|_| anyhow!("short read"))?,
102 ) as usize;
103 let headers_len = u32::from_be_bytes(
104 self.buf[4..8]
105 .try_into()
106 .map_err(|_| anyhow!("short read"))?,
107 ) as usize;
108
109 if total_len < 16 || headers_len + 16 > total_len {
110 return Err(anyhow!(
111 "invalid eventstream frame: total_len={total_len}, headers_len={headers_len}"
112 ));
113 }
114 if self.buf.len() < total_len {
115 return Ok(None);
116 }
117
118 let headers_start = 12usize;
121 let headers_end = headers_start + headers_len;
122 let payload_end = total_len - 4;
123
124 let headers = parse_headers(&self.buf[headers_start..headers_end])?;
125 let payload = self.buf[headers_end..payload_end].to_vec();
126
127 self.buf.drain(..total_len);
128 Ok(Some(EventMessage { headers, payload }))
129 }
130}
131
132fn parse_headers(mut bytes: &[u8]) -> Result<HashMap<String, String>> {
133 let mut out = HashMap::new();
134 while !bytes.is_empty() {
135 if bytes.is_empty() {
136 break;
137 }
138 let name_len = bytes[0] as usize;
139 bytes = &bytes[1..];
140 if bytes.len() < name_len + 1 {
141 return Err(anyhow!("truncated header name"));
142 }
143 let name = std::str::from_utf8(&bytes[..name_len])
144 .map_err(|e| anyhow!("bad header name utf8: {e}"))?
145 .to_string();
146 bytes = &bytes[name_len..];
147 let value_type = bytes[0];
148 bytes = &bytes[1..];
149
150 match value_type {
151 7 => {
153 if bytes.len() < 2 {
154 return Err(anyhow!("truncated header value length"));
155 }
156 let vlen = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
157 bytes = &bytes[2..];
158 if bytes.len() < vlen {
159 return Err(anyhow!("truncated header value"));
160 }
161 let value = std::str::from_utf8(&bytes[..vlen])
162 .map_err(|e| anyhow!("bad header value utf8: {e}"))?
163 .to_string();
164 bytes = &bytes[vlen..];
165 out.insert(name, value);
166 }
167 0 | 1 => {} 2 => bytes = &bytes[1..],
170 3 => bytes = &bytes[2..],
171 4 => bytes = &bytes[4..],
172 5 => bytes = &bytes[8..],
173 6 | 8 => {
174 if value_type == 6 {
176 let vlen = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
177 bytes = &bytes[2 + vlen..];
178 } else {
179 bytes = &bytes[8..];
180 }
181 }
182 9 => bytes = &bytes[16..], _ => return Err(anyhow!("unknown header type {value_type}")),
184 }
185 }
186 Ok(out)
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 fn build_frame(headers: &[(&str, &str)], payload: &[u8]) -> Vec<u8> {
194 let mut header_bytes = Vec::new();
195 for (k, v) in headers {
196 header_bytes.push(k.len() as u8);
197 header_bytes.extend_from_slice(k.as_bytes());
198 header_bytes.push(7u8); header_bytes.extend_from_slice(&(v.len() as u16).to_be_bytes());
200 header_bytes.extend_from_slice(v.as_bytes());
201 }
202 let total_len = 16 + header_bytes.len() + payload.len();
203 let mut frame = Vec::new();
204 frame.extend_from_slice(&(total_len as u32).to_be_bytes());
205 frame.extend_from_slice(&(header_bytes.len() as u32).to_be_bytes());
206 frame.extend_from_slice(&0u32.to_be_bytes());
207 frame.extend_from_slice(&header_bytes);
208 frame.extend_from_slice(payload);
209 frame.extend_from_slice(&0u32.to_be_bytes());
210 frame
211 }
212
213 #[test]
214 fn parses_single_frame_with_headers() {
215 let frame = build_frame(
216 &[(":event-type", "messageStart"), (":message-type", "event")],
217 br#"{"role":"assistant"}"#,
218 );
219 let mut buf = FrameBuffer::new();
220 buf.extend(&frame);
221 let msg = buf.next_frame().unwrap().unwrap();
222 assert_eq!(msg.event_type(), Some("messageStart"));
223 assert_eq!(msg.message_type(), Some("event"));
224 assert_eq!(msg.payload, br#"{"role":"assistant"}"#);
225 }
226
227 #[test]
228 fn handles_chunked_delivery() {
229 let frame = build_frame(&[(":event-type", "x")], b"hello");
230 let mut buf = FrameBuffer::new();
231 buf.extend(&frame[..5]);
232 assert!(buf.next_frame().unwrap().is_none());
233 buf.extend(&frame[5..]);
234 assert!(buf.next_frame().unwrap().is_some());
235 }
236
237 #[test]
238 fn parses_multiple_frames() {
239 let mut all = Vec::new();
240 all.extend(build_frame(&[(":event-type", "a")], b"1"));
241 all.extend(build_frame(&[(":event-type", "b")], b"22"));
242 let mut buf = FrameBuffer::new();
243 buf.extend(&all);
244 assert_eq!(buf.next_frame().unwrap().unwrap().event_type(), Some("a"));
245 assert_eq!(buf.next_frame().unwrap().unwrap().event_type(), Some("b"));
246 assert!(buf.next_frame().unwrap().is_none());
247 }
248}