use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersistedQueryConfig {
pub enable_apq: bool,
pub enable_pqd: bool,
pub allowlist_mode: bool,
pub max_cached_queries: usize,
pub query_ttl_seconds: u64,
pub enable_statistics: bool,
}
impl Default for PersistedQueryConfig {
fn default() -> Self {
Self {
enable_apq: true,
enable_pqd: true,
allowlist_mode: false,
max_cached_queries: 10000,
query_ttl_seconds: 3600, enable_statistics: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersistedQuery {
pub hash: String,
pub query: String,
pub version: Option<String>,
pub name: Option<String>,
pub created_at: chrono::DateTime<chrono::Utc>,
pub last_accessed: chrono::DateTime<chrono::Utc>,
pub access_count: u64,
pub is_registered: bool,
}
impl PersistedQuery {
pub fn new(hash: String, query: String, is_registered: bool) -> Self {
let now = chrono::Utc::now();
Self {
hash,
query,
version: None,
name: None,
created_at: now,
last_accessed: now,
access_count: 0,
is_registered,
}
}
pub fn with_version(mut self, version: String) -> Self {
self.version = Some(version);
self
}
pub fn with_name(mut self, name: String) -> Self {
self.name = Some(name);
self
}
pub fn record_access(&mut self) {
self.last_accessed = chrono::Utc::now();
self.access_count += 1;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApqExtension {
#[serde(rename = "persistedQuery")]
pub persisted_query: ApqPersistedQuery,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApqPersistedQuery {
pub version: u32,
#[serde(rename = "sha256Hash")]
pub sha256_hash: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryStatistics {
pub total_cached: usize,
pub registered_queries: usize,
pub apq_hits: u64,
pub apq_misses: u64,
pub hit_rate: f64,
pub top_queries: Vec<(String, u64)>,
}
pub struct PersistedQueryManager {
config: Arc<PersistedQueryConfig>,
cache: Arc<RwLock<HashMap<String, PersistedQuery>>>,
denylist: Arc<RwLock<HashMap<String, String>>>,
apq_hits: Arc<RwLock<u64>>,
apq_misses: Arc<RwLock<u64>>,
}
impl PersistedQueryManager {
pub fn new(config: PersistedQueryConfig) -> Self {
Self {
config: Arc::new(config),
cache: Arc::new(RwLock::new(HashMap::new())),
denylist: Arc::new(RwLock::new(HashMap::new())),
apq_hits: Arc::new(RwLock::new(0)),
apq_misses: Arc::new(RwLock::new(0)),
}
}
pub async fn register_query(
&self,
hash: String,
query: String,
name: Option<String>,
version: Option<String>,
) -> Result<()> {
let mut cache = self.cache.write().await;
if self.config.max_cached_queries > 0 && cache.len() >= self.config.max_cached_queries {
if let Some((oldest_hash, _)) = cache
.iter()
.min_by_key(|(_, q)| q.last_accessed)
.map(|(h, q)| (h.clone(), q.clone()))
{
cache.remove(&oldest_hash);
}
}
let mut query_entry = PersistedQuery::new(hash.clone(), query, true);
if let Some(name) = name {
query_entry = query_entry.with_name(name);
}
if let Some(version) = version {
query_entry = query_entry.with_version(version);
}
cache.insert(hash, query_entry);
Ok(())
}
pub async fn get_query(&self, hash: &str) -> Result<String> {
{
let denylist = self.denylist.read().await;
if let Some(reason) = denylist.get(hash) {
return Err(anyhow!("Query is blocked: {}", reason));
}
}
let mut cache = self.cache.write().await;
if let Some(query) = cache.get_mut(hash) {
query.record_access();
if self.config.enable_statistics {
let mut hits = self.apq_hits.write().await;
*hits += 1;
}
Ok(query.query.clone())
} else {
if self.config.enable_statistics {
let mut misses = self.apq_misses.write().await;
*misses += 1;
}
Err(anyhow!("PersistedQueryNotFound"))
}
}
pub async fn store_query(&self, hash: String, query: String) -> Result<()> {
if self.config.allowlist_mode {
return Err(anyhow!(
"Cannot register new queries in allowlist mode. Only pre-registered queries are allowed."
));
}
if !self.config.enable_apq {
return Err(anyhow!("Automatic persisted queries are disabled"));
}
let computed_hash = Self::compute_hash(&query);
if computed_hash != hash {
return Err(anyhow!(
"Query hash mismatch. Expected: {}, Got: {}",
computed_hash,
hash
));
}
self.register_query(hash, query, None, None).await
}
pub fn compute_hash(query: &str) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(query.as_bytes());
format!("{:x}", hasher.finalize())
}
pub async fn deny_query(&self, hash: String, reason: String) -> Result<()> {
let mut denylist = self.denylist.write().await;
denylist.insert(hash, reason);
Ok(())
}
pub async fn allow_query(&self, hash: &str) -> Result<()> {
let mut denylist = self.denylist.write().await;
denylist.remove(hash);
Ok(())
}
pub async fn get_statistics(&self) -> QueryStatistics {
let cache = self.cache.read().await;
let hits = *self.apq_hits.read().await;
let misses = *self.apq_misses.read().await;
let total_requests = hits + misses;
let hit_rate = if total_requests > 0 {
hits as f64 / total_requests as f64
} else {
0.0
};
let registered_queries = cache.values().filter(|q| q.is_registered).count();
let mut queries: Vec<_> = cache
.values()
.map(|q| (q.hash.clone(), q.access_count))
.collect();
queries.sort_by(|a, b| b.1.cmp(&a.1));
let top_queries = queries.into_iter().take(10).collect();
QueryStatistics {
total_cached: cache.len(),
registered_queries,
apq_hits: hits,
apq_misses: misses,
hit_rate,
top_queries,
}
}
pub async fn clear_cache(&self) -> Result<()> {
let mut cache = self.cache.write().await;
cache.clear();
Ok(())
}
pub async fn clear_statistics(&self) -> Result<()> {
let mut hits = self.apq_hits.write().await;
let mut misses = self.apq_misses.write().await;
*hits = 0;
*misses = 0;
Ok(())
}
pub async fn export_queries(&self) -> Vec<PersistedQuery> {
let cache = self.cache.read().await;
cache
.values()
.filter(|q| q.is_registered)
.cloned()
.collect()
}
pub async fn import_queries(&self, queries: Vec<PersistedQuery>) -> Result<()> {
let mut cache = self.cache.write().await;
for query in queries {
cache.insert(query.hash.clone(), query);
}
Ok(())
}
pub async fn load_from_file(&self, path: &std::path::Path) -> Result<()> {
let content = tokio::fs::read_to_string(path).await?;
let queries: Vec<PersistedQuery> = serde_json::from_str(&content)?;
self.import_queries(queries).await
}
pub async fn save_to_file(&self, path: &std::path::Path) -> Result<()> {
let queries = self.export_queries().await;
let content = serde_json::to_string_pretty(&queries)?;
tokio::fs::write(path, content).await?;
Ok(())
}
}
pub struct PersistedQueryStore {
queries: Arc<RwLock<HashMap<String, String>>>,
hits: Arc<RwLock<u64>>,
misses: Arc<RwLock<u64>>,
}
impl PersistedQueryStore {
pub fn new() -> Self {
Self {
queries: Arc::new(RwLock::new(HashMap::new())),
hits: Arc::new(RwLock::new(0)),
misses: Arc::new(RwLock::new(0)),
}
}
pub fn hash(query: &str) -> String {
use sha2::{Digest, Sha256};
let mut h = Sha256::new();
h.update(query.as_bytes());
format!("{:x}", h.finalize())
}
pub async fn store(&self, query: &str) -> Result<String> {
let key = Self::hash(query);
let mut map = self.queries.write().await;
if let Some(existing) = map.get(&key) {
if existing != query {
return Err(anyhow!("SHA-256 collision detected for hash {}", key));
}
}
map.insert(key.clone(), query.to_owned());
Ok(key)
}
pub async fn store_with_hash(&self, hash: &str, query: &str) -> Result<()> {
let computed = Self::hash(query);
if computed != hash {
return Err(anyhow!(
"Hash mismatch: provided {hash}, computed {computed}"
));
}
let mut map = self.queries.write().await;
map.insert(hash.to_owned(), query.to_owned());
Ok(())
}
pub async fn get(&self, hash: &str) -> Result<String> {
let cloned = {
let map = self.queries.read().await;
map.get(hash).cloned()
};
if let Some(q) = cloned {
let mut hits = self.hits.write().await;
*hits += 1;
Ok(q)
} else {
let mut misses = self.misses.write().await;
*misses += 1;
Err(anyhow!("PersistedQueryNotFound: {hash}"))
}
}
pub async fn handle_apq(&self, ext: &ApqPersistedQuery, query: Option<&str>) -> Result<String> {
match query {
None => self.get(&ext.sha256_hash).await,
Some(q) => {
self.store_with_hash(&ext.sha256_hash, q).await?;
Ok(q.to_owned())
}
}
}
pub async fn remove(&self, hash: &str) -> bool {
let mut map = self.queries.write().await;
map.remove(hash).is_some()
}
pub async fn len(&self) -> usize {
self.queries.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.queries.read().await.is_empty()
}
pub async fn stats(&self) -> (u64, u64) {
let hits = *self.hits.read().await;
let misses = *self.misses.read().await;
(hits, misses)
}
pub async fn reset_stats(&self) {
*self.hits.write().await = 0;
*self.misses.write().await = 0;
}
pub async fn clear(&self) {
self.queries.write().await.clear();
}
}
impl Default for PersistedQueryStore {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_persisted_query_config_default() {
let config = PersistedQueryConfig::default();
assert!(config.enable_apq);
assert!(config.enable_pqd);
assert!(!config.allowlist_mode);
assert_eq!(config.max_cached_queries, 10000);
assert!(config.enable_statistics);
}
#[tokio::test]
async fn test_compute_hash() {
let query = "{ hello }";
let hash = PersistedQueryManager::compute_hash(query);
assert_eq!(hash.len(), 64); }
#[tokio::test]
async fn test_register_and_get_query() {
let config = PersistedQueryConfig::default();
let manager = PersistedQueryManager::new(config);
let query = "{ hello }";
let hash = PersistedQueryManager::compute_hash(query);
manager
.register_query(hash.clone(), query.to_string(), None, None)
.await
.expect("should succeed");
let result = manager.get_query(&hash).await;
assert!(result.is_ok());
assert_eq!(result.expect("should succeed"), query);
}
#[tokio::test]
async fn test_apq_flow() {
let config = PersistedQueryConfig::default();
let manager = PersistedQueryManager::new(config);
let query = "{ user(id: 1) { name } }";
let hash = PersistedQueryManager::compute_hash(query);
let result = manager.get_query(&hash).await;
assert!(result.is_err());
manager
.store_query(hash.clone(), query.to_string())
.await
.expect("should succeed");
let result = manager.get_query(&hash).await;
assert!(result.is_ok());
assert_eq!(result.expect("should succeed"), query);
}
#[tokio::test]
async fn test_allowlist_mode() {
let config = PersistedQueryConfig {
allowlist_mode: true,
..Default::default()
};
let manager = PersistedQueryManager::new(config);
let query = "{ hello }";
let hash = PersistedQueryManager::compute_hash(query);
let result = manager.store_query(hash, query.to_string()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_query_denylist() {
let config = PersistedQueryConfig::default();
let manager = PersistedQueryManager::new(config);
let query = "{ malicious }";
let hash = PersistedQueryManager::compute_hash(query);
manager
.register_query(hash.clone(), query.to_string(), None, None)
.await
.expect("should succeed");
manager
.deny_query(hash.clone(), "Malicious query".to_string())
.await
.expect("should succeed");
let result = manager.get_query(&hash).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("blocked"));
}
#[tokio::test]
async fn test_statistics() {
let config = PersistedQueryConfig::default();
let manager = PersistedQueryManager::new(config);
let query = "{ hello }";
let hash = PersistedQueryManager::compute_hash(query);
manager
.register_query(hash.clone(), query.to_string(), None, None)
.await
.expect("should succeed");
for _ in 0..5 {
let _ = manager.get_query(&hash).await;
}
let stats = manager.get_statistics().await;
assert_eq!(stats.total_cached, 1);
assert_eq!(stats.registered_queries, 1);
assert_eq!(stats.apq_hits, 5);
assert_eq!(stats.hit_rate, 1.0);
}
#[tokio::test]
async fn test_cache_size_limit() {
let config = PersistedQueryConfig {
max_cached_queries: 2,
..Default::default()
};
let manager = PersistedQueryManager::new(config);
for i in 0..3 {
let query = format!("{{ query{} }}", i);
let hash = PersistedQueryManager::compute_hash(&query);
manager
.register_query(hash, query, None, None)
.await
.expect("should succeed");
}
let stats = manager.get_statistics().await;
assert_eq!(stats.total_cached, 2);
}
#[tokio::test]
async fn test_query_versioning() {
let config = PersistedQueryConfig::default();
let manager = PersistedQueryManager::new(config);
let query = "{ hello }";
let hash = PersistedQueryManager::compute_hash(query);
manager
.register_query(
hash.clone(),
query.to_string(),
Some("HelloQuery".to_string()),
Some("v1".to_string()),
)
.await
.expect("should succeed");
let result = manager.get_query(&hash).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_hash_verification() {
let config = PersistedQueryConfig::default();
let manager = PersistedQueryManager::new(config);
let query = "{ hello }";
let wrong_hash = "wronghash123";
let result = manager
.store_query(wrong_hash.to_string(), query.to_string())
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("mismatch"));
}
#[tokio::test]
async fn test_store_new_is_empty() {
let store = PersistedQueryStore::new();
assert!(store.is_empty().await);
assert_eq!(store.len().await, 0);
}
#[tokio::test]
async fn test_store_hash_is_sha256() {
let hash = PersistedQueryStore::hash("{ hello }");
assert_eq!(hash.len(), 64);
assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
}
#[tokio::test]
async fn test_store_and_retrieve() {
let store = PersistedQueryStore::new();
let query = "{ user { id name } }";
let hash = store.store(query).await.expect("should succeed");
let retrieved = store.get(&hash).await.expect("should succeed");
assert_eq!(retrieved, query);
}
#[tokio::test]
async fn test_store_returns_correct_hash() {
let store = PersistedQueryStore::new();
let query = "{ products { sku } }";
let expected_hash = PersistedQueryStore::hash(query);
let returned_hash = store.store(query).await.expect("should succeed");
assert_eq!(returned_hash, expected_hash);
}
#[tokio::test]
async fn test_get_missing_hash_returns_error() {
let store = PersistedQueryStore::new();
let result = store.get("nonexistenthash").await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("PersistedQueryNotFound"));
}
#[tokio::test]
async fn test_store_with_hash_valid() {
let store = PersistedQueryStore::new();
let query = "{ reviews { rating } }";
let hash = PersistedQueryStore::hash(query);
store
.store_with_hash(&hash, query)
.await
.expect("should succeed");
assert_eq!(store.get(&hash).await.expect("should succeed"), query);
}
#[tokio::test]
async fn test_store_with_hash_mismatch_returns_error() {
let store = PersistedQueryStore::new();
let result = store.store_with_hash("badhash", "{ hello }").await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("mismatch"));
}
#[tokio::test]
async fn test_handle_apq_first_leg_hit() {
let store = PersistedQueryStore::new();
let query = "{ status }";
let hash = store.store(query).await.expect("should succeed");
let ext = ApqPersistedQuery {
version: 1,
sha256_hash: hash,
};
let result = store.handle_apq(&ext, None).await.expect("should succeed");
assert_eq!(result, query);
}
#[tokio::test]
async fn test_handle_apq_first_leg_miss() {
let store = PersistedQueryStore::new();
let ext = ApqPersistedQuery {
version: 1,
sha256_hash: "notpresent".to_string(),
};
let result = store.handle_apq(&ext, None).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_handle_apq_second_leg_stores_query() {
let store = PersistedQueryStore::new();
let query = "{ nodes { id } }";
let hash = PersistedQueryStore::hash(query);
let ext = ApqPersistedQuery {
version: 1,
sha256_hash: hash.clone(),
};
let result = store
.handle_apq(&ext, Some(query))
.await
.expect("should succeed");
assert_eq!(result, query);
let ext2 = ApqPersistedQuery {
version: 1,
sha256_hash: hash,
};
let hit = store.handle_apq(&ext2, None).await.expect("should succeed");
assert_eq!(hit, query);
}
#[tokio::test]
async fn test_remove_query() {
let store = PersistedQueryStore::new();
let query = "{ remove_me }";
let hash = store.store(query).await.expect("should succeed");
assert!(store.remove(&hash).await);
assert!(!store.remove(&hash).await); assert!(store.is_empty().await);
}
#[tokio::test]
async fn test_stats_hit_miss_counts() {
let store = PersistedQueryStore::new();
let query = "{ stat_test }";
let hash = store.store(query).await.expect("should succeed");
let _ = store.get(&hash).await; let _ = store.get(&hash).await; let _ = store.get("missing").await; let (hits, misses) = store.stats().await;
assert_eq!(hits, 2);
assert_eq!(misses, 1);
}
#[tokio::test]
async fn test_stats_reset() {
let store = PersistedQueryStore::new();
let query = "{ reset_test }";
let hash = store.store(query).await.expect("should succeed");
let _ = store.get(&hash).await;
store.reset_stats().await;
let (hits, misses) = store.stats().await;
assert_eq!(hits, 0);
assert_eq!(misses, 0);
}
#[tokio::test]
async fn test_clear_empties_store() {
let store = PersistedQueryStore::new();
for i in 0..5 {
let query = format!("{{ q{i} }}");
store.store(&query).await.expect("should succeed");
}
assert_eq!(store.len().await, 5);
store.clear().await;
assert!(store.is_empty().await);
}
#[tokio::test]
async fn test_multiple_queries_stored() {
let store = PersistedQueryStore::new();
let queries = ["{ a }", "{ b }", "{ c }"];
for q in &queries {
store.store(q).await.expect("should succeed");
}
assert_eq!(store.len().await, 3);
}
#[tokio::test]
async fn test_idempotent_store() {
let store = PersistedQueryStore::new();
let query = "{ idempotent }";
let h1 = store.store(query).await.expect("should succeed");
let h2 = store.store(query).await.expect("should succeed");
assert_eq!(h1, h2);
assert_eq!(store.len().await, 1);
}
#[tokio::test]
async fn test_store_round_trip_with_temp_file() {
let store = PersistedQueryStore::new();
let query = "{ temp_file_test }";
let hash = store.store(query).await.expect("should succeed");
let temp_dir = std::env::temp_dir();
let path = temp_dir.join(format!("{hash}.hash"));
tokio::fs::write(&path, hash.as_bytes())
.await
.expect("should succeed");
let read_back = tokio::fs::read_to_string(&path)
.await
.expect("should succeed");
let retrieved = store.get(&read_back).await.expect("should succeed");
assert_eq!(retrieved, query);
let _ = tokio::fs::remove_file(&path).await;
}
}