use super::predictor::Predictor;
#[derive(Debug, Clone)]
pub struct PropertyDecisionNode {
pub property: i32,
pub splitval: i32,
pub predictor: Predictor,
pub predictor_offset: i32,
pub multiplier: i32,
pub lchild: usize,
pub rchild: usize,
pub context_id: u32,
}
impl Default for PropertyDecisionNode {
fn default() -> Self {
Self {
property: -1, splitval: 0,
predictor: Predictor::Gradient,
predictor_offset: 0,
multiplier: 1,
lchild: 0,
rchild: 0,
context_id: 0,
}
}
}
pub type Tree = Vec<PropertyDecisionNode>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(i32)]
pub enum Property {
Channel = 0,
GroupId = 1,
Y = 2,
X = 3,
AbsNMinusNw = 4,
AbsNMinusW = 5,
FloorLog2W = 6,
FloorLog2N = 7,
FloorLog2Nw = 8,
AbsNMinusNn = 9,
AbsWMinusWw = 10,
AbsNwMinusNww = 11,
AbsNeMinusN = 12,
AbsNwMinusW = 13,
SumWNNw = 14,
WpMaxError = 15,
}
impl Property {
pub const NUM_STATIC: usize = 14;
pub const NUM_PROPERTIES: usize = 16;
}
#[derive(Debug, Clone, Default)]
pub struct PixelProperties {
pub values: [i32; Property::NUM_PROPERTIES],
}
impl PixelProperties {
#[allow(clippy::too_many_arguments)]
pub fn compute(
channel_idx: u32,
group_id: u32,
x: usize,
y: usize,
n: i32,
w: i32,
nw: i32,
ne: i32,
nn: i32,
ww: i32,
nww: i32,
) -> Self {
let mut values = [0i32; Property::NUM_PROPERTIES];
values[Property::Channel as usize] = channel_idx as i32;
values[Property::GroupId as usize] = group_id as i32;
values[Property::Y as usize] = y as i32;
values[Property::X as usize] = x as i32;
values[Property::AbsNMinusNw as usize] = (n - nw).abs();
values[Property::AbsNMinusW as usize] = (n - w).abs();
values[Property::FloorLog2W as usize] = floor_log2(w.unsigned_abs());
values[Property::FloorLog2N as usize] = floor_log2(n.unsigned_abs());
values[Property::FloorLog2Nw as usize] = floor_log2(nw.unsigned_abs());
values[Property::AbsNMinusNn as usize] = (n - nn).abs();
values[Property::AbsWMinusWw as usize] = (w - ww).abs();
values[Property::AbsNwMinusNww as usize] = (nw - nww).abs();
values[Property::AbsNeMinusN as usize] = (ne - n).abs();
values[Property::AbsNwMinusW as usize] = (nw - w).abs();
values[Property::SumWNNw as usize] = w.abs() + n.abs() + nw.abs();
values[Property::WpMaxError as usize] = 0;
Self { values }
}
#[inline]
pub fn get(&self, property: i32) -> i32 {
if property >= 0 && (property as usize) < self.values.len() {
self.values[property as usize]
} else {
0
}
}
}
#[inline]
fn floor_log2(value: u32) -> i32 {
if value == 0 {
0
} else {
31 - value.leading_zeros() as i32
}
}
pub fn simple_tree(predictor: Predictor) -> Tree {
vec![PropertyDecisionNode {
property: -1, predictor,
context_id: 0,
..Default::default()
}]
}
pub fn gradient_tree() -> Tree {
simple_tree(Predictor::Gradient)
}
#[allow(dead_code)]
pub fn per_channel_tree(num_channels: usize) -> Tree {
let mut tree = Vec::with_capacity(num_channels * 2);
for c in 0..num_channels {
if c < num_channels - 1 {
tree.push(PropertyDecisionNode {
property: Property::Channel as i32,
splitval: c as i32,
lchild: tree.len() + num_channels - c, rchild: tree.len() + 1, ..Default::default()
});
}
}
for c in 0..num_channels {
tree.push(PropertyDecisionNode {
property: -1,
predictor: Predictor::Gradient,
context_id: c as u32,
..Default::default()
});
}
tree
}
pub fn traverse_tree<'a>(tree: &'a Tree, properties: &PixelProperties) -> &'a PropertyDecisionNode {
let mut node_idx = 0;
loop {
let node = &tree[node_idx];
if node.property < 0 {
return node;
}
let prop_value = properties.get(node.property);
if prop_value <= node.splitval {
node_idx = node.lchild;
} else {
node_idx = node.rchild;
}
}
}
const SPLIT_VAL_CONTEXT: usize = 0;
const PROPERTY_CONTEXT: usize = 1;
const PREDICTOR_CONTEXT: usize = 2;
const OFFSET_CONTEXT: usize = 3;
const MULTIPLIER_LOG_CONTEXT: usize = 4;
const MULTIPLIER_BITS_CONTEXT: usize = 5;
#[derive(Debug, Clone)]
pub struct TreeToken {
pub context: usize,
pub value: i32,
pub is_signed: bool,
}
pub fn collect_tree_tokens(tree: &Tree) -> Vec<TreeToken> {
let mut tokens = Vec::new();
let mut queue = std::collections::VecDeque::new();
queue.push_back(0usize);
while let Some(idx) = queue.pop_front() {
let node = &tree[idx];
if node.property < 0 {
tokens.push(TreeToken {
context: PROPERTY_CONTEXT,
value: 0, is_signed: false,
});
tokens.push(TreeToken {
context: PREDICTOR_CONTEXT,
value: node.predictor as i32,
is_signed: false,
});
tokens.push(TreeToken {
context: OFFSET_CONTEXT,
value: node.predictor_offset,
is_signed: true,
});
let (mul_log, mul_bits) = decompose_multiplier(node.multiplier as u32);
tokens.push(TreeToken {
context: MULTIPLIER_LOG_CONTEXT,
value: mul_log as i32,
is_signed: false,
});
tokens.push(TreeToken {
context: MULTIPLIER_BITS_CONTEXT,
value: mul_bits as i32,
is_signed: false,
});
} else {
tokens.push(TreeToken {
context: PROPERTY_CONTEXT,
value: node.property + 1, is_signed: false,
});
tokens.push(TreeToken {
context: SPLIT_VAL_CONTEXT,
value: node.splitval,
is_signed: true,
});
queue.push_back(node.rchild);
queue.push_back(node.lchild);
}
}
tokens
}
fn decompose_multiplier(multiplier: u32) -> (u32, u32) {
if multiplier == 0 {
return (0, 0);
}
let trailing = multiplier.trailing_zeros();
let mul_log = trailing;
let mul_bits = (multiplier >> trailing) - 1;
(mul_log, mul_bits)
}
pub fn weighted_tree() -> Tree {
simple_tree(Predictor::Weighted)
}
pub fn adaptive_gradient_weighted_tree() -> Tree {
vec![
PropertyDecisionNode {
property: Property::WpMaxError as i32,
splitval: 100, lchild: 1, rchild: 2, ..Default::default()
},
PropertyDecisionNode {
property: -1,
predictor: Predictor::Gradient,
context_id: 0,
..Default::default()
},
PropertyDecisionNode {
property: -1,
predictor: Predictor::Weighted,
context_id: 1,
..Default::default()
},
]
}
pub fn validate_tree_djxl(tree: &Tree) -> Result<(), String> {
if tree.is_empty() {
return Ok(());
}
let mut num_properties = 0i32;
for node in tree {
if node.property >= num_properties {
num_properties = node.property + 1;
}
}
let np = num_properties as usize;
let mut ranges: Vec<(i32, i32)> = vec![(i32::MIN, i32::MAX); np * tree.len()];
for (i, node) in tree.iter().enumerate() {
if node.property < 0 {
continue; }
let p = node.property as usize;
let val = node.splitval;
let lo = ranges[i * np + p].0;
let hi = ranges[i * np + p].1;
if lo > val || hi <= val {
return Err(format!(
"Node {} (property={}, splitval={}): range [{}, {}] invalid \
(lo > val = {}, hi <= val = {})",
i,
node.property,
val,
lo,
hi,
lo > val,
hi <= val
));
}
let lchild = node.lchild; let rchild = node.rchild;
for pp in 0..np {
ranges[rchild * np + pp] = ranges[i * np + pp];
ranges[lchild * np + pp] = ranges[i * np + pp];
}
ranges[rchild * np + p] = (val + 1, hi);
ranges[lchild * np + p] = (lo, val);
}
Ok(())
}
pub fn count_contexts(tree: &Tree) -> u32 {
let mut count = 0u32;
let mut queue = std::collections::VecDeque::new();
queue.push_back(0usize);
while let Some(idx) = queue.pop_front() {
if tree[idx].property < 0 {
count += 1;
} else {
queue.push_back(tree[idx].rchild);
queue.push_back(tree[idx].lchild);
}
}
count.max(1)
}
pub fn assign_sequential_contexts(tree: &mut Tree) -> u32 {
let mut next_context = 0u32;
let mut queue = std::collections::VecDeque::new();
queue.push_back(0usize);
while let Some(idx) = queue.pop_front() {
if tree[idx].property < 0 {
tree[idx].context_id = next_context;
next_context += 1;
} else {
let rchild = tree[idx].rchild;
let lchild = tree[idx].lchild;
queue.push_back(rchild);
queue.push_back(lchild);
}
}
next_context
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_floor_log2() {
assert_eq!(floor_log2(0), 0);
assert_eq!(floor_log2(1), 0);
assert_eq!(floor_log2(2), 1);
assert_eq!(floor_log2(3), 1);
assert_eq!(floor_log2(4), 2);
assert_eq!(floor_log2(255), 7);
assert_eq!(floor_log2(256), 8);
}
#[test]
fn test_simple_tree() {
let tree = simple_tree(Predictor::Left);
assert_eq!(tree.len(), 1);
assert_eq!(tree[0].property, -1);
assert_eq!(tree[0].predictor, Predictor::Left);
}
#[test]
fn test_traverse_simple() {
let tree = gradient_tree();
let props = PixelProperties::default();
let leaf = traverse_tree(&tree, &props);
assert_eq!(leaf.predictor, Predictor::Gradient);
assert_eq!(leaf.context_id, 0);
}
#[test]
fn test_weighted_tree() {
let tree = weighted_tree();
assert_eq!(tree.len(), 1);
assert_eq!(tree[0].predictor, Predictor::Weighted);
}
#[test]
fn test_decompose_multiplier() {
assert_eq!(decompose_multiplier(1), (0, 0)); assert_eq!(decompose_multiplier(2), (1, 0)); assert_eq!(decompose_multiplier(4), (2, 0)); assert_eq!(decompose_multiplier(3), (0, 2)); assert_eq!(decompose_multiplier(6), (1, 2)); }
#[test]
fn test_collect_tree_tokens_simple() {
let tree = gradient_tree();
let tokens = collect_tree_tokens(&tree);
assert_eq!(tokens.len(), 5);
assert_eq!(tokens[0].value, 0); assert_eq!(tokens[1].value, Predictor::Gradient as i32);
}
#[test]
fn test_adaptive_tree() {
let tree = adaptive_gradient_weighted_tree();
assert_eq!(tree.len(), 3);
let mut props = PixelProperties::default();
props.values[Property::WpMaxError as usize] = 50;
let leaf = traverse_tree(&tree, &props);
assert_eq!(leaf.predictor, Predictor::Gradient);
props.values[Property::WpMaxError as usize] = 150;
let leaf = traverse_tree(&tree, &props);
assert_eq!(leaf.predictor, Predictor::Weighted);
}
#[test]
fn test_count_contexts() {
let tree = gradient_tree();
assert_eq!(count_contexts(&tree), 1);
let tree = adaptive_gradient_weighted_tree();
assert_eq!(count_contexts(&tree), 2);
}
}