1use super::channel::Channel;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
14#[repr(u8)]
15pub enum Predictor {
16 #[default]
18 Zero = 0,
19 Left = 1,
21 Top = 2,
23 Average0 = 3,
25 Select = 4,
27 Gradient = 5,
29 Weighted = 6,
31 TopRight = 7,
33 TopLeft = 8,
35 LeftLeft = 9,
37 Average1 = 10,
39 Average2 = 11,
41 Average3 = 12,
43 Average4 = 13,
45}
46
47impl Predictor {
48 pub const NUM_SIMPLE: usize = 14;
50
51 pub fn all_simple() -> &'static [Predictor] {
53 &[
54 Predictor::Zero,
55 Predictor::Left,
56 Predictor::Top,
57 Predictor::Average0,
58 Predictor::Select,
59 Predictor::Gradient,
60 Predictor::Weighted,
61 Predictor::TopRight,
62 Predictor::TopLeft,
63 Predictor::LeftLeft,
64 Predictor::Average1,
65 Predictor::Average2,
66 Predictor::Average3,
67 Predictor::Average4,
68 ]
69 }
70
71 #[inline]
73 pub fn predict(self, channel: &Channel, x: usize, y: usize) -> i32 {
74 let neighbors = Neighbors::gather(channel, x, y);
75 self.predict_from_neighbors(&neighbors)
76 }
77
78 #[inline]
80 pub fn predict_from_neighbors(self, n: &Neighbors) -> i32 {
81 match self {
82 Predictor::Zero => 0,
83 Predictor::Left => n.w,
84 Predictor::Top => n.n,
85 Predictor::Average0 => (n.w + n.n) / 2,
86 Predictor::Select => {
87 if n.w.abs_diff(n.nw) < n.n.abs_diff(n.nw) {
89 n.n
90 } else {
91 n.w
92 }
93 }
94 Predictor::Gradient => {
95 let gradient = n.w.saturating_add(n.n).saturating_sub(n.nw);
97 gradient.clamp(n.w.min(n.n), n.w.max(n.n))
98 }
99 Predictor::Weighted => {
100 let gradient = n.w.saturating_add(n.n).saturating_sub(n.nw);
103 gradient.clamp(n.w.min(n.n), n.w.max(n.n))
104 }
105 Predictor::TopRight => n.ne,
106 Predictor::TopLeft => n.nw,
107 Predictor::LeftLeft => n.ww,
108 Predictor::Average1 => {
109 let avg = (n.w + n.n) / 2;
110 let grad = n.w.saturating_add(n.n).saturating_sub(n.nw);
111 (avg + grad) / 2
112 }
113 Predictor::Average2 => {
114 let avg = (n.w + n.n) / 2;
115 (avg + n.w) / 2
116 }
117 Predictor::Average3 => {
118 let avg = (n.w + n.n) / 2;
119 (avg + n.n) / 2
120 }
121 Predictor::Average4 => (n.n + n.ne) / 2,
122 }
123 }
124}
125
126#[derive(Debug, Clone, Copy, Default)]
128pub struct Neighbors {
129 pub n: i32,
131 pub w: i32,
133 pub nw: i32,
135 pub ne: i32,
137 pub nn: i32,
139 pub ww: i32,
141}
142
143impl Neighbors {
144 #[inline]
154 pub fn gather(channel: &Channel, x: usize, y: usize) -> Self {
155 let width = channel.width();
156
157 let w = if x > 0 {
158 channel.get(x - 1, y)
159 } else if y > 0 {
160 channel.get(0, y - 1)
161 } else {
162 0
163 };
164
165 let n = if y > 0 { channel.get(x, y - 1) } else { w };
166
167 let nw = if x > 0 && y > 0 {
168 channel.get(x - 1, y - 1)
169 } else {
170 w
171 };
172
173 let ne = if x + 1 < width && y > 0 {
174 channel.get(x + 1, y - 1)
175 } else {
176 n
177 };
178
179 let ww = if x > 1 { channel.get(x - 2, y) } else { w };
180
181 let nn = if y > 1 { channel.get(x, y - 2) } else { n };
182
183 Self {
184 n,
185 w,
186 nw,
187 ne,
188 nn,
189 ww,
190 }
191 }
192
193 #[inline]
195 pub fn gather_fast(
196 row: &[i32],
197 prev_row: Option<&[i32]>,
198 prev_prev_row: Option<&[i32]>,
199 x: usize,
200 _width: usize,
201 ) -> Self {
202 let w = if x > 0 {
203 row[x - 1]
204 } else if let Some(prev) = prev_row {
205 prev[0]
206 } else {
207 0
208 };
209
210 let n = if let Some(prev) = prev_row {
211 prev[x]
212 } else {
213 w
214 };
215
216 let nw = if x > 0 {
217 if let Some(prev) = prev_row {
218 prev[x - 1]
219 } else {
220 w
221 }
222 } else {
223 w
224 };
225
226 let ne = if let Some(prev) = prev_row {
227 if x + 1 < prev.len() { prev[x + 1] } else { n }
228 } else {
229 n
230 };
231
232 let ww = if x > 1 { row[x - 2] } else { w };
233
234 let nn = if let Some(pp) = prev_prev_row {
235 pp[x]
236 } else {
237 n
238 };
239
240 Self {
241 n,
242 w,
243 nw,
244 ne,
245 nn,
246 ww,
247 }
248 }
249}
250
251const NUM_WP_PREDICTORS: usize = 4;
253const PRED_EXTRA_BITS: i64 = 3;
255const PREDICTION_ROUND: i64 = ((1 << PRED_EXTRA_BITS) >> 1) - 1;
257
258const DIVLOOKUP: [u32; 64] = [
261 16777216, 8388608, 5592405, 4194304, 3355443, 2796202, 2396745, 2097152, 1864135, 1677721,
262 1525201, 1398101, 1290555, 1198372, 1118481, 1048576, 986895, 932067, 883011, 838860, 798915,
263 762600, 729444, 699050, 671088, 645277, 621378, 599186, 578524, 559240, 541200, 524288, 508400,
264 493447, 479349, 466033, 453438, 441505, 430185, 419430, 409200, 399457, 390167, 381300, 372827,
265 364722, 356962, 349525, 342392, 335544, 328965, 322638, 316551, 310689, 305040, 299593, 294337,
266 289262, 284359, 279620, 275036, 270600, 266305, 262144,
267];
268
269#[derive(Debug, Clone, Copy)]
271pub struct WeightedPredictorParams {
272 pub p1c: u32,
274 pub p2c: u32,
276 pub p3ca: u32,
278 pub p3cb: u32,
279 pub p3cc: u32,
280 pub p3cd: u32,
281 pub p3ce: u32,
282 pub w0: u32,
284 pub w1: u32,
285 pub w2: u32,
286 pub w3: u32,
287}
288
289impl Default for WeightedPredictorParams {
290 fn default() -> Self {
291 Self {
293 p1c: 16,
294 p2c: 10,
295 p3ca: 7,
296 p3cb: 7,
297 p3cc: 7,
298 p3cd: 0,
299 p3ce: 0,
300 w0: 0xd,
301 w1: 0xc,
302 w2: 0xc,
303 w3: 0xc,
304 }
305 }
306}
307
308impl WeightedPredictorParams {
309 pub fn w(&self, i: usize) -> u32 {
311 match i {
312 0 => self.w0,
313 1 => self.w1,
314 2 => self.w2,
315 3 => self.w3,
316 _ => panic!("Invalid weight index"),
317 }
318 }
319
320 pub fn is_default(&self) -> bool {
322 *self == Self::default()
323 }
324}
325
326impl PartialEq for WeightedPredictorParams {
327 fn eq(&self, other: &Self) -> bool {
328 self.p1c == other.p1c
329 && self.p2c == other.p2c
330 && self.p3ca == other.p3ca
331 && self.p3cb == other.p3cb
332 && self.p3cc == other.p3cc
333 && self.p3cd == other.p3cd
334 && self.p3ce == other.p3ce
335 && self.w0 == other.w0
336 && self.w1 == other.w1
337 && self.w2 == other.w2
338 && self.w3 == other.w3
339 }
340}
341
342#[inline]
344fn floor_log2_nonzero(x: u64) -> u32 {
345 63 - x.leading_zeros()
346}
347
348#[inline]
350fn add_bits(x: i32) -> i64 {
351 (x as i64) << PRED_EXTRA_BITS
352}
353
354#[inline]
356fn error_weight(x: u32, maxweight: u32) -> u32 {
357 let shift = floor_log2_nonzero(x as u64 + 1) as i32 - 5;
358 if shift < 0 {
359 4u32 + maxweight * DIVLOOKUP[x as usize & 63]
360 } else {
361 4u32 + ((maxweight * DIVLOOKUP[(x as usize >> shift) & 63]) >> shift)
362 }
363}
364
365fn weighted_average(
367 pixels: &[i64; NUM_WP_PREDICTORS],
368 weights: &mut [u32; NUM_WP_PREDICTORS],
369) -> i64 {
370 let log_weight = floor_log2_nonzero(weights.iter().fold(0u64, |sum, el| sum + *el as u64));
371 let weight_sum = weights.iter_mut().fold(0, |sum, el| {
372 *el >>= log_weight - 4;
373 sum + *el
374 });
375 let sum = weights
376 .iter()
377 .enumerate()
378 .fold(((weight_sum >> 1) - 1) as i64, |sum, (i, weight)| {
379 sum + pixels[i] * *weight as i64
380 });
381 (sum * DIVLOOKUP[(weight_sum - 1) as usize] as i64) >> 24
382}
383
384#[derive(Debug)]
387pub struct WeightedPredictorState {
388 prediction: [i64; NUM_WP_PREDICTORS],
390 pred: i64,
392 pred_errors_buffer: Vec<u32>,
395 error: Vec<i32>,
397 params: WeightedPredictorParams,
399}
400
401impl WeightedPredictorState {
402 pub fn new(params: &WeightedPredictorParams, xsize: usize) -> Self {
404 let num_errors = (xsize + 2) * 2;
405 Self {
406 prediction: [0; NUM_WP_PREDICTORS],
407 pred: 0,
408 pred_errors_buffer: vec![0; num_errors * NUM_WP_PREDICTORS],
409 error: vec![0; num_errors],
410 params: *params,
411 }
412 }
413
414 pub fn with_defaults(xsize: usize) -> Self {
416 Self::new(&WeightedPredictorParams::default(), xsize)
417 }
418
419 #[inline(always)]
421 fn get_errors_at_pos(&self, pos: usize) -> &[u32; NUM_WP_PREDICTORS] {
422 let start = pos * NUM_WP_PREDICTORS;
423 self.pred_errors_buffer[start..start + NUM_WP_PREDICTORS]
424 .try_into()
425 .unwrap()
426 }
427
428 #[inline(always)]
430 fn get_errors_at_pos_mut(&mut self, pos: usize) -> &mut [u32; NUM_WP_PREDICTORS] {
431 let start = pos * NUM_WP_PREDICTORS;
432 (&mut self.pred_errors_buffer[start..start + NUM_WP_PREDICTORS])
433 .try_into()
434 .unwrap()
435 }
436
437 #[inline]
440 pub fn predict_and_property(
441 &mut self,
442 x: usize,
443 y: usize,
444 xsize: usize,
445 neighbors: &Neighbors,
446 ) -> (i64, i32) {
447 let (cur_row, prev_row) = if y & 1 != 0 {
448 (0, xsize + 2)
449 } else {
450 (xsize + 2, 0)
451 };
452 let pos_n = prev_row + x;
453 let pos_ne = if x < xsize - 1 { pos_n + 1 } else { pos_n };
454 let pos_nw = if x > 0 { pos_n - 1 } else { pos_n };
455
456 let errors_n = self.get_errors_at_pos(pos_n);
458 let errors_ne = self.get_errors_at_pos(pos_ne);
459 let errors_nw = self.get_errors_at_pos(pos_nw);
460
461 let mut weights = [0u32; NUM_WP_PREDICTORS];
463 for i in 0..NUM_WP_PREDICTORS {
464 weights[i] = error_weight(
465 errors_n[i]
466 .wrapping_add(errors_ne[i])
467 .wrapping_add(errors_nw[i]),
468 self.params.w(i),
469 );
470 }
471
472 let n = add_bits(neighbors.n);
474 let w = add_bits(neighbors.w);
475 let ne = add_bits(neighbors.ne);
476 let nw = add_bits(neighbors.nw);
477 let nn = add_bits(neighbors.nn);
478
479 let te_w = if x == 0 {
481 0
482 } else {
483 self.error[cur_row + x - 1] as i64
484 };
485 let te_n = self.error[pos_n] as i64;
486 let te_nw = self.error[pos_nw] as i64;
487 let sum_wn = te_n + te_w;
488 let te_ne = self.error[pos_ne] as i64;
489
490 let mut p = te_w;
492 if te_n.abs() > p.abs() {
493 p = te_n;
494 }
495 if te_nw.abs() > p.abs() {
496 p = te_nw;
497 }
498 if te_ne.abs() > p.abs() {
499 p = te_ne;
500 }
501
502 self.prediction[0] = w + ne - n;
504 self.prediction[1] = n - (((sum_wn + te_ne) * self.params.p1c as i64) >> 5);
505 self.prediction[2] = w - (((sum_wn + te_nw) * self.params.p2c as i64) >> 5);
506 self.prediction[3] = n
507 - ((te_nw * (self.params.p3ca as i64)
508 + (te_n * (self.params.p3cb as i64))
509 + (te_ne * (self.params.p3cc as i64))
510 + ((nn - n) * (self.params.p3cd as i64))
511 + ((nw - w) * (self.params.p3ce as i64)))
512 >> 5);
513
514 self.pred = weighted_average(&self.prediction, &mut weights);
516
517 if ((te_n ^ te_w) | (te_n ^ te_nw)) <= 0 {
519 let mx = w.max(ne.max(n));
520 let mn = w.min(ne.min(n));
521 self.pred = mn.max(mx.min(self.pred));
522 }
523
524 ((self.pred + PREDICTION_ROUND) >> PRED_EXTRA_BITS, p as i32)
525 }
526
527 #[inline]
529 pub fn update_errors(&mut self, actual: i32, x: usize, y: usize, xsize: usize) {
530 let (cur_row, prev_row) = if y & 1 != 0 {
531 (0, xsize + 2)
532 } else {
533 (xsize + 2, 0)
534 };
535 let val = add_bits(actual);
536 self.error[cur_row + x] = (self.pred - val) as i32;
537
538 let mut errs = [0u32; NUM_WP_PREDICTORS];
540 for (err, &pred) in errs.iter_mut().zip(self.prediction.iter()) {
541 *err = (((pred - val).abs() + PREDICTION_ROUND) >> PRED_EXTRA_BITS) as u32;
542 }
543
544 *self.get_errors_at_pos_mut(cur_row + x) = errs;
546
547 let prev_errors = self.get_errors_at_pos_mut(prev_row + x + 1);
549 for i in 0..NUM_WP_PREDICTORS {
550 prev_errors[i] = prev_errors[i].wrapping_add(errs[i]);
551 }
552 }
553
554 pub fn predict(&mut self, x: usize, y: usize, xsize: usize, neighbors: &Neighbors) -> i32 {
556 let (pred, _) = self.predict_and_property(x, y, xsize, neighbors);
557 pred as i32
558 }
559}
560
561impl Default for WeightedPredictorState {
562 fn default() -> Self {
563 Self::with_defaults(256)
564 }
565}
566
567#[inline]
570pub fn pack_signed(value: i32) -> u32 {
571 if value >= 0 {
572 (value as u32) * 2
573 } else {
574 ((-value) as u32) * 2 - 1
575 }
576}
577
578#[inline]
580pub fn unpack_signed(value: u32) -> i32 {
581 if value & 1 == 0 {
582 (value / 2) as i32
583 } else {
584 -((value / 2) as i32) - 1
585 }
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591
592 #[test]
593 fn test_predictors() {
594 let mut channel = Channel::new(4, 4).unwrap();
595
596 for y in 0..4 {
598 for x in 0..4 {
599 channel.set(x, y, (x + y * 4) as i32);
600 }
601 }
602
603 let neighbors = Neighbors::gather(&channel, 2, 2);
610 assert_eq!(neighbors.n, 6); assert_eq!(neighbors.w, 9); assert_eq!(neighbors.nw, 5); assert_eq!(neighbors.ne, 7); assert_eq!(Predictor::Zero.predict_from_neighbors(&neighbors), 0);
617
618 assert_eq!(Predictor::Left.predict_from_neighbors(&neighbors), 9);
620
621 assert_eq!(Predictor::Top.predict_from_neighbors(&neighbors), 6);
623
624 assert_eq!(Predictor::Gradient.predict_from_neighbors(&neighbors), 9);
626 }
627
628 #[test]
629 fn test_pack_signed() {
630 assert_eq!(pack_signed(0), 0);
631 assert_eq!(pack_signed(1), 2);
632 assert_eq!(pack_signed(-1), 1);
633 assert_eq!(pack_signed(2), 4);
634 assert_eq!(pack_signed(-2), 3);
635 }
636
637 #[test]
638 fn test_unpack_signed() {
639 assert_eq!(unpack_signed(0), 0);
640 assert_eq!(unpack_signed(1), -1);
641 assert_eq!(unpack_signed(2), 1);
642 assert_eq!(unpack_signed(3), -2);
643 assert_eq!(unpack_signed(4), 2);
644 }
645
646 #[test]
647 fn test_pack_unpack_roundtrip() {
648 for i in -1000..=1000 {
649 assert_eq!(unpack_signed(pack_signed(i)), i);
650 }
651 }
652
653 #[test]
654 fn test_weighted_predictor_params_default() {
655 let params = WeightedPredictorParams::default();
656 assert_eq!(params.p1c, 16);
657 assert_eq!(params.p2c, 10);
658 assert_eq!(params.w0, 0xd);
659 assert!(params.is_default());
660 }
661
662 #[test]
663 fn test_weighted_predictor_state() {
664 let xsize = 8;
665 let mut wp = WeightedPredictorState::with_defaults(xsize);
666
667 let neighbors = Neighbors {
669 n: 100,
670 w: 100,
671 nw: 100,
672 ne: 100,
673 nn: 100,
674 ww: 100,
675 };
676
677 let (pred, _prop) = wp.predict_and_property(4, 2, xsize, &neighbors);
678 assert!((pred - 100).abs() <= 2);
680
681 wp.update_errors(100, 4, 2, xsize);
683 }
684
685 #[test]
686 fn test_weighted_predictor_adapts() {
687 let xsize = 8;
688 let mut wp = WeightedPredictorState::with_defaults(xsize);
689
690 for x in 0..xsize {
692 let actual = (x * 10) as i32;
693 let neighbors = Neighbors {
694 n: if x > 0 { ((x - 1) * 10) as i32 } else { 0 },
695 w: if x > 0 { ((x - 1) * 10) as i32 } else { 0 },
696 nw: if x > 1 { ((x - 2) * 10) as i32 } else { 0 },
697 ne: (x * 10) as i32,
698 nn: 0,
699 ww: if x > 1 { ((x - 2) * 10) as i32 } else { 0 },
700 };
701
702 let (_pred, _prop) = wp.predict_and_property(x, 1, xsize, &neighbors);
703 wp.update_errors(actual, x, 1, xsize);
704 }
705 }
707
708 #[test]
710 fn test_wp_matches_jxl_rs_golden() {
711 struct SimpleRandom {
712 out: i64,
713 }
714 impl SimpleRandom {
715 fn new() -> Self {
716 Self { out: 1 }
717 }
718 fn next(&mut self) -> i64 {
719 self.out = self.out * 48271 % 0x7fffffff;
720 self.out
721 }
722 }
723
724 let mut rng = SimpleRandom::new();
725 let params = WeightedPredictorParams {
726 p1c: rng.next() as u32 % 32,
727 p2c: rng.next() as u32 % 32,
728 p3ca: rng.next() as u32 % 32,
729 p3cb: rng.next() as u32 % 32,
730 p3cc: rng.next() as u32 % 32,
731 p3cd: rng.next() as u32 % 32,
732 p3ce: rng.next() as u32 % 32,
733 w0: rng.next() as u32 % 16,
734 w1: rng.next() as u32 % 16,
735 w2: rng.next() as u32 % 16,
736 w3: rng.next() as u32 % 16,
737 };
738 let xsize = 8;
739 let ysize = 8;
740 let mut state = WeightedPredictorState::new(¶ms, xsize);
741
742 let step = |rng: &mut SimpleRandom, state: &mut WeightedPredictorState| -> (i64, i32) {
744 let x = rng.next() as usize % xsize;
745 let y = rng.next() as usize % ysize;
746 let neighbors = Neighbors {
747 n: rng.next() as i32 % 256, w: rng.next() as i32 % 256, ne: rng.next() as i32 % 256, nw: rng.next() as i32 % 256, nn: rng.next() as i32 % 256, ww: 0,
753 };
754 let res = state.predict_and_property(x, y, xsize, &neighbors);
755 state.update_errors((rng.next() % 256) as i32, x, y, xsize);
756 res
757 };
758
759 assert_eq!(step(&mut rng, &mut state), (135, 0), "step 1");
761 assert_eq!(step(&mut rng, &mut state), (110, -60), "step 2");
762 assert_eq!(step(&mut rng, &mut state), (165, 0), "step 3");
763 assert_eq!(step(&mut rng, &mut state), (153, -60), "step 4");
764 }
765}