Skip to main content

lz4rip_encode/
dict.rs

1//! COVER dictionary trainer.
2
3use alloc::collections::VecDeque;
4use alloc::vec;
5use alloc::vec::Vec;
6
7use lz4rip_core::MAX_DISTANCE;
8use lz4rip_core::MINMATCH;
9
10/// Trains an LZ4 dictionary from sample messages using the COVER
11/// algorithm.
12///
13/// # Example
14/// ```
15/// use lz4rip_encode::DictTrainer;
16///
17/// let mut trainer = DictTrainer::new(2048);
18/// for msg in &[b"hello world" as &[u8], b"hello rust", b"hello lz4"] {
19///     trainer.add_sample(msg);
20/// }
21/// let dict = trainer.train();
22/// let compressor = lz4rip_encode::DictCompressor::new(&dict);
23/// ```
24pub struct DictTrainer {
25    max_dict_size: usize,
26    samples: VecDeque<Vec<u8>>,
27    total_bytes: usize,
28}
29
30impl DictTrainer {
31    /// Create a trainer targeting `max_dict_size` bytes of output.
32    ///
33    /// Typical values: 2048 for small messages, 4096 for larger ones.
34    /// The dict is capped at 65535 bytes (LZ4 max match distance).
35    pub fn new(max_dict_size: usize) -> Self {
36        let max_dict_size = max_dict_size.min(MAX_DISTANCE);
37        DictTrainer {
38            max_dict_size,
39            samples: VecDeque::new(),
40            total_bytes: 0,
41        }
42    }
43
44    /// Add a training sample.
45    ///
46    /// Samples shorter than 4 bytes or longer than `max_dict_size` are
47    /// silently skipped.
48    pub fn add_sample(&mut self, data: &[u8]) {
49        if data.len() < MINMATCH || data.len() > self.max_dict_size {
50            return;
51        }
52
53        let budget = self.max_dict_size * 8;
54
55        while self.total_bytes + data.len() > budget && !self.samples.is_empty() {
56            self.total_bytes -= self.samples.pop_front().unwrap().len();
57        }
58
59        self.total_bytes += data.len();
60        self.samples.push_back(data.to_vec());
61    }
62
63    /// Number of samples added so far.
64    pub fn sample_count(&self) -> usize {
65        self.samples.len()
66    }
67
68    /// Total bytes of sample data added so far.
69    pub fn total_bytes(&self) -> usize {
70        self.total_bytes
71    }
72
73    /// Train a dictionary from the collected samples. Consumes the
74    /// trainer, freeing all sample data.
75    ///
76    /// Returns a raw byte buffer. If fewer than 2 samples were added,
77    /// returns an empty vec.
78    pub fn train(self) -> Vec<u8> {
79        if self.samples.len() < 2 {
80            return Vec::new();
81        }
82        let sample_refs: Vec<&[u8]> = self.samples.iter().map(|s| s.as_slice()).collect();
83        cover_select(&sample_refs, self.max_dict_size)
84    }
85}
86
87const D: usize = 8;
88const FREQ_BITS: usize = 16;
89const FREQ_SIZE: usize = 1 << FREQ_BITS;
90const FREQ_MASK: usize = FREQ_SIZE - 1;
91
92#[inline(always)]
93fn hash_dmer(data: &[u8]) -> usize {
94    let v = u64::from_le_bytes(data[..8].try_into().unwrap());
95    v.wrapping_mul(0x9E37_79B9_7F4A_7C15) as usize
96}
97
98fn cover_select(samples: &[&[u8]], dict_size: usize) -> Vec<u8> {
99    let mut concat = Vec::new();
100    let mut offsets = Vec::new();
101    for &sample in samples {
102        offsets.push(concat.len());
103        concat.extend_from_slice(sample);
104    }
105    offsets.push(concat.len());
106
107    if concat.len() < D {
108        return concat[..dict_size.min(concat.len())].to_vec();
109    }
110
111    let num_dmers = concat.len() - D + 1;
112
113    let mut hashes = vec![0u32; num_dmers];
114    for i in 0..num_dmers {
115        hashes[i] = (hash_dmer(&concat[i..i + D]) & FREQ_MASK) as u32;
116    }
117
118    let mut freqs = vec![0u32; FREQ_SIZE];
119    for s in 0..samples.len() {
120        let start = offsets[s];
121        let end = offsets[s + 1];
122        if end - start < D {
123            continue;
124        }
125        for i in start..end - D + 1 {
126            freqs[hashes[i] as usize] += 1;
127        }
128    }
129
130    let k = dict_size / 4;
131    if concat.len() < k {
132        return concat;
133    }
134
135    let seg_dmers = k - D + 1;
136    let mut used = vec![false; concat.len()];
137    let mut segments: Vec<(usize, u64)> = Vec::new();
138    let mut collected = 0usize;
139
140    while collected < dict_size {
141        let mut prefix = vec![0u64; num_dmers + 1];
142        prefix[0] = 0;
143        for i in 0..num_dmers {
144            prefix[i + 1] = prefix[i] + freqs[hashes[i] as usize] as u64;
145        }
146
147        let mut best_pos = 0;
148        let mut best_score = 0u64;
149        for pos in 0..=concat.len() - k {
150            if !used[pos] {
151                let score = prefix[pos + seg_dmers] - prefix[pos];
152                if score > best_score {
153                    best_score = score;
154                    best_pos = pos;
155                }
156            }
157        }
158
159        if best_score == 0 {
160            break;
161        }
162
163        segments.push((best_pos, best_score));
164
165        for i in best_pos..best_pos + seg_dmers {
166            freqs[hashes[i] as usize] = 0;
167        }
168        used[best_pos..best_pos + k].fill(true);
169
170        collected += k;
171    }
172
173    let mut dict = Vec::with_capacity(dict_size);
174    for &(pos, _) in segments.iter().rev() {
175        let end = (pos + k).min(concat.len());
176        dict.extend_from_slice(&concat[pos..end]);
177        if dict.len() >= dict_size {
178            break;
179        }
180    }
181    dict.truncate(dict_size);
182    dict
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    fn json_msg(i: u32) -> Vec<u8> {
190        alloc::format!(
191            r#"{{"ts":"2026-04-27T12:00:00.{i:04}Z","level":"INFO","service":"api-gw","trace":"{i:08x}","method":"GET","path":"/v1/users/{i:04}","status":200,"latency_ms":{lat},"region":"us-east-1"}}"#,
192            i = i,
193            lat = 10 + i % 490,
194        )
195        .into_bytes()
196    }
197
198    #[test]
199    fn train_produces_nonempty_dict() {
200        let mut trainer = DictTrainer::new(2048);
201        for i in 0..100 {
202            trainer.add_sample(&json_msg(i));
203        }
204        let dict = trainer.train();
205        assert!(!dict.is_empty(), "dict should not be empty");
206        assert!(dict.len() <= 2048, "dict should respect max size");
207    }
208
209    #[test]
210    fn skips_too_short() {
211        let mut trainer = DictTrainer::new(2048);
212        trainer.add_sample(b"hi");
213        assert_eq!(trainer.sample_count(), 0);
214    }
215
216    #[test]
217    fn skips_too_long() {
218        let mut trainer = DictTrainer::new(64);
219        trainer.add_sample(&[0u8; 100]);
220        assert_eq!(trainer.sample_count(), 0);
221    }
222
223    #[test]
224    fn evicts_old_samples() {
225        let mut trainer = DictTrainer::new(2048);
226        for i in 0..200 {
227            trainer.add_sample(&json_msg(i));
228        }
229        assert!(trainer.total_bytes() <= 2048 * 8);
230        assert!(trainer.sample_count() < 200);
231    }
232
233    #[test]
234    fn too_few_samples_returns_empty() {
235        let mut trainer = DictTrainer::new(2048);
236        trainer.add_sample(b"hello world");
237        assert!(trainer.train().is_empty());
238    }
239}