1use std::error::Error as stdError;
2use std::fmt::Debug;
3use std::time::Duration;
4use std::time::Instant;
5
6use futures::Future;
7use futures::Stream;
8use futures::StreamExt;
9use graphql_client::QueryBody;
10use thiserror::Error;
11use tokio::sync::mpsc::channel;
12use tokio_stream::wrappers::ReceiverStream;
13use tower::BoxError;
14use tracing::instrument::WithSubscriber;
15use url::Url;
16
17pub(crate) mod feature_gate_enforcement;
18pub(crate) mod license_enforcement;
19pub(crate) mod license_stream;
20mod parsed_link_spec;
21pub(crate) mod persisted_queries_manifest_stream;
22pub(crate) mod schema;
23pub(crate) mod schema_stream;
24
25const GCP_URL: &str = "https://uplink.api.apollographql.com";
26const AWS_URL: &str = "https://aws.uplink.api.apollographql.com";
27
28#[derive(Debug, Error)]
29pub(crate) enum Error {
30 #[error("http error")]
31 Http(#[from] reqwest::Error),
32
33 #[error("fetch failed from uplink endpoint, and there are no fallback endpoints configured")]
34 FetchFailedSingle,
35
36 #[error("fetch failed from all {url_count} uplink endpoints")]
37 FetchFailedMultiple { url_count: usize },
38
39 #[allow(clippy::enum_variant_names)]
40 #[error("uplink error: code={code} message={message}")]
41 UplinkError { code: String, message: String },
42
43 #[error("uplink error, the request will not be retried: code={code} message={message}")]
44 UplinkErrorNoRetry { code: String, message: String },
45}
46
47#[derive(Debug)]
48pub(crate) struct UplinkRequest {
49 api_key: String,
50 graph_ref: String,
51 id: Option<String>,
52}
53
54#[derive(Debug)]
55pub(crate) enum UplinkResponse<Response>
56where
57 Response: Send + Debug + 'static,
58{
59 New {
60 response: Response,
61 id: String,
62 delay: u64,
63 },
64 Unchanged {
65 id: Option<String>,
66 delay: Option<u64>,
67 },
68 Error {
69 retry_later: bool,
70 code: String,
71 message: String,
72 },
73}
74
75#[derive(Debug, Clone)]
76pub enum Endpoints {
77 Fallback {
78 urls: Vec<Url>,
79 },
80 #[allow(dead_code)]
81 RoundRobin {
82 urls: Vec<Url>,
83 current: usize,
84 },
85}
86
87impl Default for Endpoints {
88 fn default() -> Self {
89 Self::fallback(
90 [GCP_URL, AWS_URL]
91 .iter()
92 .map(|url| Url::parse(url).expect("default urls must be valid"))
93 .collect(),
94 )
95 }
96}
97
98impl Endpoints {
99 pub(crate) fn fallback(urls: Vec<Url>) -> Self {
100 Endpoints::Fallback { urls }
101 }
102 #[allow(dead_code)]
103 pub(crate) fn round_robin(urls: Vec<Url>) -> Self {
104 Endpoints::RoundRobin { urls, current: 0 }
105 }
106
107 fn iter<'a>(&'a mut self) -> Box<dyn Iterator<Item = &'a Url> + Send + 'a> {
111 match self {
112 Endpoints::Fallback { urls } => Box::new(urls.iter()),
113 Endpoints::RoundRobin { urls, current } => {
114 *current %= urls.len();
116
117 Box::new(
121 urls.iter()
122 .cycle()
123 .skip(*current)
124 .inspect(|_| {
125 *current += 1;
126 })
127 .take(urls.len()),
128 )
129 }
130 }
131 }
132
133 pub(crate) fn url_count(&self) -> usize {
134 match self {
135 Endpoints::Fallback { urls } => urls.len(),
136 Endpoints::RoundRobin { urls, current: _ } => urls.len(),
137 }
138 }
139}
140
141#[derive(Debug, Clone, Default)]
144pub struct UplinkConfig {
145 pub apollo_key: String,
147
148 pub apollo_graph_ref: String,
150
151 pub endpoints: Option<Endpoints>,
153
154 pub poll_interval: Duration,
156
157 pub timeout: Duration,
159}
160
161impl UplinkConfig {
162 pub fn for_tests(uplink_endpoints: Endpoints) -> Self {
165 Self {
166 apollo_key: "key".to_string(),
167 apollo_graph_ref: "graph".to_string(),
168 endpoints: Some(uplink_endpoints),
169 poll_interval: Duration::from_secs(2),
170 timeout: Duration::from_secs(5),
171 }
172 }
173}
174
175pub(crate) fn stream_from_uplink<Query, Response>(
178 uplink_config: UplinkConfig,
179) -> impl Stream<Item = Result<Response, Error>>
180where
181 Query: graphql_client::GraphQLQuery,
182 <Query as graphql_client::GraphQLQuery>::ResponseData: Into<UplinkResponse<Response>> + Send,
183 <Query as graphql_client::GraphQLQuery>::Variables: From<UplinkRequest> + Send + Sync,
184 Response: Send + 'static + Debug,
185{
186 stream_from_uplink_transforming_new_response::<Query, Response, Response>(
187 uplink_config,
188 |response| Box::new(Box::pin(async { Ok(response) })),
189 )
190}
191
192pub(crate) fn stream_from_uplink_transforming_new_response<Query, Response, TransformedResponse>(
200 mut uplink_config: UplinkConfig,
201 transform_new_response: impl Fn(
202 Response,
203 ) -> Box<
204 dyn Future<Output = Result<TransformedResponse, BoxError>> + Send + Unpin,
205 > + Send
206 + Sync
207 + 'static,
208) -> impl Stream<Item = Result<TransformedResponse, Error>>
209where
210 Query: graphql_client::GraphQLQuery,
211 <Query as graphql_client::GraphQLQuery>::ResponseData: Into<UplinkResponse<Response>> + Send,
212 <Query as graphql_client::GraphQLQuery>::Variables: From<UplinkRequest> + Send + Sync,
213 Response: Send + 'static + Debug,
214 TransformedResponse: Send + 'static + Debug,
215{
216 let query_name = query_name::<Query>();
217 let (sender, receiver) = channel(2);
218 let client = match reqwest::Client::builder()
219 .no_gzip()
220 .timeout(uplink_config.timeout)
221 .build()
222 {
223 Ok(client) => client,
224 Err(err) => {
225 tracing::error!("unable to create client to query uplink: {err}", err = err);
226 return futures::stream::empty().boxed();
227 }
228 };
229
230 let task = async move {
231 let mut last_id = None;
232 let mut endpoints = uplink_config.endpoints.unwrap_or_default();
233 loop {
234 let variables = UplinkRequest {
235 graph_ref: uplink_config.apollo_graph_ref.to_string(),
236 api_key: uplink_config.apollo_key.to_string(),
237 id: last_id.clone(),
238 };
239
240 let query_body = Query::build_query(variables.into());
241
242 match fetch::<Query, Response, TransformedResponse>(
243 &client,
244 &query_body,
245 &mut endpoints,
246 &transform_new_response,
247 )
248 .await
249 {
250 Ok(response) => {
251 u64_counter!(
252 "apollo.router.uplink.fetch.count.total",
253 "Total number of requests to Apollo Uplink",
254 1u64,
255 status = "success",
256 query = query_name
257 );
258 match response {
259 UplinkResponse::New {
260 id,
261 response,
262 delay,
263 } => {
264 last_id = Some(id);
265 uplink_config.poll_interval = Duration::from_secs(delay);
266
267 if let Err(e) = sender.send(Ok(response)).await {
268 tracing::debug!(
269 "failed to push to stream. This is likely to be because the router is shutting down: {e}"
270 );
271 break;
272 }
273 }
274 UplinkResponse::Unchanged { id, delay } => {
275 if let Some(id) = id {
277 last_id = Some(id);
278 }
279 if let Some(delay) = delay {
280 uplink_config.poll_interval = Duration::from_secs(delay);
281 }
282 }
283 UplinkResponse::Error {
284 retry_later,
285 message,
286 code,
287 } => {
288 let err = if retry_later {
289 Err(Error::UplinkError { code, message })
290 } else {
291 Err(Error::UplinkErrorNoRetry { code, message })
292 };
293 if let Err(e) = sender.send(err).await {
294 tracing::debug!(
295 "failed to send error to uplink stream. This is likely to be because the router is shutting down: {e}"
296 );
297 break;
298 }
299 if !retry_later {
300 break;
301 }
302 }
303 }
304 }
305 Err(err) => {
306 u64_counter!(
307 "apollo.router.uplink.fetch.count.total",
308 "Total number of requests to Apollo Uplink",
309 1u64,
310 status = "failure",
311 query = query_name
312 );
313 if let Err(e) = sender.send(Err(err)).await {
314 tracing::debug!(
315 "failed to send error to uplink stream. This is likely to be because the router is shutting down: {e}"
316 );
317 break;
318 }
319 }
320 }
321
322 tokio::time::sleep(uplink_config.poll_interval).await;
323 }
324 };
325 drop(tokio::task::spawn(task.with_current_subscriber()));
326
327 ReceiverStream::new(receiver).boxed()
328}
329
330pub(crate) async fn fetch<Query, Response, TransformedResponse>(
331 client: &reqwest::Client,
332 request_body: &QueryBody<Query::Variables>,
333 endpoints: &mut Endpoints,
334 transform_new_response: &(
337 impl Fn(
338 Response,
339 ) -> Box<dyn Future<Output = Result<TransformedResponse, BoxError>> + Send + Unpin>
340 + Send
341 + Sync
342 + 'static
343 ),
344) -> Result<UplinkResponse<TransformedResponse>, Error>
345where
346 Query: graphql_client::GraphQLQuery,
347 <Query as graphql_client::GraphQLQuery>::ResponseData: Into<UplinkResponse<Response>> + Send,
348 <Query as graphql_client::GraphQLQuery>::Variables: From<UplinkRequest> + Send + Sync,
349 Response: Send + Debug + 'static,
350 TransformedResponse: Send + Debug + 'static,
351{
352 let query = query_name::<Query>();
353 for url in endpoints.iter() {
354 let now = Instant::now();
355 match http_request::<Query>(client, url.as_str(), request_body).await {
356 Ok(response) => match response.data.map(Into::into) {
357 None => {
358 f64_histogram!(
359 "apollo.router.uplink.fetch.duration.seconds",
360 "Duration of Apollo Uplink fetches.",
361 now.elapsed().as_secs_f64(),
362 query = query,
363 url = url.to_string(),
364 kind = "uplink_error",
365 error = "empty response from uplink"
366 );
367 }
368 Some(UplinkResponse::New {
369 response,
370 id,
371 delay,
372 }) => {
373 f64_histogram!(
374 "apollo.router.uplink.fetch.duration.seconds",
375 "Duration of Apollo Uplink fetches.",
376 now.elapsed().as_secs_f64(),
377 query = query,
378 url = url.to_string(),
379 kind = "new"
380 );
381 match transform_new_response(response).await {
382 Ok(res) => {
383 return Ok(UplinkResponse::New {
384 response: res,
385 id,
386 delay,
387 });
388 }
389 Err(err) => {
390 tracing::debug!(
391 "failed to process results of Uplink response from {url}: {err}. Other endpoints will be tried"
392 );
393 continue;
394 }
395 }
396 }
397 Some(UplinkResponse::Unchanged { id, delay }) => {
398 f64_histogram!(
399 "apollo.router.uplink.fetch.duration.seconds",
400 "Duration of Apollo Uplink fetches.",
401 now.elapsed().as_secs_f64(),
402 query = query,
403 url = url.to_string(),
404 kind = "unchanged"
405 );
406 return Ok(UplinkResponse::Unchanged { id, delay });
407 }
408 Some(UplinkResponse::Error {
409 message,
410 code,
411 retry_later,
412 }) => {
413 f64_histogram!(
414 "apollo.router.uplink.fetch.duration.seconds",
415 "Duration of Apollo Uplink fetches.",
416 now.elapsed().as_secs_f64(),
417 query = query,
418 url = url.to_string(),
419 kind = "uplink_error",
420 error = message.clone(),
421 code = code.clone()
422 );
423 return Ok(UplinkResponse::Error {
424 message,
425 code,
426 retry_later,
427 });
428 }
429 },
430 Err(err) => {
431 f64_histogram!(
432 "apollo.router.uplink.fetch.duration.seconds",
433 "Duration of Apollo Uplink fetches.",
434 now.elapsed().as_secs_f64(),
435 query = query,
436 url = url.to_string(),
437 kind = "http_error",
438 error = err.to_string(),
439 code = err.status().unwrap_or_default().to_string()
440 );
441 tracing::debug!(
442 "failed to fetch from Uplink endpoint {url}: {err}. Other endpoints will be tried"
443 );
444 }
445 };
446 }
447
448 let url_count = endpoints.url_count();
449 if url_count == 1 {
450 Err(Error::FetchFailedSingle)
451 } else {
452 Err(Error::FetchFailedMultiple { url_count })
453 }
454}
455
456fn query_name<Query>() -> &'static str {
457 let mut query = std::any::type_name::<Query>();
458 query = query
459 .strip_suffix("Query")
460 .expect("Uplink structs must be named xxxQuery")
461 .get(query.rfind("::").map(|index| index + 2).unwrap_or_default()..)
462 .expect("cannot fail");
463 query
464}
465
466async fn http_request<Query>(
467 client: &reqwest::Client,
468 url: &str,
469 request_body: &QueryBody<Query::Variables>,
470) -> Result<graphql_client::Response<Query::ResponseData>, reqwest::Error>
471where
472 Query: graphql_client::GraphQLQuery,
473{
474 let res = client
481 .post(url)
482 .header("x-router-version", env!("CARGO_PKG_VERSION"))
483 .json(request_body)
484 .send()
485 .await
486 .inspect_err(|e| {
487 if let Some(hyper_err) = e.source()
488 && let Some(os_err) = hyper_err.source()
489 && os_err.to_string().contains("tcp connect error: Cannot assign requested address (os error 99)") {
490 tracing::warn!("If your router is executing within a kubernetes pod, this failure may be caused by istio-proxy injection. See https://github.com/apollographql/router/issues/3533 for more details about how to solve this");
491 }
492 })?;
493 tracing::debug!("uplink response {:?}", res);
494 let response_body: graphql_client::Response<Query::ResponseData> = res.json().await?;
495 Ok(response_body)
496}
497
498#[cfg(test)]
499mod test {
500 use std::collections::VecDeque;
501 use std::time::Duration;
502
503 use buildstructor::buildstructor;
504 use futures::StreamExt;
505 use graphql_client::GraphQLQuery;
506 use http::StatusCode;
507 use insta::assert_yaml_snapshot;
508 use parking_lot::Mutex;
509 use serde_json::json;
510 use test_query::FetchErrorCode;
511 use test_query::TestQueryUplinkQuery;
512 use url::Url;
513 use wiremock::Mock;
514 use wiremock::MockServer;
515 use wiremock::Request;
516 use wiremock::Respond;
517 use wiremock::ResponseTemplate;
518 use wiremock::matchers::method;
519 use wiremock::matchers::path;
520
521 use crate::uplink::Endpoints;
522 use crate::uplink::Error;
523 use crate::uplink::UplinkConfig;
524 use crate::uplink::UplinkRequest;
525 use crate::uplink::UplinkResponse;
526 use crate::uplink::stream_from_uplink;
527 use crate::uplink::stream_from_uplink_transforming_new_response;
528
529 #[derive(GraphQLQuery)]
530 #[graphql(
531 query_path = "src/uplink/testdata/test_query.graphql",
532 schema_path = "src/uplink/testdata/test_uplink.graphql",
533 request_derives = "Debug",
534 response_derives = "PartialEq, Debug, Deserialize",
535 deprecated = "warn"
536 )]
537 pub(crate) struct TestQuery {}
538
539 #[derive(Debug, Eq, PartialEq)]
540 struct QueryResult {
541 name: String,
542 ordering: i64,
543 }
544
545 #[allow(dead_code)]
546 #[derive(Debug)]
547 struct TransformedQueryResult {
548 name: String,
549 halved_ordering: i64,
550 }
551
552 impl From<UplinkRequest> for test_query::Variables {
553 fn from(req: UplinkRequest) -> Self {
554 test_query::Variables {
555 api_key: req.api_key,
556 graph_ref: req.graph_ref,
557 if_after_id: req.id,
558 }
559 }
560 }
561
562 impl From<test_query::ResponseData> for UplinkResponse<QueryResult> {
563 fn from(response: test_query::ResponseData) -> Self {
564 match response.uplink_query {
565 TestQueryUplinkQuery::New(response) => UplinkResponse::New {
566 id: response.id,
567 delay: response.min_delay_seconds as u64,
568 response: QueryResult {
569 name: response.data.name,
570 ordering: response.data.ordering,
571 },
572 },
573 TestQueryUplinkQuery::Unchanged(response) => UplinkResponse::Unchanged {
574 id: Some(response.id),
575 delay: Some(response.min_delay_seconds as u64),
576 },
577 TestQueryUplinkQuery::FetchError(error) => UplinkResponse::Error {
578 retry_later: error.code == FetchErrorCode::RETRY_LATER,
579 code: match error.code {
580 FetchErrorCode::AUTHENTICATION_FAILED => {
581 "AUTHENTICATION_FAILED".to_string()
582 }
583 FetchErrorCode::ACCESS_DENIED => "ACCESS_DENIED".to_string(),
584 FetchErrorCode::UNKNOWN_REF => "UNKNOWN_REF".to_string(),
585 FetchErrorCode::RETRY_LATER => "RETRY_LATER".to_string(),
586 FetchErrorCode::Other(other) => other,
587 },
588 message: error.message,
589 },
590 }
591 }
592 }
593
594 fn mock_uplink_config_with_fallback_urls(urls: Vec<Url>) -> UplinkConfig {
595 UplinkConfig {
596 apollo_key: "dummy_key".to_string(),
597 apollo_graph_ref: "dummy_graph_ref".to_string(),
598 endpoints: Some(Endpoints::fallback(urls)),
599 poll_interval: Duration::from_secs(0),
600 timeout: Duration::from_secs(1),
601 }
602 }
603
604 fn mock_uplink_config_with_round_robin_urls(urls: Vec<Url>) -> UplinkConfig {
605 UplinkConfig {
606 apollo_key: "dummy_key".to_string(),
607 apollo_graph_ref: "dummy_graph_ref".to_string(),
608 endpoints: Some(Endpoints::round_robin(urls)),
609 poll_interval: Duration::from_secs(0),
610 timeout: Duration::from_secs(1),
611 }
612 }
613
614 #[test]
615 fn test_round_robin_endpoints() {
616 let url1 = Url::parse("http://example1.com").expect("url must be valid");
617 let url2 = Url::parse("http://example2.com").expect("url must be valid");
618 let mut endpoints = Endpoints::round_robin(vec![url1.clone(), url2.clone()]);
619 assert_eq!(endpoints.iter().collect::<Vec<_>>(), vec![&url1, &url2]);
620 assert_eq!(endpoints.iter().next(), Some(&url1));
621 assert_eq!(endpoints.iter().collect::<Vec<_>>(), vec![&url2, &url1]);
622 }
623
624 #[test]
625 fn test_fallback_endpoints() {
626 let url1 = Url::parse("http://example1.com").expect("url must be valid");
627 let url2 = Url::parse("http://example2.com").expect("url must be valid");
628 let mut endpoints = Endpoints::fallback(vec![url1.clone(), url2.clone()]);
629 assert_eq!(endpoints.iter().collect::<Vec<_>>(), vec![&url1, &url2]);
630 assert_eq!(endpoints.iter().next(), Some(&url1));
631 assert_eq!(endpoints.iter().collect::<Vec<_>>(), vec![&url1, &url2]);
632 }
633
634 #[tokio::test(flavor = "multi_thread")]
635 async fn stream_from_uplink_fallback() {
636 let (mock_server, url1, url2, _url3) = init_mock_server().await;
637 MockResponses::builder()
638 .mock_server(&mock_server)
639 .endpoint(&url1)
640 .response(response_ok(1))
641 .response(response_ok(2))
642 .build()
643 .await;
644 MockResponses::builder()
645 .mock_server(&mock_server)
646 .endpoint(&url2)
647 .build()
648 .await;
649
650 let results = stream_from_uplink::<TestQuery, QueryResult>(
651 mock_uplink_config_with_fallback_urls(vec![url1, url2]),
652 )
653 .take(2)
654 .collect::<Vec<_>>()
655 .await;
656 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
657 }
658
659 #[tokio::test(flavor = "multi_thread")]
660 async fn stream_from_uplink_round_robin() {
661 let (mock_server, url1, url2, _url3) = init_mock_server().await;
662 MockResponses::builder()
663 .mock_server(&mock_server)
664 .endpoint(&url1)
665 .response(response_ok(1))
666 .build()
667 .await;
668 MockResponses::builder()
669 .mock_server(&mock_server)
670 .response(response_ok(2))
671 .endpoint(&url2)
672 .build()
673 .await;
674
675 let results = stream_from_uplink::<TestQuery, QueryResult>(
676 mock_uplink_config_with_round_robin_urls(vec![url1, url2]),
677 )
678 .take(2)
679 .collect::<Vec<_>>()
680 .await;
681 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
682 }
683
684 #[tokio::test(flavor = "multi_thread")]
685 async fn stream_from_uplink_error_retry() {
686 let (mock_server, url1, url2, _url3) = init_mock_server().await;
687 MockResponses::builder()
688 .mock_server(&mock_server)
689 .endpoint(&url1)
690 .response(response_fetch_error_retry())
691 .response(response_ok(1))
692 .build()
693 .await;
694 let results = stream_from_uplink::<TestQuery, QueryResult>(
695 mock_uplink_config_with_fallback_urls(vec![url1, url2]),
696 )
697 .take(2)
698 .collect::<Vec<_>>()
699 .await;
700 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
701 }
702
703 #[tokio::test(flavor = "multi_thread")]
704 async fn stream_from_uplink_error_no_retry() {
705 let (mock_server, url1, url2, _url3) = init_mock_server().await;
706 MockResponses::builder()
707 .mock_server(&mock_server)
708 .endpoint(&url1)
709 .response(response_fetch_error_no_retry())
710 .build()
711 .await;
712 let results = stream_from_uplink::<TestQuery, QueryResult>(
713 mock_uplink_config_with_fallback_urls(vec![url1, url2]),
714 )
715 .collect::<Vec<_>>()
716 .await;
717 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
718 }
719
720 #[tokio::test(flavor = "multi_thread")]
721 async fn stream_from_uplink_error_http_fallback() {
722 let (mock_server, url1, url2, url3) = init_mock_server().await;
723 MockResponses::builder()
724 .mock_server(&mock_server)
725 .endpoint(&url1)
726 .response(response_fetch_error_http())
727 .response(response_fetch_error_http())
728 .build()
729 .await;
730 MockResponses::builder()
731 .mock_server(&mock_server)
732 .endpoint(&url2)
733 .response(response_ok(1))
734 .response(response_ok(2))
735 .build()
736 .await;
737 MockResponses::builder()
738 .mock_server(&mock_server)
739 .endpoint(&url3)
740 .build()
741 .await;
742 let results = stream_from_uplink::<TestQuery, QueryResult>(
743 mock_uplink_config_with_fallback_urls(vec![url1, url2, url3]),
744 )
745 .take(2)
746 .collect::<Vec<_>>()
747 .await;
748 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
749 }
750
751 #[tokio::test(flavor = "multi_thread")]
752 async fn stream_from_uplink_empty_http_fallback() {
753 let (mock_server, url1, url2, url3) = init_mock_server().await;
754 MockResponses::builder()
755 .mock_server(&mock_server)
756 .endpoint(&url1)
757 .response(response_empty())
758 .response(response_empty())
759 .build()
760 .await;
761 MockResponses::builder()
762 .mock_server(&mock_server)
763 .endpoint(&url2)
764 .response(response_ok(1))
765 .response(response_ok(2))
766 .build()
767 .await;
768 MockResponses::builder()
769 .mock_server(&mock_server)
770 .endpoint(&url3)
771 .build()
772 .await;
773 let results = stream_from_uplink::<TestQuery, QueryResult>(
774 mock_uplink_config_with_fallback_urls(vec![url1, url2, url3]),
775 )
776 .take(2)
777 .collect::<Vec<_>>()
778 .await;
779 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
780 }
781
782 #[tokio::test(flavor = "multi_thread")]
783 async fn stream_from_uplink_error_http_round_robin() {
784 let (mock_server, url1, url2, url3) = init_mock_server().await;
785 MockResponses::builder()
786 .mock_server(&mock_server)
787 .endpoint(&url1)
788 .response(response_fetch_error_http())
789 .build()
790 .await;
791 MockResponses::builder()
792 .mock_server(&mock_server)
793 .endpoint(&url2)
794 .response(response_ok(1))
795 .build()
796 .await;
797 MockResponses::builder()
798 .mock_server(&mock_server)
799 .endpoint(&url3)
800 .response(response_ok(2))
801 .build()
802 .await;
803 let results = stream_from_uplink::<TestQuery, QueryResult>(
804 mock_uplink_config_with_round_robin_urls(vec![url1, url2, url3]),
805 )
806 .take(2)
807 .collect::<Vec<_>>()
808 .await;
809 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
810 }
811
812 #[tokio::test(flavor = "multi_thread")]
813 async fn stream_from_uplink_empty_http_round_robin() {
814 let (mock_server, url1, url2, url3) = init_mock_server().await;
815 MockResponses::builder()
816 .mock_server(&mock_server)
817 .endpoint(&url1)
818 .response(response_empty())
819 .build()
820 .await;
821 MockResponses::builder()
822 .mock_server(&mock_server)
823 .endpoint(&url2)
824 .response(response_ok(1))
825 .build()
826 .await;
827 MockResponses::builder()
828 .mock_server(&mock_server)
829 .endpoint(&url3)
830 .response(response_ok(2))
831 .build()
832 .await;
833 let results = stream_from_uplink::<TestQuery, QueryResult>(
834 mock_uplink_config_with_round_robin_urls(vec![url1, url2, url3]),
835 )
836 .take(2)
837 .collect::<Vec<_>>()
838 .await;
839 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
840 }
841
842 #[tokio::test(flavor = "multi_thread")]
843 async fn stream_from_uplink_invalid() {
844 let (mock_server, url1, url2, url3) = init_mock_server().await;
845 MockResponses::builder()
846 .mock_server(&mock_server)
847 .endpoint(&url1)
848 .response(response_invalid_license())
849 .build()
850 .await;
851 let results = stream_from_uplink::<TestQuery, QueryResult>(
852 mock_uplink_config_with_round_robin_urls(vec![url1, url2, url3]),
853 )
854 .take(1)
855 .collect::<Vec<_>>()
856 .await;
857 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
858 }
859
860 #[tokio::test(flavor = "multi_thread")]
861 async fn stream_from_uplink_unchanged() {
862 let (mock_server, url1, url2, url3) = init_mock_server().await;
863 MockResponses::builder()
864 .mock_server(&mock_server)
865 .endpoint(&url1)
866 .response(response_ok(1))
867 .response(response_unchanged())
868 .response(response_ok(2))
869 .build()
870 .await;
871 let results = stream_from_uplink::<TestQuery, QueryResult>(
872 mock_uplink_config_with_round_robin_urls(vec![url1, url2, url3]),
873 )
874 .take(2)
875 .collect::<Vec<_>>()
876 .await;
877 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
878 }
879
880 #[tokio::test(flavor = "multi_thread")]
881 async fn stream_from_uplink_failed_from_all() {
882 let (mock_server, url1, url2, _url3) = init_mock_server().await;
883 MockResponses::builder()
884 .mock_server(&mock_server)
885 .endpoint(&url1)
886 .response(response_fetch_error_http())
887 .build()
888 .await;
889 MockResponses::builder()
890 .mock_server(&mock_server)
891 .endpoint(&url2)
892 .response(response_fetch_error_http())
893 .build()
894 .await;
895 let results = stream_from_uplink::<TestQuery, QueryResult>(
896 mock_uplink_config_with_round_robin_urls(vec![url1, url2]),
897 )
898 .take(1)
899 .collect::<Vec<_>>()
900 .await;
901 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
902 }
903
904 #[tokio::test(flavor = "multi_thread")]
905 async fn stream_from_uplink_failed_from_single() {
906 let (mock_server, url1, _url2, _url3) = init_mock_server().await;
907 MockResponses::builder()
908 .mock_server(&mock_server)
909 .endpoint(&url1)
910 .response(response_fetch_error_http())
911 .build()
912 .await;
913 let results = stream_from_uplink::<TestQuery, QueryResult>(
914 mock_uplink_config_with_fallback_urls(vec![url1]),
915 )
916 .take(1)
917 .collect::<Vec<_>>()
918 .await;
919 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
920 }
921
922 #[tokio::test(flavor = "multi_thread")]
923 async fn stream_from_uplink_transforming_new_response_first_response_transform_fails() {
924 let (mock_server, url1, url2, _url3) = init_mock_server().await;
925 MockResponses::builder()
926 .mock_server(&mock_server)
927 .endpoint(&url1)
928 .response(response_ok(15))
929 .build()
930 .await;
931 MockResponses::builder()
932 .mock_server(&mock_server)
933 .endpoint(&url2)
934 .response(response_ok(100))
935 .build()
936 .await;
937 let results = stream_from_uplink_transforming_new_response::<
938 TestQuery,
939 QueryResult,
940 TransformedQueryResult,
941 >(
942 mock_uplink_config_with_fallback_urls(vec![url1, url2]),
943 |result| {
944 Box::new(Box::pin(async move {
945 let QueryResult { name, ordering } = result;
946 if ordering % 2 == 0 {
947 Ok(TransformedQueryResult {
949 name,
950 halved_ordering: ordering / 2,
951 })
952 } else {
953 Err("cannot halve an odd number".into())
955 }
956 }))
957 },
958 )
959 .take(1)
960 .collect::<Vec<_>>()
961 .await;
962 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
963 }
964
965 fn to_friendly<R: std::fmt::Debug>(r: Result<R, Error>) -> Result<String, String> {
966 match r {
967 Ok(e) => Ok(format!("result {e:?}")),
968 Err(e) => Err(e.to_string()),
969 }
970 }
971
972 async fn init_mock_server() -> (MockServer, Url, Url, Url) {
973 let mock_server = MockServer::start().await;
974 let url1 =
975 Url::parse(&format!("{}/endpoint1", mock_server.uri())).expect("url must be valid");
976 let url2 =
977 Url::parse(&format!("{}/endpoint2", mock_server.uri())).expect("url must be valid");
978 let url3 =
979 Url::parse(&format!("{}/endpoint3", mock_server.uri())).expect("url must be valid");
980 (mock_server, url1, url2, url3)
981 }
982
983 struct MockResponses {
984 responses: Mutex<VecDeque<ResponseTemplate>>,
985 }
986
987 impl Respond for MockResponses {
988 fn respond(&self, _request: &Request) -> ResponseTemplate {
989 self.responses
990 .lock()
991 .pop_front()
992 .unwrap_or_else(response_fetch_error_test_error)
993 }
994 }
995
996 #[buildstructor]
997 impl MockResponses {
998 #[builder(entry = "builder")]
999 async fn setup<'a>(
1000 mock_server: &'a MockServer,
1001 endpoint: &'a Url,
1002 responses: Vec<ResponseTemplate>,
1003 ) {
1004 let len = responses.len() as u64;
1005 Mock::given(method("POST"))
1006 .and(path(endpoint.path()))
1007 .respond_with(Self {
1008 responses: Mutex::new(responses.into()),
1009 })
1010 .expect(len..len + 2)
1011 .mount(mock_server)
1012 .await;
1013 }
1014 }
1015
1016 fn response_ok(ordering: u64) -> ResponseTemplate {
1017 ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1018 {
1019 "data":{
1020 "uplinkQuery": {
1021 "__typename": "New",
1022 "id": ordering.to_string(),
1023 "minDelaySeconds": 0,
1024 "data": {
1025 "name": "ok",
1026 "ordering": ordering,
1027 }
1028 }
1029 }
1030 }))
1031 }
1032
1033 fn response_invalid_license() -> ResponseTemplate {
1034 ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1035 {
1036 "data":{
1037 "uplinkQuery": {
1038 "__typename": "New",
1039 "id": "3",
1040 "minDelaySeconds": 0,
1041 "garbage": "garbage"
1042 }
1043 }
1044 }))
1045 }
1046
1047 fn response_unchanged() -> ResponseTemplate {
1048 ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1049 {
1050 "data":{
1051 "uplinkQuery": {
1052 "__typename": "Unchanged",
1053 "id": "2",
1054 "minDelaySeconds": 0,
1055 }
1056 }
1057 }))
1058 }
1059
1060 fn response_fetch_error_retry() -> ResponseTemplate {
1061 ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1062 {
1063 "data":{
1064 "uplinkQuery": {
1065 "__typename": "FetchError",
1066 "code": "RETRY_LATER",
1067 "message": "error message",
1068 }
1069 }
1070 }))
1071 }
1072
1073 fn response_fetch_error_no_retry() -> ResponseTemplate {
1074 ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1075 {
1076 "data":{
1077 "uplinkQuery": {
1078 "__typename": "FetchError",
1079 "code": "NO_RETRY",
1080 "message": "error message",
1081 }
1082 }
1083 }))
1084 }
1085
1086 fn response_fetch_error_test_error() -> ResponseTemplate {
1087 ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1088 {
1089 "data":{
1090 "uplinkQuery": {
1091 "__typename": "FetchError",
1092 "code": "NO_RETRY",
1093 "message": "unexpected mock request, make sure you have set up appropriate responses",
1094 }
1095 }
1096 }))
1097 }
1098
1099 fn response_fetch_error_http() -> ResponseTemplate {
1100 ResponseTemplate::new(StatusCode::INTERNAL_SERVER_ERROR)
1101 }
1102
1103 fn response_empty() -> ResponseTemplate {
1104 ResponseTemplate::new(StatusCode::OK).set_body_json(json!({ "data": null }))
1105 }
1106}