use std::{
collections::HashMap,
path::Path,
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
};
pub(crate) const MAX_MANIFEST_BYTES: u64 = 10 * 1024 * 1024;
use dashmap::DashMap;
use serde::Deserialize;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum TrustedDocumentMode {
Strict,
Permissive,
}
#[derive(Debug, Deserialize)]
struct Manifest {
#[allow(dead_code)] version: u32,
documents: HashMap<String, String>,
}
pub struct TrustedDocumentStore {
documents: Arc<DashMap<String, String>>,
mode: TrustedDocumentMode,
}
impl TrustedDocumentStore {
pub fn from_manifest_file(
path: &Path,
mode: TrustedDocumentMode,
) -> Result<Self, TrustedDocumentError> {
let file_size = std::fs::metadata(path)
.map_err(|e| {
TrustedDocumentError::ManifestLoad(format!(
"Failed to stat manifest {}: {e}",
path.display()
))
})?
.len();
if file_size > MAX_MANIFEST_BYTES {
return Err(TrustedDocumentError::ManifestLoad(format!(
"Manifest {} is too large ({file_size} bytes, max {MAX_MANIFEST_BYTES})",
path.display()
)));
}
let contents = std::fs::read_to_string(path).map_err(|e| {
TrustedDocumentError::ManifestLoad(format!(
"Failed to read manifest {}: {e}",
path.display()
))
})?;
let manifest: Manifest = serde_json::from_str(&contents).map_err(|e| {
TrustedDocumentError::ManifestLoad(format!(
"Failed to parse manifest {}: {e}",
path.display()
))
})?;
Ok(Self {
documents: Arc::new(normalize_keys(manifest.documents)),
mode,
})
}
#[must_use]
pub fn from_documents(documents: HashMap<String, String>, mode: TrustedDocumentMode) -> Self {
Self {
documents: Arc::new(normalize_keys(documents)),
mode,
}
}
#[must_use]
pub fn disabled() -> Self {
Self {
documents: Arc::new(DashMap::new()),
mode: TrustedDocumentMode::Permissive,
}
}
#[must_use]
pub const fn mode(&self) -> TrustedDocumentMode {
self.mode
}
#[must_use]
pub fn document_count(&self) -> usize {
self.documents.len()
}
pub fn replace_documents(&self, documents: HashMap<String, String>) {
let new_docs = normalize_keys(documents);
self.documents.clear();
for entry in new_docs {
self.documents.insert(entry.0, entry.1);
}
}
pub fn resolve(
&self,
document_id: Option<&str>,
raw_query: Option<&str>,
) -> Result<String, TrustedDocumentError> {
if let Some(doc_id) = document_id {
let hash = doc_id.strip_prefix("sha256:").unwrap_or(doc_id);
return self.documents.get(hash).map(|r| r.value().clone()).ok_or_else(|| {
TrustedDocumentError::DocumentNotFound {
id: doc_id.to_string(),
}
});
}
match self.mode {
TrustedDocumentMode::Strict => Err(TrustedDocumentError::ForbiddenRawQuery),
TrustedDocumentMode::Permissive => {
raw_query.map(|s| s.to_string()).ok_or(TrustedDocumentError::ForbiddenRawQuery)
},
}
}
}
fn normalize_keys(documents: HashMap<String, String>) -> DashMap<String, String> {
let out = DashMap::with_capacity(documents.len());
for (k, v) in documents {
let key = k.strip_prefix("sha256:").unwrap_or(&k).to_string();
out.insert(key, v);
}
out
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum TrustedDocumentError {
#[error("Raw queries are not permitted. Send a documentId instead.")]
ForbiddenRawQuery,
#[error("Unknown document: {id}")]
DocumentNotFound {
id: String,
},
#[error("Manifest load error: {0}")]
ManifestLoad(String),
}
static TRUSTED_DOC_HITS: AtomicU64 = AtomicU64::new(0);
static TRUSTED_DOC_MISSES: AtomicU64 = AtomicU64::new(0);
static TRUSTED_DOC_REJECTED: AtomicU64 = AtomicU64::new(0);
pub fn record_hit() {
TRUSTED_DOC_HITS.fetch_add(1, Ordering::Relaxed);
}
pub fn record_miss() {
TRUSTED_DOC_MISSES.fetch_add(1, Ordering::Relaxed);
}
pub fn record_rejected() {
TRUSTED_DOC_REJECTED.fetch_add(1, Ordering::Relaxed);
}
pub fn hits_total() -> u64 {
TRUSTED_DOC_HITS.load(Ordering::Relaxed)
}
pub fn misses_total() -> u64 {
TRUSTED_DOC_MISSES.load(Ordering::Relaxed)
}
pub fn rejected_total() -> u64 {
TRUSTED_DOC_REJECTED.load(Ordering::Relaxed)
}