async_graphql/extensions/
apollo_persisted_queries.rs

1//! Apollo persisted queries extension.
2
3use std::{num::NonZeroUsize, sync::Arc};
4
5use async_graphql_parser::types::ExecutableDocument;
6use futures_util::lock::Mutex;
7use serde::Deserialize;
8use sha2::{Digest, Sha256};
9
10use crate::{
11    Request, ServerError, ServerResult,
12    extensions::{Extension, ExtensionContext, ExtensionFactory, NextPrepareRequest},
13    from_value,
14};
15
16#[derive(Deserialize)]
17struct PersistedQuery {
18    version: i32,
19    #[serde(rename = "sha256Hash")]
20    sha256_hash: String,
21}
22
23/// Cache storage for persisted queries.
24#[async_trait::async_trait]
25pub trait CacheStorage: Send + Sync + Clone + 'static {
26    /// Load the query by `key`.
27    async fn get(&self, key: String) -> Option<ExecutableDocument>;
28
29    /// Save the query by `key`.
30    async fn set(&self, key: String, query: ExecutableDocument);
31}
32
33/// Memory-based LRU cache.
34#[derive(Clone)]
35pub struct LruCacheStorage(Arc<Mutex<lru::LruCache<String, ExecutableDocument>>>);
36
37impl LruCacheStorage {
38    /// Creates a new LRU Cache that holds at most `cap` items.
39    pub fn new(cap: usize) -> Self {
40        Self(Arc::new(Mutex::new(lru::LruCache::new(
41            NonZeroUsize::new(cap).unwrap(),
42        ))))
43    }
44}
45
46#[async_trait::async_trait]
47impl CacheStorage for LruCacheStorage {
48    async fn get(&self, key: String) -> Option<ExecutableDocument> {
49        let mut cache = self.0.lock().await;
50        cache.get(&key).cloned()
51    }
52
53    async fn set(&self, key: String, query: ExecutableDocument) {
54        let mut cache = self.0.lock().await;
55        cache.put(key, query);
56    }
57}
58
59/// Apollo persisted queries extension.
60///
61/// [Reference](https://www.apollographql.com/docs/react/api/link/persisted-queries/)
62#[cfg_attr(docsrs, doc(cfg(feature = "apollo_persisted_queries")))]
63pub struct ApolloPersistedQueries<T>(T);
64
65impl<T: CacheStorage> ApolloPersistedQueries<T> {
66    /// Creates an apollo persisted queries extension.
67    pub fn new(cache_storage: T) -> ApolloPersistedQueries<T> {
68        Self(cache_storage)
69    }
70}
71
72impl<T: CacheStorage> ExtensionFactory for ApolloPersistedQueries<T> {
73    fn create(&self) -> Arc<dyn Extension> {
74        Arc::new(ApolloPersistedQueriesExtension {
75            storage: self.0.clone(),
76        })
77    }
78}
79
80struct ApolloPersistedQueriesExtension<T> {
81    storage: T,
82}
83
84#[async_trait::async_trait]
85impl<T: CacheStorage> Extension for ApolloPersistedQueriesExtension<T> {
86    async fn prepare_request(
87        &self,
88        ctx: &ExtensionContext<'_>,
89        mut request: Request,
90        next: NextPrepareRequest<'_>,
91    ) -> ServerResult<Request> {
92        let res = if let Some(value) = request.extensions.remove("persistedQuery") {
93            let persisted_query: PersistedQuery = from_value(value).map_err(|_| {
94                ServerError::new("Invalid \"PersistedQuery\" extension configuration.", None)
95            })?;
96            if persisted_query.version != 1 {
97                return Err(ServerError::new(
98                    format!(
99                        "Only the \"PersistedQuery\" extension of version \"1\" is supported, and the current version is \"{}\".",
100                        persisted_query.version
101                    ),
102                    None,
103                ));
104            }
105
106            if request.query.is_empty() {
107                if let Some(doc) = self.storage.get(persisted_query.sha256_hash).await {
108                    Ok(Request {
109                        parsed_query: Some(doc),
110                        ..request
111                    })
112                } else {
113                    Err(ServerError::new("PersistedQueryNotFound", None))
114                }
115            } else {
116                let sha256_hash = format!("{:x}", Sha256::digest(request.query.as_bytes()));
117
118                if persisted_query.sha256_hash != sha256_hash {
119                    Err(ServerError::new("provided sha does not match query", None))
120                } else {
121                    let doc = async_graphql_parser::parse_query(&request.query)?;
122                    self.storage.set(sha256_hash, doc.clone()).await;
123                    Ok(Request {
124                        query: String::new(),
125                        parsed_query: Some(doc),
126                        ..request
127                    })
128                }
129            }
130        } else {
131            Ok(request)
132        };
133        next.run(ctx, res?).await
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    #[tokio::test]
140    async fn test() {
141        use super::*;
142        use crate::*;
143
144        struct Query;
145
146        #[Object(internal)]
147        impl Query {
148            async fn value(&self) -> i32 {
149                100
150            }
151        }
152
153        let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
154            .extension(ApolloPersistedQueries::new(LruCacheStorage::new(256)))
155            .finish();
156
157        let mut request = Request::new("{ value }");
158        request.extensions.insert(
159            "persistedQuery".to_string(),
160            value!({
161                "version": 1,
162                "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b",
163            }),
164        );
165
166        assert_eq!(
167            schema.execute(request).await.into_result().unwrap().data,
168            value!({
169                "value": 100
170            })
171        );
172
173        let mut request = Request::new("");
174        request.extensions.insert(
175            "persistedQuery".to_string(),
176            value!({
177                "version": 1,
178                "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b",
179            }),
180        );
181
182        assert_eq!(
183            schema.execute(request).await.into_result().unwrap().data,
184            value!({
185                "value": 100
186            })
187        );
188
189        let mut request = Request::new("");
190        request.extensions.insert(
191            "persistedQuery".to_string(),
192            value!({
193                "version": 1,
194                "sha256Hash": "def",
195            }),
196        );
197
198        assert_eq!(
199            schema.execute(request).await.into_result().unwrap_err(),
200            vec![ServerError::new("PersistedQueryNotFound", None)]
201        );
202    }
203}