use core::cmp::Ordering;
use super::channel::{Channel, ModularImage};
use super::predictor::{
Neighbors, Predictor, WeightedPredictorParams, WeightedPredictorState, pack_signed,
};
use super::tree::{PropertyDecisionNode, Tree, assign_sequential_contexts, validate_tree_djxl};
use crate::entropy_coding::hybrid_uint::HybridUintConfig;
const GATHER_HYBRID_UINT: HybridUintConfig = HybridUintConfig {
split_exponent: 4,
split: 16, msb_in_token: 1,
lsb_in_token: 2,
};
const NUM_PROPERTIES: usize = 16;
const CANDIDATE_PREDICTORS: &[Predictor] = &[
Predictor::Zero,
Predictor::Left,
Predictor::Top,
Predictor::Average0,
Predictor::Select,
Predictor::Gradient,
Predictor::Weighted,
Predictor::TopRight,
Predictor::TopLeft,
Predictor::LeftLeft,
Predictor::Average1,
Predictor::Average2,
Predictor::Average3,
Predictor::Average4,
];
const PROP_ORDER_NO_SQUEEZE: &[usize] = &[
0, 1, 15, 9, 10, 11, 12, 13, 14, 2, 3, 4, 5, 6, 7, 8, ];
const PROP_ORDER_NO_SQUEEZE_NO_GID: &[usize] = &[
0, 15, 9, 10, 11, 12, 13, 14, 2, 3, 4, 5, 6, 7, 8, ];
const PROP_ORDER_SQUEEZE: &[usize] = &[
0, 1, 4, 5, 6, 7, 8, 15, 9, 10, 11, 12, 13, 14, 2, 3, ];
const CANDIDATE_PREDICTORS_SQUEEZE: &[Predictor] = &[Predictor::Zero];
pub struct TreeLearningParams {
pub properties: Vec<usize>,
pub max_property_values: usize,
pub split_threshold: f64,
pub max_nodes: usize,
pub pixel_fraction: f64,
}
impl TreeLearningParams {
pub fn from_profile(profile: &crate::effort::EffortProfile) -> Self {
Self::from_profile_impl(profile, false)
}
pub fn from_profile_squeeze(profile: &crate::effort::EffortProfile) -> Self {
Self::from_profile_impl(profile, true)
}
fn from_profile_impl(profile: &crate::effort::EffortProfile, is_squeeze: bool) -> Self {
let order = if is_squeeze {
PROP_ORDER_SQUEEZE
} else if profile.effort >= 9 {
PROP_ORDER_NO_SQUEEZE
} else {
PROP_ORDER_NO_SQUEEZE_NO_GID
};
let num_props = (profile.tree_num_properties as usize).min(order.len());
Self {
properties: order[..num_props].to_vec(),
max_property_values: profile.tree_max_buckets as usize,
split_threshold: profile.tree_threshold_base as f64,
max_nodes: 1 << 22,
pixel_fraction: 1.0,
}
}
#[cfg(test)]
pub fn for_effort(effort: u8) -> Self {
let order = if effort >= 9 {
PROP_ORDER_NO_SQUEEZE
} else {
PROP_ORDER_NO_SQUEEZE_NO_GID
};
let speed_tier = 10u8.saturating_sub(effort);
let (num_props, max_property_values) = match effort {
0..=4 => (3, 32),
5 => (4, 48),
6 => (5, 64),
7 => (7, 96),
8 => (10, 128),
_ => (order.len(), 256),
};
let threshold_base = 75.0 + 14.0 * speed_tier as f64;
let num_props = num_props.min(order.len());
Self {
properties: order[..num_props].to_vec(),
max_property_values,
split_threshold: threshold_base,
max_nodes: 1 << 22,
pixel_fraction: 1.0,
}
}
#[must_use]
pub fn with_pixel_fraction(mut self, fraction: f64) -> Self {
self.pixel_fraction = fraction.clamp(0.0, 1.0);
self
}
#[must_use]
pub fn with_total_pixels(mut self, total_pixels: usize) -> Self {
let decoder_limit = (1024 + total_pixels).min(1 << 20);
self.max_nodes = self.max_nodes.min(decoder_limit);
self
}
#[must_use]
pub fn with_ref_properties(mut self, num_ref_channels: usize, effort: u8) -> Self {
if num_ref_channels == 0 {
return self;
}
if effort >= 9 {
for i in 0..num_ref_channels * 4 {
self.properties.push(NUM_PROPERTIES + i);
}
} else {
for i in 0..num_ref_channels {
self.properties.push(NUM_PROPERTIES + i * 4 + 3);
}
}
self
}
}
pub struct TreeSamples {
pub num_samples: usize,
candidate_predictors: &'static [Predictor],
residual_tokens: Vec<Vec<u8>>,
extra_bits: Vec<Vec<u8>>,
props: Vec<Vec<i32>>,
sample_counts: Vec<u32>,
num_ref_channels: usize,
}
impl Default for TreeSamples {
fn default() -> Self {
Self::new()
}
}
impl TreeSamples {
pub fn new() -> Self {
Self::with_predictors_and_refs(CANDIDATE_PREDICTORS, 0)
}
pub fn new_with_ref_channels(num_ref_channels: usize) -> Self {
Self::with_predictors_and_refs(CANDIDATE_PREDICTORS, num_ref_channels)
}
pub fn new_for_squeeze() -> Self {
Self::with_predictors_and_refs(CANDIDATE_PREDICTORS_SQUEEZE, 0)
}
fn with_predictors_and_refs(predictors: &'static [Predictor], num_ref_channels: usize) -> Self {
let num_predictors = predictors.len();
let total_props = NUM_PROPERTIES + 4 * num_ref_channels;
Self {
num_samples: 0,
candidate_predictors: predictors,
residual_tokens: vec![Vec::new(); num_predictors],
extra_bits: vec![Vec::new(); num_predictors],
props: vec![Vec::new(); total_props],
sample_counts: Vec::new(),
num_ref_channels,
}
}
pub fn total_num_properties(&self) -> usize {
NUM_PROPERTIES + 4 * self.num_ref_channels
}
pub fn num_predictors(&self) -> usize {
self.candidate_predictors.len()
}
fn pre_quantize(&self, params: &TreeLearningParams) -> PreQuantizedProps {
let max_buckets = params.max_property_values;
let n = self.num_samples;
let total_props = self.total_num_properties();
let mut threshold_sets = vec![Vec::new(); total_props];
let mut bucket_indices = vec![Vec::new(); total_props];
for &prop_idx in ¶ms.properties {
let props = &self.props[prop_idx];
let mut min_val = i32::MAX;
let mut max_val = i32::MIN;
for &v in &props[..n] {
if v < min_val {
min_val = v;
}
if v > max_val {
max_val = v;
}
}
if min_val == max_val {
bucket_indices[prop_idx] = vec![0u8; n];
continue;
}
let range = max_val as i64 - min_val as i64 + 1;
let ts: Vec<i32>;
if range <= (max_buckets * 4) as i64 {
let range_usize = range as usize;
let mut present = vec![false; range_usize];
for i in 0..n {
present[(props[i] - min_val) as usize] = true;
}
let mut unique_vals: Vec<i32> = present
.iter()
.enumerate()
.filter(|(_, p)| **p)
.map(|(i, _)| min_val + i as i32)
.collect();
if unique_vals.len() <= 1 {
bucket_indices[prop_idx] = vec![0u8; n];
continue;
}
unique_vals.pop();
ts = if unique_vals.len() <= max_buckets {
unique_vals
} else {
let step = unique_vals.len().div_ceil(max_buckets);
unique_vals
.iter()
.step_by(step.max(1))
.take(max_buckets)
.copied()
.collect()
};
} else {
let mut sample_vals: Vec<i32> = props[..n].to_vec();
sample_vals.sort_unstable();
sample_vals.dedup();
if sample_vals.len() <= 1 {
bucket_indices[prop_idx] = vec![0u8; n];
continue;
}
sample_vals.pop();
ts = if sample_vals.len() <= max_buckets {
sample_vals
} else {
let step = sample_vals.len() / max_buckets;
sample_vals
.iter()
.step_by(step.max(1))
.take(max_buckets)
.copied()
.collect()
};
}
let num_thresholds = ts.len();
let mut bi = vec![0u8; n];
for (bi_val, &v) in bi.iter_mut().zip(props[..n].iter()) {
let bucket = match ts.binary_search(&v) {
Ok(pos) => pos,
Err(pos) => {
if pos == 0 {
0
} else {
pos
}
}
};
*bi_val = bucket.min(num_thresholds) as u8;
}
threshold_sets[prop_idx] = ts;
bucket_indices[prop_idx] = bi;
}
PreQuantizedProps {
threshold_sets,
bucket_indices,
}
}
}
fn find_ref_channels(image: &ModularImage, channel_idx: usize) -> Vec<usize> {
if channel_idx == 0 {
return Vec::new();
}
let ch = &image.channels[channel_idx];
let w = ch.width();
let h = ch.height();
let hs = ch.hshift;
let vs = ch.vshift;
let mut refs = Vec::new();
for j in (0..channel_idx).rev() {
let ref_ch = &image.channels[j];
if ref_ch.width() == w && ref_ch.height() == h && ref_ch.hshift == hs && ref_ch.vshift == vs
{
refs.push(j);
}
}
refs
}
pub fn max_ref_channels(image: &ModularImage) -> usize {
let mut max_refs = 0;
for i in 0..image.channels.len() {
let refs = find_ref_channels(image, i);
max_refs = max_refs.max(refs.len());
}
max_refs
}
#[inline]
fn compute_spec_properties(
channel_idx: u32,
group_id: u32,
x: usize,
y: usize,
n: &Neighbors,
prev_gradient: i32,
wp_max_error: i32,
) -> [i32; NUM_PROPERTIES] {
let mut props = [0i32; NUM_PROPERTIES];
props[0] = channel_idx as i32;
props[1] = group_id as i32;
props[2] = y as i32;
props[3] = x as i32;
props[4] = n.n.wrapping_abs();
props[5] = n.w.wrapping_abs();
props[6] = n.n;
props[7] = n.w;
let gradient = n.w.wrapping_add(n.n).wrapping_sub(n.nw);
props[8] = n.w.wrapping_sub(prev_gradient);
props[9] = gradient;
props[10] = n.w.wrapping_sub(n.nw);
props[11] = n.nw.wrapping_sub(n.n);
props[12] = n.n.wrapping_sub(n.ne);
props[13] = n.n.wrapping_sub(n.nn);
props[14] = n.w.wrapping_sub(n.ww);
props[15] = wp_max_error;
props
}
#[cfg(test)]
pub fn gather_samples(samples: &mut TreeSamples, image: &ModularImage, group_id: u32) {
gather_samples_strided(
samples,
image,
group_id,
0,
1,
&WeightedPredictorParams::default(),
);
}
pub fn gather_samples_strided(
samples: &mut TreeSamples,
image: &ModularImage,
group_id: u32,
channel_offset: u32,
stride: usize,
wp_params: &WeightedPredictorParams,
) {
for (ch_idx, channel) in image.channels.iter().enumerate() {
let ref_channel_indices = if samples.num_ref_channels > 0 {
find_ref_channels(image, ch_idx)
} else {
Vec::new()
};
gather_channel_samples(
samples,
channel,
ch_idx as u32 + channel_offset,
group_id,
stride,
wp_params,
image,
&ref_channel_indices,
);
}
}
pub fn max_tree_samples_from_profile(
profile: &crate::effort::EffortProfile,
total_pixels: usize,
) -> usize {
if profile.tree_sample_fraction > 0.0 {
((total_pixels as f32 * profile.tree_sample_fraction) as usize).max(65_536)
} else if profile.tree_max_samples_fixed > 0 {
profile.tree_max_samples_fixed as usize
} else {
32_768
}
}
pub fn compute_gather_stride_from_profile(
total_pixels: usize,
profile: &crate::effort::EffortProfile,
) -> usize {
let max_samples = max_tree_samples_from_profile(profile, total_pixels);
if total_pixels > max_samples {
total_pixels.div_ceil(max_samples)
} else {
1
}
}
#[allow(clippy::too_many_arguments)]
fn gather_channel_samples(
samples: &mut TreeSamples,
channel: &Channel,
channel_idx: u32,
group_id: u32,
stride: usize,
wp_params: &WeightedPredictorParams,
image: &ModularImage,
ref_channel_indices: &[usize],
) {
let width = channel.width();
let height = channel.height();
if width == 0 || height == 0 {
return;
}
let mut wp_state = WeightedPredictorState::new(wp_params, width);
let mut prev_gradient: i32;
let mut subsample_counter: usize = 0;
let max_refs = samples.num_ref_channels;
for y in 0..height {
prev_gradient = 0;
for x in 0..width {
let pixel = channel.get(x, y);
let n = Neighbors::gather(channel, x, y);
let (wp_pred, wp_max_error) = wp_state.predict_and_property(x, y, width, &n);
wp_state.update_errors(pixel, x, y, width);
if subsample_counter == 0 {
let props = compute_spec_properties(
channel_idx,
group_id,
x,
y,
&n,
prev_gradient,
wp_max_error,
);
prev_gradient = props[9];
for (pred_idx, &predictor) in samples.candidate_predictors.iter().enumerate() {
let prediction = if predictor == Predictor::Weighted {
wp_pred as i32
} else {
predictor.predict_from_neighbors(&n)
};
let residual = pixel - prediction;
let packed = pack_signed(residual);
let (token, _extra_bits, num_extra) = GATHER_HYBRID_UINT.encode(packed);
samples.residual_tokens[pred_idx].push(token as u8);
samples.extra_bits[pred_idx].push(num_extra as u8);
}
for (prop_list, &val) in samples
.props
.iter_mut()
.zip(props.iter())
.take(NUM_PROPERTIES)
{
prop_list.push(val);
}
if max_refs > 0 {
for (r, &ref_ch_idx) in ref_channel_indices.iter().enumerate() {
let ref_ch = &image.channels[ref_ch_idx];
let v = ref_ch.get(x, y);
let ref_left = if x > 0 { ref_ch.get(x - 1, y) } else { 0 };
let ref_top = if y > 0 {
ref_ch.get(x, y - 1)
} else {
ref_left
};
let ref_topleft = if x > 0 && y > 0 {
ref_ch.get(x - 1, y - 1)
} else {
ref_left
};
let ref_predicted = crate::vardct::dc_coding::clamped_gradient(
ref_top,
ref_left,
ref_topleft,
);
let base = NUM_PROPERTIES + r * 4;
samples.props[base].push(v.wrapping_abs()); samples.props[base + 1].push(v); samples.props[base + 2].push(v.wrapping_sub(ref_predicted).wrapping_abs()); samples.props[base + 3].push(v.wrapping_sub(ref_predicted)); }
for r in ref_channel_indices.len()..max_refs {
let base = NUM_PROPERTIES + r * 4;
samples.props[base].push(0);
samples.props[base + 1].push(0);
samples.props[base + 2].push(0);
samples.props[base + 3].push(0);
}
}
samples.num_samples += 1;
subsample_counter = stride - 1;
} else {
let grad = n.w.wrapping_add(n.n).wrapping_sub(n.nw);
prev_gradient = grad;
subsample_counter -= 1;
}
}
}
}
#[inline]
pub fn estimate_bits(counts: &[u32], total: u32) -> f64 {
if total == 0 {
return 0.0;
}
let total_f = total as f64;
let min_prob = 1.0 / 4096.0;
let mut bits = 0.0;
for &c in counts {
if c > 0 {
let p = (c as f64 / total_f).max(min_prob);
bits -= c as f64 * jxl_simd::fast_log2f(p as f32) as f64;
}
}
bits
}
struct PreQuantizedProps {
threshold_sets: Vec<Vec<i32>>,
bucket_indices: Vec<Vec<u8>>,
}
impl PreQuantizedProps {
fn num_thresholds(&self, prop_idx: usize) -> usize {
self.threshold_sets[prop_idx].len()
}
}
fn dedup_samples(
samples: &mut TreeSamples,
pq: &mut PreQuantizedProps,
params: &TreeLearningParams,
) {
let n = samples.num_samples;
if n <= 1 {
samples.sample_counts = vec![1; n];
return;
}
let num_pred = samples.num_predictors();
let properties = ¶ms.properties;
let mut order: Vec<usize> = (0..n).collect();
order.sort_unstable_by(|&a, &b| {
for &prop_idx in properties {
let bi = &pq.bucket_indices[prop_idx];
if !bi.is_empty() {
match bi[a].cmp(&bi[b]) {
Ordering::Equal => {}
ord => return ord,
}
}
}
for pred in 0..num_pred {
match samples.residual_tokens[pred][a].cmp(&samples.residual_tokens[pred][b]) {
Ordering::Equal => {}
ord => return ord,
}
match samples.extra_bits[pred][a].cmp(&samples.extra_bits[pred][b]) {
Ordering::Equal => {}
ord => return ord,
}
}
Ordering::Equal
});
let mut unique_indices: Vec<usize> = Vec::with_capacity(n / 2);
let mut counts: Vec<u32> = Vec::with_capacity(n / 2);
unique_indices.push(order[0]);
counts.push(1);
for &curr in &order[1..] {
let prev = *unique_indices.last().unwrap();
if is_same_sample(prev, curr, samples, pq, properties.as_slice(), num_pred) {
*counts.last_mut().unwrap() += 1;
} else {
unique_indices.push(curr);
counts.push(1);
}
}
let num_unique = unique_indices.len();
for pred in 0..num_pred {
let old_tokens = &samples.residual_tokens[pred];
let old_ebits = &samples.extra_bits[pred];
let new_tokens: Vec<u8> = unique_indices.iter().map(|&i| old_tokens[i]).collect();
let new_ebits: Vec<u8> = unique_indices.iter().map(|&i| old_ebits[i]).collect();
samples.residual_tokens[pred] = new_tokens;
samples.extra_bits[pred] = new_ebits;
}
let total_props = samples.total_num_properties();
for prop_idx in 0..total_props {
let old_props = &samples.props[prop_idx];
if old_props.is_empty() {
continue;
}
let new_props: Vec<i32> = unique_indices.iter().map(|&i| old_props[i]).collect();
samples.props[prop_idx] = new_props;
}
for prop_idx in 0..total_props {
if prop_idx >= pq.bucket_indices.len() {
break;
}
let old_bi = &pq.bucket_indices[prop_idx];
if old_bi.is_empty() {
continue;
}
let new_bi: Vec<u8> = unique_indices.iter().map(|&i| old_bi[i]).collect();
pq.bucket_indices[prop_idx] = new_bi;
}
samples.num_samples = num_unique;
samples.sample_counts = counts;
}
#[inline]
fn is_same_sample(
a: usize,
b: usize,
samples: &TreeSamples,
pq: &PreQuantizedProps,
properties: &[usize],
num_pred: usize,
) -> bool {
for &prop_idx in properties {
let bi = &pq.bucket_indices[prop_idx];
if !bi.is_empty() && bi[a] != bi[b] {
return false;
}
}
for pred in 0..num_pred {
if samples.residual_tokens[pred][a] != samples.residual_tokens[pred][b] {
return false;
}
if samples.extra_bits[pred][a] != samples.extra_bits[pred][b] {
return false;
}
}
true
}
struct SplitCandidate {
node_idx: usize,
start: usize,
end: usize,
best_predictor: usize,
base_bits: f64,
multiplier: Option<u32>,
}
pub fn compute_best_tree(samples: &mut TreeSamples, params: &TreeLearningParams) -> Tree {
let required_cost = params.pixel_fraction * 0.9 + 0.1;
let threshold = params.split_threshold * required_cost;
let n = samples.num_samples;
if n == 0 {
return vec![PropertyDecisionNode {
property: -1,
predictor: Predictor::Gradient,
context_id: 0,
multiplier: 1,
..Default::default()
}];
}
let mut pq = samples.pre_quantize(params);
dedup_samples(samples, &mut pq, params);
let n = samples.num_samples;
let max_nodes = params.max_nodes;
let mut indices: Vec<usize> = (0..n).collect();
let max_token = samples
.residual_tokens
.iter()
.flat_map(|v| v.iter())
.copied()
.max()
.unwrap_or(0) as usize;
let histogram_size = max_token + 1;
let mut tree: Tree = Vec::new();
let mut entropy_counts = vec![0u32; histogram_size];
let root_predictor =
find_best_predictor(samples, &indices[..n], histogram_size, &mut entropy_counts);
let root_bits = compute_predictor_entropy(
samples,
&indices[..n],
root_predictor,
histogram_size,
&mut entropy_counts,
);
let mut stack: Vec<SplitCandidate> = Vec::new();
tree.push(PropertyDecisionNode::default());
stack.push(SplitCandidate {
node_idx: 0,
start: 0,
end: n,
best_predictor: root_predictor,
base_bits: root_bits,
multiplier: None,
});
let max_buckets = params.max_property_values + 1;
let mut workspace = SplitWorkspace::new(n, histogram_size, max_buckets);
while let Some(candidate) = stack.pop() {
if tree.len() + 2 > max_nodes {
finalize_leaf(&mut tree, &candidate, samples.candidate_predictors);
continue;
}
let count = candidate.end - candidate.start;
if count < 2 {
finalize_leaf(&mut tree, &candidate, samples.candidate_predictors);
continue;
}
if candidate.base_bits <= threshold {
finalize_leaf(&mut tree, &candidate, samples.candidate_predictors);
continue;
}
let best_split = find_best_split(
samples,
&indices[candidate.start..candidate.end],
histogram_size,
candidate.base_bits,
params,
candidate.best_predictor,
threshold,
&pq,
&mut workspace,
);
match best_split {
Some(split) if candidate.base_bits - split.total_bits > threshold => {
let mid = partition_indices(
&mut indices[candidate.start..candidate.end],
samples,
split.property,
split.splitval,
);
let abs_mid = candidate.start + mid;
let lchild_idx = tree.len();
let rchild_idx = tree.len() + 1;
tree.push(PropertyDecisionNode::default());
tree.push(PropertyDecisionNode::default());
tree[candidate.node_idx] = PropertyDecisionNode {
property: split.property as i32,
splitval: split.splitval,
lchild: lchild_idx,
rchild: rchild_idx,
..Default::default()
};
let left_bits = compute_predictor_entropy(
samples,
&indices[candidate.start..abs_mid],
split.left_predictor,
histogram_size,
&mut entropy_counts,
);
let right_bits = compute_predictor_entropy(
samples,
&indices[abs_mid..candidate.end],
split.right_predictor,
histogram_size,
&mut entropy_counts,
);
stack.push(SplitCandidate {
node_idx: rchild_idx,
start: abs_mid,
end: candidate.end,
best_predictor: split.right_predictor,
base_bits: right_bits,
multiplier: None,
});
stack.push(SplitCandidate {
node_idx: lchild_idx,
start: candidate.start,
end: abs_mid,
best_predictor: split.left_predictor,
base_bits: left_bits,
multiplier: None,
});
}
_ => {
finalize_leaf(&mut tree, &candidate, samples.candidate_predictors);
}
}
}
assign_sequential_contexts(&mut tree);
loop {
match validate_tree_djxl(&tree) {
Ok(()) => break,
Err(msg) => {
#[cfg(feature = "debug-rect")]
eprintln!("tree/validate: fixing invalid node: {}", msg);
let node_idx = msg
.strip_prefix("Node ")
.and_then(|s| s.split_whitespace().next())
.and_then(|s| s.parse::<usize>().ok())
.expect("validate_tree_djxl error format changed");
tree[node_idx] = PropertyDecisionNode {
property: -1,
splitval: 0,
predictor: super::predictor::Predictor::Gradient,
predictor_offset: 0,
multiplier: 1,
lchild: 0,
rchild: 0,
context_id: 0,
};
assign_sequential_contexts(&mut tree);
}
}
}
let _num_leaves = tree.iter().filter(|n| n.property == -1).count();
crate::trace::debug_eprintln!(
"compute_best_tree: {} samples, pf={:.3}, threshold={:.1} (base={:.0}*rc={:.3}), \
{} nodes, {} leaves, max_nodes={}",
n,
params.pixel_fraction,
threshold,
params.split_threshold,
required_cost,
tree.len(),
_num_leaves,
max_nodes,
);
tree
}
fn finalize_leaf(tree: &mut Tree, candidate: &SplitCandidate, predictors: &[Predictor]) {
tree[candidate.node_idx] = PropertyDecisionNode {
property: -1,
predictor: predictors[candidate.best_predictor],
predictor_offset: 0,
multiplier: candidate.multiplier.unwrap_or(1) as i32,
context_id: 0, ..Default::default()
};
}
pub fn compute_best_tree_with_multipliers(
samples: &mut TreeSamples,
params: &TreeLearningParams,
multiplier_info: &[super::quantize::ModularMultiplierInfo],
initial_range: [[u32; 2]; 2],
) -> Tree {
use super::quantize::{IntersectionType, box_intersects};
let required_cost = params.pixel_fraction * 0.9 + 0.1;
let threshold = params.split_threshold * required_cost;
let n = samples.num_samples;
if n == 0 {
return vec![PropertyDecisionNode {
property: -1,
predictor: Predictor::Zero,
context_id: 0,
multiplier: 1,
..Default::default()
}];
}
let mut pq = samples.pre_quantize(params);
dedup_samples(samples, &mut pq, params);
let n = samples.num_samples;
let max_nodes = params.max_nodes;
let mut indices: Vec<usize> = (0..n).collect();
let max_token = samples
.residual_tokens
.iter()
.flat_map(|v| v.iter())
.copied()
.max()
.unwrap_or(0) as usize;
let histogram_size = max_token + 1;
let mut tree: Tree = Vec::new();
let mut entropy_counts = vec![0u32; histogram_size];
let root_predictor =
find_best_predictor(samples, &indices[..n], histogram_size, &mut entropy_counts);
let root_bits = compute_predictor_entropy(
samples,
&indices[..n],
root_predictor,
histogram_size,
&mut entropy_counts,
);
struct SplitCandidateWithRange {
node_idx: usize,
start: usize,
end: usize,
best_predictor: usize,
base_bits: f64,
static_prop_range: [[u32; 2]; 2],
}
let mut stack: Vec<SplitCandidateWithRange> = Vec::new();
tree.push(PropertyDecisionNode::default());
stack.push(SplitCandidateWithRange {
node_idx: 0,
start: 0,
end: n,
best_predictor: root_predictor,
base_bits: root_bits,
static_prop_range: initial_range,
});
let max_buckets = params.max_property_values + 1;
let mut workspace = SplitWorkspace::new(n, histogram_size, max_buckets);
while let Some(candidate) = stack.pop() {
if candidate.end <= candidate.start {
continue;
}
let mut forced_split: Option<(usize, u32)> = None; let mut assigned_multiplier: Option<u32> = None;
for mmi in multiplier_info {
let (t, axis, val) = box_intersects(&candidate.static_prop_range, &mmi.range);
match t {
IntersectionType::None => continue,
IntersectionType::Inside => {
assigned_multiplier = Some(mmi.multiplier);
break;
}
IntersectionType::Partial => {
forced_split = Some((axis, val));
break;
}
}
}
if let Some(mult) = assigned_multiplier {
let predictor = if mult > 1 {
Predictor::Zero
} else {
CANDIDATE_PREDICTORS[candidate.best_predictor]
};
tree[candidate.node_idx] = PropertyDecisionNode {
property: -1,
predictor,
predictor_offset: 0,
multiplier: mult as i32,
context_id: 0,
..Default::default()
};
continue;
}
if let Some((axis, splitval)) = forced_split {
if tree.len() + 2 > max_nodes {
tree[candidate.node_idx] = PropertyDecisionNode {
property: -1,
predictor: CANDIDATE_PREDICTORS[candidate.best_predictor],
predictor_offset: 0,
multiplier: 1,
context_id: 0,
..Default::default()
};
continue;
}
let mid = partition_indices(
&mut indices[candidate.start..candidate.end],
samples,
axis,
splitval as i32,
);
let abs_mid = candidate.start + mid;
let lchild_idx = tree.len();
let rchild_idx = tree.len() + 1;
tree.push(PropertyDecisionNode::default());
tree.push(PropertyDecisionNode::default());
tree[candidate.node_idx] = PropertyDecisionNode {
property: axis as i32,
splitval: splitval as i32,
lchild: lchild_idx,
rchild: rchild_idx,
..Default::default()
};
let mut lchild_range = candidate.static_prop_range;
lchild_range[axis][1] = splitval + 1;
let mut rchild_range = candidate.static_prop_range;
rchild_range[axis][0] = splitval + 1;
let left_predictor = if abs_mid > candidate.start {
find_best_predictor(
samples,
&indices[candidate.start..abs_mid],
histogram_size,
&mut entropy_counts,
)
} else {
candidate.best_predictor
};
let right_predictor = if abs_mid < candidate.end {
find_best_predictor(
samples,
&indices[abs_mid..candidate.end],
histogram_size,
&mut entropy_counts,
)
} else {
candidate.best_predictor
};
let left_bits = if abs_mid > candidate.start {
compute_predictor_entropy(
samples,
&indices[candidate.start..abs_mid],
left_predictor,
histogram_size,
&mut entropy_counts,
)
} else {
0.0
};
let right_bits = if abs_mid < candidate.end {
compute_predictor_entropy(
samples,
&indices[abs_mid..candidate.end],
right_predictor,
histogram_size,
&mut entropy_counts,
)
} else {
0.0
};
stack.push(SplitCandidateWithRange {
node_idx: rchild_idx,
start: abs_mid,
end: candidate.end,
best_predictor: right_predictor,
base_bits: right_bits,
static_prop_range: rchild_range,
});
stack.push(SplitCandidateWithRange {
node_idx: lchild_idx,
start: candidate.start,
end: abs_mid,
best_predictor: left_predictor,
base_bits: left_bits,
static_prop_range: lchild_range,
});
continue;
}
if tree.len() + 2 > max_nodes {
tree[candidate.node_idx] = PropertyDecisionNode {
property: -1,
predictor: CANDIDATE_PREDICTORS[candidate.best_predictor],
predictor_offset: 0,
multiplier: 1,
context_id: 0,
..Default::default()
};
continue;
}
let count = candidate.end - candidate.start;
if count < 2 || candidate.base_bits <= threshold {
tree[candidate.node_idx] = PropertyDecisionNode {
property: -1,
predictor: CANDIDATE_PREDICTORS[candidate.best_predictor],
predictor_offset: 0,
multiplier: 1,
context_id: 0,
..Default::default()
};
continue;
}
let best_split = find_best_split(
samples,
&indices[candidate.start..candidate.end],
histogram_size,
candidate.base_bits,
params,
candidate.best_predictor,
threshold,
&pq,
&mut workspace,
);
match best_split {
Some(split) if candidate.base_bits - split.total_bits > threshold => {
let mid = partition_indices(
&mut indices[candidate.start..candidate.end],
samples,
split.property,
split.splitval,
);
let abs_mid = candidate.start + mid;
let lchild_idx = tree.len();
let rchild_idx = tree.len() + 1;
tree.push(PropertyDecisionNode::default());
tree.push(PropertyDecisionNode::default());
tree[candidate.node_idx] = PropertyDecisionNode {
property: split.property as i32,
splitval: split.splitval,
lchild: lchild_idx,
rchild: rchild_idx,
..Default::default()
};
let mut lchild_range = candidate.static_prop_range;
let mut rchild_range = candidate.static_prop_range;
if split.property < 2 {
lchild_range[split.property][1] =
(split.splitval + 1).min(lchild_range[split.property][1] as i32) as u32;
rchild_range[split.property][0] =
(split.splitval + 1).max(rchild_range[split.property][0] as i32) as u32;
}
let left_bits = compute_predictor_entropy(
samples,
&indices[candidate.start..abs_mid],
split.left_predictor,
histogram_size,
&mut entropy_counts,
);
let right_bits = compute_predictor_entropy(
samples,
&indices[abs_mid..candidate.end],
split.right_predictor,
histogram_size,
&mut entropy_counts,
);
stack.push(SplitCandidateWithRange {
node_idx: rchild_idx,
start: abs_mid,
end: candidate.end,
best_predictor: split.right_predictor,
base_bits: right_bits,
static_prop_range: rchild_range,
});
stack.push(SplitCandidateWithRange {
node_idx: lchild_idx,
start: candidate.start,
end: abs_mid,
best_predictor: split.left_predictor,
base_bits: left_bits,
static_prop_range: lchild_range,
});
}
_ => {
tree[candidate.node_idx] = PropertyDecisionNode {
property: -1,
predictor: CANDIDATE_PREDICTORS[candidate.best_predictor],
predictor_offset: 0,
multiplier: 1,
context_id: 0,
..Default::default()
};
}
}
}
assign_sequential_contexts(&mut tree);
loop {
match validate_tree_djxl(&tree) {
Ok(()) => break,
Err(msg) => {
#[cfg(feature = "debug-rect")]
eprintln!("tree/validate: fixing invalid node: {}", msg);
let node_idx = msg
.strip_prefix("Node ")
.and_then(|s| s.split_whitespace().next())
.and_then(|s| s.parse::<usize>().ok())
.expect("validate_tree_djxl error format changed");
tree[node_idx] = PropertyDecisionNode {
property: -1,
splitval: 0,
predictor: Predictor::Gradient,
predictor_offset: 0,
multiplier: 1,
lchild: 0,
rchild: 0,
context_id: 0,
};
assign_sequential_contexts(&mut tree);
}
}
}
let _num_leaves = tree.iter().filter(|n| n.property == -1).count();
crate::trace::debug_eprintln!(
"compute_best_tree_with_multipliers: {} samples, {} nodes, {} leaves, {} mul_info entries",
n,
tree.len(),
_num_leaves,
multiplier_info.len(),
);
tree
}
const HISTO_PADDED: usize = 128;
const HISTO_MASK: usize = HISTO_PADDED - 1;
struct SplitWorkspace {
count_increase: Vec<u32>,
extra_bits_increase: Vec<u64>,
bucket_counts: Vec<u32>,
right_counts: Vec<u32>,
left_counts: Vec<u32>,
best_l_cost: Vec<f64>,
best_r_cost: Vec<f64>,
best_l_penalized: Vec<f64>,
best_r_penalized: Vec<f64>,
best_l_pred: Vec<usize>,
best_r_pred: Vec<usize>,
sorted_by_bucket: Vec<usize>,
bucket_starts: Vec<usize>,
bucket_write_pos: Vec<usize>,
}
impl SplitWorkspace {
fn new(max_count: usize, histogram_size: usize, max_buckets: usize) -> Self {
assert!(
histogram_size <= HISTO_PADDED,
"histogram_size {} exceeds HISTO_PADDED {}",
histogram_size,
HISTO_PADDED
);
Self {
count_increase: vec![0u32; max_buckets * HISTO_PADDED],
extra_bits_increase: vec![0u64; max_buckets],
bucket_counts: vec![0u32; max_buckets],
right_counts: vec![0u32; histogram_size],
left_counts: vec![0u32; histogram_size],
best_l_cost: vec![f64::MAX; max_buckets],
best_r_cost: vec![f64::MAX; max_buckets],
best_l_penalized: vec![f64::MAX; max_buckets],
best_r_penalized: vec![f64::MAX; max_buckets],
best_l_pred: vec![0usize; max_buckets],
best_r_pred: vec![0usize; max_buckets],
sorted_by_bucket: vec![0usize; max_count],
bucket_starts: vec![0usize; max_buckets + 2],
bucket_write_pos: vec![0usize; max_buckets],
}
}
}
struct BestSplit {
property: usize,
splitval: i32,
left_predictor: usize,
right_predictor: usize,
total_bits: f64,
}
#[allow(clippy::too_many_arguments)]
fn find_best_split(
samples: &TreeSamples,
indices: &[usize],
histogram_size: usize,
base_bits: f64,
params: &TreeLearningParams,
parent_predictor: usize,
threshold: f64,
pq: &PreQuantizedProps,
ws: &mut SplitWorkspace,
) -> Option<BestSplit> {
let count = indices.len();
if count < 2 {
return None;
}
let total_num_pred = samples.num_predictors();
let mut best: Option<BestSplit> = None;
let mut best_bits = base_bits;
let sample_counts = &samples.sample_counts;
let weighted_total: u32 = indices.iter().map(|&i| sample_counts[i]).sum();
let change_pred_penalty = 800.0 / (100.0 + threshold);
let weighted_idx = samples
.candidate_predictors
.iter()
.position(|&p| p == Predictor::Weighted)
.unwrap_or(usize::MAX);
let zero_idx = CANDIDATE_PREDICTORS
.iter()
.position(|&p| p == Predictor::Zero)
.unwrap_or(usize::MAX);
let num_pred = (if weighted_total >= 2048 {
total_num_pred } else if weighted_total >= 512 {
10
} else if weighted_total >= 64 {
7
} else {
4
})
.min(total_num_pred);
let effective_histo = histogram_size;
if effective_histo == 0 {
return None;
}
let count_increase = ws.count_increase.as_mut_slice();
let extra_bits_increase = ws.extra_bits_increase.as_mut_slice();
let bucket_counts = ws.bucket_counts.as_mut_slice();
let right_counts = ws.right_counts.as_mut_slice();
let left_counts = ws.left_counts.as_mut_slice();
let best_l_cost = ws.best_l_cost.as_mut_slice();
let best_r_cost = ws.best_r_cost.as_mut_slice();
let best_l_penalized = ws.best_l_penalized.as_mut_slice();
let best_r_penalized = ws.best_r_penalized.as_mut_slice();
let best_l_pred = ws.best_l_pred.as_mut_slice();
let best_r_pred = ws.best_r_pred.as_mut_slice();
let sorted_by_bucket = ws.sorted_by_bucket.as_mut_slice();
let bucket_starts = ws.bucket_starts.as_mut_slice();
let bucket_write_pos = ws.bucket_write_pos.as_mut_slice();
let num_props = if weighted_total >= 256 {
params.properties.len()
} else if weighted_total >= 32 {
params.properties.len().min(4)
} else {
params.properties.len().min(2)
};
for &prop_idx in ¶ms.properties[..num_props] {
let num_thresholds = pq.num_thresholds(prop_idx);
if num_thresholds == 0 {
continue;
}
let pq_buckets = &pq.bucket_indices[prop_idx];
let threshold_set = &pq.threshold_sets[prop_idx];
let mut bmin: u8 = u8::MAX;
let mut bmax: u8 = 0;
for &idx in indices {
let b = pq_buckets[idx];
if b < bmin {
bmin = b;
}
if b > bmax {
bmax = b;
}
}
if bmin == bmax {
continue; }
let bmin = bmin as usize;
let bmax = bmax as usize;
let local_num_buckets = bmax - bmin + 1;
let local_num_thresholds = bmax - bmin;
let mut unique_per_bucket = [0u32; 256];
bucket_counts[..local_num_buckets].fill(0); for &idx in indices {
let b = (pq_buckets[idx] as usize) - bmin;
unique_per_bucket[b] += 1;
bucket_counts[b] += sample_counts[idx];
}
bucket_starts[0] = 0;
for b in 0..local_num_buckets {
bucket_starts[b + 1] = bucket_starts[b] + unique_per_bucket[b] as usize;
}
bucket_write_pos[..local_num_buckets].copy_from_slice(&bucket_starts[..local_num_buckets]);
for &idx in indices {
let b = (pq_buckets[idx] as usize) - bmin;
sorted_by_bucket[bucket_write_pos[b]] = idx;
bucket_write_pos[b] += 1;
}
best_l_cost[..local_num_thresholds].fill(f64::MAX);
best_r_cost[..local_num_thresholds].fill(f64::MAX);
best_l_penalized[..local_num_thresholds].fill(f64::MAX);
best_r_penalized[..local_num_thresholds].fill(f64::MAX);
best_l_pred[..local_num_thresholds].fill(0);
best_r_pred[..local_num_thresholds].fill(0);
for pred in 0..num_pred {
let tokens = &samples.residual_tokens[pred];
let ebits = &samples.extra_bits[pred];
let mut penalty: f64 = 0.0;
if pred != parent_predictor && parent_predictor != weighted_idx {
penalty = change_pred_penalty;
}
if pred == weighted_idx {
penalty += 1e-8;
} else if pred == zero_idx {
penalty -= 1e-8;
}
for b in 0..local_num_buckets {
count_increase[b * HISTO_PADDED..b * HISTO_PADDED + effective_histo].fill(0);
}
extra_bits_increase[..local_num_buckets].fill(0);
for local_bucket in 0..local_num_buckets {
let start = bucket_starts[local_bucket];
let end = bucket_starts[local_bucket + 1];
let ci_base = local_bucket * HISTO_PADDED;
let ci_slice = &mut count_increase[ci_base..ci_base + HISTO_PADDED];
let mut eb_sum: u64 = 0;
for &idx in &sorted_by_bucket[start..end] {
let tok = tokens[idx];
let sc = sample_counts[idx];
ci_slice[tok as usize & HISTO_MASK] += sc;
eb_sum += ebits[idx] as u64 * sc as u64;
}
extra_bits_increase[local_bucket] = eb_sum;
}
right_counts[..effective_histo].fill(0);
let mut right_extra: u64 = 0;
let mut right_total: u32 = weighted_total;
for (local_bucket, &eb) in extra_bits_increase[..local_num_buckets].iter().enumerate() {
let ci_base = local_bucket * HISTO_PADDED;
let ci_row = &count_increase[ci_base..ci_base + effective_histo];
for (rc, &ci) in right_counts[..effective_histo]
.iter_mut()
.zip(ci_row.iter())
{
*rc += ci;
}
right_extra += eb;
}
left_counts[..effective_histo].fill(0);
let mut left_extra: u64 = 0;
let mut left_total: u32 = 0;
for local_k in 0..local_num_thresholds {
let bc = bucket_counts[local_k];
if bc == 0 {
continue;
}
let ci_base = local_k * HISTO_PADDED;
let ci_row = &count_increase[ci_base..ci_base + effective_histo];
for (i, &ci) in ci_row.iter().enumerate() {
if ci > 0 {
left_counts[i] += ci;
right_counts[i] -= ci;
}
}
left_extra += extra_bits_increase[local_k];
right_extra -= extra_bits_increase[local_k];
left_total += bc;
right_total -= bc;
if left_total == 0 || right_total == 0 {
continue;
}
let l_bits =
estimate_bits(&left_counts[..effective_histo], left_total) + left_extra as f64;
let r_bits = estimate_bits(&right_counts[..effective_histo], right_total)
+ right_extra as f64;
if l_bits + penalty < best_l_penalized[local_k] {
best_l_penalized[local_k] = l_bits + penalty;
best_l_cost[local_k] = l_bits;
best_l_pred[local_k] = pred;
}
if r_bits + penalty < best_r_penalized[local_k] {
best_r_penalized[local_k] = r_bits + penalty;
best_r_cost[local_k] = r_bits;
best_r_pred[local_k] = pred;
}
}
}
for local_k in 0..local_num_thresholds {
if best_l_cost[local_k] == f64::MAX || best_r_cost[local_k] == f64::MAX {
continue;
}
let total = best_l_cost[local_k] + best_r_cost[local_k];
if total < best_bits {
best_bits = total;
let global_k = bmin + local_k;
best = Some(BestSplit {
property: prop_idx,
splitval: threshold_set[global_k],
left_predictor: best_l_pred[local_k],
right_predictor: best_r_pred[local_k],
total_bits: total,
});
}
}
}
best
}
fn find_best_predictor(
samples: &TreeSamples,
indices: &[usize],
histogram_size: usize,
counts_buf: &mut [u32],
) -> usize {
let num_pred = samples.num_predictors();
let mut best_pred = 0;
let mut best_bits = f64::MAX;
for pred_idx in 0..num_pred {
let bits =
compute_predictor_entropy(samples, indices, pred_idx, histogram_size, counts_buf);
if bits < best_bits {
best_bits = bits;
best_pred = pred_idx;
}
}
best_pred
}
fn compute_predictor_entropy(
samples: &TreeSamples,
indices: &[usize],
predictor_idx: usize,
histogram_size: usize,
counts_buf: &mut [u32],
) -> f64 {
let tokens = &samples.residual_tokens[predictor_idx];
let ebits = &samples.extra_bits[predictor_idx];
let sample_counts = &samples.sample_counts;
counts_buf[..histogram_size].fill(0);
let mut total = 0u32;
let mut tot_extra: u64 = 0;
for &idx in indices {
let count = sample_counts[idx];
let tok = tokens[idx] as usize;
if tok < histogram_size {
counts_buf[tok] += count;
total += count;
}
tot_extra += ebits[idx] as u64 * count as u64;
}
estimate_bits(&counts_buf[..histogram_size], total) + tot_extra as f64
}
fn partition_indices(
indices: &mut [usize],
samples: &TreeSamples,
prop_idx: usize,
splitval: i32,
) -> usize {
let props = &samples.props[prop_idx];
let mut left = 0;
let mut right = indices.len();
while left < right {
if props[indices[left]] <= splitval {
left += 1;
} else {
right -= 1;
indices.swap(left, right);
}
}
left
}
pub fn collect_residuals_with_tree(
image: &ModularImage,
tree: &Tree,
group_id: u32,
wp_params: &WeightedPredictorParams,
) -> Vec<crate::entropy_coding::token::Token> {
collect_residuals_with_tree_offset(image, tree, group_id, 0, wp_params)
}
pub fn collect_residuals_with_tree_offset(
image: &ModularImage,
tree: &Tree,
group_id: u32,
channel_offset: u32,
wp_params: &WeightedPredictorParams,
) -> Vec<crate::entropy_coding::token::Token> {
use crate::entropy_coding::token::Token as AnsToken;
let max_tree_prop = tree
.iter()
.filter(|n| n.property >= 0)
.map(|n| n.property as usize)
.max()
.unwrap_or(0);
let needs_ref_props = max_tree_prop >= NUM_PROPERTIES;
let mut tokens = Vec::new();
let num_extended_props = if needs_ref_props {
max_tree_prop + 1
} else {
NUM_PROPERTIES
};
let mut extended_props = vec![0i32; num_extended_props];
for (ch_idx, channel) in image.channels.iter().enumerate() {
let width = channel.width();
let height = channel.height();
if width == 0 || height == 0 {
continue;
}
let ref_channel_indices = if needs_ref_props {
find_ref_channels(image, ch_idx)
} else {
Vec::new()
};
let mut wp_state = WeightedPredictorState::new(wp_params, width);
let mut prev_gradient: i32;
for y in 0..height {
prev_gradient = 0;
for x in 0..width {
let pixel = channel.get(x, y);
let n = Neighbors::gather(channel, x, y);
let (wp_pred, wp_max_error) = wp_state.predict_and_property(x, y, width, &n);
let base_props = compute_spec_properties(
ch_idx as u32 + channel_offset,
group_id,
x,
y,
&n,
prev_gradient,
wp_max_error,
);
prev_gradient = base_props[9];
let leaf = if needs_ref_props {
extended_props[..NUM_PROPERTIES].copy_from_slice(&base_props);
for (r, &ref_ch_idx) in ref_channel_indices.iter().enumerate() {
let ref_ch = &image.channels[ref_ch_idx];
let v = ref_ch.get(x, y);
let ref_left = if x > 0 { ref_ch.get(x - 1, y) } else { 0 };
let ref_top = if y > 0 {
ref_ch.get(x, y - 1)
} else {
ref_left
};
let ref_topleft = if x > 0 && y > 0 {
ref_ch.get(x - 1, y - 1)
} else {
ref_left
};
let ref_predicted = crate::vardct::dc_coding::clamped_gradient(
ref_top,
ref_left,
ref_topleft,
);
let base = NUM_PROPERTIES + r * 4;
if base + 3 < num_extended_props {
extended_props[base] = v.wrapping_abs();
extended_props[base + 1] = v;
extended_props[base + 2] = v.wrapping_sub(ref_predicted).wrapping_abs();
extended_props[base + 3] = v.wrapping_sub(ref_predicted);
}
}
let num_ref_slots = (num_extended_props - NUM_PROPERTIES) / 4;
for r in ref_channel_indices.len()..num_ref_slots {
let base = NUM_PROPERTIES + r * 4;
if base + 3 < num_extended_props {
extended_props[base] = 0;
extended_props[base + 1] = 0;
extended_props[base + 2] = 0;
extended_props[base + 3] = 0;
}
}
traverse_with_props(tree, &extended_props)
} else {
traverse_with_spec_props(tree, &base_props)
};
let prediction = if leaf.predictor == Predictor::Weighted {
wp_pred as i32
} else {
leaf.predictor.predict_from_neighbors(&n)
};
let residual = pixel - prediction;
let multiplier = leaf.multiplier;
let divided = if multiplier == 1 {
residual
} else {
debug_assert!(
residual % multiplier == 0,
"residual {} not divisible by multiplier {} at ({},{}) ch={}",
residual,
multiplier,
x,
y,
ch_idx,
);
residual / multiplier
};
let packed = pack_signed(divided);
wp_state.update_errors(pixel, x, y, width);
tokens.push(AnsToken::new(leaf.context_id, packed));
}
}
}
tokens
}
fn traverse_with_spec_props<'a>(
tree: &'a Tree,
props: &[i32; NUM_PROPERTIES],
) -> &'a PropertyDecisionNode {
let mut idx = 0;
loop {
let node = &tree[idx];
if node.property < 0 {
return node;
}
let pval = props[node.property as usize];
if pval <= node.splitval {
idx = node.lchild;
} else {
idx = node.rchild;
}
}
}
fn traverse_with_props<'a>(tree: &'a Tree, props: &[i32]) -> &'a PropertyDecisionNode {
let mut idx = 0;
loop {
let node = &tree[idx];
if node.property < 0 {
return node;
}
let pval = props[node.property as usize];
if pval <= node.splitval {
idx = node.lchild;
} else {
idx = node.rchild;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::modular::channel::ModularImage;
#[test]
fn test_estimate_bits_uniform() {
let counts = [100u32, 100, 100, 100];
let total = 400;
let bits = estimate_bits(&counts, total);
assert!(
(bits - 800.0).abs() < 0.01,
"expected 800 bits, got {}",
bits
);
}
#[test]
fn test_estimate_bits_single_symbol() {
let counts = [100u32];
let total = 100;
let bits = estimate_bits(&counts, total);
assert!(
bits < 1.0,
"single symbol should have near-zero entropy, got {}",
bits
);
}
#[test]
fn test_gather_samples_simple() {
let image = ModularImage::from_gray8(&[128u8; 16], 4, 4).unwrap();
let mut samples = TreeSamples::new();
gather_samples(&mut samples, &image, 0);
assert_eq!(samples.num_samples, 16);
}
#[test]
fn test_compute_best_tree_constant() {
let image = ModularImage::from_gray8(&[100u8; 64], 8, 8).unwrap();
let mut samples = TreeSamples::new();
gather_samples(&mut samples, &image, 0);
let params = TreeLearningParams::for_effort(9);
let tree = compute_best_tree(&mut samples, ¶ms);
assert!(!tree.is_empty());
assert_eq!(tree[0].property, -1);
}
#[test]
fn test_compute_best_tree_two_channels() {
let mut image = ModularImage {
channels: Vec::new(),
bit_depth: 8,
is_grayscale: false,
has_alpha: false,
};
let mut ch0 = Channel::new(32, 32).unwrap();
for y in 0..32 {
for x in 0..32 {
ch0.set(x, y, 100);
}
}
image.channels.push(ch0);
let mut ch1 = Channel::new(32, 32).unwrap();
for y in 0..32 {
for x in 0..32 {
ch1.set(x, y, (x * 7 + y * 5) as i32);
}
}
image.channels.push(ch1);
let mut samples = TreeSamples::new();
gather_samples(&mut samples, &image, 0);
let params = TreeLearningParams::for_effort(9);
let tree = compute_best_tree(&mut samples, ¶ms);
let num_leaves = tree.iter().filter(|n| n.property < 0).count();
assert!(num_leaves >= 2, "expected >= 2 leaves, got {}", num_leaves);
}
#[test]
fn test_collect_residuals_with_tree() {
let tree = vec![PropertyDecisionNode {
property: -1,
predictor: Predictor::Gradient,
context_id: 0,
multiplier: 1,
..Default::default()
}];
let image = ModularImage::from_gray8(&[100u8; 16], 4, 4).unwrap();
let tokens =
collect_residuals_with_tree(&image, &tree, 0, &WeightedPredictorParams::default());
assert_eq!(tokens.len(), 16);
for t in &tokens {
assert_eq!(t.context(), 0);
}
}
#[test]
fn test_traverse_with_spec_props() {
let tree = vec![
PropertyDecisionNode {
property: 0, splitval: 0,
lchild: 1,
rchild: 2,
..Default::default()
},
PropertyDecisionNode {
property: -1,
predictor: Predictor::Zero,
context_id: 0,
multiplier: 1,
..Default::default()
},
PropertyDecisionNode {
property: -1,
predictor: Predictor::Gradient,
context_id: 1,
multiplier: 1,
..Default::default()
},
];
let mut props = [0i32; NUM_PROPERTIES];
props[0] = 0;
let leaf = traverse_with_spec_props(&tree, &props);
assert_eq!(leaf.predictor, Predictor::Zero);
props[0] = 1;
let leaf = traverse_with_spec_props(&tree, &props);
assert_eq!(leaf.predictor, Predictor::Gradient);
}
#[test]
fn test_partition_indices() {
let image = ModularImage::from_gray8(&[0u8; 16], 4, 4).unwrap();
let mut samples = TreeSamples::new();
gather_samples(&mut samples, &image, 0);
let mut indices: Vec<usize> = (0..samples.num_samples).collect();
let mid = partition_indices(&mut indices, &samples, 3, 1);
assert_eq!(mid, 8);
for &i in &indices[..mid] {
assert!(samples.props[3][i] <= 1);
}
for &i in &indices[mid..] {
assert!(samples.props[3][i] > 1);
}
}
}