1use std::error::Error as stdError;
2use std::fmt::Debug;
3use std::time::Duration;
4use std::time::Instant;
5
6use futures::Future;
7use futures::Stream;
8use futures::StreamExt;
9use graphql_client::QueryBody;
10use thiserror::Error;
11use tokio::sync::mpsc::channel;
12use tokio_stream::wrappers::ReceiverStream;
13use tower::BoxError;
14use tracing::instrument::WithSubscriber;
15use url::Url;
16
17pub(crate) mod license_enforcement;
18pub(crate) mod license_stream;
19pub(crate) mod persisted_queries_manifest_stream;
20pub(crate) mod schema;
21pub(crate) mod schema_stream;
22
23const GCP_URL: &str = "https://uplink.api.apollographql.com";
24const AWS_URL: &str = "https://aws.uplink.api.apollographql.com";
25
26#[derive(Debug, Error)]
27pub(crate) enum Error {
28 #[error("http error")]
29 Http(#[from] reqwest::Error),
30
31 #[error("fetch failed from uplink endpoint, and there are no fallback endpoints configured")]
32 FetchFailedSingle,
33
34 #[error("fetch failed from all {url_count} uplink endpoints")]
35 FetchFailedMultiple { url_count: usize },
36
37 #[allow(clippy::enum_variant_names)]
38 #[error("uplink error: code={code} message={message}")]
39 UplinkError { code: String, message: String },
40
41 #[error("uplink error, the request will not be retried: code={code} message={message}")]
42 UplinkErrorNoRetry { code: String, message: String },
43}
44
45#[derive(Debug)]
46pub(crate) struct UplinkRequest {
47 api_key: String,
48 graph_ref: String,
49 id: Option<String>,
50}
51
52#[derive(Debug)]
53pub(crate) enum UplinkResponse<Response>
54where
55 Response: Send + Debug + 'static,
56{
57 New {
58 response: Response,
59 id: String,
60 delay: u64,
61 },
62 Unchanged {
63 id: Option<String>,
64 delay: Option<u64>,
65 },
66 Error {
67 retry_later: bool,
68 code: String,
69 message: String,
70 },
71}
72
73#[derive(Debug, Clone)]
74pub enum Endpoints {
75 Fallback {
76 urls: Vec<Url>,
77 },
78 #[allow(dead_code)]
79 RoundRobin {
80 urls: Vec<Url>,
81 current: usize,
82 },
83}
84
85impl Default for Endpoints {
86 fn default() -> Self {
87 Self::fallback(
88 [GCP_URL, AWS_URL]
89 .iter()
90 .map(|url| Url::parse(url).expect("default urls must be valid"))
91 .collect(),
92 )
93 }
94}
95
96impl Endpoints {
97 pub(crate) fn fallback(urls: Vec<Url>) -> Self {
98 Endpoints::Fallback { urls }
99 }
100 #[allow(dead_code)]
101 pub(crate) fn round_robin(urls: Vec<Url>) -> Self {
102 Endpoints::RoundRobin { urls, current: 0 }
103 }
104
105 fn iter<'a>(&'a mut self) -> Box<dyn Iterator<Item = &'a Url> + Send + 'a> {
109 match self {
110 Endpoints::Fallback { urls } => Box::new(urls.iter()),
111 Endpoints::RoundRobin { urls, current } => {
112 *current %= urls.len();
114
115 Box::new(
119 urls.iter()
120 .cycle()
121 .skip(*current)
122 .inspect(|_| {
123 *current += 1;
124 })
125 .take(urls.len()),
126 )
127 }
128 }
129 }
130
131 pub(crate) fn url_count(&self) -> usize {
132 match self {
133 Endpoints::Fallback { urls } => urls.len(),
134 Endpoints::RoundRobin { urls, current: _ } => urls.len(),
135 }
136 }
137}
138
139#[derive(Debug, Clone, Default)]
142pub struct UplinkConfig {
143 pub apollo_key: String,
145
146 pub apollo_graph_ref: String,
148
149 pub endpoints: Option<Endpoints>,
151
152 pub poll_interval: Duration,
154
155 pub timeout: Duration,
157}
158
159impl UplinkConfig {
160 pub fn for_tests(uplink_endpoints: Endpoints) -> Self {
163 Self {
164 apollo_key: "key".to_string(),
165 apollo_graph_ref: "graph".to_string(),
166 endpoints: Some(uplink_endpoints),
167 poll_interval: Duration::from_secs(2),
168 timeout: Duration::from_secs(5),
169 }
170 }
171}
172
173pub(crate) fn stream_from_uplink<Query, Response>(
176 uplink_config: UplinkConfig,
177) -> impl Stream<Item = Result<Response, Error>>
178where
179 Query: graphql_client::GraphQLQuery,
180 <Query as graphql_client::GraphQLQuery>::ResponseData: Into<UplinkResponse<Response>> + Send,
181 <Query as graphql_client::GraphQLQuery>::Variables: From<UplinkRequest> + Send + Sync,
182 Response: Send + 'static + Debug,
183{
184 stream_from_uplink_transforming_new_response::<Query, Response, Response>(
185 uplink_config,
186 |response| Box::new(Box::pin(async { Ok(response) })),
187 )
188}
189
190pub(crate) fn stream_from_uplink_transforming_new_response<Query, Response, TransformedResponse>(
198 mut uplink_config: UplinkConfig,
199 transform_new_response: impl Fn(
200 Response,
201 ) -> Box<
202 dyn Future<Output = Result<TransformedResponse, BoxError>> + Send + Unpin,
203 > + Send
204 + Sync
205 + 'static,
206) -> impl Stream<Item = Result<TransformedResponse, Error>>
207where
208 Query: graphql_client::GraphQLQuery,
209 <Query as graphql_client::GraphQLQuery>::ResponseData: Into<UplinkResponse<Response>> + Send,
210 <Query as graphql_client::GraphQLQuery>::Variables: From<UplinkRequest> + Send + Sync,
211 Response: Send + 'static + Debug,
212 TransformedResponse: Send + 'static + Debug,
213{
214 let query_name = query_name::<Query>();
215 let (sender, receiver) = channel(2);
216 let client = match reqwest::Client::builder()
217 .no_gzip()
218 .timeout(uplink_config.timeout)
219 .build()
220 {
221 Ok(client) => client,
222 Err(err) => {
223 tracing::error!("unable to create client to query uplink: {err}", err = err);
224 return futures::stream::empty().boxed();
225 }
226 };
227
228 let task = async move {
229 let mut last_id = None;
230 let mut endpoints = uplink_config.endpoints.unwrap_or_default();
231 loop {
232 let variables = UplinkRequest {
233 graph_ref: uplink_config.apollo_graph_ref.to_string(),
234 api_key: uplink_config.apollo_key.to_string(),
235 id: last_id.clone(),
236 };
237
238 let query_body = Query::build_query(variables.into());
239
240 match fetch::<Query, Response, TransformedResponse>(
241 &client,
242 &query_body,
243 &mut endpoints,
244 &transform_new_response,
245 )
246 .await
247 {
248 Ok(response) => {
249 u64_counter!(
250 "apollo_router_uplink_fetch_count_total",
251 "Total number of requests to Apollo Uplink",
252 1u64,
253 status = "success",
254 query = query_name
255 );
256 match response {
257 UplinkResponse::New {
258 id,
259 response,
260 delay,
261 } => {
262 last_id = Some(id);
263 uplink_config.poll_interval = Duration::from_secs(delay);
264
265 if let Err(e) = sender.send(Ok(response)).await {
266 tracing::debug!(
267 "failed to push to stream. This is likely to be because the router is shutting down: {e}"
268 );
269 break;
270 }
271 }
272 UplinkResponse::Unchanged { id, delay } => {
273 if let Some(id) = id {
275 last_id = Some(id);
276 }
277 if let Some(delay) = delay {
278 uplink_config.poll_interval = Duration::from_secs(delay);
279 }
280 }
281 UplinkResponse::Error {
282 retry_later,
283 message,
284 code,
285 } => {
286 let err = if retry_later {
287 Err(Error::UplinkError { code, message })
288 } else {
289 Err(Error::UplinkErrorNoRetry { code, message })
290 };
291 if let Err(e) = sender.send(err).await {
292 tracing::debug!(
293 "failed to send error to uplink stream. This is likely to be because the router is shutting down: {e}"
294 );
295 break;
296 }
297 if !retry_later {
298 break;
299 }
300 }
301 }
302 }
303 Err(err) => {
304 u64_counter!(
305 "apollo_router_uplink_fetch_count_total",
306 "Total number of requests to Apollo Uplink",
307 1u64,
308 status = "failure",
309 query = query_name
310 );
311 if let Err(e) = sender.send(Err(err)).await {
312 tracing::debug!(
313 "failed to send error to uplink stream. This is likely to be because the router is shutting down: {e}"
314 );
315 break;
316 }
317 }
318 }
319
320 tokio::time::sleep(uplink_config.poll_interval).await;
321 }
322 };
323 drop(tokio::task::spawn(task.with_current_subscriber()));
324
325 ReceiverStream::new(receiver).boxed()
326}
327
328pub(crate) async fn fetch<Query, Response, TransformedResponse>(
329 client: &reqwest::Client,
330 request_body: &QueryBody<Query::Variables>,
331 endpoints: &mut Endpoints,
332 transform_new_response: &(
335 impl Fn(
336 Response,
337 ) -> Box<dyn Future<Output = Result<TransformedResponse, BoxError>> + Send + Unpin>
338 + Send
339 + Sync
340 + 'static
341 ),
342) -> Result<UplinkResponse<TransformedResponse>, Error>
343where
344 Query: graphql_client::GraphQLQuery,
345 <Query as graphql_client::GraphQLQuery>::ResponseData: Into<UplinkResponse<Response>> + Send,
346 <Query as graphql_client::GraphQLQuery>::Variables: From<UplinkRequest> + Send + Sync,
347 Response: Send + Debug + 'static,
348 TransformedResponse: Send + Debug + 'static,
349{
350 let query = query_name::<Query>();
351 for url in endpoints.iter() {
352 let now = Instant::now();
353 match http_request::<Query>(client, url.as_str(), request_body).await {
354 Ok(response) => match response.data.map(Into::into) {
355 None => {
356 f64_histogram!(
357 "apollo_router_uplink_fetch_duration_seconds",
358 "Duration of Apollo Uplink fetches.",
359 now.elapsed().as_secs_f64(),
360 query = query,
361 url = url.to_string(),
362 kind = "uplink_error",
363 error = "empty response from uplink"
364 );
365 }
366 Some(UplinkResponse::New {
367 response,
368 id,
369 delay,
370 }) => {
371 f64_histogram!(
372 "apollo_router_uplink_fetch_duration_seconds",
373 "Duration of Apollo Uplink fetches.",
374 now.elapsed().as_secs_f64(),
375 query = query,
376 url = url.to_string(),
377 kind = "new"
378 );
379 match transform_new_response(response).await {
380 Ok(res) => {
381 return Ok(UplinkResponse::New {
382 response: res,
383 id,
384 delay,
385 });
386 }
387 Err(err) => {
388 tracing::debug!(
389 "failed to process results of Uplink response from {url}: {err}. Other endpoints will be tried"
390 );
391 continue;
392 }
393 }
394 }
395 Some(UplinkResponse::Unchanged { id, delay }) => {
396 f64_histogram!(
397 "apollo_router_uplink_fetch_duration_seconds",
398 "Duration of Apollo Uplink fetches.",
399 now.elapsed().as_secs_f64(),
400 query = query,
401 url = url.to_string(),
402 kind = "unchanged"
403 );
404 return Ok(UplinkResponse::Unchanged { id, delay });
405 }
406 Some(UplinkResponse::Error {
407 message,
408 code,
409 retry_later,
410 }) => {
411 f64_histogram!(
412 "apollo_router_uplink_fetch_duration_seconds",
413 "Duration of Apollo Uplink fetches.",
414 now.elapsed().as_secs_f64(),
415 query = query,
416 url = url.to_string(),
417 kind = "uplink_error",
418 error = message.clone(),
419 code = code.clone()
420 );
421 return Ok(UplinkResponse::Error {
422 message,
423 code,
424 retry_later,
425 });
426 }
427 },
428 Err(err) => {
429 f64_histogram!(
430 "apollo_router_uplink_fetch_duration_seconds",
431 "Duration of Apollo Uplink fetches.",
432 now.elapsed().as_secs_f64(),
433 query = query,
434 url = url.to_string(),
435 kind = "http_error",
436 error = err.to_string(),
437 code = err.status().unwrap_or_default().to_string()
438 );
439 tracing::debug!(
440 "failed to fetch from Uplink endpoint {url}: {err}. Other endpoints will be tried"
441 );
442 }
443 };
444 }
445
446 let url_count = endpoints.url_count();
447 if url_count == 1 {
448 Err(Error::FetchFailedSingle)
449 } else {
450 Err(Error::FetchFailedMultiple { url_count })
451 }
452}
453
454fn query_name<Query>() -> &'static str {
455 let mut query = std::any::type_name::<Query>();
456 query = query
457 .strip_suffix("Query")
458 .expect("Uplink structs must be named xxxQuery")
459 .get(query.rfind("::").map(|index| index + 2).unwrap_or_default()..)
460 .expect("cannot fail");
461 query
462}
463
464async fn http_request<Query>(
465 client: &reqwest::Client,
466 url: &str,
467 request_body: &QueryBody<Query::Variables>,
468) -> Result<graphql_client::Response<Query::ResponseData>, reqwest::Error>
469where
470 Query: graphql_client::GraphQLQuery,
471{
472 let res = client
479 .post(url)
480 .header("x-router-version", env!("CARGO_PKG_VERSION"))
481 .json(request_body)
482 .send()
483 .await
484 .inspect_err(|e| {
485 if let Some(hyper_err) = e.source() {
486 if let Some(os_err) = hyper_err.source() {
487 if os_err.to_string().contains("tcp connect error: Cannot assign requested address (os error 99)") {
488 tracing::warn!("If your router is executing within a kubernetes pod, this failure may be caused by istio-proxy injection. See https://github.com/apollographql/router/issues/3533 for more details about how to solve this");
489 }
490 }
491 }
492 })?;
493 tracing::debug!("uplink response {:?}", res);
494 let response_body: graphql_client::Response<Query::ResponseData> = res.json().await?;
495 Ok(response_body)
496}
497
498#[cfg(test)]
499mod test {
500 use std::collections::VecDeque;
501 use std::sync::Mutex;
502 use std::time::Duration;
503
504 use buildstructor::buildstructor;
505 use futures::StreamExt;
506 use graphql_client::GraphQLQuery;
507 use http::StatusCode;
508 use insta::assert_yaml_snapshot;
509 use serde_json::json;
510 use test_query::FetchErrorCode;
511 use test_query::TestQueryUplinkQuery;
512 use url::Url;
513 use wiremock::Mock;
514 use wiremock::MockServer;
515 use wiremock::Request;
516 use wiremock::Respond;
517 use wiremock::ResponseTemplate;
518 use wiremock::matchers::method;
519 use wiremock::matchers::path;
520
521 use crate::uplink::Endpoints;
522 use crate::uplink::Error;
523 use crate::uplink::UplinkConfig;
524 use crate::uplink::UplinkRequest;
525 use crate::uplink::UplinkResponse;
526 use crate::uplink::stream_from_uplink;
527 use crate::uplink::stream_from_uplink_transforming_new_response;
528
529 #[derive(GraphQLQuery)]
530 #[graphql(
531 query_path = "src/uplink/testdata/test_query.graphql",
532 schema_path = "src/uplink/testdata/test_uplink.graphql",
533 request_derives = "Debug",
534 response_derives = "PartialEq, Debug, Deserialize",
535 deprecated = "warn"
536 )]
537 pub(crate) struct TestQuery {}
538
539 #[derive(Debug, Eq, PartialEq)]
540 struct QueryResult {
541 name: String,
542 ordering: i64,
543 }
544
545 #[allow(dead_code)]
546 #[derive(Debug)]
547 struct TransformedQueryResult {
548 name: String,
549 halved_ordering: i64,
550 }
551
552 impl From<UplinkRequest> for test_query::Variables {
553 fn from(req: UplinkRequest) -> Self {
554 test_query::Variables {
555 api_key: req.api_key,
556 graph_ref: req.graph_ref,
557 if_after_id: req.id,
558 }
559 }
560 }
561
562 impl From<test_query::ResponseData> for UplinkResponse<QueryResult> {
563 fn from(response: test_query::ResponseData) -> Self {
564 match response.uplink_query {
565 TestQueryUplinkQuery::New(response) => UplinkResponse::New {
566 id: response.id,
567 delay: response.min_delay_seconds as u64,
568 response: QueryResult {
569 name: response.data.name,
570 ordering: response.data.ordering,
571 },
572 },
573 TestQueryUplinkQuery::Unchanged(response) => UplinkResponse::Unchanged {
574 id: Some(response.id),
575 delay: Some(response.min_delay_seconds as u64),
576 },
577 TestQueryUplinkQuery::FetchError(error) => UplinkResponse::Error {
578 retry_later: error.code == FetchErrorCode::RETRY_LATER,
579 code: match error.code {
580 FetchErrorCode::AUTHENTICATION_FAILED => {
581 "AUTHENTICATION_FAILED".to_string()
582 }
583 FetchErrorCode::ACCESS_DENIED => "ACCESS_DENIED".to_string(),
584 FetchErrorCode::UNKNOWN_REF => "UNKNOWN_REF".to_string(),
585 FetchErrorCode::RETRY_LATER => "RETRY_LATER".to_string(),
586 FetchErrorCode::Other(other) => other,
587 },
588 message: error.message,
589 },
590 }
591 }
592 }
593
594 fn mock_uplink_config_with_fallback_urls(urls: Vec<Url>) -> UplinkConfig {
595 UplinkConfig {
596 apollo_key: "dummy_key".to_string(),
597 apollo_graph_ref: "dummy_graph_ref".to_string(),
598 endpoints: Some(Endpoints::fallback(urls)),
599 poll_interval: Duration::from_secs(0),
600 timeout: Duration::from_secs(1),
601 }
602 }
603
604 fn mock_uplink_config_with_round_robin_urls(urls: Vec<Url>) -> UplinkConfig {
605 UplinkConfig {
606 apollo_key: "dummy_key".to_string(),
607 apollo_graph_ref: "dummy_graph_ref".to_string(),
608 endpoints: Some(Endpoints::round_robin(urls)),
609 poll_interval: Duration::from_secs(0),
610 timeout: Duration::from_secs(1),
611 }
612 }
613
614 #[test]
615 fn test_round_robin_endpoints() {
616 let url1 = Url::parse("http://example1.com").expect("url must be valid");
617 let url2 = Url::parse("http://example2.com").expect("url must be valid");
618 let mut endpoints = Endpoints::round_robin(vec![url1.clone(), url2.clone()]);
619 assert_eq!(endpoints.iter().collect::<Vec<_>>(), vec![&url1, &url2]);
620 assert_eq!(endpoints.iter().next(), Some(&url1));
621 assert_eq!(endpoints.iter().collect::<Vec<_>>(), vec![&url2, &url1]);
622 }
623
624 #[test]
625 fn test_fallback_endpoints() {
626 let url1 = Url::parse("http://example1.com").expect("url must be valid");
627 let url2 = Url::parse("http://example2.com").expect("url must be valid");
628 let mut endpoints = Endpoints::fallback(vec![url1.clone(), url2.clone()]);
629 assert_eq!(endpoints.iter().collect::<Vec<_>>(), vec![&url1, &url2]);
630 assert_eq!(endpoints.iter().next(), Some(&url1));
631 assert_eq!(endpoints.iter().collect::<Vec<_>>(), vec![&url1, &url2]);
632 }
633
634 #[tokio::test(flavor = "multi_thread")]
635 async fn stream_from_uplink_fallback() {
636 let (mock_server, url1, url2, _url3) = init_mock_server().await;
637 MockResponses::builder()
638 .mock_server(&mock_server)
639 .endpoint(&url1)
640 .response(response_ok(1))
641 .response(response_ok(2))
642 .build()
643 .await;
644 MockResponses::builder()
645 .mock_server(&mock_server)
646 .endpoint(&url2)
647 .build()
648 .await;
649
650 let results = stream_from_uplink::<TestQuery, QueryResult>(
651 mock_uplink_config_with_fallback_urls(vec![url1, url2]),
652 )
653 .take(2)
654 .collect::<Vec<_>>()
655 .await;
656 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
657 }
658
659 #[tokio::test(flavor = "multi_thread")]
660 async fn stream_from_uplink_round_robin() {
661 let (mock_server, url1, url2, _url3) = init_mock_server().await;
662 MockResponses::builder()
663 .mock_server(&mock_server)
664 .endpoint(&url1)
665 .response(response_ok(1))
666 .build()
667 .await;
668 MockResponses::builder()
669 .mock_server(&mock_server)
670 .response(response_ok(2))
671 .endpoint(&url2)
672 .build()
673 .await;
674
675 let results = stream_from_uplink::<TestQuery, QueryResult>(
676 mock_uplink_config_with_round_robin_urls(vec![url1, url2]),
677 )
678 .take(2)
679 .collect::<Vec<_>>()
680 .await;
681 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
682 }
683
684 #[tokio::test(flavor = "multi_thread")]
685 async fn stream_from_uplink_error_retry() {
686 let (mock_server, url1, url2, _url3) = init_mock_server().await;
687 MockResponses::builder()
688 .mock_server(&mock_server)
689 .endpoint(&url1)
690 .response(response_fetch_error_retry())
691 .response(response_ok(1))
692 .build()
693 .await;
694 let results = stream_from_uplink::<TestQuery, QueryResult>(
695 mock_uplink_config_with_fallback_urls(vec![url1, url2]),
696 )
697 .take(2)
698 .collect::<Vec<_>>()
699 .await;
700 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
701 }
702
703 #[tokio::test(flavor = "multi_thread")]
704 async fn stream_from_uplink_error_no_retry() {
705 let (mock_server, url1, url2, _url3) = init_mock_server().await;
706 MockResponses::builder()
707 .mock_server(&mock_server)
708 .endpoint(&url1)
709 .response(response_fetch_error_no_retry())
710 .build()
711 .await;
712 let results = stream_from_uplink::<TestQuery, QueryResult>(
713 mock_uplink_config_with_fallback_urls(vec![url1, url2]),
714 )
715 .collect::<Vec<_>>()
716 .await;
717 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
718 }
719
720 #[tokio::test(flavor = "multi_thread")]
721 async fn stream_from_uplink_error_http_fallback() {
722 let (mock_server, url1, url2, url3) = init_mock_server().await;
723 MockResponses::builder()
724 .mock_server(&mock_server)
725 .endpoint(&url1)
726 .response(response_fetch_error_http())
727 .response(response_fetch_error_http())
728 .build()
729 .await;
730 MockResponses::builder()
731 .mock_server(&mock_server)
732 .endpoint(&url2)
733 .response(response_ok(1))
734 .response(response_ok(2))
735 .build()
736 .await;
737 MockResponses::builder()
738 .mock_server(&mock_server)
739 .endpoint(&url3)
740 .build()
741 .await;
742 let results = stream_from_uplink::<TestQuery, QueryResult>(
743 mock_uplink_config_with_fallback_urls(vec![url1, url2, url3]),
744 )
745 .take(2)
746 .collect::<Vec<_>>()
747 .await;
748 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
749 }
750
751 #[tokio::test(flavor = "multi_thread")]
752 async fn stream_from_uplink_empty_http_fallback() {
753 let (mock_server, url1, url2, url3) = init_mock_server().await;
754 MockResponses::builder()
755 .mock_server(&mock_server)
756 .endpoint(&url1)
757 .response(response_empty())
758 .response(response_empty())
759 .build()
760 .await;
761 MockResponses::builder()
762 .mock_server(&mock_server)
763 .endpoint(&url2)
764 .response(response_ok(1))
765 .response(response_ok(2))
766 .build()
767 .await;
768 MockResponses::builder()
769 .mock_server(&mock_server)
770 .endpoint(&url3)
771 .build()
772 .await;
773 let results = stream_from_uplink::<TestQuery, QueryResult>(
774 mock_uplink_config_with_fallback_urls(vec![url1, url2, url3]),
775 )
776 .take(2)
777 .collect::<Vec<_>>()
778 .await;
779 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
780 }
781
782 #[tokio::test(flavor = "multi_thread")]
783 async fn stream_from_uplink_error_http_round_robin() {
784 let (mock_server, url1, url2, url3) = init_mock_server().await;
785 MockResponses::builder()
786 .mock_server(&mock_server)
787 .endpoint(&url1)
788 .response(response_fetch_error_http())
789 .build()
790 .await;
791 MockResponses::builder()
792 .mock_server(&mock_server)
793 .endpoint(&url2)
794 .response(response_ok(1))
795 .build()
796 .await;
797 MockResponses::builder()
798 .mock_server(&mock_server)
799 .endpoint(&url3)
800 .response(response_ok(2))
801 .build()
802 .await;
803 let results = stream_from_uplink::<TestQuery, QueryResult>(
804 mock_uplink_config_with_round_robin_urls(vec![url1, url2, url3]),
805 )
806 .take(2)
807 .collect::<Vec<_>>()
808 .await;
809 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
810 }
811
812 #[tokio::test(flavor = "multi_thread")]
813 async fn stream_from_uplink_empty_http_round_robin() {
814 let (mock_server, url1, url2, url3) = init_mock_server().await;
815 MockResponses::builder()
816 .mock_server(&mock_server)
817 .endpoint(&url1)
818 .response(response_empty())
819 .build()
820 .await;
821 MockResponses::builder()
822 .mock_server(&mock_server)
823 .endpoint(&url2)
824 .response(response_ok(1))
825 .build()
826 .await;
827 MockResponses::builder()
828 .mock_server(&mock_server)
829 .endpoint(&url3)
830 .response(response_ok(2))
831 .build()
832 .await;
833 let results = stream_from_uplink::<TestQuery, QueryResult>(
834 mock_uplink_config_with_round_robin_urls(vec![url1, url2, url3]),
835 )
836 .take(2)
837 .collect::<Vec<_>>()
838 .await;
839 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
840 }
841
842 #[tokio::test(flavor = "multi_thread")]
843 async fn stream_from_uplink_invalid() {
844 let (mock_server, url1, url2, url3) = init_mock_server().await;
845 MockResponses::builder()
846 .mock_server(&mock_server)
847 .endpoint(&url1)
848 .response(response_invalid_license())
849 .build()
850 .await;
851 let results = stream_from_uplink::<TestQuery, QueryResult>(
852 mock_uplink_config_with_round_robin_urls(vec![url1, url2, url3]),
853 )
854 .take(1)
855 .collect::<Vec<_>>()
856 .await;
857 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
858 }
859
860 #[tokio::test(flavor = "multi_thread")]
861 async fn stream_from_uplink_unchanged() {
862 let (mock_server, url1, url2, url3) = init_mock_server().await;
863 MockResponses::builder()
864 .mock_server(&mock_server)
865 .endpoint(&url1)
866 .response(response_ok(1))
867 .response(response_unchanged())
868 .response(response_ok(2))
869 .build()
870 .await;
871 let results = stream_from_uplink::<TestQuery, QueryResult>(
872 mock_uplink_config_with_round_robin_urls(vec![url1, url2, url3]),
873 )
874 .take(2)
875 .collect::<Vec<_>>()
876 .await;
877 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
878 }
879
880 #[tokio::test(flavor = "multi_thread")]
881 async fn stream_from_uplink_failed_from_all() {
882 let (mock_server, url1, url2, _url3) = init_mock_server().await;
883 MockResponses::builder()
884 .mock_server(&mock_server)
885 .endpoint(&url1)
886 .response(response_fetch_error_http())
887 .build()
888 .await;
889 MockResponses::builder()
890 .mock_server(&mock_server)
891 .endpoint(&url2)
892 .response(response_fetch_error_http())
893 .build()
894 .await;
895 let results = stream_from_uplink::<TestQuery, QueryResult>(
896 mock_uplink_config_with_round_robin_urls(vec![url1, url2]),
897 )
898 .take(1)
899 .collect::<Vec<_>>()
900 .await;
901 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
902 }
903
904 #[tokio::test(flavor = "multi_thread")]
905 async fn stream_from_uplink_failed_from_single() {
906 let (mock_server, url1, _url2, _url3) = init_mock_server().await;
907 MockResponses::builder()
908 .mock_server(&mock_server)
909 .endpoint(&url1)
910 .response(response_fetch_error_http())
911 .build()
912 .await;
913 let results = stream_from_uplink::<TestQuery, QueryResult>(
914 mock_uplink_config_with_fallback_urls(vec![url1]),
915 )
916 .take(1)
917 .collect::<Vec<_>>()
918 .await;
919 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
920 }
921
922 #[tokio::test(flavor = "multi_thread")]
923 async fn stream_from_uplink_transforming_new_response_first_response_transform_fails() {
924 let (mock_server, url1, url2, _url3) = init_mock_server().await;
925 MockResponses::builder()
926 .mock_server(&mock_server)
927 .endpoint(&url1)
928 .response(response_ok(15))
929 .build()
930 .await;
931 MockResponses::builder()
932 .mock_server(&mock_server)
933 .endpoint(&url2)
934 .response(response_ok(100))
935 .build()
936 .await;
937 let results = stream_from_uplink_transforming_new_response::<
938 TestQuery,
939 QueryResult,
940 TransformedQueryResult,
941 >(
942 mock_uplink_config_with_fallback_urls(vec![url1, url2]),
943 |result| {
944 Box::new(Box::pin(async move {
945 let QueryResult { name, ordering } = result;
946 if ordering % 2 == 0 {
947 Ok(TransformedQueryResult {
949 name,
950 halved_ordering: ordering / 2,
951 })
952 } else {
953 Err("cannot halve an odd number".into())
955 }
956 }))
957 },
958 )
959 .take(1)
960 .collect::<Vec<_>>()
961 .await;
962 assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
963 }
964
965 fn to_friendly<R: std::fmt::Debug>(r: Result<R, Error>) -> Result<String, String> {
966 match r {
967 Ok(e) => Ok(format!("result {:?}", e)),
968 Err(e) => Err(e.to_string()),
969 }
970 }
971
972 async fn init_mock_server() -> (MockServer, Url, Url, Url) {
973 let mock_server = MockServer::start().await;
974 let url1 =
975 Url::parse(&format!("{}/endpoint1", mock_server.uri())).expect("url must be valid");
976 let url2 =
977 Url::parse(&format!("{}/endpoint2", mock_server.uri())).expect("url must be valid");
978 let url3 =
979 Url::parse(&format!("{}/endpoint3", mock_server.uri())).expect("url must be valid");
980 (mock_server, url1, url2, url3)
981 }
982
983 struct MockResponses {
984 responses: Mutex<VecDeque<ResponseTemplate>>,
985 }
986
987 impl Respond for MockResponses {
988 fn respond(&self, _request: &Request) -> ResponseTemplate {
989 self.responses
990 .lock()
991 .expect("lock poisoned")
992 .pop_front()
993 .unwrap_or_else(response_fetch_error_test_error)
994 }
995 }
996
997 #[buildstructor]
998 impl MockResponses {
999 #[builder(entry = "builder")]
1000 async fn setup<'a>(
1001 mock_server: &'a MockServer,
1002 endpoint: &'a Url,
1003 responses: Vec<ResponseTemplate>,
1004 ) {
1005 let len = responses.len() as u64;
1006 Mock::given(method("POST"))
1007 .and(path(endpoint.path()))
1008 .respond_with(Self {
1009 responses: Mutex::new(responses.into()),
1010 })
1011 .expect(len..len + 2)
1012 .mount(mock_server)
1013 .await;
1014 }
1015 }
1016
1017 fn response_ok(ordering: u64) -> ResponseTemplate {
1018 ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1019 {
1020 "data":{
1021 "uplinkQuery": {
1022 "__typename": "New",
1023 "id": ordering.to_string(),
1024 "minDelaySeconds": 0,
1025 "data": {
1026 "name": "ok",
1027 "ordering": ordering,
1028 }
1029 }
1030 }
1031 }))
1032 }
1033
1034 fn response_invalid_license() -> ResponseTemplate {
1035 ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1036 {
1037 "data":{
1038 "uplinkQuery": {
1039 "__typename": "New",
1040 "id": "3",
1041 "minDelaySeconds": 0,
1042 "garbage": "garbage"
1043 }
1044 }
1045 }))
1046 }
1047
1048 fn response_unchanged() -> ResponseTemplate {
1049 ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1050 {
1051 "data":{
1052 "uplinkQuery": {
1053 "__typename": "Unchanged",
1054 "id": "2",
1055 "minDelaySeconds": 0,
1056 }
1057 }
1058 }))
1059 }
1060
1061 fn response_fetch_error_retry() -> ResponseTemplate {
1062 ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1063 {
1064 "data":{
1065 "uplinkQuery": {
1066 "__typename": "FetchError",
1067 "code": "RETRY_LATER",
1068 "message": "error message",
1069 }
1070 }
1071 }))
1072 }
1073
1074 fn response_fetch_error_no_retry() -> ResponseTemplate {
1075 ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1076 {
1077 "data":{
1078 "uplinkQuery": {
1079 "__typename": "FetchError",
1080 "code": "NO_RETRY",
1081 "message": "error message",
1082 }
1083 }
1084 }))
1085 }
1086
1087 fn response_fetch_error_test_error() -> ResponseTemplate {
1088 ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1089 {
1090 "data":{
1091 "uplinkQuery": {
1092 "__typename": "FetchError",
1093 "code": "NO_RETRY",
1094 "message": "unexpected mock request, make sure you have set up appropriate responses",
1095 }
1096 }
1097 }))
1098 }
1099
1100 fn response_fetch_error_http() -> ResponseTemplate {
1101 ResponseTemplate::new(StatusCode::INTERNAL_SERVER_ERROR)
1102 }
1103
1104 fn response_empty() -> ResponseTemplate {
1105 ResponseTemplate::new(StatusCode::OK).set_body_json(json!({ "data": null }))
1106 }
1107}