Skip to main content

rig/http_client/
sse.rs

1//! An SSE implementation that leverages [`crate::http_client::HttpClientExt`] to allow streaming with automatic retry handling for any implementor of HttpClientExt.
2//!
3//! Primarily intended for internal usage. However if you also wish to implement generic HTTP streaming for your custom completion model,
4//! you may find this helpful.
5use crate::{
6    http_client::{
7        HttpClientExt, Result as StreamResult,
8        retry::{DEFAULT_RETRY, ExponentialBackoff, RetryPolicy},
9    },
10    wasm_compat::{WasmCompatSend, WasmCompatSendStream},
11};
12use bytes::Bytes;
13use eventsource_stream::{Event as MessageEvent, EventStreamError, Eventsource};
14use futures::Stream;
15#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
16use futures::{future::BoxFuture, stream::BoxStream};
17#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
18use futures::{future::LocalBoxFuture, stream::LocalBoxStream};
19use futures_timer::Delay;
20use http::Response;
21use http::{HeaderName, HeaderValue, Request, StatusCode};
22use mime_guess::mime;
23use pin_project_lite::pin_project;
24use std::{
25    pin::Pin,
26    task::{Context, Poll},
27    time::Duration,
28};
29
30pub type BoxedStream = Pin<Box<dyn WasmCompatSendStream<InnerItem = StreamResult<Bytes>>>>;
31
32#[cfg(not(target_arch = "wasm32"))]
33type ResponseFuture = BoxFuture<'static, Result<Response<BoxedStream>, super::Error>>;
34#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
35type ResponseFuture = LocalBoxFuture<'static, Result<Response<BoxedStream>, super::Error>>;
36
37#[cfg(not(target_arch = "wasm32"))]
38type EventStream = BoxStream<'static, Result<MessageEvent, EventStreamError<super::Error>>>;
39#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
40type EventStream = LocalBoxStream<'static, Result<MessageEvent, EventStreamError<super::Error>>>;
41
42pin_project! {
43    /// Internal state variants for the SSE state machine.
44    #[project = SourceStateProjection]
45    enum SourceState {
46        /// Initial connection attempt (no retry history yet)
47        Connecting {
48            #[pin]
49            response_future: ResponseFuture,
50        },
51        /// Reconnection attempt after a retry delay (always has retry history)
52        Reconnecting {
53            #[pin]
54            response_future: ResponseFuture,
55            last_retry: (usize, Duration),
56        },
57        /// Actively receiving SSE events
58        Open {
59            #[pin]
60            event_stream: EventStream,
61        },
62        /// Waiting before retry after an error
63        WaitingToRetry {
64            #[pin]
65            retry_delay: Delay,
66            current_retry: (usize, Duration),
67        },
68        /// Terminal state
69        Closed,
70    }
71}
72
73pin_project! {
74    /// A generic SSE event source that works with any [`HttpClientExt`] implementation.
75    #[project = GenericEventSourceProjection]
76    pub struct GenericEventSource<HttpClient, RequestBody, Retry = ExponentialBackoff> {
77        client: HttpClient,
78        req: Request<RequestBody>,
79        retry_policy: Retry,
80        last_event_id: Option<String>,
81        allow_missing_content_type: bool,
82        #[pin]
83        state: SourceState,
84    }
85}
86
87impl<HttpClient, RequestBody> GenericEventSource<HttpClient, RequestBody>
88where
89    HttpClient: HttpClientExt + Clone + 'static,
90    RequestBody: Into<Bytes> + Clone + WasmCompatSend + 'static,
91{
92    /// Create a new event source that will connect to the given request.
93    pub fn new(client: HttpClient, req: Request<RequestBody>) -> Self {
94        let response_future = Self::create_response_future(&client, &req, None);
95        let state = SourceState::Connecting { response_future };
96
97        Self {
98            client,
99            req,
100            retry_policy: DEFAULT_RETRY,
101            last_event_id: None,
102            allow_missing_content_type: false,
103            state,
104        }
105    }
106
107    pub fn with_retry_policy<R>(
108        client: HttpClient,
109        req: Request<RequestBody>,
110        retry_policy: R,
111    ) -> GenericEventSource<HttpClient, RequestBody, R>
112    where
113        R: RetryPolicy,
114    {
115        let response_future = Self::create_response_future(&client, &req, None);
116        let state = SourceState::Connecting { response_future };
117
118        GenericEventSource {
119            client,
120            req,
121            retry_policy,
122            last_event_id: None,
123            allow_missing_content_type: false,
124            state,
125        }
126    }
127
128    pub fn allow_missing_content_type(mut self) -> Self {
129        self.allow_missing_content_type = true;
130        self
131    }
132
133    /// Create a response future for connecting/reconnecting
134    fn create_response_future(
135        client: &HttpClient,
136        req: &Request<RequestBody>,
137        last_event_id: Option<&str>,
138    ) -> ResponseFuture {
139        let mut req_clone = req.clone();
140        req_clone
141            .headers_mut()
142            .entry("Accept")
143            .or_insert(HeaderValue::from_static("text/event-stream"));
144
145        if let Some(id) = last_event_id
146            && let Ok(value) = HeaderValue::from_str(id)
147        {
148            req_clone
149                .headers_mut()
150                .insert(HeaderName::from_static("last-event-id"), value);
151        }
152
153        let client_clone = client.clone();
154        Box::pin(async move { client_clone.send_streaming(req_clone).await })
155    }
156
157    /// Get the last event id
158    pub fn last_event_id(&self) -> Option<&str> {
159        self.last_event_id.as_deref()
160    }
161
162    /// Close the event source, transitioning to the Closed state.
163    /// After calling this, the stream will yield `None` on the next poll.
164    pub fn close(&mut self) {
165        self.state = SourceState::Closed;
166    }
167}
168
169/// Events created by the [`GenericEventSource`]
170#[derive(Debug, Clone, Eq, PartialEq)]
171pub enum Event {
172    /// The event fired when the connection is opened
173    Open,
174    /// The event fired when a [`MessageEvent`] is received
175    Message(MessageEvent),
176}
177
178impl From<MessageEvent> for Event {
179    fn from(event: MessageEvent) -> Self {
180        Event::Message(event)
181    }
182}
183
184impl<HttpClient, RequestBody> Stream for GenericEventSource<HttpClient, RequestBody>
185where
186    HttpClient: HttpClientExt + Clone + 'static,
187    RequestBody: Into<Bytes> + Clone + WasmCompatSend + 'static,
188{
189    type Item = Result<Event, super::Error>;
190
191    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192        let mut this = self.project();
193
194        loop {
195            match this.state.as_mut().project() {
196                SourceStateProjection::Connecting { response_future } => {
197                    match response_future.poll(cx) {
198                        Poll::Pending => return Poll::Pending,
199                        Poll::Ready(Ok(response)) => {
200                            match check_response(response, *this.allow_missing_content_type) {
201                                Ok(response) => {
202                                    // Transition: Connecting -> Open
203                                    let mut event_stream = response.into_body().eventsource();
204                                    if let Some(id) = &this.last_event_id {
205                                        event_stream.set_last_event_id(id.clone());
206                                    }
207                                    this.state.set(SourceState::Open {
208                                        event_stream: Box::pin(event_stream),
209                                    });
210                                    return Poll::Ready(Some(Ok(Event::Open)));
211                                }
212                                Err(err) => {
213                                    // Transition: Connecting -> Closed (non-retryable error)
214                                    this.state.set(SourceState::Closed);
215                                    return Poll::Ready(Some(Err(err)));
216                                }
217                            }
218                        }
219                        Poll::Ready(Err(err)) => {
220                            // First connection attempt failed - start retry cycle
221                            if let Some(delay_duration) = this.retry_policy.retry(&err, None) {
222                                // Transition: Connecting -> WaitingToRetry
223                                this.state.set(SourceState::WaitingToRetry {
224                                    retry_delay: Delay::new(delay_duration),
225                                    current_retry: (1, delay_duration),
226                                });
227                                return Poll::Ready(Some(Err(err)));
228                            } else {
229                                // Transition: Connecting -> Closed
230                                this.state.set(SourceState::Closed);
231                                return Poll::Ready(Some(Err(err)));
232                            }
233                        }
234                    }
235                }
236
237                SourceStateProjection::Reconnecting {
238                    response_future,
239                    last_retry,
240                } => {
241                    match response_future.poll(cx) {
242                        Poll::Pending => return Poll::Pending,
243                        Poll::Ready(Ok(response)) => {
244                            match check_response(response, *this.allow_missing_content_type) {
245                                Ok(response) => {
246                                    // Transition: Reconnecting -> Open (retry cycle complete)
247                                    let mut event_stream = response.into_body().eventsource();
248                                    if let Some(id) = &this.last_event_id {
249                                        event_stream.set_last_event_id(id.clone());
250                                    }
251                                    this.state.set(SourceState::Open {
252                                        event_stream: Box::pin(event_stream),
253                                    });
254                                    return Poll::Ready(Some(Ok(Event::Open)));
255                                }
256                                Err(err) => {
257                                    // Transition: Reconnecting -> Closed (non-retryable error)
258                                    this.state.set(SourceState::Closed);
259                                    return Poll::Ready(Some(Err(err)));
260                                }
261                            }
262                        }
263                        Poll::Ready(Err(err)) => {
264                            // Reconnection attempt failed - continue retry cycle
265                            if let Some(delay_duration) =
266                                this.retry_policy.retry(&err, Some(*last_retry))
267                            {
268                                let (retry_num, _) = *last_retry;
269                                // Transition: Reconnecting -> WaitingToRetry
270                                this.state.set(SourceState::WaitingToRetry {
271                                    retry_delay: Delay::new(delay_duration),
272                                    current_retry: (retry_num + 1, delay_duration),
273                                });
274                                return Poll::Ready(Some(Err(err)));
275                            } else {
276                                // Transition: Reconnecting -> Closed (max retries exceeded)
277                                this.state.set(SourceState::Closed);
278                                return Poll::Ready(Some(Err(err)));
279                            }
280                        }
281                    }
282                }
283
284                SourceStateProjection::Open { event_stream } => {
285                    match event_stream.poll_next(cx) {
286                        Poll::Pending => return Poll::Pending,
287                        Poll::Ready(Some(Ok(event))) => {
288                            if !event.id.is_empty() {
289                                *this.last_event_id = Some(event.id.clone());
290                            }
291                            if let Some(duration) = event.retry {
292                                this.retry_policy.set_reconnection_time(duration);
293                            }
294                            return Poll::Ready(Some(Ok(Event::Message(event))));
295                        }
296                        Poll::Ready(Some(Err(EventStreamError::Transport(err)))) => {
297                            // Connection error while open - start fresh retry cycle
298                            if let Some(delay_duration) = this.retry_policy.retry(&err, None) {
299                                // Transition: Open -> WaitingToRetry
300                                this.state.set(SourceState::WaitingToRetry {
301                                    retry_delay: Delay::new(delay_duration),
302                                    current_retry: (1, delay_duration),
303                                });
304                                return Poll::Ready(Some(Err(err)));
305                            } else {
306                                // Transition: Open -> Closed
307                                this.state.set(SourceState::Closed);
308                                return Poll::Ready(Some(Err(err)));
309                            }
310                        }
311                        Poll::Ready(Some(Err(EventStreamError::Parser(_)))) => {
312                            // Parser errors are recoverable - continue polling
313                            continue;
314                        }
315                        Poll::Ready(Some(Err(EventStreamError::Utf8(_)))) => {
316                            // UTF-8 errors are recoverable - continue polling
317                            continue;
318                        }
319                        Poll::Ready(None) => {
320                            // Transition: Open -> Closed
321                            this.state.set(SourceState::Closed);
322                            return Poll::Ready(None);
323                        }
324                    }
325                }
326
327                SourceStateProjection::WaitingToRetry {
328                    retry_delay,
329                    current_retry,
330                } => {
331                    // Copy before polling to avoid borrow conflicts
332                    let retry_info = *current_retry;
333                    match retry_delay.poll(cx) {
334                        Poll::Pending => return Poll::Pending,
335                        Poll::Ready(()) => {
336                            // Transition: WaitingToRetry -> Reconnecting
337                            let response_future =
338                                GenericEventSource::<HttpClient, RequestBody>::create_response_future(
339                                    this.client,
340                                    this.req,
341                                    this.last_event_id.as_deref(),
342                                );
343                            this.state.set(SourceState::Reconnecting {
344                                response_future,
345                                last_retry: retry_info,
346                            });
347                            continue;
348                        }
349                    }
350                }
351
352                SourceStateProjection::Closed => {
353                    return Poll::Ready(None);
354                }
355            }
356        }
357    }
358}
359
360fn check_response<T>(
361    response: Response<T>,
362    allow_missing_content_type: bool,
363) -> Result<Response<T>, super::Error> {
364    let StatusCode::OK = response.status() else {
365        return Err(super::Error::InvalidStatusCode(response.status()));
366    };
367
368    let content_type =
369        if let Some(content_type) = response.headers().get(&reqwest::header::CONTENT_TYPE) {
370            content_type
371        } else if allow_missing_content_type {
372            return Ok(response);
373        } else {
374            return Err(super::Error::InvalidContentType(HeaderValue::from_static(
375                "",
376            )));
377        };
378
379    if content_type
380        .to_str()
381        .map_err(|_| ())
382        .and_then(|s| s.parse::<mime::Mime>().map_err(|_| ()))
383        .map(|mime_type| {
384            matches!(
385                (mime_type.type_(), mime_type.subtype()),
386                (mime::TEXT, mime::EVENT_STREAM)
387            )
388        })
389        .unwrap_or(false)
390    {
391        Ok(response)
392    } else {
393        Err(super::Error::InvalidContentType(content_type.clone()))
394    }
395}