Skip to main content

rig/http_client/
sse.rs

1//! An SSE implementation that leverages [`crate::http_client::HttpClientExt`] to allow streaming with automatic retry handling for any implementor of HttpClientExt.
2//!
3//! Primarily intended for internal usage. However if you also wish to implement generic HTTP streaming for your custom completion model,
4//! you may find this helpful.
5use crate::{
6    http_client::{
7        HttpClientExt, Result as StreamResult,
8        retry::{DEFAULT_RETRY, ExponentialBackoff, RetryPolicy},
9    },
10    wasm_compat::{WasmCompatSend, WasmCompatSendStream},
11};
12use bytes::Bytes;
13use eventsource_stream::{Event as MessageEvent, EventStreamError, Eventsource};
14use futures::Stream;
15#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
16use futures::{future::BoxFuture, stream::BoxStream};
17#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
18use futures::{future::LocalBoxFuture, stream::LocalBoxStream};
19use futures_timer::Delay;
20use http::Response;
21use http::{HeaderName, HeaderValue, Request, StatusCode};
22use mime_guess::mime;
23use pin_project_lite::pin_project;
24use std::{
25    pin::Pin,
26    task::{Context, Poll},
27    time::Duration,
28};
29
30pub type BoxedStream = Pin<Box<dyn WasmCompatSendStream<InnerItem = StreamResult<Bytes>>>>;
31
32#[cfg(not(target_arch = "wasm32"))]
33type ResponseFuture = BoxFuture<'static, Result<Response<BoxedStream>, super::Error>>;
34#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
35type ResponseFuture = LocalBoxFuture<'static, Result<Response<BoxedStream>, super::Error>>;
36
37#[cfg(not(target_arch = "wasm32"))]
38type EventStream = BoxStream<'static, Result<MessageEvent, EventStreamError<super::Error>>>;
39#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
40type EventStream = LocalBoxStream<'static, Result<MessageEvent, EventStreamError<super::Error>>>;
41
42pin_project! {
43    /// Internal state variants for the SSE state machine.
44    #[project = SourceStateProjection]
45    enum SourceState {
46        /// Initial connection attempt (no retry history yet)
47        Connecting {
48            #[pin]
49            response_future: ResponseFuture,
50        },
51        /// Reconnection attempt after a retry delay (always has retry history)
52        Reconnecting {
53            #[pin]
54            response_future: ResponseFuture,
55            last_retry: (usize, Duration),
56        },
57        /// Actively receiving SSE events
58        Open {
59            #[pin]
60            event_stream: EventStream,
61        },
62        /// Waiting before retry after an error
63        WaitingToRetry {
64            #[pin]
65            retry_delay: Delay,
66            current_retry: (usize, Duration),
67        },
68        /// Terminal state
69        Closed,
70    }
71}
72
73pin_project! {
74    /// A generic SSE event source that works with any [`HttpClientExt`] implementation.
75    #[project = GenericEventSourceProjection]
76    pub struct GenericEventSource<HttpClient, RequestBody, Retry = ExponentialBackoff> {
77        client: HttpClient,
78        req: Request<RequestBody>,
79        retry_policy: Retry,
80        last_event_id: Option<String>,
81        #[pin]
82        state: SourceState,
83    }
84}
85
86impl<HttpClient, RequestBody> GenericEventSource<HttpClient, RequestBody>
87where
88    HttpClient: HttpClientExt + Clone + 'static,
89    RequestBody: Into<Bytes> + Clone + WasmCompatSend + 'static,
90{
91    /// Create a new event source that will connect to the given request.
92    pub fn new(client: HttpClient, req: Request<RequestBody>) -> Self {
93        let response_future = Self::create_response_future(&client, &req, None);
94        let state = SourceState::Connecting { response_future };
95
96        Self {
97            client,
98            req,
99            retry_policy: DEFAULT_RETRY,
100            last_event_id: None,
101            state,
102        }
103    }
104
105    pub fn with_retry_policy<R>(
106        client: HttpClient,
107        req: Request<RequestBody>,
108        retry_policy: R,
109    ) -> GenericEventSource<HttpClient, RequestBody, R>
110    where
111        R: RetryPolicy,
112    {
113        let response_future = Self::create_response_future(&client, &req, None);
114        let state = SourceState::Connecting { response_future };
115
116        GenericEventSource {
117            client,
118            req,
119            retry_policy,
120            last_event_id: None,
121            state,
122        }
123    }
124
125    /// Create a response future for connecting/reconnecting
126    fn create_response_future(
127        client: &HttpClient,
128        req: &Request<RequestBody>,
129        last_event_id: Option<&str>,
130    ) -> ResponseFuture {
131        let mut req_clone = req.clone();
132        req_clone
133            .headers_mut()
134            .entry("Accept")
135            .or_insert(HeaderValue::from_static("text/event-stream"));
136
137        if let Some(id) = last_event_id
138            && let Ok(value) = HeaderValue::from_str(id)
139        {
140            req_clone
141                .headers_mut()
142                .insert(HeaderName::from_static("last-event-id"), value);
143        }
144
145        let client_clone = client.clone();
146        Box::pin(async move { client_clone.send_streaming(req_clone).await })
147    }
148
149    /// Get the last event id
150    pub fn last_event_id(&self) -> Option<&str> {
151        self.last_event_id.as_deref()
152    }
153
154    /// Close the event source, transitioning to the Closed state.
155    /// After calling this, the stream will yield `None` on the next poll.
156    pub fn close(&mut self) {
157        self.state = SourceState::Closed;
158    }
159}
160
161/// Events created by the [`GenericEventSource`]
162#[derive(Debug, Clone, Eq, PartialEq)]
163pub enum Event {
164    /// The event fired when the connection is opened
165    Open,
166    /// The event fired when a [`MessageEvent`] is received
167    Message(MessageEvent),
168}
169
170impl From<MessageEvent> for Event {
171    fn from(event: MessageEvent) -> Self {
172        Event::Message(event)
173    }
174}
175
176impl<HttpClient, RequestBody> Stream for GenericEventSource<HttpClient, RequestBody>
177where
178    HttpClient: HttpClientExt + Clone + 'static,
179    RequestBody: Into<Bytes> + Clone + WasmCompatSend + 'static,
180{
181    type Item = Result<Event, super::Error>;
182
183    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
184        let mut this = self.project();
185
186        loop {
187            match this.state.as_mut().project() {
188                SourceStateProjection::Connecting { response_future } => {
189                    match response_future.poll(cx) {
190                        Poll::Pending => return Poll::Pending,
191                        Poll::Ready(Ok(response)) => {
192                            match check_response(response) {
193                                Ok(response) => {
194                                    // Transition: Connecting -> Open
195                                    let mut event_stream = response.into_body().eventsource();
196                                    if let Some(id) = &this.last_event_id {
197                                        event_stream.set_last_event_id(id.clone());
198                                    }
199                                    this.state.set(SourceState::Open {
200                                        event_stream: Box::pin(event_stream),
201                                    });
202                                    return Poll::Ready(Some(Ok(Event::Open)));
203                                }
204                                Err(err) => {
205                                    // Transition: Connecting -> Closed (non-retryable error)
206                                    this.state.set(SourceState::Closed);
207                                    return Poll::Ready(Some(Err(err)));
208                                }
209                            }
210                        }
211                        Poll::Ready(Err(err)) => {
212                            // First connection attempt failed - start retry cycle
213                            if let Some(delay_duration) = this.retry_policy.retry(&err, None) {
214                                // Transition: Connecting -> WaitingToRetry
215                                this.state.set(SourceState::WaitingToRetry {
216                                    retry_delay: Delay::new(delay_duration),
217                                    current_retry: (1, delay_duration),
218                                });
219                                return Poll::Ready(Some(Err(err)));
220                            } else {
221                                // Transition: Connecting -> Closed
222                                this.state.set(SourceState::Closed);
223                                return Poll::Ready(Some(Err(err)));
224                            }
225                        }
226                    }
227                }
228
229                SourceStateProjection::Reconnecting {
230                    response_future,
231                    last_retry,
232                } => {
233                    match response_future.poll(cx) {
234                        Poll::Pending => return Poll::Pending,
235                        Poll::Ready(Ok(response)) => {
236                            match check_response(response) {
237                                Ok(response) => {
238                                    // Transition: Reconnecting -> Open (retry cycle complete)
239                                    let mut event_stream = response.into_body().eventsource();
240                                    if let Some(id) = &this.last_event_id {
241                                        event_stream.set_last_event_id(id.clone());
242                                    }
243                                    this.state.set(SourceState::Open {
244                                        event_stream: Box::pin(event_stream),
245                                    });
246                                    return Poll::Ready(Some(Ok(Event::Open)));
247                                }
248                                Err(err) => {
249                                    // Transition: Reconnecting -> Closed (non-retryable error)
250                                    this.state.set(SourceState::Closed);
251                                    return Poll::Ready(Some(Err(err)));
252                                }
253                            }
254                        }
255                        Poll::Ready(Err(err)) => {
256                            // Reconnection attempt failed - continue retry cycle
257                            if let Some(delay_duration) =
258                                this.retry_policy.retry(&err, Some(*last_retry))
259                            {
260                                let (retry_num, _) = *last_retry;
261                                // Transition: Reconnecting -> WaitingToRetry
262                                this.state.set(SourceState::WaitingToRetry {
263                                    retry_delay: Delay::new(delay_duration),
264                                    current_retry: (retry_num + 1, delay_duration),
265                                });
266                                return Poll::Ready(Some(Err(err)));
267                            } else {
268                                // Transition: Reconnecting -> Closed (max retries exceeded)
269                                this.state.set(SourceState::Closed);
270                                return Poll::Ready(Some(Err(err)));
271                            }
272                        }
273                    }
274                }
275
276                SourceStateProjection::Open { event_stream } => {
277                    match event_stream.poll_next(cx) {
278                        Poll::Pending => return Poll::Pending,
279                        Poll::Ready(Some(Ok(event))) => {
280                            if !event.id.is_empty() {
281                                *this.last_event_id = Some(event.id.clone());
282                            }
283                            if let Some(duration) = event.retry {
284                                this.retry_policy.set_reconnection_time(duration);
285                            }
286                            return Poll::Ready(Some(Ok(Event::Message(event))));
287                        }
288                        Poll::Ready(Some(Err(EventStreamError::Transport(err)))) => {
289                            // Connection error while open - start fresh retry cycle
290                            if let Some(delay_duration) = this.retry_policy.retry(&err, None) {
291                                // Transition: Open -> WaitingToRetry
292                                this.state.set(SourceState::WaitingToRetry {
293                                    retry_delay: Delay::new(delay_duration),
294                                    current_retry: (1, delay_duration),
295                                });
296                                return Poll::Ready(Some(Err(err)));
297                            } else {
298                                // Transition: Open -> Closed
299                                this.state.set(SourceState::Closed);
300                                return Poll::Ready(Some(Err(err)));
301                            }
302                        }
303                        Poll::Ready(Some(Err(EventStreamError::Parser(_)))) => {
304                            // Parser errors are recoverable - continue polling
305                            continue;
306                        }
307                        Poll::Ready(Some(Err(EventStreamError::Utf8(_)))) => {
308                            // UTF-8 errors are recoverable - continue polling
309                            continue;
310                        }
311                        Poll::Ready(None) => {
312                            // Transition: Open -> Closed
313                            this.state.set(SourceState::Closed);
314                            return Poll::Ready(None);
315                        }
316                    }
317                }
318
319                SourceStateProjection::WaitingToRetry {
320                    retry_delay,
321                    current_retry,
322                } => {
323                    // Copy before polling to avoid borrow conflicts
324                    let retry_info = *current_retry;
325                    match retry_delay.poll(cx) {
326                        Poll::Pending => return Poll::Pending,
327                        Poll::Ready(()) => {
328                            // Transition: WaitingToRetry -> Reconnecting
329                            let response_future =
330                                GenericEventSource::<HttpClient, RequestBody>::create_response_future(
331                                    this.client,
332                                    this.req,
333                                    this.last_event_id.as_deref(),
334                                );
335                            this.state.set(SourceState::Reconnecting {
336                                response_future,
337                                last_retry: retry_info,
338                            });
339                            continue;
340                        }
341                    }
342                }
343
344                SourceStateProjection::Closed => {
345                    return Poll::Ready(None);
346                }
347            }
348        }
349    }
350}
351
352fn check_response<T>(response: Response<T>) -> Result<Response<T>, super::Error> {
353    let StatusCode::OK = response.status() else {
354        return Err(super::Error::InvalidStatusCode(response.status()));
355    };
356
357    let content_type =
358        if let Some(content_type) = response.headers().get(&reqwest::header::CONTENT_TYPE) {
359            content_type
360        } else {
361            return Err(super::Error::InvalidContentType(HeaderValue::from_static(
362                "",
363            )));
364        };
365
366    if content_type
367        .to_str()
368        .map_err(|_| ())
369        .and_then(|s| s.parse::<mime::Mime>().map_err(|_| ()))
370        .map(|mime_type| {
371            matches!(
372                (mime_type.type_(), mime_type.subtype()),
373                (mime::TEXT, mime::EVENT_STREAM)
374            )
375        })
376        .unwrap_or(false)
377    {
378        Ok(response)
379    } else {
380        Err(super::Error::InvalidContentType(content_type.clone()))
381    }
382}