use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use thiserror::Error;
pub type DictId = u32;
#[derive(Debug, Error)]
pub enum DictError {
#[error("dictionary I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("dictionary {0} not found in registry")]
NotFound(DictId),
#[error("dictionary {0} is already registered")]
DuplicateId(DictId),
#[error("dictionary bytes are missing the zstd dict header magic")]
InvalidDictBytes,
#[error("dictionary bytes have no embedded ID and one was not supplied")]
NoEmbeddedId,
#[error("dictionary {0} has no info sidecar")]
MissingInfo(DictId),
#[error("registry serialization error: {0}")]
Serialization(String),
}
pub type DictResult<T> = Result<T, DictError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DictInfo {
pub id: DictId,
pub created_at: DateTime<Utc>,
pub sample_count: u32,
pub size_bytes: u64,
pub is_active: bool,
}
#[derive(Clone)]
pub struct DictRegistry {
inner: Arc<RwLock<RegistryState>>,
}
struct RegistryState {
root: Option<PathBuf>,
cache: HashMap<DictId, Arc<Vec<u8>>>,
known: HashMap<DictId, DictInfo>,
active: Option<DictId>,
}
impl DictRegistry {
pub fn in_memory() -> Self {
Self {
inner: Arc::new(RwLock::new(RegistryState {
root: None,
cache: HashMap::new(),
known: HashMap::new(),
active: None,
})),
}
}
pub fn open<P: AsRef<Path>>(root: P) -> DictResult<Self> {
let root = root.as_ref().to_path_buf();
std::fs::create_dir_all(&root)?;
let mut known: HashMap<DictId, DictInfo> = HashMap::new();
for entry in std::fs::read_dir(&root)? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) != Some("json") {
continue;
}
let raw = match std::fs::read(&path) {
Ok(b) => b,
Err(_) => continue,
};
let info: DictInfo = serde_json::from_slice(&raw)
.map_err(|e| DictError::Serialization(e.to_string()))?;
known.insert(info.id, info);
}
let active_path = root.join("active");
let active = if active_path.exists() {
let raw = std::fs::read_to_string(&active_path)?;
raw.trim()
.parse::<DictId>()
.ok()
.filter(|id| known.contains_key(id))
} else {
None
};
Ok(Self {
inner: Arc::new(RwLock::new(RegistryState {
root: Some(root),
cache: HashMap::new(),
known,
active,
})),
})
}
pub fn register(&self, bytes: Vec<u8>, sample_count: u32) -> DictResult<DictInfo> {
let id = extract_dict_id(&bytes)?;
let mut state = self.inner.write().unwrap();
if state.known.contains_key(&id) {
return Err(DictError::DuplicateId(id));
}
let info = DictInfo {
id,
created_at: Utc::now(),
sample_count,
size_bytes: bytes.len() as u64,
is_active: false,
};
if let Some(root) = state.root.clone() {
std::fs::write(root.join(format!("{id}.zstd")), &bytes)?;
let info_json = serde_json::to_vec_pretty(&info)
.map_err(|e| DictError::Serialization(e.to_string()))?;
std::fs::write(root.join(format!("{id}.json")), info_json)?;
}
state.cache.insert(id, Arc::new(bytes));
state.known.insert(id, info.clone());
Ok(info)
}
pub fn activate(&self, id: DictId) -> DictResult<()> {
let mut state = self.inner.write().unwrap();
if !state.known.contains_key(&id) {
return Err(DictError::NotFound(id));
}
for info in state.known.values_mut() {
info.is_active = info.id == id;
}
state.active = Some(id);
if let Some(root) = state.root.clone() {
std::fs::write(root.join("active"), format!("{id}\n"))?;
let snapshot: Vec<DictInfo> = state.known.values().cloned().collect();
for info in snapshot {
let info_path = root.join(format!("{}.json", info.id));
let info_json = serde_json::to_vec_pretty(&info)
.map_err(|e| DictError::Serialization(e.to_string()))?;
std::fs::write(info_path, info_json)?;
}
}
Ok(())
}
pub fn active_id(&self) -> Option<DictId> {
self.inner.read().unwrap().active
}
pub fn get_bytes(&self, id: DictId) -> DictResult<Arc<Vec<u8>>> {
{
let state = self.inner.read().unwrap();
if let Some(bytes) = state.cache.get(&id) {
return Ok(Arc::clone(bytes));
}
}
let mut state = self.inner.write().unwrap();
if let Some(bytes) = state.cache.get(&id) {
return Ok(Arc::clone(bytes));
}
if !state.known.contains_key(&id) {
return Err(DictError::NotFound(id));
}
let root = state
.root
.as_ref()
.ok_or(DictError::NotFound(id))?
.clone();
let bytes = std::fs::read(root.join(format!("{id}.zstd")))?;
let arc = Arc::new(bytes);
state.cache.insert(id, Arc::clone(&arc));
Ok(arc)
}
pub fn list(&self) -> Vec<DictInfo> {
let state = self.inner.read().unwrap();
let mut out: Vec<DictInfo> = state.known.values().cloned().collect();
out.sort_by_key(|d| d.created_at);
out
}
pub fn len(&self) -> usize {
self.inner.read().unwrap().known.len()
}
pub fn is_empty(&self) -> bool {
self.inner.read().unwrap().known.is_empty()
}
}
fn extract_dict_id(bytes: &[u8]) -> DictResult<DictId> {
if bytes.len() < 8 {
return Err(DictError::InvalidDictBytes);
}
let magic = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
if magic != 0xEC30_A437 {
return Err(DictError::InvalidDictBytes);
}
let id = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
if id == 0 {
return Err(DictError::NoEmbeddedId);
}
Ok(id)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compression::varint;
use tempfile::TempDir;
fn train_dict(seed: u32) -> Vec<u8> {
let samples: Vec<Vec<u8>> = (0u32..30)
.map(|i| {
let tokens: Vec<u32> = (0..200).map(|t| (t + i * seed) % 30_000).collect();
varint::encode_tokens(&tokens)
})
.collect();
let refs: Vec<&[u8]> = samples.iter().map(Vec::as_slice).collect();
zstd::dict::from_samples(&refs, 4096).expect("train dict")
}
#[test]
fn extract_id_from_real_dict() {
let dict = train_dict(7);
let id = extract_dict_id(&dict).unwrap();
assert!(id != 0, "trained dict should have a non-zero embedded id");
}
#[test]
fn extract_id_rejects_garbage() {
assert!(matches!(
extract_dict_id(b"not a real dict"),
Err(DictError::InvalidDictBytes)
));
}
#[test]
fn in_memory_register_get_activate() {
let reg = DictRegistry::in_memory();
let dict = train_dict(11);
let info = reg.register(dict.clone(), 30).unwrap();
assert_eq!(reg.len(), 1);
assert!(reg.active_id().is_none());
reg.activate(info.id).unwrap();
assert_eq!(reg.active_id(), Some(info.id));
let bytes = reg.get_bytes(info.id).unwrap();
assert_eq!(bytes.as_slice(), dict.as_slice());
}
#[test]
fn duplicate_register_errors() {
let reg = DictRegistry::in_memory();
let dict = train_dict(13);
reg.register(dict.clone(), 30).unwrap();
let err = reg.register(dict, 30).unwrap_err();
assert!(matches!(err, DictError::DuplicateId(_)));
}
#[test]
fn activate_unknown_errors() {
let reg = DictRegistry::in_memory();
let err = reg.activate(42).unwrap_err();
assert!(matches!(err, DictError::NotFound(42)));
}
#[test]
fn filesystem_persistence_round_trip() {
let dir = TempDir::new().unwrap();
let dict = train_dict(17);
let id;
{
let reg = DictRegistry::open(dir.path()).unwrap();
let info = reg.register(dict.clone(), 30).unwrap();
reg.activate(info.id).unwrap();
id = info.id;
}
let reg2 = DictRegistry::open(dir.path()).unwrap();
assert_eq!(reg2.len(), 1);
assert_eq!(reg2.active_id(), Some(id));
let bytes = reg2.get_bytes(id).unwrap();
assert_eq!(bytes.as_slice(), dict.as_slice());
}
#[test]
fn multiple_dicts_coexist() {
let dir = TempDir::new().unwrap();
let reg = DictRegistry::open(dir.path()).unwrap();
let d1 = train_dict(19);
let d2 = train_dict(23);
let i1 = reg.register(d1.clone(), 30).unwrap().id;
let i2 = reg.register(d2.clone(), 30).unwrap().id;
assert_ne!(i1, i2, "different training corpora should produce different ids");
reg.activate(i2).unwrap();
assert_eq!(reg.get_bytes(i1).unwrap().as_slice(), d1.as_slice());
assert_eq!(reg.get_bytes(i2).unwrap().as_slice(), d2.as_slice());
assert_eq!(reg.active_id(), Some(i2));
}
}