lattice_sdk/core/
sse_stream.rs

1use crate::ApiError;
2use futures::Stream;
3use pin_project::pin_project;
4use reqwest::{header::CONTENT_TYPE, Response};
5use reqwest_sse::{error::EventError, Event, EventSource};
6use serde::de::DeserializeOwned;
7use std::{
8    marker::PhantomData,
9    pin::Pin,
10    task::{Context, Poll},
11};
12
13/// Metadata from a Server-Sent Event
14///
15/// Contains the SSE protocol fields (event, id, retry) that accompany the data payload.
16/// This struct provides access to SSE metadata for advanced use cases.
17///
18/// Per the SSE specification:
19/// - `event` defaults to "message" when not specified by the server
20/// - `id` is optional and used for reconnection support (Last-Event-ID header)
21/// - `retry` is optional and specifies reconnection timeout in milliseconds
22#[derive(Debug, Clone)]
23pub struct SseMetadata {
24    /// The event type (defaults to "message" per SSE spec if not specified by server)
25    pub event: String,
26    /// The event ID for reconnection support (None if not specified)
27    pub id: Option<String>,
28    /// Retry timeout in milliseconds (None if not specified)
29    pub retry: Option<u64>,
30}
31
32/// A Server-Sent Event with both data and metadata
33///
34/// Contains the deserialized data payload along with SSE protocol metadata.
35/// Use this when you need access to event IDs, event types, or retry information.
36#[derive(Debug)]
37pub struct SseEvent<T> {
38    /// The deserialized data payload
39    pub data: T,
40    /// SSE protocol metadata
41    pub metadata: SseMetadata,
42}
43
44/// A type-safe wrapper around Server-Sent Events (SSE) streams
45///
46/// Leverages `reqwest-sse` for SSE protocol parsing and adds:
47/// - Automatic JSON deserialization to typed structs
48/// - Stream terminator support (e.g., `[DONE]` for OpenAI-style APIs)
49/// - Integrated error handling with `ApiError`
50/// - Content-Type validation (`text/event-stream` required)
51///
52/// # Charset Handling
53///
54/// The `reqwest-sse` library automatically handles charset detection and decoding
55/// based on the Content-Type header. If no charset is specified, UTF-8 is assumed.
56/// This matches the SSE specification default behavior.
57///
58/// # Example
59///
60/// Basic usage with async iteration:
61///
62/// ```no_run
63/// use futures::StreamExt;
64///
65/// let stream: SseStream<CompletionChunk> = client.stream_completions(request).await?;
66/// let mut stream = std::pin::pin!(stream);
67///
68/// while let Some(result) = stream.next().await {
69///     match result {
70///         Ok(chunk) => println!("Received: {:?}", chunk),
71///         Err(e) => eprintln!("Error: {}", e),
72///     }
73/// }
74/// ```
75///
76/// # Error Handling
77///
78/// The stream yields `Result<T, ApiError>` items. Errors can occur from:
79/// - Invalid JSON in SSE data field (`ApiError::Serialization`)
80/// - SSE protocol errors (`ApiError::SseParseError`)
81/// - Network errors during streaming
82///
83/// **Important:** When an error occurs for a single event (e.g., malformed JSON),
84/// the stream yields `Err` for that item but **continues streaming** subsequent events.
85/// The stream only ends when:
86/// - A terminator is received (if configured)
87/// - The server closes the connection
88/// - A fatal network error occurs
89///
90/// This allows the client to handle per-event errors gracefully without losing
91/// the entire stream. Compare this to other error handling strategies where a single
92/// bad event might terminate the stream.
93///
94/// # Terminator Support
95///
96/// When a terminator string is specified (e.g., `[DONE]`), the stream automatically
97/// ends when an SSE event with that exact data is received. The terminator event
98/// itself is not yielded to the consumer.
99#[pin_project]
100pub struct SseStream<T> {
101    #[pin]
102    inner: Pin<Box<dyn Stream<Item = Result<Event, EventError>> + Send>>,
103    terminator: Option<String>,
104    _phantom: PhantomData<T>,
105}
106
107impl<T> SseStream<T>
108where
109    T: DeserializeOwned,
110{
111    /// Create a new SSE stream from a Response
112    ///
113    /// # Arguments
114    /// * `response` - The HTTP response to parse as SSE
115    /// * `terminator` - Optional terminator string (e.g., `"[DONE]"`) that signals end of stream
116    ///
117    /// # Errors
118    /// Returns `ApiError::SseParseError` if:
119    /// - Response Content-Type is not `text/event-stream`
120    /// - SSE stream cannot be created from response
121    pub(crate) async fn new(
122        response: Response,
123        terminator: Option<String>,
124    ) -> Result<Self, ApiError> {
125        // Validate Content-Type header (case-insensitive, handles parameters)
126        let content_type = response
127            .headers()
128            .get(CONTENT_TYPE)
129            .and_then(|v| v.to_str().ok())
130            .unwrap_or("");
131
132        // Extract main content type (before ';' parameter separator) and compare case-insensitively
133        let content_type_main = content_type.split(';').next().unwrap_or("").trim();
134
135        if !content_type_main.eq_ignore_ascii_case("text/event-stream") {
136            return Err(ApiError::SseParseError(format!(
137                "Expected Content-Type to be 'text/event-stream', got '{}'",
138                content_type
139            )));
140        }
141
142        // Use reqwest-sse's EventSource trait to get SSE stream
143        let events = response
144            .events()
145            .await
146            .map_err(|e| ApiError::SseParseError(e.to_string()))?;
147
148        Ok(Self {
149            inner: Box::pin(events),
150            terminator,
151            _phantom: PhantomData,
152        })
153    }
154}
155
156impl<T> SseStream<T>
157where
158    T: DeserializeOwned,
159{
160    /// Convert this stream into one that yields events with metadata
161    ///
162    /// This consumes the stream and returns a new stream that yields `SseEvent<T>`
163    /// instead of just `T`, providing access to SSE metadata (event type, id, retry).
164    ///
165    /// # Example
166    ///
167    /// ```no_run
168    /// use futures::StreamExt;
169    ///
170    /// let stream = client.stream_completions(request).await?;
171    /// let mut stream_with_metadata = stream.with_metadata();
172    /// let mut stream_with_metadata = std::pin::pin!(stream_with_metadata);
173    ///
174    /// while let Some(result) = stream_with_metadata.next().await {
175    ///     match result {
176    ///         Ok(event) => {
177    ///             println!("Data: {:?}", event.data);
178    ///             println!("Event type: {}", event.metadata.event);
179    ///             if let Some(id) = &event.metadata.id {
180    ///                 println!("Event ID: {}", id);
181    ///             }
182    ///         }
183    ///         Err(e) => eprintln!("Error: {}", e),
184    ///     }
185    /// }
186    /// ```
187    pub fn with_metadata(self) -> SseStreamWithMetadata<T> {
188        SseStreamWithMetadata { inner: self }
189    }
190}
191
192impl<T> Stream for SseStream<T>
193where
194    T: DeserializeOwned,
195{
196    type Item = Result<T, ApiError>;
197
198    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
199        let this = self.project();
200        match this.inner.poll_next(cx) {
201            Poll::Ready(Some(Ok(event))) => {
202                // Check for terminator before parsing
203                if let Some(ref terminator) = this.terminator {
204                    if event.data == *terminator {
205                        // Terminator found - end stream cleanly
206                        return Poll::Ready(None);
207                    }
208                }
209
210                // Deserialize JSON data to typed struct
211                match serde_json::from_str(&event.data) {
212                    Ok(value) => Poll::Ready(Some(Ok(value))),
213                    Err(e) => Poll::Ready(Some(Err(ApiError::Serialization(e)))),
214                }
215            }
216            Poll::Ready(Some(Err(e))) => {
217                Poll::Ready(Some(Err(ApiError::SseParseError(e.to_string()))))
218            }
219            Poll::Ready(None) => Poll::Ready(None),
220            Poll::Pending => Poll::Pending,
221        }
222    }
223}
224
225/// Stream wrapper that yields events with metadata
226///
227/// Created by calling [`SseStream::with_metadata()`]. This stream yields `SseEvent<T>`
228/// which includes both the deserialized data and SSE protocol metadata.
229#[pin_project]
230pub struct SseStreamWithMetadata<T> {
231    #[pin]
232    inner: SseStream<T>,
233}
234
235impl<T> Stream for SseStreamWithMetadata<T>
236where
237    T: DeserializeOwned,
238{
239    type Item = Result<SseEvent<T>, ApiError>;
240
241    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
242        let this = self.project();
243
244        // Access the inner stream's fields through pin projection
245        let inner_pin = this.inner.project();
246
247        match inner_pin.inner.poll_next(cx) {
248            Poll::Ready(Some(Ok(event))) => {
249                // Check for terminator
250                if let Some(ref terminator) = inner_pin.terminator {
251                    if event.data == *terminator {
252                        return Poll::Ready(None);
253                    }
254                }
255
256                // Extract metadata
257                let metadata = SseMetadata {
258                    // Default to "message" if event type is empty (per SSE spec)
259                    event: if event.event_type.is_empty() {
260                        "message".to_string()
261                    } else {
262                        event.event_type.clone()
263                    },
264                    id: event.last_event_id.clone(),
265                    retry: event.retry.map(|d| d.as_millis() as u64),
266                };
267
268                // Deserialize JSON data
269                match serde_json::from_str(&event.data) {
270                    Ok(data) => Poll::Ready(Some(Ok(SseEvent { data, metadata }))),
271                    Err(e) => Poll::Ready(Some(Err(ApiError::Serialization(e)))),
272                }
273            }
274            Poll::Ready(Some(Err(e))) => {
275                Poll::Ready(Some(Err(ApiError::SseParseError(e.to_string()))))
276            }
277            Poll::Ready(None) => Poll::Ready(None),
278            Poll::Pending => Poll::Pending,
279        }
280    }
281}