1use std::{
7 pin::Pin,
8 task::{Context, Poll},
9 time::Duration,
10};
11
12use bytes::Bytes;
13use eventsource_stream::{Event as MessageEvent, EventStreamError, Eventsource};
14use futures::Stream;
15#[cfg(not(target_arch = "wasm32"))]
16use futures::{future::BoxFuture, stream::BoxStream};
17#[cfg(target_arch = "wasm32")]
18use futures::{future::LocalBoxFuture, stream::LocalBoxStream};
19use futures_timer::Delay;
20use http::Response;
21use http::{HeaderName, HeaderValue, Request, StatusCode};
22use mime_guess::mime;
23use pin_project_lite::pin_project;
24
25use crate::{
26 http_client::{
27 HttpClientExt, Result as StreamResult, instance_error,
28 retry::{DEFAULT_RETRY, RetryPolicy},
29 },
30 wasm_compat::{WasmCompatSend, WasmCompatSendStream},
31};
32
33pub type BoxedStream = Pin<Box<dyn WasmCompatSendStream<InnerItem = StreamResult<Bytes>>>>;
34
35#[cfg(not(target_arch = "wasm32"))]
36type ResponseFuture<T> = BoxFuture<'static, Result<Response<T>, super::Error>>;
37#[cfg(target_arch = "wasm32")]
38type ResponseFuture<T> = LocalBoxFuture<'static, Result<Response<T>, super::Error>>;
39
40#[cfg(not(target_arch = "wasm32"))]
41type EventStream = BoxStream<'static, Result<MessageEvent, EventStreamError<super::Error>>>;
42#[cfg(target_arch = "wasm32")]
43type EventStream = LocalBoxStream<'static, Result<MessageEvent, EventStreamError<super::Error>>>;
44type BoxedRetry = Box<dyn RetryPolicy + Send + Unpin + 'static>;
45
46#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
48#[repr(u8)]
49pub enum ReadyState {
50 Connecting = 0,
52 Open = 1,
54 Closed = 2,
56}
57
58pin_project! {
59 #[project = GenericEventSourceProjection]
62 pub struct GenericEventSource<HttpClient, RequestBody, ResponseBody>
63 where
64 HttpClient: HttpClientExt,
65 {
66 client: HttpClient,
67 req: Request<RequestBody>,
68 #[pin]
69 next_response: Option<ResponseFuture<ResponseBody>>,
70 #[pin]
71 cur_stream: Option<EventStream>,
72 #[pin]
73 delay: Option<Delay>,
74 is_closed: bool,
75 retry_policy: BoxedRetry,
76 last_event_id: String,
77 last_retry: Option<(usize, Duration)>,
78 }
79}
80
81impl<HttpClient, RequestBody>
82 GenericEventSource<
83 HttpClient,
84 RequestBody,
85 Pin<Box<dyn WasmCompatSendStream<InnerItem = StreamResult<Bytes>>>>,
86 >
87where
88 HttpClient: HttpClientExt + Clone + 'static,
89 RequestBody: Into<Bytes> + Clone + Send + 'static,
90{
91 pub fn new(client: HttpClient, req: Request<RequestBody>) -> Self {
92 let client_clone = client.clone();
93 let mut req_clone = req.clone();
94 req_clone
95 .headers_mut()
96 .insert("Accept", HeaderValue::from_static("text/event-stream"));
97 let res_fut = Box::pin(async move { client_clone.clone().send_streaming(req_clone).await });
98 Self {
99 client,
100 next_response: Some(res_fut),
101 cur_stream: None,
102 req,
103 delay: None,
104 is_closed: false,
105 retry_policy: Box::new(DEFAULT_RETRY),
106 last_event_id: String::new(),
107 last_retry: None,
108 }
109 }
110
111 pub fn close(&mut self) {
113 self.is_closed = true;
114 }
115
116 pub fn last_event_id(&self) -> &str {
118 &self.last_event_id
119 }
120
121 pub fn ready_state(&self) -> ReadyState {
123 if self.is_closed {
124 ReadyState::Closed
125 } else if self.delay.is_some() || self.next_response.is_some() {
126 ReadyState::Connecting
127 } else {
128 ReadyState::Open
129 }
130 }
131}
132
133impl<'a, HttpClient, RequestBody>
134 GenericEventSourceProjection<'a, HttpClient, RequestBody, BoxedStream>
135where
136 HttpClient: HttpClientExt + Clone + 'static,
137 RequestBody: Into<Bytes> + Clone + WasmCompatSend + 'static,
138{
139 fn clear_fetch(&mut self) {
140 self.next_response.take();
141 self.cur_stream.take();
142 }
143
144 fn retry_fetch(&mut self) -> Result<(), super::Error> {
145 self.cur_stream.take();
146 let mut req = self.req.clone();
147 req.headers_mut().insert(
148 HeaderName::from_static("last-event-id"),
149 HeaderValue::from_str(self.last_event_id).map_err(instance_error)?,
150 );
151 let client = self.client.clone();
152 let res_future = Box::pin(async move { client.send_streaming(req).await });
153 self.next_response.replace(res_future);
154 Ok(())
155 }
156
157 fn handle_response<T>(&mut self, res: Response<T>)
158 where
159 T: Stream<Item = StreamResult<Bytes>> + WasmCompatSend + 'static,
160 {
161 self.last_retry.take();
162 let mut stream = res.into_body().eventsource();
163 stream.set_last_event_id(self.last_event_id.clone());
164 self.cur_stream.replace(Box::pin(stream));
165 }
166
167 fn handle_event(&mut self, event: &eventsource_stream::Event) {
168 *self.last_event_id = event.id.clone();
169 if let Some(duration) = event.retry {
170 self.retry_policy.set_reconnection_time(duration)
171 }
172 }
173
174 fn handle_error(&mut self, error: &super::Error) {
175 self.clear_fetch();
176 if let Some(retry_delay) = self.retry_policy.retry(error, *self.last_retry) {
177 let retry_num = self
178 .last_retry
179 .map(|retry| retry.0.saturating_add(1))
180 .unwrap_or(1);
181 *self.last_retry = Some((retry_num, retry_delay));
182 self.delay.replace(Delay::new(retry_delay));
183 } else {
184 *self.is_closed = true;
185 }
186 }
187}
188
189#[derive(Debug, Clone, Eq, PartialEq)]
191pub enum Event {
192 Open,
194 Message(MessageEvent),
196}
197
198impl From<MessageEvent> for Event {
199 fn from(event: MessageEvent) -> Self {
200 Event::Message(event)
201 }
202}
203
204impl<HttpClient, RequestBody> Stream for GenericEventSource<HttpClient, RequestBody, BoxedStream>
205where
206 HttpClient: HttpClientExt + Clone + 'static,
207 RequestBody: Into<Bytes> + Clone + WasmCompatSend + 'static,
208{
209 type Item = Result<Event, super::Error>;
210
211 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
212 let mut this = self.project();
213
214 if *this.is_closed {
215 return Poll::Ready(None);
216 }
217
218 if let Some(delay) = this.delay.as_mut().as_pin_mut() {
219 match delay.poll(cx) {
220 Poll::Ready(_) => {
221 this.delay.take();
222 if let Err(err) = this.retry_fetch() {
223 *this.is_closed = true;
224 return Poll::Ready(Some(Err(err)));
225 }
226 }
227 Poll::Pending => return Poll::Pending,
228 }
229 }
230
231 if let Some(response_future) = this.next_response.as_mut().as_pin_mut() {
232 match response_future.poll(cx) {
233 Poll::Ready(Ok(res)) => {
234 this.clear_fetch();
235 match check_response(res) {
236 Ok(res) => {
237 this.handle_response(res);
238 return Poll::Ready(Some(Ok(Event::Open)));
239 }
240 Err(err) => {
241 *this.is_closed = true;
242 return Poll::Ready(Some(Err(err)));
243 }
244 }
245 }
246 Poll::Ready(Err(err)) => {
247 this.handle_error(&err);
248 return Poll::Ready(Some(Err(err)));
249 }
250 Poll::Pending => {
251 return Poll::Pending;
252 }
253 }
254 }
255
256 match this
257 .cur_stream
258 .as_mut()
259 .as_pin_mut()
260 .unwrap()
261 .as_mut()
262 .poll_next(cx)
263 {
264 Poll::Ready(Some(Err(err))) => {
265 let EventStreamError::Transport(err) = err else {
266 panic!("u");
267 };
268 this.handle_error(&err);
269 Poll::Ready(Some(Err(err)))
270 }
271 Poll::Ready(Some(Ok(event))) => {
272 this.handle_event(&event);
273 Poll::Ready(Some(Ok(event.into())))
274 }
275 Poll::Ready(None) => Poll::Ready(None),
276 Poll::Pending => Poll::Pending,
277 }
278 }
279}
280
281fn check_response<T>(response: Response<T>) -> Result<Response<T>, super::Error> {
282 match response.status() {
283 StatusCode::OK => {}
284 status => {
285 return Err(super::Error::InvalidStatusCode(status));
286 }
287 }
288 let content_type =
289 if let Some(content_type) = response.headers().get(&reqwest::header::CONTENT_TYPE) {
290 content_type
291 } else {
292 return Err(super::Error::InvalidContentType(HeaderValue::from_static(
293 "",
294 )));
295 };
296 if content_type
297 .to_str()
298 .map_err(|_| ())
299 .and_then(|s| s.parse::<mime::Mime>().map_err(|_| ()))
300 .map(|mime_type| {
301 matches!(
302 (mime_type.type_(), mime_type.subtype()),
303 (mime::TEXT, mime::EVENT_STREAM)
304 )
305 })
306 .unwrap_or(false)
307 {
308 Ok(response)
309 } else {
310 Err(super::Error::InvalidContentType(content_type.clone()))
311 }
312}