use std::sync::Arc;
use dashmap::DashMap;
use pjson_rs_domain::value_objects::SessionId;
use tokio::sync::{Mutex, OnceCell};
use crate::{
Error, Result,
compression::zstd::{MAX_DICT_SIZE, N_TRAIN, ZstdDictCompressor, ZstdDictionary},
domain::ports::dictionary_store::{DictionaryFuture, DictionaryStore},
security::CompressionBombDetector,
};
struct SessionDictState {
corpus: Mutex<Vec<Vec<u8>>>,
dict: OnceCell<Arc<ZstdDictionary>>,
}
pub struct InMemoryDictionaryStore {
sessions: DashMap<SessionId, Arc<SessionDictState>>,
bomb_detector: Arc<CompressionBombDetector>,
target_dict_size: usize,
}
impl InMemoryDictionaryStore {
pub fn new(bomb_detector: Arc<CompressionBombDetector>, target_dict_size: usize) -> Self {
Self {
sessions: DashMap::new(),
bomb_detector,
target_dict_size: target_dict_size.min(MAX_DICT_SIZE),
}
}
pub fn register(&self, session_id: SessionId, dict: ZstdDictionary) -> Result<()> {
self.bomb_detector
.validate_pre_decompression(dict.len())
.map_err(|e| {
Error::CompressionError(format!("dictionary rejected by bomb detector: {e}"))
})?;
let state = self
.sessions
.entry(session_id)
.or_insert_with(|| {
Arc::new(SessionDictState {
corpus: Mutex::new(Vec::new()),
dict: OnceCell::new(),
})
})
.clone();
let _ = state.dict.set(Arc::new(dict));
Ok(())
}
fn session_state(&self, session_id: SessionId) -> Arc<SessionDictState> {
self.sessions
.entry(session_id)
.or_insert_with(|| {
Arc::new(SessionDictState {
corpus: Mutex::new(Vec::new()),
dict: OnceCell::new(),
})
})
.clone()
}
}
impl DictionaryStore for InMemoryDictionaryStore {
fn get_dictionary<'a>(
&'a self,
session_id: SessionId,
) -> DictionaryFuture<'a, Option<Arc<ZstdDictionary>>> {
Box::pin(async move {
Ok(self
.sessions
.get(&session_id)
.and_then(|s| s.dict.get().cloned()))
})
}
fn train_if_ready<'a>(
&'a self,
session_id: SessionId,
sample: Vec<u8>,
) -> DictionaryFuture<'a, ()> {
Box::pin(async move {
let state = self.session_state(session_id);
if state.dict.initialized() {
return Ok(());
}
let snapshot = {
let mut guard = state.corpus.lock().await;
if guard.len() < N_TRAIN {
guard.push(sample);
}
if guard.len() < N_TRAIN {
return Ok(());
}
guard.clone()
};
let target = self.target_dict_size;
let bomb_detector = self.bomb_detector.clone();
let _ = state
.dict
.get_or_try_init(|| async move {
let dict = tokio::task::spawn_blocking(move || {
ZstdDictCompressor::train(&snapshot, target)
})
.await
.map_err(|e| {
Error::CompressionError(format!("zstd: train join error: {e}"))
})??;
bomb_detector
.validate_pre_decompression(dict.len())
.map_err(|e| {
Error::CompressionError(format!(
"trained dict rejected by bomb detector: {e}"
))
})?;
Ok::<_, Error>(Arc::new(dict))
})
.await?;
Ok(())
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use pjson_rs_domain::value_objects::SessionId;
fn make_store() -> InMemoryDictionaryStore {
InMemoryDictionaryStore::new(Arc::new(CompressionBombDetector::default()), 64 * 1024)
}
fn make_samples(count: usize) -> Vec<Vec<u8>> {
(0..count)
.map(|i| format!(r#"{{"id":{i},"name":"item","value":{}}}"#, i * 10).into_bytes())
.collect()
}
#[tokio::test]
async fn test_get_dictionary_returns_none_before_training() {
let store = make_store();
let sid = SessionId::new();
let result = store.get_dictionary(sid).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_train_if_ready_below_threshold_stays_none() {
let store = make_store();
let sid = SessionId::new();
for i in 0..(N_TRAIN - 1) {
let sample = format!(r#"{{"i":{i}}}"#).into_bytes();
store.train_if_ready(sid, sample).await.unwrap();
}
let result = store.get_dictionary(sid).await.unwrap();
assert!(
result.is_none(),
"should still be None before N_TRAIN samples"
);
}
#[tokio::test]
async fn test_train_if_ready_fires_after_threshold() {
let store = make_store();
let sid = SessionId::new();
let samples = make_samples(N_TRAIN);
for sample in samples {
store.train_if_ready(sid, sample).await.unwrap();
}
let result = store.get_dictionary(sid).await.unwrap();
assert!(
result.is_some(),
"dictionary should be Some after N_TRAIN samples"
);
}
#[tokio::test]
async fn test_register_then_get_returns_dict() {
let store = make_store();
let sid = SessionId::new();
let samples = make_samples(N_TRAIN);
let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
store.register(sid, dict).unwrap();
let result = store.get_dictionary(sid).await.unwrap();
assert!(result.is_some());
}
#[tokio::test]
async fn test_concurrent_train_if_ready_produces_exactly_one_dict() {
use futures::future::try_join_all;
let store = Arc::new(make_store());
let sid = SessionId::new();
let samples = make_samples(N_TRAIN * 2);
let futs: Vec<_> = samples
.into_iter()
.map(|sample| {
let store = store.clone();
tokio::spawn(async move { store.train_if_ready(sid, sample).await })
})
.collect();
let results = try_join_all(futs).await.unwrap();
for r in results {
r.unwrap();
}
let result = store.get_dictionary(sid).await.unwrap();
assert!(result.is_some(), "exactly one dictionary should be trained");
}
#[tokio::test]
async fn test_train_if_ready_bomb_detector_rejects_trained_dict() {
use crate::security::CompressionBombConfig;
let config = CompressionBombConfig {
max_compressed_size: 100,
..Default::default()
};
let store = InMemoryDictionaryStore::new(
Arc::new(CompressionBombDetector::new(config)),
MAX_DICT_SIZE,
);
let sid = SessionId::new();
let samples = make_samples(N_TRAIN);
let mut training_error_seen = false;
for sample in samples {
let result = store.train_if_ready(sid, sample).await;
if result.is_err() {
training_error_seen = true;
break;
}
}
assert!(
training_error_seen,
"expected bomb detector to reject the trained dict"
);
let result = store.get_dictionary(sid).await.unwrap();
assert!(
result.is_none(),
"bomb detector should have prevented dict from being stored"
);
}
#[test]
fn test_register_rejects_oversized_dict_via_bomb_detector() {
use crate::security::CompressionBombConfig;
let config = CompressionBombConfig {
max_compressed_size: 10, ..Default::default()
};
let store = InMemoryDictionaryStore::new(
Arc::new(CompressionBombDetector::new(config)),
MAX_DICT_SIZE,
);
let sid = SessionId::new();
let samples = make_samples(N_TRAIN);
let dict = ZstdDictCompressor::train(&samples, MAX_DICT_SIZE).unwrap();
let result = store.register(sid, dict);
assert!(result.is_err(), "bomb detector must reject oversized dict");
}
}