Skip to main content

apollo_router/uplink/
mod.rs

1use std::error::Error as stdError;
2use std::fmt::Debug;
3use std::time::Duration;
4use std::time::Instant;
5
6use futures::Future;
7use futures::Stream;
8use futures::StreamExt;
9use graphql_client::QueryBody;
10use thiserror::Error;
11use tokio::sync::mpsc::channel;
12use tokio_stream::wrappers::ReceiverStream;
13use tower::BoxError;
14use tracing::instrument::WithSubscriber;
15use url::Url;
16
17pub(crate) mod license_enforcement;
18pub(crate) mod license_stream;
19pub(crate) mod persisted_queries_manifest_stream;
20pub(crate) mod schema;
21pub(crate) mod schema_stream;
22
23const GCP_URL: &str = "https://uplink.api.apollographql.com";
24const AWS_URL: &str = "https://aws.uplink.api.apollographql.com";
25
26#[derive(Debug, Error)]
27pub(crate) enum Error {
28    #[error("http error")]
29    Http(#[from] reqwest::Error),
30
31    #[error("fetch failed from uplink endpoint, and there are no fallback endpoints configured")]
32    FetchFailedSingle,
33
34    #[error("fetch failed from all {url_count} uplink endpoints")]
35    FetchFailedMultiple { url_count: usize },
36
37    #[allow(clippy::enum_variant_names)]
38    #[error("uplink error: code={code} message={message}")]
39    UplinkError { code: String, message: String },
40
41    #[error("uplink error, the request will not be retried: code={code} message={message}")]
42    UplinkErrorNoRetry { code: String, message: String },
43}
44
45#[derive(Debug)]
46pub(crate) struct UplinkRequest {
47    api_key: String,
48    graph_ref: String,
49    id: Option<String>,
50}
51
52#[derive(Debug)]
53pub(crate) enum UplinkResponse<Response>
54where
55    Response: Send + Debug + 'static,
56{
57    New {
58        response: Response,
59        id: String,
60        delay: u64,
61    },
62    Unchanged {
63        id: Option<String>,
64        delay: Option<u64>,
65    },
66    Error {
67        retry_later: bool,
68        code: String,
69        message: String,
70    },
71}
72
73#[derive(Debug, Clone)]
74pub enum Endpoints {
75    Fallback {
76        urls: Vec<Url>,
77    },
78    #[allow(dead_code)]
79    RoundRobin {
80        urls: Vec<Url>,
81        current: usize,
82    },
83}
84
85impl Default for Endpoints {
86    fn default() -> Self {
87        Self::fallback(
88            [GCP_URL, AWS_URL]
89                .iter()
90                .map(|url| Url::parse(url).expect("default urls must be valid"))
91                .collect(),
92        )
93    }
94}
95
96impl Endpoints {
97    pub(crate) fn fallback(urls: Vec<Url>) -> Self {
98        Endpoints::Fallback { urls }
99    }
100    #[allow(dead_code)]
101    pub(crate) fn round_robin(urls: Vec<Url>) -> Self {
102        Endpoints::RoundRobin { urls, current: 0 }
103    }
104
105    /// Return an iterator of endpoints to check on a poll of uplink.
106    /// Fallback will always return URLs in the same order.
107    /// Round-robin will return an iterator that cycles over the URLS starting at the next URL
108    fn iter<'a>(&'a mut self) -> Box<dyn Iterator<Item = &'a Url> + Send + 'a> {
109        match self {
110            Endpoints::Fallback { urls } => Box::new(urls.iter()),
111            Endpoints::RoundRobin { urls, current } => {
112                // Prevent current from getting large.
113                *current %= urls.len();
114
115                // The iterator cycles, but will skip to the next untried URL and is finally limited by the number of URLs.
116                // This gives us a sliding window of URLs to try on each poll to uplink.
117                // The returned iterator will increment current each time it is called.
118                Box::new(
119                    urls.iter()
120                        .cycle()
121                        .skip(*current)
122                        .inspect(|_| {
123                            *current += 1;
124                        })
125                        .take(urls.len()),
126                )
127            }
128        }
129    }
130
131    pub(crate) fn url_count(&self) -> usize {
132        match self {
133            Endpoints::Fallback { urls } => urls.len(),
134            Endpoints::RoundRobin { urls, current: _ } => urls.len(),
135        }
136    }
137}
138
139/// Configuration for polling Apollo Uplink.
140/// This struct does not change on router reloads - they are all sourced from CLI options.
141#[derive(Debug, Clone, Default)]
142pub struct UplinkConfig {
143    /// The Apollo key: `<YOUR_GRAPH_API_KEY>`
144    pub apollo_key: String,
145
146    /// The apollo graph reference: `<YOUR_GRAPH_ID>@<VARIANT>`
147    pub apollo_graph_ref: String,
148
149    /// The endpoints polled.
150    pub endpoints: Option<Endpoints>,
151
152    /// The duration between polling
153    pub poll_interval: Duration,
154
155    /// The HTTP client timeout for each poll
156    pub timeout: Duration,
157}
158
159impl UplinkConfig {
160    /// Mock uplink configuration options for use in tests
161    /// A nice pattern is to use wiremock to start an uplink mocker and pass the URL here.
162    pub fn for_tests(uplink_endpoints: Endpoints) -> Self {
163        Self {
164            apollo_key: "key".to_string(),
165            apollo_graph_ref: "graph".to_string(),
166            endpoints: Some(uplink_endpoints),
167            poll_interval: Duration::from_secs(2),
168            timeout: Duration::from_secs(5),
169        }
170    }
171}
172
173/// Regularly fetch from Uplink
174/// If urls are supplied then they will be called round robin
175pub(crate) fn stream_from_uplink<Query, Response>(
176    uplink_config: UplinkConfig,
177) -> impl Stream<Item = Result<Response, Error>>
178where
179    Query: graphql_client::GraphQLQuery,
180    <Query as graphql_client::GraphQLQuery>::ResponseData: Into<UplinkResponse<Response>> + Send,
181    <Query as graphql_client::GraphQLQuery>::Variables: From<UplinkRequest> + Send + Sync,
182    Response: Send + 'static + Debug,
183{
184    stream_from_uplink_transforming_new_response::<Query, Response, Response>(
185        uplink_config,
186        |response| Box::new(Box::pin(async { Ok(response) })),
187    )
188}
189
190/// Like stream_from_uplink, but applies an async transformation function to the
191/// result of the HTTP fetch if the response is an UplinkResponse::New. If this
192/// function returns Err, we fail over to the next Uplink endpoint, just like if
193/// the HTTP fetch itself failed. This serves the use case where an Uplink
194/// endpoint's response includes another URL located close to the Uplink
195/// endpoint; if that second URL is down, we want to try the next Uplink
196/// endpoint rather than fully giving up.
197pub(crate) fn stream_from_uplink_transforming_new_response<Query, Response, TransformedResponse>(
198    mut uplink_config: UplinkConfig,
199    transform_new_response: impl Fn(
200        Response,
201    ) -> Box<
202        dyn Future<Output = Result<TransformedResponse, BoxError>> + Send + Unpin,
203    > + Send
204    + Sync
205    + 'static,
206) -> impl Stream<Item = Result<TransformedResponse, Error>>
207where
208    Query: graphql_client::GraphQLQuery,
209    <Query as graphql_client::GraphQLQuery>::ResponseData: Into<UplinkResponse<Response>> + Send,
210    <Query as graphql_client::GraphQLQuery>::Variables: From<UplinkRequest> + Send + Sync,
211    Response: Send + 'static + Debug,
212    TransformedResponse: Send + 'static + Debug,
213{
214    let query_name = query_name::<Query>();
215    let (sender, receiver) = channel(2);
216    let client = match reqwest::Client::builder()
217        .no_gzip()
218        .timeout(uplink_config.timeout)
219        .build()
220    {
221        Ok(client) => client,
222        Err(err) => {
223            tracing::error!("unable to create client to query uplink: {err}", err = err);
224            return futures::stream::empty().boxed();
225        }
226    };
227
228    let task = async move {
229        let mut last_id = None;
230        let mut endpoints = uplink_config.endpoints.unwrap_or_default();
231        loop {
232            let variables = UplinkRequest {
233                graph_ref: uplink_config.apollo_graph_ref.to_string(),
234                api_key: uplink_config.apollo_key.to_string(),
235                id: last_id.clone(),
236            };
237
238            let query_body = Query::build_query(variables.into());
239
240            match fetch::<Query, Response, TransformedResponse>(
241                &client,
242                &query_body,
243                &mut endpoints,
244                &transform_new_response,
245            )
246            .await
247            {
248                Ok(response) => {
249                    u64_counter!(
250                        "apollo_router_uplink_fetch_count_total",
251                        "Total number of requests to Apollo Uplink",
252                        1u64,
253                        status = "success",
254                        query = query_name
255                    );
256                    match response {
257                        UplinkResponse::New {
258                            id,
259                            response,
260                            delay,
261                        } => {
262                            last_id = Some(id);
263                            uplink_config.poll_interval = Duration::from_secs(delay);
264
265                            if let Err(e) = sender.send(Ok(response)).await {
266                                tracing::debug!(
267                                    "failed to push to stream. This is likely to be because the router is shutting down: {e}"
268                                );
269                                break;
270                            }
271                        }
272                        UplinkResponse::Unchanged { id, delay } => {
273                            // Preserve behavior for schema uplink errors where id and delay are not reset if they are not provided on error.
274                            if let Some(id) = id {
275                                last_id = Some(id);
276                            }
277                            if let Some(delay) = delay {
278                                uplink_config.poll_interval = Duration::from_secs(delay);
279                            }
280                        }
281                        UplinkResponse::Error {
282                            retry_later,
283                            message,
284                            code,
285                        } => {
286                            let err = if retry_later {
287                                Err(Error::UplinkError { code, message })
288                            } else {
289                                Err(Error::UplinkErrorNoRetry { code, message })
290                            };
291                            if let Err(e) = sender.send(err).await {
292                                tracing::debug!(
293                                    "failed to send error to uplink stream. This is likely to be because the router is shutting down: {e}"
294                                );
295                                break;
296                            }
297                            if !retry_later {
298                                break;
299                            }
300                        }
301                    }
302                }
303                Err(err) => {
304                    u64_counter!(
305                        "apollo_router_uplink_fetch_count_total",
306                        "Total number of requests to Apollo Uplink",
307                        1u64,
308                        status = "failure",
309                        query = query_name
310                    );
311                    if let Err(e) = sender.send(Err(err)).await {
312                        tracing::debug!(
313                            "failed to send error to uplink stream. This is likely to be because the router is shutting down: {e}"
314                        );
315                        break;
316                    }
317                }
318            }
319
320            tokio::time::sleep(uplink_config.poll_interval).await;
321        }
322    };
323    drop(tokio::task::spawn(task.with_current_subscriber()));
324
325    ReceiverStream::new(receiver).boxed()
326}
327
328pub(crate) async fn fetch<Query, Response, TransformedResponse>(
329    client: &reqwest::Client,
330    request_body: &QueryBody<Query::Variables>,
331    endpoints: &mut Endpoints,
332    // See stream_from_uplink_transforming_new_response for an explanation of
333    // this argument.
334    transform_new_response: &(
335         impl Fn(
336        Response,
337    ) -> Box<dyn Future<Output = Result<TransformedResponse, BoxError>> + Send + Unpin>
338         + Send
339         + Sync
340         + 'static
341     ),
342) -> Result<UplinkResponse<TransformedResponse>, Error>
343where
344    Query: graphql_client::GraphQLQuery,
345    <Query as graphql_client::GraphQLQuery>::ResponseData: Into<UplinkResponse<Response>> + Send,
346    <Query as graphql_client::GraphQLQuery>::Variables: From<UplinkRequest> + Send + Sync,
347    Response: Send + Debug + 'static,
348    TransformedResponse: Send + Debug + 'static,
349{
350    let query = query_name::<Query>();
351    for url in endpoints.iter() {
352        let now = Instant::now();
353        match http_request::<Query>(client, url.as_str(), request_body).await {
354            Ok(response) => match response.data.map(Into::into) {
355                None => {
356                    f64_histogram!(
357                        "apollo_router_uplink_fetch_duration_seconds",
358                        "Duration of Apollo Uplink fetches.",
359                        now.elapsed().as_secs_f64(),
360                        query = query,
361                        url = url.to_string(),
362                        kind = "uplink_error",
363                        error = "empty response from uplink"
364                    );
365                }
366                Some(UplinkResponse::New {
367                    response,
368                    id,
369                    delay,
370                }) => {
371                    f64_histogram!(
372                        "apollo_router_uplink_fetch_duration_seconds",
373                        "Duration of Apollo Uplink fetches.",
374                        now.elapsed().as_secs_f64(),
375                        query = query,
376                        url = url.to_string(),
377                        kind = "new"
378                    );
379                    match transform_new_response(response).await {
380                        Ok(res) => {
381                            return Ok(UplinkResponse::New {
382                                response: res,
383                                id,
384                                delay,
385                            });
386                        }
387                        Err(err) => {
388                            tracing::debug!(
389                                "failed to process results of Uplink response from {url}: {err}. Other endpoints will be tried"
390                            );
391                            continue;
392                        }
393                    }
394                }
395                Some(UplinkResponse::Unchanged { id, delay }) => {
396                    f64_histogram!(
397                        "apollo_router_uplink_fetch_duration_seconds",
398                        "Duration of Apollo Uplink fetches.",
399                        now.elapsed().as_secs_f64(),
400                        query = query,
401                        url = url.to_string(),
402                        kind = "unchanged"
403                    );
404                    return Ok(UplinkResponse::Unchanged { id, delay });
405                }
406                Some(UplinkResponse::Error {
407                    message,
408                    code,
409                    retry_later,
410                }) => {
411                    f64_histogram!(
412                        "apollo_router_uplink_fetch_duration_seconds",
413                        "Duration of Apollo Uplink fetches.",
414                        now.elapsed().as_secs_f64(),
415                        query = query,
416                        url = url.to_string(),
417                        kind = "uplink_error",
418                        error = message.clone(),
419                        code = code.clone()
420                    );
421                    return Ok(UplinkResponse::Error {
422                        message,
423                        code,
424                        retry_later,
425                    });
426                }
427            },
428            Err(err) => {
429                f64_histogram!(
430                    "apollo_router_uplink_fetch_duration_seconds",
431                    "Duration of Apollo Uplink fetches.",
432                    now.elapsed().as_secs_f64(),
433                    query = query,
434                    url = url.to_string(),
435                    kind = "http_error",
436                    error = err.to_string(),
437                    code = err.status().unwrap_or_default().to_string()
438                );
439                tracing::debug!(
440                    "failed to fetch from Uplink endpoint {url}: {err}. Other endpoints will be tried"
441                );
442            }
443        };
444    }
445
446    let url_count = endpoints.url_count();
447    if url_count == 1 {
448        Err(Error::FetchFailedSingle)
449    } else {
450        Err(Error::FetchFailedMultiple { url_count })
451    }
452}
453
454fn query_name<Query>() -> &'static str {
455    let mut query = std::any::type_name::<Query>();
456    query = query
457        .strip_suffix("Query")
458        .expect("Uplink structs must be named xxxQuery")
459        .get(query.rfind("::").map(|index| index + 2).unwrap_or_default()..)
460        .expect("cannot fail");
461    query
462}
463
464async fn http_request<Query>(
465    client: &reqwest::Client,
466    url: &str,
467    request_body: &QueryBody<Query::Variables>,
468) -> Result<graphql_client::Response<Query::ResponseData>, reqwest::Error>
469where
470    Query: graphql_client::GraphQLQuery,
471{
472    // It is possible that istio-proxy is re-configuring networking beneath us. If it is, we'll see an error something like this:
473    // level: "ERROR"
474    // message: "fetch failed from all endpoints"
475    // target: "apollo_router::router::event::schema"
476    // timestamp: "2023-08-01T10:40:28.831196Z"
477    // That's deeply confusing and very hard to debug. Let's try to help by printing out a helpful error message here
478    let res = client
479        .post(url)
480        .header("x-router-version", env!("CARGO_PKG_VERSION"))
481        .json(request_body)
482        .send()
483        .await
484        .inspect_err(|e| {
485            if let Some(hyper_err) = e.source() {
486                if let Some(os_err) = hyper_err.source() {
487                    if os_err.to_string().contains("tcp connect error: Cannot assign requested address (os error 99)") {
488                        tracing::warn!("If your router is executing within a kubernetes pod, this failure may be caused by istio-proxy injection. See https://github.com/apollographql/router/issues/3533 for more details about how to solve this");
489                    }
490                }
491            }
492        })?;
493    tracing::debug!("uplink response {:?}", res);
494    let response_body: graphql_client::Response<Query::ResponseData> = res.json().await?;
495    Ok(response_body)
496}
497
498#[cfg(test)]
499mod test {
500    use std::collections::VecDeque;
501    use std::sync::Mutex;
502    use std::time::Duration;
503
504    use buildstructor::buildstructor;
505    use futures::StreamExt;
506    use graphql_client::GraphQLQuery;
507    use http::StatusCode;
508    use insta::assert_yaml_snapshot;
509    use serde_json::json;
510    use test_query::FetchErrorCode;
511    use test_query::TestQueryUplinkQuery;
512    use url::Url;
513    use wiremock::Mock;
514    use wiremock::MockServer;
515    use wiremock::Request;
516    use wiremock::Respond;
517    use wiremock::ResponseTemplate;
518    use wiremock::matchers::method;
519    use wiremock::matchers::path;
520
521    use crate::uplink::Endpoints;
522    use crate::uplink::Error;
523    use crate::uplink::UplinkConfig;
524    use crate::uplink::UplinkRequest;
525    use crate::uplink::UplinkResponse;
526    use crate::uplink::stream_from_uplink;
527    use crate::uplink::stream_from_uplink_transforming_new_response;
528
529    #[derive(GraphQLQuery)]
530    #[graphql(
531        query_path = "src/uplink/testdata/test_query.graphql",
532        schema_path = "src/uplink/testdata/test_uplink.graphql",
533        request_derives = "Debug",
534        response_derives = "PartialEq, Debug, Deserialize",
535        deprecated = "warn"
536    )]
537    pub(crate) struct TestQuery {}
538
539    #[derive(Debug, Eq, PartialEq)]
540    struct QueryResult {
541        name: String,
542        ordering: i64,
543    }
544
545    #[allow(dead_code)]
546    #[derive(Debug)]
547    struct TransformedQueryResult {
548        name: String,
549        halved_ordering: i64,
550    }
551
552    impl From<UplinkRequest> for test_query::Variables {
553        fn from(req: UplinkRequest) -> Self {
554            test_query::Variables {
555                api_key: req.api_key,
556                graph_ref: req.graph_ref,
557                if_after_id: req.id,
558            }
559        }
560    }
561
562    impl From<test_query::ResponseData> for UplinkResponse<QueryResult> {
563        fn from(response: test_query::ResponseData) -> Self {
564            match response.uplink_query {
565                TestQueryUplinkQuery::New(response) => UplinkResponse::New {
566                    id: response.id,
567                    delay: response.min_delay_seconds as u64,
568                    response: QueryResult {
569                        name: response.data.name,
570                        ordering: response.data.ordering,
571                    },
572                },
573                TestQueryUplinkQuery::Unchanged(response) => UplinkResponse::Unchanged {
574                    id: Some(response.id),
575                    delay: Some(response.min_delay_seconds as u64),
576                },
577                TestQueryUplinkQuery::FetchError(error) => UplinkResponse::Error {
578                    retry_later: error.code == FetchErrorCode::RETRY_LATER,
579                    code: match error.code {
580                        FetchErrorCode::AUTHENTICATION_FAILED => {
581                            "AUTHENTICATION_FAILED".to_string()
582                        }
583                        FetchErrorCode::ACCESS_DENIED => "ACCESS_DENIED".to_string(),
584                        FetchErrorCode::UNKNOWN_REF => "UNKNOWN_REF".to_string(),
585                        FetchErrorCode::RETRY_LATER => "RETRY_LATER".to_string(),
586                        FetchErrorCode::Other(other) => other,
587                    },
588                    message: error.message,
589                },
590            }
591        }
592    }
593
594    fn mock_uplink_config_with_fallback_urls(urls: Vec<Url>) -> UplinkConfig {
595        UplinkConfig {
596            apollo_key: "dummy_key".to_string(),
597            apollo_graph_ref: "dummy_graph_ref".to_string(),
598            endpoints: Some(Endpoints::fallback(urls)),
599            poll_interval: Duration::from_secs(0),
600            timeout: Duration::from_secs(1),
601        }
602    }
603
604    fn mock_uplink_config_with_round_robin_urls(urls: Vec<Url>) -> UplinkConfig {
605        UplinkConfig {
606            apollo_key: "dummy_key".to_string(),
607            apollo_graph_ref: "dummy_graph_ref".to_string(),
608            endpoints: Some(Endpoints::round_robin(urls)),
609            poll_interval: Duration::from_secs(0),
610            timeout: Duration::from_secs(1),
611        }
612    }
613
614    #[test]
615    fn test_round_robin_endpoints() {
616        let url1 = Url::parse("http://example1.com").expect("url must be valid");
617        let url2 = Url::parse("http://example2.com").expect("url must be valid");
618        let mut endpoints = Endpoints::round_robin(vec![url1.clone(), url2.clone()]);
619        assert_eq!(endpoints.iter().collect::<Vec<_>>(), vec![&url1, &url2]);
620        assert_eq!(endpoints.iter().next(), Some(&url1));
621        assert_eq!(endpoints.iter().collect::<Vec<_>>(), vec![&url2, &url1]);
622    }
623
624    #[test]
625    fn test_fallback_endpoints() {
626        let url1 = Url::parse("http://example1.com").expect("url must be valid");
627        let url2 = Url::parse("http://example2.com").expect("url must be valid");
628        let mut endpoints = Endpoints::fallback(vec![url1.clone(), url2.clone()]);
629        assert_eq!(endpoints.iter().collect::<Vec<_>>(), vec![&url1, &url2]);
630        assert_eq!(endpoints.iter().next(), Some(&url1));
631        assert_eq!(endpoints.iter().collect::<Vec<_>>(), vec![&url1, &url2]);
632    }
633
634    #[tokio::test(flavor = "multi_thread")]
635    async fn stream_from_uplink_fallback() {
636        let (mock_server, url1, url2, _url3) = init_mock_server().await;
637        MockResponses::builder()
638            .mock_server(&mock_server)
639            .endpoint(&url1)
640            .response(response_ok(1))
641            .response(response_ok(2))
642            .build()
643            .await;
644        MockResponses::builder()
645            .mock_server(&mock_server)
646            .endpoint(&url2)
647            .build()
648            .await;
649
650        let results = stream_from_uplink::<TestQuery, QueryResult>(
651            mock_uplink_config_with_fallback_urls(vec![url1, url2]),
652        )
653        .take(2)
654        .collect::<Vec<_>>()
655        .await;
656        assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
657    }
658
659    #[tokio::test(flavor = "multi_thread")]
660    async fn stream_from_uplink_round_robin() {
661        let (mock_server, url1, url2, _url3) = init_mock_server().await;
662        MockResponses::builder()
663            .mock_server(&mock_server)
664            .endpoint(&url1)
665            .response(response_ok(1))
666            .build()
667            .await;
668        MockResponses::builder()
669            .mock_server(&mock_server)
670            .response(response_ok(2))
671            .endpoint(&url2)
672            .build()
673            .await;
674
675        let results = stream_from_uplink::<TestQuery, QueryResult>(
676            mock_uplink_config_with_round_robin_urls(vec![url1, url2]),
677        )
678        .take(2)
679        .collect::<Vec<_>>()
680        .await;
681        assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
682    }
683
684    #[tokio::test(flavor = "multi_thread")]
685    async fn stream_from_uplink_error_retry() {
686        let (mock_server, url1, url2, _url3) = init_mock_server().await;
687        MockResponses::builder()
688            .mock_server(&mock_server)
689            .endpoint(&url1)
690            .response(response_fetch_error_retry())
691            .response(response_ok(1))
692            .build()
693            .await;
694        let results = stream_from_uplink::<TestQuery, QueryResult>(
695            mock_uplink_config_with_fallback_urls(vec![url1, url2]),
696        )
697        .take(2)
698        .collect::<Vec<_>>()
699        .await;
700        assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
701    }
702
703    #[tokio::test(flavor = "multi_thread")]
704    async fn stream_from_uplink_error_no_retry() {
705        let (mock_server, url1, url2, _url3) = init_mock_server().await;
706        MockResponses::builder()
707            .mock_server(&mock_server)
708            .endpoint(&url1)
709            .response(response_fetch_error_no_retry())
710            .build()
711            .await;
712        let results = stream_from_uplink::<TestQuery, QueryResult>(
713            mock_uplink_config_with_fallback_urls(vec![url1, url2]),
714        )
715        .collect::<Vec<_>>()
716        .await;
717        assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
718    }
719
720    #[tokio::test(flavor = "multi_thread")]
721    async fn stream_from_uplink_error_http_fallback() {
722        let (mock_server, url1, url2, url3) = init_mock_server().await;
723        MockResponses::builder()
724            .mock_server(&mock_server)
725            .endpoint(&url1)
726            .response(response_fetch_error_http())
727            .response(response_fetch_error_http())
728            .build()
729            .await;
730        MockResponses::builder()
731            .mock_server(&mock_server)
732            .endpoint(&url2)
733            .response(response_ok(1))
734            .response(response_ok(2))
735            .build()
736            .await;
737        MockResponses::builder()
738            .mock_server(&mock_server)
739            .endpoint(&url3)
740            .build()
741            .await;
742        let results = stream_from_uplink::<TestQuery, QueryResult>(
743            mock_uplink_config_with_fallback_urls(vec![url1, url2, url3]),
744        )
745        .take(2)
746        .collect::<Vec<_>>()
747        .await;
748        assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
749    }
750
751    #[tokio::test(flavor = "multi_thread")]
752    async fn stream_from_uplink_empty_http_fallback() {
753        let (mock_server, url1, url2, url3) = init_mock_server().await;
754        MockResponses::builder()
755            .mock_server(&mock_server)
756            .endpoint(&url1)
757            .response(response_empty())
758            .response(response_empty())
759            .build()
760            .await;
761        MockResponses::builder()
762            .mock_server(&mock_server)
763            .endpoint(&url2)
764            .response(response_ok(1))
765            .response(response_ok(2))
766            .build()
767            .await;
768        MockResponses::builder()
769            .mock_server(&mock_server)
770            .endpoint(&url3)
771            .build()
772            .await;
773        let results = stream_from_uplink::<TestQuery, QueryResult>(
774            mock_uplink_config_with_fallback_urls(vec![url1, url2, url3]),
775        )
776        .take(2)
777        .collect::<Vec<_>>()
778        .await;
779        assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
780    }
781
782    #[tokio::test(flavor = "multi_thread")]
783    async fn stream_from_uplink_error_http_round_robin() {
784        let (mock_server, url1, url2, url3) = init_mock_server().await;
785        MockResponses::builder()
786            .mock_server(&mock_server)
787            .endpoint(&url1)
788            .response(response_fetch_error_http())
789            .build()
790            .await;
791        MockResponses::builder()
792            .mock_server(&mock_server)
793            .endpoint(&url2)
794            .response(response_ok(1))
795            .build()
796            .await;
797        MockResponses::builder()
798            .mock_server(&mock_server)
799            .endpoint(&url3)
800            .response(response_ok(2))
801            .build()
802            .await;
803        let results = stream_from_uplink::<TestQuery, QueryResult>(
804            mock_uplink_config_with_round_robin_urls(vec![url1, url2, url3]),
805        )
806        .take(2)
807        .collect::<Vec<_>>()
808        .await;
809        assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
810    }
811
812    #[tokio::test(flavor = "multi_thread")]
813    async fn stream_from_uplink_empty_http_round_robin() {
814        let (mock_server, url1, url2, url3) = init_mock_server().await;
815        MockResponses::builder()
816            .mock_server(&mock_server)
817            .endpoint(&url1)
818            .response(response_empty())
819            .build()
820            .await;
821        MockResponses::builder()
822            .mock_server(&mock_server)
823            .endpoint(&url2)
824            .response(response_ok(1))
825            .build()
826            .await;
827        MockResponses::builder()
828            .mock_server(&mock_server)
829            .endpoint(&url3)
830            .response(response_ok(2))
831            .build()
832            .await;
833        let results = stream_from_uplink::<TestQuery, QueryResult>(
834            mock_uplink_config_with_round_robin_urls(vec![url1, url2, url3]),
835        )
836        .take(2)
837        .collect::<Vec<_>>()
838        .await;
839        assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
840    }
841
842    #[tokio::test(flavor = "multi_thread")]
843    async fn stream_from_uplink_invalid() {
844        let (mock_server, url1, url2, url3) = init_mock_server().await;
845        MockResponses::builder()
846            .mock_server(&mock_server)
847            .endpoint(&url1)
848            .response(response_invalid_license())
849            .build()
850            .await;
851        let results = stream_from_uplink::<TestQuery, QueryResult>(
852            mock_uplink_config_with_round_robin_urls(vec![url1, url2, url3]),
853        )
854        .take(1)
855        .collect::<Vec<_>>()
856        .await;
857        assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
858    }
859
860    #[tokio::test(flavor = "multi_thread")]
861    async fn stream_from_uplink_unchanged() {
862        let (mock_server, url1, url2, url3) = init_mock_server().await;
863        MockResponses::builder()
864            .mock_server(&mock_server)
865            .endpoint(&url1)
866            .response(response_ok(1))
867            .response(response_unchanged())
868            .response(response_ok(2))
869            .build()
870            .await;
871        let results = stream_from_uplink::<TestQuery, QueryResult>(
872            mock_uplink_config_with_round_robin_urls(vec![url1, url2, url3]),
873        )
874        .take(2)
875        .collect::<Vec<_>>()
876        .await;
877        assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
878    }
879
880    #[tokio::test(flavor = "multi_thread")]
881    async fn stream_from_uplink_failed_from_all() {
882        let (mock_server, url1, url2, _url3) = init_mock_server().await;
883        MockResponses::builder()
884            .mock_server(&mock_server)
885            .endpoint(&url1)
886            .response(response_fetch_error_http())
887            .build()
888            .await;
889        MockResponses::builder()
890            .mock_server(&mock_server)
891            .endpoint(&url2)
892            .response(response_fetch_error_http())
893            .build()
894            .await;
895        let results = stream_from_uplink::<TestQuery, QueryResult>(
896            mock_uplink_config_with_round_robin_urls(vec![url1, url2]),
897        )
898        .take(1)
899        .collect::<Vec<_>>()
900        .await;
901        assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
902    }
903
904    #[tokio::test(flavor = "multi_thread")]
905    async fn stream_from_uplink_failed_from_single() {
906        let (mock_server, url1, _url2, _url3) = init_mock_server().await;
907        MockResponses::builder()
908            .mock_server(&mock_server)
909            .endpoint(&url1)
910            .response(response_fetch_error_http())
911            .build()
912            .await;
913        let results = stream_from_uplink::<TestQuery, QueryResult>(
914            mock_uplink_config_with_fallback_urls(vec![url1]),
915        )
916        .take(1)
917        .collect::<Vec<_>>()
918        .await;
919        assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
920    }
921
922    #[tokio::test(flavor = "multi_thread")]
923    async fn stream_from_uplink_transforming_new_response_first_response_transform_fails() {
924        let (mock_server, url1, url2, _url3) = init_mock_server().await;
925        MockResponses::builder()
926            .mock_server(&mock_server)
927            .endpoint(&url1)
928            .response(response_ok(15))
929            .build()
930            .await;
931        MockResponses::builder()
932            .mock_server(&mock_server)
933            .endpoint(&url2)
934            .response(response_ok(100))
935            .build()
936            .await;
937        let results = stream_from_uplink_transforming_new_response::<
938            TestQuery,
939            QueryResult,
940            TransformedQueryResult,
941        >(
942            mock_uplink_config_with_fallback_urls(vec![url1, url2]),
943            |result| {
944                Box::new(Box::pin(async move {
945                    let QueryResult { name, ordering } = result;
946                    if ordering % 2 == 0 {
947                        // This will trigger on url2's response.
948                        Ok(TransformedQueryResult {
949                            name,
950                            halved_ordering: ordering / 2,
951                        })
952                    } else {
953                        // This will trigger on url1's response.
954                        Err("cannot halve an odd number".into())
955                    }
956                }))
957            },
958        )
959        .take(1)
960        .collect::<Vec<_>>()
961        .await;
962        assert_yaml_snapshot!(results.into_iter().map(to_friendly).collect::<Vec<_>>());
963    }
964
965    fn to_friendly<R: std::fmt::Debug>(r: Result<R, Error>) -> Result<String, String> {
966        match r {
967            Ok(e) => Ok(format!("result {:?}", e)),
968            Err(e) => Err(e.to_string()),
969        }
970    }
971
972    async fn init_mock_server() -> (MockServer, Url, Url, Url) {
973        let mock_server = MockServer::start().await;
974        let url1 =
975            Url::parse(&format!("{}/endpoint1", mock_server.uri())).expect("url must be valid");
976        let url2 =
977            Url::parse(&format!("{}/endpoint2", mock_server.uri())).expect("url must be valid");
978        let url3 =
979            Url::parse(&format!("{}/endpoint3", mock_server.uri())).expect("url must be valid");
980        (mock_server, url1, url2, url3)
981    }
982
983    struct MockResponses {
984        responses: Mutex<VecDeque<ResponseTemplate>>,
985    }
986
987    impl Respond for MockResponses {
988        fn respond(&self, _request: &Request) -> ResponseTemplate {
989            self.responses
990                .lock()
991                .expect("lock poisoned")
992                .pop_front()
993                .unwrap_or_else(response_fetch_error_test_error)
994        }
995    }
996
997    #[buildstructor]
998    impl MockResponses {
999        #[builder(entry = "builder")]
1000        async fn setup<'a>(
1001            mock_server: &'a MockServer,
1002            endpoint: &'a Url,
1003            responses: Vec<ResponseTemplate>,
1004        ) {
1005            let len = responses.len() as u64;
1006            Mock::given(method("POST"))
1007                .and(path(endpoint.path()))
1008                .respond_with(Self {
1009                    responses: Mutex::new(responses.into()),
1010                })
1011                .expect(len..len + 2)
1012                .mount(mock_server)
1013                .await;
1014        }
1015    }
1016
1017    fn response_ok(ordering: u64) -> ResponseTemplate {
1018        ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1019        {
1020            "data":{
1021                "uplinkQuery": {
1022                "__typename": "New",
1023                "id": ordering.to_string(),
1024                "minDelaySeconds": 0,
1025                "data": {
1026                    "name": "ok",
1027                    "ordering": ordering,
1028                    }
1029                }
1030            }
1031        }))
1032    }
1033
1034    fn response_invalid_license() -> ResponseTemplate {
1035        ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1036        {
1037            "data":{
1038                "uplinkQuery": {
1039                    "__typename": "New",
1040                    "id": "3",
1041                    "minDelaySeconds": 0,
1042                    "garbage": "garbage"
1043                    }
1044                }
1045        }))
1046    }
1047
1048    fn response_unchanged() -> ResponseTemplate {
1049        ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1050        {
1051            "data":{
1052                "uplinkQuery": {
1053                    "__typename": "Unchanged",
1054                    "id": "2",
1055                    "minDelaySeconds": 0,
1056                }
1057            }
1058        }))
1059    }
1060
1061    fn response_fetch_error_retry() -> ResponseTemplate {
1062        ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1063        {
1064            "data":{
1065                "uplinkQuery": {
1066                    "__typename": "FetchError",
1067                    "code": "RETRY_LATER",
1068                    "message": "error message",
1069                }
1070            }
1071        }))
1072    }
1073
1074    fn response_fetch_error_no_retry() -> ResponseTemplate {
1075        ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1076        {
1077            "data":{
1078                "uplinkQuery": {
1079                    "__typename": "FetchError",
1080                    "code": "NO_RETRY",
1081                    "message": "error message",
1082                }
1083            }
1084        }))
1085    }
1086
1087    fn response_fetch_error_test_error() -> ResponseTemplate {
1088        ResponseTemplate::new(StatusCode::OK).set_body_json(json!(
1089        {
1090            "data":{
1091                "uplinkQuery": {
1092                    "__typename": "FetchError",
1093                    "code": "NO_RETRY",
1094                    "message": "unexpected mock request, make sure you have set up appropriate responses",
1095                }
1096            }
1097        }))
1098    }
1099
1100    fn response_fetch_error_http() -> ResponseTemplate {
1101        ResponseTemplate::new(StatusCode::INTERNAL_SERVER_ERROR)
1102    }
1103
1104    fn response_empty() -> ResponseTemplate {
1105        ResponseTemplate::new(StatusCode::OK).set_body_json(json!({ "data": null }))
1106    }
1107}