Skip to main content

datacortex_core/model/
match_model.rs

1//! MatchModel -- ring buffer + hash table for longest match prediction.
2//!
3//! Phase 5: Multi-candidate match finding with extended length verification.
4//! Finds up to 4 candidates via different hash functions, verifies each,
5//! and uses the one with the longest actual match length.
6//!
7//! CRITICAL V2 LESSONS:
8//! - Rolling hash must NOT be cumulative
9//! - Confidence ramp must be linear (not step function)
10//! - Length tracking must reset on mismatch
11
12/// Default size of the ring buffer (16MB).
13const DEFAULT_BUF_SIZE: usize = 16 * 1024 * 1024;
14
15/// Default size of the hash table for match finding (8M entries).
16const DEFAULT_HASH_SIZE: usize = 8 * 1024 * 1024;
17
18/// Minimum match length before we start predicting.
19const MIN_MATCH: usize = 2;
20
21/// Maximum match length for confidence calculation.
22const MAX_MATCH_FOR_CONF: usize = 64;
23
24/// Maximum bytes to scan forward when verifying match length.
25const MAX_VERIFY_LEN: usize = 128;
26
27/// Match model: finds longest match in history, predicts from continuation.
28pub struct MatchModel {
29    /// Ring buffer of past bytes.
30    buf: Vec<u8>,
31    /// Current write position in ring buffer.
32    buf_pos: usize,
33    /// Total bytes written (for position validity).
34    total_written: usize,
35    /// Hash table: hash -> position in ring buffer.
36    hash_table: Vec<u32>,
37    /// Ring buffer size (must be power of 2).
38    buf_size: usize,
39    /// Hash table size (must be power of 2).
40    hash_size: usize,
41    /// Current match position in ring buffer (-1 = no match).
42    match_pos: i64,
43    /// Current match length.
44    match_len: usize,
45    /// Bit position within the matched byte (0-7).
46    match_bpos: u8,
47    /// Rolling hash of recent bytes (non-cumulative).
48    hash: u32,
49    /// Last predicted probability.
50    last_p: u32,
51}
52
53impl MatchModel {
54    pub fn new() -> Self {
55        Self::with_sizes(DEFAULT_BUF_SIZE, DEFAULT_HASH_SIZE)
56    }
57
58    /// Create a match model with custom ring buffer and hash table sizes.
59    /// Both sizes must be powers of 2.
60    pub fn with_sizes(buf_size: usize, hash_size: usize) -> Self {
61        debug_assert!(buf_size.is_power_of_two());
62        debug_assert!(hash_size.is_power_of_two());
63        MatchModel {
64            buf: vec![0u8; buf_size],
65            buf_pos: 0,
66            total_written: 0,
67            hash_table: vec![0u32; hash_size],
68            buf_size,
69            hash_size,
70            match_pos: -1,
71            match_len: 0,
72            match_bpos: 0,
73            hash: 0,
74            last_p: 2048,
75        }
76    }
77
78    /// Predict probability of bit=1 based on match continuation.
79    /// Returns 12-bit probability in [1, 4095].
80    ///
81    /// `c0`: partial byte being decoded (1-255).
82    /// `bpos`: bit position (0-7).
83    /// `c1`: last completed byte.
84    /// `c2`: second-to-last byte.
85    /// `c3`: third-to-last byte.
86    #[inline]
87    pub fn predict(&mut self, _c0: u32, bpos: u8, c1: u8, c2: u8, c3: u8) -> u32 {
88        if bpos == 0 {
89            // At byte boundary: look for new match or extend existing one.
90            self.find_match(c1, c2, c3);
91        }
92
93        if self.match_pos < 0 || self.match_len < MIN_MATCH {
94            self.last_p = 2048;
95            return 2048; // No match -> neutral prediction
96        }
97
98        // Predict from match continuation.
99        let mpos = self.match_pos as usize & (self.buf_size - 1);
100        let match_byte = self.buf[mpos];
101        let match_bit = (match_byte >> (7 - bpos)) & 1;
102
103        // Improved confidence ramp: slow start, steep middle, saturating.
104        // Uses a piecewise function:
105        //   len 2-3: 80 per byte (tentative)
106        //   len 4-8: 200 per byte (building confidence)
107        //   len 9-32: 120 per byte (strong)
108        //   len 33+: 60 per byte (saturating)
109        let len = self.match_len.min(MAX_MATCH_FOR_CONF);
110        let conf = if len <= 3 {
111            (len as u32) * 80
112        } else if len <= 8 {
113            240 + ((len as u32 - 3) * 200)
114        } else if len <= 32 {
115            1240 + ((len as u32 - 8) * 120)
116        } else {
117            4120u32.min(1240 + 2880 + ((len as u32 - 32) * 60))
118        };
119        let conf = conf.min(3800);
120
121        let p = if match_bit == 1 {
122            2048 + conf
123        } else {
124            2048u32.saturating_sub(conf)
125        };
126        let p = p.clamp(1, 4095);
127        self.last_p = p;
128        p
129    }
130
131    /// Update match model after observing `bit`.
132    ///
133    /// `bit`: observed bit.
134    /// `bpos`: bit position (0-7).
135    /// `c0`: partial byte (after this bit).
136    /// `c1`: last completed byte (if bpos==7, this byte just completed).
137    #[inline]
138    pub fn update(&mut self, bit: u8, bpos: u8, c0: u32, c1: u8, c2: u8) {
139        // Check if match continues.
140        if self.match_pos >= 0 {
141            let mpos = self.match_pos as usize & (self.buf_size - 1);
142            let match_bit = (self.buf[mpos] >> (7 - self.match_bpos)) & 1;
143            if match_bit == bit {
144                self.match_bpos += 1;
145                if self.match_bpos >= 8 {
146                    self.match_bpos = 0;
147                    self.match_len += 1;
148                    self.match_pos = (self.match_pos + 1) & (self.buf_size as i64 - 1);
149                }
150            } else {
151                // Mismatch: reset match.
152                self.match_pos = -1;
153                self.match_len = 0;
154                self.match_bpos = 0;
155            }
156        }
157
158        // At byte boundary (after last bit of byte): store byte and update hash.
159        if bpos == 7 {
160            let byte = (c0 & 0xFF) as u8;
161            self.buf[self.buf_pos] = byte;
162
163            // Non-cumulative rolling hash: hash of last 4 bytes for better match finding.
164            // CRITICAL: must NOT be cumulative (V2 bug lesson).
165            self.hash = hash4(byte, c1, c2, self.prev_byte(3));
166
167            // Store position in hash table using primary (4-byte) hash.
168            let idx = self.hash as usize & (self.hash_size - 1);
169            self.hash_table[idx] = self.buf_pos as u32;
170
171            // Store using 3-byte hash for fallback match opportunities.
172            let h3 = hash3(byte, c1, c2);
173            let idx3 = h3 as usize & (self.hash_size - 1);
174            // Only store if slot is empty (don't overwrite better 4-byte matches)
175            if self.hash_table[idx3] == 0 || self.total_written < 4 {
176                self.hash_table[idx3] = self.buf_pos as u32;
177            }
178
179            // Store using 5-byte hash for longer match opportunities.
180            let c3 = self.prev_byte(3);
181            let c4 = self.prev_byte(4);
182            let h5 = hash5(byte, c1, c2, c3, c4);
183            let idx5 = h5 as usize & (self.hash_size - 1);
184            self.hash_table[idx5] = self.buf_pos as u32;
185
186            self.buf_pos = (self.buf_pos + 1) & (self.buf_size - 1);
187            self.total_written += 1;
188        }
189    }
190
191    /// Get byte from N positions before current write position.
192    #[inline]
193    fn prev_byte(&self, n: usize) -> u8 {
194        if self.total_written >= n {
195            self.buf[(self.buf_pos.wrapping_sub(n)) & (self.buf_size - 1)]
196        } else {
197            0
198        }
199    }
200
201    /// Verify actual match length at a candidate position.
202    /// Returns the number of matching bytes starting from the candidate + 1
203    /// (i.e., matching bytes after the context that was used to find it).
204    #[inline]
205    fn verify_match_length(&self, candidate_pos: usize) -> usize {
206        let verify_start = (candidate_pos + 1) & (self.buf_size - 1);
207        let data_start = self.buf_pos; // current write position = where next byte goes
208        let max_len = self.total_written.min(MAX_VERIFY_LEN);
209        let mut len = 0;
210        while len < max_len {
211            let mp = (verify_start + len) & (self.buf_size - 1);
212            let dp = (data_start + len) & (self.buf_size - 1);
213            // Don't compare beyond what we've written or wrap into the match itself.
214            if mp == self.buf_pos {
215                break;
216            }
217            if self.buf[mp] != self.buf[dp] {
218                break;
219            }
220            len += 1;
221        }
222        len
223    }
224
225    /// Find the best match among multiple candidates using different hash functions.
226    fn find_match(&mut self, c1: u8, c2: u8, c3: u8) {
227        if self.total_written < 3 {
228            self.match_pos = -1;
229            self.match_len = 0;
230            return;
231        }
232
233        let c4 = self.prev_byte(3); // 4th-to-last byte
234        let c5 = self.prev_byte(4); // 5th-to-last byte
235
236        // Collect up to 4 candidate positions from different hash functions.
237        // Each candidate is verified for actual match quality.
238        let mut best_pos: i64 = -1;
239        let mut best_len: usize = 0;
240
241        // Candidate 1: 5-byte hash (best precision)
242        if self.total_written >= 5 {
243            let h5 = hash5(c1, c2, c3, c4, c5);
244            let idx5 = h5 as usize & (self.hash_size - 1);
245            let cand = self.hash_table[idx5] as usize;
246            self.check_candidate(cand, c1, c2, c3, &mut best_pos, &mut best_len);
247        }
248
249        // Candidate 2: 4-byte hash
250        let h4 = hash4(c1, c2, c3, c4);
251        let idx4 = h4 as usize & (self.hash_size - 1);
252        let cand4 = self.hash_table[idx4] as usize;
253        self.check_candidate(cand4, c1, c2, c3, &mut best_pos, &mut best_len);
254
255        // Candidate 3: 3-byte hash (wider net)
256        let h3 = hash3(c1, c2, c3);
257        let idx3 = h3 as usize & (self.hash_size - 1);
258        let cand3 = self.hash_table[idx3] as usize;
259        self.check_candidate(cand3, c1, c2, c3, &mut best_pos, &mut best_len);
260
261        // Candidate 4: alternate 4-byte hash with different mixing
262        let h4b = hash4_alt(c1, c2, c3, c4);
263        let idx4b = h4b as usize & (self.hash_size - 1);
264        let cand4b = self.hash_table[idx4b] as usize;
265        self.check_candidate(cand4b, c1, c2, c3, &mut best_pos, &mut best_len);
266
267        if best_len >= MIN_MATCH {
268            self.match_pos = best_pos;
269            self.match_len = best_len;
270            self.match_bpos = 0;
271        } else {
272            self.match_pos = -1;
273            self.match_len = 0;
274        }
275    }
276
277    /// Check a candidate position and update best match if it's better.
278    #[inline]
279    fn check_candidate(
280        &self,
281        candidate_pos: usize,
282        c1: u8,
283        c2: u8,
284        c3: u8,
285        best_pos: &mut i64,
286        best_len: &mut usize,
287    ) {
288        let bp = candidate_pos;
289        let p1 = bp.wrapping_sub(1) & (self.buf_size - 1);
290        let p2 = bp.wrapping_sub(2) & (self.buf_size - 1);
291
292        // Verify the context bytes match.
293        if self.buf[bp] == c1 && self.buf[p1] == c2 && self.buf[p2] == c3 {
294            // Context matches. Now verify how far the match extends forward.
295            let fwd_len = self.verify_match_length(bp);
296            let total_match = 3 + fwd_len; // 3 context bytes + forward extension
297            if total_match > *best_len {
298                *best_len = total_match;
299                *best_pos = ((bp + 1) & (self.buf_size - 1)) as i64;
300            }
301        }
302    }
303
304    /// Return the quantized match length for mixer context.
305    /// 0=no match, 1=short, 2=medium, 3=long.
306    #[inline]
307    pub fn match_length_quantized(&self) -> u8 {
308        if self.match_pos < 0 || self.match_len < MIN_MATCH {
309            0
310        } else if self.match_len < 8 {
311            1
312        } else if self.match_len < 32 {
313            2
314        } else {
315            3
316        }
317    }
318
319    /// Return the last predicted probability (for APM context).
320    #[inline]
321    pub fn last_prediction(&self) -> u32 {
322        self.last_p
323    }
324}
325
326impl Default for MatchModel {
327    fn default() -> Self {
328        Self::new()
329    }
330}
331
332/// Non-cumulative hash of 3 bytes. MUST NOT accumulate across calls (V2 lesson).
333#[inline]
334fn hash3(b1: u8, b2: u8, b3: u8) -> u32 {
335    let mut h: u32 = b3 as u32;
336    h = h.wrapping_mul(0x01000193) ^ b2 as u32;
337    h = h.wrapping_mul(0x01000193) ^ b1 as u32;
338    h
339}
340
341/// Non-cumulative hash of 4 bytes for better match precision.
342#[inline]
343fn hash4(b1: u8, b2: u8, b3: u8, b4: u8) -> u32 {
344    let mut h: u32 = b4 as u32;
345    h = h.wrapping_mul(0x01000193) ^ b3 as u32;
346    h = h.wrapping_mul(0x01000193) ^ b2 as u32;
347    h = h.wrapping_mul(0x01000193) ^ b1 as u32;
348    h
349}
350
351/// Alternate 4-byte hash with different constants for a second slot.
352#[inline]
353fn hash4_alt(b1: u8, b2: u8, b3: u8, b4: u8) -> u32 {
354    let mut h: u32 = 0x9E3779B9; // golden ratio
355    h ^= b4 as u32;
356    h = h.wrapping_mul(0x01000193);
357    h ^= b3 as u32;
358    h = h.wrapping_mul(0x01000193);
359    h ^= b2 as u32;
360    h = h.wrapping_mul(0x01000193);
361    h ^= b1 as u32;
362    h
363}
364
365/// Non-cumulative hash of 5 bytes for longer context matching.
366#[inline]
367fn hash5(b1: u8, b2: u8, b3: u8, b4: u8, b5: u8) -> u32 {
368    let mut h: u32 = b5 as u32;
369    h = h.wrapping_mul(0x01000193) ^ b4 as u32;
370    h = h.wrapping_mul(0x01000193) ^ b3 as u32;
371    h = h.wrapping_mul(0x01000193) ^ b2 as u32;
372    h = h.wrapping_mul(0x01000193) ^ b1 as u32;
373    h
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn new_model_predicts_neutral() {
382        let mut mm = MatchModel::new();
383        let p = mm.predict(1, 0, 0, 0, 0);
384        assert_eq!(p, 2048);
385    }
386
387    #[test]
388    fn prediction_in_range() {
389        let mut mm = MatchModel::new();
390        // Feed some bytes first.
391        for i in 0..100u8 {
392            for bpos in 0..8u8 {
393                let bit = (i >> (7 - bpos)) & 1;
394                let c0 = if bpos == 7 {
395                    (i as u32) | 0x100
396                } else {
397                    1u32 << (bpos + 1)
398                };
399                mm.update(bit, bpos, c0, i.wrapping_sub(1), i.wrapping_sub(2));
400            }
401        }
402        let p = mm.predict(1, 0, 99, 98, 97);
403        assert!((1..=4095).contains(&p));
404    }
405
406    #[test]
407    fn hash3_not_cumulative() {
408        // Same inputs should give same hash regardless of call order.
409        let h1 = hash3(10, 20, 30);
410        let h2 = hash3(10, 20, 30);
411        assert_eq!(h1, h2);
412
413        // Different inputs should differ.
414        let h3 = hash3(11, 20, 30);
415        assert_ne!(h1, h3);
416    }
417
418    #[test]
419    fn hash4_not_cumulative() {
420        let h1 = hash4(10, 20, 30, 40);
421        let h2 = hash4(10, 20, 30, 40);
422        assert_eq!(h1, h2);
423
424        let h3 = hash4(11, 20, 30, 40);
425        assert_ne!(h1, h3);
426    }
427
428    #[test]
429    fn hash5_not_cumulative() {
430        let h1 = hash5(10, 20, 30, 40, 50);
431        let h2 = hash5(10, 20, 30, 40, 50);
432        assert_eq!(h1, h2);
433
434        let h3 = hash5(11, 20, 30, 40, 50);
435        assert_ne!(h1, h3);
436    }
437
438    #[test]
439    fn hash4_alt_differs_from_hash4() {
440        let h1 = hash4(10, 20, 30, 40);
441        let h2 = hash4_alt(10, 20, 30, 40);
442        assert_ne!(h1, h2, "alt hash should differ from primary");
443    }
444
445    #[test]
446    fn match_quantization() {
447        let mm = MatchModel::new();
448        assert_eq!(mm.match_length_quantized(), 0); // no match
449    }
450}