rig/http_client/
sse.rs

1//! An SSE implementation that leverages [`crate::http_client::HttpClientExt`] to allow streaming with automatic retry handling for any implementor of HttpClientExt.
2//!
3//! Primarily intended for internal usage. However if you also wish to implement generic HTTP streaming for your custom completion model,
4//! you may find this helpful.
5
6use std::{
7    pin::Pin,
8    task::{Context, Poll},
9    time::Duration,
10};
11
12use bytes::Bytes;
13use eventsource_stream::{Event as MessageEvent, EventStreamError, Eventsource};
14use futures::Stream;
15#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
16use futures::{future::BoxFuture, stream::BoxStream};
17#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
18use futures::{future::LocalBoxFuture, stream::LocalBoxStream};
19use futures_timer::Delay;
20use http::Response;
21use http::{HeaderName, HeaderValue, Request, StatusCode};
22use mime_guess::mime;
23use pin_project_lite::pin_project;
24
25use crate::{
26    http_client::{
27        HttpClientExt, Result as StreamResult, instance_error,
28        retry::{DEFAULT_RETRY, RetryPolicy},
29    },
30    wasm_compat::{WasmCompatSend, WasmCompatSendStream},
31};
32
33pub type BoxedStream = Pin<Box<dyn WasmCompatSendStream<InnerItem = StreamResult<Bytes>>>>;
34
35#[cfg(not(target_arch = "wasm32"))]
36type ResponseFuture<T> = BoxFuture<'static, Result<Response<T>, super::Error>>;
37#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
38type ResponseFuture<T> = LocalBoxFuture<'static, Result<Response<T>, super::Error>>;
39
40#[cfg(not(target_arch = "wasm32"))]
41type EventStream = BoxStream<'static, Result<MessageEvent, EventStreamError<super::Error>>>;
42#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
43type EventStream = LocalBoxStream<'static, Result<MessageEvent, EventStreamError<super::Error>>>;
44type BoxedRetry = Box<dyn RetryPolicy + Send + Unpin + 'static>;
45
46/// The ready state of a [`GenericEventSource`]
47#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
48#[repr(u8)]
49pub enum ReadyState {
50    /// The EventSource is waiting on a response from the endpoint
51    Connecting = 0,
52    /// The EventSource is connected
53    Open = 1,
54    /// The EventSource is closed and no longer emitting Events
55    Closed = 2,
56}
57
58pin_project! {
59    /// A generic event source that can use any HTTP client.
60    /// Modeled heavily on the `reqwest-eventsource` implementation.
61    #[project = GenericEventSourceProjection]
62    pub struct GenericEventSource<HttpClient, RequestBody, ResponseBody>
63    where
64        HttpClient: HttpClientExt,
65    {
66        client: HttpClient,
67        req: Request<RequestBody>,
68        #[pin]
69        next_response: Option<ResponseFuture<ResponseBody>>,
70        #[pin]
71        cur_stream: Option<EventStream>,
72        #[pin]
73        delay: Option<Delay>,
74        is_closed: bool,
75        retry_policy: BoxedRetry,
76        last_event_id: String,
77        last_retry: Option<(usize, Duration)>,
78    }
79}
80
81impl<HttpClient, RequestBody>
82    GenericEventSource<
83        HttpClient,
84        RequestBody,
85        Pin<Box<dyn WasmCompatSendStream<InnerItem = StreamResult<Bytes>>>>,
86    >
87where
88    HttpClient: HttpClientExt + Clone + 'static,
89    RequestBody: Into<Bytes> + Clone + Send + 'static,
90{
91    pub fn new(client: HttpClient, req: Request<RequestBody>) -> Self {
92        let client_clone = client.clone();
93        let mut req_clone = req.clone();
94        req_clone
95            .headers_mut()
96            .entry("Accept")
97            .or_insert(HeaderValue::from_static("text/event-stream"));
98        let res_fut = Box::pin(async move { client_clone.clone().send_streaming(req_clone).await });
99        Self {
100            client,
101            next_response: Some(res_fut),
102            cur_stream: None,
103            req,
104            delay: None,
105            is_closed: false,
106            retry_policy: Box::new(DEFAULT_RETRY),
107            last_event_id: String::new(),
108            last_retry: None,
109        }
110    }
111
112    /// Close the EventSource stream and stop trying to reconnect
113    pub fn close(&mut self) {
114        self.is_closed = true;
115    }
116
117    /// Get the last event id
118    pub fn last_event_id(&self) -> &str {
119        &self.last_event_id
120    }
121
122    /// Get the current ready state
123    pub fn ready_state(&self) -> ReadyState {
124        if self.is_closed {
125            ReadyState::Closed
126        } else if self.delay.is_some() || self.next_response.is_some() {
127            ReadyState::Connecting
128        } else {
129            ReadyState::Open
130        }
131    }
132}
133
134impl<'a, HttpClient, RequestBody>
135    GenericEventSourceProjection<'a, HttpClient, RequestBody, BoxedStream>
136where
137    HttpClient: HttpClientExt + Clone + 'static,
138    RequestBody: Into<Bytes> + Clone + WasmCompatSend + 'static,
139{
140    fn clear_fetch(&mut self) {
141        self.next_response.take();
142        self.cur_stream.take();
143    }
144
145    fn retry_fetch(&mut self) -> Result<(), super::Error> {
146        self.cur_stream.take();
147        let mut req = self.req.clone();
148        req.headers_mut().insert(
149            HeaderName::from_static("last-event-id"),
150            HeaderValue::from_str(self.last_event_id).map_err(instance_error)?,
151        );
152        let client = self.client.clone();
153        let res_future = Box::pin(async move { client.send_streaming(req).await });
154        self.next_response.replace(res_future);
155        Ok(())
156    }
157
158    fn handle_response<T>(&mut self, res: Response<T>)
159    where
160        T: Stream<Item = StreamResult<Bytes>> + WasmCompatSend + 'static,
161    {
162        self.last_retry.take();
163        let mut stream = res.into_body().eventsource();
164        stream.set_last_event_id(self.last_event_id.clone());
165        self.cur_stream.replace(Box::pin(stream));
166    }
167
168    fn handle_event(&mut self, event: &eventsource_stream::Event) {
169        *self.last_event_id = event.id.clone();
170        if let Some(duration) = event.retry {
171            self.retry_policy.set_reconnection_time(duration)
172        }
173    }
174
175    fn handle_error(&mut self, error: &super::Error) {
176        self.clear_fetch();
177        if let Some(retry_delay) = self.retry_policy.retry(error, *self.last_retry) {
178            let retry_num = self
179                .last_retry
180                .map(|retry| retry.0.saturating_add(1))
181                .unwrap_or(1);
182            *self.last_retry = Some((retry_num, retry_delay));
183            self.delay.replace(Delay::new(retry_delay));
184        } else {
185            *self.is_closed = true;
186        }
187    }
188}
189
190/// Events created by the [`GenericEventSource`]
191#[derive(Debug, Clone, Eq, PartialEq)]
192pub enum Event {
193    /// The event fired when the connection is opened
194    Open,
195    /// The event fired when a [`MessageEvent`] is received
196    Message(MessageEvent),
197}
198
199impl From<MessageEvent> for Event {
200    fn from(event: MessageEvent) -> Self {
201        Event::Message(event)
202    }
203}
204
205impl<HttpClient, RequestBody> Stream for GenericEventSource<HttpClient, RequestBody, BoxedStream>
206where
207    HttpClient: HttpClientExt + Clone + 'static,
208    RequestBody: Into<Bytes> + Clone + WasmCompatSend + 'static,
209{
210    type Item = Result<Event, super::Error>;
211
212    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
213        let mut this = self.project();
214
215        if *this.is_closed {
216            return Poll::Ready(None);
217        }
218
219        if let Some(delay) = this.delay.as_mut().as_pin_mut() {
220            match delay.poll(cx) {
221                Poll::Ready(_) => {
222                    this.delay.take();
223                    if let Err(err) = this.retry_fetch() {
224                        *this.is_closed = true;
225                        return Poll::Ready(Some(Err(err)));
226                    }
227                }
228                Poll::Pending => return Poll::Pending,
229            }
230        }
231
232        if let Some(response_future) = this.next_response.as_mut().as_pin_mut() {
233            match response_future.poll(cx) {
234                Poll::Ready(Ok(res)) => {
235                    this.clear_fetch();
236                    match check_response(res) {
237                        Ok(res) => {
238                            this.handle_response(res);
239                            return Poll::Ready(Some(Ok(Event::Open)));
240                        }
241                        Err(err) => {
242                            *this.is_closed = true;
243                            return Poll::Ready(Some(Err(err)));
244                        }
245                    }
246                }
247                Poll::Ready(Err(err)) => {
248                    this.handle_error(&err);
249                    return Poll::Ready(Some(Err(err)));
250                }
251                Poll::Pending => {
252                    return Poll::Pending;
253                }
254            }
255        }
256
257        match this
258            .cur_stream
259            .as_mut()
260            .as_pin_mut()
261            .unwrap()
262            .as_mut()
263            .poll_next(cx)
264        {
265            Poll::Ready(Some(Err(err))) => {
266                let EventStreamError::Transport(err) = err else {
267                    panic!("u");
268                };
269                this.handle_error(&err);
270                Poll::Ready(Some(Err(err)))
271            }
272            Poll::Ready(Some(Ok(event))) => {
273                this.handle_event(&event);
274                Poll::Ready(Some(Ok(event.into())))
275            }
276            Poll::Ready(None) => Poll::Ready(None),
277            Poll::Pending => Poll::Pending,
278        }
279    }
280}
281
282fn check_response<T>(response: Response<T>) -> Result<Response<T>, super::Error> {
283    match response.status() {
284        StatusCode::OK => {}
285        status => {
286            return Err(super::Error::InvalidStatusCode(status));
287        }
288    }
289    let content_type =
290        if let Some(content_type) = response.headers().get(&reqwest::header::CONTENT_TYPE) {
291            content_type
292        } else {
293            return Err(super::Error::InvalidContentType(HeaderValue::from_static(
294                "",
295            )));
296        };
297    if content_type
298        .to_str()
299        .map_err(|_| ())
300        .and_then(|s| s.parse::<mime::Mime>().map_err(|_| ()))
301        .map(|mime_type| {
302            matches!(
303                (mime_type.type_(), mime_type.subtype()),
304                (mime::TEXT, mime::EVENT_STREAM)
305            )
306        })
307        .unwrap_or(false)
308    {
309        Ok(response)
310    } else {
311        Err(super::Error::InvalidContentType(content_type.clone()))
312    }
313}