use alloc::collections::BTreeMap;
use alloc::string::ToString;
use alloc::vec::Vec;
use super::types::{
CompressionDict, DEFAULT_DICT_SIZE, DictError, DictResult, MAX_DICT_SIZE, MIN_DICT_SIZE,
MIN_MATCH_LEN, SubstringEntry, TrainingOptions,
};
pub fn train_dictionary(samples: &[&[u8]], options: &TrainingOptions) -> DictResult<Vec<u8>> {
if samples.len() < 2 {
return Err(DictError::InsufficientSamples(samples.len()));
}
let dict_size = options.dict_size.clamp(MIN_DICT_SIZE, MAX_DICT_SIZE);
let min_len = options.min_substring_len.max(MIN_MATCH_LEN);
let max_len = options.max_substring_len.min(256);
let mut substring_counts: BTreeMap<Vec<u8>, SubstringEntry> = BTreeMap::new();
for sample in samples {
let mut seen_in_sample: alloc::collections::BTreeSet<Vec<u8>> =
alloc::collections::BTreeSet::new();
for len in min_len..=max_len.min(sample.len()) {
for start in 0..=sample.len().saturating_sub(len) {
let substr = sample[start..start + len].to_vec();
if seen_in_sample.contains(&substr) {
continue;
}
seen_in_sample.insert(substr.clone());
if let Some(entry) = substring_counts.get_mut(&substr) {
entry.increment();
} else {
substring_counts.insert(substr.clone(), SubstringEntry::new(substr));
}
}
}
}
let mut scored: Vec<(u64, Vec<u8>)> = substring_counts
.into_iter()
.filter(|(_, entry)| entry.count >= options.min_occurrences)
.map(|(data, entry)| (entry.score(), data))
.collect();
scored.sort_by(|a, b| b.0.cmp(&a.0));
let mut dict_data = Vec::with_capacity(dict_size);
let mut included: alloc::collections::BTreeSet<Vec<u8>> = alloc::collections::BTreeSet::new();
for (_, substr) in scored {
let dominated = included
.iter()
.any(|inc| contains_substring(&substr, inc) || contains_substring(inc, &substr));
if dominated {
continue;
}
if dict_data.len() + substr.len() > dict_size {
if dict_data.len() + MIN_MATCH_LEN > dict_size {
break;
}
continue;
}
included.insert(substr.clone());
dict_data.extend_from_slice(&substr);
}
Ok(dict_data)
}
pub fn train_dictionary_default(samples: &[&[u8]]) -> DictResult<Vec<u8>> {
train_dictionary(samples, &TrainingOptions::default())
}
pub fn train(
id: u64,
name: &str,
pattern: &str,
dataset: &str,
samples: &[&[u8]],
options: &TrainingOptions,
timestamp: u64,
) -> DictResult<CompressionDict> {
let data = train_dictionary(samples, options)?;
let mut total_ratio = 0.0f32;
let mut sample_count = 0u32;
for sample in samples {
let compressed_estimate = estimate_compressed_size(&data, sample);
if !sample.is_empty() {
total_ratio += compressed_estimate as f32 / sample.len() as f32;
sample_count += 1;
}
}
let avg_ratio = if sample_count > 0 {
total_ratio / sample_count as f32
} else {
1.0
};
Ok(
CompressionDict::new(id, name, data, pattern, dataset, timestamp)
.with_training_info(sample_count, avg_ratio),
)
}
fn contains_substring(haystack: &[u8], needle: &[u8]) -> bool {
if needle.is_empty() || needle.len() > haystack.len() {
return false;
}
for i in 0..=haystack.len() - needle.len() {
if &haystack[i..i + needle.len()] == needle {
return true;
}
}
false
}
fn estimate_compressed_size(dict: &[u8], data: &[u8]) -> usize {
if data.is_empty() {
return 0;
}
let mut pos = 0;
let mut compressed = 0;
while pos < data.len() {
let (match_offset, match_len) = find_longest_match(dict, &data[pos..]);
if match_len >= MIN_MATCH_LEN {
compressed += 5;
pos += match_len;
} else {
let literal_start = pos;
while pos < data.len() {
let (_, len) = find_longest_match(dict, &data[pos..]);
if len >= MIN_MATCH_LEN {
break;
}
pos += 1;
}
let literal_len = pos - literal_start;
compressed += 3 + literal_len;
}
}
compressed
}
fn find_longest_match(dict: &[u8], data: &[u8]) -> (usize, usize) {
if dict.is_empty() || data.len() < MIN_MATCH_LEN {
return (0, 0);
}
let mut best_offset = 0;
let mut best_len = 0;
let max_check = data.len().min(256);
for offset in 0..dict.len() {
let remaining_dict = dict.len() - offset;
let max_len = remaining_dict.min(max_check).min(u16::MAX as usize);
let mut len = 0;
while len < max_len && offset + len < dict.len() && len < data.len() {
if dict[offset + len] != data[len] {
break;
}
len += 1;
}
if len > best_len && len >= MIN_MATCH_LEN {
best_offset = offset;
best_len = len;
}
}
(best_offset, best_len)
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn test_train_insufficient_samples() {
let samples: &[&[u8]] = &[b"only one"];
let result = train_dictionary(samples, &TrainingOptions::default());
assert!(matches!(result, Err(DictError::InsufficientSamples(1))));
}
#[test]
fn test_train_basic() {
let samples: &[&[u8]] = &[
b"hello world hello",
b"hello there world",
b"world hello world",
];
let dict = train_dictionary_default(samples).unwrap();
assert!(!dict.is_empty());
}
#[test]
fn test_train_with_options() {
let samples: &[&[u8]] = &[
b"the quick brown fox jumps over",
b"the lazy dog jumps over the",
b"quick brown fox the lazy dog",
];
let options = TrainingOptions::default()
.with_size(1024)
.min_len(4)
.min_count(2);
let dict = train_dictionary(samples, &options).unwrap();
assert!(dict.len() <= 1024);
}
#[test]
fn test_train_repeated_patterns() {
let pattern = b"ABCD".repeat(100);
let samples: &[&[u8]] = &[&pattern, &pattern, &pattern];
let dict = train_dictionary_default(samples).unwrap();
assert!(!dict.is_empty());
}
#[test]
fn test_contains_substring() {
assert!(contains_substring(b"hello world", b"world"));
assert!(contains_substring(b"hello world", b"hello"));
assert!(contains_substring(b"hello world", b"lo wo"));
assert!(!contains_substring(b"hello", b"world"));
assert!(!contains_substring(b"hi", b"hello"));
assert!(contains_substring(b"aaa", b"aaa"));
}
#[test]
fn test_find_longest_match() {
let dict = b"hello world";
let (offset, len) = find_longest_match(dict, b"hello there");
assert_eq!(offset, 0);
assert_eq!(len, 6);
let (offset2, len2) = find_longest_match(dict, b"world hello");
assert_eq!(offset2, 6);
assert_eq!(len2, 5);
let (_, len3) = find_longest_match(dict, b"xyz");
assert_eq!(len3, 0); }
#[test]
fn test_estimate_compressed_size() {
let dict = b"hello world";
let data = b"hello world hello";
let estimated = estimate_compressed_size(dict, data);
assert!(estimated > 0);
}
#[test]
fn test_train_full() {
let samples: &[&[u8]] = &[
b"json data with common patterns",
b"json data with repeated fields",
];
let dict = train(
42,
"test_dict",
"*.json",
"pool/data",
samples,
&TrainingOptions::default().with_size(1024),
12345,
)
.unwrap();
assert_eq!(dict.id, 42);
assert_eq!(dict.name, "test_dict");
assert_eq!(dict.pattern, "*.json");
assert!(dict.sample_count > 0);
}
}