1use base64::prelude::*;
2
3use futures::{ready, Stream};
4use hyper::{
5 body::HttpBody,
6 client::{
7 connect::{Connect, Connection},
8 ResponseFuture,
9 },
10 header::{HeaderMap, HeaderName, HeaderValue},
11 service::Service,
12 Body, Request, Uri,
13};
14use log::{debug, info, trace, warn};
15use pin_project::pin_project;
16use std::{
17 boxed,
18 fmt::{self, Debug, Formatter},
19 future::Future,
20 io::ErrorKind,
21 pin::{pin, Pin},
22 str::FromStr,
23 task::{Context, Poll},
24 time::{Duration, Instant},
25};
26
27use tokio::{
28 io::{AsyncRead, AsyncWrite},
29 time::Sleep,
30};
31
32use crate::{
33 config::ReconnectOptions,
34 response::{ErrorBody, Response},
35};
36use crate::{
37 error::{Error, Result},
38 event_parser::ConnectionDetails,
39};
40
41use hyper::client::HttpConnector;
42use hyper_timeout::TimeoutConnector;
43
44use crate::event_parser::EventParser;
45use crate::event_parser::SSE;
46
47use crate::retry::{BackoffRetry, RetryStrategy};
48use std::error::Error as StdError;
49
50#[cfg(feature = "rustls")]
51use hyper_rustls::HttpsConnectorBuilder;
52
53type BoxError = Box<dyn std::error::Error + Send + Sync>;
54
55pub type BoxStream<T> = Pin<boxed::Box<dyn Stream<Item = T> + Send + Sync>>;
57
58pub trait Client: Send + Sync + private::Sealed {
61 fn stream(&self) -> BoxStream<Result<SSE>>;
62}
63
64pub const DEFAULT_REDIRECT_LIMIT: u32 = 16;
72
73pub struct ClientBuilder {
75 url: Uri,
76 headers: HeaderMap,
77 reconnect_opts: ReconnectOptions,
78 connect_timeout: Option<Duration>,
79 read_timeout: Option<Duration>,
80 write_timeout: Option<Duration>,
81 last_event_id: Option<String>,
82 method: String,
83 body: Option<String>,
84 max_redirects: Option<u32>,
85}
86
87impl ClientBuilder {
88 pub fn for_url(url: &str) -> Result<ClientBuilder> {
90 let url = url
91 .parse()
92 .map_err(|e| Error::InvalidParameter(Box::new(e)))?;
93
94 let mut header_map = HeaderMap::new();
95 header_map.insert("Accept", HeaderValue::from_static("text/event-stream"));
96 header_map.insert("Cache-Control", HeaderValue::from_static("no-cache"));
97
98 Ok(ClientBuilder {
99 url,
100 headers: header_map,
101 reconnect_opts: ReconnectOptions::default(),
102 connect_timeout: None,
103 read_timeout: None,
104 write_timeout: None,
105 last_event_id: None,
106 method: String::from("GET"),
107 max_redirects: None,
108 body: None,
109 })
110 }
111
112 pub fn method(mut self, method: String) -> ClientBuilder {
114 self.method = method;
115 self
116 }
117
118 pub fn body(mut self, body: String) -> ClientBuilder {
120 self.body = Some(body);
121 self
122 }
123
124 pub fn last_event_id(mut self, last_event_id: String) -> ClientBuilder {
127 self.last_event_id = Some(last_event_id);
128 self
129 }
130
131 pub fn header(mut self, name: &str, value: &str) -> Result<ClientBuilder> {
133 let name = HeaderName::from_str(name).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
134
135 let value =
136 HeaderValue::from_str(value).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
137
138 self.headers.insert(name, value);
139 Ok(self)
140 }
141
142 pub fn basic_auth(self, username: &str, password: &str) -> Result<ClientBuilder> {
144 let auth = format!("{}:{}", username, password);
145 let encoded = BASE64_STANDARD.encode(auth);
146 let value = format!("Basic {}", encoded);
147
148 self.header("Authorization", &value)
149 }
150
151 pub fn connect_timeout(mut self, connect_timeout: Duration) -> ClientBuilder {
154 self.connect_timeout = Some(connect_timeout);
155 self
156 }
157
158 pub fn read_timeout(mut self, read_timeout: Duration) -> ClientBuilder {
160 self.read_timeout = Some(read_timeout);
161 self
162 }
163
164 pub fn write_timeout(mut self, write_timeout: Duration) -> ClientBuilder {
166 self.write_timeout = Some(write_timeout);
167 self
168 }
169
170 pub fn reconnect(mut self, opts: ReconnectOptions) -> ClientBuilder {
175 self.reconnect_opts = opts;
176 self
177 }
178
179 pub fn redirect_limit(mut self, limit: u32) -> ClientBuilder {
183 self.max_redirects = Some(limit);
184 self
185 }
186
187 pub fn build_with_conn<C>(self, conn: C) -> impl Client
189 where
190 C: Service<Uri> + Clone + Send + Sync + 'static,
191 C::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin,
192 C::Future: Send + 'static,
193 C::Error: Into<BoxError>,
194 {
195 let mut connector = TimeoutConnector::new(conn);
196 connector.set_connect_timeout(self.connect_timeout);
197 connector.set_read_timeout(self.read_timeout);
198 connector.set_write_timeout(self.write_timeout);
199
200 let client = hyper::Client::builder().build::<_, hyper::Body>(connector);
201
202 ClientImpl {
203 http: client,
204 request_props: RequestProps {
205 url: self.url,
206 headers: self.headers,
207 method: self.method,
208 body: self.body,
209 reconnect_opts: self.reconnect_opts,
210 max_redirects: self.max_redirects.unwrap_or(DEFAULT_REDIRECT_LIMIT),
211 },
212 last_event_id: self.last_event_id,
213 }
214 }
215
216 pub fn build_http(self) -> impl Client {
218 self.build_with_conn(HttpConnector::new())
219 }
220
221 #[cfg(feature = "rustls")]
222 pub fn build(self) -> impl Client {
224 let conn = HttpsConnectorBuilder::new()
225 .with_native_roots()
226 .https_or_http()
227 .enable_http1()
228 .enable_http2()
229 .build();
230
231 self.build_with_conn(conn)
232 }
233
234 pub fn build_with_http_client<C>(self, http: hyper::Client<C>) -> impl Client
236 where
237 C: Connect + Clone + Send + Sync + 'static,
238 {
239 ClientImpl {
240 http,
241 request_props: RequestProps {
242 url: self.url,
243 headers: self.headers,
244 method: self.method,
245 body: self.body,
246 reconnect_opts: self.reconnect_opts,
247 max_redirects: self.max_redirects.unwrap_or(DEFAULT_REDIRECT_LIMIT),
248 },
249 last_event_id: self.last_event_id,
250 }
251 }
252}
253
254#[derive(Clone)]
255struct RequestProps {
256 url: Uri,
257 headers: HeaderMap,
258 method: String,
259 body: Option<String>,
260 reconnect_opts: ReconnectOptions,
261 max_redirects: u32,
262}
263
264struct ClientImpl<C> {
268 http: hyper::Client<C>,
269 request_props: RequestProps,
270 last_event_id: Option<String>,
271}
272
273impl<C> Client for ClientImpl<C>
274where
275 C: Connect + Clone + Send + Sync + 'static,
276{
277 fn stream(&self) -> BoxStream<Result<SSE>> {
285 Box::pin(ReconnectingRequest::new(
286 self.http.clone(),
287 self.request_props.clone(),
288 self.last_event_id.clone(),
289 ))
290 }
291}
292
293#[allow(clippy::large_enum_variant)] #[pin_project(project = StateProj)]
295enum State {
296 New,
297 Connecting {
298 retry: bool,
299 #[pin]
300 resp: ResponseFuture,
301 },
302 Connected(#[pin] hyper::Body),
303 WaitingToReconnect(#[pin] Sleep),
304 FollowingRedirect(Option<HeaderValue>),
305 StreamClosed,
306}
307
308impl State {
309 fn name(&self) -> &'static str {
310 match self {
311 State::New => "new",
312 State::Connecting { retry: false, .. } => "connecting(no-retry)",
313 State::Connecting { retry: true, .. } => "connecting(retry)",
314 State::Connected(_) => "connected",
315 State::WaitingToReconnect(_) => "waiting-to-reconnect",
316 State::FollowingRedirect(_) => "following-redirect",
317 State::StreamClosed => "closed",
318 }
319 }
320}
321
322impl Debug for State {
323 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
324 write!(f, "{}", self.name())
325 }
326}
327
328#[must_use = "streams do nothing unless polled"]
329#[pin_project]
330pub struct ReconnectingRequest<C> {
331 http: hyper::Client<C>,
332 props: RequestProps,
333 #[pin]
334 state: State,
335 retry_strategy: Box<dyn RetryStrategy + Send + Sync>,
336 current_url: Uri,
337 redirect_count: u32,
338 event_parser: EventParser,
339 last_event_id: Option<String>,
340 #[pin]
341 initial_connection: bool,
342}
343
344impl<C> ReconnectingRequest<C> {
345 fn new(
346 http: hyper::Client<C>,
347 props: RequestProps,
348 last_event_id: Option<String>,
349 ) -> ReconnectingRequest<C> {
350 let reconnect_delay = props.reconnect_opts.delay;
351 let delay_max = props.reconnect_opts.delay_max;
352 let backoff_factor = props.reconnect_opts.backoff_factor;
353
354 let url = props.url.clone();
355 ReconnectingRequest {
356 props,
357 http,
358 state: State::New,
359 retry_strategy: Box::new(BackoffRetry::new(
360 reconnect_delay,
361 delay_max,
362 backoff_factor,
363 true,
364 )),
365 redirect_count: 0,
366 current_url: url,
367 event_parser: EventParser::new(),
368 last_event_id,
369 initial_connection: true,
370 }
371 }
372
373 fn send_request(&self) -> Result<ResponseFuture>
374 where
375 C: Connect + Clone + Send + Sync + 'static,
376 {
377 let mut request_builder = Request::builder()
378 .method(self.props.method.as_str())
379 .uri(&self.current_url);
380
381 for (name, value) in &self.props.headers {
382 request_builder = request_builder.header(name, value);
383 }
384
385 if let Some(id) = self.last_event_id.as_ref() {
386 if !id.is_empty() {
387 let id_as_header =
388 HeaderValue::from_str(id).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
389
390 request_builder = request_builder.header("last-event-id", id_as_header);
391 }
392 }
393
394 let body = match &self.props.body {
395 Some(body) => Body::from(body.to_string()),
396 None => Body::empty(),
397 };
398
399 let request = request_builder
400 .body(body)
401 .map_err(|e| Error::InvalidParameter(Box::new(e)))?;
402
403 Ok(self.http.request(request))
404 }
405
406 fn reset_redirects(self: Pin<&mut Self>) {
407 let url = self.props.url.clone();
408 let this = self.project();
409 *this.current_url = url;
410 *this.redirect_count = 0;
411 }
412
413 fn increment_redirect_counter(self: Pin<&mut Self>) -> bool {
414 if self.redirect_count == self.props.max_redirects {
415 return false;
416 }
417 *self.project().redirect_count += 1;
418 true
419 }
420}
421
422impl<C> Stream for ReconnectingRequest<C>
423where
424 C: Connect + Clone + Send + Sync + 'static,
425{
426 type Item = Result<SSE>;
427
428 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
429 trace!("ReconnectingRequest::poll({:?})", &self.state);
430
431 loop {
432 let this = self.as_mut().project();
433 if let Some(event) = this.event_parser.get_event() {
434 return match event {
435 SSE::Connected(_) => Poll::Ready(Some(Ok(event))),
436 SSE::Event(ref evt) => {
437 this.last_event_id.clone_from(&evt.id);
438
439 if let Some(retry) = evt.retry {
440 this.retry_strategy
441 .change_base_delay(Duration::from_millis(retry));
442 }
443 Poll::Ready(Some(Ok(event)))
444 }
445 SSE::Comment(_) => Poll::Ready(Some(Ok(event))),
446 };
447 }
448
449 trace!("ReconnectingRequest::poll loop({:?})", &this.state);
450
451 let state = this.state.project();
452 match state {
453 StateProj::StreamClosed => return Poll::Ready(None),
454 StateProj::New => {
457 *self.as_mut().project().event_parser = EventParser::new();
458 match self.send_request() {
459 Ok(resp) => {
460 let retry = if self.initial_connection {
461 self.props.reconnect_opts.retry_initial
462 } else {
463 self.props.reconnect_opts.reconnect
464 };
465 self.as_mut()
466 .project()
467 .state
468 .set(State::Connecting { resp, retry })
469 }
470 Err(e) => {
471 self.as_mut().project().state.set(State::StreamClosed);
474 return Poll::Ready(Some(Err(e)));
475 }
476 }
477 }
478 StateProj::Connecting { retry, resp } => match ready!(resp.poll(cx)) {
479 Ok(resp) => {
480 debug!("HTTP response: {:#?}", resp);
481
482 if resp.status().is_success() {
483 self.as_mut().project().retry_strategy.reset(Instant::now());
484 self.as_mut().reset_redirects();
485
486 let status = resp.status();
487 let headers = resp.headers().clone();
488
489 self.as_mut()
490 .project()
491 .state
492 .set(State::Connected(resp.into_body()));
493 self.as_mut().project().initial_connection.set(false);
494
495 return Poll::Ready(Some(Ok(SSE::Connected(ConnectionDetails::new(
496 Response::new(status, headers),
497 )))));
498 }
499
500 if resp.status() == 301 || resp.status() == 307 {
501 debug!("got redirected ({})", resp.status());
502
503 if self.as_mut().increment_redirect_counter() {
504 debug!("following redirect {}", self.redirect_count);
505
506 self.as_mut().project().state.set(State::FollowingRedirect(
507 resp.headers().get(hyper::header::LOCATION).cloned(),
508 ));
509 continue;
510 } else {
511 debug!("redirect limit reached ({})", self.props.max_redirects);
512
513 self.as_mut().project().state.set(State::StreamClosed);
514 return Poll::Ready(Some(Err(Error::MaxRedirectLimitReached(
515 self.props.max_redirects,
516 ))));
517 }
518 }
519
520 let error = Error::UnexpectedResponse(
521 Response::new(resp.status(), resp.headers().clone()),
522 ErrorBody::new(resp.into_body()),
523 );
524
525 if !*retry {
526 self.as_mut().project().state.set(State::StreamClosed);
527 return Poll::Ready(Some(Err(error)));
528 }
529
530 self.as_mut().reset_redirects();
531
532 let duration = self
533 .as_mut()
534 .project()
535 .retry_strategy
536 .next_delay(Instant::now());
537
538 self.as_mut()
539 .project()
540 .state
541 .set(State::WaitingToReconnect(delay(duration, "retrying")));
542
543 return Poll::Ready(Some(Err(error)));
544 }
545 Err(e) => {
546 warn!("request returned an error: {}", e);
548 if !*retry {
549 self.as_mut().project().state.set(State::StreamClosed);
550 return Poll::Ready(Some(Err(Error::HttpStream(Box::new(e)))));
551 }
552
553 let duration = self
554 .as_mut()
555 .project()
556 .retry_strategy
557 .next_delay(Instant::now());
558
559 self.as_mut()
560 .project()
561 .state
562 .set(State::WaitingToReconnect(delay(duration, "retrying")));
563 }
564 },
565 StateProj::FollowingRedirect(maybe_header) => match uri_from_header(maybe_header) {
566 Ok(uri) => {
567 *self.as_mut().project().current_url = uri;
568 self.as_mut().project().state.set(State::New);
569 }
570 Err(e) => {
571 self.as_mut().project().state.set(State::StreamClosed);
572 return Poll::Ready(Some(Err(e)));
573 }
574 },
575 StateProj::Connected(body) => match ready!(body.poll_data(cx)) {
576 Some(Ok(result)) => {
577 this.event_parser.process_bytes(result)?;
578 continue;
579 }
580 Some(Err(e)) => {
581 if self.props.reconnect_opts.reconnect {
582 let duration = self
583 .as_mut()
584 .project()
585 .retry_strategy
586 .next_delay(Instant::now());
587 self.as_mut()
588 .project()
589 .state
590 .set(State::WaitingToReconnect(delay(duration, "reconnecting")));
591 }
592
593 if let Some(cause) = e.source() {
594 if let Some(downcast) = cause.downcast_ref::<std::io::Error>() {
595 if let std::io::ErrorKind::TimedOut = downcast.kind() {
596 return Poll::Ready(Some(Err(Error::TimedOut)));
597 }
598 }
599 } else {
600 return Poll::Ready(Some(Err(Error::HttpStream(Box::new(e)))));
601 }
602 }
603 None => {
604 let duration = self
605 .as_mut()
606 .project()
607 .retry_strategy
608 .next_delay(Instant::now());
609 self.as_mut()
610 .project()
611 .state
612 .set(State::WaitingToReconnect(delay(duration, "retrying")));
613
614 if self.event_parser.was_processing() {
615 return Poll::Ready(Some(Err(Error::UnexpectedEof)));
616 }
617 return Poll::Ready(Some(Err(Error::Eof)));
618 }
619 },
620 StateProj::WaitingToReconnect(delay) => {
621 ready!(delay.poll(cx));
622 info!("Reconnecting");
623 self.as_mut().project().state.set(State::New);
624 }
625 };
626 }
627 }
628}
629
630fn uri_from_header(maybe_header: &Option<HeaderValue>) -> Result<Uri> {
631 let header = maybe_header.as_ref().ok_or_else(|| {
632 Error::MalformedLocationHeader(Box::new(std::io::Error::new(
633 ErrorKind::NotFound,
634 "missing Location header",
635 )))
636 })?;
637
638 let header_string = header
639 .to_str()
640 .map_err(|e| Error::MalformedLocationHeader(Box::new(e)))?;
641
642 header_string
643 .parse::<Uri>()
644 .map_err(|e| Error::MalformedLocationHeader(Box::new(e)))
645}
646
647fn delay(dur: Duration, description: &str) -> Sleep {
648 info!("Waiting {:?} before {}", dur, description);
649 tokio::time::sleep(dur)
650}
651
652mod private {
653 use crate::client::ClientImpl;
654
655 pub trait Sealed {}
656 impl<C> Sealed for ClientImpl<C> {}
657}
658
659#[cfg(test)]
660mod tests {
661 use crate::ClientBuilder;
662 use hyper::http::HeaderValue;
663 use test_case::test_case;
664
665 #[test_case("user", "pass", "dXNlcjpwYXNz")]
666 #[test_case("user1", "password123", "dXNlcjE6cGFzc3dvcmQxMjM=")]
667 #[test_case("user2", "", "dXNlcjI6")]
668 #[test_case("user@name", "pass#word!", "dXNlckBuYW1lOnBhc3Mjd29yZCE=")]
669 #[test_case("user3", "my pass", "dXNlcjM6bXkgcGFzcw==")]
670 #[test_case(
671 "weird@-/:stuff",
672 "goes@-/:here",
673 "d2VpcmRALS86c3R1ZmY6Z29lc0AtLzpoZXJl"
674 )]
675 fn basic_auth_generates_correct_headers(username: &str, password: &str, expected: &str) {
676 let builder = ClientBuilder::for_url("http://example.com")
677 .expect("failed to build client")
678 .basic_auth(username, password)
679 .expect("failed to add authentication");
680
681 let actual = builder.headers.get("Authorization");
682 let expected = HeaderValue::from_str(format!("Basic {}", expected).as_str())
683 .expect("unable to create expected header");
684
685 assert_eq!(Some(&expected), actual);
686 }
687
688 use std::{pin::pin, str::FromStr, time::Duration};
689
690 use futures::TryStreamExt;
691 use hyper::{client::HttpConnector, Body, HeaderMap, Request, Uri};
692 use hyper_timeout::TimeoutConnector;
693 use tokio::time::timeout;
694
695 use crate::{
696 client::{RequestProps, State},
697 ReconnectOptionsBuilder, ReconnectingRequest,
698 };
699
700 const INVALID_URI: &'static str = "http://mycrazyunexsistenturl.invaliddomainext";
701
702 #[test_case(INVALID_URI, false, |state| matches!(state, State::StreamClosed))]
703 #[test_case(INVALID_URI, true, |state| matches!(state, State::WaitingToReconnect(_)))]
704 #[tokio::test]
705 async fn initial_connection(uri: &str, retry_initial: bool, expected: fn(&State) -> bool) {
706 let default_timeout = Some(Duration::from_secs(1));
707 let conn = HttpConnector::new();
708 let mut connector = TimeoutConnector::new(conn);
709 connector.set_connect_timeout(default_timeout);
710 connector.set_read_timeout(default_timeout);
711 connector.set_write_timeout(default_timeout);
712
713 let reconnect_opts = ReconnectOptionsBuilder::new(false)
714 .backoff_factor(1)
715 .delay(Duration::from_secs(1))
716 .retry_initial(retry_initial)
717 .build();
718
719 let http = hyper::Client::builder().build::<_, hyper::Body>(connector);
720 let req_props = RequestProps {
721 url: Uri::from_str(uri).unwrap(),
722 headers: HeaderMap::new(),
723 method: "GET".to_string(),
724 body: None,
725 reconnect_opts,
726 max_redirects: 10,
727 };
728
729 let mut reconnecting_request = ReconnectingRequest::new(http, req_props, None);
730
731 let resp = reconnecting_request.http.request(
733 Request::builder()
734 .method("GET")
735 .uri(uri)
736 .body(Body::empty())
737 .unwrap(),
738 );
739
740 reconnecting_request.state = State::Connecting {
741 retry: reconnecting_request.props.reconnect_opts.retry_initial,
742 resp,
743 };
744
745 let mut reconnecting_request = pin!(reconnecting_request);
746
747 timeout(Duration::from_millis(500), reconnecting_request.try_next())
748 .await
749 .ok();
750
751 assert!(expected(&reconnecting_request.state));
752 }
753
754 #[test_case(false, |state| matches!(state, State::StreamClosed))]
755 #[test_case(true, |state| matches!(state, State::WaitingToReconnect(_)))]
756 #[tokio::test]
757 async fn initial_connection_mocked_server(retry_initial: bool, expected: fn(&State) -> bool) {
758 let mut mock_server = mockito::Server::new_async().await;
759 let _mock = mock_server
760 .mock("GET", "/")
761 .with_status(404)
762 .create_async()
763 .await;
764
765 initial_connection(&mock_server.url(), retry_initial, expected).await;
766 }
767}