a2a_protocol_client/streaming/
sse_parser.rs1#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct SseFrame {
25 pub data: String,
27
28 pub event_type: Option<String>,
30
31 pub id: Option<String>,
33
34 pub retry: Option<u64>,
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum SseParseError {
43 EventTooLarge {
45 limit: usize,
47 actual: usize,
49 },
50}
51
52impl std::fmt::Display for SseParseError {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 Self::EventTooLarge { limit, actual } => {
56 write!(
57 f,
58 "SSE event too large: {actual} bytes exceeds {limit} byte limit"
59 )
60 }
61 }
62 }
63}
64
65impl std::error::Error for SseParseError {}
66
67const DEFAULT_MAX_EVENT_SIZE: usize = 4 * 1024 * 1024;
69
70#[derive(Debug)]
87pub struct SseParser {
88 line_buf: Vec<u8>,
90 data_lines: Vec<String>,
92 current_event_size: usize,
94 max_event_size: usize,
96 event_type: Option<String>,
98 id: Option<String>,
100 retry: Option<u64>,
102 ready: Vec<Result<SseFrame, SseParseError>>,
104 bom_checked: bool,
106}
107
108impl Default for SseParser {
109 fn default() -> Self {
110 Self {
111 line_buf: Vec::new(),
112 data_lines: Vec::new(),
113 current_event_size: 0,
114 max_event_size: DEFAULT_MAX_EVENT_SIZE,
115 event_type: None,
116 id: None,
117 retry: None,
118 ready: Vec::new(),
119 bom_checked: false,
120 }
121 }
122}
123
124impl SseParser {
125 #[must_use]
127 pub fn new() -> Self {
128 Self::default()
129 }
130
131 #[must_use]
135 pub fn with_max_event_size(max_event_size: usize) -> Self {
136 Self {
137 max_event_size,
138 ..Self::default()
139 }
140 }
141
142 #[must_use]
144 pub const fn pending_count(&self) -> usize {
145 self.ready.len()
146 }
147
148 pub fn feed(&mut self, bytes: &[u8]) {
153 let mut input = bytes;
154 if !self.bom_checked && self.line_buf.is_empty() {
156 if input.starts_with(b"\xEF\xBB\xBF") {
157 input = &input[3..];
158 }
159 if !input.is_empty() || bytes.len() >= 3 {
162 self.bom_checked = true;
163 }
164 }
165 for &byte in input {
166 if byte == b'\n' {
167 self.process_line();
168 self.line_buf.clear();
169 } else if byte != b'\r' {
170 self.line_buf.push(byte);
172 }
173 }
174 }
175
176 pub fn next_frame(&mut self) -> Option<Result<SseFrame, SseParseError>> {
180 if self.ready.is_empty() {
181 None
182 } else {
183 Some(self.ready.remove(0))
184 }
185 }
186
187 fn process_line(&mut self) {
190 if !self.bom_checked {
192 if self.line_buf.starts_with(b"\xEF\xBB\xBF") {
193 self.line_buf.drain(..3);
194 }
195 self.bom_checked = true;
196 }
197 let line = match std::str::from_utf8(&self.line_buf) {
198 Ok(s) => s.to_owned(),
199 Err(_) => return, };
201
202 if line.is_empty() {
203 self.dispatch_frame();
205 return;
206 }
207
208 if line.starts_with(':') {
209 return;
211 }
212
213 let (field, value) = line.find(':').map_or_else(
215 || (line.as_str(), String::new()),
216 |pos| {
217 let field = &line[..pos];
218 let value = line[pos + 1..].trim_start_matches(' ');
219 (field, value.to_owned())
220 },
221 );
222
223 self.current_event_size += value.len();
225 if self.current_event_size > self.max_event_size {
226 let error = SseParseError::EventTooLarge {
228 limit: self.max_event_size,
229 actual: self.current_event_size,
230 };
231 self.data_lines.clear();
232 self.event_type = None;
233 self.current_event_size = 0;
234 self.ready.push(Err(error));
235 return;
236 }
237
238 match field {
239 "data" => self.data_lines.push(value),
240 "event" => self.event_type = Some(value),
241 "id" => {
242 if value.contains('\0') {
243 self.id = None;
245 } else {
246 self.id = Some(value);
247 }
248 }
249 "retry" => {
250 if let Ok(ms) = value.parse::<u64>() {
251 self.retry = Some(ms);
252 }
253 }
254 _ => {
255 }
257 }
258 }
259
260 fn dispatch_frame(&mut self) {
261 if self.data_lines.is_empty() {
262 self.event_type = None;
264 self.current_event_size = 0;
265 return;
266 }
267
268 let mut data = self.data_lines.join("\n");
270 if data.ends_with('\n') {
271 data.pop();
272 }
273
274 let frame = SseFrame {
275 data,
276 event_type: self.event_type.take(),
277 id: self.id.clone(), retry: self.retry,
279 };
280
281 self.data_lines.clear();
282 self.current_event_size = 0;
283 self.ready.push(Ok(frame));
284 }
285}
286
287#[cfg(test)]
290mod tests {
291 use super::*;
292
293 fn parse_all(input: &str) -> Vec<SseFrame> {
294 let mut p = SseParser::new();
295 p.feed(input.as_bytes());
296 let mut frames = Vec::new();
297 while let Some(f) = p.next_frame() {
298 frames.push(f.expect("unexpected error"));
299 }
300 frames
301 }
302
303 #[test]
304 fn parse_single_data_event() {
305 let frames = parse_all("data: hello world\n\n");
306 assert_eq!(frames.len(), 1);
307 assert_eq!(frames[0].data, "hello world");
308 }
309
310 #[test]
311 fn parse_multiline_data() {
312 let frames = parse_all("data: line1\ndata: line2\n\n");
313 assert_eq!(frames.len(), 1);
314 assert_eq!(frames[0].data, "line1\nline2");
315 }
316
317 #[test]
318 fn parse_two_events() {
319 let frames = parse_all("data: first\n\ndata: second\n\n");
320 assert_eq!(frames.len(), 2);
321 assert_eq!(frames[0].data, "first");
322 assert_eq!(frames[1].data, "second");
323 }
324
325 #[test]
326 fn ignore_keepalive_comment() {
327 let frames = parse_all(": keep-alive\n\ndata: real\n\n");
328 assert_eq!(frames.len(), 1);
329 assert_eq!(frames[0].data, "real");
330 }
331
332 #[test]
333 fn parse_event_type() {
334 let frames = parse_all("event: status-update\ndata: {}\n\n");
335 assert_eq!(frames.len(), 1);
336 assert_eq!(frames[0].event_type.as_deref(), Some("status-update"));
337 }
338
339 #[test]
340 fn parse_id_field() {
341 let frames = parse_all("id: 42\ndata: hello\n\n");
342 assert_eq!(frames.len(), 1);
343 assert_eq!(frames[0].id.as_deref(), Some("42"));
344 }
345
346 #[test]
347 fn parse_retry_field() {
348 let frames = parse_all("retry: 5000\ndata: hello\n\n");
349 assert_eq!(frames.len(), 1);
350 assert_eq!(frames[0].retry, Some(5000));
351 }
352
353 #[test]
354 fn fragmented_delivery() {
355 let mut p = SseParser::new();
356 for byte in b"data: fragmented\n\n" {
358 p.feed(std::slice::from_ref(byte));
359 }
360 let frame = p.next_frame().expect("expected frame").expect("no error");
361 assert_eq!(frame.data, "fragmented");
362 }
363
364 #[test]
365 fn blank_line_without_data_is_ignored() {
366 let frames = parse_all("event: ping\n\ndata: real\n\n");
367 assert_eq!(frames.len(), 1);
369 assert_eq!(frames[0].data, "real");
370 }
371
372 #[test]
373 fn json_data_roundtrip() {
374 let json = r#"{"jsonrpc":"2.0","id":"1","result":{"kind":"task"}}"#;
375 let input = format!("data: {json}\n\n");
376 let frames = parse_all(&input);
377 assert_eq!(frames.len(), 1);
378 assert_eq!(frames[0].data, json);
379 }
380
381 #[test]
382 fn event_too_large_returns_error() {
383 let mut p = SseParser::with_max_event_size(32);
384 let big_line = format!("data: {}\n\n", "x".repeat(64));
386 p.feed(big_line.as_bytes());
387 let result = p.next_frame().expect("expected result");
388 assert!(result.is_err());
389 match result.unwrap_err() {
390 SseParseError::EventTooLarge { limit, .. } => {
391 assert_eq!(limit, 32);
392 }
393 }
394 }
395
396 #[test]
397 fn events_after_oversized_event_still_parse() {
398 let mut p = SseParser::with_max_event_size(16);
399 let big = format!("data: {}\n\n", "x".repeat(32));
401 let small = "data: ok\n\n";
403 p.feed(big.as_bytes());
404 p.feed(small.as_bytes());
405
406 let first = p.next_frame().expect("expected result");
407 assert!(first.is_err());
408
409 let second = p.next_frame().expect("expected result");
410 assert_eq!(second.unwrap().data, "ok");
411 }
412}