Skip to main content

jxl_encoder/modular/
tree.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! Decision tree for modular encoding context selection.
6//!
7//! The tree determines which predictor and context to use for each pixel
8//! based on properties of the neighborhood.
9
10use super::predictor::Predictor;
11
12/// A node in the property decision tree.
13#[derive(Debug, Clone)]
14pub struct PropertyDecisionNode {
15    /// Property to split on (-1 = leaf node).
16    pub property: i32,
17    /// Split threshold value.
18    pub splitval: i32,
19    /// Predictor to use (for leaf nodes).
20    pub predictor: Predictor,
21    /// Offset for predictor (for leaf nodes).
22    pub predictor_offset: i32,
23    /// Multiplier for residual (for leaf nodes).
24    pub multiplier: i32,
25    /// Left child index (value <= splitval).
26    pub lchild: usize,
27    /// Right child index (value > splitval).
28    pub rchild: usize,
29    /// Context ID for ANS coding.
30    pub context_id: u32,
31}
32
33impl Default for PropertyDecisionNode {
34    fn default() -> Self {
35        Self {
36            property: -1, // Leaf node
37            splitval: 0,
38            predictor: Predictor::Gradient,
39            predictor_offset: 0,
40            multiplier: 1,
41            lchild: 0,
42            rchild: 0,
43            context_id: 0,
44        }
45    }
46}
47
48/// A decision tree for context selection.
49pub type Tree = Vec<PropertyDecisionNode>;
50
51/// Property indices for tree decisions.
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53#[repr(i32)]
54pub enum Property {
55    /// Channel index.
56    Channel = 0,
57    /// Group ID.
58    GroupId = 1,
59    /// Y coordinate.
60    Y = 2,
61    /// X coordinate.
62    X = 3,
63    /// |N - NW|
64    AbsNMinusNw = 4,
65    /// |N - W|
66    AbsNMinusW = 5,
67    /// FloorLog2(W)
68    FloorLog2W = 6,
69    /// FloorLog2(N)
70    FloorLog2N = 7,
71    /// FloorLog2(NW)
72    FloorLog2Nw = 8,
73    /// |N - NN|
74    AbsNMinusNn = 9,
75    /// |W - WW|
76    AbsWMinusWw = 10,
77    /// |NW - NWW|
78    AbsNwMinusNww = 11,
79    /// |NE - N|
80    AbsNeMinusN = 12,
81    /// |NW - W|
82    AbsNwMinusW = 13,
83    /// |W| + |N| + |NW|
84    SumWNNw = 14,
85    /// Max error in weighted predictor.
86    WpMaxError = 15,
87}
88
89impl Property {
90    /// Total number of static properties (not including WP properties).
91    pub const NUM_STATIC: usize = 14;
92
93    /// Total number of properties including WP.
94    pub const NUM_PROPERTIES: usize = 16;
95}
96
97/// Properties computed for a pixel location.
98#[derive(Debug, Clone, Default)]
99pub struct PixelProperties {
100    /// Property values.
101    pub values: [i32; Property::NUM_PROPERTIES],
102}
103
104impl PixelProperties {
105    /// Computes properties for a pixel.
106    #[allow(clippy::too_many_arguments)]
107    pub fn compute(
108        channel_idx: u32,
109        group_id: u32,
110        x: usize,
111        y: usize,
112        n: i32,
113        w: i32,
114        nw: i32,
115        ne: i32,
116        nn: i32,
117        ww: i32,
118        nww: i32,
119    ) -> Self {
120        let mut values = [0i32; Property::NUM_PROPERTIES];
121
122        values[Property::Channel as usize] = channel_idx as i32;
123        values[Property::GroupId as usize] = group_id as i32;
124        values[Property::Y as usize] = y as i32;
125        values[Property::X as usize] = x as i32;
126        values[Property::AbsNMinusNw as usize] = (n - nw).abs();
127        values[Property::AbsNMinusW as usize] = (n - w).abs();
128        values[Property::FloorLog2W as usize] = floor_log2(w.unsigned_abs());
129        values[Property::FloorLog2N as usize] = floor_log2(n.unsigned_abs());
130        values[Property::FloorLog2Nw as usize] = floor_log2(nw.unsigned_abs());
131        values[Property::AbsNMinusNn as usize] = (n - nn).abs();
132        values[Property::AbsWMinusWw as usize] = (w - ww).abs();
133        values[Property::AbsNwMinusNww as usize] = (nw - nww).abs();
134        values[Property::AbsNeMinusN as usize] = (ne - n).abs();
135        values[Property::AbsNwMinusW as usize] = (nw - w).abs();
136        values[Property::SumWNNw as usize] = w.abs() + n.abs() + nw.abs();
137        values[Property::WpMaxError as usize] = 0; // Filled in by WP state
138
139        Self { values }
140    }
141
142    /// Gets a property value.
143    #[inline]
144    pub fn get(&self, property: i32) -> i32 {
145        if property >= 0 && (property as usize) < self.values.len() {
146            self.values[property as usize]
147        } else {
148            0
149        }
150    }
151}
152
153/// Floor log2 for unsigned values (returns 0 for 0).
154#[inline]
155fn floor_log2(value: u32) -> i32 {
156    if value == 0 {
157        0
158    } else {
159        31 - value.leading_zeros() as i32
160    }
161}
162
163/// Creates a simple tree that uses a single predictor for all pixels.
164pub fn simple_tree(predictor: Predictor) -> Tree {
165    vec![PropertyDecisionNode {
166        property: -1, // Leaf
167        predictor,
168        context_id: 0,
169        ..Default::default()
170    }]
171}
172
173/// Creates a gradient tree (most common for lossless).
174pub fn gradient_tree() -> Tree {
175    simple_tree(Predictor::Gradient)
176}
177
178/// Creates a tree that selects predictor based on channel.
179#[allow(dead_code)]
180pub fn per_channel_tree(num_channels: usize) -> Tree {
181    let mut tree = Vec::with_capacity(num_channels * 2);
182
183    // Build a simple chain: if channel == 0, use ctx 0; if channel == 1, use ctx 1; etc.
184    for c in 0..num_channels {
185        if c < num_channels - 1 {
186            // Internal node: split on channel
187            tree.push(PropertyDecisionNode {
188                property: Property::Channel as i32,
189                splitval: c as i32,
190                lchild: tree.len() + num_channels - c, // Leaf for this channel
191                rchild: tree.len() + 1,                // Next decision
192                ..Default::default()
193            });
194        }
195    }
196
197    // Leaf nodes
198    for c in 0..num_channels {
199        tree.push(PropertyDecisionNode {
200            property: -1,
201            predictor: Predictor::Gradient,
202            context_id: c as u32,
203            ..Default::default()
204        });
205    }
206
207    tree
208}
209
210/// Traverses the tree to find the leaf node for given properties.
211pub fn traverse_tree<'a>(tree: &'a Tree, properties: &PixelProperties) -> &'a PropertyDecisionNode {
212    let mut node_idx = 0;
213
214    loop {
215        let node = &tree[node_idx];
216
217        // Leaf node?
218        if node.property < 0 {
219            return node;
220        }
221
222        // Get property value and decide direction
223        let prop_value = properties.get(node.property);
224        if prop_value <= node.splitval {
225            node_idx = node.lchild;
226        } else {
227            node_idx = node.rchild;
228        }
229    }
230}
231
232/// Tree serialization context indices.
233const SPLIT_VAL_CONTEXT: usize = 0;
234const PROPERTY_CONTEXT: usize = 1;
235const PREDICTOR_CONTEXT: usize = 2;
236const OFFSET_CONTEXT: usize = 3;
237const MULTIPLIER_LOG_CONTEXT: usize = 4;
238const MULTIPLIER_BITS_CONTEXT: usize = 5;
239
240/// Token for tree serialization.
241#[derive(Debug, Clone)]
242pub struct TreeToken {
243    /// Context for this token.
244    pub context: usize,
245    /// Token value (unsigned for property/predictor/log, signed for split_val/offset).
246    pub value: i32,
247    /// Whether this is a signed value.
248    pub is_signed: bool,
249}
250
251/// Collect tokens for tree serialization.
252pub fn collect_tree_tokens(tree: &Tree) -> Vec<TreeToken> {
253    let mut tokens = Vec::new();
254
255    // Process tree in BFS order
256    let mut queue = std::collections::VecDeque::new();
257    queue.push_back(0usize);
258
259    while let Some(idx) = queue.pop_front() {
260        let node = &tree[idx];
261
262        if node.property < 0 {
263            // Leaf node: property = 0 (indicator), then predictor, offset, mul_log, mul_bits
264            tokens.push(TreeToken {
265                context: PROPERTY_CONTEXT,
266                value: 0, // 0 means leaf
267                is_signed: false,
268            });
269
270            // Predictor
271            tokens.push(TreeToken {
272                context: PREDICTOR_CONTEXT,
273                value: node.predictor as i32,
274                is_signed: false,
275            });
276
277            // Offset (signed)
278            tokens.push(TreeToken {
279                context: OFFSET_CONTEXT,
280                value: node.predictor_offset,
281                is_signed: true,
282            });
283
284            // Multiplier is encoded as (mul_bits + 1) << mul_log
285            // For multiplier = 1: mul_log = 0, mul_bits = 0
286            let (mul_log, mul_bits) = decompose_multiplier(node.multiplier as u32);
287            tokens.push(TreeToken {
288                context: MULTIPLIER_LOG_CONTEXT,
289                value: mul_log as i32,
290                is_signed: false,
291            });
292
293            tokens.push(TreeToken {
294                context: MULTIPLIER_BITS_CONTEXT,
295                value: mul_bits as i32,
296                is_signed: false,
297            });
298        } else {
299            // Split node: property+1, splitval, then children
300            tokens.push(TreeToken {
301                context: PROPERTY_CONTEXT,
302                value: node.property + 1, // +1 because 0 means leaf
303                is_signed: false,
304            });
305
306            tokens.push(TreeToken {
307                context: SPLIT_VAL_CONTEXT,
308                value: node.splitval,
309                is_signed: true,
310            });
311
312            // Queue children: rchild first (value > splitval = decoder's "left"/first BFS child),
313            // then lchild (value <= splitval = decoder's "right"/second BFS child).
314            // jxl-rs reads first BFS child as "left" (property > splitval).
315            queue.push_back(node.rchild);
316            queue.push_back(node.lchild);
317        }
318    }
319
320    tokens
321}
322
323/// Decompose multiplier into (log, bits) where multiplier = (bits + 1) << log.
324fn decompose_multiplier(multiplier: u32) -> (u32, u32) {
325    if multiplier == 0 {
326        return (0, 0);
327    }
328
329    let trailing = multiplier.trailing_zeros();
330    let mul_log = trailing;
331    let mul_bits = (multiplier >> trailing) - 1;
332
333    (mul_log, mul_bits)
334}
335
336/// Creates a tree with the weighted predictor.
337pub fn weighted_tree() -> Tree {
338    simple_tree(Predictor::Weighted)
339}
340
341/// Creates a tree that selects between Gradient and Weighted based on WP max error.
342/// Uses Gradient when max error is low (WP is stable), Weighted when error is higher.
343pub fn adaptive_gradient_weighted_tree() -> Tree {
344    vec![
345        // Root: split on WP max error (property 15)
346        PropertyDecisionNode {
347            property: Property::WpMaxError as i32,
348            splitval: 100, // Threshold
349            lchild: 1,     // Low error -> gradient
350            rchild: 2,     // High error -> weighted
351            ..Default::default()
352        },
353        // Leaf: Gradient predictor (for stable regions)
354        PropertyDecisionNode {
355            property: -1,
356            predictor: Predictor::Gradient,
357            context_id: 0,
358            ..Default::default()
359        },
360        // Leaf: Weighted predictor (for complex regions)
361        PropertyDecisionNode {
362            property: -1,
363            predictor: Predictor::Weighted,
364            context_id: 1,
365            ..Default::default()
366        },
367    ]
368}
369
370/// Count the number of unique context IDs used in a tree.
371pub fn count_contexts(tree: &Tree) -> u32 {
372    tree.iter()
373        .filter(|n| n.property < 0)
374        .map(|n| n.context_id)
375        .max()
376        .map(|m| m + 1)
377        .unwrap_or(1)
378}
379
380/// Assign context IDs to leaf nodes sequentially in BFS order.
381///
382/// The decoder assigns context IDs to leaves in the order it encounters them
383/// during BFS deserialization (rchild first, then lchild — matching
384/// `collect_tree_tokens`). We must use the same traversal order here so that
385/// context IDs in the encoder match what the decoder derives.
386pub fn assign_sequential_contexts(tree: &mut Tree) {
387    let mut next_context = 0u32;
388    let mut queue = std::collections::VecDeque::new();
389    queue.push_back(0usize);
390
391    while let Some(idx) = queue.pop_front() {
392        if tree[idx].property < 0 {
393            tree[idx].context_id = next_context;
394            next_context += 1;
395        } else {
396            let rchild = tree[idx].rchild;
397            let lchild = tree[idx].lchild;
398            // Same child order as collect_tree_tokens: rchild first, lchild second
399            queue.push_back(rchild);
400            queue.push_back(lchild);
401        }
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn test_floor_log2() {
411        assert_eq!(floor_log2(0), 0);
412        assert_eq!(floor_log2(1), 0);
413        assert_eq!(floor_log2(2), 1);
414        assert_eq!(floor_log2(3), 1);
415        assert_eq!(floor_log2(4), 2);
416        assert_eq!(floor_log2(255), 7);
417        assert_eq!(floor_log2(256), 8);
418    }
419
420    #[test]
421    fn test_simple_tree() {
422        let tree = simple_tree(Predictor::Left);
423        assert_eq!(tree.len(), 1);
424        assert_eq!(tree[0].property, -1);
425        assert_eq!(tree[0].predictor, Predictor::Left);
426    }
427
428    #[test]
429    fn test_traverse_simple() {
430        let tree = gradient_tree();
431        let props = PixelProperties::default();
432        let leaf = traverse_tree(&tree, &props);
433        assert_eq!(leaf.predictor, Predictor::Gradient);
434        assert_eq!(leaf.context_id, 0);
435    }
436
437    #[test]
438    fn test_weighted_tree() {
439        let tree = weighted_tree();
440        assert_eq!(tree.len(), 1);
441        assert_eq!(tree[0].predictor, Predictor::Weighted);
442    }
443
444    #[test]
445    fn test_decompose_multiplier() {
446        assert_eq!(decompose_multiplier(1), (0, 0)); // (0+1) << 0 = 1
447        assert_eq!(decompose_multiplier(2), (1, 0)); // (0+1) << 1 = 2
448        assert_eq!(decompose_multiplier(4), (2, 0)); // (0+1) << 2 = 4
449        assert_eq!(decompose_multiplier(3), (0, 2)); // (2+1) << 0 = 3
450        assert_eq!(decompose_multiplier(6), (1, 2)); // (2+1) << 1 = 6
451    }
452
453    #[test]
454    fn test_collect_tree_tokens_simple() {
455        let tree = gradient_tree();
456        let tokens = collect_tree_tokens(&tree);
457        // Single leaf: property(0), predictor(5), offset(0), mul_log(0), mul_bits(0)
458        assert_eq!(tokens.len(), 5);
459        assert_eq!(tokens[0].value, 0); // property = 0 (leaf)
460        assert_eq!(tokens[1].value, Predictor::Gradient as i32);
461    }
462
463    #[test]
464    fn test_adaptive_tree() {
465        let tree = adaptive_gradient_weighted_tree();
466        assert_eq!(tree.len(), 3);
467
468        // Test traversal with low error -> gradient
469        let mut props = PixelProperties::default();
470        props.values[Property::WpMaxError as usize] = 50;
471        let leaf = traverse_tree(&tree, &props);
472        assert_eq!(leaf.predictor, Predictor::Gradient);
473
474        // Test traversal with high error -> weighted
475        props.values[Property::WpMaxError as usize] = 150;
476        let leaf = traverse_tree(&tree, &props);
477        assert_eq!(leaf.predictor, Predictor::Weighted);
478    }
479
480    #[test]
481    fn test_count_contexts() {
482        let tree = gradient_tree();
483        assert_eq!(count_contexts(&tree), 1);
484
485        let tree = adaptive_gradient_weighted_tree();
486        assert_eq!(count_contexts(&tree), 2);
487    }
488}