use std::{
collections::HashMap,
path::Path,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
};
const MAX_MANIFEST_BYTES: u64 = 10 * 1024 * 1024;
use serde::Deserialize;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum TrustedDocumentMode {
Strict,
Permissive,
}
#[derive(Debug, Deserialize)]
struct Manifest {
#[allow(dead_code)] version: u32,
documents: HashMap<String, String>,
}
pub struct TrustedDocumentStore {
documents: Arc<RwLock<HashMap<String, String>>>,
mode: TrustedDocumentMode,
}
impl TrustedDocumentStore {
pub fn from_manifest_file(
path: &Path,
mode: TrustedDocumentMode,
) -> Result<Self, TrustedDocumentError> {
let file_size = std::fs::metadata(path)
.map_err(|e| {
TrustedDocumentError::ManifestLoad(format!(
"Failed to stat manifest {}: {e}",
path.display()
))
})?
.len();
if file_size > MAX_MANIFEST_BYTES {
return Err(TrustedDocumentError::ManifestLoad(format!(
"Manifest {} is too large ({file_size} bytes, max {MAX_MANIFEST_BYTES})",
path.display()
)));
}
let contents = std::fs::read_to_string(path).map_err(|e| {
TrustedDocumentError::ManifestLoad(format!(
"Failed to read manifest {}: {e}",
path.display()
))
})?;
let manifest: Manifest = serde_json::from_str(&contents).map_err(|e| {
TrustedDocumentError::ManifestLoad(format!(
"Failed to parse manifest {}: {e}",
path.display()
))
})?;
let documents = normalize_keys(manifest.documents);
Ok(Self {
documents: Arc::new(RwLock::new(documents)),
mode,
})
}
pub fn from_documents(documents: HashMap<String, String>, mode: TrustedDocumentMode) -> Self {
let documents = normalize_keys(documents);
Self {
documents: Arc::new(RwLock::new(documents)),
mode,
}
}
pub fn disabled() -> Self {
Self {
documents: Arc::new(RwLock::new(HashMap::new())),
mode: TrustedDocumentMode::Permissive,
}
}
pub const fn mode(&self) -> TrustedDocumentMode {
self.mode
}
pub async fn document_count(&self) -> usize {
self.documents.read().await.len()
}
pub async fn replace_documents(&self, documents: HashMap<String, String>) {
let documents = normalize_keys(documents);
*self.documents.write().await = documents;
}
pub async fn resolve(
&self,
document_id: Option<&str>,
raw_query: Option<&str>,
) -> Result<String, TrustedDocumentError> {
if let Some(doc_id) = document_id {
let hash = doc_id.strip_prefix("sha256:").unwrap_or(doc_id);
let docs = self.documents.read().await;
return docs.get(hash).cloned().ok_or_else(|| TrustedDocumentError::DocumentNotFound {
id: doc_id.to_string(),
});
}
match self.mode {
TrustedDocumentMode::Strict => Err(TrustedDocumentError::ForbiddenRawQuery),
TrustedDocumentMode::Permissive => {
raw_query.map(|s| s.to_string()).ok_or(TrustedDocumentError::ForbiddenRawQuery)
},
}
}
}
fn normalize_keys(documents: HashMap<String, String>) -> HashMap<String, String> {
documents
.into_iter()
.map(|(k, v)| {
let key = k.strip_prefix("sha256:").unwrap_or(&k).to_string();
(key, v)
})
.collect()
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum TrustedDocumentError {
#[error("Raw queries are not permitted. Send a documentId instead.")]
ForbiddenRawQuery,
#[error("Unknown document: {id}")]
DocumentNotFound {
id: String,
},
#[error("Manifest load error: {0}")]
ManifestLoad(String),
}
static TRUSTED_DOC_HITS: AtomicU64 = AtomicU64::new(0);
static TRUSTED_DOC_MISSES: AtomicU64 = AtomicU64::new(0);
static TRUSTED_DOC_REJECTED: AtomicU64 = AtomicU64::new(0);
pub fn record_hit() {
TRUSTED_DOC_HITS.fetch_add(1, Ordering::Relaxed);
}
pub fn record_miss() {
TRUSTED_DOC_MISSES.fetch_add(1, Ordering::Relaxed);
}
pub fn record_rejected() {
TRUSTED_DOC_REJECTED.fetch_add(1, Ordering::Relaxed);
}
pub fn hits_total() -> u64 {
TRUSTED_DOC_HITS.load(Ordering::Relaxed)
}
pub fn misses_total() -> u64 {
TRUSTED_DOC_MISSES.load(Ordering::Relaxed)
}
pub fn rejected_total() -> u64 {
TRUSTED_DOC_REJECTED.load(Ordering::Relaxed)
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)] #![allow(clippy::cast_precision_loss)] #![allow(clippy::cast_sign_loss)] #![allow(clippy::cast_possible_truncation)] #![allow(clippy::cast_possible_wrap)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_errors_doc)] #![allow(missing_docs)] #![allow(clippy::items_after_statements)]
use super::*;
fn test_documents() -> HashMap<String, String> {
let mut docs = HashMap::new();
docs.insert("sha256:abc123".to_string(), "{ users { id } }".to_string());
docs.insert("sha256:def456".to_string(), "mutation { createUser { id } }".to_string());
docs
}
#[tokio::test]
async fn strict_mode_rejects_raw_query() {
let store =
TrustedDocumentStore::from_documents(test_documents(), TrustedDocumentMode::Strict);
let result = store.resolve(None, Some("{ users { id } }")).await;
assert!(matches!(result, Err(TrustedDocumentError::ForbiddenRawQuery)));
}
#[tokio::test]
async fn strict_mode_accepts_valid_document_id() {
let store =
TrustedDocumentStore::from_documents(test_documents(), TrustedDocumentMode::Strict);
let result = store.resolve(Some("sha256:abc123"), None).await;
assert_eq!(result.unwrap(), "{ users { id } }");
}
#[tokio::test]
async fn strict_mode_rejects_unknown_document_id() {
let store =
TrustedDocumentStore::from_documents(test_documents(), TrustedDocumentMode::Strict);
let result = store.resolve(Some("sha256:unknown"), None).await;
assert!(matches!(result, Err(TrustedDocumentError::DocumentNotFound { .. })));
}
#[tokio::test]
async fn permissive_mode_allows_raw_queries() {
let store =
TrustedDocumentStore::from_documents(test_documents(), TrustedDocumentMode::Permissive);
let result = store.resolve(None, Some("{ arbitrary { query } }")).await;
assert_eq!(result.unwrap(), "{ arbitrary { query } }");
}
#[tokio::test]
async fn permissive_mode_uses_manifest_for_document_id() {
let store =
TrustedDocumentStore::from_documents(test_documents(), TrustedDocumentMode::Permissive);
let result = store.resolve(Some("sha256:abc123"), None).await;
assert_eq!(result.unwrap(), "{ users { id } }");
}
#[tokio::test]
async fn document_id_without_prefix_is_resolved() {
let store =
TrustedDocumentStore::from_documents(test_documents(), TrustedDocumentMode::Strict);
let result = store.resolve(Some("abc123"), None).await;
assert_eq!(result.unwrap(), "{ users { id } }");
}
#[tokio::test]
async fn disabled_store_passes_through() {
let store = TrustedDocumentStore::disabled();
let result = store.resolve(None, Some("{ anything }")).await;
assert_eq!(result.unwrap(), "{ anything }");
}
#[tokio::test]
async fn hot_reload_replaces_documents() {
let store =
TrustedDocumentStore::from_documents(test_documents(), TrustedDocumentMode::Strict);
assert_eq!(store.document_count().await, 2);
let mut new_docs = HashMap::new();
new_docs.insert("sha256:new123".to_string(), "{ new query }".to_string());
store.replace_documents(new_docs).await;
assert_eq!(store.document_count().await, 1);
let result = store.resolve(Some("sha256:new123"), None).await;
assert_eq!(result.unwrap(), "{ new query }");
let result = store.resolve(Some("sha256:abc123"), None).await;
assert!(
matches!(result, Err(TrustedDocumentError::DocumentNotFound { .. })),
"expected DocumentNotFound after hot-reload removed old document, got: {result:?}"
);
}
#[tokio::test]
async fn manifest_file_loading() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("trusted-documents.json");
let manifest = serde_json::json!({
"version": 1,
"documents": {
"sha256:aaa": "{ users { id } }",
"sha256:bbb": "{ posts { title } }"
}
});
std::fs::write(&path, serde_json::to_string(&manifest).unwrap()).unwrap();
let store =
TrustedDocumentStore::from_manifest_file(&path, TrustedDocumentMode::Strict).unwrap();
assert_eq!(store.document_count().await, 2);
let result = store.resolve(Some("sha256:aaa"), None).await;
assert_eq!(result.unwrap(), "{ users { id } }");
}
#[test]
fn manifest_file_exceeding_size_limit_is_rejected() {
use std::io::Write as _;
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("huge-manifest.json");
let mut f = std::fs::File::create(&path).unwrap();
f.write_all(b"{\"version\":1,\"documents\":{}}").unwrap();
let padding = vec![b' '; (MAX_MANIFEST_BYTES + 1) as usize];
f.write_all(&padding).unwrap();
drop(f);
let result = TrustedDocumentStore::from_manifest_file(&path, TrustedDocumentMode::Strict);
assert!(result.is_err(), "oversized manifest must be rejected");
let msg = result.err().unwrap().to_string();
assert!(
msg.contains("too large") || msg.contains("10485760"),
"error must mention size limit: {msg}"
);
}
#[test]
fn manifest_file_at_size_limit_is_accepted_if_valid() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("small-manifest.json");
let manifest = serde_json::json!({"version": 1, "documents": {}});
std::fs::write(&path, serde_json::to_string(&manifest).unwrap()).unwrap();
TrustedDocumentStore::from_manifest_file(&path, TrustedDocumentMode::Permissive)
.unwrap_or_else(|e| panic!("small valid manifest must be accepted: {e}"));
}
}