1use std::{
12 collections::HashMap,
13 path::Path,
14 sync::{
15 Arc,
16 atomic::{AtomicU64, Ordering},
17 },
18};
19
20const MAX_MANIFEST_BYTES: u64 = 10 * 1024 * 1024; use serde::Deserialize;
28use tokio::sync::RwLock;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32#[non_exhaustive]
33pub enum TrustedDocumentMode {
34 Strict,
36 Permissive,
38}
39
40#[derive(Debug, Deserialize)]
42struct Manifest {
43 #[allow(dead_code)] version: u32,
47 documents: HashMap<String, String>,
48}
49
50pub struct TrustedDocumentStore {
52 documents: Arc<RwLock<HashMap<String, String>>>,
54 mode: TrustedDocumentMode,
55}
56
57impl TrustedDocumentStore {
58 pub fn from_manifest_file(
64 path: &Path,
65 mode: TrustedDocumentMode,
66 ) -> Result<Self, TrustedDocumentError> {
67 let file_size = std::fs::metadata(path)
69 .map_err(|e| {
70 TrustedDocumentError::ManifestLoad(format!(
71 "Failed to stat manifest {}: {e}",
72 path.display()
73 ))
74 })?
75 .len();
76 if file_size > MAX_MANIFEST_BYTES {
77 return Err(TrustedDocumentError::ManifestLoad(format!(
78 "Manifest {} is too large ({file_size} bytes, max {MAX_MANIFEST_BYTES})",
79 path.display()
80 )));
81 }
82
83 let contents = std::fs::read_to_string(path).map_err(|e| {
84 TrustedDocumentError::ManifestLoad(format!(
85 "Failed to read manifest {}: {e}",
86 path.display()
87 ))
88 })?;
89 let manifest: Manifest = serde_json::from_str(&contents).map_err(|e| {
90 TrustedDocumentError::ManifestLoad(format!(
91 "Failed to parse manifest {}: {e}",
92 path.display()
93 ))
94 })?;
95 let documents = normalize_keys(manifest.documents);
96 Ok(Self {
97 documents: Arc::new(RwLock::new(documents)),
98 mode,
99 })
100 }
101
102 pub fn from_documents(documents: HashMap<String, String>, mode: TrustedDocumentMode) -> Self {
104 let documents = normalize_keys(documents);
105 Self {
106 documents: Arc::new(RwLock::new(documents)),
107 mode,
108 }
109 }
110
111 pub fn disabled() -> Self {
113 Self {
114 documents: Arc::new(RwLock::new(HashMap::new())),
115 mode: TrustedDocumentMode::Permissive,
116 }
117 }
118
119 pub const fn mode(&self) -> TrustedDocumentMode {
121 self.mode
122 }
123
124 pub async fn document_count(&self) -> usize {
126 self.documents.read().await.len()
127 }
128
129 pub async fn replace_documents(&self, documents: HashMap<String, String>) {
131 let documents = normalize_keys(documents);
132 *self.documents.write().await = documents;
133 }
134
135 pub async fn resolve(
148 &self,
149 document_id: Option<&str>,
150 raw_query: Option<&str>,
151 ) -> Result<String, TrustedDocumentError> {
152 if let Some(doc_id) = document_id {
153 let hash = doc_id.strip_prefix("sha256:").unwrap_or(doc_id);
154 let docs = self.documents.read().await;
155 return docs.get(hash).cloned().ok_or_else(|| TrustedDocumentError::DocumentNotFound {
156 id: doc_id.to_string(),
157 });
158 }
159 match self.mode {
160 TrustedDocumentMode::Strict => Err(TrustedDocumentError::ForbiddenRawQuery),
161 TrustedDocumentMode::Permissive => {
162 raw_query.map(|s| s.to_string()).ok_or(TrustedDocumentError::ForbiddenRawQuery)
163 },
164 }
165 }
166}
167
168fn normalize_keys(documents: HashMap<String, String>) -> HashMap<String, String> {
170 documents
171 .into_iter()
172 .map(|(k, v)| {
173 let key = k.strip_prefix("sha256:").unwrap_or(&k).to_string();
174 (key, v)
175 })
176 .collect()
177}
178
179#[derive(Debug, thiserror::Error)]
181#[non_exhaustive]
182pub enum TrustedDocumentError {
183 #[error("Raw queries are not permitted. Send a documentId instead.")]
185 ForbiddenRawQuery,
186
187 #[error("Unknown document: {id}")]
189 DocumentNotFound {
190 id: String,
192 },
193
194 #[error("Manifest load error: {0}")]
196 ManifestLoad(String),
197}
198
199static TRUSTED_DOC_HITS: AtomicU64 = AtomicU64::new(0);
202static TRUSTED_DOC_MISSES: AtomicU64 = AtomicU64::new(0);
203static TRUSTED_DOC_REJECTED: AtomicU64 = AtomicU64::new(0);
204
205pub fn record_hit() {
207 TRUSTED_DOC_HITS.fetch_add(1, Ordering::Relaxed);
208}
209
210pub fn record_miss() {
212 TRUSTED_DOC_MISSES.fetch_add(1, Ordering::Relaxed);
213}
214
215pub fn record_rejected() {
217 TRUSTED_DOC_REJECTED.fetch_add(1, Ordering::Relaxed);
218}
219
220pub fn hits_total() -> u64 {
222 TRUSTED_DOC_HITS.load(Ordering::Relaxed)
223}
224
225pub fn misses_total() -> u64 {
227 TRUSTED_DOC_MISSES.load(Ordering::Relaxed)
228}
229
230pub fn rejected_total() -> u64 {
232 TRUSTED_DOC_REJECTED.load(Ordering::Relaxed)
233}
234
235#[cfg(test)]
236mod tests {
237 #![allow(clippy::unwrap_used)] #![allow(clippy::cast_precision_loss)] #![allow(clippy::cast_sign_loss)] #![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_wrap)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_errors_doc)] #![allow(missing_docs)] #![allow(clippy::items_after_statements)] use super::*;
248
249 fn test_documents() -> HashMap<String, String> {
250 let mut docs = HashMap::new();
251 docs.insert("sha256:abc123".to_string(), "{ users { id } }".to_string());
252 docs.insert("sha256:def456".to_string(), "mutation { createUser { id } }".to_string());
253 docs
254 }
255
256 #[tokio::test]
257 async fn strict_mode_rejects_raw_query() {
258 let store =
259 TrustedDocumentStore::from_documents(test_documents(), TrustedDocumentMode::Strict);
260 let result = store.resolve(None, Some("{ users { id } }")).await;
261 assert!(matches!(result, Err(TrustedDocumentError::ForbiddenRawQuery)));
262 }
263
264 #[tokio::test]
265 async fn strict_mode_accepts_valid_document_id() {
266 let store =
267 TrustedDocumentStore::from_documents(test_documents(), TrustedDocumentMode::Strict);
268 let result = store.resolve(Some("sha256:abc123"), None).await;
269 assert_eq!(result.unwrap(), "{ users { id } }");
270 }
271
272 #[tokio::test]
273 async fn strict_mode_rejects_unknown_document_id() {
274 let store =
275 TrustedDocumentStore::from_documents(test_documents(), TrustedDocumentMode::Strict);
276 let result = store.resolve(Some("sha256:unknown"), None).await;
277 assert!(matches!(result, Err(TrustedDocumentError::DocumentNotFound { .. })));
278 }
279
280 #[tokio::test]
281 async fn permissive_mode_allows_raw_queries() {
282 let store =
283 TrustedDocumentStore::from_documents(test_documents(), TrustedDocumentMode::Permissive);
284 let result = store.resolve(None, Some("{ arbitrary { query } }")).await;
285 assert_eq!(result.unwrap(), "{ arbitrary { query } }");
286 }
287
288 #[tokio::test]
289 async fn permissive_mode_uses_manifest_for_document_id() {
290 let store =
291 TrustedDocumentStore::from_documents(test_documents(), TrustedDocumentMode::Permissive);
292 let result = store.resolve(Some("sha256:abc123"), None).await;
293 assert_eq!(result.unwrap(), "{ users { id } }");
294 }
295
296 #[tokio::test]
297 async fn document_id_without_prefix_is_resolved() {
298 let store =
299 TrustedDocumentStore::from_documents(test_documents(), TrustedDocumentMode::Strict);
300 let result = store.resolve(Some("abc123"), None).await;
302 assert_eq!(result.unwrap(), "{ users { id } }");
303 }
304
305 #[tokio::test]
306 async fn disabled_store_passes_through() {
307 let store = TrustedDocumentStore::disabled();
308 let result = store.resolve(None, Some("{ anything }")).await;
309 assert_eq!(result.unwrap(), "{ anything }");
310 }
311
312 #[tokio::test]
313 async fn hot_reload_replaces_documents() {
314 let store =
315 TrustedDocumentStore::from_documents(test_documents(), TrustedDocumentMode::Strict);
316 assert_eq!(store.document_count().await, 2);
317
318 let mut new_docs = HashMap::new();
319 new_docs.insert("sha256:new123".to_string(), "{ new query }".to_string());
320 store.replace_documents(new_docs).await;
321
322 assert_eq!(store.document_count().await, 1);
323 let result = store.resolve(Some("sha256:new123"), None).await;
324 assert_eq!(result.unwrap(), "{ new query }");
325
326 let result = store.resolve(Some("sha256:abc123"), None).await;
328 assert!(
329 matches!(result, Err(TrustedDocumentError::DocumentNotFound { .. })),
330 "expected DocumentNotFound after hot-reload removed old document, got: {result:?}"
331 );
332 }
333
334 #[tokio::test]
335 async fn manifest_file_loading() {
336 let dir = tempfile::tempdir().unwrap();
337 let path = dir.path().join("trusted-documents.json");
338 let manifest = serde_json::json!({
339 "version": 1,
340 "documents": {
341 "sha256:aaa": "{ users { id } }",
342 "sha256:bbb": "{ posts { title } }"
343 }
344 });
345 std::fs::write(&path, serde_json::to_string(&manifest).unwrap()).unwrap();
346
347 let store =
348 TrustedDocumentStore::from_manifest_file(&path, TrustedDocumentMode::Strict).unwrap();
349 assert_eq!(store.document_count().await, 2);
350 let result = store.resolve(Some("sha256:aaa"), None).await;
351 assert_eq!(result.unwrap(), "{ users { id } }");
352 }
353
354 #[test]
357 fn manifest_file_exceeding_size_limit_is_rejected() {
358 use std::io::Write as _;
359
360 let dir = tempfile::tempdir().unwrap();
361 let path = dir.path().join("huge-manifest.json");
362
363 let mut f = std::fs::File::create(&path).unwrap();
365 f.write_all(b"{\"version\":1,\"documents\":{}}").unwrap();
367 let padding = vec![b' '; (MAX_MANIFEST_BYTES + 1) as usize];
368 f.write_all(&padding).unwrap();
369 drop(f);
370
371 let result = TrustedDocumentStore::from_manifest_file(&path, TrustedDocumentMode::Strict);
372 assert!(result.is_err(), "oversized manifest must be rejected");
373 let msg = result.err().unwrap().to_string();
374 assert!(
375 msg.contains("too large") || msg.contains("10485760"),
376 "error must mention size limit: {msg}"
377 );
378 }
379
380 #[test]
381 fn manifest_file_at_size_limit_is_accepted_if_valid() {
382 let dir = tempfile::tempdir().unwrap();
384 let path = dir.path().join("small-manifest.json");
385 let manifest = serde_json::json!({"version": 1, "documents": {}});
386 std::fs::write(&path, serde_json::to_string(&manifest).unwrap()).unwrap();
387 TrustedDocumentStore::from_manifest_file(&path, TrustedDocumentMode::Permissive)
388 .unwrap_or_else(|e| panic!("small valid manifest must be accepted: {e}"));
389 }
390}