Skip to main content

jxl_encoder/modular/
tree.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! Decision tree for modular encoding context selection.
6//!
7//! The tree determines which predictor and context to use for each pixel
8//! based on properties of the neighborhood.
9
10use super::predictor::Predictor;
11
12/// A node in the property decision tree.
13#[derive(Debug, Clone)]
14pub struct PropertyDecisionNode {
15    /// Property to split on (-1 = leaf node).
16    pub property: i32,
17    /// Split threshold value.
18    pub splitval: i32,
19    /// Predictor to use (for leaf nodes).
20    pub predictor: Predictor,
21    /// Offset for predictor (for leaf nodes).
22    pub predictor_offset: i32,
23    /// Multiplier for residual (for leaf nodes).
24    pub multiplier: i32,
25    /// Left child index (value <= splitval).
26    pub lchild: usize,
27    /// Right child index (value > splitval).
28    pub rchild: usize,
29    /// Context ID for ANS coding.
30    pub context_id: u32,
31}
32
33impl Default for PropertyDecisionNode {
34    fn default() -> Self {
35        Self {
36            property: -1, // Leaf node
37            splitval: 0,
38            predictor: Predictor::Gradient,
39            predictor_offset: 0,
40            multiplier: 1,
41            lchild: 0,
42            rchild: 0,
43            context_id: 0,
44        }
45    }
46}
47
48/// A decision tree for context selection.
49pub type Tree = Vec<PropertyDecisionNode>;
50
51/// Property indices for tree decisions.
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53#[repr(i32)]
54pub enum Property {
55    /// Channel index.
56    Channel = 0,
57    /// Group ID.
58    GroupId = 1,
59    /// Y coordinate.
60    Y = 2,
61    /// X coordinate.
62    X = 3,
63    /// |N - NW|
64    AbsNMinusNw = 4,
65    /// |N - W|
66    AbsNMinusW = 5,
67    /// FloorLog2(W)
68    FloorLog2W = 6,
69    /// FloorLog2(N)
70    FloorLog2N = 7,
71    /// FloorLog2(NW)
72    FloorLog2Nw = 8,
73    /// |N - NN|
74    AbsNMinusNn = 9,
75    /// |W - WW|
76    AbsWMinusWw = 10,
77    /// |NW - NWW|
78    AbsNwMinusNww = 11,
79    /// |NE - N|
80    AbsNeMinusN = 12,
81    /// |NW - W|
82    AbsNwMinusW = 13,
83    /// |W| + |N| + |NW|
84    SumWNNw = 14,
85    /// Max error in weighted predictor.
86    WpMaxError = 15,
87}
88
89impl Property {
90    /// Total number of static properties (not including WP properties).
91    pub const NUM_STATIC: usize = 14;
92
93    /// Total number of properties including WP.
94    pub const NUM_PROPERTIES: usize = 16;
95}
96
97/// Properties computed for a pixel location.
98#[derive(Debug, Clone, Default)]
99pub struct PixelProperties {
100    /// Property values.
101    pub values: [i32; Property::NUM_PROPERTIES],
102}
103
104impl PixelProperties {
105    /// Computes properties for a pixel.
106    #[allow(clippy::too_many_arguments)]
107    pub fn compute(
108        channel_idx: u32,
109        group_id: u32,
110        x: usize,
111        y: usize,
112        n: i32,
113        w: i32,
114        nw: i32,
115        ne: i32,
116        nn: i32,
117        ww: i32,
118        nww: i32,
119    ) -> Self {
120        let mut values = [0i32; Property::NUM_PROPERTIES];
121
122        values[Property::Channel as usize] = channel_idx as i32;
123        values[Property::GroupId as usize] = group_id as i32;
124        values[Property::Y as usize] = y as i32;
125        values[Property::X as usize] = x as i32;
126        values[Property::AbsNMinusNw as usize] = (n - nw).abs();
127        values[Property::AbsNMinusW as usize] = (n - w).abs();
128        values[Property::FloorLog2W as usize] = floor_log2(w.unsigned_abs());
129        values[Property::FloorLog2N as usize] = floor_log2(n.unsigned_abs());
130        values[Property::FloorLog2Nw as usize] = floor_log2(nw.unsigned_abs());
131        values[Property::AbsNMinusNn as usize] = (n - nn).abs();
132        values[Property::AbsWMinusWw as usize] = (w - ww).abs();
133        values[Property::AbsNwMinusNww as usize] = (nw - nww).abs();
134        values[Property::AbsNeMinusN as usize] = (ne - n).abs();
135        values[Property::AbsNwMinusW as usize] = (nw - w).abs();
136        values[Property::SumWNNw as usize] = w.abs() + n.abs() + nw.abs();
137        values[Property::WpMaxError as usize] = 0; // Filled in by WP state
138
139        Self { values }
140    }
141
142    /// Gets a property value.
143    #[inline]
144    pub fn get(&self, property: i32) -> i32 {
145        if property >= 0 && (property as usize) < self.values.len() {
146            self.values[property as usize]
147        } else {
148            0
149        }
150    }
151}
152
153/// Floor log2 for unsigned values (returns 0 for 0).
154#[inline]
155fn floor_log2(value: u32) -> i32 {
156    if value == 0 {
157        0
158    } else {
159        31 - value.leading_zeros() as i32
160    }
161}
162
163/// Creates a simple tree that uses a single predictor for all pixels.
164pub fn simple_tree(predictor: Predictor) -> Tree {
165    vec![PropertyDecisionNode {
166        property: -1, // Leaf
167        predictor,
168        context_id: 0,
169        ..Default::default()
170    }]
171}
172
173/// Creates a gradient tree (most common for lossless).
174pub fn gradient_tree() -> Tree {
175    simple_tree(Predictor::Gradient)
176}
177
178/// Creates a tree that selects predictor based on channel.
179#[allow(dead_code)]
180pub fn per_channel_tree(num_channels: usize) -> Tree {
181    let mut tree = Vec::with_capacity(num_channels * 2);
182
183    // Build a simple chain: if channel == 0, use ctx 0; if channel == 1, use ctx 1; etc.
184    for c in 0..num_channels {
185        if c < num_channels - 1 {
186            // Internal node: split on channel
187            tree.push(PropertyDecisionNode {
188                property: Property::Channel as i32,
189                splitval: c as i32,
190                lchild: tree.len() + num_channels - c, // Leaf for this channel
191                rchild: tree.len() + 1,                // Next decision
192                ..Default::default()
193            });
194        }
195    }
196
197    // Leaf nodes
198    for c in 0..num_channels {
199        tree.push(PropertyDecisionNode {
200            property: -1,
201            predictor: Predictor::Gradient,
202            context_id: c as u32,
203            ..Default::default()
204        });
205    }
206
207    tree
208}
209
210/// Traverses the tree to find the leaf node for given properties.
211pub fn traverse_tree<'a>(tree: &'a Tree, properties: &PixelProperties) -> &'a PropertyDecisionNode {
212    let mut node_idx = 0;
213
214    loop {
215        let node = &tree[node_idx];
216
217        // Leaf node?
218        if node.property < 0 {
219            return node;
220        }
221
222        // Get property value and decide direction
223        let prop_value = properties.get(node.property);
224        if prop_value <= node.splitval {
225            node_idx = node.lchild;
226        } else {
227            node_idx = node.rchild;
228        }
229    }
230}
231
232/// Tree serialization context indices.
233const SPLIT_VAL_CONTEXT: usize = 0;
234const PROPERTY_CONTEXT: usize = 1;
235const PREDICTOR_CONTEXT: usize = 2;
236const OFFSET_CONTEXT: usize = 3;
237const MULTIPLIER_LOG_CONTEXT: usize = 4;
238const MULTIPLIER_BITS_CONTEXT: usize = 5;
239
240/// Token for tree serialization.
241#[derive(Debug, Clone)]
242pub struct TreeToken {
243    /// Context for this token.
244    pub context: usize,
245    /// Token value (unsigned for property/predictor/log, signed for split_val/offset).
246    pub value: i32,
247    /// Whether this is a signed value.
248    pub is_signed: bool,
249}
250
251/// Collect tokens for tree serialization.
252pub fn collect_tree_tokens(tree: &Tree) -> Vec<TreeToken> {
253    let mut tokens = Vec::new();
254
255    // Process tree in BFS order
256    let mut queue = std::collections::VecDeque::new();
257    queue.push_back(0usize);
258
259    while let Some(idx) = queue.pop_front() {
260        let node = &tree[idx];
261
262        if node.property < 0 {
263            // Leaf node: property = 0 (indicator), then predictor, offset, mul_log, mul_bits
264            tokens.push(TreeToken {
265                context: PROPERTY_CONTEXT,
266                value: 0, // 0 means leaf
267                is_signed: false,
268            });
269
270            // Predictor
271            tokens.push(TreeToken {
272                context: PREDICTOR_CONTEXT,
273                value: node.predictor as i32,
274                is_signed: false,
275            });
276
277            // Offset (signed)
278            tokens.push(TreeToken {
279                context: OFFSET_CONTEXT,
280                value: node.predictor_offset,
281                is_signed: true,
282            });
283
284            // Multiplier is encoded as (mul_bits + 1) << mul_log
285            // For multiplier = 1: mul_log = 0, mul_bits = 0
286            let (mul_log, mul_bits) = decompose_multiplier(node.multiplier as u32);
287            tokens.push(TreeToken {
288                context: MULTIPLIER_LOG_CONTEXT,
289                value: mul_log as i32,
290                is_signed: false,
291            });
292
293            tokens.push(TreeToken {
294                context: MULTIPLIER_BITS_CONTEXT,
295                value: mul_bits as i32,
296                is_signed: false,
297            });
298        } else {
299            // Split node: property+1, splitval, then children
300            tokens.push(TreeToken {
301                context: PROPERTY_CONTEXT,
302                value: node.property + 1, // +1 because 0 means leaf
303                is_signed: false,
304            });
305
306            tokens.push(TreeToken {
307                context: SPLIT_VAL_CONTEXT,
308                value: node.splitval,
309                is_signed: true,
310            });
311
312            // Queue children: rchild first (value > splitval = decoder's "left"/first BFS child),
313            // then lchild (value <= splitval = decoder's "right"/second BFS child).
314            // jxl-rs reads first BFS child as "left" (property > splitval).
315            queue.push_back(node.rchild);
316            queue.push_back(node.lchild);
317        }
318    }
319
320    tokens
321}
322
323/// Decompose multiplier into (log, bits) where multiplier = (bits + 1) << log.
324fn decompose_multiplier(multiplier: u32) -> (u32, u32) {
325    if multiplier == 0 {
326        return (0, 0);
327    }
328
329    let trailing = multiplier.trailing_zeros();
330    let mul_log = trailing;
331    let mul_bits = (multiplier >> trailing) - 1;
332
333    (mul_log, mul_bits)
334}
335
336/// Creates a tree with the weighted predictor.
337pub fn weighted_tree() -> Tree {
338    simple_tree(Predictor::Weighted)
339}
340
341/// Creates a tree that selects between Gradient and Weighted based on WP max error.
342/// Uses Gradient when max error is low (WP is stable), Weighted when error is higher.
343pub fn adaptive_gradient_weighted_tree() -> Tree {
344    vec![
345        // Root: split on WP max error (property 15)
346        PropertyDecisionNode {
347            property: Property::WpMaxError as i32,
348            splitval: 100, // Threshold
349            lchild: 1,     // Low error -> gradient
350            rchild: 2,     // High error -> weighted
351            ..Default::default()
352        },
353        // Leaf: Gradient predictor (for stable regions)
354        PropertyDecisionNode {
355            property: -1,
356            predictor: Predictor::Gradient,
357            context_id: 0,
358            ..Default::default()
359        },
360        // Leaf: Weighted predictor (for complex regions)
361        PropertyDecisionNode {
362            property: -1,
363            predictor: Predictor::Weighted,
364            context_id: 1,
365            ..Default::default()
366        },
367    ]
368}
369
370/// Validate tree structure matching libjxl's ValidateTree in dec_ma.cc.
371///
372/// Tracks property ranges as the tree narrows them through splits.
373/// Returns Ok(()) if valid, Err with details of the failing node.
374///
375/// Convention: lchild = value <= splitval, rchild = value > splitval.
376///
377/// But the decoder reads BFS where first child is "lchild" (value > splitval)
378/// and second child is "rchild" (value <= splitval). So we map:
379/// - Our rchild → decoder lchild: range [val+1, u]
380/// - Our lchild → decoder rchild: range [l, val]
381pub fn validate_tree_djxl(tree: &Tree) -> Result<(), String> {
382    if tree.is_empty() {
383        return Ok(());
384    }
385
386    let mut num_properties = 0i32;
387    for node in tree {
388        if node.property >= num_properties {
389            num_properties = node.property + 1;
390        }
391    }
392    let np = num_properties as usize;
393
394    // Track (lo, hi) range per property per node
395    // Range is [lo, hi] inclusive; split at val requires lo <= val && val < hi
396    // (in libjxl terms: u > val, meaning hi > val)
397    let mut ranges: Vec<(i32, i32)> = vec![(i32::MIN, i32::MAX); np * tree.len()];
398
399    for (i, node) in tree.iter().enumerate() {
400        if node.property < 0 {
401            continue; // leaf
402        }
403        let p = node.property as usize;
404        let val = node.splitval;
405        let lo = ranges[i * np + p].0;
406        let hi = ranges[i * np + p].1;
407
408        // libjxl check: if (l > val || u <= val) return FAILURE
409        if lo > val || hi <= val {
410            return Err(format!(
411                "Node {} (property={}, splitval={}): range [{}, {}] invalid \
412                 (lo > val = {}, hi <= val = {})",
413                i,
414                node.property,
415                val,
416                lo,
417                hi,
418                lo > val,
419                hi <= val
420            ));
421        }
422
423        let lchild = node.lchild; // value <= splitval
424        let rchild = node.rchild; // value > splitval
425
426        // Copy all property ranges to children
427        for pp in 0..np {
428            ranges[rchild * np + pp] = ranges[i * np + pp];
429            ranges[lchild * np + pp] = ranges[i * np + pp];
430        }
431
432        // Narrow property p for children
433        // rchild (value > splitval): lo = val + 1
434        ranges[rchild * np + p] = (val + 1, hi);
435        // lchild (value <= splitval): hi = val
436        ranges[lchild * np + p] = (lo, val);
437    }
438
439    Ok(())
440}
441
442/// Count the number of unique context IDs used in a tree.
443/// Count the number of BFS-reachable leaf contexts in the tree.
444///
445/// Only counts leaves reachable from root via BFS traversal, ignoring
446/// unreachable orphan nodes that may exist after tree validation.
447pub fn count_contexts(tree: &Tree) -> u32 {
448    let mut count = 0u32;
449    let mut queue = std::collections::VecDeque::new();
450    queue.push_back(0usize);
451
452    while let Some(idx) = queue.pop_front() {
453        if tree[idx].property < 0 {
454            count += 1;
455        } else {
456            queue.push_back(tree[idx].rchild);
457            queue.push_back(tree[idx].lchild);
458        }
459    }
460    count.max(1)
461}
462
463/// Assign context IDs to leaf nodes sequentially in BFS order.
464///
465/// The decoder assigns context IDs to leaves in the order it encounters them
466/// during BFS deserialization (rchild first, then lchild — matching
467/// `collect_tree_tokens`). We must use the same traversal order here so that
468/// context IDs in the encoder match what the decoder derives.
469///
470/// Returns the number of contexts assigned.
471pub fn assign_sequential_contexts(tree: &mut Tree) -> u32 {
472    let mut next_context = 0u32;
473    let mut queue = std::collections::VecDeque::new();
474    queue.push_back(0usize);
475
476    while let Some(idx) = queue.pop_front() {
477        if tree[idx].property < 0 {
478            tree[idx].context_id = next_context;
479            next_context += 1;
480        } else {
481            let rchild = tree[idx].rchild;
482            let lchild = tree[idx].lchild;
483            // Same child order as collect_tree_tokens: rchild first, lchild second
484            queue.push_back(rchild);
485            queue.push_back(lchild);
486        }
487    }
488    next_context
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    #[test]
496    fn test_floor_log2() {
497        assert_eq!(floor_log2(0), 0);
498        assert_eq!(floor_log2(1), 0);
499        assert_eq!(floor_log2(2), 1);
500        assert_eq!(floor_log2(3), 1);
501        assert_eq!(floor_log2(4), 2);
502        assert_eq!(floor_log2(255), 7);
503        assert_eq!(floor_log2(256), 8);
504    }
505
506    #[test]
507    fn test_simple_tree() {
508        let tree = simple_tree(Predictor::Left);
509        assert_eq!(tree.len(), 1);
510        assert_eq!(tree[0].property, -1);
511        assert_eq!(tree[0].predictor, Predictor::Left);
512    }
513
514    #[test]
515    fn test_traverse_simple() {
516        let tree = gradient_tree();
517        let props = PixelProperties::default();
518        let leaf = traverse_tree(&tree, &props);
519        assert_eq!(leaf.predictor, Predictor::Gradient);
520        assert_eq!(leaf.context_id, 0);
521    }
522
523    #[test]
524    fn test_weighted_tree() {
525        let tree = weighted_tree();
526        assert_eq!(tree.len(), 1);
527        assert_eq!(tree[0].predictor, Predictor::Weighted);
528    }
529
530    #[test]
531    fn test_decompose_multiplier() {
532        assert_eq!(decompose_multiplier(1), (0, 0)); // (0+1) << 0 = 1
533        assert_eq!(decompose_multiplier(2), (1, 0)); // (0+1) << 1 = 2
534        assert_eq!(decompose_multiplier(4), (2, 0)); // (0+1) << 2 = 4
535        assert_eq!(decompose_multiplier(3), (0, 2)); // (2+1) << 0 = 3
536        assert_eq!(decompose_multiplier(6), (1, 2)); // (2+1) << 1 = 6
537    }
538
539    #[test]
540    fn test_collect_tree_tokens_simple() {
541        let tree = gradient_tree();
542        let tokens = collect_tree_tokens(&tree);
543        // Single leaf: property(0), predictor(5), offset(0), mul_log(0), mul_bits(0)
544        assert_eq!(tokens.len(), 5);
545        assert_eq!(tokens[0].value, 0); // property = 0 (leaf)
546        assert_eq!(tokens[1].value, Predictor::Gradient as i32);
547    }
548
549    #[test]
550    fn test_adaptive_tree() {
551        let tree = adaptive_gradient_weighted_tree();
552        assert_eq!(tree.len(), 3);
553
554        // Test traversal with low error -> gradient
555        let mut props = PixelProperties::default();
556        props.values[Property::WpMaxError as usize] = 50;
557        let leaf = traverse_tree(&tree, &props);
558        assert_eq!(leaf.predictor, Predictor::Gradient);
559
560        // Test traversal with high error -> weighted
561        props.values[Property::WpMaxError as usize] = 150;
562        let leaf = traverse_tree(&tree, &props);
563        assert_eq!(leaf.predictor, Predictor::Weighted);
564    }
565
566    #[test]
567    fn test_count_contexts() {
568        let tree = gradient_tree();
569        assert_eq!(count_contexts(&tree), 1);
570
571        let tree = adaptive_gradient_weighted_tree();
572        assert_eq!(count_contexts(&tree), 2);
573    }
574}