opencode_sdk_rs/
streaming.rs1use std::{
8 pin::Pin,
9 task::{Context, Poll},
10};
11
12use bytes::Bytes;
13use futures_core::Stream;
14use pin_project_lite::pin_project;
15use serde::de::DeserializeOwned;
16
17use crate::error::OpencodeError;
18
19#[derive(Debug, Clone, Default)]
25pub struct ServerSentEvent {
26 pub event: Option<String>,
28 pub data: String,
30 pub id: Option<String>,
32}
33
34struct SseDecoder {
41 buffer: String,
43 current_event: Option<String>,
45 current_data: Vec<String>,
47 current_id: Option<String>,
49}
50
51impl SseDecoder {
52 const fn new() -> Self {
53 Self {
54 buffer: String::new(),
55 current_event: None,
56 current_data: Vec::new(),
57 current_id: None,
58 }
59 }
60
61 fn feed(&mut self, chunk: &[u8]) -> Vec<ServerSentEvent> {
63 let text = String::from_utf8_lossy(chunk);
64 self.buffer.push_str(&text);
65
66 let mut events = Vec::new();
67
68 while let Some(newline_pos) = self.buffer.find('\n') {
71 let line = self.buffer[..newline_pos].trim_end_matches('\r').to_owned();
72 self.buffer = self.buffer[newline_pos + 1..].to_owned();
73
74 if line.is_empty() {
75 if let Some(event) = self.emit_event() {
77 events.push(event);
78 }
79 continue;
80 }
81
82 if line.starts_with(':') {
83 continue;
85 }
86
87 let (field, value) = if let Some(colon_pos) = line.find(':') {
88 let field = &line[..colon_pos];
89 let mut value = &line[colon_pos + 1..];
90 if value.starts_with(' ') {
92 value = &value[1..];
93 }
94 (field.to_owned(), value.to_owned())
95 } else {
96 (line, String::new())
98 };
99
100 match field.as_str() {
101 "event" => self.current_event = Some(value),
102 "data" => self.current_data.push(value),
103 "id" => self.current_id = Some(value),
104 _ => {}
106 }
107 }
108
109 events
110 }
111
112 fn emit_event(&mut self) -> Option<ServerSentEvent> {
114 if self.current_data.is_empty() && self.current_event.is_none() && self.current_id.is_none()
115 {
116 return None;
117 }
118
119 let event = ServerSentEvent {
120 event: self.current_event.take(),
121 data: self.current_data.join("\n"),
122 id: self.current_id.take(),
123 };
124 self.current_data.clear();
125
126 Some(event)
127 }
128
129 fn flush(&mut self) -> Option<ServerSentEvent> {
131 if !self.buffer.is_empty() {
133 let remaining = std::mem::take(&mut self.buffer);
134 let trimmed = remaining.trim_end_matches('\r');
135 if !trimmed.is_empty() && !trimmed.starts_with(':') {
136 let (field, value) = trimmed.find(':').map_or_else(
137 || (trimmed.to_owned(), String::new()),
138 |colon_pos| {
139 let field = &trimmed[..colon_pos];
140 let mut value = &trimmed[colon_pos + 1..];
141 if value.starts_with(' ') {
142 value = &value[1..];
143 }
144 (field.to_owned(), value.to_owned())
145 },
146 );
147
148 match field.as_str() {
149 "event" => self.current_event = Some(value),
150 "data" => self.current_data.push(value),
151 "id" => self.current_id = Some(value),
152 _ => {}
153 }
154 }
155 }
156
157 self.emit_event()
158 }
159}
160
161pin_project! {
166 pub struct SseStream<T> {
171 #[pin]
172 inner: Pin<Box<dyn Stream<Item = Result<Bytes, hpx::Error>> + Send>>,
173 decoder: SseDecoder,
174 pending: Vec<ServerSentEvent>,
175 _marker: std::marker::PhantomData<T>,
176 }
177}
178
179impl<T: DeserializeOwned> SseStream<T> {
180 pub(crate) fn new(
182 byte_stream: impl Stream<Item = Result<Bytes, hpx::Error>> + Send + 'static,
183 ) -> Self {
184 Self {
185 inner: Box::pin(byte_stream),
186 decoder: SseDecoder::new(),
187 pending: Vec::new(),
188 _marker: std::marker::PhantomData,
189 }
190 }
191}
192
193impl<T: DeserializeOwned> Stream for SseStream<T> {
194 type Item = Result<T, OpencodeError>;
195
196 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
197 let mut this = self.project();
198
199 if !this.pending.is_empty() {
201 let event = this.pending.remove(0);
202 if event.data.is_empty() {
203 cx.waker().wake_by_ref();
205 return Poll::Pending;
206 }
207 let parsed =
208 serde_json::from_str::<T>(&event.data).map_err(OpencodeError::Serialization);
209 return Poll::Ready(Some(parsed));
210 }
211
212 match this.inner.as_mut().poll_next(cx) {
214 Poll::Ready(Some(Ok(bytes))) => {
215 let events = this.decoder.feed(&bytes);
216 *this.pending = events;
217 cx.waker().wake_by_ref();
218 Poll::Pending
219 }
220 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(OpencodeError::Connection {
221 message: e.to_string(),
222 source: Some(Box::new(e)),
223 }))),
224 Poll::Ready(None) => {
225 if let Some(event) = this.decoder.flush() &&
227 !event.data.is_empty()
228 {
229 let parsed = serde_json::from_str::<T>(&event.data)
230 .map_err(OpencodeError::Serialization);
231 return Poll::Ready(Some(parsed));
232 }
233 Poll::Ready(None)
234 }
235 Poll::Pending => Poll::Pending,
236 }
237 }
238}
239
240#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn test_parse_simple_event() {
250 let mut decoder = SseDecoder::new();
251 let events = decoder.feed(b"data: {\"key\":\"value\"}\n\n");
252 assert_eq!(events.len(), 1);
253 assert_eq!(events[0].data, "{\"key\":\"value\"}");
254 assert!(events[0].event.is_none());
255 }
256
257 #[test]
258 fn test_parse_event_with_type() {
259 let mut decoder = SseDecoder::new();
260 let events = decoder.feed(b"event: message\ndata: hello\n\n");
261 assert_eq!(events.len(), 1);
262 assert_eq!(events[0].event.as_deref(), Some("message"));
263 assert_eq!(events[0].data, "hello");
264 }
265
266 #[test]
267 fn test_parse_multiline_data() {
268 let mut decoder = SseDecoder::new();
269 let events = decoder.feed(b"data: line1\ndata: line2\n\n");
270 assert_eq!(events.len(), 1);
271 assert_eq!(events[0].data, "line1\nline2");
272 }
273
274 #[test]
275 fn test_parse_multiple_events() {
276 let mut decoder = SseDecoder::new();
277 let events = decoder.feed(b"data: event1\n\ndata: event2\n\n");
278 assert_eq!(events.len(), 2);
279 assert_eq!(events[0].data, "event1");
280 assert_eq!(events[1].data, "event2");
281 }
282
283 #[test]
284 fn test_ignore_comments() {
285 let mut decoder = SseDecoder::new();
286 let events = decoder.feed(b": this is a comment\ndata: actual\n\n");
287 assert_eq!(events.len(), 1);
288 assert_eq!(events[0].data, "actual");
289 }
290
291 #[test]
292 fn test_chunked_data() {
293 let mut decoder = SseDecoder::new();
294 let events1 = decoder.feed(b"data: hel");
295 assert!(events1.is_empty());
296 let events2 = decoder.feed(b"lo\n\n");
297 assert_eq!(events2.len(), 1);
298 assert_eq!(events2[0].data, "hello");
299 }
300
301 #[test]
302 fn test_id_field() {
303 let mut decoder = SseDecoder::new();
304 let events = decoder.feed(b"id: 42\ndata: test\n\n");
305 assert_eq!(events.len(), 1);
306 assert_eq!(events[0].id.as_deref(), Some("42"));
307 assert_eq!(events[0].data, "test");
308 }
309
310 #[test]
311 fn test_flush_remaining() {
312 let mut decoder = SseDecoder::new();
313 let events = decoder.feed(b"data: partial");
314 assert!(events.is_empty());
315 let event = decoder.flush();
316 assert!(event.is_some());
317 assert_eq!(event.as_ref().unwrap().data, "partial");
318 }
319
320 #[test]
321 fn test_empty_line_no_data() {
322 let mut decoder = SseDecoder::new();
323 let events = decoder.feed(b"\n");
325 assert!(events.is_empty());
326 }
327
328 #[test]
329 fn test_field_without_value() {
330 let mut decoder = SseDecoder::new();
331 let events = decoder.feed(b"data\n\n");
332 assert_eq!(events.len(), 1);
333 assert_eq!(events[0].data, "");
334 }
335
336 #[test]
337 fn test_crlf_line_endings() {
338 let mut decoder = SseDecoder::new();
339 let events = decoder.feed(b"data: hello\r\n\r\n");
340 assert_eq!(events.len(), 1);
341 assert_eq!(events[0].data, "hello");
342 }
343
344 #[test]
345 fn test_sse_stream_typed_compiles() {
346 fn _assert_stream<S: Stream<Item = Result<serde_json::Value, OpencodeError>>>(_s: S) {}
348
349 fn _assert_send<S: Send>(_s: S) {}
351 }
352}