Skip to main content

apollo_router/services/layers/
apq.rs

1//! (A)utomatic (P)ersisted (Q)ueries cache.
2//!
3//! For more information on APQ see:
4//! <https://www.apollographql.com/docs/apollo-server/performance/apq/>
5
6use http::HeaderValue;
7use http::StatusCode;
8use http::header::CACHE_CONTROL;
9use serde::Deserialize;
10use serde_json_bytes::Value;
11use sha2::Digest;
12use sha2::Sha256;
13
14use crate::cache::DeduplicatingCache;
15use crate::services::SupergraphRequest;
16use crate::services::SupergraphResponse;
17
18const DONT_CACHE_RESPONSE_VALUE: &str = "private, no-cache, must-revalidate";
19static DONT_CACHE_HEADER_VALUE: HeaderValue = HeaderValue::from_static(DONT_CACHE_RESPONSE_VALUE);
20pub(crate) const PERSISTED_QUERY_CACHE_HIT: &str = "apollo::apq::cache_hit";
21pub(crate) const PERSISTED_QUERY_REGISTERED: &str = "apollo::apq::registered";
22
23/// A persisted query.
24#[derive(Deserialize, Clone, Debug)]
25pub(crate) struct PersistedQuery {
26    #[allow(unused)]
27    pub(crate) version: u8,
28    #[serde(rename = "sha256Hash")]
29    pub(crate) sha256hash: String,
30}
31
32impl PersistedQuery {
33    /// Attempt to extract a `PersistedQuery` from a `&SupergraphRequest`
34    pub(crate) fn maybe_from_request(request: &SupergraphRequest) -> Option<Self> {
35        request
36            .supergraph_request
37            .body()
38            .extensions
39            .get("persistedQuery")
40            .and_then(|value| serde_json_bytes::from_value(value.clone()).ok())
41    }
42
43    /// Attempt to decode the sha256 hash in a [`PersistedQuery`]
44    pub(crate) fn decode_hash(self) -> Option<(String, Vec<u8>)> {
45        hex::decode(self.sha256hash.as_bytes())
46            .ok()
47            .map(|decoded| (self.sha256hash, decoded))
48    }
49}
50
51/// A layer-like type implementing Automatic Persisted Queries.
52#[derive(Clone)]
53pub(crate) struct APQLayer {
54    /// set to None if APQ is disabled
55    cache: Option<DeduplicatingCache<String, String>>,
56}
57
58impl APQLayer {
59    pub(crate) fn activate(&self) {
60        if let Some(cache) = &self.cache {
61            cache.activate();
62        }
63    }
64}
65
66impl APQLayer {
67    pub(crate) fn with_cache(cache: DeduplicatingCache<String, String>) -> Self {
68        Self { cache: Some(cache) }
69    }
70
71    pub(crate) fn disabled() -> Self {
72        Self { cache: None }
73    }
74
75    /// Supergraph service implementation for Automatic Persisted Queries.
76    ///
77    /// For more information about APQ:
78    /// https://www.apollographql.com/docs/apollo-server/performance/apq.
79    ///
80    /// If APQ is disabled, it rejects requests that try to use a persisted query hash.
81    /// If APQ is enabled, requests using APQ will populate the cache and use the cache as needed,
82    /// see [`apq_request`] for details.
83    ///
84    /// This must happen before GraphQL query parsing.
85    ///
86    /// This functions similarly to a checkpoint service, short-circuiting the pipeline on error
87    /// (using an `Err()` return value).
88    /// The user of this function is responsible for propagating short-circuiting.
89    pub(crate) async fn supergraph_request(
90        &self,
91        request: SupergraphRequest,
92    ) -> Result<SupergraphRequest, SupergraphResponse> {
93        match self.cache.as_ref() {
94            Some(cache) => apq_request(cache, request).await,
95            None => disabled_apq_request(request).await,
96        }
97    }
98}
99
100/// Used when APQ is enabled.
101///
102/// If the request contains a hash and a query string, that query is added to the APQ cache.
103/// Then, the client can submit only the hash and not the query string on subsequent requests.
104/// The request is rejected if the hash does not match the query string.
105///
106/// If the request contains only a hash, attempts to read the query from the APQ cache, and
107/// populates the query string in the request body.
108/// The request is rejected if the hash is not present in the cache.
109async fn apq_request(
110    cache: &DeduplicatingCache<String, String>,
111    mut request: SupergraphRequest,
112) -> Result<SupergraphRequest, SupergraphResponse> {
113    let maybe_query_hash =
114        PersistedQuery::maybe_from_request(&request).and_then(PersistedQuery::decode_hash);
115
116    let body_query = request.supergraph_request.body().query.clone();
117
118    match (maybe_query_hash, body_query) {
119        (Some((query_hash, query_hash_bytes)), Some(query)) => {
120            if query_matches_hash(query.as_str(), query_hash_bytes.as_slice()) {
121                tracing::trace!("apq: cache insert");
122                let _ = request.context.insert(PERSISTED_QUERY_REGISTERED, true);
123                let query = query.to_owned();
124                let cache = cache.clone();
125                tokio::spawn(async move {
126                    cache.insert(redis_key(&query_hash), query).await;
127                });
128                Ok(request)
129            } else {
130                tracing::debug!("apq: graphql request doesn't match provided sha256Hash");
131                let errors = vec![
132                    crate::error::Error::builder()
133                        .message("provided sha does not match query".to_string())
134                        .locations(Default::default())
135                        .extension_code("PERSISTED_QUERY_HASH_MISMATCH")
136                        .build(),
137                ];
138                let res = SupergraphResponse::builder()
139                    .status_code(StatusCode::BAD_REQUEST)
140                    .data(Value::default())
141                    .errors(errors)
142                    .context(request.context)
143                    .build()
144                    .expect("response is valid");
145                Err(res)
146            }
147        }
148        (Some((apq_hash, _)), _) => {
149            if let Ok(cached_query) = cache
150                .get(&redis_key(&apq_hash), |_| Ok(()))
151                .await
152                .get()
153                .await
154            {
155                let _ = request.context.insert(PERSISTED_QUERY_CACHE_HIT, true);
156                tracing::trace!("apq: cache hit");
157                request.supergraph_request.body_mut().query = Some(cached_query);
158                Ok(request)
159            } else {
160                let _ = request.context.insert(PERSISTED_QUERY_CACHE_HIT, false);
161                tracing::trace!("apq: cache miss");
162                let errors = vec![
163                    crate::error::Error::builder()
164                        .message("PersistedQueryNotFound".to_string())
165                        .locations(Default::default())
166                        .extension_code("PERSISTED_QUERY_NOT_FOUND")
167                        .build(),
168                ];
169                let res = SupergraphResponse::builder()
170                    .data(Value::default())
171                    .errors(errors)
172                    // Persisted query errors (especially "not found") need to be uncached, because
173                    // hopefully we're about to fill in the APQ cache and the same request will
174                    // succeed next time.
175                    .header(CACHE_CONTROL, DONT_CACHE_HEADER_VALUE.clone())
176                    .context(request.context)
177                    .build()
178                    .expect("response is valid");
179
180                Err(res)
181            }
182        }
183        _ => Ok(request),
184    }
185}
186
187fn query_matches_hash(query: &str, hash: &[u8]) -> bool {
188    let mut digest = Sha256::new();
189    digest.update(query.as_bytes());
190    hash == digest.finalize().as_slice()
191}
192
193fn redis_key(query_hash: &str) -> String {
194    format!("apq:{query_hash}")
195}
196
197pub(crate) fn calculate_hash_for_query(query: &str) -> String {
198    let mut hasher = Sha256::new();
199    hasher.update(query);
200    hex::encode(hasher.finalize())
201}
202
203/// Used when APQ is disabled. Rejects requests that try to use a persisted query hash anyways.
204async fn disabled_apq_request(
205    request: SupergraphRequest,
206) -> Result<SupergraphRequest, SupergraphResponse> {
207    if request
208        .supergraph_request
209        .body()
210        .extensions
211        .contains_key("persistedQuery")
212    {
213        let errors = vec![
214            crate::error::Error::builder()
215                .message("PersistedQueryNotSupported".to_string())
216                .locations(Default::default())
217                .extension_code("PERSISTED_QUERY_NOT_SUPPORTED")
218                .build(),
219        ];
220        let res = SupergraphResponse::builder()
221            .data(Value::default())
222            .errors(errors)
223            .context(request.context)
224            .build()
225            .expect("response is valid");
226
227        Err(res)
228    } else {
229        Ok(request)
230    }
231}
232#[cfg(test)]
233mod apq_tests {
234    use std::borrow::Cow;
235    use std::sync::Arc;
236
237    use futures::StreamExt;
238    use http::StatusCode;
239    use serde_json_bytes::json;
240    use tower::Service;
241    use tower::ServiceExt;
242
243    use super::*;
244    use crate::Configuration;
245    use crate::Context;
246    use crate::assert_error_eq_ignoring_id;
247    use crate::error::Error;
248    use crate::services::router::ClientRequestAccepts;
249    use crate::services::router::service::from_supergraph_mock_callback;
250    use crate::services::router::service::from_supergraph_mock_callback_and_configuration;
251
252    #[tokio::test]
253    async fn it_works() {
254        let hash = Cow::from("ecf4edb46db40b5132295c0291d62fb65d6759a9eedfa4d5d612dd5ec54a6b38");
255        let hash2 = hash.clone();
256
257        let expected_apq_miss_error = Error::builder()
258            .message("PersistedQueryNotFound".to_string())
259            .locations(Default::default())
260            .extension_code("PERSISTED_QUERY_NOT_FOUND")
261            .build();
262
263        let mut router_service = from_supergraph_mock_callback(move |req| {
264            let body = req.supergraph_request.body();
265            let as_json = body.extensions.get("persistedQuery").unwrap();
266
267            let persisted_query: PersistedQuery =
268                serde_json_bytes::from_value(as_json.clone()).unwrap();
269
270            assert_eq!(persisted_query.sha256hash, hash2);
271
272            assert!(body.query.is_some());
273
274            let hash = hex::decode(hash2.as_bytes()).unwrap();
275
276            assert!(query_matches_hash(
277                body.query.clone().unwrap().as_str(),
278                hash.as_slice()
279            ));
280
281            Ok(SupergraphResponse::fake_builder()
282                .context(req.context)
283                .build()
284                .expect("expecting valid request"))
285        })
286        .await;
287
288        let persisted = json!({
289            "version" : 1,
290            "sha256Hash" : "ecf4edb46db40b5132295c0291d62fb65d6759a9eedfa4d5d612dd5ec54a6b38"
291        });
292
293        let hash_only = SupergraphRequest::fake_builder()
294            .extension("persistedQuery", persisted.clone())
295            .context(new_context())
296            .build()
297            .expect("expecting valid request")
298            .try_into()
299            .unwrap();
300        let apq_response = router_service
301            .ready()
302            .await
303            .expect("readied")
304            .call(hash_only)
305            .await
306            .unwrap();
307
308        // make sure clients won't cache apq missed response
309        assert_eq!(
310            DONT_CACHE_RESPONSE_VALUE,
311            apq_response.response.headers().get(CACHE_CONTROL).unwrap()
312        );
313
314        let apq_error = apq_response
315            .into_graphql_response_stream()
316            .await
317            .next()
318            .await
319            .unwrap()
320            .unwrap();
321
322        assert_error_eq_ignoring_id!(expected_apq_miss_error, apq_error.errors[0]);
323
324        let with_query = SupergraphRequest::fake_builder()
325            .extension("persistedQuery", persisted.clone())
326            .query("{__typename}".to_string())
327            .context(new_context())
328            .build()
329            .expect("expecting valid request")
330            .try_into()
331            .unwrap();
332
333        let full_response = router_service
334            .ready()
335            .await
336            .expect("readied")
337            .call(with_query)
338            .await
339            .unwrap();
340
341        // the cache control header shouldn't have been tampered with
342        assert!(
343            full_response
344                .response
345                .headers()
346                .get(CACHE_CONTROL)
347                .is_none()
348        );
349
350        // We need to yield here to make sure the router
351        // runs the Drop implementation of the deduplicating cache Entry.
352        tokio::task::yield_now().await;
353
354        let second_hash_only = SupergraphRequest::fake_builder()
355            .extension("persistedQuery", persisted.clone())
356            .context(new_context())
357            .build()
358            .expect("expecting valid request")
359            .try_into()
360            .unwrap();
361
362        let apq_response = router_service
363            .ready()
364            .await
365            .expect("readied")
366            .call(second_hash_only)
367            .await
368            .unwrap();
369
370        // the cache control header shouldn't have been tampered with
371        assert!(apq_response.response.headers().get(CACHE_CONTROL).is_none());
372    }
373
374    #[tokio::test]
375    async fn it_doesnt_update_the_cache_if_the_hash_is_not_valid() {
376        let hash = Cow::from("ecf4edb46db40b5132295c0291d62fb65d6759a9eedfa4d5d612dd5ec54a6b36");
377        let hash2 = hash.clone();
378
379        let expected_apq_miss_error = Error::builder()
380            .message("PersistedQueryNotFound".to_string())
381            .locations(Default::default())
382            .extension_code("PERSISTED_QUERY_NOT_FOUND")
383            .build();
384
385        let mut router_service = from_supergraph_mock_callback(move |req| {
386            let body = req.supergraph_request.body();
387            let as_json = body.extensions.get("persistedQuery").unwrap();
388
389            let persisted_query: PersistedQuery =
390                serde_json_bytes::from_value(as_json.clone()).unwrap();
391
392            assert_eq!(persisted_query.sha256hash, hash2);
393
394            assert!(body.query.is_some());
395
396            Ok(SupergraphResponse::fake_builder()
397                .context(req.context)
398                .build()
399                .expect("expecting valid request"))
400        })
401        .await;
402
403        let persisted = json!({
404            "version" : 1,
405            "sha256Hash" : "ecf4edb46db40b5132295c0291d62fb65d6759a9eedfa4d5d612dd5ec54a6b36"
406        });
407
408        let request_builder =
409            SupergraphRequest::fake_builder().extension("persistedQuery", persisted.clone());
410
411        let hash_only = request_builder
412            .context(new_context())
413            .build()
414            .expect("expecting valid request")
415            .try_into()
416            .unwrap();
417
418        let request_builder =
419            SupergraphRequest::fake_builder().extension("persistedQuery", persisted.clone());
420
421        let second_hash_only = request_builder
422            .context(new_context())
423            .build()
424            .expect("expecting valid request")
425            .try_into()
426            .unwrap();
427
428        let request_builder =
429            SupergraphRequest::fake_builder().extension("persistedQuery", persisted.clone());
430
431        let with_query = request_builder
432            .query("{__typename}".to_string())
433            .context(new_context())
434            .build()
435            .expect("expecting valid request")
436            .try_into()
437            .unwrap();
438
439        // This apq call will miss the APQ cache
440        let apq_error = router_service
441            .ready()
442            .await
443            .expect("readied")
444            .call(hash_only)
445            .await
446            .unwrap()
447            .into_graphql_response_stream()
448            .await
449            .next()
450            .await
451            .unwrap()
452            .unwrap();
453
454        assert_error_eq_ignoring_id!(expected_apq_miss_error, apq_error.errors[0]);
455
456        // sha256 is wrong, apq insert won't happen
457        let insert_failed_response = router_service
458            .ready()
459            .await
460            .expect("readied")
461            .call(with_query)
462            .await
463            .unwrap();
464
465        assert_eq!(
466            StatusCode::BAD_REQUEST,
467            insert_failed_response.response.status()
468        );
469
470        let graphql_response = insert_failed_response
471            .into_graphql_response_stream()
472            .await
473            .next()
474            .await
475            .unwrap()
476            .unwrap();
477        let expected_apq_insert_failed_error = Error::builder()
478            .message("provided sha does not match query".to_string())
479            .locations(Default::default())
480            .extension_code("PERSISTED_QUERY_HASH_MISMATCH")
481            .build();
482        assert_error_eq_ignoring_id!(expected_apq_insert_failed_error, graphql_response.errors[0]);
483
484        // apq insert failed, this call will miss
485        let second_apq_error = router_service
486            .ready()
487            .await
488            .expect("readied")
489            .call(second_hash_only)
490            .await
491            .unwrap()
492            .into_graphql_response_stream()
493            .await
494            .next()
495            .await
496            .unwrap()
497            .unwrap();
498
499        assert_error_eq_ignoring_id!(expected_apq_miss_error, second_apq_error.errors[0]);
500    }
501
502    #[tokio::test]
503    async fn return_not_supported_when_disabled() {
504        let expected_apq_miss_error = Error::builder()
505            .message("PersistedQueryNotSupported".to_string())
506            .locations(Default::default())
507            .extension_code("PERSISTED_QUERY_NOT_SUPPORTED")
508            .build();
509
510        let mut config = Configuration::default();
511        config.apq.enabled = false;
512
513        let mut router_service = from_supergraph_mock_callback_and_configuration(
514            move |req| {
515                Ok(SupergraphResponse::fake_builder()
516                    .context(req.context)
517                    .build()
518                    .expect("expecting valid request"))
519            },
520            Arc::new(config),
521        )
522        .await;
523
524        let persisted = json!({
525            "version" : 1,
526            "sha256Hash" : "ecf4edb46db40b5132295c0291d62fb65d6759a9eedfa4d5d612dd5ec54a6b38"
527        });
528
529        let hash_only = SupergraphRequest::fake_builder()
530            .extension("persistedQuery", persisted.clone())
531            .context(new_context())
532            .build()
533            .expect("expecting valid request")
534            .try_into()
535            .unwrap();
536        let apq_response = router_service
537            .ready()
538            .await
539            .expect("readied")
540            .call(hash_only)
541            .await
542            .unwrap();
543
544        let apq_error = apq_response
545            .into_graphql_response_stream()
546            .await
547            .next()
548            .await
549            .unwrap()
550            .unwrap();
551
552        assert_error_eq_ignoring_id!(expected_apq_miss_error, apq_error.errors[0]);
553
554        let with_query = SupergraphRequest::fake_builder()
555            .extension("persistedQuery", persisted.clone())
556            .query("{__typename}".to_string())
557            .context(new_context())
558            .build()
559            .expect("expecting valid request")
560            .try_into()
561            .unwrap();
562
563        let with_query_response = router_service
564            .ready()
565            .await
566            .expect("readied")
567            .call(with_query)
568            .await
569            .unwrap();
570
571        let apq_error = with_query_response
572            .into_graphql_response_stream()
573            .await
574            .next()
575            .await
576            .unwrap()
577            .unwrap();
578
579        assert_error_eq_ignoring_id!(expected_apq_miss_error, apq_error.errors[0]);
580
581        let without_apq = SupergraphRequest::fake_builder()
582            .query("{__typename}".to_string())
583            .context(new_context())
584            .build()
585            .expect("expecting valid request")
586            .try_into()
587            .unwrap();
588
589        let without_apq_response = router_service
590            .ready()
591            .await
592            .expect("readied")
593            .call(without_apq)
594            .await
595            .unwrap();
596
597        let without_apq_graphql_response = without_apq_response
598            .into_graphql_response_stream()
599            .await
600            .next()
601            .await
602            .unwrap()
603            .unwrap();
604
605        assert!(without_apq_graphql_response.errors.is_empty());
606    }
607
608    fn new_context() -> Context {
609        let context = Context::new();
610        context.extensions().with_lock(|lock| {
611            lock.insert(ClientRequestAccepts {
612                json: true,
613                ..Default::default()
614            })
615        });
616
617        context
618    }
619}