sh_layer1/streaming/
sse.rs1use anyhow::Result;
6
7#[derive(Debug, Default)]
12pub struct SseParser {
13 buffer: Vec<u8>,
14 provider: Option<String>,
15 model: Option<String>,
16}
17
18impl SseParser {
19 #[must_use]
21 pub fn new() -> Self {
22 Self::default()
23 }
24
25 #[must_use]
27 pub fn with_context(mut self, provider: impl Into<String>, model: impl Into<String>) -> Self {
28 self.provider = Some(provider.into());
29 self.model = Some(model.into());
30 self
31 }
32
33 pub fn push(&mut self, chunk: &[u8]) -> Result<Vec<SseEvent>> {
35 self.buffer.extend_from_slice(chunk);
36 let mut events = Vec::new();
37
38 while let Some(frame) = self.next_frame() {
39 if let Some(event) = self.parse_frame(&frame)? {
40 events.push(event);
41 }
42 }
43
44 Ok(events)
45 }
46
47 pub fn finish(&mut self) -> Result<Vec<SseEvent>> {
49 if self.buffer.is_empty() {
50 return Ok(Vec::new());
51 }
52
53 let trailing = std::mem::take(&mut self.buffer);
54 match self.parse_frame(&String::from_utf8_lossy(&trailing))? {
55 Some(event) => Ok(vec![event]),
56 None => Ok(Vec::new()),
57 }
58 }
59
60 fn next_frame(&mut self) -> Option<String> {
61 let separator = self
63 .buffer
64 .windows(2)
65 .position(|window| window == b"\n\n")
66 .map(|position| (position, 2))
67 .or_else(|| {
68 self.buffer
69 .windows(4)
70 .position(|window| window == b"\r\n\r\n")
71 .map(|position| (position, 4))
72 })?;
73
74 let (position, separator_len) = separator;
75 let frame = self
76 .buffer
77 .drain(..position + separator_len)
78 .collect::<Vec<_>>();
79 let frame_len = frame.len().saturating_sub(separator_len);
80 Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
81 }
82
83 fn parse_frame(&self, frame: &str) -> Result<Option<SseEvent>> {
84 let trimmed = frame.trim();
85 if trimmed.is_empty() {
86 return Ok(None);
87 }
88
89 let mut data_lines = Vec::new();
90 let mut event_name: Option<String> = None;
91
92 for line in trimmed.lines() {
93 if line.starts_with(':') {
95 continue;
96 }
97 if let Some(name) = line.strip_prefix("event:") {
99 event_name = Some(name.trim().to_string());
100 continue;
101 }
102 if let Some(data) = line.strip_prefix("data:") {
104 data_lines.push(data.trim_start().to_string());
105 }
106 }
107
108 if matches!(event_name.as_deref(), Some("ping")) {
110 return Ok(None);
111 }
112
113 if data_lines.is_empty() {
114 return Ok(None);
115 }
116
117 let payload = data_lines.join("\n");
118
119 if payload == "[DONE]" {
121 return Ok(None);
122 }
123
124 Ok(Some(SseEvent {
125 event: event_name,
126 data: payload,
127 }))
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct SseEvent {
134 pub event: Option<String>,
136 pub data: String,
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn sse_parser_parses_single_frame() {
146 let frame = concat!(
147 "event: content_block_start\n",
148 "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n"
149 );
150
151 let mut parser = SseParser::new();
152 let events = parser.push(frame.as_bytes()).expect("frame should parse");
153
154 assert_eq!(events.len(), 1);
155 assert_eq!(events[0].event, Some("content_block_start".to_string()));
156 assert!(events[0].data.contains("content_block_start"));
157 }
158
159 #[test]
160 fn sse_parser_handles_chunked_stream() {
161 let mut parser = SseParser::new();
162 let first = b"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hel";
163 let second = b"lo\"}}\n\n";
164
165 assert!(parser
166 .push(first)
167 .expect("first chunk should buffer")
168 .is_empty());
169 let events = parser.push(second).expect("second chunk should parse");
170
171 assert_eq!(events.len(), 1);
172 assert!(events[0].data.contains("Hello"));
173 }
174
175 #[test]
176 fn sse_parser_ignores_ping_and_done() {
177 let mut parser = SseParser::new();
178 let payload = concat!(
179 ": keepalive\n",
180 "event: ping\n",
181 "data: {\"type\":\"ping\"}\n\n",
182 "event: message_delta\n",
183 "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"}}\n\n",
184 "data: [DONE]\n\n"
185 );
186
187 let events = parser
188 .push(payload.as_bytes())
189 .expect("parser should succeed");
190 assert_eq!(events.len(), 1); }
192}