codetether_agent/provider/bedrock/
eventstream.rs1use anyhow::{Result, anyhow};
47use std::collections::HashMap;
48
49#[derive(Debug, Clone, Default)]
51pub struct EventMessage {
52 pub headers: HashMap<String, String>,
54 pub payload: Vec<u8>,
56}
57
58impl EventMessage {
59 pub fn event_type(&self) -> Option<&str> {
61 self.headers.get(":event-type").map(String::as_str)
62 }
63
64 pub fn message_type(&self) -> Option<&str> {
66 self.headers.get(":message-type").map(String::as_str)
67 }
68}
69
70#[derive(Debug, Default)]
73pub struct FrameBuffer {
74 buf: Vec<u8>,
75}
76
77impl FrameBuffer {
78 pub fn new() -> Self {
80 Self::default()
81 }
82
83 pub fn extend(&mut self, chunk: &[u8]) {
85 self.buf.extend_from_slice(chunk);
86 }
87
88 pub fn next_frame(&mut self) -> Result<Option<EventMessage>> {
95 if self.buf.len() < 12 {
96 return Ok(None);
97 }
98 let total_len = u32::from_be_bytes(self.buf[0..4].try_into().unwrap()) as usize;
99 let headers_len = u32::from_be_bytes(self.buf[4..8].try_into().unwrap()) as usize;
100
101 if total_len < 16 || headers_len + 16 > total_len {
102 return Err(anyhow!(
103 "invalid eventstream frame: total_len={total_len}, headers_len={headers_len}"
104 ));
105 }
106 if self.buf.len() < total_len {
107 return Ok(None);
108 }
109
110 let headers_start = 12usize;
113 let headers_end = headers_start + headers_len;
114 let payload_end = total_len - 4;
115
116 let headers = parse_headers(&self.buf[headers_start..headers_end])?;
117 let payload = self.buf[headers_end..payload_end].to_vec();
118
119 self.buf.drain(..total_len);
120 Ok(Some(EventMessage { headers, payload }))
121 }
122}
123
124fn parse_headers(mut bytes: &[u8]) -> Result<HashMap<String, String>> {
125 let mut out = HashMap::new();
126 while !bytes.is_empty() {
127 if bytes.is_empty() {
128 break;
129 }
130 let name_len = bytes[0] as usize;
131 bytes = &bytes[1..];
132 if bytes.len() < name_len + 1 {
133 return Err(anyhow!("truncated header name"));
134 }
135 let name = std::str::from_utf8(&bytes[..name_len])
136 .map_err(|e| anyhow!("bad header name utf8: {e}"))?
137 .to_string();
138 bytes = &bytes[name_len..];
139 let value_type = bytes[0];
140 bytes = &bytes[1..];
141
142 match value_type {
143 7 => {
145 if bytes.len() < 2 {
146 return Err(anyhow!("truncated header value length"));
147 }
148 let vlen = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
149 bytes = &bytes[2..];
150 if bytes.len() < vlen {
151 return Err(anyhow!("truncated header value"));
152 }
153 let value = std::str::from_utf8(&bytes[..vlen])
154 .map_err(|e| anyhow!("bad header value utf8: {e}"))?
155 .to_string();
156 bytes = &bytes[vlen..];
157 out.insert(name, value);
158 }
159 0 | 1 => {} 2 => bytes = &bytes[1..],
162 3 => bytes = &bytes[2..],
163 4 => bytes = &bytes[4..],
164 5 => bytes = &bytes[8..],
165 6 | 8 => {
166 if value_type == 6 {
168 let vlen = u16::from_be_bytes([bytes[0], bytes[1]]) as usize;
169 bytes = &bytes[2 + vlen..];
170 } else {
171 bytes = &bytes[8..];
172 }
173 }
174 9 => bytes = &bytes[16..], _ => return Err(anyhow!("unknown header type {value_type}")),
176 }
177 }
178 Ok(out)
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 fn build_frame(headers: &[(&str, &str)], payload: &[u8]) -> Vec<u8> {
186 let mut header_bytes = Vec::new();
187 for (k, v) in headers {
188 header_bytes.push(k.len() as u8);
189 header_bytes.extend_from_slice(k.as_bytes());
190 header_bytes.push(7u8); header_bytes.extend_from_slice(&(v.len() as u16).to_be_bytes());
192 header_bytes.extend_from_slice(v.as_bytes());
193 }
194 let total_len = 16 + header_bytes.len() + payload.len();
195 let mut frame = Vec::new();
196 frame.extend_from_slice(&(total_len as u32).to_be_bytes());
197 frame.extend_from_slice(&(header_bytes.len() as u32).to_be_bytes());
198 frame.extend_from_slice(&0u32.to_be_bytes());
199 frame.extend_from_slice(&header_bytes);
200 frame.extend_from_slice(payload);
201 frame.extend_from_slice(&0u32.to_be_bytes());
202 frame
203 }
204
205 #[test]
206 fn parses_single_frame_with_headers() {
207 let frame = build_frame(
208 &[(":event-type", "messageStart"), (":message-type", "event")],
209 br#"{"role":"assistant"}"#,
210 );
211 let mut buf = FrameBuffer::new();
212 buf.extend(&frame);
213 let msg = buf.next_frame().unwrap().unwrap();
214 assert_eq!(msg.event_type(), Some("messageStart"));
215 assert_eq!(msg.message_type(), Some("event"));
216 assert_eq!(msg.payload, br#"{"role":"assistant"}"#);
217 }
218
219 #[test]
220 fn handles_chunked_delivery() {
221 let frame = build_frame(&[(":event-type", "x")], b"hello");
222 let mut buf = FrameBuffer::new();
223 buf.extend(&frame[..5]);
224 assert!(buf.next_frame().unwrap().is_none());
225 buf.extend(&frame[5..]);
226 assert!(buf.next_frame().unwrap().is_some());
227 }
228
229 #[test]
230 fn parses_multiple_frames() {
231 let mut all = Vec::new();
232 all.extend(build_frame(&[(":event-type", "a")], b"1"));
233 all.extend(build_frame(&[(":event-type", "b")], b"22"));
234 let mut buf = FrameBuffer::new();
235 buf.extend(&all);
236 assert_eq!(buf.next_frame().unwrap().unwrap().event_type(), Some("a"));
237 assert_eq!(buf.next_frame().unwrap().unwrap().event_type(), Some("b"));
238 assert!(buf.next_frame().unwrap().is_none());
239 }
240}