lattice_sdk/core/sse_stream.rs
1use crate::ApiError;
2use futures::Stream;
3use pin_project::pin_project;
4use reqwest::{header::CONTENT_TYPE, Response};
5use reqwest_sse::{error::EventError, Event, EventSource};
6use serde::de::DeserializeOwned;
7use std::{
8 marker::PhantomData,
9 pin::Pin,
10 task::{Context, Poll},
11};
12
13/// Metadata from a Server-Sent Event
14///
15/// Contains the SSE protocol fields (event, id, retry) that accompany the data payload.
16/// This struct provides access to SSE metadata for advanced use cases.
17///
18/// Per the SSE specification:
19/// - `event` defaults to "message" when not specified by the server
20/// - `id` is optional and used for reconnection support (Last-Event-ID header)
21/// - `retry` is optional and specifies reconnection timeout in milliseconds
22#[derive(Debug, Clone)]
23pub struct SseMetadata {
24 /// The event type (defaults to "message" per SSE spec if not specified by server)
25 pub event: String,
26 /// The event ID for reconnection support (None if not specified)
27 pub id: Option<String>,
28 /// Retry timeout in milliseconds (None if not specified)
29 pub retry: Option<u64>,
30}
31
32/// A Server-Sent Event with both data and metadata
33///
34/// Contains the deserialized data payload along with SSE protocol metadata.
35/// Use this when you need access to event IDs, event types, or retry information.
36#[derive(Debug)]
37pub struct SseEvent<T> {
38 /// The deserialized data payload
39 pub data: T,
40 /// SSE protocol metadata
41 pub metadata: SseMetadata,
42}
43
44/// A type-safe wrapper around Server-Sent Events (SSE) streams
45///
46/// Leverages `reqwest-sse` for SSE protocol parsing and adds:
47/// - Automatic JSON deserialization to typed structs
48/// - Stream terminator support (e.g., `[DONE]` for OpenAI-style APIs)
49/// - Integrated error handling with `ApiError`
50/// - Content-Type validation (`text/event-stream` required)
51///
52/// # Charset Handling
53///
54/// The `reqwest-sse` library automatically handles charset detection and decoding
55/// based on the Content-Type header. If no charset is specified, UTF-8 is assumed.
56/// This matches the SSE specification default behavior.
57///
58/// # Example
59///
60/// Basic usage with async iteration:
61///
62/// ```no_run
63/// use futures::StreamExt;
64///
65/// let stream: SseStream<CompletionChunk> = client.stream_completions(request).await?;
66/// let mut stream = std::pin::pin!(stream);
67///
68/// while let Some(result) = stream.next().await {
69/// match result {
70/// Ok(chunk) => println!("Received: {:?}", chunk),
71/// Err(e) => eprintln!("Error: {}", e),
72/// }
73/// }
74/// ```
75///
76/// # Error Handling
77///
78/// The stream yields `Result<T, ApiError>` items. Errors can occur from:
79/// - Invalid JSON in SSE data field (`ApiError::Serialization`)
80/// - SSE protocol errors (`ApiError::SseParseError`)
81/// - Network errors during streaming
82///
83/// **Important:** When an error occurs for a single event (e.g., malformed JSON),
84/// the stream yields `Err` for that item but **continues streaming** subsequent events.
85/// The stream only ends when:
86/// - A terminator is received (if configured)
87/// - The server closes the connection
88/// - A fatal network error occurs
89///
90/// This allows the client to handle per-event errors gracefully without losing
91/// the entire stream. Compare this to other error handling strategies where a single
92/// bad event might terminate the stream.
93///
94/// # Terminator Support
95///
96/// When a terminator string is specified (e.g., `[DONE]`), the stream automatically
97/// ends when an SSE event with that exact data is received. The terminator event
98/// itself is not yielded to the consumer.
99#[pin_project]
100pub struct SseStream<T> {
101 #[pin]
102 inner: Pin<Box<dyn Stream<Item = Result<Event, EventError>> + Send>>,
103 terminator: Option<String>,
104 _phantom: PhantomData<T>,
105}
106
107impl<T> SseStream<T>
108where
109 T: DeserializeOwned,
110{
111 /// Create a new SSE stream from a Response
112 ///
113 /// # Arguments
114 /// * `response` - The HTTP response to parse as SSE
115 /// * `terminator` - Optional terminator string (e.g., `"[DONE]"`) that signals end of stream
116 ///
117 /// # Errors
118 /// Returns `ApiError::SseParseError` if:
119 /// - Response Content-Type is not `text/event-stream`
120 /// - SSE stream cannot be created from response
121 pub(crate) async fn new(
122 response: Response,
123 terminator: Option<String>,
124 ) -> Result<Self, ApiError> {
125 // Validate Content-Type header (case-insensitive, handles parameters)
126 let content_type = response
127 .headers()
128 .get(CONTENT_TYPE)
129 .and_then(|v| v.to_str().ok())
130 .unwrap_or("");
131
132 // Extract main content type (before ';' parameter separator) and compare case-insensitively
133 let content_type_main = content_type.split(';').next().unwrap_or("").trim();
134
135 if !content_type_main.eq_ignore_ascii_case("text/event-stream") {
136 return Err(ApiError::SseParseError(format!(
137 "Expected Content-Type to be 'text/event-stream', got '{}'",
138 content_type
139 )));
140 }
141
142 // Use reqwest-sse's EventSource trait to get SSE stream
143 let events = response
144 .events()
145 .await
146 .map_err(|e| ApiError::SseParseError(e.to_string()))?;
147
148 Ok(Self {
149 inner: Box::pin(events),
150 terminator,
151 _phantom: PhantomData,
152 })
153 }
154}
155
156impl<T> SseStream<T>
157where
158 T: DeserializeOwned,
159{
160 /// Convert this stream into one that yields events with metadata
161 ///
162 /// This consumes the stream and returns a new stream that yields `SseEvent<T>`
163 /// instead of just `T`, providing access to SSE metadata (event type, id, retry).
164 ///
165 /// # Example
166 ///
167 /// ```no_run
168 /// use futures::StreamExt;
169 ///
170 /// let stream = client.stream_completions(request).await?;
171 /// let mut stream_with_metadata = stream.with_metadata();
172 /// let mut stream_with_metadata = std::pin::pin!(stream_with_metadata);
173 ///
174 /// while let Some(result) = stream_with_metadata.next().await {
175 /// match result {
176 /// Ok(event) => {
177 /// println!("Data: {:?}", event.data);
178 /// println!("Event type: {}", event.metadata.event);
179 /// if let Some(id) = &event.metadata.id {
180 /// println!("Event ID: {}", id);
181 /// }
182 /// }
183 /// Err(e) => eprintln!("Error: {}", e),
184 /// }
185 /// }
186 /// ```
187 pub fn with_metadata(self) -> SseStreamWithMetadata<T> {
188 SseStreamWithMetadata { inner: self }
189 }
190}
191
192impl<T> Stream for SseStream<T>
193where
194 T: DeserializeOwned,
195{
196 type Item = Result<T, ApiError>;
197
198 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
199 let this = self.project();
200 match this.inner.poll_next(cx) {
201 Poll::Ready(Some(Ok(event))) => {
202 // Check for terminator before parsing
203 if let Some(ref terminator) = this.terminator {
204 if event.data == *terminator {
205 // Terminator found - end stream cleanly
206 return Poll::Ready(None);
207 }
208 }
209
210 // Deserialize JSON data to typed struct
211 match serde_json::from_str(&event.data) {
212 Ok(value) => Poll::Ready(Some(Ok(value))),
213 Err(e) => Poll::Ready(Some(Err(ApiError::Serialization(e)))),
214 }
215 }
216 Poll::Ready(Some(Err(e))) => {
217 Poll::Ready(Some(Err(ApiError::SseParseError(e.to_string()))))
218 }
219 Poll::Ready(None) => Poll::Ready(None),
220 Poll::Pending => Poll::Pending,
221 }
222 }
223}
224
225/// Stream wrapper that yields events with metadata
226///
227/// Created by calling [`SseStream::with_metadata()`]. This stream yields `SseEvent<T>`
228/// which includes both the deserialized data and SSE protocol metadata.
229#[pin_project]
230pub struct SseStreamWithMetadata<T> {
231 #[pin]
232 inner: SseStream<T>,
233}
234
235impl<T> Stream for SseStreamWithMetadata<T>
236where
237 T: DeserializeOwned,
238{
239 type Item = Result<SseEvent<T>, ApiError>;
240
241 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
242 let this = self.project();
243
244 // Access the inner stream's fields through pin projection
245 let inner_pin = this.inner.project();
246
247 match inner_pin.inner.poll_next(cx) {
248 Poll::Ready(Some(Ok(event))) => {
249 // Check for terminator
250 if let Some(ref terminator) = inner_pin.terminator {
251 if event.data == *terminator {
252 return Poll::Ready(None);
253 }
254 }
255
256 // Extract metadata
257 let metadata = SseMetadata {
258 // Default to "message" if event type is empty (per SSE spec)
259 event: if event.event_type.is_empty() {
260 "message".to_string()
261 } else {
262 event.event_type.clone()
263 },
264 id: event.last_event_id.clone(),
265 retry: event.retry.map(|d| d.as_millis() as u64),
266 };
267
268 // Deserialize JSON data
269 match serde_json::from_str(&event.data) {
270 Ok(data) => Poll::Ready(Some(Ok(SseEvent { data, metadata }))),
271 Err(e) => Poll::Ready(Some(Err(ApiError::Serialization(e)))),
272 }
273 }
274 Poll::Ready(Some(Err(e))) => {
275 Poll::Ready(Some(Err(ApiError::SseParseError(e.to_string()))))
276 }
277 Poll::Ready(None) => Poll::Ready(None),
278 Poll::Pending => Poll::Pending,
279 }
280 }
281}