1const DEFAULT_BUF_SIZE: usize = 16 * 1024 * 1024;
14
15const DEFAULT_HASH_SIZE: usize = 8 * 1024 * 1024;
17
18const MIN_MATCH: usize = 2;
20
21const MAX_MATCH_FOR_CONF: usize = 64;
23
24const MAX_VERIFY_LEN: usize = 128;
26
27pub struct MatchModel {
29 buf: Vec<u8>,
31 buf_pos: usize,
33 total_written: usize,
35 hash_table: Vec<u32>,
37 buf_size: usize,
39 hash_size: usize,
41 match_pos: i64,
43 match_len: usize,
45 match_bpos: u8,
47 hash: u32,
49 last_p: u32,
51}
52
53impl MatchModel {
54 pub fn new() -> Self {
55 Self::with_sizes(DEFAULT_BUF_SIZE, DEFAULT_HASH_SIZE)
56 }
57
58 pub fn with_sizes(buf_size: usize, hash_size: usize) -> Self {
61 debug_assert!(buf_size.is_power_of_two());
62 debug_assert!(hash_size.is_power_of_two());
63 MatchModel {
64 buf: vec![0u8; buf_size],
65 buf_pos: 0,
66 total_written: 0,
67 hash_table: vec![0u32; hash_size],
68 buf_size,
69 hash_size,
70 match_pos: -1,
71 match_len: 0,
72 match_bpos: 0,
73 hash: 0,
74 last_p: 2048,
75 }
76 }
77
78 #[inline]
87 pub fn predict(&mut self, _c0: u32, bpos: u8, c1: u8, c2: u8, c3: u8) -> u32 {
88 if bpos == 0 {
89 self.find_match(c1, c2, c3);
91 }
92
93 if self.match_pos < 0 || self.match_len < MIN_MATCH {
94 self.last_p = 2048;
95 return 2048; }
97
98 let mpos = self.match_pos as usize & (self.buf_size - 1);
100 let match_byte = self.buf[mpos];
101 let match_bit = (match_byte >> (7 - bpos)) & 1;
102
103 let len = self.match_len.min(MAX_MATCH_FOR_CONF);
110 let conf = if len <= 3 {
111 (len as u32) * 80
112 } else if len <= 8 {
113 240 + ((len as u32 - 3) * 200)
114 } else if len <= 32 {
115 1240 + ((len as u32 - 8) * 120)
116 } else {
117 4120u32.min(1240 + 2880 + ((len as u32 - 32) * 60))
118 };
119 let conf = conf.min(3800);
120
121 let p = if match_bit == 1 {
122 2048 + conf
123 } else {
124 2048u32.saturating_sub(conf)
125 };
126 let p = p.clamp(1, 4095);
127 self.last_p = p;
128 p
129 }
130
131 #[inline]
138 pub fn update(&mut self, bit: u8, bpos: u8, c0: u32, c1: u8, c2: u8) {
139 if self.match_pos >= 0 {
141 let mpos = self.match_pos as usize & (self.buf_size - 1);
142 let match_bit = (self.buf[mpos] >> (7 - self.match_bpos)) & 1;
143 if match_bit == bit {
144 self.match_bpos += 1;
145 if self.match_bpos >= 8 {
146 self.match_bpos = 0;
147 self.match_len += 1;
148 self.match_pos = (self.match_pos + 1) & (self.buf_size as i64 - 1);
149 }
150 } else {
151 self.match_pos = -1;
153 self.match_len = 0;
154 self.match_bpos = 0;
155 }
156 }
157
158 if bpos == 7 {
160 let byte = (c0 & 0xFF) as u8;
161 self.buf[self.buf_pos] = byte;
162
163 self.hash = hash4(byte, c1, c2, self.prev_byte(3));
166
167 let idx = self.hash as usize & (self.hash_size - 1);
169 self.hash_table[idx] = self.buf_pos as u32;
170
171 let h3 = hash3(byte, c1, c2);
173 let idx3 = h3 as usize & (self.hash_size - 1);
174 if self.hash_table[idx3] == 0 || self.total_written < 4 {
176 self.hash_table[idx3] = self.buf_pos as u32;
177 }
178
179 let c3 = self.prev_byte(3);
181 let c4 = self.prev_byte(4);
182 let h5 = hash5(byte, c1, c2, c3, c4);
183 let idx5 = h5 as usize & (self.hash_size - 1);
184 self.hash_table[idx5] = self.buf_pos as u32;
185
186 self.buf_pos = (self.buf_pos + 1) & (self.buf_size - 1);
187 self.total_written += 1;
188 }
189 }
190
191 #[inline]
193 fn prev_byte(&self, n: usize) -> u8 {
194 if self.total_written >= n {
195 self.buf[(self.buf_pos.wrapping_sub(n)) & (self.buf_size - 1)]
196 } else {
197 0
198 }
199 }
200
201 #[inline]
205 fn verify_match_length(&self, candidate_pos: usize) -> usize {
206 let verify_start = (candidate_pos + 1) & (self.buf_size - 1);
207 let data_start = self.buf_pos; let max_len = self.total_written.min(MAX_VERIFY_LEN);
209 let mut len = 0;
210 while len < max_len {
211 let mp = (verify_start + len) & (self.buf_size - 1);
212 let dp = (data_start + len) & (self.buf_size - 1);
213 if mp == self.buf_pos {
215 break;
216 }
217 if self.buf[mp] != self.buf[dp] {
218 break;
219 }
220 len += 1;
221 }
222 len
223 }
224
225 fn find_match(&mut self, c1: u8, c2: u8, c3: u8) {
227 if self.total_written < 3 {
228 self.match_pos = -1;
229 self.match_len = 0;
230 return;
231 }
232
233 let c4 = self.prev_byte(3); let c5 = self.prev_byte(4); let mut best_pos: i64 = -1;
239 let mut best_len: usize = 0;
240
241 if self.total_written >= 5 {
243 let h5 = hash5(c1, c2, c3, c4, c5);
244 let idx5 = h5 as usize & (self.hash_size - 1);
245 let cand = self.hash_table[idx5] as usize;
246 self.check_candidate(cand, c1, c2, c3, &mut best_pos, &mut best_len);
247 }
248
249 let h4 = hash4(c1, c2, c3, c4);
251 let idx4 = h4 as usize & (self.hash_size - 1);
252 let cand4 = self.hash_table[idx4] as usize;
253 self.check_candidate(cand4, c1, c2, c3, &mut best_pos, &mut best_len);
254
255 let h3 = hash3(c1, c2, c3);
257 let idx3 = h3 as usize & (self.hash_size - 1);
258 let cand3 = self.hash_table[idx3] as usize;
259 self.check_candidate(cand3, c1, c2, c3, &mut best_pos, &mut best_len);
260
261 let h4b = hash4_alt(c1, c2, c3, c4);
263 let idx4b = h4b as usize & (self.hash_size - 1);
264 let cand4b = self.hash_table[idx4b] as usize;
265 self.check_candidate(cand4b, c1, c2, c3, &mut best_pos, &mut best_len);
266
267 if best_len >= MIN_MATCH {
268 self.match_pos = best_pos;
269 self.match_len = best_len;
270 self.match_bpos = 0;
271 } else {
272 self.match_pos = -1;
273 self.match_len = 0;
274 }
275 }
276
277 #[inline]
279 fn check_candidate(
280 &self,
281 candidate_pos: usize,
282 c1: u8,
283 c2: u8,
284 c3: u8,
285 best_pos: &mut i64,
286 best_len: &mut usize,
287 ) {
288 let bp = candidate_pos;
289 let p1 = bp.wrapping_sub(1) & (self.buf_size - 1);
290 let p2 = bp.wrapping_sub(2) & (self.buf_size - 1);
291
292 if self.buf[bp] == c1 && self.buf[p1] == c2 && self.buf[p2] == c3 {
294 let fwd_len = self.verify_match_length(bp);
296 let total_match = 3 + fwd_len; if total_match > *best_len {
298 *best_len = total_match;
299 *best_pos = ((bp + 1) & (self.buf_size - 1)) as i64;
300 }
301 }
302 }
303
304 #[inline]
307 pub fn match_length_quantized(&self) -> u8 {
308 if self.match_pos < 0 || self.match_len < MIN_MATCH {
309 0
310 } else if self.match_len < 8 {
311 1
312 } else if self.match_len < 32 {
313 2
314 } else {
315 3
316 }
317 }
318
319 #[inline]
321 pub fn last_prediction(&self) -> u32 {
322 self.last_p
323 }
324}
325
326impl Default for MatchModel {
327 fn default() -> Self {
328 Self::new()
329 }
330}
331
332#[inline]
334fn hash3(b1: u8, b2: u8, b3: u8) -> u32 {
335 let mut h: u32 = b3 as u32;
336 h = h.wrapping_mul(0x01000193) ^ b2 as u32;
337 h = h.wrapping_mul(0x01000193) ^ b1 as u32;
338 h
339}
340
341#[inline]
343fn hash4(b1: u8, b2: u8, b3: u8, b4: u8) -> u32 {
344 let mut h: u32 = b4 as u32;
345 h = h.wrapping_mul(0x01000193) ^ b3 as u32;
346 h = h.wrapping_mul(0x01000193) ^ b2 as u32;
347 h = h.wrapping_mul(0x01000193) ^ b1 as u32;
348 h
349}
350
351#[inline]
353fn hash4_alt(b1: u8, b2: u8, b3: u8, b4: u8) -> u32 {
354 let mut h: u32 = 0x9E3779B9; h ^= b4 as u32;
356 h = h.wrapping_mul(0x01000193);
357 h ^= b3 as u32;
358 h = h.wrapping_mul(0x01000193);
359 h ^= b2 as u32;
360 h = h.wrapping_mul(0x01000193);
361 h ^= b1 as u32;
362 h
363}
364
365#[inline]
367fn hash5(b1: u8, b2: u8, b3: u8, b4: u8, b5: u8) -> u32 {
368 let mut h: u32 = b5 as u32;
369 h = h.wrapping_mul(0x01000193) ^ b4 as u32;
370 h = h.wrapping_mul(0x01000193) ^ b3 as u32;
371 h = h.wrapping_mul(0x01000193) ^ b2 as u32;
372 h = h.wrapping_mul(0x01000193) ^ b1 as u32;
373 h
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn new_model_predicts_neutral() {
382 let mut mm = MatchModel::new();
383 let p = mm.predict(1, 0, 0, 0, 0);
384 assert_eq!(p, 2048);
385 }
386
387 #[test]
388 fn prediction_in_range() {
389 let mut mm = MatchModel::new();
390 for i in 0..100u8 {
392 for bpos in 0..8u8 {
393 let bit = (i >> (7 - bpos)) & 1;
394 let c0 = if bpos == 7 {
395 (i as u32) | 0x100
396 } else {
397 1u32 << (bpos + 1)
398 };
399 mm.update(bit, bpos, c0, i.wrapping_sub(1), i.wrapping_sub(2));
400 }
401 }
402 let p = mm.predict(1, 0, 99, 98, 97);
403 assert!((1..=4095).contains(&p));
404 }
405
406 #[test]
407 fn hash3_not_cumulative() {
408 let h1 = hash3(10, 20, 30);
410 let h2 = hash3(10, 20, 30);
411 assert_eq!(h1, h2);
412
413 let h3 = hash3(11, 20, 30);
415 assert_ne!(h1, h3);
416 }
417
418 #[test]
419 fn hash4_not_cumulative() {
420 let h1 = hash4(10, 20, 30, 40);
421 let h2 = hash4(10, 20, 30, 40);
422 assert_eq!(h1, h2);
423
424 let h3 = hash4(11, 20, 30, 40);
425 assert_ne!(h1, h3);
426 }
427
428 #[test]
429 fn hash5_not_cumulative() {
430 let h1 = hash5(10, 20, 30, 40, 50);
431 let h2 = hash5(10, 20, 30, 40, 50);
432 assert_eq!(h1, h2);
433
434 let h3 = hash5(11, 20, 30, 40, 50);
435 assert_ne!(h1, h3);
436 }
437
438 #[test]
439 fn hash4_alt_differs_from_hash4() {
440 let h1 = hash4(10, 20, 30, 40);
441 let h2 = hash4_alt(10, 20, 30, 40);
442 assert_ne!(h1, h2, "alt hash should differ from primary");
443 }
444
445 #[test]
446 fn match_quantization() {
447 let mm = MatchModel::new();
448 assert_eq!(mm.match_length_quantized(), 0); }
450}