use alloc::collections::BTreeMap;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use spin::Mutex;
use super::compress::{compress, decompress};
use super::train::train;
use super::types::{
CompressedHeader, CompressionDict, DictError, DictResult, DictStats, TrainingOptions,
};
lazy_static::lazy_static! {
static ref DICTIONARIES: Mutex<BTreeMap<u64, CompressionDict>> = Mutex::new(BTreeMap::new());
static ref DATASET_DICTS: Mutex<BTreeMap<String, Vec<(String, u64)>>> = Mutex::new(BTreeMap::new());
static ref DICT_STATS: Mutex<BTreeMap<u64, DictStats>> = Mutex::new(BTreeMap::new());
static ref NEXT_ID: Mutex<u64> = Mutex::new(1);
}
pub fn register_dict(dict: CompressionDict) -> u64 {
let id = dict.id;
let dataset = dict.dataset.clone();
let pattern = dict.pattern.clone();
DICTIONARIES.lock().insert(id, dict);
DATASET_DICTS
.lock()
.entry(dataset)
.or_default()
.push((pattern, id));
DICT_STATS.lock().insert(id, DictStats::default());
id
}
pub fn get_dict(id: u64) -> Option<CompressionDict> {
DICTIONARIES.lock().get(&id).cloned()
}
pub fn remove_dict(id: u64) -> Option<CompressionDict> {
let dict = DICTIONARIES.lock().remove(&id)?;
let mut ds_dicts = DATASET_DICTS.lock();
if let Some(mappings) = ds_dicts.get_mut(&dict.dataset) {
mappings.retain(|(_, did)| *did != id);
}
DICT_STATS.lock().remove(&id);
Some(dict)
}
pub fn list_dicts() -> Vec<(u64, String, String, String)> {
DICTIONARIES
.lock()
.values()
.map(|d| (d.id, d.name.clone(), d.dataset.clone(), d.pattern.clone()))
.collect()
}
pub fn find_dict_for_path(dataset: &str, path: &str) -> Option<CompressionDict> {
let ds_dicts = DATASET_DICTS.lock();
let mappings = ds_dicts.get(dataset)?;
let dicts = DICTIONARIES.lock();
for (pattern, id) in mappings {
if let Some(dict) = dicts.get(id) {
if dict.matches_pattern(path) {
return Some(dict.clone());
}
}
}
None
}
pub fn next_dict_id() -> u64 {
let mut id = NEXT_ID.lock();
let current = *id;
*id += 1;
current
}
pub fn clear_all() {
DICTIONARIES.lock().clear();
DATASET_DICTS.lock().clear();
DICT_STATS.lock().clear();
*NEXT_ID.lock() = 1;
}
pub fn auto_train(
name: &str,
dataset: &str,
pattern: &str,
samples: &[&[u8]],
options: &TrainingOptions,
timestamp: u64,
) -> DictResult<CompressionDict> {
let id = next_dict_id();
let dict = train(id, name, pattern, dataset, samples, options, timestamp)?;
register_dict(dict.clone());
Ok(dict)
}
pub fn auto_train_default(
name: &str,
dataset: &str,
pattern: &str,
samples: &[&[u8]],
timestamp: u64,
) -> DictResult<CompressionDict> {
auto_train(
name,
dataset,
pattern,
samples,
&TrainingOptions::default(),
timestamp,
)
}
pub fn compress_auto(dataset: &str, path: &str, data: &[u8]) -> DictResult<Vec<u8>> {
let dict = find_dict_for_path(dataset, path)
.ok_or_else(|| DictError::NoMatchingDict(path.to_string()))?;
let result = compress(data, &dict)?;
if let Some(stats) = DICT_STATS.lock().get_mut(&dict.id) {
stats.record_compression(
data.len() as u64,
result.len() as u64,
0, 0,
);
}
Ok(result)
}
pub fn decompress_auto(compressed: &[u8]) -> DictResult<Vec<u8>> {
let header = CompressedHeader::from_bytes(compressed)
.ok_or_else(|| DictError::InvalidData("invalid header".to_string()))?;
let dict = get_dict(header.dict_id).ok_or(DictError::DictNotFound(header.dict_id))?;
let result = decompress(compressed, &dict)?;
if let Some(stats) = DICT_STATS.lock().get_mut(&dict.id) {
stats.record_decompression();
}
Ok(result)
}
pub fn is_dict_compressed(data: &[u8]) -> bool {
CompressedHeader::is_dict_compressed(data)
}
pub fn get_stats(id: u64) -> Option<DictStats> {
DICT_STATS.lock().get(&id).cloned()
}
pub fn get_all_stats() -> BTreeMap<u64, DictStats> {
DICT_STATS.lock().clone()
}
#[derive(Debug, Clone, Default)]
pub struct GlobalStats {
pub dict_count: usize,
pub dict_bytes: u64,
pub total_compressions: u64,
pub total_decompressions: u64,
pub bytes_saved: u64,
}
pub fn get_global_stats() -> GlobalStats {
let dicts = DICTIONARIES.lock();
let stats = DICT_STATS.lock();
let dict_count = dicts.len();
let dict_bytes: u64 = dicts.values().map(|d| d.size() as u64).sum();
let mut total_compressions = 0u64;
let mut total_decompressions = 0u64;
let mut bytes_saved = 0u64;
for s in stats.values() {
total_compressions += s.compressions;
total_decompressions += s.decompressions;
if s.bytes_in > s.bytes_out {
bytes_saved += s.bytes_in - s.bytes_out;
}
}
GlobalStats {
dict_count,
dict_bytes,
total_compressions,
total_decompressions,
bytes_saved,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn setup() {
clear_all();
}
#[test]
fn test_register_get_dict() {
let id = 1000u64;
let dict = CompressionDict::new(
id,
"test_register",
alloc::vec![1, 2, 3],
"*.txt",
"test_ds",
0,
);
register_dict(dict.clone());
let retrieved = get_dict(id).unwrap();
assert_eq!(retrieved.id, id);
assert_eq!(retrieved.name, "test_register");
}
#[test]
fn test_remove_dict() {
let id = 2000u64;
let dict = CompressionDict::new(id, "test_remove", alloc::vec![], "*", "ds", 0);
register_dict(dict);
assert!(get_dict(id).is_some());
remove_dict(id);
assert!(get_dict(id).is_none());
}
#[test]
fn test_find_dict_for_path() {
let id = 3000u64;
let dict = CompressionDict::new(
id,
"json_dict",
alloc::vec![1, 2, 3],
"*.json",
"find_ds",
0,
);
register_dict(dict);
let found = find_dict_for_path("find_ds", "data.json");
assert!(found.is_some());
assert_eq!(found.unwrap().id, id);
let not_found = find_dict_for_path("find_ds", "data.xml");
assert!(not_found.is_none());
}
#[test]
fn test_next_dict_id() {
let id1 = next_dict_id();
let id2 = next_dict_id();
let id3 = next_dict_id();
assert!(id2 > id1);
assert!(id3 > id2);
}
#[test]
fn test_auto_train() {
let samples: &[&[u8]] = &[b"json data with fields", b"json data with values"];
let dict = auto_train_default("auto_test", "auto_ds", "*.json", samples, 12345).unwrap();
assert!(!dict.data.is_empty());
assert!(get_dict(dict.id).is_some());
}
#[test]
fn test_compress_decompress_auto() {
let id = 4000u64;
let dict = CompressionDict::new(
id,
"auto_comp",
b"common pattern".to_vec(),
"*.txt",
"auto_comp_ds",
0,
);
register_dict(dict);
let data = b"common pattern is used here";
let compressed = compress_auto("auto_comp_ds", "file.txt", data).unwrap();
let decompressed = decompress_auto(&compressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_is_dict_compressed() {
let id = 5000u64;
let dict = CompressionDict::new(id, "check", alloc::vec![], "*", "ds", 0);
register_dict(dict.clone());
let data = b"test";
let compressed = compress(data, &dict).unwrap();
assert!(is_dict_compressed(&compressed));
assert!(!is_dict_compressed(data));
}
#[test]
fn test_stats() {
let id = 6000u64;
let dict =
CompressionDict::new(id, "stats_test", b"dictionary".to_vec(), "*", "stats_ds", 0);
register_dict(dict.clone());
let data = b"dictionary is here";
let _ = compress(data, &dict).unwrap();
let global = get_global_stats();
assert!(global.dict_count >= 1);
}
#[test]
fn test_list_dicts() {
let id = 7000u64;
let dict = CompressionDict::new(id, "list_test", alloc::vec![], "*.log", "list_ds", 0);
register_dict(dict);
let list = list_dicts();
assert!(list.iter().any(|(did, _, _, _)| *did == id));
}
}