async_graphql/extensions/
apollo_persisted_queries.rs1use std::sync::Arc;
4
5use async_graphql_parser::types::ExecutableDocument;
6use serde::Deserialize;
7use sha2::{Digest, Sha256};
8
9use crate::{
10 Request, ServerError, ServerResult,
11 extensions::{Extension, ExtensionContext, ExtensionFactory, NextPrepareRequest},
12 from_value,
13};
14
15#[derive(Deserialize)]
16struct PersistedQuery {
17 version: i32,
18 #[serde(rename = "sha256Hash")]
19 sha256_hash: String,
20}
21
22#[async_trait::async_trait]
24pub trait CacheStorage: Send + Sync + Clone + 'static {
25 async fn get(&self, key: String) -> Option<ExecutableDocument>;
27
28 async fn set(&self, key: String, query: ExecutableDocument);
30}
31
32#[derive(Clone)]
34pub struct LruCacheStorage(Arc<scc::HashCache<String, ExecutableDocument>>);
35
36impl LruCacheStorage {
37 pub fn new(cap: usize) -> Self {
39 Self(Arc::new(scc::HashCache::with_capacity(0, cap)))
40 }
41}
42
43#[async_trait::async_trait]
44impl CacheStorage for LruCacheStorage {
45 async fn get(&self, key: String) -> Option<ExecutableDocument> {
46 self.0
47 .get_async(&key)
48 .await
49 .map(|entry| entry.get().clone())
50 }
51
52 async fn set(&self, key: String, query: ExecutableDocument) {
53 let _ = self.0.put_async(key, query).await;
54 }
55}
56
57#[cfg_attr(docsrs, doc(cfg(feature = "apollo_persisted_queries")))]
61pub struct ApolloPersistedQueries<T>(T);
62
63impl<T: CacheStorage> ApolloPersistedQueries<T> {
64 pub fn new(cache_storage: T) -> ApolloPersistedQueries<T> {
66 Self(cache_storage)
67 }
68}
69
70impl<T: CacheStorage> ExtensionFactory for ApolloPersistedQueries<T> {
71 fn create(&self) -> Arc<dyn Extension> {
72 Arc::new(ApolloPersistedQueriesExtension {
73 storage: self.0.clone(),
74 })
75 }
76}
77
78struct ApolloPersistedQueriesExtension<T> {
79 storage: T,
80}
81
82#[async_trait::async_trait]
83impl<T: CacheStorage> Extension for ApolloPersistedQueriesExtension<T> {
84 async fn prepare_request(
85 &self,
86 ctx: &ExtensionContext<'_>,
87 mut request: Request,
88 next: NextPrepareRequest<'_>,
89 ) -> ServerResult<Request> {
90 let res = if let Some(value) = request.extensions.remove("persistedQuery") {
91 let persisted_query: PersistedQuery = from_value(value).map_err(|_| {
92 ServerError::new("Invalid \"PersistedQuery\" extension configuration.", None)
93 })?;
94 if persisted_query.version != 1 {
95 return Err(ServerError::new(
96 format!(
97 "Only the \"PersistedQuery\" extension of version \"1\" is supported, and the current version is \"{}\".",
98 persisted_query.version
99 ),
100 None,
101 ));
102 }
103
104 if request.query.is_empty() {
105 if let Some(doc) = self.storage.get(persisted_query.sha256_hash).await {
106 Ok(Request {
107 parsed_query: Some(doc),
108 ..request
109 })
110 } else {
111 Err(ServerError::new("PersistedQueryNotFound", None))
112 }
113 } else {
114 let sha256_hash = format!("{:x}", Sha256::digest(request.query.as_bytes()));
115
116 if persisted_query.sha256_hash != sha256_hash {
117 Err(ServerError::new("provided sha does not match query", None))
118 } else {
119 let doc = async_graphql_parser::parse_query(&request.query)?;
120 self.storage.set(sha256_hash, doc.clone()).await;
121 Ok(Request {
122 query: String::new(),
123 parsed_query: Some(doc),
124 ..request
125 })
126 }
127 }
128 } else {
129 Ok(request)
130 };
131 next.run(ctx, res?).await
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 #[tokio::test]
138 async fn test() {
139 use super::*;
140 use crate::*;
141
142 struct Query;
143
144 #[Object(internal)]
145 impl Query {
146 async fn value(&self) -> i32 {
147 100
148 }
149 }
150
151 let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
152 .extension(ApolloPersistedQueries::new(LruCacheStorage::new(256)))
153 .finish();
154
155 let mut request = Request::new("{ value }");
156 request.extensions.insert(
157 "persistedQuery".to_string(),
158 value!({
159 "version": 1,
160 "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b",
161 }),
162 );
163
164 assert_eq!(
165 schema.execute(request).await.into_result().unwrap().data,
166 value!({
167 "value": 100
168 })
169 );
170
171 let mut request = Request::new("");
172 request.extensions.insert(
173 "persistedQuery".to_string(),
174 value!({
175 "version": 1,
176 "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b",
177 }),
178 );
179
180 assert_eq!(
181 schema.execute(request).await.into_result().unwrap().data,
182 value!({
183 "value": 100
184 })
185 );
186
187 let mut request = Request::new("");
188 request.extensions.insert(
189 "persistedQuery".to_string(),
190 value!({
191 "version": 1,
192 "sha256Hash": "def",
193 }),
194 );
195
196 assert_eq!(
197 schema.execute(request).await.into_result().unwrap_err(),
198 vec![ServerError::new("PersistedQueryNotFound", None)]
199 );
200 }
201}