1use base64::prelude::*;
2
3use futures::{ready, Stream};
4use http::{HeaderMap, HeaderName, HeaderValue, Request, Uri};
5use log::{debug, info, trace, warn};
6use pin_project::pin_project;
7use std::{
8 boxed,
9 fmt::{self, Debug, Formatter},
10 future::Future,
11 io::ErrorKind,
12 pin::Pin,
13 str::FromStr,
14 sync::Arc,
15 task::{Context, Poll},
16 time::{Duration, Instant},
17};
18
19use tokio::time::Sleep;
20
21use crate::{
22 config::ReconnectOptions,
23 response::{ErrorBody, Response},
24};
25use crate::{
26 error::{Error, Result},
27 event_parser::ConnectionDetails,
28};
29use launchdarkly_sdk_transport::{ByteStream, HttpTransport, ResponseFuture};
30
31use crate::event_parser::EventParser;
32use crate::event_parser::SSE;
33
34use crate::retry::{BackoffRetry, RetryStrategy};
35use std::error::Error as StdError;
36
37pub type BoxStream<T> = Pin<boxed::Box<dyn Stream<Item = T> + Send + Sync>>;
39
40pub trait Client: Send + Sync + private::Sealed {
43 fn stream(&self) -> BoxStream<Result<SSE>>;
44}
45
46pub const DEFAULT_REDIRECT_LIMIT: u32 = 16;
54
55pub struct ClientBuilder {
57 url: Uri,
58 headers: HeaderMap,
59 reconnect_opts: ReconnectOptions,
60 last_event_id: Option<String>,
61 method: String,
62 body: Option<String>,
63 max_redirects: Option<u32>,
64}
65
66impl ClientBuilder {
67 pub fn for_url(url: &str) -> Result<ClientBuilder> {
69 let url = url
70 .parse()
71 .map_err(|e| Error::InvalidParameter(Box::new(e)))?;
72
73 let mut header_map = HeaderMap::new();
74 header_map.insert("Accept", HeaderValue::from_static("text/event-stream"));
75 header_map.insert("Cache-Control", HeaderValue::from_static("no-cache"));
76
77 Ok(ClientBuilder {
78 url,
79 headers: header_map,
80 reconnect_opts: ReconnectOptions::default(),
81 last_event_id: None,
82 method: String::from("GET"),
83 max_redirects: None,
84 body: None,
85 })
86 }
87
88 pub fn method(mut self, method: String) -> ClientBuilder {
90 self.method = method;
91 self
92 }
93
94 pub fn body(mut self, body: String) -> ClientBuilder {
96 self.body = Some(body);
97 self
98 }
99
100 pub fn last_event_id(mut self, last_event_id: String) -> ClientBuilder {
103 self.last_event_id = Some(last_event_id);
104 self
105 }
106
107 pub fn header(mut self, name: &str, value: &str) -> Result<ClientBuilder> {
109 let name = HeaderName::from_str(name).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
110
111 let value =
112 HeaderValue::from_str(value).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
113
114 self.headers.insert(name, value);
115 Ok(self)
116 }
117
118 pub fn basic_auth(self, username: &str, password: &str) -> Result<ClientBuilder> {
120 let auth = format!("{username}:{password}");
121 let encoded = BASE64_STANDARD.encode(auth);
122 let value = format!("Basic {encoded}");
123
124 self.header("Authorization", &value)
125 }
126
127 pub fn reconnect(mut self, opts: ReconnectOptions) -> ClientBuilder {
132 self.reconnect_opts = opts;
133 self
134 }
135
136 pub fn redirect_limit(mut self, limit: u32) -> ClientBuilder {
140 self.max_redirects = Some(limit);
141 self
142 }
143
144 pub fn build_with_transport<T>(self, transport: T) -> impl Client
162 where
163 T: HttpTransport,
164 {
165 ClientImpl {
166 transport: Arc::new(transport),
167 request_props: RequestProps {
168 url: self.url,
169 headers: self.headers,
170 method: self.method,
171 body: self.body,
172 reconnect_opts: self.reconnect_opts,
173 max_redirects: self.max_redirects.unwrap_or(DEFAULT_REDIRECT_LIMIT),
174 },
175 last_event_id: self.last_event_id,
176 }
177 }
178}
179
180#[derive(Clone)]
181struct RequestProps {
182 url: Uri,
183 headers: HeaderMap,
184 method: String,
185 body: Option<String>,
186 reconnect_opts: ReconnectOptions,
187 max_redirects: u32,
188}
189
190struct ClientImpl<T: HttpTransport> {
193 transport: Arc<T>,
194 request_props: RequestProps,
195 last_event_id: Option<String>,
196}
197
198impl<T: HttpTransport> Client for ClientImpl<T> {
199 fn stream(&self) -> BoxStream<Result<SSE>> {
207 Box::pin(ReconnectingRequest::new(
208 Arc::clone(&self.transport),
209 self.request_props.clone(),
210 self.last_event_id.clone(),
211 ))
212 }
213}
214
215#[allow(clippy::large_enum_variant)] #[pin_project(project = StateProj)]
217enum State {
218 New,
219 Connecting {
220 retry: bool,
221 #[pin]
222 resp: ResponseFuture,
223 },
224 Connected(#[pin] ByteStream),
225 WaitingToReconnect(#[pin] Sleep),
226 FollowingRedirect(Option<HeaderValue>),
227 StreamClosed,
228}
229
230impl State {
231 fn name(&self) -> &'static str {
232 match self {
233 State::New => "new",
234 State::Connecting { retry: false, .. } => "connecting(no-retry)",
235 State::Connecting { retry: true, .. } => "connecting(retry)",
236 State::Connected(_) => "connected",
237 State::WaitingToReconnect(_) => "waiting-to-reconnect",
238 State::FollowingRedirect(_) => "following-redirect",
239 State::StreamClosed => "closed",
240 }
241 }
242}
243
244impl Debug for State {
245 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
246 write!(f, "{}", self.name())
247 }
248}
249
250#[must_use = "streams do nothing unless polled"]
251#[pin_project]
252pub struct ReconnectingRequest<T: HttpTransport> {
253 transport: Arc<T>,
254 props: RequestProps,
255 #[pin]
256 state: State,
257 retry_strategy: Box<dyn RetryStrategy + Send + Sync>,
258 current_url: Uri,
259 redirect_count: u32,
260 event_parser: EventParser,
261 last_event_id: Option<String>,
262 #[pin]
263 initial_connection: bool,
264}
265
266impl<T: HttpTransport> ReconnectingRequest<T> {
267 fn new(
268 transport: Arc<T>,
269 props: RequestProps,
270 last_event_id: Option<String>,
271 ) -> ReconnectingRequest<T> {
272 let reconnect_delay = props.reconnect_opts.delay;
273 let delay_max = props.reconnect_opts.delay_max;
274 let backoff_factor = props.reconnect_opts.backoff_factor;
275
276 let url = props.url.clone();
277 ReconnectingRequest {
278 props,
279 transport,
280 state: State::New,
281 retry_strategy: Box::new(BackoffRetry::new(
282 reconnect_delay,
283 delay_max,
284 backoff_factor,
285 true,
286 )),
287 redirect_count: 0,
288 current_url: url,
289 event_parser: EventParser::new(),
290 last_event_id,
291 initial_connection: true,
292 }
293 }
294
295 fn send_request(&self) -> Result<ResponseFuture> {
296 let mut request_builder = Request::builder()
297 .method(self.props.method.as_str())
298 .uri(&self.current_url);
299
300 for (name, value) in &self.props.headers {
301 request_builder = request_builder.header(name, value);
302 }
303
304 if let Some(id) = self.last_event_id.as_ref() {
305 if !id.is_empty() {
306 let id_as_header =
307 HeaderValue::from_str(id).map_err(|e| Error::InvalidParameter(Box::new(e)))?;
308
309 request_builder = request_builder.header("last-event-id", id_as_header);
310 }
311 }
312
313 let request = request_builder
316 .body(self.props.body.clone().map(|b| b.into()))
317 .map_err(|e| Error::InvalidParameter(Box::new(e)))?;
318
319 Ok(self.transport.request(request))
320 }
321
322 fn reset_redirects(self: Pin<&mut Self>) {
323 let url = self.props.url.clone();
324 let this = self.project();
325 *this.current_url = url;
326 *this.redirect_count = 0;
327 }
328
329 fn increment_redirect_counter(self: Pin<&mut Self>) -> bool {
330 if self.redirect_count == self.props.max_redirects {
331 return false;
332 }
333 *self.project().redirect_count += 1;
334 true
335 }
336}
337
338impl<T: HttpTransport> Stream for ReconnectingRequest<T> {
339 type Item = Result<SSE>;
340
341 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
342 trace!("ReconnectingRequest::poll({:?})", &self.state);
343
344 loop {
345 let this = self.as_mut().project();
346 if let Some(event) = this.event_parser.get_event() {
347 return match event {
348 SSE::Connected(_) => Poll::Ready(Some(Ok(event))),
349 SSE::Event(ref evt) => {
350 this.last_event_id.clone_from(&evt.id);
351
352 if let Some(retry) = evt.retry {
353 this.retry_strategy
354 .change_base_delay(Duration::from_millis(retry));
355 }
356 Poll::Ready(Some(Ok(event)))
357 }
358 SSE::Comment(_) => Poll::Ready(Some(Ok(event))),
359 };
360 }
361
362 trace!("ReconnectingRequest::poll loop({:?})", &this.state);
363
364 let state = this.state.project();
365 match state {
366 StateProj::StreamClosed => return Poll::Ready(None),
367 StateProj::New => {
370 *self.as_mut().project().event_parser = EventParser::new();
371 match self.send_request() {
372 Ok(resp) => {
373 let retry = if self.initial_connection {
374 self.props.reconnect_opts.retry_initial
375 } else {
376 self.props.reconnect_opts.reconnect
377 };
378 self.as_mut()
379 .project()
380 .state
381 .set(State::Connecting { resp, retry })
382 }
383 Err(e) => {
384 self.as_mut().project().state.set(State::StreamClosed);
387 return Poll::Ready(Some(Err(e)));
388 }
389 }
390 }
391 StateProj::Connecting { retry, resp } => match ready!(resp.poll(cx)) {
392 Ok(resp) => {
393 debug!(
394 "HTTP response status: {}, headers: {:?}",
395 resp.status(),
396 resp.headers()
397 );
398
399 if resp.status().is_success() {
400 self.as_mut().project().retry_strategy.reset(Instant::now());
401 self.as_mut().reset_redirects();
402
403 let status = resp.status();
404 let headers = resp.headers().clone();
405
406 self.as_mut()
407 .project()
408 .state
409 .set(State::Connected(resp.into_body()));
410 self.as_mut().project().initial_connection.set(false);
411
412 return Poll::Ready(Some(Ok(SSE::Connected(ConnectionDetails::new(
413 Response::new(status, headers),
414 )))));
415 }
416
417 if resp.status() == 301 || resp.status() == 307 {
418 debug!("got redirected ({})", resp.status());
419
420 if self.as_mut().increment_redirect_counter() {
421 debug!("following redirect {}", self.redirect_count);
422
423 self.as_mut().project().state.set(State::FollowingRedirect(
424 resp.headers().get("location").cloned(),
425 ));
426 continue;
427 } else {
428 debug!("redirect limit reached ({})", self.props.max_redirects);
429
430 self.as_mut().project().state.set(State::StreamClosed);
431 return Poll::Ready(Some(Err(Error::MaxRedirectLimitReached(
432 self.props.max_redirects,
433 ))));
434 }
435 }
436
437 let status = resp.status();
438 let headers = resp.headers().clone();
439 let body = resp.into_body();
440
441 let error = Error::UnexpectedResponse(
442 Response::new(status, headers),
443 ErrorBody::new(body),
444 );
445
446 if !*retry {
447 self.as_mut().project().state.set(State::StreamClosed);
448 return Poll::Ready(Some(Err(error)));
449 }
450
451 self.as_mut().reset_redirects();
452
453 let duration = self
454 .as_mut()
455 .project()
456 .retry_strategy
457 .next_delay(Instant::now());
458
459 self.as_mut()
460 .project()
461 .state
462 .set(State::WaitingToReconnect(delay(duration, "retrying")));
463
464 return Poll::Ready(Some(Err(error)));
465 }
466 Err(e) => {
467 warn!("request returned an error: {e}");
469 if !*retry {
470 self.as_mut().project().state.set(State::StreamClosed);
471 return Poll::Ready(Some(Err(Error::Transport(e))));
472 }
473
474 let duration = self
475 .as_mut()
476 .project()
477 .retry_strategy
478 .next_delay(Instant::now());
479
480 self.as_mut()
481 .project()
482 .state
483 .set(State::WaitingToReconnect(delay(duration, "retrying")));
484 }
485 },
486 StateProj::FollowingRedirect(maybe_header) => match uri_from_header(maybe_header) {
487 Ok(uri) => {
488 *self.as_mut().project().current_url = uri;
489 self.as_mut().project().state.set(State::New);
490 }
491 Err(e) => {
492 self.as_mut().project().state.set(State::StreamClosed);
493 return Poll::Ready(Some(Err(e)));
494 }
495 },
496 StateProj::Connected(mut body) => match ready!(body.as_mut().poll_next(cx)) {
497 Some(Ok(result)) => {
498 this.event_parser.process_bytes(result)?;
499 continue;
500 }
501 Some(Err(e)) => {
502 if self.props.reconnect_opts.reconnect {
503 let duration = self
504 .as_mut()
505 .project()
506 .retry_strategy
507 .next_delay(Instant::now());
508 self.as_mut()
509 .project()
510 .state
511 .set(State::WaitingToReconnect(delay(duration, "reconnecting")));
512 }
513
514 if let Some(cause) = e.source() {
516 if let Some(downcast) = cause.downcast_ref::<std::io::Error>() {
517 if let std::io::ErrorKind::TimedOut = downcast.kind() {
518 return Poll::Ready(Some(Err(Error::TimedOut)));
519 }
520 }
521 }
522
523 return Poll::Ready(Some(Err(Error::Transport(e))));
524 }
525 None => {
526 let duration = self
527 .as_mut()
528 .project()
529 .retry_strategy
530 .next_delay(Instant::now());
531 self.as_mut()
532 .project()
533 .state
534 .set(State::WaitingToReconnect(delay(duration, "retrying")));
535
536 if self.event_parser.was_processing() {
537 return Poll::Ready(Some(Err(Error::UnexpectedEof)));
538 }
539 return Poll::Ready(Some(Err(Error::Eof)));
540 }
541 },
542 StateProj::WaitingToReconnect(delay) => {
543 ready!(delay.poll(cx));
544 info!("Reconnecting");
545 self.as_mut().project().state.set(State::New);
546 }
547 };
548 }
549 }
550}
551
552fn uri_from_header(maybe_header: &Option<HeaderValue>) -> Result<Uri> {
553 let header = maybe_header.as_ref().ok_or_else(|| {
554 Error::MalformedLocationHeader(Box::new(std::io::Error::new(
555 ErrorKind::NotFound,
556 "missing Location header",
557 )))
558 })?;
559
560 let header_string = header
561 .to_str()
562 .map_err(|e| Error::MalformedLocationHeader(Box::new(e)))?;
563
564 header_string
565 .parse::<Uri>()
566 .map_err(|e| Error::MalformedLocationHeader(Box::new(e)))
567}
568
569fn delay(dur: Duration, description: &str) -> Sleep {
570 info!("Waiting {dur:?} before {description}");
571 tokio::time::sleep(dur)
572}
573
574mod private {
575 use crate::client::ClientImpl;
576 use launchdarkly_sdk_transport::HttpTransport;
577
578 pub trait Sealed {}
579 impl<T: HttpTransport> Sealed for ClientImpl<T> {}
580}
581
582#[cfg(test)]
583mod tests {
584 use crate::ClientBuilder;
585 use http::HeaderValue;
586 use test_case::test_case;
587
588 #[test_case("user", "pass", "dXNlcjpwYXNz")]
589 #[test_case("user1", "password123", "dXNlcjE6cGFzc3dvcmQxMjM=")]
590 #[test_case("user2", "", "dXNlcjI6")]
591 #[test_case("user@name", "pass#word!", "dXNlckBuYW1lOnBhc3Mjd29yZCE=")]
592 #[test_case("user3", "my pass", "dXNlcjM6bXkgcGFzcw==")]
593 #[test_case(
594 "weird@-/:stuff",
595 "goes@-/:here",
596 "d2VpcmRALS86c3R1ZmY6Z29lc0AtLzpoZXJl"
597 )]
598 fn basic_auth_generates_correct_headers(username: &str, password: &str, expected: &str) {
599 let builder = ClientBuilder::for_url("http://example.com")
600 .expect("failed to build client")
601 .basic_auth(username, password)
602 .expect("failed to add authentication");
603
604 let actual = builder.headers.get("Authorization");
605 let expected = HeaderValue::from_str(format!("Basic {expected}").as_str())
606 .expect("unable to create expected header");
607
608 assert_eq!(Some(&expected), actual);
609 }
610
611 use std::{pin::pin, sync::Arc, time::Duration};
612
613 use bytes::Bytes;
614 use futures::{stream, TryStreamExt};
615 use http::HeaderMap;
616 use tokio::time::timeout;
617
618 use crate::{
619 client::{RequestProps, State},
620 ReconnectOptionsBuilder, ReconnectingRequest,
621 };
622 use launchdarkly_sdk_transport::{ByteStream, HttpTransport, ResponseFuture, TransportError};
623
624 #[derive(Clone)]
626 struct MockTransport {
627 fail_request: bool,
628 }
629
630 impl MockTransport {
631 fn new(_url: String, fail_request: bool) -> Self {
632 Self { fail_request }
633 }
634 }
635
636 impl HttpTransport for MockTransport {
637 fn request(&self, _request: http::Request<Option<Bytes>>) -> ResponseFuture {
638 if self.fail_request {
639 Box::pin(async {
641 Err(TransportError::new(std::io::Error::new(
642 std::io::ErrorKind::ConnectionRefused,
643 "connection refused",
644 )))
645 })
646 } else {
647 Box::pin(async {
649 let byte_stream: ByteStream =
650 Box::pin(stream::iter(vec![Ok(Bytes::from("not found"))]));
651 let response = http::Response::builder()
652 .status(404)
653 .body(byte_stream)
654 .unwrap();
655 Ok(response)
656 })
657 }
658 }
659 }
660
661 const INVALID_URI: &str = "http://mycrazyunexsistenturl.invaliddomainext";
662
663 #[test_case(INVALID_URI, false, |state| matches!(state, State::StreamClosed))]
664 #[test_case(INVALID_URI, true, |state| matches!(state, State::WaitingToReconnect(_)))]
665 #[tokio::test]
666 async fn initial_connection(uri: &str, retry_initial: bool, expected: fn(&State) -> bool) {
667 let reconnect_opts = ReconnectOptionsBuilder::new(false)
668 .backoff_factor(1)
669 .delay(Duration::from_secs(1))
670 .retry_initial(retry_initial)
671 .build();
672
673 let transport = Arc::new(MockTransport::new(uri.to_string(), true));
674 let req_props = RequestProps {
675 url: uri.parse().unwrap(),
676 headers: HeaderMap::new(),
677 method: "GET".to_string(),
678 body: None,
679 reconnect_opts,
680 max_redirects: 10,
681 };
682
683 let mut reconnecting_request = ReconnectingRequest::new(transport.clone(), req_props, None);
684
685 let resp = transport.request(http::Request::builder().uri(uri).body(None).unwrap());
687
688 reconnecting_request.state = State::Connecting {
689 retry: reconnecting_request.props.reconnect_opts.retry_initial,
690 resp,
691 };
692
693 let mut reconnecting_request = pin!(reconnecting_request);
694
695 timeout(Duration::from_millis(500), reconnecting_request.try_next())
696 .await
697 .ok();
698
699 assert!(expected(&reconnecting_request.state));
700 }
701
702 #[test_case(false, |state| matches!(state, State::StreamClosed))]
703 #[test_case(true, |state| matches!(state, State::WaitingToReconnect(_)))]
704 #[tokio::test]
705 async fn initial_connection_mocked_server(retry_initial: bool, expected: fn(&State) -> bool) {
706 let mut mock_server = mockito::Server::new_async().await;
707 let _mock = mock_server
708 .mock("GET", "/")
709 .with_status(404)
710 .create_async()
711 .await;
712
713 initial_connection(&mock_server.url(), retry_initial, expected).await;
714 }
715}