async_graphql/extensions/
apollo_persisted_queries.rs1use std::{num::NonZeroUsize, sync::Arc};
4
5use async_graphql_parser::types::ExecutableDocument;
6use futures_util::lock::Mutex;
7use serde::Deserialize;
8use sha2::{Digest, Sha256};
9
10use crate::{
11 Request, ServerError, ServerResult,
12 extensions::{Extension, ExtensionContext, ExtensionFactory, NextPrepareRequest},
13 from_value,
14};
15
16#[derive(Deserialize)]
17struct PersistedQuery {
18 version: i32,
19 #[serde(rename = "sha256Hash")]
20 sha256_hash: String,
21}
22
23#[async_trait::async_trait]
25pub trait CacheStorage: Send + Sync + Clone + 'static {
26 async fn get(&self, key: String) -> Option<ExecutableDocument>;
28
29 async fn set(&self, key: String, query: ExecutableDocument);
31}
32
33#[derive(Clone)]
35pub struct LruCacheStorage(Arc<Mutex<lru::LruCache<String, ExecutableDocument>>>);
36
37impl LruCacheStorage {
38 pub fn new(cap: usize) -> Self {
40 Self(Arc::new(Mutex::new(lru::LruCache::new(
41 NonZeroUsize::new(cap).unwrap(),
42 ))))
43 }
44}
45
46#[async_trait::async_trait]
47impl CacheStorage for LruCacheStorage {
48 async fn get(&self, key: String) -> Option<ExecutableDocument> {
49 let mut cache = self.0.lock().await;
50 cache.get(&key).cloned()
51 }
52
53 async fn set(&self, key: String, query: ExecutableDocument) {
54 let mut cache = self.0.lock().await;
55 cache.put(key, query);
56 }
57}
58
59#[cfg_attr(docsrs, doc(cfg(feature = "apollo_persisted_queries")))]
63pub struct ApolloPersistedQueries<T>(T);
64
65impl<T: CacheStorage> ApolloPersistedQueries<T> {
66 pub fn new(cache_storage: T) -> ApolloPersistedQueries<T> {
68 Self(cache_storage)
69 }
70}
71
72impl<T: CacheStorage> ExtensionFactory for ApolloPersistedQueries<T> {
73 fn create(&self) -> Arc<dyn Extension> {
74 Arc::new(ApolloPersistedQueriesExtension {
75 storage: self.0.clone(),
76 })
77 }
78}
79
80struct ApolloPersistedQueriesExtension<T> {
81 storage: T,
82}
83
84#[async_trait::async_trait]
85impl<T: CacheStorage> Extension for ApolloPersistedQueriesExtension<T> {
86 async fn prepare_request(
87 &self,
88 ctx: &ExtensionContext<'_>,
89 mut request: Request,
90 next: NextPrepareRequest<'_>,
91 ) -> ServerResult<Request> {
92 let res = if let Some(value) = request.extensions.remove("persistedQuery") {
93 let persisted_query: PersistedQuery = from_value(value).map_err(|_| {
94 ServerError::new("Invalid \"PersistedQuery\" extension configuration.", None)
95 })?;
96 if persisted_query.version != 1 {
97 return Err(ServerError::new(
98 format!(
99 "Only the \"PersistedQuery\" extension of version \"1\" is supported, and the current version is \"{}\".",
100 persisted_query.version
101 ),
102 None,
103 ));
104 }
105
106 if request.query.is_empty() {
107 if let Some(doc) = self.storage.get(persisted_query.sha256_hash).await {
108 Ok(Request {
109 parsed_query: Some(doc),
110 ..request
111 })
112 } else {
113 Err(ServerError::new("PersistedQueryNotFound", None))
114 }
115 } else {
116 let sha256_hash = format!("{:x}", Sha256::digest(request.query.as_bytes()));
117
118 if persisted_query.sha256_hash != sha256_hash {
119 Err(ServerError::new("provided sha does not match query", None))
120 } else {
121 let doc = async_graphql_parser::parse_query(&request.query)?;
122 self.storage.set(sha256_hash, doc.clone()).await;
123 Ok(Request {
124 query: String::new(),
125 parsed_query: Some(doc),
126 ..request
127 })
128 }
129 }
130 } else {
131 Ok(request)
132 };
133 next.run(ctx, res?).await
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 #[tokio::test]
140 async fn test() {
141 use super::*;
142 use crate::*;
143
144 struct Query;
145
146 #[Object(internal)]
147 impl Query {
148 async fn value(&self) -> i32 {
149 100
150 }
151 }
152
153 let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
154 .extension(ApolloPersistedQueries::new(LruCacheStorage::new(256)))
155 .finish();
156
157 let mut request = Request::new("{ value }");
158 request.extensions.insert(
159 "persistedQuery".to_string(),
160 value!({
161 "version": 1,
162 "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b",
163 }),
164 );
165
166 assert_eq!(
167 schema.execute(request).await.into_result().unwrap().data,
168 value!({
169 "value": 100
170 })
171 );
172
173 let mut request = Request::new("");
174 request.extensions.insert(
175 "persistedQuery".to_string(),
176 value!({
177 "version": 1,
178 "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b",
179 }),
180 );
181
182 assert_eq!(
183 schema.execute(request).await.into_result().unwrap().data,
184 value!({
185 "value": 100
186 })
187 );
188
189 let mut request = Request::new("");
190 request.extensions.insert(
191 "persistedQuery".to_string(),
192 value!({
193 "version": 1,
194 "sha256Hash": "def",
195 }),
196 );
197
198 assert_eq!(
199 schema.execute(request).await.into_result().unwrap_err(),
200 vec![ServerError::new("PersistedQueryNotFound", None)]
201 );
202 }
203}