1use crate::{
6 http_client::{
7 HttpClientExt, Result as StreamResult,
8 retry::{DEFAULT_RETRY, ExponentialBackoff, RetryPolicy},
9 },
10 wasm_compat::{WasmCompatSend, WasmCompatSendStream},
11};
12use bytes::Bytes;
13use eventsource_stream::{Event as MessageEvent, EventStreamError, Eventsource};
14use futures::Stream;
15#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
16use futures::{future::BoxFuture, stream::BoxStream};
17#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
18use futures::{future::LocalBoxFuture, stream::LocalBoxStream};
19use futures_timer::Delay;
20use http::Response;
21use http::{HeaderName, HeaderValue, Request, StatusCode};
22use mime_guess::mime;
23use pin_project_lite::pin_project;
24use std::{
25 pin::Pin,
26 task::{Context, Poll},
27 time::Duration,
28};
29
30pub type BoxedStream = Pin<Box<dyn WasmCompatSendStream<InnerItem = StreamResult<Bytes>>>>;
31
32#[cfg(not(target_arch = "wasm32"))]
33type ResponseFuture = BoxFuture<'static, Result<Response<BoxedStream>, super::Error>>;
34#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
35type ResponseFuture = LocalBoxFuture<'static, Result<Response<BoxedStream>, super::Error>>;
36
37#[cfg(not(target_arch = "wasm32"))]
38type EventStream = BoxStream<'static, Result<MessageEvent, EventStreamError<super::Error>>>;
39#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
40type EventStream = LocalBoxStream<'static, Result<MessageEvent, EventStreamError<super::Error>>>;
41
42pin_project! {
43 #[project = SourceStateProjection]
45 enum SourceState {
46 Connecting {
48 #[pin]
49 response_future: ResponseFuture,
50 },
51 Reconnecting {
53 #[pin]
54 response_future: ResponseFuture,
55 last_retry: (usize, Duration),
56 },
57 Open {
59 #[pin]
60 event_stream: EventStream,
61 },
62 WaitingToRetry {
64 #[pin]
65 retry_delay: Delay,
66 current_retry: (usize, Duration),
67 },
68 Closed,
70 }
71}
72
73pin_project! {
74 #[project = GenericEventSourceProjection]
76 pub struct GenericEventSource<HttpClient, RequestBody, Retry = ExponentialBackoff> {
77 client: HttpClient,
78 req: Request<RequestBody>,
79 retry_policy: Retry,
80 last_event_id: Option<String>,
81 allow_missing_content_type: bool,
82 #[pin]
83 state: SourceState,
84 }
85}
86
87impl<HttpClient, RequestBody> GenericEventSource<HttpClient, RequestBody>
88where
89 HttpClient: HttpClientExt + Clone + 'static,
90 RequestBody: Into<Bytes> + Clone + WasmCompatSend + 'static,
91{
92 pub fn new(client: HttpClient, req: Request<RequestBody>) -> Self {
94 let response_future = Self::create_response_future(&client, &req, None);
95 let state = SourceState::Connecting { response_future };
96
97 Self {
98 client,
99 req,
100 retry_policy: DEFAULT_RETRY,
101 last_event_id: None,
102 allow_missing_content_type: false,
103 state,
104 }
105 }
106
107 pub fn with_retry_policy<R>(
108 client: HttpClient,
109 req: Request<RequestBody>,
110 retry_policy: R,
111 ) -> GenericEventSource<HttpClient, RequestBody, R>
112 where
113 R: RetryPolicy,
114 {
115 let response_future = Self::create_response_future(&client, &req, None);
116 let state = SourceState::Connecting { response_future };
117
118 GenericEventSource {
119 client,
120 req,
121 retry_policy,
122 last_event_id: None,
123 allow_missing_content_type: false,
124 state,
125 }
126 }
127
128 pub fn allow_missing_content_type(mut self) -> Self {
129 self.allow_missing_content_type = true;
130 self
131 }
132
133 fn create_response_future(
135 client: &HttpClient,
136 req: &Request<RequestBody>,
137 last_event_id: Option<&str>,
138 ) -> ResponseFuture {
139 let mut req_clone = req.clone();
140 req_clone
141 .headers_mut()
142 .entry("Accept")
143 .or_insert(HeaderValue::from_static("text/event-stream"));
144
145 if let Some(id) = last_event_id
146 && let Ok(value) = HeaderValue::from_str(id)
147 {
148 req_clone
149 .headers_mut()
150 .insert(HeaderName::from_static("last-event-id"), value);
151 }
152
153 let client_clone = client.clone();
154 Box::pin(async move { client_clone.send_streaming(req_clone).await })
155 }
156
157 pub fn last_event_id(&self) -> Option<&str> {
159 self.last_event_id.as_deref()
160 }
161
162 pub fn close(&mut self) {
165 self.state = SourceState::Closed;
166 }
167}
168
169#[derive(Debug, Clone, Eq, PartialEq)]
171pub enum Event {
172 Open,
174 Message(MessageEvent),
176}
177
178impl From<MessageEvent> for Event {
179 fn from(event: MessageEvent) -> Self {
180 Event::Message(event)
181 }
182}
183
184impl<HttpClient, RequestBody> Stream for GenericEventSource<HttpClient, RequestBody>
185where
186 HttpClient: HttpClientExt + Clone + 'static,
187 RequestBody: Into<Bytes> + Clone + WasmCompatSend + 'static,
188{
189 type Item = Result<Event, super::Error>;
190
191 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192 let mut this = self.project();
193
194 loop {
195 match this.state.as_mut().project() {
196 SourceStateProjection::Connecting { response_future } => {
197 match response_future.poll(cx) {
198 Poll::Pending => return Poll::Pending,
199 Poll::Ready(Ok(response)) => {
200 match check_response(response, *this.allow_missing_content_type) {
201 Ok(response) => {
202 let mut event_stream = response.into_body().eventsource();
204 if let Some(id) = &this.last_event_id {
205 event_stream.set_last_event_id(id.clone());
206 }
207 this.state.set(SourceState::Open {
208 event_stream: Box::pin(event_stream),
209 });
210 return Poll::Ready(Some(Ok(Event::Open)));
211 }
212 Err(err) => {
213 this.state.set(SourceState::Closed);
215 return Poll::Ready(Some(Err(err)));
216 }
217 }
218 }
219 Poll::Ready(Err(err)) => {
220 if let Some(delay_duration) = this.retry_policy.retry(&err, None) {
222 this.state.set(SourceState::WaitingToRetry {
224 retry_delay: Delay::new(delay_duration),
225 current_retry: (1, delay_duration),
226 });
227 return Poll::Ready(Some(Err(err)));
228 } else {
229 this.state.set(SourceState::Closed);
231 return Poll::Ready(Some(Err(err)));
232 }
233 }
234 }
235 }
236
237 SourceStateProjection::Reconnecting {
238 response_future,
239 last_retry,
240 } => {
241 match response_future.poll(cx) {
242 Poll::Pending => return Poll::Pending,
243 Poll::Ready(Ok(response)) => {
244 match check_response(response, *this.allow_missing_content_type) {
245 Ok(response) => {
246 let mut event_stream = response.into_body().eventsource();
248 if let Some(id) = &this.last_event_id {
249 event_stream.set_last_event_id(id.clone());
250 }
251 this.state.set(SourceState::Open {
252 event_stream: Box::pin(event_stream),
253 });
254 return Poll::Ready(Some(Ok(Event::Open)));
255 }
256 Err(err) => {
257 this.state.set(SourceState::Closed);
259 return Poll::Ready(Some(Err(err)));
260 }
261 }
262 }
263 Poll::Ready(Err(err)) => {
264 if let Some(delay_duration) =
266 this.retry_policy.retry(&err, Some(*last_retry))
267 {
268 let (retry_num, _) = *last_retry;
269 this.state.set(SourceState::WaitingToRetry {
271 retry_delay: Delay::new(delay_duration),
272 current_retry: (retry_num + 1, delay_duration),
273 });
274 return Poll::Ready(Some(Err(err)));
275 } else {
276 this.state.set(SourceState::Closed);
278 return Poll::Ready(Some(Err(err)));
279 }
280 }
281 }
282 }
283
284 SourceStateProjection::Open { event_stream } => {
285 match event_stream.poll_next(cx) {
286 Poll::Pending => return Poll::Pending,
287 Poll::Ready(Some(Ok(event))) => {
288 if !event.id.is_empty() {
289 *this.last_event_id = Some(event.id.clone());
290 }
291 if let Some(duration) = event.retry {
292 this.retry_policy.set_reconnection_time(duration);
293 }
294 return Poll::Ready(Some(Ok(Event::Message(event))));
295 }
296 Poll::Ready(Some(Err(EventStreamError::Transport(err)))) => {
297 if let Some(delay_duration) = this.retry_policy.retry(&err, None) {
299 this.state.set(SourceState::WaitingToRetry {
301 retry_delay: Delay::new(delay_duration),
302 current_retry: (1, delay_duration),
303 });
304 return Poll::Ready(Some(Err(err)));
305 } else {
306 this.state.set(SourceState::Closed);
308 return Poll::Ready(Some(Err(err)));
309 }
310 }
311 Poll::Ready(Some(Err(EventStreamError::Parser(_)))) => {
312 continue;
314 }
315 Poll::Ready(Some(Err(EventStreamError::Utf8(_)))) => {
316 continue;
318 }
319 Poll::Ready(None) => {
320 this.state.set(SourceState::Closed);
322 return Poll::Ready(None);
323 }
324 }
325 }
326
327 SourceStateProjection::WaitingToRetry {
328 retry_delay,
329 current_retry,
330 } => {
331 let retry_info = *current_retry;
333 match retry_delay.poll(cx) {
334 Poll::Pending => return Poll::Pending,
335 Poll::Ready(()) => {
336 let response_future =
338 GenericEventSource::<HttpClient, RequestBody>::create_response_future(
339 this.client,
340 this.req,
341 this.last_event_id.as_deref(),
342 );
343 this.state.set(SourceState::Reconnecting {
344 response_future,
345 last_retry: retry_info,
346 });
347 continue;
348 }
349 }
350 }
351
352 SourceStateProjection::Closed => {
353 return Poll::Ready(None);
354 }
355 }
356 }
357 }
358}
359
360fn check_response<T>(
361 response: Response<T>,
362 allow_missing_content_type: bool,
363) -> Result<Response<T>, super::Error> {
364 let StatusCode::OK = response.status() else {
365 return Err(super::Error::InvalidStatusCode(response.status()));
366 };
367
368 let content_type =
369 if let Some(content_type) = response.headers().get(&reqwest::header::CONTENT_TYPE) {
370 content_type
371 } else if allow_missing_content_type {
372 return Ok(response);
373 } else {
374 return Err(super::Error::InvalidContentType(HeaderValue::from_static(
375 "",
376 )));
377 };
378
379 if content_type
380 .to_str()
381 .map_err(|_| ())
382 .and_then(|s| s.parse::<mime::Mime>().map_err(|_| ()))
383 .map(|mime_type| {
384 matches!(
385 (mime_type.type_(), mime_type.subtype()),
386 (mime::TEXT, mime::EVENT_STREAM)
387 )
388 })
389 .unwrap_or(false)
390 {
391 Ok(response)
392 } else {
393 Err(super::Error::InvalidContentType(content_type.clone()))
394 }
395}