Skip to main content

async_graphql/extensions/
apollo_persisted_queries.rs

1//! Apollo persisted queries extension.
2
3use std::sync::Arc;
4
5use async_graphql_parser::types::ExecutableDocument;
6use serde::Deserialize;
7use sha2::{Digest, Sha256};
8
9use crate::{
10    Request, ServerError, ServerResult,
11    extensions::{Extension, ExtensionContext, ExtensionFactory, NextPrepareRequest},
12    from_value,
13};
14
15#[derive(Deserialize)]
16struct PersistedQuery {
17    version: i32,
18    #[serde(rename = "sha256Hash")]
19    sha256_hash: String,
20}
21
22/// Cache storage for persisted queries.
23#[async_trait::async_trait]
24pub trait CacheStorage: Send + Sync + Clone + 'static {
25    /// Load the query by `key`.
26    async fn get(&self, key: String) -> Option<ExecutableDocument>;
27
28    /// Save the query by `key`.
29    async fn set(&self, key: String, query: ExecutableDocument);
30}
31
32/// Memory-based LRU cache.
33#[derive(Clone)]
34pub struct LruCacheStorage(Arc<scc::HashCache<String, ExecutableDocument>>);
35
36impl LruCacheStorage {
37    /// Creates a new LRU Cache that holds at most `cap` items.
38    pub fn new(cap: usize) -> Self {
39        Self(Arc::new(scc::HashCache::with_capacity(0, cap)))
40    }
41}
42
43#[async_trait::async_trait]
44impl CacheStorage for LruCacheStorage {
45    async fn get(&self, key: String) -> Option<ExecutableDocument> {
46        self.0
47            .get_async(&key)
48            .await
49            .map(|entry| entry.get().clone())
50    }
51
52    async fn set(&self, key: String, query: ExecutableDocument) {
53        let _ = self.0.put_async(key, query).await;
54    }
55}
56
57/// Apollo persisted queries extension.
58///
59/// [Reference](https://www.apollographql.com/docs/react/api/link/persisted-queries/)
60#[cfg_attr(docsrs, doc(cfg(feature = "apollo_persisted_queries")))]
61pub struct ApolloPersistedQueries<T>(T);
62
63impl<T: CacheStorage> ApolloPersistedQueries<T> {
64    /// Creates an apollo persisted queries extension.
65    pub fn new(cache_storage: T) -> ApolloPersistedQueries<T> {
66        Self(cache_storage)
67    }
68}
69
70impl<T: CacheStorage> ExtensionFactory for ApolloPersistedQueries<T> {
71    fn create(&self) -> Arc<dyn Extension> {
72        Arc::new(ApolloPersistedQueriesExtension {
73            storage: self.0.clone(),
74        })
75    }
76}
77
78struct ApolloPersistedQueriesExtension<T> {
79    storage: T,
80}
81
82#[async_trait::async_trait]
83impl<T: CacheStorage> Extension for ApolloPersistedQueriesExtension<T> {
84    async fn prepare_request(
85        &self,
86        ctx: &ExtensionContext<'_>,
87        mut request: Request,
88        next: NextPrepareRequest<'_>,
89    ) -> ServerResult<Request> {
90        let res = if let Some(value) = request.extensions.remove("persistedQuery") {
91            let persisted_query: PersistedQuery = from_value(value).map_err(|_| {
92                ServerError::new("Invalid \"PersistedQuery\" extension configuration.", None)
93            })?;
94            if persisted_query.version != 1 {
95                return Err(ServerError::new(
96                    format!(
97                        "Only the \"PersistedQuery\" extension of version \"1\" is supported, and the current version is \"{}\".",
98                        persisted_query.version
99                    ),
100                    None,
101                ));
102            }
103
104            if request.query.is_empty() {
105                if let Some(doc) = self.storage.get(persisted_query.sha256_hash).await {
106                    Ok(Request {
107                        parsed_query: Some(doc),
108                        ..request
109                    })
110                } else {
111                    Err(ServerError::new("PersistedQueryNotFound", None))
112                }
113            } else {
114                let sha256_hash = format!("{:x}", Sha256::digest(request.query.as_bytes()));
115
116                if persisted_query.sha256_hash != sha256_hash {
117                    Err(ServerError::new("provided sha does not match query", None))
118                } else {
119                    let doc = async_graphql_parser::parse_query(&request.query)?;
120                    self.storage.set(sha256_hash, doc.clone()).await;
121                    Ok(Request {
122                        query: String::new(),
123                        parsed_query: Some(doc),
124                        ..request
125                    })
126                }
127            }
128        } else {
129            Ok(request)
130        };
131        next.run(ctx, res?).await
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    #[tokio::test]
138    async fn test() {
139        use super::*;
140        use crate::*;
141
142        struct Query;
143
144        #[Object(internal)]
145        impl Query {
146            async fn value(&self) -> i32 {
147                100
148            }
149        }
150
151        let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
152            .extension(ApolloPersistedQueries::new(LruCacheStorage::new(256)))
153            .finish();
154
155        let mut request = Request::new("{ value }");
156        request.extensions.insert(
157            "persistedQuery".to_string(),
158            value!({
159                "version": 1,
160                "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b",
161            }),
162        );
163
164        assert_eq!(
165            schema.execute(request).await.into_result().unwrap().data,
166            value!({
167                "value": 100
168            })
169        );
170
171        let mut request = Request::new("");
172        request.extensions.insert(
173            "persistedQuery".to_string(),
174            value!({
175                "version": 1,
176                "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b",
177            }),
178        );
179
180        assert_eq!(
181            schema.execute(request).await.into_result().unwrap().data,
182            value!({
183                "value": 100
184            })
185        );
186
187        let mut request = Request::new("");
188        request.extensions.insert(
189            "persistedQuery".to_string(),
190            value!({
191                "version": 1,
192                "sha256Hash": "def",
193            }),
194        );
195
196        assert_eq!(
197            schema.execute(request).await.into_result().unwrap_err(),
198            vec![ServerError::new("PersistedQueryNotFound", None)]
199        );
200    }
201}