Skip to main content

jxl_encoder/modular/
predictor.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! Predictor implementations for modular encoding.
6//!
7//! Predictors estimate the value of a pixel based on its neighbors.
8//! The prediction residual (actual - predicted) is what gets entropy coded.
9
10use super::channel::Channel;
11
12/// Available predictors for modular encoding.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
14#[repr(u8)]
15pub enum Predictor {
16    /// Always predicts 0.
17    #[default]
18    Zero = 0,
19    /// Uses the left neighbor (west).
20    Left = 1,
21    /// Uses the top neighbor (north).
22    Top = 2,
23    /// Average of left and top.
24    Average0 = 3,
25    /// Select between left/top based on top-left.
26    Select = 4,
27    /// Gradient: left + top - topleft (clamped).
28    Gradient = 5,
29    /// Weighted average favoring left.
30    Weighted = 6,
31    /// Top-right neighbor.
32    TopRight = 7,
33    /// Top-left neighbor.
34    TopLeft = 8,
35    /// Left-left neighbor (2 pixels left).
36    LeftLeft = 9,
37    /// Average of west and north-west: (W + NW) / 2
38    Average1 = 10,
39    /// Average of north and north-west: (N + NW) / 2
40    Average2 = 11,
41    /// Average of north and north-east: (N + NE) / 2
42    Average3 = 12,
43    /// Weighted average: (6N - 2NN + 7W + WW + NE2 + 3NE + 8) / 16
44    Average4 = 13,
45}
46
47impl Predictor {
48    /// Number of simple predictors (excluding weighted/variable).
49    pub const NUM_SIMPLE: usize = 14;
50
51    /// Returns all simple predictors.
52    pub fn all_simple() -> &'static [Predictor] {
53        &[
54            Predictor::Zero,
55            Predictor::Left,
56            Predictor::Top,
57            Predictor::Average0,
58            Predictor::Select,
59            Predictor::Gradient,
60            Predictor::Weighted,
61            Predictor::TopRight,
62            Predictor::TopLeft,
63            Predictor::LeftLeft,
64            Predictor::Average1,
65            Predictor::Average2,
66            Predictor::Average3,
67            Predictor::Average4,
68        ]
69    }
70
71    /// Predicts the value at (x, y) using this predictor.
72    #[inline]
73    pub fn predict(self, channel: &Channel, x: usize, y: usize) -> i32 {
74        let neighbors = Neighbors::gather(channel, x, y);
75        self.predict_from_neighbors(&neighbors)
76    }
77
78    /// Predicts from pre-gathered neighbor values.
79    #[inline]
80    pub fn predict_from_neighbors(self, n: &Neighbors) -> i32 {
81        match self {
82            Predictor::Zero => 0,
83            Predictor::Left => n.w,
84            Predictor::Top => n.n,
85            Predictor::Average0 => (n.w + n.n) / 2,
86            Predictor::Select => {
87                // Select predictor (matches JXL spec):
88                // p = W + N - NW
89                // if abs(p - W) < abs(p - N) then W else N
90                // Since p - W = N - NW and p - N = W - NW:
91                // if abs(N - NW) < abs(W - NW) then W else N
92                if n.n.abs_diff(n.nw) < n.w.abs_diff(n.nw) {
93                    n.w
94                } else {
95                    n.n
96                }
97            }
98            Predictor::Gradient => {
99                // Clamped gradient: W + N - NW, clamped to [min(W,N), max(W,N)]
100                let gradient = n.w.saturating_add(n.n).saturating_sub(n.nw);
101                gradient.clamp(n.w.min(n.n), n.w.max(n.n))
102            }
103            Predictor::Weighted => {
104                // Simplified weighted predictor (full version uses adaptive weights)
105                // This is a placeholder - full weighted uses WP state
106                let gradient = n.w.saturating_add(n.n).saturating_sub(n.nw);
107                gradient.clamp(n.w.min(n.n), n.w.max(n.n))
108            }
109            Predictor::TopRight => n.ne,
110            Predictor::TopLeft => n.nw,
111            Predictor::LeftLeft => n.ww,
112            Predictor::Average1 => (n.w + n.nw) / 2,
113            Predictor::Average2 => (n.n + n.nw) / 2,
114            Predictor::Average3 => (n.n + n.ne) / 2,
115            Predictor::Average4 => {
116                // AverageAll: (6*N - 2*NN + 7*W + WW + NEE + 3*NE + 8) / 16
117                // where NEE = toprightright = pixel at (x+2, y-1)
118                // Use i64 intermediates to prevent overflow (libjxl PR #4574)
119                ((6i64 * n.n as i64 - 2 * n.nn as i64
120                    + 7 * n.w as i64
121                    + n.ww as i64
122                    + n.nee as i64
123                    + 3 * n.ne as i64
124                    + 8)
125                    / 16) as i32
126            }
127        }
128    }
129}
130
131/// Neighbor values for prediction.
132#[derive(Debug, Clone, Copy, Default)]
133pub struct Neighbors {
134    /// North (top) neighbor.
135    pub n: i32,
136    /// West (left) neighbor.
137    pub w: i32,
138    /// Northwest (top-left) neighbor.
139    pub nw: i32,
140    /// Northeast (top-right) neighbor.
141    pub ne: i32,
142    /// North-north (2 pixels above) neighbor.
143    pub nn: i32,
144    /// West-west (2 pixels left) neighbor.
145    pub ww: i32,
146    /// Northeast-east (top-right of top-right, pixel at x+2, y-1). Used by AverageAll predictor.
147    pub nee: i32,
148}
149
150impl Neighbors {
151    /// Gathers neighbor values from a channel, matching the JXL spec's edge handling.
152    ///
153    /// Edge clamping rules (from jxl-rs PredictionData::get_rows):
154    /// - left: `x>0 ? row[x-1] : (y>0 ? top_row[0] : 0)`
155    /// - top: `y>0 ? top_row[x] : left`
156    /// - topleft: `x>0 && y>0 ? top_row[x-1] : left`
157    /// - topright: `x+1 < width && y>0 ? top_row[x+1] : top`
158    /// - leftleft: `x>1 ? row[x-2] : left`
159    /// - toptop: `y>1 ? toptop_row[x] : top`
160    #[inline]
161    pub fn gather(channel: &Channel, x: usize, y: usize) -> Self {
162        let width = channel.width();
163
164        let w = if x > 0 {
165            channel.get(x - 1, y)
166        } else if y > 0 {
167            channel.get(0, y - 1)
168        } else {
169            0
170        };
171
172        let n = if y > 0 { channel.get(x, y - 1) } else { w };
173
174        let nw = if x > 0 && y > 0 {
175            channel.get(x - 1, y - 1)
176        } else {
177            w
178        };
179
180        let ne = if x + 1 < width && y > 0 {
181            channel.get(x + 1, y - 1)
182        } else {
183            n
184        };
185
186        let ww = if x > 1 { channel.get(x - 2, y) } else { w };
187
188        let nn = if y > 1 { channel.get(x, y - 2) } else { n };
189
190        let nee = if x + 2 < width && y > 0 {
191            channel.get(x + 2, y - 1)
192        } else {
193            ne
194        };
195
196        Self {
197            n,
198            w,
199            nw,
200            ne,
201            nn,
202            ww,
203            nee,
204        }
205    }
206
207    /// Gathers neighbors with explicit row pointers for speed, matching JXL spec edge handling.
208    #[inline]
209    pub fn gather_fast(
210        row: &[i32],
211        prev_row: Option<&[i32]>,
212        prev_prev_row: Option<&[i32]>,
213        x: usize,
214        _width: usize,
215    ) -> Self {
216        let w = if x > 0 {
217            row[x - 1]
218        } else if let Some(prev) = prev_row {
219            prev[0]
220        } else {
221            0
222        };
223
224        let n = if let Some(prev) = prev_row {
225            prev[x]
226        } else {
227            w
228        };
229
230        let nw = if x > 0 {
231            if let Some(prev) = prev_row {
232                prev[x - 1]
233            } else {
234                w
235            }
236        } else {
237            w
238        };
239
240        let ne = if let Some(prev) = prev_row {
241            if x + 1 < prev.len() { prev[x + 1] } else { n }
242        } else {
243            n
244        };
245
246        let ww = if x > 1 { row[x - 2] } else { w };
247
248        let nn = if let Some(pp) = prev_prev_row {
249            pp[x]
250        } else {
251            n
252        };
253
254        let nee = if let Some(prev) = prev_row {
255            if x + 2 < prev.len() { prev[x + 2] } else { ne }
256        } else {
257            ne
258        };
259
260        Self {
261            n,
262            w,
263            nw,
264            ne,
265            nn,
266            ww,
267            nee,
268        }
269    }
270}
271
272/// Number of sub-predictors in weighted predictor.
273const NUM_WP_PREDICTORS: usize = 4;
274/// Extra precision bits for weighted predictor.
275const PRED_EXTRA_BITS: i64 = 3;
276/// Rounding value for weighted predictor.
277const PREDICTION_ROUND: i64 = ((1 << PRED_EXTRA_BITS) >> 1) - 1;
278
279/// Division lookup table for fast approximate division by 1-64.
280/// `DIVLOOKUP[i] = (1 << 24) / (i + 1)`
281const DIVLOOKUP: [u32; 64] = [
282    16777216, 8388608, 5592405, 4194304, 3355443, 2796202, 2396745, 2097152, 1864135, 1677721,
283    1525201, 1398101, 1290555, 1198372, 1118481, 1048576, 986895, 932067, 883011, 838860, 798915,
284    762600, 729444, 699050, 671088, 645277, 621378, 599186, 578524, 559240, 541200, 524288, 508400,
285    493447, 479349, 466033, 453438, 441505, 430185, 419430, 409200, 399457, 390167, 381300, 372827,
286    364722, 356962, 349525, 342392, 335544, 328965, 322638, 316551, 310689, 305040, 299593, 294337,
287    289262, 284359, 279620, 275036, 270600, 266305, 262144,
288];
289
290/// Parameters for the weighted predictor (from bitstream header).
291#[derive(Debug, Clone, Copy)]
292pub struct WeightedPredictorParams {
293    /// Correction parameter for predictor 1.
294    pub p1c: u32,
295    /// Correction parameter for predictor 2.
296    pub p2c: u32,
297    /// Correction parameters for predictor 3.
298    pub p3ca: u32,
299    pub p3cb: u32,
300    pub p3cc: u32,
301    pub p3cd: u32,
302    pub p3ce: u32,
303    /// Weight multipliers for error weighting.
304    pub w0: u32,
305    pub w1: u32,
306    pub w2: u32,
307    pub w3: u32,
308}
309
310impl Default for WeightedPredictorParams {
311    fn default() -> Self {
312        // Default values from JXL spec
313        Self {
314            p1c: 16,
315            p2c: 10,
316            p3ca: 7,
317            p3cb: 7,
318            p3cc: 7,
319            p3cd: 0,
320            p3ce: 0,
321            w0: 0xd,
322            w1: 0xc,
323            w2: 0xc,
324            w3: 0xc,
325        }
326    }
327}
328
329impl WeightedPredictorParams {
330    /// Get weight multiplier by index.
331    pub fn w(&self, i: usize) -> u32 {
332        match i {
333            0 => self.w0,
334            1 => self.w1,
335            2 => self.w2,
336            3 => self.w3,
337            _ => panic!("Invalid weight index"),
338        }
339    }
340
341    /// Returns true if all parameters are at default values.
342    pub fn is_default(&self) -> bool {
343        *self == Self::default()
344    }
345
346    /// Get parameter set by mode index (0–4), matching libjxl's PredictorMode().
347    ///
348    /// - Mode 0: Default (lossless16)
349    /// - Mode 1: lossless8 variant
350    /// - Mode 2: West-biased lossless8
351    /// - Mode 3: North-biased lossless8
352    /// - Mode 4: Generic/balanced
353    pub fn for_mode(mode: u8) -> Self {
354        match mode {
355            0 => Self::default(),
356            1 => Self {
357                p1c: 8,
358                p2c: 8,
359                p3ca: 4,
360                p3cb: 0,
361                p3cc: 3,
362                p3cd: 23,
363                p3ce: 2,
364                w0: 0xd,
365                w1: 0xc,
366                w2: 0xc,
367                w3: 0xb,
368            },
369            2 => Self {
370                p1c: 10,
371                p2c: 9,
372                p3ca: 7,
373                p3cb: 0,
374                p3cc: 0,
375                p3cd: 16,
376                p3ce: 9,
377                w0: 0xd,
378                w1: 0xc,
379                w2: 0xd,
380                w3: 0xc,
381            },
382            3 => Self {
383                p1c: 16,
384                p2c: 8,
385                p3ca: 0,
386                p3cb: 16,
387                p3cc: 0,
388                p3cd: 23,
389                p3ce: 0,
390                w0: 0xd,
391                w1: 0xd,
392                w2: 0xc,
393                w3: 0xc,
394            },
395            _ => Self {
396                p1c: 10,
397                p2c: 10,
398                p3ca: 5,
399                p3cb: 5,
400                p3cc: 5,
401                p3cd: 12,
402                p3ce: 4,
403                w0: 0xd,
404                w1: 0xc,
405                w2: 0xc,
406                w3: 0xc,
407            },
408        }
409    }
410}
411
412impl PartialEq for WeightedPredictorParams {
413    fn eq(&self, other: &Self) -> bool {
414        self.p1c == other.p1c
415            && self.p2c == other.p2c
416            && self.p3ca == other.p3ca
417            && self.p3cb == other.p3cb
418            && self.p3cc == other.p3cc
419            && self.p3cd == other.p3cd
420            && self.p3ce == other.p3ce
421            && self.w0 == other.w0
422            && self.w1 == other.w1
423            && self.w2 == other.w2
424            && self.w3 == other.w3
425    }
426}
427
428/// Floor log2 for non-zero values.
429#[inline]
430fn floor_log2_nonzero(x: u64) -> u32 {
431    63 - x.leading_zeros()
432}
433
434/// Add extra precision bits.
435#[inline]
436fn add_bits(x: i32) -> i64 {
437    (x as i64) << PRED_EXTRA_BITS
438}
439
440/// Compute error weight from accumulated error.
441#[inline]
442fn error_weight(x: u32, maxweight: u32) -> u32 {
443    let shift = floor_log2_nonzero(x as u64 + 1) as i32 - 5;
444    if shift < 0 {
445        4u32 + maxweight * DIVLOOKUP[x as usize & 63]
446    } else {
447        4u32 + ((maxweight * DIVLOOKUP[(x as usize >> shift) & 63]) >> shift)
448    }
449}
450
451/// Compute weighted average of predictions.
452fn weighted_average(
453    pixels: &[i64; NUM_WP_PREDICTORS],
454    weights: &mut [u32; NUM_WP_PREDICTORS],
455) -> i64 {
456    let log_weight = floor_log2_nonzero(weights.iter().fold(0u64, |sum, el| sum + *el as u64));
457    let weight_sum = weights.iter_mut().fold(0, |sum, el| {
458        *el >>= log_weight - 4;
459        sum + *el
460    });
461    let sum = weights
462        .iter()
463        .enumerate()
464        .fold(((weight_sum >> 1) - 1) as i64, |sum, (i, weight)| {
465            sum + pixels[i] * *weight as i64
466        });
467    (sum * DIVLOOKUP[(weight_sum - 1) as usize] as i64) >> 24
468}
469
470/// Full weighted predictor state for adaptive prediction.
471/// Matches libjxl/jxl-rs implementation for encoding parity.
472#[derive(Debug)]
473pub struct WeightedPredictorState {
474    /// Current predictions from each sub-predictor.
475    prediction: [i64; NUM_WP_PREDICTORS],
476    /// Final weighted prediction.
477    pred: i64,
478    /// Per-position error buffer (position-major layout).
479    /// Layout: [pos0: p0,p1,p2,p3] [pos1: p0,p1,p2,p3] ...
480    pred_errors_buffer: Vec<u32>,
481    /// Prediction errors per position.
482    error: Vec<i32>,
483    /// Weighted predictor parameters.
484    params: WeightedPredictorParams,
485}
486
487impl WeightedPredictorState {
488    /// Creates a new weighted predictor state.
489    pub fn new(params: &WeightedPredictorParams, xsize: usize) -> Self {
490        let num_errors = (xsize + 2) * 2;
491        Self {
492            prediction: [0; NUM_WP_PREDICTORS],
493            pred: 0,
494            pred_errors_buffer: vec![0; num_errors * NUM_WP_PREDICTORS],
495            error: vec![0; num_errors],
496            params: *params,
497        }
498    }
499
500    /// Creates with default parameters.
501    pub fn with_defaults(xsize: usize) -> Self {
502        Self::new(&WeightedPredictorParams::default(), xsize)
503    }
504
505    /// Get all predictor errors for a given position (contiguous in memory).
506    #[inline(always)]
507    fn get_errors_at_pos(&self, pos: usize) -> &[u32; NUM_WP_PREDICTORS] {
508        let start = pos * NUM_WP_PREDICTORS;
509        self.pred_errors_buffer[start..start + NUM_WP_PREDICTORS]
510            .try_into()
511            .unwrap()
512    }
513
514    /// Get mutable reference to all predictor errors for a given position.
515    #[inline(always)]
516    fn get_errors_at_pos_mut(&mut self, pos: usize) -> &mut [u32; NUM_WP_PREDICTORS] {
517        let start = pos * NUM_WP_PREDICTORS;
518        (&mut self.pred_errors_buffer[start..start + NUM_WP_PREDICTORS])
519            .try_into()
520            .unwrap()
521    }
522
523    /// Compute prediction and property value.
524    /// Returns (prediction, max_error_property).
525    #[inline]
526    pub fn predict_and_property(
527        &mut self,
528        x: usize,
529        y: usize,
530        xsize: usize,
531        neighbors: &Neighbors,
532    ) -> (i64, i32) {
533        let (cur_row, prev_row) = if y & 1 != 0 {
534            (0, xsize + 2)
535        } else {
536            (xsize + 2, 0)
537        };
538        let pos_n = prev_row + x;
539        let pos_ne = if x < xsize - 1 { pos_n + 1 } else { pos_n };
540        let pos_nw = if x > 0 { pos_n - 1 } else { pos_n };
541
542        // Get errors at neighboring positions
543        let errors_n = self.get_errors_at_pos(pos_n);
544        let errors_ne = self.get_errors_at_pos(pos_ne);
545        let errors_nw = self.get_errors_at_pos(pos_nw);
546
547        // Compute weights from errors
548        let mut weights = [0u32; NUM_WP_PREDICTORS];
549        for i in 0..NUM_WP_PREDICTORS {
550            weights[i] = error_weight(
551                errors_n[i]
552                    .wrapping_add(errors_ne[i])
553                    .wrapping_add(errors_nw[i]),
554                self.params.w(i),
555            );
556        }
557
558        // Convert neighbors to higher precision
559        let n = add_bits(neighbors.n);
560        let w = add_bits(neighbors.w);
561        let ne = add_bits(neighbors.ne);
562        let nw = add_bits(neighbors.nw);
563        let nn = add_bits(neighbors.nn);
564
565        // Get transmission errors from neighboring positions
566        let te_w = if x == 0 {
567            0
568        } else {
569            self.error[cur_row + x - 1] as i64
570        };
571        let te_n = self.error[pos_n] as i64;
572        let te_nw = self.error[pos_nw] as i64;
573        let sum_wn = te_n + te_w;
574        let te_ne = self.error[pos_ne] as i64;
575
576        // Find max absolute error for property
577        let mut p = te_w;
578        if te_n.abs() > p.abs() {
579            p = te_n;
580        }
581        if te_nw.abs() > p.abs() {
582            p = te_nw;
583        }
584        if te_ne.abs() > p.abs() {
585            p = te_ne;
586        }
587
588        // Compute 4 sub-predictions with corrections
589        self.prediction[0] = w + ne - n;
590        self.prediction[1] = n - (((sum_wn + te_ne) * self.params.p1c as i64) >> 5);
591        self.prediction[2] = w - (((sum_wn + te_nw) * self.params.p2c as i64) >> 5);
592        self.prediction[3] = n
593            - ((te_nw * (self.params.p3ca as i64)
594                + (te_n * (self.params.p3cb as i64))
595                + (te_ne * (self.params.p3cc as i64))
596                + ((nn - n) * (self.params.p3cd as i64))
597                + ((nw - w) * (self.params.p3ce as i64)))
598                >> 5);
599
600        // Compute weighted average
601        self.pred = weighted_average(&self.prediction, &mut weights);
602
603        // Apply clamping when errors have consistent signs
604        if ((te_n ^ te_w) | (te_n ^ te_nw)) <= 0 {
605            let mx = w.max(ne.max(n));
606            let mn = w.min(ne.min(n));
607            self.pred = mn.max(mx.min(self.pred));
608        }
609
610        ((self.pred + PREDICTION_ROUND) >> PRED_EXTRA_BITS, p as i32)
611    }
612
613    /// Update error buffers after seeing actual value.
614    #[inline]
615    pub fn update_errors(&mut self, actual: i32, x: usize, y: usize, xsize: usize) {
616        let (cur_row, prev_row) = if y & 1 != 0 {
617            (0, xsize + 2)
618        } else {
619            (xsize + 2, 0)
620        };
621        let val = add_bits(actual);
622        self.error[cur_row + x] = (self.pred - val) as i32;
623
624        // Compute errors for all predictors
625        let mut errs = [0u32; NUM_WP_PREDICTORS];
626        for (err, &pred) in errs.iter_mut().zip(self.prediction.iter()) {
627            *err = (((pred - val).abs() + PREDICTION_ROUND) >> PRED_EXTRA_BITS) as u32;
628        }
629
630        // Write to current position
631        *self.get_errors_at_pos_mut(cur_row + x) = errs;
632
633        // Update previous row position
634        let prev_errors = self.get_errors_at_pos_mut(prev_row + x + 1);
635        for i in 0..NUM_WP_PREDICTORS {
636            prev_errors[i] = prev_errors[i].wrapping_add(errs[i]);
637        }
638    }
639
640    /// Simple predict method for compatibility.
641    pub fn predict(&mut self, x: usize, y: usize, xsize: usize, neighbors: &Neighbors) -> i32 {
642        let (pred, _) = self.predict_and_property(x, y, xsize, neighbors);
643        pred as i32
644    }
645}
646
647impl Default for WeightedPredictorState {
648    fn default() -> Self {
649        Self::with_defaults(256)
650    }
651}
652
653/// Packs a signed integer for entropy coding.
654/// Converts signed values to unsigned using zig-zag encoding.
655#[inline]
656pub fn pack_signed(value: i32) -> u32 {
657    if value >= 0 {
658        (value as u32) * 2
659    } else {
660        ((-value) as u32) * 2 - 1
661    }
662}
663
664/// Unpacks a zig-zag encoded value back to signed.
665#[inline]
666pub fn unpack_signed(value: u32) -> i32 {
667    if value & 1 == 0 {
668        (value / 2) as i32
669    } else {
670        -((value / 2) as i32) - 1
671    }
672}
673
674/// Estimate the total encoding cost of using a WP parameter set on the given channels.
675///
676/// Runs the weighted predictor over every pixel, computes residuals, and
677/// estimates Shannon entropy + HybridUint extra bits as a cost proxy.
678/// Matching libjxl's EstimateWPCost (enc_modular.cc:238-287).
679pub fn estimate_wp_cost(channels: &[super::Channel], params: &WeightedPredictorParams) -> f64 {
680    // Use 256-bin histogram for entropy estimation
681    const NUM_BINS: usize = 256;
682    let mut histogram = [0u32; NUM_BINS];
683    let mut total_extra_bits = 0u64;
684    let mut total_samples = 0u64;
685
686    for channel in channels {
687        let width = channel.width();
688        let height = channel.height();
689        if width == 0 || height == 0 {
690            continue;
691        }
692
693        let mut wp_state = WeightedPredictorState::new(params, width);
694
695        for y in 0..height {
696            for x in 0..width {
697                let pixel = channel.get(x, y);
698                let neighbors = Neighbors::gather(channel, x, y);
699                let prediction = wp_state.predict(x, y, width, &neighbors);
700
701                let residual = pixel - prediction;
702                let packed = pack_signed(residual);
703
704                // Bin the packed residual for histogram
705                let bin = if packed < NUM_BINS as u32 {
706                    packed as usize
707                } else {
708                    // For large residuals, count extra bits needed
709                    let bits = 32 - packed.leading_zeros();
710                    total_extra_bits += bits as u64;
711                    NUM_BINS - 1
712                };
713                histogram[bin] += 1;
714                total_samples += 1;
715
716                wp_state.update_errors(pixel, x, y, width);
717            }
718        }
719    }
720
721    if total_samples == 0 {
722        return 0.0;
723    }
724
725    // Estimate Shannon entropy from histogram
726    let total_f = total_samples as f64;
727    let mut entropy = 0.0f64;
728    for &count in &histogram {
729        if count > 0 {
730            let p = count as f64 / total_f;
731            entropy -= p * jxl_simd::fast_log2f(p as f32) as f64;
732        }
733    }
734
735    // Total cost = entropy bits + extra bits for large values
736    entropy * total_f + total_extra_bits as f64
737}
738
739/// Find the best WP parameter set by trying `num_sets` modes (0..num_sets).
740///
741/// Returns the best `WeightedPredictorParams` and whether it differs from default.
742/// At effort 8 (kKitten): `num_sets=2` (modes 0-1).
743/// At effort 9+ (kTortoise): `num_sets=5` (modes 0-4).
744pub fn find_best_wp_params(channels: &[super::Channel], num_sets: u8) -> WeightedPredictorParams {
745    if num_sets <= 1 {
746        return WeightedPredictorParams::default();
747    }
748
749    let mut best_cost = f64::MAX;
750    let mut best_mode = 0u8;
751
752    for mode in 0..num_sets.min(5) {
753        let params = WeightedPredictorParams::for_mode(mode);
754        let cost = estimate_wp_cost(channels, &params);
755        if cost < best_cost {
756            best_cost = cost;
757            best_mode = mode;
758        }
759    }
760
761    WeightedPredictorParams::for_mode(best_mode)
762}
763
764#[cfg(test)]
765mod tests {
766    use super::*;
767
768    #[test]
769    fn test_predictors() {
770        let mut channel = Channel::new(4, 4).unwrap();
771
772        // Set up a simple gradient pattern
773        for y in 0..4 {
774            for x in 0..4 {
775                channel.set(x, y, (x + y * 4) as i32);
776            }
777        }
778
779        // Test at position (2, 2)
780        // Pattern: 0  1  2  3
781        //          4  5  6  7
782        //          8  9 [10] 11
783        //         12 13 14 15
784
785        let neighbors = Neighbors::gather(&channel, 2, 2);
786        assert_eq!(neighbors.n, 6); // Top
787        assert_eq!(neighbors.w, 9); // Left
788        assert_eq!(neighbors.nw, 5); // Top-left
789        assert_eq!(neighbors.ne, 7); // Top-right
790
791        // Test Zero predictor
792        assert_eq!(Predictor::Zero.predict_from_neighbors(&neighbors), 0);
793
794        // Test Left predictor
795        assert_eq!(Predictor::Left.predict_from_neighbors(&neighbors), 9);
796
797        // Test Top predictor
798        assert_eq!(Predictor::Top.predict_from_neighbors(&neighbors), 6);
799
800        // Test Gradient predictor: 9 + 6 - 5 = 10, clamped to [6, 9] = 9
801        assert_eq!(Predictor::Gradient.predict_from_neighbors(&neighbors), 9);
802    }
803
804    #[test]
805    fn test_pack_signed() {
806        assert_eq!(pack_signed(0), 0);
807        assert_eq!(pack_signed(1), 2);
808        assert_eq!(pack_signed(-1), 1);
809        assert_eq!(pack_signed(2), 4);
810        assert_eq!(pack_signed(-2), 3);
811    }
812
813    #[test]
814    fn test_unpack_signed() {
815        assert_eq!(unpack_signed(0), 0);
816        assert_eq!(unpack_signed(1), -1);
817        assert_eq!(unpack_signed(2), 1);
818        assert_eq!(unpack_signed(3), -2);
819        assert_eq!(unpack_signed(4), 2);
820    }
821
822    #[test]
823    fn test_pack_unpack_roundtrip() {
824        for i in -1000..=1000 {
825            assert_eq!(unpack_signed(pack_signed(i)), i);
826        }
827    }
828
829    #[test]
830    fn test_weighted_predictor_params_default() {
831        let params = WeightedPredictorParams::default();
832        assert_eq!(params.p1c, 16);
833        assert_eq!(params.p2c, 10);
834        assert_eq!(params.w0, 0xd);
835        assert!(params.is_default());
836    }
837
838    #[test]
839    fn test_weighted_predictor_state() {
840        let xsize = 8;
841        let mut wp = WeightedPredictorState::with_defaults(xsize);
842
843        // Test prediction on uniform data
844        let neighbors = Neighbors {
845            n: 100,
846            w: 100,
847            nw: 100,
848            ne: 100,
849            nn: 100,
850            ww: 100,
851            nee: 100,
852        };
853
854        let (pred, _prop) = wp.predict_and_property(4, 2, xsize, &neighbors);
855        // For uniform data, prediction should be close to 100
856        assert!((pred - 100).abs() <= 2);
857
858        // Update with actual value
859        wp.update_errors(100, 4, 2, xsize);
860    }
861
862    #[test]
863    fn test_weighted_predictor_adapts() {
864        let xsize = 8;
865        let mut wp = WeightedPredictorState::with_defaults(xsize);
866
867        // Simulate processing a row with gradient pattern
868        for x in 0..xsize {
869            let actual = (x * 10) as i32;
870            let neighbors = Neighbors {
871                n: if x > 0 { ((x - 1) * 10) as i32 } else { 0 },
872                w: if x > 0 { ((x - 1) * 10) as i32 } else { 0 },
873                nw: if x > 1 { ((x - 2) * 10) as i32 } else { 0 },
874                ne: (x * 10) as i32,
875                nn: 0,
876                ww: if x > 1 { ((x - 2) * 10) as i32 } else { 0 },
877                nee: 0,
878            };
879
880            let (_pred, _prop) = wp.predict_and_property(x, 1, xsize, &neighbors);
881            wp.update_errors(actual, x, 1, xsize);
882        }
883        // Just verify it doesn't panic
884    }
885
886    /// Reproduce jxl-rs golden-number test to verify bit-exactness.
887    #[test]
888    fn test_wp_matches_jxl_rs_golden() {
889        struct SimpleRandom {
890            out: i64,
891        }
892        impl SimpleRandom {
893            fn new() -> Self {
894                Self { out: 1 }
895            }
896            fn next(&mut self) -> i64 {
897                self.out = self.out * 48271 % 0x7fffffff;
898                self.out
899            }
900        }
901
902        let mut rng = SimpleRandom::new();
903        let params = WeightedPredictorParams {
904            p1c: rng.next() as u32 % 32,
905            p2c: rng.next() as u32 % 32,
906            p3ca: rng.next() as u32 % 32,
907            p3cb: rng.next() as u32 % 32,
908            p3cc: rng.next() as u32 % 32,
909            p3cd: rng.next() as u32 % 32,
910            p3ce: rng.next() as u32 % 32,
911            w0: rng.next() as u32 % 16,
912            w1: rng.next() as u32 % 16,
913            w2: rng.next() as u32 % 16,
914            w3: rng.next() as u32 % 16,
915        };
916        let xsize = 8;
917        let ysize = 8;
918        let mut state = WeightedPredictorState::new(&params, xsize);
919
920        // Helper: one step of predict + update
921        let step = |rng: &mut SimpleRandom, state: &mut WeightedPredictorState| -> (i64, i32) {
922            let x = rng.next() as usize % xsize;
923            let y = rng.next() as usize % ysize;
924            let neighbors = Neighbors {
925                n: rng.next() as i32 % 256,  // top
926                w: rng.next() as i32 % 256,  // left
927                ne: rng.next() as i32 % 256, // topright
928                nw: rng.next() as i32 % 256, // topleft
929                nn: rng.next() as i32 % 256, // toptop
930                ww: 0,
931                nee: 0,
932            };
933            let res = state.predict_and_property(x, y, xsize, &neighbors);
934            state.update_errors((rng.next() % 256) as i32, x, y, xsize);
935            res
936        };
937
938        // Golden numbers from libjxl (verified in jxl-rs test)
939        assert_eq!(step(&mut rng, &mut state), (135, 0), "step 1");
940        assert_eq!(step(&mut rng, &mut state), (110, -60), "step 2");
941        assert_eq!(step(&mut rng, &mut state), (165, 0), "step 3");
942        assert_eq!(step(&mut rng, &mut state), (153, -60), "step 4");
943    }
944}