eventsource_client/
client.rs

1use base64::prelude::*;
2
3use futures::{ready, Stream};
4use hyper::{
5    body::HttpBody,
6    client::{
7        connect::{Connect, Connection},
8        ResponseFuture,
9    },
10    header::{HeaderMap, HeaderName, HeaderValue},
11    service::Service,
12    Body, Request, Uri,
13};
14use log::{debug, info, trace, warn};
15use pin_project::pin_project;
16use std::{
17    boxed,
18    fmt::{self, Debug, Formatter},
19    future::Future,
20    io::ErrorKind,
21    pin::{pin, Pin},
22    str::FromStr,
23    task::{Context, Poll},
24    time::{Duration, Instant},
25};
26
27use tokio::{
28    io::{AsyncRead, AsyncWrite},
29    time::Sleep,
30};
31
32use crate::{
33    config::ReconnectOptions,
34    response::{ErrorBody, Response},
35};
36use crate::{
37    error::{Error, Result},
38    event_parser::ConnectionDetails,
39};
40
41use hyper::client::HttpConnector;
42use hyper_timeout::TimeoutConnector;
43
44use crate::event_parser::EventParser;
45use crate::event_parser::SSE;
46
47use crate::retry::{BackoffRetry, RetryStrategy};
48use std::error::Error as StdError;
49
50#[cfg(feature = "rustls")]
51use hyper_rustls::HttpsConnectorBuilder;
52
53type BoxError = Box<dyn std::error::Error + Send + Sync>;
54
55/// Represents a [`Pin`]'d [`Send`] + [`Sync`] stream, returned by [`Client`]'s stream method.
56pub type BoxStream<T> = Pin<boxed::Box<dyn Stream<Item = T> + Send + Sync>>;
57
58/// Client is the Server-Sent-Events interface.
59/// This trait is sealed and cannot be implemented for types outside this crate.
60pub trait Client: Send + Sync + private::Sealed {
61    fn stream(&self) -> BoxStream<Result<SSE>>;
62}
63
64/*
65 * TODO remove debug output
66 * TODO specify list of stati to not retry (e.g. 204)
67 */
68
69/// Maximum amount of redirects that the client will follow before
70/// giving up, if not overridden via [ClientBuilder::redirect_limit].
71pub const DEFAULT_REDIRECT_LIMIT: u32 = 16;
72
73/// ClientBuilder provides a series of builder methods to easily construct a [`Client`].
74pub struct ClientBuilder {
75    url: Uri,
76    headers: HeaderMap,
77    reconnect_opts: ReconnectOptions,
78    connect_timeout: Option<Duration>,
79    read_timeout: Option<Duration>,
80    write_timeout: Option<Duration>,
81    last_event_id: Option<String>,
82    method: String,
83    body: Option<String>,
84    max_redirects: Option<u32>,
85}
86
87impl ClientBuilder {
88    /// Create a builder for a given URL.
89    pub fn for_url(url: &str) -> Result<ClientBuilder> {
90        let url = url
91            .parse()
92            .map_err(|e| Error::InvalidParameter(Box::new(e)))?;
93
94        let mut header_map = HeaderMap::new();
95        header_map.insert("Accept", HeaderValue::from_static("text/event-stream"));
96        header_map.insert("Cache-Control", HeaderValue::from_static("no-cache"));
97
98        Ok(ClientBuilder {
99            url,
100            headers: header_map,
101            reconnect_opts: ReconnectOptions::default(),
102            connect_timeout: None,
103            read_timeout: None,
104            write_timeout: None,
105            last_event_id: None,
106            method: String::from("GET"),
107            max_redirects: None,
108            body: None,
109        })
110    }
111
112    /// Set the request method used for the initial connection to the SSE endpoint.
113    pub fn method(mut self, method: String) -> ClientBuilder {
114        self.method = method;
115        self
116    }
117
118    /// Set the request body used for the initial connection to the SSE endpoint.
119    pub fn body(mut self, body: String) -> ClientBuilder {
120        self.body = Some(body);
121        self
122    }
123
124    /// Set the last event id for a stream when it is created. If it is set, it will be sent to the
125    /// server in case it can replay missed events.
126    pub fn last_event_id(mut self, last_event_id: String) -> ClientBuilder {
127        self.last_event_id = Some(last_event_id);
128        self
129    }
130
131    /// Set a HTTP header on the SSE request.
132    pub fn header(mut self, name: &str, value: &str) -> Result<ClientBuilder> {
133        let name = HeaderName::from_str(name).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
134
135        let value =
136            HeaderValue::from_str(value).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
137
138        self.headers.insert(name, value);
139        Ok(self)
140    }
141
142    /// Set the Authorization header with the calculated basic authentication value.
143    pub fn basic_auth(self, username: &str, password: &str) -> Result<ClientBuilder> {
144        let auth = format!("{}:{}", username, password);
145        let encoded = BASE64_STANDARD.encode(auth);
146        let value = format!("Basic {}", encoded);
147
148        self.header("Authorization", &value)
149    }
150
151    /// Set a connect timeout for the underlying connection. There is no connect timeout by
152    /// default.
153    pub fn connect_timeout(mut self, connect_timeout: Duration) -> ClientBuilder {
154        self.connect_timeout = Some(connect_timeout);
155        self
156    }
157
158    /// Set a read timeout for the underlying connection. There is no read timeout by default.
159    pub fn read_timeout(mut self, read_timeout: Duration) -> ClientBuilder {
160        self.read_timeout = Some(read_timeout);
161        self
162    }
163
164    /// Set a write timeout for the underlying connection. There is no write timeout by default.
165    pub fn write_timeout(mut self, write_timeout: Duration) -> ClientBuilder {
166        self.write_timeout = Some(write_timeout);
167        self
168    }
169
170    /// Configure the client's reconnect behaviour according to the supplied
171    /// [`ReconnectOptions`].
172    ///
173    /// [`ReconnectOptions`]: struct.ReconnectOptions.html
174    pub fn reconnect(mut self, opts: ReconnectOptions) -> ClientBuilder {
175        self.reconnect_opts = opts;
176        self
177    }
178
179    /// Customize the client's following behavior when served a redirect.
180    /// To disable following redirects, pass `0`.
181    /// By default, the limit is [`DEFAULT_REDIRECT_LIMIT`].
182    pub fn redirect_limit(mut self, limit: u32) -> ClientBuilder {
183        self.max_redirects = Some(limit);
184        self
185    }
186
187    /// Build with a specific client connector.
188    pub fn build_with_conn<C>(self, conn: C) -> impl Client
189    where
190        C: Service<Uri> + Clone + Send + Sync + 'static,
191        C::Response: Connection + AsyncRead + AsyncWrite + Send + Unpin,
192        C::Future: Send + 'static,
193        C::Error: Into<BoxError>,
194    {
195        let mut connector = TimeoutConnector::new(conn);
196        connector.set_connect_timeout(self.connect_timeout);
197        connector.set_read_timeout(self.read_timeout);
198        connector.set_write_timeout(self.write_timeout);
199
200        let client = hyper::Client::builder().build::<_, hyper::Body>(connector);
201
202        ClientImpl {
203            http: client,
204            request_props: RequestProps {
205                url: self.url,
206                headers: self.headers,
207                method: self.method,
208                body: self.body,
209                reconnect_opts: self.reconnect_opts,
210                max_redirects: self.max_redirects.unwrap_or(DEFAULT_REDIRECT_LIMIT),
211            },
212            last_event_id: self.last_event_id,
213        }
214    }
215
216    /// Build with an HTTP client connector.
217    pub fn build_http(self) -> impl Client {
218        self.build_with_conn(HttpConnector::new())
219    }
220
221    #[cfg(feature = "rustls")]
222    /// Build with an HTTPS client connector, using the OS root certificate store.
223    pub fn build(self) -> impl Client {
224        let conn = HttpsConnectorBuilder::new()
225            .with_native_roots()
226            .https_or_http()
227            .enable_http1()
228            .enable_http2()
229            .build();
230
231        self.build_with_conn(conn)
232    }
233
234    /// Build with the given [`hyper::client::Client`].
235    pub fn build_with_http_client<C>(self, http: hyper::Client<C>) -> impl Client
236    where
237        C: Connect + Clone + Send + Sync + 'static,
238    {
239        ClientImpl {
240            http,
241            request_props: RequestProps {
242                url: self.url,
243                headers: self.headers,
244                method: self.method,
245                body: self.body,
246                reconnect_opts: self.reconnect_opts,
247                max_redirects: self.max_redirects.unwrap_or(DEFAULT_REDIRECT_LIMIT),
248            },
249            last_event_id: self.last_event_id,
250        }
251    }
252}
253
254#[derive(Clone)]
255struct RequestProps {
256    url: Uri,
257    headers: HeaderMap,
258    method: String,
259    body: Option<String>,
260    reconnect_opts: ReconnectOptions,
261    max_redirects: u32,
262}
263
264/// A client implementation that connects to a server using the Server-Sent Events protocol
265/// and consumes the event stream indefinitely.
266/// Can be parameterized with different hyper Connectors, such as HTTP or HTTPS.
267struct ClientImpl<C> {
268    http: hyper::Client<C>,
269    request_props: RequestProps,
270    last_event_id: Option<String>,
271}
272
273impl<C> Client for ClientImpl<C>
274where
275    C: Connect + Clone + Send + Sync + 'static,
276{
277    /// Connect to the server and begin consuming the stream. Produces a
278    /// [`Stream`] of [`Event`](crate::Event)s wrapped in [`Result`].
279    ///
280    /// Do not use the stream after it returned an error!
281    ///
282    /// After the first successful connection, the stream will
283    /// reconnect for retryable errors.
284    fn stream(&self) -> BoxStream<Result<SSE>> {
285        Box::pin(ReconnectingRequest::new(
286            self.http.clone(),
287            self.request_props.clone(),
288            self.last_event_id.clone(),
289        ))
290    }
291}
292
293#[allow(clippy::large_enum_variant)] // false positive
294#[pin_project(project = StateProj)]
295enum State {
296    New,
297    Connecting {
298        retry: bool,
299        #[pin]
300        resp: ResponseFuture,
301    },
302    Connected(#[pin] hyper::Body),
303    WaitingToReconnect(#[pin] Sleep),
304    FollowingRedirect(Option<HeaderValue>),
305    StreamClosed,
306}
307
308impl State {
309    fn name(&self) -> &'static str {
310        match self {
311            State::New => "new",
312            State::Connecting { retry: false, .. } => "connecting(no-retry)",
313            State::Connecting { retry: true, .. } => "connecting(retry)",
314            State::Connected(_) => "connected",
315            State::WaitingToReconnect(_) => "waiting-to-reconnect",
316            State::FollowingRedirect(_) => "following-redirect",
317            State::StreamClosed => "closed",
318        }
319    }
320}
321
322impl Debug for State {
323    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
324        write!(f, "{}", self.name())
325    }
326}
327
328#[must_use = "streams do nothing unless polled"]
329#[pin_project]
330pub struct ReconnectingRequest<C> {
331    http: hyper::Client<C>,
332    props: RequestProps,
333    #[pin]
334    state: State,
335    retry_strategy: Box<dyn RetryStrategy + Send + Sync>,
336    current_url: Uri,
337    redirect_count: u32,
338    event_parser: EventParser,
339    last_event_id: Option<String>,
340    #[pin]
341    initial_connection: bool,
342}
343
344impl<C> ReconnectingRequest<C> {
345    fn new(
346        http: hyper::Client<C>,
347        props: RequestProps,
348        last_event_id: Option<String>,
349    ) -> ReconnectingRequest<C> {
350        let reconnect_delay = props.reconnect_opts.delay;
351        let delay_max = props.reconnect_opts.delay_max;
352        let backoff_factor = props.reconnect_opts.backoff_factor;
353
354        let url = props.url.clone();
355        ReconnectingRequest {
356            props,
357            http,
358            state: State::New,
359            retry_strategy: Box::new(BackoffRetry::new(
360                reconnect_delay,
361                delay_max,
362                backoff_factor,
363                true,
364            )),
365            redirect_count: 0,
366            current_url: url,
367            event_parser: EventParser::new(),
368            last_event_id,
369            initial_connection: true,
370        }
371    }
372
373    fn send_request(&self) -> Result<ResponseFuture>
374    where
375        C: Connect + Clone + Send + Sync + 'static,
376    {
377        let mut request_builder = Request::builder()
378            .method(self.props.method.as_str())
379            .uri(&self.current_url);
380
381        for (name, value) in &self.props.headers {
382            request_builder = request_builder.header(name, value);
383        }
384
385        if let Some(id) = self.last_event_id.as_ref() {
386            if !id.is_empty() {
387                let id_as_header =
388                    HeaderValue::from_str(id).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
389
390                request_builder = request_builder.header("last-event-id", id_as_header);
391            }
392        }
393
394        let body = match &self.props.body {
395            Some(body) => Body::from(body.to_string()),
396            None => Body::empty(),
397        };
398
399        let request = request_builder
400            .body(body)
401            .map_err(|e| Error::InvalidParameter(Box::new(e)))?;
402
403        Ok(self.http.request(request))
404    }
405
406    fn reset_redirects(self: Pin<&mut Self>) {
407        let url = self.props.url.clone();
408        let this = self.project();
409        *this.current_url = url;
410        *this.redirect_count = 0;
411    }
412
413    fn increment_redirect_counter(self: Pin<&mut Self>) -> bool {
414        if self.redirect_count == self.props.max_redirects {
415            return false;
416        }
417        *self.project().redirect_count += 1;
418        true
419    }
420}
421
422impl<C> Stream for ReconnectingRequest<C>
423where
424    C: Connect + Clone + Send + Sync + 'static,
425{
426    type Item = Result<SSE>;
427
428    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
429        trace!("ReconnectingRequest::poll({:?})", &self.state);
430
431        loop {
432            let this = self.as_mut().project();
433            if let Some(event) = this.event_parser.get_event() {
434                return match event {
435                    SSE::Connected(_) => Poll::Ready(Some(Ok(event))),
436                    SSE::Event(ref evt) => {
437                        this.last_event_id.clone_from(&evt.id);
438
439                        if let Some(retry) = evt.retry {
440                            this.retry_strategy
441                                .change_base_delay(Duration::from_millis(retry));
442                        }
443                        Poll::Ready(Some(Ok(event)))
444                    }
445                    SSE::Comment(_) => Poll::Ready(Some(Ok(event))),
446                };
447            }
448
449            trace!("ReconnectingRequest::poll loop({:?})", &this.state);
450
451            let state = this.state.project();
452            match state {
453                StateProj::StreamClosed => return Poll::Ready(None),
454                // New immediately transitions to Connecting, and exists only
455                // to ensure that we only connect when polled.
456                StateProj::New => {
457                    *self.as_mut().project().event_parser = EventParser::new();
458                    match self.send_request() {
459                        Ok(resp) => {
460                            let retry = if self.initial_connection {
461                                self.props.reconnect_opts.retry_initial
462                            } else {
463                                self.props.reconnect_opts.reconnect
464                            };
465                            self.as_mut()
466                                .project()
467                                .state
468                                .set(State::Connecting { resp, retry })
469                        }
470                        Err(e) => {
471                            // This error seems to be unrecoverable. So we should just shut down the
472                            // stream.
473                            self.as_mut().project().state.set(State::StreamClosed);
474                            return Poll::Ready(Some(Err(e)));
475                        }
476                    }
477                }
478                StateProj::Connecting { retry, resp } => match ready!(resp.poll(cx)) {
479                    Ok(resp) => {
480                        debug!("HTTP response: {:#?}", resp);
481
482                        if resp.status().is_success() {
483                            self.as_mut().project().retry_strategy.reset(Instant::now());
484                            self.as_mut().reset_redirects();
485
486                            let status = resp.status();
487                            let headers = resp.headers().clone();
488
489                            self.as_mut()
490                                .project()
491                                .state
492                                .set(State::Connected(resp.into_body()));
493                            self.as_mut().project().initial_connection.set(false);
494
495                            return Poll::Ready(Some(Ok(SSE::Connected(ConnectionDetails::new(
496                                Response::new(status, headers),
497                            )))));
498                        }
499
500                        if resp.status() == 301 || resp.status() == 307 {
501                            debug!("got redirected ({})", resp.status());
502
503                            if self.as_mut().increment_redirect_counter() {
504                                debug!("following redirect {}", self.redirect_count);
505
506                                self.as_mut().project().state.set(State::FollowingRedirect(
507                                    resp.headers().get(hyper::header::LOCATION).cloned(),
508                                ));
509                                continue;
510                            } else {
511                                debug!("redirect limit reached ({})", self.props.max_redirects);
512
513                                self.as_mut().project().state.set(State::StreamClosed);
514                                return Poll::Ready(Some(Err(Error::MaxRedirectLimitReached(
515                                    self.props.max_redirects,
516                                ))));
517                            }
518                        }
519
520                        let error = Error::UnexpectedResponse(
521                            Response::new(resp.status(), resp.headers().clone()),
522                            ErrorBody::new(resp.into_body()),
523                        );
524
525                        if !*retry {
526                            self.as_mut().project().state.set(State::StreamClosed);
527                            return Poll::Ready(Some(Err(error)));
528                        }
529
530                        self.as_mut().reset_redirects();
531
532                        let duration = self
533                            .as_mut()
534                            .project()
535                            .retry_strategy
536                            .next_delay(Instant::now());
537
538                        self.as_mut()
539                            .project()
540                            .state
541                            .set(State::WaitingToReconnect(delay(duration, "retrying")));
542
543                        return Poll::Ready(Some(Err(error)));
544                    }
545                    Err(e) => {
546                        // This happens when the server is unreachable, e.g. connection refused.
547                        warn!("request returned an error: {}", e);
548                        if !*retry {
549                            self.as_mut().project().state.set(State::StreamClosed);
550                            return Poll::Ready(Some(Err(Error::HttpStream(Box::new(e)))));
551                        }
552
553                        let duration = self
554                            .as_mut()
555                            .project()
556                            .retry_strategy
557                            .next_delay(Instant::now());
558
559                        self.as_mut()
560                            .project()
561                            .state
562                            .set(State::WaitingToReconnect(delay(duration, "retrying")));
563                    }
564                },
565                StateProj::FollowingRedirect(maybe_header) => match uri_from_header(maybe_header) {
566                    Ok(uri) => {
567                        *self.as_mut().project().current_url = uri;
568                        self.as_mut().project().state.set(State::New);
569                    }
570                    Err(e) => {
571                        self.as_mut().project().state.set(State::StreamClosed);
572                        return Poll::Ready(Some(Err(e)));
573                    }
574                },
575                StateProj::Connected(body) => match ready!(body.poll_data(cx)) {
576                    Some(Ok(result)) => {
577                        this.event_parser.process_bytes(result)?;
578                        continue;
579                    }
580                    Some(Err(e)) => {
581                        if self.props.reconnect_opts.reconnect {
582                            let duration = self
583                                .as_mut()
584                                .project()
585                                .retry_strategy
586                                .next_delay(Instant::now());
587                            self.as_mut()
588                                .project()
589                                .state
590                                .set(State::WaitingToReconnect(delay(duration, "reconnecting")));
591                        }
592
593                        if let Some(cause) = e.source() {
594                            if let Some(downcast) = cause.downcast_ref::<std::io::Error>() {
595                                if let std::io::ErrorKind::TimedOut = downcast.kind() {
596                                    return Poll::Ready(Some(Err(Error::TimedOut)));
597                                }
598                            }
599                        } else {
600                            return Poll::Ready(Some(Err(Error::HttpStream(Box::new(e)))));
601                        }
602                    }
603                    None => {
604                        let duration = self
605                            .as_mut()
606                            .project()
607                            .retry_strategy
608                            .next_delay(Instant::now());
609                        self.as_mut()
610                            .project()
611                            .state
612                            .set(State::WaitingToReconnect(delay(duration, "retrying")));
613
614                        if self.event_parser.was_processing() {
615                            return Poll::Ready(Some(Err(Error::UnexpectedEof)));
616                        }
617                        return Poll::Ready(Some(Err(Error::Eof)));
618                    }
619                },
620                StateProj::WaitingToReconnect(delay) => {
621                    ready!(delay.poll(cx));
622                    info!("Reconnecting");
623                    self.as_mut().project().state.set(State::New);
624                }
625            };
626        }
627    }
628}
629
630fn uri_from_header(maybe_header: &Option<HeaderValue>) -> Result<Uri> {
631    let header = maybe_header.as_ref().ok_or_else(|| {
632        Error::MalformedLocationHeader(Box::new(std::io::Error::new(
633            ErrorKind::NotFound,
634            "missing Location header",
635        )))
636    })?;
637
638    let header_string = header
639        .to_str()
640        .map_err(|e| Error::MalformedLocationHeader(Box::new(e)))?;
641
642    header_string
643        .parse::<Uri>()
644        .map_err(|e| Error::MalformedLocationHeader(Box::new(e)))
645}
646
647fn delay(dur: Duration, description: &str) -> Sleep {
648    info!("Waiting {:?} before {}", dur, description);
649    tokio::time::sleep(dur)
650}
651
652mod private {
653    use crate::client::ClientImpl;
654
655    pub trait Sealed {}
656    impl<C> Sealed for ClientImpl<C> {}
657}
658
659#[cfg(test)]
660mod tests {
661    use crate::ClientBuilder;
662    use hyper::http::HeaderValue;
663    use test_case::test_case;
664
665    #[test_case("user", "pass", "dXNlcjpwYXNz")]
666    #[test_case("user1", "password123", "dXNlcjE6cGFzc3dvcmQxMjM=")]
667    #[test_case("user2", "", "dXNlcjI6")]
668    #[test_case("user@name", "pass#word!", "dXNlckBuYW1lOnBhc3Mjd29yZCE=")]
669    #[test_case("user3", "my pass", "dXNlcjM6bXkgcGFzcw==")]
670    #[test_case(
671        "weird@-/:stuff",
672        "goes@-/:here",
673        "d2VpcmRALS86c3R1ZmY6Z29lc0AtLzpoZXJl"
674    )]
675    fn basic_auth_generates_correct_headers(username: &str, password: &str, expected: &str) {
676        let builder = ClientBuilder::for_url("http://example.com")
677            .expect("failed to build client")
678            .basic_auth(username, password)
679            .expect("failed to add authentication");
680
681        let actual = builder.headers.get("Authorization");
682        let expected = HeaderValue::from_str(format!("Basic {}", expected).as_str())
683            .expect("unable to create expected header");
684
685        assert_eq!(Some(&expected), actual);
686    }
687
688    use std::{pin::pin, str::FromStr, time::Duration};
689
690    use futures::TryStreamExt;
691    use hyper::{client::HttpConnector, Body, HeaderMap, Request, Uri};
692    use hyper_timeout::TimeoutConnector;
693    use tokio::time::timeout;
694
695    use crate::{
696        client::{RequestProps, State},
697        ReconnectOptionsBuilder, ReconnectingRequest,
698    };
699
700    const INVALID_URI: &'static str = "http://mycrazyunexsistenturl.invaliddomainext";
701
702    #[test_case(INVALID_URI, false, |state| matches!(state, State::StreamClosed))]
703    #[test_case(INVALID_URI, true, |state| matches!(state, State::WaitingToReconnect(_)))]
704    #[tokio::test]
705    async fn initial_connection(uri: &str, retry_initial: bool, expected: fn(&State) -> bool) {
706        let default_timeout = Some(Duration::from_secs(1));
707        let conn = HttpConnector::new();
708        let mut connector = TimeoutConnector::new(conn);
709        connector.set_connect_timeout(default_timeout);
710        connector.set_read_timeout(default_timeout);
711        connector.set_write_timeout(default_timeout);
712
713        let reconnect_opts = ReconnectOptionsBuilder::new(false)
714            .backoff_factor(1)
715            .delay(Duration::from_secs(1))
716            .retry_initial(retry_initial)
717            .build();
718
719        let http = hyper::Client::builder().build::<_, hyper::Body>(connector);
720        let req_props = RequestProps {
721            url: Uri::from_str(uri).unwrap(),
722            headers: HeaderMap::new(),
723            method: "GET".to_string(),
724            body: None,
725            reconnect_opts,
726            max_redirects: 10,
727        };
728
729        let mut reconnecting_request = ReconnectingRequest::new(http, req_props, None);
730
731        // sets initial state
732        let resp = reconnecting_request.http.request(
733            Request::builder()
734                .method("GET")
735                .uri(uri)
736                .body(Body::empty())
737                .unwrap(),
738        );
739
740        reconnecting_request.state = State::Connecting {
741            retry: reconnecting_request.props.reconnect_opts.retry_initial,
742            resp,
743        };
744
745        let mut reconnecting_request = pin!(reconnecting_request);
746
747        timeout(Duration::from_millis(500), reconnecting_request.try_next())
748            .await
749            .ok();
750
751        assert!(expected(&reconnecting_request.state));
752    }
753
754    #[test_case(false, |state| matches!(state, State::StreamClosed))]
755    #[test_case(true, |state| matches!(state, State::WaitingToReconnect(_)))]
756    #[tokio::test]
757    async fn initial_connection_mocked_server(retry_initial: bool, expected: fn(&State) -> bool) {
758        let mut mock_server = mockito::Server::new_async().await;
759        let _mock = mock_server
760            .mock("GET", "/")
761            .with_status(404)
762            .create_async()
763            .await;
764
765        initial_connection(&mock_server.url(), retry_initial, expected).await;
766    }
767}