Skip to main content

reqwest_eventsource/
event_source.rs

1use crate::error::{CannotCloneRequestError, Error};
2use crate::retry::{RetryPolicy, DEFAULT_RETRY};
3use core::pin::Pin;
4use eventsource_stream::Eventsource;
5pub use eventsource_stream::{Event as MessageEvent, EventStreamError};
6#[cfg(not(target_arch = "wasm32"))]
7use futures_core::future::BoxFuture;
8use futures_core::future::Future;
9#[cfg(target_arch = "wasm32")]
10use futures_core::future::LocalBoxFuture;
11#[cfg(not(target_arch = "wasm32"))]
12use futures_core::stream::BoxStream;
13#[cfg(target_arch = "wasm32")]
14use futures_core::stream::LocalBoxStream;
15use futures_core::stream::Stream;
16use futures_core::task::{Context, Poll};
17use futures_timer::Delay;
18use pin_project_lite::pin_project;
19use reqwest::header::{HeaderName, HeaderValue};
20use reqwest::{Error as ReqwestError, IntoUrl, RequestBuilder, Response, StatusCode};
21use std::time::Duration;
22
23#[cfg(not(target_arch = "wasm32"))]
24type ResponseFuture = BoxFuture<'static, Result<Response, ReqwestError>>;
25#[cfg(target_arch = "wasm32")]
26type ResponseFuture = LocalBoxFuture<'static, Result<Response, ReqwestError>>;
27
28#[cfg(not(target_arch = "wasm32"))]
29type EventStream = BoxStream<'static, Result<MessageEvent, EventStreamError<ReqwestError>>>;
30#[cfg(target_arch = "wasm32")]
31type EventStream = LocalBoxStream<'static, Result<MessageEvent, EventStreamError<ReqwestError>>>;
32
33type BoxedRetry = Box<dyn RetryPolicy + Send + Unpin + 'static>;
34
35/// The ready state of an [`EventSource`]
36#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
37#[repr(u8)]
38pub enum ReadyState {
39    /// The EventSource is waiting on a response from the endpoint
40    Connecting = 0,
41    /// The EventSource is connected
42    Open = 1,
43    /// The EventSource is closed and no longer emitting Events
44    Closed = 2,
45}
46
47pin_project! {
48/// Provides the [`Stream`] implementation for the [`Event`] items. This wraps the
49/// [`RequestBuilder`] and retries requests when they fail.
50#[project = EventSourceProjection]
51pub struct EventSource {
52    builder: RequestBuilder,
53    #[pin]
54    next_response: Option<ResponseFuture>,
55    #[pin]
56    cur_stream: Option<EventStream>,
57    #[pin]
58    delay: Option<Delay>,
59    is_closed: bool,
60    retry_policy: BoxedRetry,
61    last_event_id: String,
62    last_retry: Option<(usize, Duration)>
63}
64}
65
66impl EventSource {
67    /// Wrap a [`RequestBuilder`]
68    pub fn new(builder: RequestBuilder) -> Result<Self, CannotCloneRequestError> {
69        let builder = builder.header(
70            reqwest::header::ACCEPT,
71            HeaderValue::from_static("text/event-stream"),
72        );
73        let res_future = Box::pin(builder.try_clone().ok_or(CannotCloneRequestError)?.send());
74        Ok(Self {
75            builder,
76            next_response: Some(res_future),
77            cur_stream: None,
78            delay: None,
79            is_closed: false,
80            retry_policy: Box::new(DEFAULT_RETRY),
81            last_event_id: String::new(),
82            last_retry: None,
83        })
84    }
85
86    /// Create a simple EventSource based on a GET request
87    pub fn get<T: IntoUrl>(url: T) -> Self {
88        Self::new(reqwest::Client::new().get(url)).unwrap()
89    }
90
91    /// Close the EventSource stream and stop trying to reconnect
92    pub fn close(&mut self) {
93        self.is_closed = true;
94    }
95
96    /// Set the retry policy
97    pub fn set_retry_policy(&mut self, policy: BoxedRetry) {
98        self.retry_policy = policy
99    }
100
101    /// Get the last event id
102    pub fn last_event_id(&self) -> &str {
103        &self.last_event_id
104    }
105
106    /// Get the current ready state
107    pub fn ready_state(&self) -> ReadyState {
108        if self.is_closed {
109            ReadyState::Closed
110        } else if self.delay.is_some() || self.next_response.is_some() {
111            ReadyState::Connecting
112        } else {
113            ReadyState::Open
114        }
115    }
116}
117
118fn check_response(response: Response) -> Result<Response, Error> {
119    match response.status() {
120        StatusCode::OK => {}
121        status => {
122            return Err(Error::InvalidStatusCode(status, response));
123        }
124    }
125    let content_type =
126        if let Some(content_type) = response.headers().get(&reqwest::header::CONTENT_TYPE) {
127            content_type
128        } else {
129            return Err(Error::InvalidContentType(
130                HeaderValue::from_static(""),
131                response,
132            ));
133        };
134    if content_type
135        .to_str()
136        .map_err(|_| ())
137        .and_then(|s| s.parse::<mime::Mime>().map_err(|_| ()))
138        .map(|mime_type| {
139            matches!(
140                (mime_type.type_(), mime_type.subtype()),
141                (mime::TEXT, mime::EVENT_STREAM)
142            )
143        })
144        .unwrap_or(false)
145    {
146        Ok(response)
147    } else {
148        Err(Error::InvalidContentType(content_type.clone(), response))
149    }
150}
151
152impl<'a> EventSourceProjection<'a> {
153    fn clear_fetch(&mut self) {
154        self.next_response.take();
155        self.cur_stream.take();
156    }
157
158    fn retry_fetch(&mut self) -> Result<(), Error> {
159        self.cur_stream.take();
160        let req = self.builder.try_clone().unwrap().header(
161            HeaderName::from_static("last-event-id"),
162            HeaderValue::from_str(self.last_event_id)
163                .map_err(|_| Error::InvalidLastEventId(self.last_event_id.clone()))?,
164        );
165        let res_future = Box::pin(req.send());
166        self.next_response.replace(res_future);
167        Ok(())
168    }
169
170    fn handle_response(&mut self, res: Response) {
171        self.last_retry.take();
172        let mut stream = res.bytes_stream().eventsource();
173        stream.set_last_event_id(self.last_event_id.clone());
174        self.cur_stream.replace(Box::pin(stream));
175    }
176
177    fn handle_event(&mut self, event: &MessageEvent) {
178        *self.last_event_id = event.id.clone();
179        if let Some(duration) = event.retry {
180            self.retry_policy.set_reconnection_time(duration)
181        }
182    }
183
184    fn handle_error(&mut self, error: &Error) {
185        self.clear_fetch();
186        if let Some(retry_delay) = self.retry_policy.retry(error, *self.last_retry) {
187            let retry_num = self
188                .last_retry
189                .map(|retry| retry.0.saturating_add(1))
190                .unwrap_or(1);
191            *self.last_retry = Some((retry_num, retry_delay));
192            self.delay.replace(Delay::new(retry_delay));
193        } else {
194            *self.is_closed = true;
195        }
196    }
197}
198
199/// Events created by the [`EventSource`]
200#[derive(Debug, Clone, Eq, PartialEq)]
201pub enum Event {
202    /// The event fired when the connection is opened
203    Open,
204    /// The event fired when a [`MessageEvent`] is received
205    Message(MessageEvent),
206}
207
208impl From<MessageEvent> for Event {
209    fn from(event: MessageEvent) -> Self {
210        Event::Message(event)
211    }
212}
213
214impl Stream for EventSource {
215    type Item = Result<Event, Error>;
216
217    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
218        let mut this = self.project();
219
220        if *this.is_closed {
221            return Poll::Ready(None);
222        }
223
224        if let Some(delay) = this.delay.as_mut().as_pin_mut() {
225            match delay.poll(cx) {
226                Poll::Ready(_) => {
227                    this.delay.take();
228                    if let Err(err) = this.retry_fetch() {
229                        *this.is_closed = true;
230                        return Poll::Ready(Some(Err(err)));
231                    }
232                }
233                Poll::Pending => return Poll::Pending,
234            }
235        }
236
237        if let Some(response_future) = this.next_response.as_mut().as_pin_mut() {
238            match response_future.poll(cx) {
239                Poll::Ready(Ok(res)) => {
240                    this.clear_fetch();
241                    match check_response(res) {
242                        Ok(res) => {
243                            this.handle_response(res);
244                            return Poll::Ready(Some(Ok(Event::Open)));
245                        }
246                        Err(err) => {
247                            *this.is_closed = true;
248                            return Poll::Ready(Some(Err(err)));
249                        }
250                    }
251                }
252                Poll::Ready(Err(err)) => {
253                    let err = Error::Transport(err);
254                    this.handle_error(&err);
255                    return Poll::Ready(Some(Err(err)));
256                }
257                Poll::Pending => {
258                    return Poll::Pending;
259                }
260            }
261        }
262
263        match this
264            .cur_stream
265            .as_mut()
266            .as_pin_mut()
267            .unwrap()
268            .as_mut()
269            .poll_next(cx)
270        {
271            Poll::Ready(Some(Err(err))) => {
272                let err = err.into();
273                this.handle_error(&err);
274                Poll::Ready(Some(Err(err)))
275            }
276            Poll::Ready(Some(Ok(event))) => {
277                this.handle_event(&event);
278                Poll::Ready(Some(Ok(event.into())))
279            }
280            Poll::Ready(None) => {
281                let err = Error::StreamEnded;
282                this.handle_error(&err);
283                Poll::Ready(Some(Err(err)))
284            }
285            Poll::Pending => Poll::Pending,
286        }
287    }
288}