use crate::xxhash::xxhash64_with_seed;
use oxiarc_core::error::{OxiArcError, Result};
use std::collections::HashMap;
pub const MAX_DICT_SIZE: usize = 1024 * 1024;
const MIN_NGRAM: usize = 4;
const MAX_NGRAM: usize = 16;
const MIN_FREQUENCY: usize = 2;
#[derive(Debug, Clone)]
pub struct ZstdDict {
data: Vec<u8>,
id: u32,
}
impl ZstdDict {
pub fn new(data: Vec<u8>) -> Result<Self> {
if data.len() > MAX_DICT_SIZE {
return Err(OxiArcError::CorruptedData {
offset: 0,
message: format!(
"dictionary too large: {} bytes exceeds maximum {} bytes",
data.len(),
MAX_DICT_SIZE
),
});
}
let id = xxhash64_with_seed(&data, 0) as u32;
Ok(Self { data, id })
}
pub fn id(&self) -> u32 {
self.id
}
pub fn data(&self) -> &[u8] {
&self.data
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn into_data(self) -> Vec<u8> {
self.data
}
}
pub fn train_dictionary(samples: &[&[u8]], dict_size: usize) -> Result<ZstdDict> {
if samples.is_empty() {
return Err(OxiArcError::CorruptedData {
offset: 0,
message: "no samples provided for dictionary training".to_string(),
});
}
let dict_size = dict_size.min(MAX_DICT_SIZE);
let mut ngram_counts: HashMap<Vec<u8>, usize> = HashMap::new();
for sample in samples {
let max_window = MAX_NGRAM.min(sample.len());
if max_window < MIN_NGRAM {
continue; }
for window_size in MIN_NGRAM..=max_window {
for window in sample.windows(window_size) {
*ngram_counts.entry(window.to_vec()).or_insert(0) += 1;
}
}
}
let mut scored: Vec<(Vec<u8>, usize)> = ngram_counts
.into_iter()
.filter(|(_, count)| *count >= MIN_FREQUENCY)
.map(|(ngram, count)| {
let score = count * ngram.len();
(ngram, score)
})
.collect();
scored.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| b.0.len().cmp(&a.0.len())));
let mut dict_data = Vec::with_capacity(dict_size);
let mut included_ngrams: Vec<Vec<u8>> = Vec::new();
for (ngram, _score) in &scored {
if dict_data.len() + ngram.len() > dict_size {
if dict_data.len() >= dict_size {
break;
}
continue;
}
let is_substring = included_ngrams
.iter()
.any(|included| included.windows(ngram.len()).any(|w| w == ngram.as_slice()));
if is_substring {
continue;
}
dict_data.extend_from_slice(ngram);
included_ngrams.push(ngram.clone());
}
if dict_data.is_empty() {
for sample in samples {
let remaining = dict_size.saturating_sub(dict_data.len());
if remaining == 0 {
break;
}
let take = remaining.min(sample.len());
dict_data.extend_from_slice(&sample[..take]);
}
}
ZstdDict::new(dict_data)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dict_new_basic() {
let data = b"hello world".to_vec();
let dict = ZstdDict::new(data.clone()).unwrap();
assert_eq!(dict.data(), data.as_slice());
assert_eq!(dict.len(), data.len());
assert!(!dict.is_empty());
}
#[test]
fn test_dict_new_empty() {
let dict = ZstdDict::new(Vec::new()).unwrap();
assert!(dict.is_empty());
assert_eq!(dict.len(), 0);
}
#[test]
fn test_dict_too_large() {
let data = vec![0u8; MAX_DICT_SIZE + 1];
let result = ZstdDict::new(data);
assert!(result.is_err());
}
#[test]
fn test_dict_id_deterministic() {
let data = b"test dictionary".to_vec();
let dict1 = ZstdDict::new(data.clone()).unwrap();
let dict2 = ZstdDict::new(data).unwrap();
assert_eq!(dict1.id(), dict2.id());
}
#[test]
fn test_dict_id_differs_for_different_data() {
let dict_a = ZstdDict::new(b"data A".to_vec()).unwrap();
let dict_b = ZstdDict::new(b"data B".to_vec()).unwrap();
assert_ne!(dict_a.id(), dict_b.id());
}
#[test]
fn test_dict_into_data() {
let data = b"round-trip".to_vec();
let dict = ZstdDict::new(data.clone()).unwrap();
assert_eq!(dict.into_data(), data);
}
#[test]
fn test_train_dictionary_no_samples() {
let result = train_dictionary(&[], 4096);
assert!(result.is_err());
}
#[test]
fn test_train_dictionary_basic() {
let samples: Vec<&[u8]> = vec![
b"the quick brown fox jumps",
b"the quick brown dog runs",
b"the quick brown cat sleeps",
];
let dict = train_dictionary(&samples, 256).unwrap();
assert!(!dict.is_empty());
assert!(dict.len() <= 256);
}
#[test]
fn test_train_dictionary_respects_size_limit() {
let samples: Vec<&[u8]> = vec![
b"AAAA BBBB CCCC DDDD EEEE",
b"AAAA BBBB CCCC DDDD FFFF",
b"AAAA BBBB CCCC DDDD GGGG",
];
let dict = train_dictionary(&samples, 32).unwrap();
assert!(dict.len() <= 32);
}
#[test]
fn test_train_dictionary_short_samples() {
let samples: Vec<&[u8]> = vec![b"AB", b"CD", b"EF"];
let dict = train_dictionary(&samples, 64).unwrap();
assert!(!dict.is_empty());
}
#[test]
fn test_train_dictionary_identical_samples() {
let sample = b"identical content repeated";
let samples: Vec<&[u8]> = vec![sample, sample, sample];
let dict = train_dictionary(&samples, 256).unwrap();
assert!(!dict.is_empty());
}
#[test]
fn test_train_dictionary_caps_at_max_dict_size() {
let samples: Vec<&[u8]> = vec![b"data"];
let dict = train_dictionary(&samples, MAX_DICT_SIZE + 100).unwrap();
assert!(dict.len() <= MAX_DICT_SIZE);
}
#[test]
fn test_train_dictionary_common_prefix() {
let samples: Vec<&[u8]> = vec![
b"prefix_alpha_suffix",
b"prefix_beta_suffix",
b"prefix_gamma_suffix",
];
let dict = train_dictionary(&samples, 1024).unwrap();
let dict_str = String::from_utf8_lossy(dict.data());
let has_common = dict_str.contains("prefix_") || dict_str.contains("_suffix");
assert!(
has_common,
"dictionary should contain common substrings: {:?}",
dict_str
);
}
}