1use std::{
7 pin::Pin,
8 task::{Context, Poll},
9 time::Duration,
10};
11
12use bytes::Bytes;
13use eventsource_stream::{Event as MessageEvent, EventStreamError, Eventsource};
14use futures::Stream;
15#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
16use futures::{future::BoxFuture, stream::BoxStream};
17#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
18use futures::{future::LocalBoxFuture, stream::LocalBoxStream};
19use futures_timer::Delay;
20use http::Response;
21use http::{HeaderName, HeaderValue, Request, StatusCode};
22use mime_guess::mime;
23use pin_project_lite::pin_project;
24
25use crate::{
26 http_client::{
27 HttpClientExt, Result as StreamResult, instance_error,
28 retry::{DEFAULT_RETRY, RetryPolicy},
29 },
30 wasm_compat::{WasmCompatSend, WasmCompatSendStream},
31};
32
33pub type BoxedStream = Pin<Box<dyn WasmCompatSendStream<InnerItem = StreamResult<Bytes>>>>;
34
35#[cfg(not(target_arch = "wasm32"))]
36type ResponseFuture<T> = BoxFuture<'static, Result<Response<T>, super::Error>>;
37#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
38type ResponseFuture<T> = LocalBoxFuture<'static, Result<Response<T>, super::Error>>;
39
40#[cfg(not(target_arch = "wasm32"))]
41type EventStream = BoxStream<'static, Result<MessageEvent, EventStreamError<super::Error>>>;
42#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
43type EventStream = LocalBoxStream<'static, Result<MessageEvent, EventStreamError<super::Error>>>;
44type BoxedRetry = Box<dyn RetryPolicy + Send + Unpin + 'static>;
45
46#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
48#[repr(u8)]
49pub enum ReadyState {
50 Connecting = 0,
52 Open = 1,
54 Closed = 2,
56}
57
58pin_project! {
59 #[project = GenericEventSourceProjection]
62 pub struct GenericEventSource<HttpClient, RequestBody, ResponseBody>
63 where
64 HttpClient: HttpClientExt,
65 {
66 client: HttpClient,
67 req: Request<RequestBody>,
68 #[pin]
69 next_response: Option<ResponseFuture<ResponseBody>>,
70 #[pin]
71 cur_stream: Option<EventStream>,
72 #[pin]
73 delay: Option<Delay>,
74 is_closed: bool,
75 retry_policy: BoxedRetry,
76 last_event_id: String,
77 last_retry: Option<(usize, Duration)>,
78 }
79}
80
81impl<HttpClient, RequestBody>
82 GenericEventSource<
83 HttpClient,
84 RequestBody,
85 Pin<Box<dyn WasmCompatSendStream<InnerItem = StreamResult<Bytes>>>>,
86 >
87where
88 HttpClient: HttpClientExt + Clone + 'static,
89 RequestBody: Into<Bytes> + Clone + Send + 'static,
90{
91 pub fn new(client: HttpClient, req: Request<RequestBody>) -> Self {
92 let client_clone = client.clone();
93 let mut req_clone = req.clone();
94 req_clone
95 .headers_mut()
96 .entry("Accept")
97 .or_insert(HeaderValue::from_static("text/event-stream"));
98 let res_fut = Box::pin(async move { client_clone.clone().send_streaming(req_clone).await });
99 Self {
100 client,
101 next_response: Some(res_fut),
102 cur_stream: None,
103 req,
104 delay: None,
105 is_closed: false,
106 retry_policy: Box::new(DEFAULT_RETRY),
107 last_event_id: String::new(),
108 last_retry: None,
109 }
110 }
111
112 pub fn close(&mut self) {
114 self.is_closed = true;
115 }
116
117 pub fn last_event_id(&self) -> &str {
119 &self.last_event_id
120 }
121
122 pub fn ready_state(&self) -> ReadyState {
124 if self.is_closed {
125 ReadyState::Closed
126 } else if self.delay.is_some() || self.next_response.is_some() {
127 ReadyState::Connecting
128 } else {
129 ReadyState::Open
130 }
131 }
132}
133
134impl<'a, HttpClient, RequestBody>
135 GenericEventSourceProjection<'a, HttpClient, RequestBody, BoxedStream>
136where
137 HttpClient: HttpClientExt + Clone + 'static,
138 RequestBody: Into<Bytes> + Clone + WasmCompatSend + 'static,
139{
140 fn clear_fetch(&mut self) {
141 self.next_response.take();
142 self.cur_stream.take();
143 }
144
145 fn retry_fetch(&mut self) -> Result<(), super::Error> {
146 self.cur_stream.take();
147 let mut req = self.req.clone();
148 req.headers_mut().insert(
149 HeaderName::from_static("last-event-id"),
150 HeaderValue::from_str(self.last_event_id).map_err(instance_error)?,
151 );
152 let client = self.client.clone();
153 let res_future = Box::pin(async move { client.send_streaming(req).await });
154 self.next_response.replace(res_future);
155 Ok(())
156 }
157
158 fn handle_response<T>(&mut self, res: Response<T>)
159 where
160 T: Stream<Item = StreamResult<Bytes>> + WasmCompatSend + 'static,
161 {
162 self.last_retry.take();
163 let mut stream = res.into_body().eventsource();
164 stream.set_last_event_id(self.last_event_id.clone());
165 self.cur_stream.replace(Box::pin(stream));
166 }
167
168 fn handle_event(&mut self, event: &eventsource_stream::Event) {
169 *self.last_event_id = event.id.clone();
170 if let Some(duration) = event.retry {
171 self.retry_policy.set_reconnection_time(duration)
172 }
173 }
174
175 fn handle_error(&mut self, error: &super::Error) {
176 self.clear_fetch();
177 if let Some(retry_delay) = self.retry_policy.retry(error, *self.last_retry) {
178 let retry_num = self
179 .last_retry
180 .map(|retry| retry.0.saturating_add(1))
181 .unwrap_or(1);
182 *self.last_retry = Some((retry_num, retry_delay));
183 self.delay.replace(Delay::new(retry_delay));
184 } else {
185 *self.is_closed = true;
186 }
187 }
188}
189
190#[derive(Debug, Clone, Eq, PartialEq)]
192pub enum Event {
193 Open,
195 Message(MessageEvent),
197}
198
199impl From<MessageEvent> for Event {
200 fn from(event: MessageEvent) -> Self {
201 Event::Message(event)
202 }
203}
204
205impl<HttpClient, RequestBody> Stream for GenericEventSource<HttpClient, RequestBody, BoxedStream>
206where
207 HttpClient: HttpClientExt + Clone + 'static,
208 RequestBody: Into<Bytes> + Clone + WasmCompatSend + 'static,
209{
210 type Item = Result<Event, super::Error>;
211
212 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
213 let mut this = self.project();
214
215 if *this.is_closed {
216 return Poll::Ready(None);
217 }
218
219 if let Some(delay) = this.delay.as_mut().as_pin_mut() {
220 match delay.poll(cx) {
221 Poll::Ready(_) => {
222 this.delay.take();
223 if let Err(err) = this.retry_fetch() {
224 *this.is_closed = true;
225 return Poll::Ready(Some(Err(err)));
226 }
227 }
228 Poll::Pending => return Poll::Pending,
229 }
230 }
231
232 if let Some(response_future) = this.next_response.as_mut().as_pin_mut() {
233 match response_future.poll(cx) {
234 Poll::Ready(Ok(res)) => {
235 this.clear_fetch();
236 match check_response(res) {
237 Ok(res) => {
238 this.handle_response(res);
239 return Poll::Ready(Some(Ok(Event::Open)));
240 }
241 Err(err) => {
242 *this.is_closed = true;
243 return Poll::Ready(Some(Err(err)));
244 }
245 }
246 }
247 Poll::Ready(Err(err)) => {
248 this.handle_error(&err);
249 return Poll::Ready(Some(Err(err)));
250 }
251 Poll::Pending => {
252 return Poll::Pending;
253 }
254 }
255 }
256
257 match this
258 .cur_stream
259 .as_mut()
260 .as_pin_mut()
261 .unwrap()
262 .as_mut()
263 .poll_next(cx)
264 {
265 Poll::Ready(Some(Err(err))) => {
266 let EventStreamError::Transport(err) = err else {
267 panic!("u");
268 };
269 this.handle_error(&err);
270 Poll::Ready(Some(Err(err)))
271 }
272 Poll::Ready(Some(Ok(event))) => {
273 this.handle_event(&event);
274 Poll::Ready(Some(Ok(event.into())))
275 }
276 Poll::Ready(None) => Poll::Ready(None),
277 Poll::Pending => Poll::Pending,
278 }
279 }
280}
281
282fn check_response<T>(response: Response<T>) -> Result<Response<T>, super::Error> {
283 match response.status() {
284 StatusCode::OK => {}
285 status => {
286 return Err(super::Error::InvalidStatusCode(status));
287 }
288 }
289 let content_type =
290 if let Some(content_type) = response.headers().get(&reqwest::header::CONTENT_TYPE) {
291 content_type
292 } else {
293 return Err(super::Error::InvalidContentType(HeaderValue::from_static(
294 "",
295 )));
296 };
297 if content_type
298 .to_str()
299 .map_err(|_| ())
300 .and_then(|s| s.parse::<mime::Mime>().map_err(|_| ()))
301 .map(|mime_type| {
302 matches!(
303 (mime_type.type_(), mime_type.subtype()),
304 (mime::TEXT, mime::EVENT_STREAM)
305 )
306 })
307 .unwrap_or(false)
308 {
309 Ok(response)
310 } else {
311 Err(super::Error::InvalidContentType(content_type.clone()))
312 }
313}