1use super::predictor::Predictor;
11
12#[derive(Debug, Clone)]
14pub struct PropertyDecisionNode {
15 pub property: i32,
17 pub splitval: i32,
19 pub predictor: Predictor,
21 pub predictor_offset: i32,
23 pub multiplier: i32,
25 pub lchild: usize,
27 pub rchild: usize,
29 pub context_id: u32,
31}
32
33impl Default for PropertyDecisionNode {
34 fn default() -> Self {
35 Self {
36 property: -1, splitval: 0,
38 predictor: Predictor::Gradient,
39 predictor_offset: 0,
40 multiplier: 1,
41 lchild: 0,
42 rchild: 0,
43 context_id: 0,
44 }
45 }
46}
47
48pub type Tree = Vec<PropertyDecisionNode>;
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53#[repr(i32)]
54pub enum Property {
55 Channel = 0,
57 GroupId = 1,
59 Y = 2,
61 X = 3,
63 AbsNMinusNw = 4,
65 AbsNMinusW = 5,
67 FloorLog2W = 6,
69 FloorLog2N = 7,
71 FloorLog2Nw = 8,
73 AbsNMinusNn = 9,
75 AbsWMinusWw = 10,
77 AbsNwMinusNww = 11,
79 AbsNeMinusN = 12,
81 AbsNwMinusW = 13,
83 SumWNNw = 14,
85 WpMaxError = 15,
87}
88
89impl Property {
90 pub const NUM_STATIC: usize = 14;
92
93 pub const NUM_PROPERTIES: usize = 16;
95}
96
97#[derive(Debug, Clone, Default)]
99pub struct PixelProperties {
100 pub values: [i32; Property::NUM_PROPERTIES],
102}
103
104impl PixelProperties {
105 #[allow(clippy::too_many_arguments)]
107 pub fn compute(
108 channel_idx: u32,
109 group_id: u32,
110 x: usize,
111 y: usize,
112 n: i32,
113 w: i32,
114 nw: i32,
115 ne: i32,
116 nn: i32,
117 ww: i32,
118 nww: i32,
119 ) -> Self {
120 let mut values = [0i32; Property::NUM_PROPERTIES];
121
122 values[Property::Channel as usize] = channel_idx as i32;
123 values[Property::GroupId as usize] = group_id as i32;
124 values[Property::Y as usize] = y as i32;
125 values[Property::X as usize] = x as i32;
126 values[Property::AbsNMinusNw as usize] = (n - nw).abs();
127 values[Property::AbsNMinusW as usize] = (n - w).abs();
128 values[Property::FloorLog2W as usize] = floor_log2(w.unsigned_abs());
129 values[Property::FloorLog2N as usize] = floor_log2(n.unsigned_abs());
130 values[Property::FloorLog2Nw as usize] = floor_log2(nw.unsigned_abs());
131 values[Property::AbsNMinusNn as usize] = (n - nn).abs();
132 values[Property::AbsWMinusWw as usize] = (w - ww).abs();
133 values[Property::AbsNwMinusNww as usize] = (nw - nww).abs();
134 values[Property::AbsNeMinusN as usize] = (ne - n).abs();
135 values[Property::AbsNwMinusW as usize] = (nw - w).abs();
136 values[Property::SumWNNw as usize] = w.abs() + n.abs() + nw.abs();
137 values[Property::WpMaxError as usize] = 0; Self { values }
140 }
141
142 #[inline]
144 pub fn get(&self, property: i32) -> i32 {
145 if property >= 0 && (property as usize) < self.values.len() {
146 self.values[property as usize]
147 } else {
148 0
149 }
150 }
151}
152
153#[inline]
155fn floor_log2(value: u32) -> i32 {
156 if value == 0 {
157 0
158 } else {
159 31 - value.leading_zeros() as i32
160 }
161}
162
163pub fn simple_tree(predictor: Predictor) -> Tree {
165 vec![PropertyDecisionNode {
166 property: -1, predictor,
168 context_id: 0,
169 ..Default::default()
170 }]
171}
172
173pub fn gradient_tree() -> Tree {
175 simple_tree(Predictor::Gradient)
176}
177
178#[allow(dead_code)]
180pub fn per_channel_tree(num_channels: usize) -> Tree {
181 let mut tree = Vec::with_capacity(num_channels * 2);
182
183 for c in 0..num_channels {
185 if c < num_channels - 1 {
186 tree.push(PropertyDecisionNode {
188 property: Property::Channel as i32,
189 splitval: c as i32,
190 lchild: tree.len() + num_channels - c, rchild: tree.len() + 1, ..Default::default()
193 });
194 }
195 }
196
197 for c in 0..num_channels {
199 tree.push(PropertyDecisionNode {
200 property: -1,
201 predictor: Predictor::Gradient,
202 context_id: c as u32,
203 ..Default::default()
204 });
205 }
206
207 tree
208}
209
210pub fn traverse_tree<'a>(tree: &'a Tree, properties: &PixelProperties) -> &'a PropertyDecisionNode {
212 let mut node_idx = 0;
213
214 loop {
215 let node = &tree[node_idx];
216
217 if node.property < 0 {
219 return node;
220 }
221
222 let prop_value = properties.get(node.property);
224 if prop_value <= node.splitval {
225 node_idx = node.lchild;
226 } else {
227 node_idx = node.rchild;
228 }
229 }
230}
231
232const SPLIT_VAL_CONTEXT: usize = 0;
234const PROPERTY_CONTEXT: usize = 1;
235const PREDICTOR_CONTEXT: usize = 2;
236const OFFSET_CONTEXT: usize = 3;
237const MULTIPLIER_LOG_CONTEXT: usize = 4;
238const MULTIPLIER_BITS_CONTEXT: usize = 5;
239
240#[derive(Debug, Clone)]
242pub struct TreeToken {
243 pub context: usize,
245 pub value: i32,
247 pub is_signed: bool,
249}
250
251pub fn collect_tree_tokens(tree: &Tree) -> Vec<TreeToken> {
253 let mut tokens = Vec::new();
254
255 let mut queue = std::collections::VecDeque::new();
257 queue.push_back(0usize);
258
259 while let Some(idx) = queue.pop_front() {
260 let node = &tree[idx];
261
262 if node.property < 0 {
263 tokens.push(TreeToken {
265 context: PROPERTY_CONTEXT,
266 value: 0, is_signed: false,
268 });
269
270 tokens.push(TreeToken {
272 context: PREDICTOR_CONTEXT,
273 value: node.predictor as i32,
274 is_signed: false,
275 });
276
277 tokens.push(TreeToken {
279 context: OFFSET_CONTEXT,
280 value: node.predictor_offset,
281 is_signed: true,
282 });
283
284 let (mul_log, mul_bits) = decompose_multiplier(node.multiplier as u32);
287 tokens.push(TreeToken {
288 context: MULTIPLIER_LOG_CONTEXT,
289 value: mul_log as i32,
290 is_signed: false,
291 });
292
293 tokens.push(TreeToken {
294 context: MULTIPLIER_BITS_CONTEXT,
295 value: mul_bits as i32,
296 is_signed: false,
297 });
298 } else {
299 tokens.push(TreeToken {
301 context: PROPERTY_CONTEXT,
302 value: node.property + 1, is_signed: false,
304 });
305
306 tokens.push(TreeToken {
307 context: SPLIT_VAL_CONTEXT,
308 value: node.splitval,
309 is_signed: true,
310 });
311
312 queue.push_back(node.rchild);
316 queue.push_back(node.lchild);
317 }
318 }
319
320 tokens
321}
322
323fn decompose_multiplier(multiplier: u32) -> (u32, u32) {
325 if multiplier == 0 {
326 return (0, 0);
327 }
328
329 let trailing = multiplier.trailing_zeros();
330 let mul_log = trailing;
331 let mul_bits = (multiplier >> trailing) - 1;
332
333 (mul_log, mul_bits)
334}
335
336pub fn weighted_tree() -> Tree {
338 simple_tree(Predictor::Weighted)
339}
340
341pub fn adaptive_gradient_weighted_tree() -> Tree {
344 vec![
345 PropertyDecisionNode {
347 property: Property::WpMaxError as i32,
348 splitval: 100, lchild: 1, rchild: 2, ..Default::default()
352 },
353 PropertyDecisionNode {
355 property: -1,
356 predictor: Predictor::Gradient,
357 context_id: 0,
358 ..Default::default()
359 },
360 PropertyDecisionNode {
362 property: -1,
363 predictor: Predictor::Weighted,
364 context_id: 1,
365 ..Default::default()
366 },
367 ]
368}
369
370pub fn count_contexts(tree: &Tree) -> u32 {
372 tree.iter()
373 .filter(|n| n.property < 0)
374 .map(|n| n.context_id)
375 .max()
376 .map(|m| m + 1)
377 .unwrap_or(1)
378}
379
380pub fn assign_sequential_contexts(tree: &mut Tree) {
387 let mut next_context = 0u32;
388 let mut queue = std::collections::VecDeque::new();
389 queue.push_back(0usize);
390
391 while let Some(idx) = queue.pop_front() {
392 if tree[idx].property < 0 {
393 tree[idx].context_id = next_context;
394 next_context += 1;
395 } else {
396 let rchild = tree[idx].rchild;
397 let lchild = tree[idx].lchild;
398 queue.push_back(rchild);
400 queue.push_back(lchild);
401 }
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_floor_log2() {
411 assert_eq!(floor_log2(0), 0);
412 assert_eq!(floor_log2(1), 0);
413 assert_eq!(floor_log2(2), 1);
414 assert_eq!(floor_log2(3), 1);
415 assert_eq!(floor_log2(4), 2);
416 assert_eq!(floor_log2(255), 7);
417 assert_eq!(floor_log2(256), 8);
418 }
419
420 #[test]
421 fn test_simple_tree() {
422 let tree = simple_tree(Predictor::Left);
423 assert_eq!(tree.len(), 1);
424 assert_eq!(tree[0].property, -1);
425 assert_eq!(tree[0].predictor, Predictor::Left);
426 }
427
428 #[test]
429 fn test_traverse_simple() {
430 let tree = gradient_tree();
431 let props = PixelProperties::default();
432 let leaf = traverse_tree(&tree, &props);
433 assert_eq!(leaf.predictor, Predictor::Gradient);
434 assert_eq!(leaf.context_id, 0);
435 }
436
437 #[test]
438 fn test_weighted_tree() {
439 let tree = weighted_tree();
440 assert_eq!(tree.len(), 1);
441 assert_eq!(tree[0].predictor, Predictor::Weighted);
442 }
443
444 #[test]
445 fn test_decompose_multiplier() {
446 assert_eq!(decompose_multiplier(1), (0, 0)); assert_eq!(decompose_multiplier(2), (1, 0)); assert_eq!(decompose_multiplier(4), (2, 0)); assert_eq!(decompose_multiplier(3), (0, 2)); assert_eq!(decompose_multiplier(6), (1, 2)); }
452
453 #[test]
454 fn test_collect_tree_tokens_simple() {
455 let tree = gradient_tree();
456 let tokens = collect_tree_tokens(&tree);
457 assert_eq!(tokens.len(), 5);
459 assert_eq!(tokens[0].value, 0); assert_eq!(tokens[1].value, Predictor::Gradient as i32);
461 }
462
463 #[test]
464 fn test_adaptive_tree() {
465 let tree = adaptive_gradient_weighted_tree();
466 assert_eq!(tree.len(), 3);
467
468 let mut props = PixelProperties::default();
470 props.values[Property::WpMaxError as usize] = 50;
471 let leaf = traverse_tree(&tree, &props);
472 assert_eq!(leaf.predictor, Predictor::Gradient);
473
474 props.values[Property::WpMaxError as usize] = 150;
476 let leaf = traverse_tree(&tree, &props);
477 assert_eq!(leaf.predictor, Predictor::Weighted);
478 }
479
480 #[test]
481 fn test_count_contexts() {
482 let tree = gradient_tree();
483 assert_eq!(count_contexts(&tree), 1);
484
485 let tree = adaptive_gradient_weighted_tree();
486 assert_eq!(count_contexts(&tree), 2);
487 }
488}