Skip to main content

jxl_encoder/modular/
predictor.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! Predictor implementations for modular encoding.
6//!
7//! Predictors estimate the value of a pixel based on its neighbors.
8//! The prediction residual (actual - predicted) is what gets entropy coded.
9
10use super::channel::Channel;
11
12/// Available predictors for modular encoding.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
14#[repr(u8)]
15pub enum Predictor {
16    /// Always predicts 0.
17    #[default]
18    Zero = 0,
19    /// Uses the left neighbor (west).
20    Left = 1,
21    /// Uses the top neighbor (north).
22    Top = 2,
23    /// Average of left and top.
24    Average0 = 3,
25    /// Select between left/top based on top-left.
26    Select = 4,
27    /// Gradient: left + top - topleft (clamped).
28    Gradient = 5,
29    /// Weighted average favoring left.
30    Weighted = 6,
31    /// Top-right neighbor.
32    TopRight = 7,
33    /// Top-left neighbor.
34    TopLeft = 8,
35    /// Left-left neighbor (2 pixels left).
36    LeftLeft = 9,
37    /// Average of average and gradient.
38    Average1 = 10,
39    /// Average of average and left.
40    Average2 = 11,
41    /// Average of average and top.
42    Average3 = 12,
43    /// Average of top and top-right.
44    Average4 = 13,
45}
46
47impl Predictor {
48    /// Number of simple predictors (excluding weighted/variable).
49    pub const NUM_SIMPLE: usize = 14;
50
51    /// Returns all simple predictors.
52    pub fn all_simple() -> &'static [Predictor] {
53        &[
54            Predictor::Zero,
55            Predictor::Left,
56            Predictor::Top,
57            Predictor::Average0,
58            Predictor::Select,
59            Predictor::Gradient,
60            Predictor::Weighted,
61            Predictor::TopRight,
62            Predictor::TopLeft,
63            Predictor::LeftLeft,
64            Predictor::Average1,
65            Predictor::Average2,
66            Predictor::Average3,
67            Predictor::Average4,
68        ]
69    }
70
71    /// Predicts the value at (x, y) using this predictor.
72    #[inline]
73    pub fn predict(self, channel: &Channel, x: usize, y: usize) -> i32 {
74        let neighbors = Neighbors::gather(channel, x, y);
75        self.predict_from_neighbors(&neighbors)
76    }
77
78    /// Predicts from pre-gathered neighbor values.
79    #[inline]
80    pub fn predict_from_neighbors(self, n: &Neighbors) -> i32 {
81        match self {
82            Predictor::Zero => 0,
83            Predictor::Left => n.w,
84            Predictor::Top => n.n,
85            Predictor::Average0 => (n.w + n.n) / 2,
86            Predictor::Select => {
87                // Median-like selection
88                if n.w.abs_diff(n.nw) < n.n.abs_diff(n.nw) {
89                    n.n
90                } else {
91                    n.w
92                }
93            }
94            Predictor::Gradient => {
95                // Clamped gradient: W + N - NW, clamped to [min(W,N), max(W,N)]
96                let gradient = n.w.saturating_add(n.n).saturating_sub(n.nw);
97                gradient.clamp(n.w.min(n.n), n.w.max(n.n))
98            }
99            Predictor::Weighted => {
100                // Simplified weighted predictor (full version uses adaptive weights)
101                // This is a placeholder - full weighted uses WP state
102                let gradient = n.w.saturating_add(n.n).saturating_sub(n.nw);
103                gradient.clamp(n.w.min(n.n), n.w.max(n.n))
104            }
105            Predictor::TopRight => n.ne,
106            Predictor::TopLeft => n.nw,
107            Predictor::LeftLeft => n.ww,
108            Predictor::Average1 => {
109                let avg = (n.w + n.n) / 2;
110                let grad = n.w.saturating_add(n.n).saturating_sub(n.nw);
111                (avg + grad) / 2
112            }
113            Predictor::Average2 => {
114                let avg = (n.w + n.n) / 2;
115                (avg + n.w) / 2
116            }
117            Predictor::Average3 => {
118                let avg = (n.w + n.n) / 2;
119                (avg + n.n) / 2
120            }
121            Predictor::Average4 => (n.n + n.ne) / 2,
122        }
123    }
124}
125
126/// Neighbor values for prediction.
127#[derive(Debug, Clone, Copy, Default)]
128pub struct Neighbors {
129    /// North (top) neighbor.
130    pub n: i32,
131    /// West (left) neighbor.
132    pub w: i32,
133    /// Northwest (top-left) neighbor.
134    pub nw: i32,
135    /// Northeast (top-right) neighbor.
136    pub ne: i32,
137    /// North-north (2 pixels above) neighbor.
138    pub nn: i32,
139    /// West-west (2 pixels left) neighbor.
140    pub ww: i32,
141}
142
143impl Neighbors {
144    /// Gathers neighbor values from a channel, matching the JXL spec's edge handling.
145    ///
146    /// Edge clamping rules (from jxl-rs PredictionData::get_rows):
147    /// - left: x>0 ? row[x-1] : (y>0 ? top_row[0] : 0)
148    /// - top: y>0 ? top_row[x] : left
149    /// - topleft: x>0 && y>0 ? top_row[x-1] : left
150    /// - topright: x+1 < width && y>0 ? top_row[x+1] : top
151    /// - leftleft: x>1 ? row[x-2] : left
152    /// - toptop: y>1 ? toptop_row[x] : top
153    #[inline]
154    pub fn gather(channel: &Channel, x: usize, y: usize) -> Self {
155        let width = channel.width();
156
157        let w = if x > 0 {
158            channel.get(x - 1, y)
159        } else if y > 0 {
160            channel.get(0, y - 1)
161        } else {
162            0
163        };
164
165        let n = if y > 0 { channel.get(x, y - 1) } else { w };
166
167        let nw = if x > 0 && y > 0 {
168            channel.get(x - 1, y - 1)
169        } else {
170            w
171        };
172
173        let ne = if x + 1 < width && y > 0 {
174            channel.get(x + 1, y - 1)
175        } else {
176            n
177        };
178
179        let ww = if x > 1 { channel.get(x - 2, y) } else { w };
180
181        let nn = if y > 1 { channel.get(x, y - 2) } else { n };
182
183        Self {
184            n,
185            w,
186            nw,
187            ne,
188            nn,
189            ww,
190        }
191    }
192
193    /// Gathers neighbors with explicit row pointers for speed, matching JXL spec edge handling.
194    #[inline]
195    pub fn gather_fast(
196        row: &[i32],
197        prev_row: Option<&[i32]>,
198        prev_prev_row: Option<&[i32]>,
199        x: usize,
200        _width: usize,
201    ) -> Self {
202        let w = if x > 0 {
203            row[x - 1]
204        } else if let Some(prev) = prev_row {
205            prev[0]
206        } else {
207            0
208        };
209
210        let n = if let Some(prev) = prev_row {
211            prev[x]
212        } else {
213            w
214        };
215
216        let nw = if x > 0 {
217            if let Some(prev) = prev_row {
218                prev[x - 1]
219            } else {
220                w
221            }
222        } else {
223            w
224        };
225
226        let ne = if let Some(prev) = prev_row {
227            if x + 1 < prev.len() { prev[x + 1] } else { n }
228        } else {
229            n
230        };
231
232        let ww = if x > 1 { row[x - 2] } else { w };
233
234        let nn = if let Some(pp) = prev_prev_row {
235            pp[x]
236        } else {
237            n
238        };
239
240        Self {
241            n,
242            w,
243            nw,
244            ne,
245            nn,
246            ww,
247        }
248    }
249}
250
251/// Number of sub-predictors in weighted predictor.
252const NUM_WP_PREDICTORS: usize = 4;
253/// Extra precision bits for weighted predictor.
254const PRED_EXTRA_BITS: i64 = 3;
255/// Rounding value for weighted predictor.
256const PREDICTION_ROUND: i64 = ((1 << PRED_EXTRA_BITS) >> 1) - 1;
257
258/// Division lookup table for fast approximate division by 1-64.
259/// `DIVLOOKUP[i] = (1 << 24) / (i + 1)`
260const DIVLOOKUP: [u32; 64] = [
261    16777216, 8388608, 5592405, 4194304, 3355443, 2796202, 2396745, 2097152, 1864135, 1677721,
262    1525201, 1398101, 1290555, 1198372, 1118481, 1048576, 986895, 932067, 883011, 838860, 798915,
263    762600, 729444, 699050, 671088, 645277, 621378, 599186, 578524, 559240, 541200, 524288, 508400,
264    493447, 479349, 466033, 453438, 441505, 430185, 419430, 409200, 399457, 390167, 381300, 372827,
265    364722, 356962, 349525, 342392, 335544, 328965, 322638, 316551, 310689, 305040, 299593, 294337,
266    289262, 284359, 279620, 275036, 270600, 266305, 262144,
267];
268
269/// Parameters for the weighted predictor (from bitstream header).
270#[derive(Debug, Clone, Copy)]
271pub struct WeightedPredictorParams {
272    /// Correction parameter for predictor 1.
273    pub p1c: u32,
274    /// Correction parameter for predictor 2.
275    pub p2c: u32,
276    /// Correction parameters for predictor 3.
277    pub p3ca: u32,
278    pub p3cb: u32,
279    pub p3cc: u32,
280    pub p3cd: u32,
281    pub p3ce: u32,
282    /// Weight multipliers for error weighting.
283    pub w0: u32,
284    pub w1: u32,
285    pub w2: u32,
286    pub w3: u32,
287}
288
289impl Default for WeightedPredictorParams {
290    fn default() -> Self {
291        // Default values from JXL spec
292        Self {
293            p1c: 16,
294            p2c: 10,
295            p3ca: 7,
296            p3cb: 7,
297            p3cc: 7,
298            p3cd: 0,
299            p3ce: 0,
300            w0: 0xd,
301            w1: 0xc,
302            w2: 0xc,
303            w3: 0xc,
304        }
305    }
306}
307
308impl WeightedPredictorParams {
309    /// Get weight multiplier by index.
310    pub fn w(&self, i: usize) -> u32 {
311        match i {
312            0 => self.w0,
313            1 => self.w1,
314            2 => self.w2,
315            3 => self.w3,
316            _ => panic!("Invalid weight index"),
317        }
318    }
319
320    /// Returns true if all parameters are at default values.
321    pub fn is_default(&self) -> bool {
322        *self == Self::default()
323    }
324}
325
326impl PartialEq for WeightedPredictorParams {
327    fn eq(&self, other: &Self) -> bool {
328        self.p1c == other.p1c
329            && self.p2c == other.p2c
330            && self.p3ca == other.p3ca
331            && self.p3cb == other.p3cb
332            && self.p3cc == other.p3cc
333            && self.p3cd == other.p3cd
334            && self.p3ce == other.p3ce
335            && self.w0 == other.w0
336            && self.w1 == other.w1
337            && self.w2 == other.w2
338            && self.w3 == other.w3
339    }
340}
341
342/// Floor log2 for non-zero values.
343#[inline]
344fn floor_log2_nonzero(x: u64) -> u32 {
345    63 - x.leading_zeros()
346}
347
348/// Add extra precision bits.
349#[inline]
350fn add_bits(x: i32) -> i64 {
351    (x as i64) << PRED_EXTRA_BITS
352}
353
354/// Compute error weight from accumulated error.
355#[inline]
356fn error_weight(x: u32, maxweight: u32) -> u32 {
357    let shift = floor_log2_nonzero(x as u64 + 1) as i32 - 5;
358    if shift < 0 {
359        4u32 + maxweight * DIVLOOKUP[x as usize & 63]
360    } else {
361        4u32 + ((maxweight * DIVLOOKUP[(x as usize >> shift) & 63]) >> shift)
362    }
363}
364
365/// Compute weighted average of predictions.
366fn weighted_average(
367    pixels: &[i64; NUM_WP_PREDICTORS],
368    weights: &mut [u32; NUM_WP_PREDICTORS],
369) -> i64 {
370    let log_weight = floor_log2_nonzero(weights.iter().fold(0u64, |sum, el| sum + *el as u64));
371    let weight_sum = weights.iter_mut().fold(0, |sum, el| {
372        *el >>= log_weight - 4;
373        sum + *el
374    });
375    let sum = weights
376        .iter()
377        .enumerate()
378        .fold(((weight_sum >> 1) - 1) as i64, |sum, (i, weight)| {
379            sum + pixels[i] * *weight as i64
380        });
381    (sum * DIVLOOKUP[(weight_sum - 1) as usize] as i64) >> 24
382}
383
384/// Full weighted predictor state for adaptive prediction.
385/// Matches libjxl/jxl-rs implementation for encoding parity.
386#[derive(Debug)]
387pub struct WeightedPredictorState {
388    /// Current predictions from each sub-predictor.
389    prediction: [i64; NUM_WP_PREDICTORS],
390    /// Final weighted prediction.
391    pred: i64,
392    /// Per-position error buffer (position-major layout).
393    /// Layout: [pos0: p0,p1,p2,p3] [pos1: p0,p1,p2,p3] ...
394    pred_errors_buffer: Vec<u32>,
395    /// Prediction errors per position.
396    error: Vec<i32>,
397    /// Weighted predictor parameters.
398    params: WeightedPredictorParams,
399}
400
401impl WeightedPredictorState {
402    /// Creates a new weighted predictor state.
403    pub fn new(params: &WeightedPredictorParams, xsize: usize) -> Self {
404        let num_errors = (xsize + 2) * 2;
405        Self {
406            prediction: [0; NUM_WP_PREDICTORS],
407            pred: 0,
408            pred_errors_buffer: vec![0; num_errors * NUM_WP_PREDICTORS],
409            error: vec![0; num_errors],
410            params: *params,
411        }
412    }
413
414    /// Creates with default parameters.
415    pub fn with_defaults(xsize: usize) -> Self {
416        Self::new(&WeightedPredictorParams::default(), xsize)
417    }
418
419    /// Get all predictor errors for a given position (contiguous in memory).
420    #[inline(always)]
421    fn get_errors_at_pos(&self, pos: usize) -> &[u32; NUM_WP_PREDICTORS] {
422        let start = pos * NUM_WP_PREDICTORS;
423        self.pred_errors_buffer[start..start + NUM_WP_PREDICTORS]
424            .try_into()
425            .unwrap()
426    }
427
428    /// Get mutable reference to all predictor errors for a given position.
429    #[inline(always)]
430    fn get_errors_at_pos_mut(&mut self, pos: usize) -> &mut [u32; NUM_WP_PREDICTORS] {
431        let start = pos * NUM_WP_PREDICTORS;
432        (&mut self.pred_errors_buffer[start..start + NUM_WP_PREDICTORS])
433            .try_into()
434            .unwrap()
435    }
436
437    /// Compute prediction and property value.
438    /// Returns (prediction, max_error_property).
439    #[inline]
440    pub fn predict_and_property(
441        &mut self,
442        x: usize,
443        y: usize,
444        xsize: usize,
445        neighbors: &Neighbors,
446    ) -> (i64, i32) {
447        let (cur_row, prev_row) = if y & 1 != 0 {
448            (0, xsize + 2)
449        } else {
450            (xsize + 2, 0)
451        };
452        let pos_n = prev_row + x;
453        let pos_ne = if x < xsize - 1 { pos_n + 1 } else { pos_n };
454        let pos_nw = if x > 0 { pos_n - 1 } else { pos_n };
455
456        // Get errors at neighboring positions
457        let errors_n = self.get_errors_at_pos(pos_n);
458        let errors_ne = self.get_errors_at_pos(pos_ne);
459        let errors_nw = self.get_errors_at_pos(pos_nw);
460
461        // Compute weights from errors
462        let mut weights = [0u32; NUM_WP_PREDICTORS];
463        for i in 0..NUM_WP_PREDICTORS {
464            weights[i] = error_weight(
465                errors_n[i]
466                    .wrapping_add(errors_ne[i])
467                    .wrapping_add(errors_nw[i]),
468                self.params.w(i),
469            );
470        }
471
472        // Convert neighbors to higher precision
473        let n = add_bits(neighbors.n);
474        let w = add_bits(neighbors.w);
475        let ne = add_bits(neighbors.ne);
476        let nw = add_bits(neighbors.nw);
477        let nn = add_bits(neighbors.nn);
478
479        // Get transmission errors from neighboring positions
480        let te_w = if x == 0 {
481            0
482        } else {
483            self.error[cur_row + x - 1] as i64
484        };
485        let te_n = self.error[pos_n] as i64;
486        let te_nw = self.error[pos_nw] as i64;
487        let sum_wn = te_n + te_w;
488        let te_ne = self.error[pos_ne] as i64;
489
490        // Find max absolute error for property
491        let mut p = te_w;
492        if te_n.abs() > p.abs() {
493            p = te_n;
494        }
495        if te_nw.abs() > p.abs() {
496            p = te_nw;
497        }
498        if te_ne.abs() > p.abs() {
499            p = te_ne;
500        }
501
502        // Compute 4 sub-predictions with corrections
503        self.prediction[0] = w + ne - n;
504        self.prediction[1] = n - (((sum_wn + te_ne) * self.params.p1c as i64) >> 5);
505        self.prediction[2] = w - (((sum_wn + te_nw) * self.params.p2c as i64) >> 5);
506        self.prediction[3] = n
507            - ((te_nw * (self.params.p3ca as i64)
508                + (te_n * (self.params.p3cb as i64))
509                + (te_ne * (self.params.p3cc as i64))
510                + ((nn - n) * (self.params.p3cd as i64))
511                + ((nw - w) * (self.params.p3ce as i64)))
512                >> 5);
513
514        // Compute weighted average
515        self.pred = weighted_average(&self.prediction, &mut weights);
516
517        // Apply clamping when errors have consistent signs
518        if ((te_n ^ te_w) | (te_n ^ te_nw)) <= 0 {
519            let mx = w.max(ne.max(n));
520            let mn = w.min(ne.min(n));
521            self.pred = mn.max(mx.min(self.pred));
522        }
523
524        ((self.pred + PREDICTION_ROUND) >> PRED_EXTRA_BITS, p as i32)
525    }
526
527    /// Update error buffers after seeing actual value.
528    #[inline]
529    pub fn update_errors(&mut self, actual: i32, x: usize, y: usize, xsize: usize) {
530        let (cur_row, prev_row) = if y & 1 != 0 {
531            (0, xsize + 2)
532        } else {
533            (xsize + 2, 0)
534        };
535        let val = add_bits(actual);
536        self.error[cur_row + x] = (self.pred - val) as i32;
537
538        // Compute errors for all predictors
539        let mut errs = [0u32; NUM_WP_PREDICTORS];
540        for (err, &pred) in errs.iter_mut().zip(self.prediction.iter()) {
541            *err = (((pred - val).abs() + PREDICTION_ROUND) >> PRED_EXTRA_BITS) as u32;
542        }
543
544        // Write to current position
545        *self.get_errors_at_pos_mut(cur_row + x) = errs;
546
547        // Update previous row position
548        let prev_errors = self.get_errors_at_pos_mut(prev_row + x + 1);
549        for i in 0..NUM_WP_PREDICTORS {
550            prev_errors[i] = prev_errors[i].wrapping_add(errs[i]);
551        }
552    }
553
554    /// Simple predict method for compatibility.
555    pub fn predict(&mut self, x: usize, y: usize, xsize: usize, neighbors: &Neighbors) -> i32 {
556        let (pred, _) = self.predict_and_property(x, y, xsize, neighbors);
557        pred as i32
558    }
559}
560
561impl Default for WeightedPredictorState {
562    fn default() -> Self {
563        Self::with_defaults(256)
564    }
565}
566
567/// Packs a signed integer for entropy coding.
568/// Converts signed values to unsigned using zig-zag encoding.
569#[inline]
570pub fn pack_signed(value: i32) -> u32 {
571    if value >= 0 {
572        (value as u32) * 2
573    } else {
574        ((-value) as u32) * 2 - 1
575    }
576}
577
578/// Unpacks a zig-zag encoded value back to signed.
579#[inline]
580pub fn unpack_signed(value: u32) -> i32 {
581    if value & 1 == 0 {
582        (value / 2) as i32
583    } else {
584        -((value / 2) as i32) - 1
585    }
586}
587
588#[cfg(test)]
589mod tests {
590    use super::*;
591
592    #[test]
593    fn test_predictors() {
594        let mut channel = Channel::new(4, 4).unwrap();
595
596        // Set up a simple gradient pattern
597        for y in 0..4 {
598            for x in 0..4 {
599                channel.set(x, y, (x + y * 4) as i32);
600            }
601        }
602
603        // Test at position (2, 2)
604        // Pattern: 0  1  2  3
605        //          4  5  6  7
606        //          8  9 [10] 11
607        //         12 13 14 15
608
609        let neighbors = Neighbors::gather(&channel, 2, 2);
610        assert_eq!(neighbors.n, 6); // Top
611        assert_eq!(neighbors.w, 9); // Left
612        assert_eq!(neighbors.nw, 5); // Top-left
613        assert_eq!(neighbors.ne, 7); // Top-right
614
615        // Test Zero predictor
616        assert_eq!(Predictor::Zero.predict_from_neighbors(&neighbors), 0);
617
618        // Test Left predictor
619        assert_eq!(Predictor::Left.predict_from_neighbors(&neighbors), 9);
620
621        // Test Top predictor
622        assert_eq!(Predictor::Top.predict_from_neighbors(&neighbors), 6);
623
624        // Test Gradient predictor: 9 + 6 - 5 = 10, clamped to [6, 9] = 9
625        assert_eq!(Predictor::Gradient.predict_from_neighbors(&neighbors), 9);
626    }
627
628    #[test]
629    fn test_pack_signed() {
630        assert_eq!(pack_signed(0), 0);
631        assert_eq!(pack_signed(1), 2);
632        assert_eq!(pack_signed(-1), 1);
633        assert_eq!(pack_signed(2), 4);
634        assert_eq!(pack_signed(-2), 3);
635    }
636
637    #[test]
638    fn test_unpack_signed() {
639        assert_eq!(unpack_signed(0), 0);
640        assert_eq!(unpack_signed(1), -1);
641        assert_eq!(unpack_signed(2), 1);
642        assert_eq!(unpack_signed(3), -2);
643        assert_eq!(unpack_signed(4), 2);
644    }
645
646    #[test]
647    fn test_pack_unpack_roundtrip() {
648        for i in -1000..=1000 {
649            assert_eq!(unpack_signed(pack_signed(i)), i);
650        }
651    }
652
653    #[test]
654    fn test_weighted_predictor_params_default() {
655        let params = WeightedPredictorParams::default();
656        assert_eq!(params.p1c, 16);
657        assert_eq!(params.p2c, 10);
658        assert_eq!(params.w0, 0xd);
659        assert!(params.is_default());
660    }
661
662    #[test]
663    fn test_weighted_predictor_state() {
664        let xsize = 8;
665        let mut wp = WeightedPredictorState::with_defaults(xsize);
666
667        // Test prediction on uniform data
668        let neighbors = Neighbors {
669            n: 100,
670            w: 100,
671            nw: 100,
672            ne: 100,
673            nn: 100,
674            ww: 100,
675        };
676
677        let (pred, _prop) = wp.predict_and_property(4, 2, xsize, &neighbors);
678        // For uniform data, prediction should be close to 100
679        assert!((pred - 100).abs() <= 2);
680
681        // Update with actual value
682        wp.update_errors(100, 4, 2, xsize);
683    }
684
685    #[test]
686    fn test_weighted_predictor_adapts() {
687        let xsize = 8;
688        let mut wp = WeightedPredictorState::with_defaults(xsize);
689
690        // Simulate processing a row with gradient pattern
691        for x in 0..xsize {
692            let actual = (x * 10) as i32;
693            let neighbors = Neighbors {
694                n: if x > 0 { ((x - 1) * 10) as i32 } else { 0 },
695                w: if x > 0 { ((x - 1) * 10) as i32 } else { 0 },
696                nw: if x > 1 { ((x - 2) * 10) as i32 } else { 0 },
697                ne: (x * 10) as i32,
698                nn: 0,
699                ww: if x > 1 { ((x - 2) * 10) as i32 } else { 0 },
700            };
701
702            let (_pred, _prop) = wp.predict_and_property(x, 1, xsize, &neighbors);
703            wp.update_errors(actual, x, 1, xsize);
704        }
705        // Just verify it doesn't panic
706    }
707
708    /// Reproduce jxl-rs golden-number test to verify bit-exactness.
709    #[test]
710    fn test_wp_matches_jxl_rs_golden() {
711        struct SimpleRandom {
712            out: i64,
713        }
714        impl SimpleRandom {
715            fn new() -> Self {
716                Self { out: 1 }
717            }
718            fn next(&mut self) -> i64 {
719                self.out = self.out * 48271 % 0x7fffffff;
720                self.out
721            }
722        }
723
724        let mut rng = SimpleRandom::new();
725        let params = WeightedPredictorParams {
726            p1c: rng.next() as u32 % 32,
727            p2c: rng.next() as u32 % 32,
728            p3ca: rng.next() as u32 % 32,
729            p3cb: rng.next() as u32 % 32,
730            p3cc: rng.next() as u32 % 32,
731            p3cd: rng.next() as u32 % 32,
732            p3ce: rng.next() as u32 % 32,
733            w0: rng.next() as u32 % 16,
734            w1: rng.next() as u32 % 16,
735            w2: rng.next() as u32 % 16,
736            w3: rng.next() as u32 % 16,
737        };
738        let xsize = 8;
739        let ysize = 8;
740        let mut state = WeightedPredictorState::new(&params, xsize);
741
742        // Helper: one step of predict + update
743        let step = |rng: &mut SimpleRandom, state: &mut WeightedPredictorState| -> (i64, i32) {
744            let x = rng.next() as usize % xsize;
745            let y = rng.next() as usize % ysize;
746            let neighbors = Neighbors {
747                n: rng.next() as i32 % 256,  // top
748                w: rng.next() as i32 % 256,  // left
749                ne: rng.next() as i32 % 256, // topright
750                nw: rng.next() as i32 % 256, // topleft
751                nn: rng.next() as i32 % 256, // toptop
752                ww: 0,
753            };
754            let res = state.predict_and_property(x, y, xsize, &neighbors);
755            state.update_errors((rng.next() % 256) as i32, x, y, xsize);
756            res
757        };
758
759        // Golden numbers from libjxl (verified in jxl-rs test)
760        assert_eq!(step(&mut rng, &mut state), (135, 0), "step 1");
761        assert_eq!(step(&mut rng, &mut state), (110, -60), "step 2");
762        assert_eq!(step(&mut rng, &mut state), (165, 0), "step 3");
763        assert_eq!(step(&mut rng, &mut state), (153, -60), "step 4");
764    }
765}