1use crate::{
6 http_client::{
7 HttpClientExt, Result as StreamResult,
8 retry::{DEFAULT_RETRY, ExponentialBackoff, RetryPolicy},
9 },
10 wasm_compat::{WasmCompatSend, WasmCompatSendStream},
11};
12use bytes::Bytes;
13use eventsource_stream::{Event as MessageEvent, EventStreamError, Eventsource};
14use futures::Stream;
15#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
16use futures::{future::BoxFuture, stream::BoxStream};
17#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
18use futures::{future::LocalBoxFuture, stream::LocalBoxStream};
19use futures_timer::Delay;
20use http::Response;
21use http::{HeaderName, HeaderValue, Request, StatusCode};
22use mime_guess::mime;
23use pin_project_lite::pin_project;
24use std::{
25 pin::Pin,
26 task::{Context, Poll},
27 time::Duration,
28};
29
30pub type BoxedStream = Pin<Box<dyn WasmCompatSendStream<InnerItem = StreamResult<Bytes>>>>;
31
32#[cfg(not(target_arch = "wasm32"))]
33type ResponseFuture = BoxFuture<'static, Result<Response<BoxedStream>, super::Error>>;
34#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
35type ResponseFuture = LocalBoxFuture<'static, Result<Response<BoxedStream>, super::Error>>;
36
37#[cfg(not(target_arch = "wasm32"))]
38type EventStream = BoxStream<'static, Result<MessageEvent, EventStreamError<super::Error>>>;
39#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
40type EventStream = LocalBoxStream<'static, Result<MessageEvent, EventStreamError<super::Error>>>;
41
42pin_project! {
43 #[project = SourceStateProjection]
45 enum SourceState {
46 Connecting {
48 #[pin]
49 response_future: ResponseFuture,
50 },
51 Reconnecting {
53 #[pin]
54 response_future: ResponseFuture,
55 last_retry: (usize, Duration),
56 },
57 Open {
59 #[pin]
60 event_stream: EventStream,
61 },
62 WaitingToRetry {
64 #[pin]
65 retry_delay: Delay,
66 current_retry: (usize, Duration),
67 },
68 Closed,
70 }
71}
72
73pin_project! {
74 #[project = GenericEventSourceProjection]
76 pub struct GenericEventSource<HttpClient, RequestBody, Retry = ExponentialBackoff> {
77 client: HttpClient,
78 req: Request<RequestBody>,
79 retry_policy: Retry,
80 last_event_id: Option<String>,
81 #[pin]
82 state: SourceState,
83 }
84}
85
86impl<HttpClient, RequestBody> GenericEventSource<HttpClient, RequestBody>
87where
88 HttpClient: HttpClientExt + Clone + 'static,
89 RequestBody: Into<Bytes> + Clone + WasmCompatSend + 'static,
90{
91 pub fn new(client: HttpClient, req: Request<RequestBody>) -> Self {
93 let response_future = Self::create_response_future(&client, &req, None);
94 let state = SourceState::Connecting { response_future };
95
96 Self {
97 client,
98 req,
99 retry_policy: DEFAULT_RETRY,
100 last_event_id: None,
101 state,
102 }
103 }
104
105 pub fn with_retry_policy<R>(
106 client: HttpClient,
107 req: Request<RequestBody>,
108 retry_policy: R,
109 ) -> GenericEventSource<HttpClient, RequestBody, R>
110 where
111 R: RetryPolicy,
112 {
113 let response_future = Self::create_response_future(&client, &req, None);
114 let state = SourceState::Connecting { response_future };
115
116 GenericEventSource {
117 client,
118 req,
119 retry_policy,
120 last_event_id: None,
121 state,
122 }
123 }
124
125 fn create_response_future(
127 client: &HttpClient,
128 req: &Request<RequestBody>,
129 last_event_id: Option<&str>,
130 ) -> ResponseFuture {
131 let mut req_clone = req.clone();
132 req_clone
133 .headers_mut()
134 .entry("Accept")
135 .or_insert(HeaderValue::from_static("text/event-stream"));
136
137 if let Some(id) = last_event_id
138 && let Ok(value) = HeaderValue::from_str(id)
139 {
140 req_clone
141 .headers_mut()
142 .insert(HeaderName::from_static("last-event-id"), value);
143 }
144
145 let client_clone = client.clone();
146 Box::pin(async move { client_clone.send_streaming(req_clone).await })
147 }
148
149 pub fn last_event_id(&self) -> Option<&str> {
151 self.last_event_id.as_deref()
152 }
153
154 pub fn close(&mut self) {
157 self.state = SourceState::Closed;
158 }
159}
160
161#[derive(Debug, Clone, Eq, PartialEq)]
163pub enum Event {
164 Open,
166 Message(MessageEvent),
168}
169
170impl From<MessageEvent> for Event {
171 fn from(event: MessageEvent) -> Self {
172 Event::Message(event)
173 }
174}
175
176impl<HttpClient, RequestBody> Stream for GenericEventSource<HttpClient, RequestBody>
177where
178 HttpClient: HttpClientExt + Clone + 'static,
179 RequestBody: Into<Bytes> + Clone + WasmCompatSend + 'static,
180{
181 type Item = Result<Event, super::Error>;
182
183 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
184 let mut this = self.project();
185
186 loop {
187 match this.state.as_mut().project() {
188 SourceStateProjection::Connecting { response_future } => {
189 match response_future.poll(cx) {
190 Poll::Pending => return Poll::Pending,
191 Poll::Ready(Ok(response)) => {
192 match check_response(response) {
193 Ok(response) => {
194 let mut event_stream = response.into_body().eventsource();
196 if let Some(id) = &this.last_event_id {
197 event_stream.set_last_event_id(id.clone());
198 }
199 this.state.set(SourceState::Open {
200 event_stream: Box::pin(event_stream),
201 });
202 return Poll::Ready(Some(Ok(Event::Open)));
203 }
204 Err(err) => {
205 this.state.set(SourceState::Closed);
207 return Poll::Ready(Some(Err(err)));
208 }
209 }
210 }
211 Poll::Ready(Err(err)) => {
212 if let Some(delay_duration) = this.retry_policy.retry(&err, None) {
214 this.state.set(SourceState::WaitingToRetry {
216 retry_delay: Delay::new(delay_duration),
217 current_retry: (1, delay_duration),
218 });
219 return Poll::Ready(Some(Err(err)));
220 } else {
221 this.state.set(SourceState::Closed);
223 return Poll::Ready(Some(Err(err)));
224 }
225 }
226 }
227 }
228
229 SourceStateProjection::Reconnecting {
230 response_future,
231 last_retry,
232 } => {
233 match response_future.poll(cx) {
234 Poll::Pending => return Poll::Pending,
235 Poll::Ready(Ok(response)) => {
236 match check_response(response) {
237 Ok(response) => {
238 let mut event_stream = response.into_body().eventsource();
240 if let Some(id) = &this.last_event_id {
241 event_stream.set_last_event_id(id.clone());
242 }
243 this.state.set(SourceState::Open {
244 event_stream: Box::pin(event_stream),
245 });
246 return Poll::Ready(Some(Ok(Event::Open)));
247 }
248 Err(err) => {
249 this.state.set(SourceState::Closed);
251 return Poll::Ready(Some(Err(err)));
252 }
253 }
254 }
255 Poll::Ready(Err(err)) => {
256 if let Some(delay_duration) =
258 this.retry_policy.retry(&err, Some(*last_retry))
259 {
260 let (retry_num, _) = *last_retry;
261 this.state.set(SourceState::WaitingToRetry {
263 retry_delay: Delay::new(delay_duration),
264 current_retry: (retry_num + 1, delay_duration),
265 });
266 return Poll::Ready(Some(Err(err)));
267 } else {
268 this.state.set(SourceState::Closed);
270 return Poll::Ready(Some(Err(err)));
271 }
272 }
273 }
274 }
275
276 SourceStateProjection::Open { event_stream } => {
277 match event_stream.poll_next(cx) {
278 Poll::Pending => return Poll::Pending,
279 Poll::Ready(Some(Ok(event))) => {
280 if !event.id.is_empty() {
281 *this.last_event_id = Some(event.id.clone());
282 }
283 if let Some(duration) = event.retry {
284 this.retry_policy.set_reconnection_time(duration);
285 }
286 return Poll::Ready(Some(Ok(Event::Message(event))));
287 }
288 Poll::Ready(Some(Err(EventStreamError::Transport(err)))) => {
289 if let Some(delay_duration) = this.retry_policy.retry(&err, None) {
291 this.state.set(SourceState::WaitingToRetry {
293 retry_delay: Delay::new(delay_duration),
294 current_retry: (1, delay_duration),
295 });
296 return Poll::Ready(Some(Err(err)));
297 } else {
298 this.state.set(SourceState::Closed);
300 return Poll::Ready(Some(Err(err)));
301 }
302 }
303 Poll::Ready(Some(Err(EventStreamError::Parser(_)))) => {
304 continue;
306 }
307 Poll::Ready(Some(Err(EventStreamError::Utf8(_)))) => {
308 continue;
310 }
311 Poll::Ready(None) => {
312 this.state.set(SourceState::Closed);
314 return Poll::Ready(None);
315 }
316 }
317 }
318
319 SourceStateProjection::WaitingToRetry {
320 retry_delay,
321 current_retry,
322 } => {
323 let retry_info = *current_retry;
325 match retry_delay.poll(cx) {
326 Poll::Pending => return Poll::Pending,
327 Poll::Ready(()) => {
328 let response_future =
330 GenericEventSource::<HttpClient, RequestBody>::create_response_future(
331 this.client,
332 this.req,
333 this.last_event_id.as_deref(),
334 );
335 this.state.set(SourceState::Reconnecting {
336 response_future,
337 last_retry: retry_info,
338 });
339 continue;
340 }
341 }
342 }
343
344 SourceStateProjection::Closed => {
345 return Poll::Ready(None);
346 }
347 }
348 }
349 }
350}
351
352fn check_response<T>(response: Response<T>) -> Result<Response<T>, super::Error> {
353 let StatusCode::OK = response.status() else {
354 return Err(super::Error::InvalidStatusCode(response.status()));
355 };
356
357 let content_type =
358 if let Some(content_type) = response.headers().get(&reqwest::header::CONTENT_TYPE) {
359 content_type
360 } else {
361 return Err(super::Error::InvalidContentType(HeaderValue::from_static(
362 "",
363 )));
364 };
365
366 if content_type
367 .to_str()
368 .map_err(|_| ())
369 .and_then(|s| s.parse::<mime::Mime>().map_err(|_| ()))
370 .map(|mime_type| {
371 matches!(
372 (mime_type.type_(), mime_type.subtype()),
373 (mime::TEXT, mime::EVENT_STREAM)
374 )
375 })
376 .unwrap_or(false)
377 {
378 Ok(response)
379 } else {
380 Err(super::Error::InvalidContentType(content_type.clone()))
381 }
382}