Skip to main content

just_common/transport/
sse.rs

1//! Shared SSE parsing for OpenAI-like JSON chunk streams.
2//!
3//! # Known limitation: unbounded buffering
4//!
5//! The pending byte buffer grows without limit until a blank line terminates an event, and each
6//! event's `data:` payload is materialized in full before deserialization. Unlike the
7//! non-streaming path (capped by `crate::transport::http::MAX_BODY_BYTES`), there is no size cap
8//! here, so a malicious or broken server emitting very large events — or withholding the
9//! terminating blank line — can drive unbounded memory growth. Acceptable today because providers
10//! bound event size server-side; tighten if this is ever exposed to untrusted endpoints.
11
12use std::{
13    fmt,
14    pin::Pin,
15    task::{Context, Poll},
16};
17
18use async_stream::try_stream;
19use futures_core::Stream;
20use futures_util::StreamExt;
21use reqwest::header::CONTENT_TYPE;
22use serde::de::DeserializeOwned;
23
24use crate::error::TransportError;
25
26type BoxedJsonStream<T> = Pin<Box<dyn Stream<Item = Result<T, TransportError>> + Send>>;
27
28/// Generic JSON-over-SSE stream.
29pub struct JsonEventStream<T> {
30    inner: BoxedJsonStream<T>,
31}
32
33impl<T> JsonEventStream<T>
34where
35    T: DeserializeOwned + Send + 'static,
36{
37    /// Creates a stream from a successful SSE response.
38    pub fn from_response(response: reqwest::Response) -> Result<Self, TransportError> {
39        ensure_event_stream(&response)?;
40
41        let stream = try_stream! {
42            let mut bytes = response.bytes_stream();
43            let mut buffer = Vec::new();
44            let mut done = false;
45
46            while let Some(chunk) = bytes.next().await {
47                let chunk = chunk.map_err(TransportError::Transport)?;
48                buffer.extend_from_slice(&chunk);
49
50                while let Some((event_end, consumed)) = split_event(&buffer) {
51                    let event = buffer[..event_end].to_vec();
52                    buffer.drain(..consumed);
53
54                    match parse_event::<T>(&event)? {
55                        ParsedEvent::Done => {
56                            done = true;
57                            break;
58                        }
59                        ParsedEvent::Skip => {}
60                        ParsedEvent::Chunk(chunk) => yield chunk,
61                    }
62                }
63
64                if done {
65                    break;
66                }
67            }
68
69            if !done && !buffer.iter().all(u8::is_ascii_whitespace) {
70                match parse_event::<T>(&buffer)? {
71                    ParsedEvent::Done | ParsedEvent::Skip => {}
72                    ParsedEvent::Chunk(chunk) => yield chunk,
73                }
74            }
75        };
76
77        Ok(Self {
78            inner: Box::pin(stream),
79        })
80    }
81}
82
83impl<T> fmt::Debug for JsonEventStream<T> {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        f.debug_struct("JsonEventStream").finish_non_exhaustive()
86    }
87}
88
89impl<T> Stream for JsonEventStream<T> {
90    type Item = Result<T, TransportError>;
91
92    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
93        // `inner` is already pinned in a `Pin<Box<...>>`, so projecting to it is safe.
94        unsafe { self.map_unchecked_mut(|stream| &mut stream.inner) }.poll_next(cx)
95    }
96}
97
98#[derive(Debug)]
99enum ParsedEvent<T> {
100    Done,
101    Skip,
102    Chunk(T),
103}
104
105fn ensure_event_stream(response: &reqwest::Response) -> Result<(), TransportError> {
106    let Some(content_type) = response.headers().get(CONTENT_TYPE) else {
107        return Err(TransportError::InvalidResponse(
108            "streaming response was missing content-type".to_owned(),
109        ));
110    };
111
112    let content_type = content_type.to_str().map_err(|_| {
113        TransportError::InvalidResponse(
114            "streaming response content-type was not valid UTF-8".to_owned(),
115        )
116    })?;
117
118    if !content_type.starts_with("text/event-stream") {
119        return Err(TransportError::InvalidResponse(format!(
120            "expected text/event-stream response, got {content_type}"
121        )));
122    }
123
124    Ok(())
125}
126
127fn split_event(buffer: &[u8]) -> Option<(usize, usize)> {
128    let mut index = 0;
129
130    while index < buffer.len() {
131        if buffer[index..].starts_with(b"\r\n\r\n") {
132            return Some((index, index + 4));
133        }
134
135        if buffer[index..].starts_with(b"\n\n") {
136            return Some((index, index + 2));
137        }
138
139        index += 1;
140    }
141
142    None
143}
144
145fn parse_event<T>(raw_event: &[u8]) -> Result<ParsedEvent<T>, TransportError>
146where
147    T: DeserializeOwned,
148{
149    if raw_event.is_empty() || raw_event.iter().all(u8::is_ascii_whitespace) {
150        return Ok(ParsedEvent::Skip);
151    }
152
153    let event = String::from_utf8(raw_event.to_vec()).map_err(TransportError::Utf8)?;
154    let mut data_lines = Vec::new();
155
156    for line in event.lines() {
157        let line = line.trim_end_matches('\r');
158
159        if line.starts_with(':') {
160            continue;
161        }
162
163        if let Some(data) = line.strip_prefix("data:") {
164            data_lines.push(data.trim_start());
165        }
166    }
167
168    if data_lines.is_empty() {
169        return Ok(ParsedEvent::Skip);
170    }
171
172    let payload = data_lines.join("\n");
173
174    if payload == "[DONE]" {
175        return Ok(ParsedEvent::Done);
176    }
177
178    let chunk = serde_json::from_str(&payload).map_err(|source| TransportError::Deserialize {
179        source,
180        body: payload,
181    })?;
182
183    Ok(ParsedEvent::Chunk(chunk))
184}