reqwest_eventsource/
event_source.rs

1use crate::error::{CannotCloneRequestError, Error};
2use crate::retry::{RetryPolicy, DEFAULT_RETRY};
3use core::pin::Pin;
4use eventsource_stream::Eventsource;
5pub use eventsource_stream::{Event as MessageEvent, EventStreamError};
6#[cfg(not(target_arch = "wasm32"))]
7use futures_core::future::BoxFuture;
8use futures_core::future::Future;
9#[cfg(target_arch = "wasm32")]
10use futures_core::future::LocalBoxFuture;
11#[cfg(not(target_arch = "wasm32"))]
12use futures_core::stream::BoxStream;
13#[cfg(target_arch = "wasm32")]
14use futures_core::stream::LocalBoxStream;
15use futures_core::stream::Stream;
16use futures_core::task::{Context, Poll};
17use futures_timer::Delay;
18use pin_project_lite::pin_project;
19use reqwest::header::{HeaderName, HeaderValue};
20use reqwest::{Error as ReqwestError, IntoUrl, RequestBuilder, Response, StatusCode};
21use std::time::Duration;
22
23#[cfg(not(target_arch = "wasm32"))]
24type ResponseFuture = BoxFuture<'static, Result<Response, ReqwestError>>;
25#[cfg(target_arch = "wasm32")]
26type ResponseFuture = LocalBoxFuture<'static, Result<Response, ReqwestError>>;
27
28#[cfg(not(target_arch = "wasm32"))]
29type EventStream = BoxStream<'static, Result<MessageEvent, EventStreamError<ReqwestError>>>;
30#[cfg(target_arch = "wasm32")]
31type EventStream = LocalBoxStream<'static, Result<MessageEvent, EventStreamError<ReqwestError>>>;
32
33type BoxedRetry = Box<dyn RetryPolicy + Send + Unpin + 'static>;
34
35/// The ready state of an [`EventSource`]
36#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
37#[repr(u8)]
38pub enum ReadyState {
39    /// The EventSource is waiting on a response from the endpoint
40    Connecting = 0,
41    /// The EventSource is connected
42    Open = 1,
43    /// The EventSource is closed and no longer emitting Events
44    Closed = 2,
45}
46
47pin_project! {
48/// Provides the [`Stream`] implementation for the [`Event`] items. This wraps the
49/// [`RequestBuilder`] and retries requests when they fail.
50#[project = EventSourceProjection]
51pub struct EventSource {
52    builder: RequestBuilder,
53    #[pin]
54    next_response: Option<ResponseFuture>,
55    #[pin]
56    cur_stream: Option<EventStream>,
57    #[pin]
58    delay: Option<Delay>,
59    is_closed: bool,
60    retry_policy: BoxedRetry,
61    last_event_id: String,
62    last_retry: Option<(usize, Duration)>
63}
64}
65
66impl EventSource {
67    /// Wrap a [`RequestBuilder`]
68    pub fn new(builder: RequestBuilder) -> Result<Self, CannotCloneRequestError> {
69        let builder = builder.header(
70            reqwest::header::ACCEPT,
71            HeaderValue::from_static("text/event-stream"),
72        );
73        let res_future = Box::pin(builder.try_clone().ok_or(CannotCloneRequestError)?.send());
74        Ok(Self {
75            builder,
76            next_response: Some(res_future),
77            cur_stream: None,
78            delay: None,
79            is_closed: false,
80            retry_policy: Box::new(DEFAULT_RETRY),
81            last_event_id: String::new(),
82            last_retry: None,
83        })
84    }
85
86    /// Create a simple EventSource based on a GET request
87    pub fn get<T: IntoUrl>(url: T) -> Self {
88        Self::new(reqwest::Client::new().get(url)).unwrap()
89    }
90
91    /// Close the EventSource stream and stop trying to reconnect
92    pub fn close(&mut self) {
93        self.is_closed = true;
94    }
95
96    /// Set the retry policy
97    pub fn set_retry_policy(&mut self, policy: BoxedRetry) {
98        self.retry_policy = policy
99    }
100
101    /// Get the last event id
102    pub fn last_event_id(&self) -> &str {
103        &self.last_event_id
104    }
105
106    /// Get the current ready state
107    pub fn ready_state(&self) -> ReadyState {
108        if self.is_closed {
109            ReadyState::Closed
110        } else if self.delay.is_some() || self.next_response.is_some() {
111            ReadyState::Connecting
112        } else {
113            ReadyState::Open
114        }
115    }
116}
117
118fn check_response(response: Response) -> Result<Response, Error> {
119    match response.status() {
120        StatusCode::OK => {}
121        status => {
122            return Err(Error::InvalidStatusCode(status, response));
123        }
124    }
125    let content_type =
126        if let Some(content_type) = response.headers().get(&reqwest::header::CONTENT_TYPE) {
127            content_type
128        } else {
129            return Err(Error::InvalidContentType(
130                HeaderValue::from_static(""),
131                response,
132            ));
133        };
134    if content_type
135        .to_str()
136        .map_err(|_| ())
137        .and_then(|s| s.parse::<mime::Mime>().map_err(|_| ()))
138        .map(|mime_type| {
139            matches!(
140                (mime_type.type_(), mime_type.subtype()),
141                (mime::TEXT, mime::EVENT_STREAM)
142            )
143        })
144        .unwrap_or(false)
145    {
146        Ok(response)
147    } else {
148        Err(Error::InvalidContentType(content_type.clone(), response))
149    }
150}
151
152impl<'a> EventSourceProjection<'a> {
153    fn clear_fetch(&mut self) {
154        self.next_response.take();
155        self.cur_stream.take();
156    }
157
158    fn retry_fetch(&mut self) -> Result<(), Error> {
159        self.cur_stream.take();
160        let req = self.builder.try_clone().unwrap().header(
161            HeaderName::from_static("last-event-id"),
162            HeaderValue::from_str(self.last_event_id)
163                .map_err(|_| Error::InvalidLastEventId(self.last_event_id.clone()))?,
164        );
165        let res_future = Box::pin(req.send());
166        self.next_response.replace(res_future);
167        Ok(())
168    }
169
170    fn handle_response(&mut self, res: Response) {
171        self.last_retry.take();
172        let mut stream = res.bytes_stream().eventsource();
173        stream.set_last_event_id(self.last_event_id.clone());
174        self.cur_stream.replace(Box::pin(stream));
175    }
176
177    fn handle_event(&mut self, event: &MessageEvent) {
178        *self.last_event_id = event.id.clone();
179        if let Some(duration) = event.retry {
180            self.retry_policy.set_reconnection_time(duration)
181        }
182    }
183
184    fn handle_error(&mut self, error: &Error) {
185        self.clear_fetch();
186        if let Some(retry_delay) = self.retry_policy.retry(error, *self.last_retry) {
187            let retry_num = self.last_retry.map(|retry| retry.0).unwrap_or(1);
188            *self.last_retry = Some((retry_num, retry_delay));
189            self.delay.replace(Delay::new(retry_delay));
190        } else {
191            *self.is_closed = true;
192        }
193    }
194}
195
196/// Events created by the [`EventSource`]
197#[derive(Debug, Clone, Eq, PartialEq)]
198pub enum Event {
199    /// The event fired when the connection is opened
200    Open,
201    /// The event fired when a [`MessageEvent`] is received
202    Message(MessageEvent),
203}
204
205impl From<MessageEvent> for Event {
206    fn from(event: MessageEvent) -> Self {
207        Event::Message(event)
208    }
209}
210
211impl Stream for EventSource {
212    type Item = Result<Event, Error>;
213
214    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
215        let mut this = self.project();
216
217        if *this.is_closed {
218            return Poll::Ready(None);
219        }
220
221        if let Some(delay) = this.delay.as_mut().as_pin_mut() {
222            match delay.poll(cx) {
223                Poll::Ready(_) => {
224                    this.delay.take();
225                    if let Err(err) = this.retry_fetch() {
226                        *this.is_closed = true;
227                        return Poll::Ready(Some(Err(err)));
228                    }
229                }
230                Poll::Pending => return Poll::Pending,
231            }
232        }
233
234        if let Some(response_future) = this.next_response.as_mut().as_pin_mut() {
235            match response_future.poll(cx) {
236                Poll::Ready(Ok(res)) => {
237                    this.clear_fetch();
238                    match check_response(res) {
239                        Ok(res) => {
240                            this.handle_response(res);
241                            return Poll::Ready(Some(Ok(Event::Open)));
242                        }
243                        Err(err) => {
244                            *this.is_closed = true;
245                            return Poll::Ready(Some(Err(err)));
246                        }
247                    }
248                }
249                Poll::Ready(Err(err)) => {
250                    let err = Error::Transport(err);
251                    this.handle_error(&err);
252                    return Poll::Ready(Some(Err(err)));
253                }
254                Poll::Pending => {
255                    return Poll::Pending;
256                }
257            }
258        }
259
260        match this
261            .cur_stream
262            .as_mut()
263            .as_pin_mut()
264            .unwrap()
265            .as_mut()
266            .poll_next(cx)
267        {
268            Poll::Ready(Some(Err(err))) => {
269                let err = err.into();
270                this.handle_error(&err);
271                Poll::Ready(Some(Err(err)))
272            }
273            Poll::Ready(Some(Ok(event))) => {
274                this.handle_event(&event);
275                Poll::Ready(Some(Ok(event.into())))
276            }
277            Poll::Ready(None) => {
278                let err = Error::StreamEnded;
279                this.handle_error(&err);
280                Poll::Ready(Some(Err(err)))
281            }
282            Poll::Pending => Poll::Pending,
283        }
284    }
285}