Skip to main content

opencode_sdk_rs/
streaming.rs

1//! Server-Sent Events (SSE) streaming support.
2//!
3//! Provides [`SseStream`], a `futures_core::Stream` that wraps an HTTP byte
4//! stream (from `hpx::Response::bytes_stream()`) and yields typed items
5//! parsed from SSE `data:` fields.
6
7use std::{
8    pin::Pin,
9    task::{Context, Poll},
10};
11
12use bytes::Bytes;
13use futures_core::Stream;
14use pin_project_lite::pin_project;
15use serde::de::DeserializeOwned;
16
17use crate::error::OpencodeError;
18
19// ---------------------------------------------------------------------------
20// ServerSentEvent
21// ---------------------------------------------------------------------------
22
23/// A single Server-Sent Event parsed from the wire format.
24#[derive(Debug, Clone, Default)]
25pub struct ServerSentEvent {
26    /// The event type (from `event:` lines).
27    pub event: Option<String>,
28    /// The data payload (from `data:` lines, concatenated with newlines).
29    pub data: String,
30    /// The event ID (from `id:` lines).
31    pub id: Option<String>,
32}
33
34// ---------------------------------------------------------------------------
35// SseDecoder
36// ---------------------------------------------------------------------------
37
38/// Internal buffer that accumulates SSE lines from a byte stream
39/// and yields complete [`ServerSentEvent`]s on empty-line boundaries.
40struct SseDecoder {
41    /// Accumulates partial lines across chunk boundaries.
42    buffer: String,
43    /// Current `event:` value being built.
44    current_event: Option<String>,
45    /// Accumulated `data:` lines for the current event.
46    current_data: Vec<String>,
47    /// Current `id:` value being built.
48    current_id: Option<String>,
49}
50
51impl SseDecoder {
52    const fn new() -> Self {
53        Self {
54            buffer: String::new(),
55            current_event: None,
56            current_data: Vec::new(),
57            current_id: None,
58        }
59    }
60
61    /// Feed a chunk of bytes into the decoder, returning any complete events.
62    fn feed(&mut self, chunk: &[u8]) -> Vec<ServerSentEvent> {
63        let text = String::from_utf8_lossy(chunk);
64        self.buffer.push_str(&text);
65
66        let mut events = Vec::new();
67
68        // Process all complete lines (terminated by \n).
69        // Partial lines remain in `self.buffer` for the next call.
70        while let Some(newline_pos) = self.buffer.find('\n') {
71            let line = self.buffer[..newline_pos].trim_end_matches('\r').to_owned();
72            self.buffer = self.buffer[newline_pos + 1..].to_owned();
73
74            if line.is_empty() {
75                // Empty line marks the end of an event.
76                if let Some(event) = self.emit_event() {
77                    events.push(event);
78                }
79                continue;
80            }
81
82            if line.starts_with(':') {
83                // Comment line — ignore.
84                continue;
85            }
86
87            let (field, value) = if let Some(colon_pos) = line.find(':') {
88                let field = &line[..colon_pos];
89                let mut value = &line[colon_pos + 1..];
90                // Strip a single leading space after the colon (per SSE spec).
91                if value.starts_with(' ') {
92                    value = &value[1..];
93                }
94                (field.to_owned(), value.to_owned())
95            } else {
96                // Field with no value.
97                (line, String::new())
98            };
99
100            match field.as_str() {
101                "event" => self.current_event = Some(value),
102                "data" => self.current_data.push(value),
103                "id" => self.current_id = Some(value),
104                // Unknown fields are ignored per the SSE spec.
105                _ => {}
106            }
107        }
108
109        events
110    }
111
112    /// Emit the current event (if any data has been accumulated) and reset.
113    fn emit_event(&mut self) -> Option<ServerSentEvent> {
114        if self.current_data.is_empty() && self.current_event.is_none() && self.current_id.is_none()
115        {
116            return None;
117        }
118
119        let event = ServerSentEvent {
120            event: self.current_event.take(),
121            data: self.current_data.join("\n"),
122            id: self.current_id.take(),
123        };
124        self.current_data.clear();
125
126        Some(event)
127    }
128
129    /// Flush any remaining partial event when the stream ends.
130    fn flush(&mut self) -> Option<ServerSentEvent> {
131        // If there is leftover text in the buffer, treat it as a final line.
132        if !self.buffer.is_empty() {
133            let remaining = std::mem::take(&mut self.buffer);
134            let trimmed = remaining.trim_end_matches('\r');
135            if !trimmed.is_empty() && !trimmed.starts_with(':') {
136                let (field, value) = trimmed.find(':').map_or_else(
137                    || (trimmed.to_owned(), String::new()),
138                    |colon_pos| {
139                        let field = &trimmed[..colon_pos];
140                        let mut value = &trimmed[colon_pos + 1..];
141                        if value.starts_with(' ') {
142                            value = &value[1..];
143                        }
144                        (field.to_owned(), value.to_owned())
145                    },
146                );
147
148                match field.as_str() {
149                    "event" => self.current_event = Some(value),
150                    "data" => self.current_data.push(value),
151                    "id" => self.current_id = Some(value),
152                    _ => {}
153                }
154            }
155        }
156
157        self.emit_event()
158    }
159}
160
161// ---------------------------------------------------------------------------
162// SseStream
163// ---------------------------------------------------------------------------
164
165pin_project! {
166    /// A stream of typed items parsed from Server-Sent Events.
167    ///
168    /// Wraps an inner byte stream (from `hpx::Response::bytes_stream()`)
169    /// and parses each SSE event's `data` field as JSON of type `T`.
170    pub struct SseStream<T> {
171        #[pin]
172        inner: Pin<Box<dyn Stream<Item = Result<Bytes, hpx::Error>> + Send>>,
173        decoder: SseDecoder,
174        pending: Vec<ServerSentEvent>,
175        _marker: std::marker::PhantomData<T>,
176    }
177}
178
179impl<T: DeserializeOwned> SseStream<T> {
180    /// Create an `SseStream` from an hpx response byte stream.
181    pub(crate) fn new(
182        byte_stream: impl Stream<Item = Result<Bytes, hpx::Error>> + Send + 'static,
183    ) -> Self {
184        Self {
185            inner: Box::pin(byte_stream),
186            decoder: SseDecoder::new(),
187            pending: Vec::new(),
188            _marker: std::marker::PhantomData,
189        }
190    }
191}
192
193impl<T: DeserializeOwned> Stream for SseStream<T> {
194    type Item = Result<T, OpencodeError>;
195
196    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
197        let mut this = self.project();
198
199        // First, drain any pending events from a previous chunk.
200        if !this.pending.is_empty() {
201            let event = this.pending.remove(0);
202            if event.data.is_empty() {
203                // Skip events with no data (heartbeats, etc.).
204                cx.waker().wake_by_ref();
205                return Poll::Pending;
206            }
207            let parsed =
208                serde_json::from_str::<T>(&event.data).map_err(OpencodeError::Serialization);
209            return Poll::Ready(Some(parsed));
210        }
211
212        // Poll the inner byte stream for more data.
213        match this.inner.as_mut().poll_next(cx) {
214            Poll::Ready(Some(Ok(bytes))) => {
215                let events = this.decoder.feed(&bytes);
216                *this.pending = events;
217                cx.waker().wake_by_ref();
218                Poll::Pending
219            }
220            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(OpencodeError::Connection {
221                message: e.to_string(),
222                source: Some(Box::new(e)),
223            }))),
224            Poll::Ready(None) => {
225                // Stream ended — flush any remaining partial event.
226                if let Some(event) = this.decoder.flush() &&
227                    !event.data.is_empty()
228                {
229                    let parsed = serde_json::from_str::<T>(&event.data)
230                        .map_err(OpencodeError::Serialization);
231                    return Poll::Ready(Some(parsed));
232                }
233                Poll::Ready(None)
234            }
235            Poll::Pending => Poll::Pending,
236        }
237    }
238}
239
240// ---------------------------------------------------------------------------
241// Tests
242// ---------------------------------------------------------------------------
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    #[test]
249    fn test_parse_simple_event() {
250        let mut decoder = SseDecoder::new();
251        let events = decoder.feed(b"data: {\"key\":\"value\"}\n\n");
252        assert_eq!(events.len(), 1);
253        assert_eq!(events[0].data, "{\"key\":\"value\"}");
254        assert!(events[0].event.is_none());
255    }
256
257    #[test]
258    fn test_parse_event_with_type() {
259        let mut decoder = SseDecoder::new();
260        let events = decoder.feed(b"event: message\ndata: hello\n\n");
261        assert_eq!(events.len(), 1);
262        assert_eq!(events[0].event.as_deref(), Some("message"));
263        assert_eq!(events[0].data, "hello");
264    }
265
266    #[test]
267    fn test_parse_multiline_data() {
268        let mut decoder = SseDecoder::new();
269        let events = decoder.feed(b"data: line1\ndata: line2\n\n");
270        assert_eq!(events.len(), 1);
271        assert_eq!(events[0].data, "line1\nline2");
272    }
273
274    #[test]
275    fn test_parse_multiple_events() {
276        let mut decoder = SseDecoder::new();
277        let events = decoder.feed(b"data: event1\n\ndata: event2\n\n");
278        assert_eq!(events.len(), 2);
279        assert_eq!(events[0].data, "event1");
280        assert_eq!(events[1].data, "event2");
281    }
282
283    #[test]
284    fn test_ignore_comments() {
285        let mut decoder = SseDecoder::new();
286        let events = decoder.feed(b": this is a comment\ndata: actual\n\n");
287        assert_eq!(events.len(), 1);
288        assert_eq!(events[0].data, "actual");
289    }
290
291    #[test]
292    fn test_chunked_data() {
293        let mut decoder = SseDecoder::new();
294        let events1 = decoder.feed(b"data: hel");
295        assert!(events1.is_empty());
296        let events2 = decoder.feed(b"lo\n\n");
297        assert_eq!(events2.len(), 1);
298        assert_eq!(events2[0].data, "hello");
299    }
300
301    #[test]
302    fn test_id_field() {
303        let mut decoder = SseDecoder::new();
304        let events = decoder.feed(b"id: 42\ndata: test\n\n");
305        assert_eq!(events.len(), 1);
306        assert_eq!(events[0].id.as_deref(), Some("42"));
307        assert_eq!(events[0].data, "test");
308    }
309
310    #[test]
311    fn test_flush_remaining() {
312        let mut decoder = SseDecoder::new();
313        let events = decoder.feed(b"data: partial");
314        assert!(events.is_empty());
315        let event = decoder.flush();
316        assert!(event.is_some());
317        assert_eq!(event.as_ref().unwrap().data, "partial");
318    }
319
320    #[test]
321    fn test_empty_line_no_data() {
322        let mut decoder = SseDecoder::new();
323        // An empty line without prior fields produces nothing.
324        let events = decoder.feed(b"\n");
325        assert!(events.is_empty());
326    }
327
328    #[test]
329    fn test_field_without_value() {
330        let mut decoder = SseDecoder::new();
331        let events = decoder.feed(b"data\n\n");
332        assert_eq!(events.len(), 1);
333        assert_eq!(events[0].data, "");
334    }
335
336    #[test]
337    fn test_crlf_line_endings() {
338        let mut decoder = SseDecoder::new();
339        let events = decoder.feed(b"data: hello\r\n\r\n");
340        assert_eq!(events.len(), 1);
341        assert_eq!(events[0].data, "hello");
342    }
343
344    #[test]
345    fn test_sse_stream_typed_compiles() {
346        // Verify that SseStream implements Stream with the expected Item.
347        fn _assert_stream<S: Stream<Item = Result<serde_json::Value, OpencodeError>>>(_s: S) {}
348
349        // Verify SseStream is Send (required for async runtimes).
350        fn _assert_send<S: Send>(_s: S) {}
351    }
352}