1use super::predictor::Predictor;
11
12#[derive(Debug, Clone)]
14pub struct PropertyDecisionNode {
15 pub property: i32,
17 pub splitval: i32,
19 pub predictor: Predictor,
21 pub predictor_offset: i32,
23 pub multiplier: i32,
25 pub lchild: usize,
27 pub rchild: usize,
29 pub context_id: u32,
31}
32
33impl Default for PropertyDecisionNode {
34 fn default() -> Self {
35 Self {
36 property: -1, splitval: 0,
38 predictor: Predictor::Gradient,
39 predictor_offset: 0,
40 multiplier: 1,
41 lchild: 0,
42 rchild: 0,
43 context_id: 0,
44 }
45 }
46}
47
48pub type Tree = Vec<PropertyDecisionNode>;
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53#[repr(i32)]
54pub enum Property {
55 Channel = 0,
57 GroupId = 1,
59 Y = 2,
61 X = 3,
63 AbsNMinusNw = 4,
65 AbsNMinusW = 5,
67 FloorLog2W = 6,
69 FloorLog2N = 7,
71 FloorLog2Nw = 8,
73 AbsNMinusNn = 9,
75 AbsWMinusWw = 10,
77 AbsNwMinusNww = 11,
79 AbsNeMinusN = 12,
81 AbsNwMinusW = 13,
83 SumWNNw = 14,
85 WpMaxError = 15,
87}
88
89impl Property {
90 pub const NUM_STATIC: usize = 14;
92
93 pub const NUM_PROPERTIES: usize = 16;
95}
96
97#[derive(Debug, Clone, Default)]
99pub struct PixelProperties {
100 pub values: [i32; Property::NUM_PROPERTIES],
102}
103
104impl PixelProperties {
105 #[allow(clippy::too_many_arguments)]
107 pub fn compute(
108 channel_idx: u32,
109 group_id: u32,
110 x: usize,
111 y: usize,
112 n: i32,
113 w: i32,
114 nw: i32,
115 ne: i32,
116 nn: i32,
117 ww: i32,
118 nww: i32,
119 ) -> Self {
120 let mut values = [0i32; Property::NUM_PROPERTIES];
121
122 values[Property::Channel as usize] = channel_idx as i32;
123 values[Property::GroupId as usize] = group_id as i32;
124 values[Property::Y as usize] = y as i32;
125 values[Property::X as usize] = x as i32;
126 values[Property::AbsNMinusNw as usize] = (n - nw).abs();
127 values[Property::AbsNMinusW as usize] = (n - w).abs();
128 values[Property::FloorLog2W as usize] = floor_log2(w.unsigned_abs());
129 values[Property::FloorLog2N as usize] = floor_log2(n.unsigned_abs());
130 values[Property::FloorLog2Nw as usize] = floor_log2(nw.unsigned_abs());
131 values[Property::AbsNMinusNn as usize] = (n - nn).abs();
132 values[Property::AbsWMinusWw as usize] = (w - ww).abs();
133 values[Property::AbsNwMinusNww as usize] = (nw - nww).abs();
134 values[Property::AbsNeMinusN as usize] = (ne - n).abs();
135 values[Property::AbsNwMinusW as usize] = (nw - w).abs();
136 values[Property::SumWNNw as usize] = w.abs() + n.abs() + nw.abs();
137 values[Property::WpMaxError as usize] = 0; Self { values }
140 }
141
142 #[inline]
144 pub fn get(&self, property: i32) -> i32 {
145 if property >= 0 && (property as usize) < self.values.len() {
146 self.values[property as usize]
147 } else {
148 0
149 }
150 }
151}
152
153#[inline]
155fn floor_log2(value: u32) -> i32 {
156 if value == 0 {
157 0
158 } else {
159 31 - value.leading_zeros() as i32
160 }
161}
162
163pub fn simple_tree(predictor: Predictor) -> Tree {
165 vec![PropertyDecisionNode {
166 property: -1, predictor,
168 context_id: 0,
169 ..Default::default()
170 }]
171}
172
173pub fn gradient_tree() -> Tree {
175 simple_tree(Predictor::Gradient)
176}
177
178#[allow(dead_code)]
180pub fn per_channel_tree(num_channels: usize) -> Tree {
181 let mut tree = Vec::with_capacity(num_channels * 2);
182
183 for c in 0..num_channels {
185 if c < num_channels - 1 {
186 tree.push(PropertyDecisionNode {
188 property: Property::Channel as i32,
189 splitval: c as i32,
190 lchild: tree.len() + num_channels - c, rchild: tree.len() + 1, ..Default::default()
193 });
194 }
195 }
196
197 for c in 0..num_channels {
199 tree.push(PropertyDecisionNode {
200 property: -1,
201 predictor: Predictor::Gradient,
202 context_id: c as u32,
203 ..Default::default()
204 });
205 }
206
207 tree
208}
209
210pub fn traverse_tree<'a>(tree: &'a Tree, properties: &PixelProperties) -> &'a PropertyDecisionNode {
212 let mut node_idx = 0;
213
214 loop {
215 let node = &tree[node_idx];
216
217 if node.property < 0 {
219 return node;
220 }
221
222 let prop_value = properties.get(node.property);
224 if prop_value <= node.splitval {
225 node_idx = node.lchild;
226 } else {
227 node_idx = node.rchild;
228 }
229 }
230}
231
232const SPLIT_VAL_CONTEXT: usize = 0;
234const PROPERTY_CONTEXT: usize = 1;
235const PREDICTOR_CONTEXT: usize = 2;
236const OFFSET_CONTEXT: usize = 3;
237const MULTIPLIER_LOG_CONTEXT: usize = 4;
238const MULTIPLIER_BITS_CONTEXT: usize = 5;
239
240#[derive(Debug, Clone)]
242pub struct TreeToken {
243 pub context: usize,
245 pub value: i32,
247 pub is_signed: bool,
249}
250
251pub fn collect_tree_tokens(tree: &Tree) -> Vec<TreeToken> {
253 let mut tokens = Vec::new();
254
255 let mut queue = std::collections::VecDeque::new();
257 queue.push_back(0usize);
258
259 while let Some(idx) = queue.pop_front() {
260 let node = &tree[idx];
261
262 if node.property < 0 {
263 tokens.push(TreeToken {
265 context: PROPERTY_CONTEXT,
266 value: 0, is_signed: false,
268 });
269
270 tokens.push(TreeToken {
272 context: PREDICTOR_CONTEXT,
273 value: node.predictor as i32,
274 is_signed: false,
275 });
276
277 tokens.push(TreeToken {
279 context: OFFSET_CONTEXT,
280 value: node.predictor_offset,
281 is_signed: true,
282 });
283
284 let (mul_log, mul_bits) = decompose_multiplier(node.multiplier as u32);
287 tokens.push(TreeToken {
288 context: MULTIPLIER_LOG_CONTEXT,
289 value: mul_log as i32,
290 is_signed: false,
291 });
292
293 tokens.push(TreeToken {
294 context: MULTIPLIER_BITS_CONTEXT,
295 value: mul_bits as i32,
296 is_signed: false,
297 });
298 } else {
299 tokens.push(TreeToken {
301 context: PROPERTY_CONTEXT,
302 value: node.property + 1, is_signed: false,
304 });
305
306 tokens.push(TreeToken {
307 context: SPLIT_VAL_CONTEXT,
308 value: node.splitval,
309 is_signed: true,
310 });
311
312 queue.push_back(node.rchild);
316 queue.push_back(node.lchild);
317 }
318 }
319
320 tokens
321}
322
323fn decompose_multiplier(multiplier: u32) -> (u32, u32) {
325 if multiplier == 0 {
326 return (0, 0);
327 }
328
329 let trailing = multiplier.trailing_zeros();
330 let mul_log = trailing;
331 let mul_bits = (multiplier >> trailing) - 1;
332
333 (mul_log, mul_bits)
334}
335
336pub fn weighted_tree() -> Tree {
338 simple_tree(Predictor::Weighted)
339}
340
341pub fn adaptive_gradient_weighted_tree() -> Tree {
344 vec![
345 PropertyDecisionNode {
347 property: Property::WpMaxError as i32,
348 splitval: 100, lchild: 1, rchild: 2, ..Default::default()
352 },
353 PropertyDecisionNode {
355 property: -1,
356 predictor: Predictor::Gradient,
357 context_id: 0,
358 ..Default::default()
359 },
360 PropertyDecisionNode {
362 property: -1,
363 predictor: Predictor::Weighted,
364 context_id: 1,
365 ..Default::default()
366 },
367 ]
368}
369
370pub fn validate_tree_djxl(tree: &Tree) -> Result<(), String> {
382 if tree.is_empty() {
383 return Ok(());
384 }
385
386 let mut num_properties = 0i32;
387 for node in tree {
388 if node.property >= num_properties {
389 num_properties = node.property + 1;
390 }
391 }
392 let np = num_properties as usize;
393
394 let mut ranges: Vec<(i32, i32)> = vec![(i32::MIN, i32::MAX); np * tree.len()];
398
399 for (i, node) in tree.iter().enumerate() {
400 if node.property < 0 {
401 continue; }
403 let p = node.property as usize;
404 let val = node.splitval;
405 let lo = ranges[i * np + p].0;
406 let hi = ranges[i * np + p].1;
407
408 if lo > val || hi <= val {
410 return Err(format!(
411 "Node {} (property={}, splitval={}): range [{}, {}] invalid \
412 (lo > val = {}, hi <= val = {})",
413 i,
414 node.property,
415 val,
416 lo,
417 hi,
418 lo > val,
419 hi <= val
420 ));
421 }
422
423 let lchild = node.lchild; let rchild = node.rchild; for pp in 0..np {
428 ranges[rchild * np + pp] = ranges[i * np + pp];
429 ranges[lchild * np + pp] = ranges[i * np + pp];
430 }
431
432 ranges[rchild * np + p] = (val + 1, hi);
435 ranges[lchild * np + p] = (lo, val);
437 }
438
439 Ok(())
440}
441
442pub fn count_contexts(tree: &Tree) -> u32 {
448 let mut count = 0u32;
449 let mut queue = std::collections::VecDeque::new();
450 queue.push_back(0usize);
451
452 while let Some(idx) = queue.pop_front() {
453 if tree[idx].property < 0 {
454 count += 1;
455 } else {
456 queue.push_back(tree[idx].rchild);
457 queue.push_back(tree[idx].lchild);
458 }
459 }
460 count.max(1)
461}
462
463pub fn assign_sequential_contexts(tree: &mut Tree) -> u32 {
472 let mut next_context = 0u32;
473 let mut queue = std::collections::VecDeque::new();
474 queue.push_back(0usize);
475
476 while let Some(idx) = queue.pop_front() {
477 if tree[idx].property < 0 {
478 tree[idx].context_id = next_context;
479 next_context += 1;
480 } else {
481 let rchild = tree[idx].rchild;
482 let lchild = tree[idx].lchild;
483 queue.push_back(rchild);
485 queue.push_back(lchild);
486 }
487 }
488 next_context
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_floor_log2() {
497 assert_eq!(floor_log2(0), 0);
498 assert_eq!(floor_log2(1), 0);
499 assert_eq!(floor_log2(2), 1);
500 assert_eq!(floor_log2(3), 1);
501 assert_eq!(floor_log2(4), 2);
502 assert_eq!(floor_log2(255), 7);
503 assert_eq!(floor_log2(256), 8);
504 }
505
506 #[test]
507 fn test_simple_tree() {
508 let tree = simple_tree(Predictor::Left);
509 assert_eq!(tree.len(), 1);
510 assert_eq!(tree[0].property, -1);
511 assert_eq!(tree[0].predictor, Predictor::Left);
512 }
513
514 #[test]
515 fn test_traverse_simple() {
516 let tree = gradient_tree();
517 let props = PixelProperties::default();
518 let leaf = traverse_tree(&tree, &props);
519 assert_eq!(leaf.predictor, Predictor::Gradient);
520 assert_eq!(leaf.context_id, 0);
521 }
522
523 #[test]
524 fn test_weighted_tree() {
525 let tree = weighted_tree();
526 assert_eq!(tree.len(), 1);
527 assert_eq!(tree[0].predictor, Predictor::Weighted);
528 }
529
530 #[test]
531 fn test_decompose_multiplier() {
532 assert_eq!(decompose_multiplier(1), (0, 0)); assert_eq!(decompose_multiplier(2), (1, 0)); assert_eq!(decompose_multiplier(4), (2, 0)); assert_eq!(decompose_multiplier(3), (0, 2)); assert_eq!(decompose_multiplier(6), (1, 2)); }
538
539 #[test]
540 fn test_collect_tree_tokens_simple() {
541 let tree = gradient_tree();
542 let tokens = collect_tree_tokens(&tree);
543 assert_eq!(tokens.len(), 5);
545 assert_eq!(tokens[0].value, 0); assert_eq!(tokens[1].value, Predictor::Gradient as i32);
547 }
548
549 #[test]
550 fn test_adaptive_tree() {
551 let tree = adaptive_gradient_weighted_tree();
552 assert_eq!(tree.len(), 3);
553
554 let mut props = PixelProperties::default();
556 props.values[Property::WpMaxError as usize] = 50;
557 let leaf = traverse_tree(&tree, &props);
558 assert_eq!(leaf.predictor, Predictor::Gradient);
559
560 props.values[Property::WpMaxError as usize] = 150;
562 let leaf = traverse_tree(&tree, &props);
563 assert_eq!(leaf.predictor, Predictor::Weighted);
564 }
565
566 #[test]
567 fn test_count_contexts() {
568 let tree = gradient_tree();
569 assert_eq!(count_contexts(&tree), 1);
570
571 let tree = adaptive_gradient_weighted_tree();
572 assert_eq!(count_contexts(&tree), 2);
573 }
574}