1use alloc::collections::VecDeque;
4use alloc::vec;
5use alloc::vec::Vec;
6
7use lz4rip_core::MAX_DISTANCE;
8use lz4rip_core::MINMATCH;
9
10pub struct DictTrainer {
25 max_dict_size: usize,
26 samples: VecDeque<Vec<u8>>,
27 total_bytes: usize,
28}
29
30impl DictTrainer {
31 pub fn new(max_dict_size: usize) -> Self {
36 let max_dict_size = max_dict_size.min(MAX_DISTANCE);
37 DictTrainer {
38 max_dict_size,
39 samples: VecDeque::new(),
40 total_bytes: 0,
41 }
42 }
43
44 pub fn add_sample(&mut self, data: &[u8]) {
49 if data.len() < MINMATCH || data.len() > self.max_dict_size {
50 return;
51 }
52
53 let budget = self.max_dict_size * 8;
54
55 while self.total_bytes + data.len() > budget && !self.samples.is_empty() {
56 self.total_bytes -= self.samples.pop_front().unwrap().len();
57 }
58
59 self.total_bytes += data.len();
60 self.samples.push_back(data.to_vec());
61 }
62
63 pub fn sample_count(&self) -> usize {
65 self.samples.len()
66 }
67
68 pub fn total_bytes(&self) -> usize {
70 self.total_bytes
71 }
72
73 pub fn train(self) -> Vec<u8> {
79 if self.samples.len() < 2 {
80 return Vec::new();
81 }
82 let sample_refs: Vec<&[u8]> = self.samples.iter().map(|s| s.as_slice()).collect();
83 cover_select(&sample_refs, self.max_dict_size)
84 }
85}
86
87const D: usize = 8;
88const FREQ_BITS: usize = 16;
89const FREQ_SIZE: usize = 1 << FREQ_BITS;
90const FREQ_MASK: usize = FREQ_SIZE - 1;
91
92#[inline(always)]
93fn hash_dmer(data: &[u8]) -> usize {
94 let v = u64::from_le_bytes(data[..8].try_into().unwrap());
95 v.wrapping_mul(0x9E37_79B9_7F4A_7C15) as usize
96}
97
98fn cover_select(samples: &[&[u8]], dict_size: usize) -> Vec<u8> {
99 let mut concat = Vec::new();
100 let mut offsets = Vec::new();
101 for &sample in samples {
102 offsets.push(concat.len());
103 concat.extend_from_slice(sample);
104 }
105 offsets.push(concat.len());
106
107 if concat.len() < D {
108 return concat[..dict_size.min(concat.len())].to_vec();
109 }
110
111 let num_dmers = concat.len() - D + 1;
112
113 let mut hashes = vec![0u32; num_dmers];
114 for i in 0..num_dmers {
115 hashes[i] = (hash_dmer(&concat[i..i + D]) & FREQ_MASK) as u32;
116 }
117
118 let mut freqs = vec![0u32; FREQ_SIZE];
119 for s in 0..samples.len() {
120 let start = offsets[s];
121 let end = offsets[s + 1];
122 if end - start < D {
123 continue;
124 }
125 for i in start..end - D + 1 {
126 freqs[hashes[i] as usize] += 1;
127 }
128 }
129
130 let k = dict_size / 4;
131 if concat.len() < k {
132 return concat;
133 }
134
135 let seg_dmers = k - D + 1;
136 let mut used = vec![false; concat.len()];
137 let mut segments: Vec<(usize, u64)> = Vec::new();
138 let mut collected = 0usize;
139
140 while collected < dict_size {
141 let mut prefix = vec![0u64; num_dmers + 1];
142 prefix[0] = 0;
143 for i in 0..num_dmers {
144 prefix[i + 1] = prefix[i] + freqs[hashes[i] as usize] as u64;
145 }
146
147 let mut best_pos = 0;
148 let mut best_score = 0u64;
149 for pos in 0..=concat.len() - k {
150 if !used[pos] {
151 let score = prefix[pos + seg_dmers] - prefix[pos];
152 if score > best_score {
153 best_score = score;
154 best_pos = pos;
155 }
156 }
157 }
158
159 if best_score == 0 {
160 break;
161 }
162
163 segments.push((best_pos, best_score));
164
165 for i in best_pos..best_pos + seg_dmers {
166 freqs[hashes[i] as usize] = 0;
167 }
168 used[best_pos..best_pos + k].fill(true);
169
170 collected += k;
171 }
172
173 let mut dict = Vec::with_capacity(dict_size);
174 for &(pos, _) in segments.iter().rev() {
175 let end = (pos + k).min(concat.len());
176 dict.extend_from_slice(&concat[pos..end]);
177 if dict.len() >= dict_size {
178 break;
179 }
180 }
181 dict.truncate(dict_size);
182 dict
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 fn json_msg(i: u32) -> Vec<u8> {
190 alloc::format!(
191 r#"{{"ts":"2026-04-27T12:00:00.{i:04}Z","level":"INFO","service":"api-gw","trace":"{i:08x}","method":"GET","path":"/v1/users/{i:04}","status":200,"latency_ms":{lat},"region":"us-east-1"}}"#,
192 i = i,
193 lat = 10 + i % 490,
194 )
195 .into_bytes()
196 }
197
198 #[test]
199 fn train_produces_nonempty_dict() {
200 let mut trainer = DictTrainer::new(2048);
201 for i in 0..100 {
202 trainer.add_sample(&json_msg(i));
203 }
204 let dict = trainer.train();
205 assert!(!dict.is_empty(), "dict should not be empty");
206 assert!(dict.len() <= 2048, "dict should respect max size");
207 }
208
209 #[test]
210 fn skips_too_short() {
211 let mut trainer = DictTrainer::new(2048);
212 trainer.add_sample(b"hi");
213 assert_eq!(trainer.sample_count(), 0);
214 }
215
216 #[test]
217 fn skips_too_long() {
218 let mut trainer = DictTrainer::new(64);
219 trainer.add_sample(&[0u8; 100]);
220 assert_eq!(trainer.sample_count(), 0);
221 }
222
223 #[test]
224 fn evicts_old_samples() {
225 let mut trainer = DictTrainer::new(2048);
226 for i in 0..200 {
227 trainer.add_sample(&json_msg(i));
228 }
229 assert!(trainer.total_bytes() <= 2048 * 8);
230 assert!(trainer.sample_count() < 200);
231 }
232
233 #[test]
234 fn too_few_samples_returns_empty() {
235 let mut trainer = DictTrainer::new(2048);
236 trainer.add_sample(b"hello world");
237 assert!(trainer.train().is_empty());
238 }
239}