Skip to main content

eventsource_client/
client.rs

1use base64::prelude::*;
2
3use futures::{ready, Stream};
4use http::{HeaderMap, HeaderName, HeaderValue, Request, Uri};
5use log::{debug, info, trace, warn};
6use pin_project::pin_project;
7use std::{
8    boxed,
9    fmt::{self, Debug, Formatter},
10    future::Future,
11    io::ErrorKind,
12    pin::Pin,
13    str::FromStr,
14    sync::Arc,
15    task::{Context, Poll},
16    time::{Duration, Instant},
17};
18
19use tokio::time::Sleep;
20
21use crate::{
22    config::ReconnectOptions,
23    response::{ErrorBody, Response},
24};
25use crate::{
26    error::{Error, Result},
27    event_parser::ConnectionDetails,
28};
29use launchdarkly_sdk_transport::{ByteStream, HttpTransport, ResponseFuture};
30
31use crate::event_parser::EventParser;
32use crate::event_parser::SSE;
33
34use crate::retry::{BackoffRetry, RetryStrategy};
35use std::error::Error as StdError;
36
37/// Represents a [`Pin`]'d [`Send`] + [`Sync`] stream, returned by [`Client`]'s stream method.
38pub type BoxStream<T> = Pin<boxed::Box<dyn Stream<Item = T> + Send + Sync>>;
39
40/// Client is the Server-Sent-Events interface.
41/// This trait is sealed and cannot be implemented for types outside this crate.
42pub trait Client: Send + Sync + private::Sealed {
43    fn stream(&self) -> BoxStream<Result<SSE>>;
44}
45
46/*
47 * TODO remove debug output
48 * TODO specify list of stati to not retry (e.g. 204)
49 */
50
51/// Maximum amount of redirects that the client will follow before
52/// giving up, if not overridden via [ClientBuilder::redirect_limit].
53pub const DEFAULT_REDIRECT_LIMIT: u32 = 16;
54
55/// ClientBuilder provides a series of builder methods to easily construct a [`Client`].
56pub struct ClientBuilder {
57    url: Uri,
58    headers: HeaderMap,
59    reconnect_opts: ReconnectOptions,
60    last_event_id: Option<String>,
61    method: String,
62    body: Option<String>,
63    max_redirects: Option<u32>,
64}
65
66impl ClientBuilder {
67    /// Create a builder for a given URL.
68    pub fn for_url(url: &str) -> Result<ClientBuilder> {
69        let url = url
70            .parse()
71            .map_err(|e| Error::InvalidParameter(Box::new(e)))?;
72
73        let mut header_map = HeaderMap::new();
74        header_map.insert("Accept", HeaderValue::from_static("text/event-stream"));
75        header_map.insert("Cache-Control", HeaderValue::from_static("no-cache"));
76
77        Ok(ClientBuilder {
78            url,
79            headers: header_map,
80            reconnect_opts: ReconnectOptions::default(),
81            last_event_id: None,
82            method: String::from("GET"),
83            max_redirects: None,
84            body: None,
85        })
86    }
87
88    /// Set the request method used for the initial connection to the SSE endpoint.
89    pub fn method(mut self, method: String) -> ClientBuilder {
90        self.method = method;
91        self
92    }
93
94    /// Set the request body used for the initial connection to the SSE endpoint.
95    pub fn body(mut self, body: String) -> ClientBuilder {
96        self.body = Some(body);
97        self
98    }
99
100    /// Set the last event id for a stream when it is created. If it is set, it will be sent to the
101    /// server in case it can replay missed events.
102    pub fn last_event_id(mut self, last_event_id: String) -> ClientBuilder {
103        self.last_event_id = Some(last_event_id);
104        self
105    }
106
107    /// Set a HTTP header on the SSE request.
108    pub fn header(mut self, name: &str, value: &str) -> Result<ClientBuilder> {
109        let name = HeaderName::from_str(name).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
110
111        let value =
112            HeaderValue::from_str(value).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
113
114        self.headers.insert(name, value);
115        Ok(self)
116    }
117
118    /// Set the Authorization header with the calculated basic authentication value.
119    pub fn basic_auth(self, username: &str, password: &str) -> Result<ClientBuilder> {
120        let auth = format!("{username}:{password}");
121        let encoded = BASE64_STANDARD.encode(auth);
122        let value = format!("Basic {encoded}");
123
124        self.header("Authorization", &value)
125    }
126
127    /// Configure the client's reconnect behaviour according to the supplied
128    /// [`ReconnectOptions`].
129    ///
130    /// [`ReconnectOptions`]: struct.ReconnectOptions.html
131    pub fn reconnect(mut self, opts: ReconnectOptions) -> ClientBuilder {
132        self.reconnect_opts = opts;
133        self
134    }
135
136    /// Customize the client's following behavior when served a redirect.
137    /// To disable following redirects, pass `0`.
138    /// By default, the limit is [`DEFAULT_REDIRECT_LIMIT`].
139    pub fn redirect_limit(mut self, limit: u32) -> ClientBuilder {
140        self.max_redirects = Some(limit);
141        self
142    }
143
144    /// Build a client with a custom HTTP transport implementation.
145    ///
146    /// # Arguments
147    ///
148    /// * `transport` - An implementation of the [`HttpTransport`] trait that will handle
149    ///   HTTP requests. See the `examples/` directory for reference implementations.
150    ///
151    /// # Example
152    ///
153    /// ```ignore
154    /// use eventsource_client::ClientBuilder;
155    ///
156    /// let transport = MyTransport::new();
157    /// let client = ClientBuilder::for_url("https://live-test-scores.herokuapp.com/scores")
158    ///     .expect("failed to create client builder")
159    ///     .build_with_transport(transport);
160    /// ```
161    pub fn build_with_transport<T>(self, transport: T) -> impl Client
162    where
163        T: HttpTransport,
164    {
165        ClientImpl {
166            transport: Arc::new(transport),
167            request_props: RequestProps {
168                url: self.url,
169                headers: self.headers,
170                method: self.method,
171                body: self.body,
172                reconnect_opts: self.reconnect_opts,
173                max_redirects: self.max_redirects.unwrap_or(DEFAULT_REDIRECT_LIMIT),
174            },
175            last_event_id: self.last_event_id,
176        }
177    }
178}
179
180#[derive(Clone)]
181struct RequestProps {
182    url: Uri,
183    headers: HeaderMap,
184    method: String,
185    body: Option<String>,
186    reconnect_opts: ReconnectOptions,
187    max_redirects: u32,
188}
189
190/// A client implementation that connects to a server using the Server-Sent Events protocol
191/// and consumes the event stream indefinitely.
192struct ClientImpl<T: HttpTransport> {
193    transport: Arc<T>,
194    request_props: RequestProps,
195    last_event_id: Option<String>,
196}
197
198impl<T: HttpTransport> Client for ClientImpl<T> {
199    /// Connect to the server and begin consuming the stream. Produces a
200    /// [`Stream`] of [`Event`](crate::Event)s wrapped in [`Result`].
201    ///
202    /// Do not use the stream after it returned an error!
203    ///
204    /// After the first successful connection, the stream will
205    /// reconnect for retryable errors.
206    fn stream(&self) -> BoxStream<Result<SSE>> {
207        Box::pin(ReconnectingRequest::new(
208            Arc::clone(&self.transport),
209            self.request_props.clone(),
210            self.last_event_id.clone(),
211        ))
212    }
213}
214
215#[allow(clippy::large_enum_variant)] // false positive
216#[pin_project(project = StateProj)]
217enum State {
218    New,
219    Connecting {
220        retry: bool,
221        #[pin]
222        resp: ResponseFuture,
223    },
224    Connected(#[pin] ByteStream),
225    WaitingToReconnect(#[pin] Sleep),
226    FollowingRedirect(Option<HeaderValue>),
227    StreamClosed,
228}
229
230impl State {
231    fn name(&self) -> &'static str {
232        match self {
233            State::New => "new",
234            State::Connecting { retry: false, .. } => "connecting(no-retry)",
235            State::Connecting { retry: true, .. } => "connecting(retry)",
236            State::Connected(_) => "connected",
237            State::WaitingToReconnect(_) => "waiting-to-reconnect",
238            State::FollowingRedirect(_) => "following-redirect",
239            State::StreamClosed => "closed",
240        }
241    }
242}
243
244impl Debug for State {
245    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
246        write!(f, "{}", self.name())
247    }
248}
249
250#[must_use = "streams do nothing unless polled"]
251#[pin_project]
252pub struct ReconnectingRequest<T: HttpTransport> {
253    transport: Arc<T>,
254    props: RequestProps,
255    #[pin]
256    state: State,
257    retry_strategy: Box<dyn RetryStrategy + Send + Sync>,
258    current_url: Uri,
259    redirect_count: u32,
260    event_parser: EventParser,
261    last_event_id: Option<String>,
262    #[pin]
263    initial_connection: bool,
264}
265
266impl<T: HttpTransport> ReconnectingRequest<T> {
267    fn new(
268        transport: Arc<T>,
269        props: RequestProps,
270        last_event_id: Option<String>,
271    ) -> ReconnectingRequest<T> {
272        let reconnect_delay = props.reconnect_opts.delay;
273        let delay_max = props.reconnect_opts.delay_max;
274        let backoff_factor = props.reconnect_opts.backoff_factor;
275
276        let url = props.url.clone();
277        ReconnectingRequest {
278            props,
279            transport,
280            state: State::New,
281            retry_strategy: Box::new(BackoffRetry::new(
282                reconnect_delay,
283                delay_max,
284                backoff_factor,
285                true,
286            )),
287            redirect_count: 0,
288            current_url: url,
289            event_parser: EventParser::new(),
290            last_event_id,
291            initial_connection: true,
292        }
293    }
294
295    fn send_request(&self) -> Result<ResponseFuture> {
296        let mut request_builder = Request::builder()
297            .method(self.props.method.as_str())
298            .uri(&self.current_url);
299
300        for (name, value) in &self.props.headers {
301            request_builder = request_builder.header(name, value);
302        }
303
304        if let Some(id) = self.last_event_id.as_ref() {
305            if !id.is_empty() {
306                let id_as_header =
307                    HeaderValue::from_str(id).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
308
309                request_builder = request_builder.header("last-event-id", id_as_header);
310            }
311        }
312
313        // Include the request body if set. Most SSE requests use GET and will have None,
314        // but some implementations (e.g., using REPORT method) may include a body.
315        let request = request_builder
316            .body(self.props.body.clone().map(|b| b.into()))
317            .map_err(|e| Error::InvalidParameter(Box::new(e)))?;
318
319        Ok(self.transport.request(request))
320    }
321
322    fn reset_redirects(self: Pin<&mut Self>) {
323        let url = self.props.url.clone();
324        let this = self.project();
325        *this.current_url = url;
326        *this.redirect_count = 0;
327    }
328
329    fn increment_redirect_counter(self: Pin<&mut Self>) -> bool {
330        if self.redirect_count == self.props.max_redirects {
331            return false;
332        }
333        *self.project().redirect_count += 1;
334        true
335    }
336}
337
338impl<T: HttpTransport> Stream for ReconnectingRequest<T> {
339    type Item = Result<SSE>;
340
341    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
342        trace!("ReconnectingRequest::poll({:?})", &self.state);
343
344        loop {
345            let this = self.as_mut().project();
346            if let Some(event) = this.event_parser.get_event() {
347                return match event {
348                    SSE::Connected(_) => Poll::Ready(Some(Ok(event))),
349                    SSE::Event(ref evt) => {
350                        this.last_event_id.clone_from(&evt.id);
351
352                        if let Some(retry) = evt.retry {
353                            this.retry_strategy
354                                .change_base_delay(Duration::from_millis(retry));
355                        }
356                        Poll::Ready(Some(Ok(event)))
357                    }
358                    SSE::Comment(_) => Poll::Ready(Some(Ok(event))),
359                };
360            }
361
362            trace!("ReconnectingRequest::poll loop({:?})", &this.state);
363
364            let state = this.state.project();
365            match state {
366                StateProj::StreamClosed => return Poll::Ready(None),
367                // New immediately transitions to Connecting, and exists only
368                // to ensure that we only connect when polled.
369                StateProj::New => {
370                    *self.as_mut().project().event_parser = EventParser::new();
371                    match self.send_request() {
372                        Ok(resp) => {
373                            let retry = if self.initial_connection {
374                                self.props.reconnect_opts.retry_initial
375                            } else {
376                                self.props.reconnect_opts.reconnect
377                            };
378                            self.as_mut()
379                                .project()
380                                .state
381                                .set(State::Connecting { resp, retry })
382                        }
383                        Err(e) => {
384                            // This error seems to be unrecoverable. So we should just shut down the
385                            // stream.
386                            self.as_mut().project().state.set(State::StreamClosed);
387                            return Poll::Ready(Some(Err(e)));
388                        }
389                    }
390                }
391                StateProj::Connecting { retry, resp } => match ready!(resp.poll(cx)) {
392                    Ok(resp) => {
393                        debug!(
394                            "HTTP response status: {}, headers: {:?}",
395                            resp.status(),
396                            resp.headers()
397                        );
398
399                        if resp.status().is_success() {
400                            self.as_mut().project().retry_strategy.reset(Instant::now());
401                            self.as_mut().reset_redirects();
402
403                            let status = resp.status();
404                            let headers = resp.headers().clone();
405
406                            self.as_mut()
407                                .project()
408                                .state
409                                .set(State::Connected(resp.into_body()));
410                            self.as_mut().project().initial_connection.set(false);
411
412                            return Poll::Ready(Some(Ok(SSE::Connected(ConnectionDetails::new(
413                                Response::new(status, headers),
414                            )))));
415                        }
416
417                        if resp.status() == 301 || resp.status() == 307 {
418                            debug!("got redirected ({})", resp.status());
419
420                            if self.as_mut().increment_redirect_counter() {
421                                debug!("following redirect {}", self.redirect_count);
422
423                                self.as_mut().project().state.set(State::FollowingRedirect(
424                                    resp.headers().get("location").cloned(),
425                                ));
426                                continue;
427                            } else {
428                                debug!("redirect limit reached ({})", self.props.max_redirects);
429
430                                self.as_mut().project().state.set(State::StreamClosed);
431                                return Poll::Ready(Some(Err(Error::MaxRedirectLimitReached(
432                                    self.props.max_redirects,
433                                ))));
434                            }
435                        }
436
437                        let status = resp.status();
438                        let headers = resp.headers().clone();
439                        let body = resp.into_body();
440
441                        let error = Error::UnexpectedResponse(
442                            Response::new(status, headers),
443                            ErrorBody::new(body),
444                        );
445
446                        if !*retry {
447                            self.as_mut().project().state.set(State::StreamClosed);
448                            return Poll::Ready(Some(Err(error)));
449                        }
450
451                        self.as_mut().reset_redirects();
452
453                        let duration = self
454                            .as_mut()
455                            .project()
456                            .retry_strategy
457                            .next_delay(Instant::now());
458
459                        self.as_mut()
460                            .project()
461                            .state
462                            .set(State::WaitingToReconnect(delay(duration, "retrying")));
463
464                        return Poll::Ready(Some(Err(error)));
465                    }
466                    Err(e) => {
467                        // This happens when the server is unreachable, e.g. connection refused.
468                        warn!("request returned an error: {e}");
469                        if !*retry {
470                            self.as_mut().project().state.set(State::StreamClosed);
471                            return Poll::Ready(Some(Err(Error::Transport(e))));
472                        }
473
474                        let duration = self
475                            .as_mut()
476                            .project()
477                            .retry_strategy
478                            .next_delay(Instant::now());
479
480                        self.as_mut()
481                            .project()
482                            .state
483                            .set(State::WaitingToReconnect(delay(duration, "retrying")));
484                    }
485                },
486                StateProj::FollowingRedirect(maybe_header) => match uri_from_header(maybe_header) {
487                    Ok(uri) => {
488                        *self.as_mut().project().current_url = uri;
489                        self.as_mut().project().state.set(State::New);
490                    }
491                    Err(e) => {
492                        self.as_mut().project().state.set(State::StreamClosed);
493                        return Poll::Ready(Some(Err(e)));
494                    }
495                },
496                StateProj::Connected(mut body) => match ready!(body.as_mut().poll_next(cx)) {
497                    Some(Ok(result)) => {
498                        this.event_parser.process_bytes(result)?;
499                        continue;
500                    }
501                    Some(Err(e)) => {
502                        if self.props.reconnect_opts.reconnect {
503                            let duration = self
504                                .as_mut()
505                                .project()
506                                .retry_strategy
507                                .next_delay(Instant::now());
508                            self.as_mut()
509                                .project()
510                                .state
511                                .set(State::WaitingToReconnect(delay(duration, "reconnecting")));
512                        }
513
514                        // Check if the underlying error is a timeout
515                        if let Some(cause) = e.source() {
516                            if let Some(downcast) = cause.downcast_ref::<std::io::Error>() {
517                                if let std::io::ErrorKind::TimedOut = downcast.kind() {
518                                    return Poll::Ready(Some(Err(Error::TimedOut)));
519                                }
520                            }
521                        }
522
523                        return Poll::Ready(Some(Err(Error::Transport(e))));
524                    }
525                    None => {
526                        let duration = self
527                            .as_mut()
528                            .project()
529                            .retry_strategy
530                            .next_delay(Instant::now());
531                        self.as_mut()
532                            .project()
533                            .state
534                            .set(State::WaitingToReconnect(delay(duration, "retrying")));
535
536                        if self.event_parser.was_processing() {
537                            return Poll::Ready(Some(Err(Error::UnexpectedEof)));
538                        }
539                        return Poll::Ready(Some(Err(Error::Eof)));
540                    }
541                },
542                StateProj::WaitingToReconnect(delay) => {
543                    ready!(delay.poll(cx));
544                    info!("Reconnecting");
545                    self.as_mut().project().state.set(State::New);
546                }
547            };
548        }
549    }
550}
551
552fn uri_from_header(maybe_header: &Option<HeaderValue>) -> Result<Uri> {
553    let header = maybe_header.as_ref().ok_or_else(|| {
554        Error::MalformedLocationHeader(Box::new(std::io::Error::new(
555            ErrorKind::NotFound,
556            "missing Location header",
557        )))
558    })?;
559
560    let header_string = header
561        .to_str()
562        .map_err(|e| Error::MalformedLocationHeader(Box::new(e)))?;
563
564    header_string
565        .parse::<Uri>()
566        .map_err(|e| Error::MalformedLocationHeader(Box::new(e)))
567}
568
569fn delay(dur: Duration, description: &str) -> Sleep {
570    info!("Waiting {dur:?} before {description}");
571    tokio::time::sleep(dur)
572}
573
574mod private {
575    use crate::client::ClientImpl;
576    use launchdarkly_sdk_transport::HttpTransport;
577
578    pub trait Sealed {}
579    impl<T: HttpTransport> Sealed for ClientImpl<T> {}
580}
581
582#[cfg(test)]
583mod tests {
584    use crate::ClientBuilder;
585    use http::HeaderValue;
586    use test_case::test_case;
587
588    #[test_case("user", "pass", "dXNlcjpwYXNz")]
589    #[test_case("user1", "password123", "dXNlcjE6cGFzc3dvcmQxMjM=")]
590    #[test_case("user2", "", "dXNlcjI6")]
591    #[test_case("user@name", "pass#word!", "dXNlckBuYW1lOnBhc3Mjd29yZCE=")]
592    #[test_case("user3", "my pass", "dXNlcjM6bXkgcGFzcw==")]
593    #[test_case(
594        "weird@-/:stuff",
595        "goes@-/:here",
596        "d2VpcmRALS86c3R1ZmY6Z29lc0AtLzpoZXJl"
597    )]
598    fn basic_auth_generates_correct_headers(username: &str, password: &str, expected: &str) {
599        let builder = ClientBuilder::for_url("http://example.com")
600            .expect("failed to build client")
601            .basic_auth(username, password)
602            .expect("failed to add authentication");
603
604        let actual = builder.headers.get("Authorization");
605        let expected = HeaderValue::from_str(format!("Basic {expected}").as_str())
606            .expect("unable to create expected header");
607
608        assert_eq!(Some(&expected), actual);
609    }
610
611    use std::{pin::pin, sync::Arc, time::Duration};
612
613    use bytes::Bytes;
614    use futures::{stream, TryStreamExt};
615    use http::HeaderMap;
616    use tokio::time::timeout;
617
618    use crate::{
619        client::{RequestProps, State},
620        ReconnectOptionsBuilder, ReconnectingRequest,
621    };
622    use launchdarkly_sdk_transport::{ByteStream, HttpTransport, ResponseFuture, TransportError};
623
624    // Mock transport for testing
625    #[derive(Clone)]
626    struct MockTransport {
627        fail_request: bool,
628    }
629
630    impl MockTransport {
631        fn new(_url: String, fail_request: bool) -> Self {
632            Self { fail_request }
633        }
634    }
635
636    impl HttpTransport for MockTransport {
637        fn request(&self, _request: http::Request<Option<Bytes>>) -> ResponseFuture {
638            if self.fail_request {
639                // Simulate a connection error
640                Box::pin(async {
641                    Err(TransportError::new(std::io::Error::new(
642                        std::io::ErrorKind::ConnectionRefused,
643                        "connection refused",
644                    )))
645                })
646            } else {
647                // Return a 404 response
648                Box::pin(async {
649                    let byte_stream: ByteStream =
650                        Box::pin(stream::iter(vec![Ok(Bytes::from("not found"))]));
651                    let response = http::Response::builder()
652                        .status(404)
653                        .body(byte_stream)
654                        .unwrap();
655                    Ok(response)
656                })
657            }
658        }
659    }
660
661    const INVALID_URI: &str = "http://mycrazyunexsistenturl.invaliddomainext";
662
663    #[test_case(INVALID_URI, false, |state| matches!(state, State::StreamClosed))]
664    #[test_case(INVALID_URI, true, |state| matches!(state, State::WaitingToReconnect(_)))]
665    #[tokio::test]
666    async fn initial_connection(uri: &str, retry_initial: bool, expected: fn(&State) -> bool) {
667        let reconnect_opts = ReconnectOptionsBuilder::new(false)
668            .backoff_factor(1)
669            .delay(Duration::from_secs(1))
670            .retry_initial(retry_initial)
671            .build();
672
673        let transport = Arc::new(MockTransport::new(uri.to_string(), true));
674        let req_props = RequestProps {
675            url: uri.parse().unwrap(),
676            headers: HeaderMap::new(),
677            method: "GET".to_string(),
678            body: None,
679            reconnect_opts,
680            max_redirects: 10,
681        };
682
683        let mut reconnecting_request = ReconnectingRequest::new(transport.clone(), req_props, None);
684
685        // sets initial state with a failing request
686        let resp = transport.request(http::Request::builder().uri(uri).body(None).unwrap());
687
688        reconnecting_request.state = State::Connecting {
689            retry: reconnecting_request.props.reconnect_opts.retry_initial,
690            resp,
691        };
692
693        let mut reconnecting_request = pin!(reconnecting_request);
694
695        timeout(Duration::from_millis(500), reconnecting_request.try_next())
696            .await
697            .ok();
698
699        assert!(expected(&reconnecting_request.state));
700    }
701
702    #[test_case(false, |state| matches!(state, State::StreamClosed))]
703    #[test_case(true, |state| matches!(state, State::WaitingToReconnect(_)))]
704    #[tokio::test]
705    async fn initial_connection_mocked_server(retry_initial: bool, expected: fn(&State) -> bool) {
706        let mut mock_server = mockito::Server::new_async().await;
707        let _mock = mock_server
708            .mock("GET", "/")
709            .with_status(404)
710            .create_async()
711            .await;
712
713        initial_connection(&mock_server.url(), retry_initial, expected).await;
714    }
715}