use super::channel::ModularImage;
use super::encode::{
write_gradient_tree_tokens, write_hybrid_data_histogram, write_palette_transform,
write_rct_transform, write_tree_histogram_for_gradient,
};
use super::predictor::pack_signed;
use super::rct::RctType;
use crate::bit_writer::BitWriter;
use crate::entropy_coding::encode::{
OwnedAnsEntropyCode, build_entropy_code_ans, write_tokens_ans,
};
use crate::entropy_coding::hybrid_uint::HybridUintConfig;
use crate::entropy_coding::token::Token as AnsToken;
use crate::error::Result;
const MODULAR_HYBRID_UINT: HybridUintConfig = HybridUintConfig {
split_exponent: 4,
split: 16, msb_in_token: 2,
lsb_in_token: 0,
};
#[inline]
fn predict_gradient(left: i32, top: i32, topleft: i32) -> i32 {
let grad = left + top - topleft;
let min = left.min(top);
let max = left.max(top);
grad.clamp(min, max)
}
pub fn collect_all_residuals(image: &ModularImage) -> (Vec<u32>, u32) {
let mut residuals = Vec::new();
let mut max_residual: u32 = 0;
for channel in &image.channels {
let width = channel.width();
let height = channel.height();
for y in 0..height {
for x in 0..width {
let pixel = channel.get(x, y);
let left = if x > 0 { channel.get(x - 1, y) } else { 0 };
let top = if y > 0 { channel.get(x, y - 1) } else { left };
let topleft = if x > 0 && y > 0 {
channel.get(x - 1, y - 1)
} else {
left
};
let prediction = predict_gradient(left, top, topleft);
let residual = pixel - prediction;
let packed = pack_signed(residual);
residuals.push(packed);
max_residual = max_residual.max(packed);
}
}
}
(residuals, max_residual)
}
pub fn build_histogram_from_residuals(residuals: &[u32], _max_residual: u32) -> (Vec<u32>, u32) {
let mut max_token: u32 = 0;
for &r in residuals {
let (token, _, _) = MODULAR_HYBRID_UINT.encode(r);
max_token = max_token.max(token);
}
let histogram_size = (max_token + 1) as usize;
let mut histogram = vec![0u32; histogram_size];
for &r in residuals {
let (token, _, _) = MODULAR_HYBRID_UINT.encode(r);
histogram[token as usize] += 1;
}
(histogram, max_token)
}
pub enum GlobalModularState {
Huffman {
depths: Vec<u8>,
codes: Vec<u16>,
max_token: u32,
},
Ans {
code: OwnedAnsEntropyCode,
},
AnsWithTree {
code: OwnedAnsEntropyCode,
tree: super::tree::Tree,
wp_params: super::predictor::WeightedPredictorParams,
},
}
fn ceil_log2_nonzero(x: u32) -> u32 {
debug_assert!(x > 0);
let floor = 31 - x.leading_zeros();
if x.is_power_of_two() {
floor
} else {
floor + 1
}
}
pub(super) fn write_ans_modular_header(
writer: &mut BitWriter,
code: &OwnedAnsEntropyCode,
) -> Result<()> {
assert_eq!(
code.histograms.len(),
1,
"modular ANS header only supports single-distribution (single-leaf tree)"
);
writer.write(1, 0)?;
writer.write(1, 0)?;
let las = code.log_alpha_size;
writer.write(2, (las - 5) as u64)?;
let config = code
.uint_configs
.first()
.copied()
.unwrap_or(crate::entropy_coding::hybrid_uint::HybridUintConfig::default_config());
let se_bits = ceil_log2_nonzero(las as u32 + 1);
writer.write(se_bits as usize, config.split_exponent as u64)?;
if (config.split_exponent as usize) != las {
let msb_bits = ceil_log2_nonzero(config.split_exponent + 1);
writer.write(msb_bits as usize, config.msb_in_token as u64)?;
let lsb_bits = ceil_log2_nonzero(config.split_exponent - config.msb_in_token + 1);
writer.write(lsb_bits as usize, config.lsb_in_token as u64)?;
}
code.histograms[0].write(writer)?;
Ok(())
}
pub fn write_global_modular_section(
all_residuals: &[u32],
histogram: &[u32],
max_token: u32,
writer: &mut BitWriter,
use_ans: bool,
transforms: GlobalTransforms,
) -> Result<GlobalModularState> {
crate::trace::debug_eprintln!(
"GLOBAL_MODULAR [bit {}]: Starting global section (ans={})",
writer.bits_written(),
use_ans
);
writer.write(1, 1)?;
writer.write(1, 1)?;
let (tree_depths, tree_codes) = write_tree_histogram_for_gradient(writer)?;
write_gradient_tree_tokens(writer, &tree_depths, &tree_codes)?;
if use_ans {
let tokens: Vec<AnsToken> = all_residuals.iter().map(|&r| AnsToken::new(0, r)).collect();
let code = build_entropy_code_ans(&tokens, 1);
write_ans_modular_header(writer, &code)?;
writer.write(1, 1)?; writer.write(1, 1)?; write_global_transforms_full(writer, &transforms)?;
writer.zero_pad_to_byte();
crate::trace::debug_eprintln!(
"GLOBAL_MODULAR [bit {}]: Global section done (ANS)",
writer.bits_written()
);
Ok(GlobalModularState::Ans { code })
} else {
let (depths, codes) = write_hybrid_data_histogram(writer, histogram, max_token)?;
writer.write(1, 1)?; writer.write(1, 1)?; write_global_transforms_full(writer, &transforms)?;
writer.zero_pad_to_byte();
crate::trace::debug_eprintln!(
"GLOBAL_MODULAR [bit {}]: Global section done (Huffman)",
writer.bits_written()
);
Ok(GlobalModularState::Huffman {
depths,
codes,
max_token,
})
}
}
pub fn write_global_modular_section_with_tree(
images: &[ModularImage],
writer: &mut BitWriter,
profile: &crate::effort::EffortProfile,
transforms: GlobalTransforms,
use_lz77: bool,
lz77_method: crate::entropy_coding::lz77::Lz77Method,
meta_image: Option<&ModularImage>,
) -> Result<GlobalModularState> {
write_global_modular_section_with_tree_dc_quant(
images,
writer,
profile,
transforms,
use_lz77,
lz77_method,
None,
meta_image,
)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn write_global_modular_section_with_tree_dc_quant(
images: &[ModularImage],
writer: &mut BitWriter,
profile: &crate::effort::EffortProfile,
transforms: GlobalTransforms,
use_lz77: bool,
lz77_method: crate::entropy_coding::lz77::Lz77Method,
dc_quant_custom: Option<[f32; 3]>,
meta_image: Option<&ModularImage>,
) -> Result<GlobalModularState> {
use super::encode::write_tree;
use super::encode::write_wp_header;
use super::predictor::WeightedPredictorParams;
use super::tree::count_contexts;
use super::tree_learn::{
TreeLearningParams, TreeSamples, collect_residuals_with_tree, compute_best_tree,
compute_gather_stride_from_profile, gather_samples_strided, max_ref_channels,
};
use crate::entropy_coding::encode::build_entropy_code_ans_with_options;
use crate::entropy_coding::encode::write_entropy_code_ans;
use crate::entropy_coding::lz77::write_lz77_header;
let all_channels: Vec<&super::channel::Channel> = meta_image
.into_iter()
.chain(images.iter())
.flat_map(|img| img.channels.iter())
.collect();
let wp_params = if profile.wp_num_param_sets > 0 {
let channels_for_wp: Vec<super::channel::Channel> =
all_channels.iter().map(|c| (*c).clone()).collect();
super::predictor::find_best_wp_params(&channels_for_wp, profile.wp_num_param_sets)
} else {
WeightedPredictorParams::default()
};
let total_pixels: usize = meta_image
.into_iter()
.chain(images.iter())
.flat_map(|img| img.channels.iter())
.map(|ch| ch.width() * ch.height())
.sum();
let stride = compute_gather_stride_from_profile(total_pixels, profile);
let num_refs = {
let mut mr = 0;
if let Some(meta) = meta_image {
mr = mr.max(max_ref_channels(meta));
}
for img in images.iter() {
mr = mr.max(max_ref_channels(img));
}
mr
};
let mut samples = TreeSamples::new_with_ref_channels(num_refs);
if let Some(meta) = meta_image {
gather_samples_strided(&mut samples, meta, 0, 0, stride, &wp_params);
}
let per_group_id_offset = if meta_image.is_some() { 1u32 } else { 0u32 };
for (group_idx, group_image) in images.iter().enumerate() {
gather_samples_strided(
&mut samples,
group_image,
group_idx as u32 + per_group_id_offset,
0,
stride,
&wp_params,
);
}
let pixel_fraction = if total_pixels > 0 {
samples.num_samples as f64 / total_pixels as f64
} else {
1.0
};
let params = TreeLearningParams::from_profile(profile)
.with_ref_properties(num_refs, profile.effort)
.with_pixel_fraction(pixel_fraction)
.with_total_pixels(total_pixels);
let tree = compute_best_tree(&mut samples, ¶ms);
let num_contexts = count_contexts(&tree) as usize;
crate::trace::debug_eprintln!(
"GLOBAL_MODULAR_TREE: {} nodes, {} leaves/contexts from {} samples \
(pixel_fraction={:.3}, threshold={:.1}*{:.3}={:.1})",
tree.len(),
num_contexts,
samples.num_samples,
pixel_fraction,
params.split_threshold,
pixel_fraction * 0.9 + 0.1,
params.split_threshold * (pixel_fraction * 0.9 + 0.1),
);
let mut all_tokens = Vec::new();
let nb_meta_tokens = if let Some(meta) = meta_image {
let meta_tokens = collect_residuals_with_tree(meta, &tree, 0, &wp_params);
let n = meta_tokens.len();
all_tokens.extend(meta_tokens);
n
} else {
0
};
for (group_idx, group_image) in images.iter().enumerate() {
let group_tokens = collect_residuals_with_tree(
group_image,
&tree,
group_idx as u32 + per_group_id_offset,
&wp_params,
);
all_tokens.extend(group_tokens);
}
let _ = (use_lz77, lz77_method); let lz77_params: Option<crate::entropy_coding::lz77::Lz77Params> = None;
let ans_num_contexts = if lz77_params.is_some() {
num_contexts + 1
} else {
num_contexts
};
let code = build_entropy_code_ans_with_options(
&all_tokens,
ans_num_contexts,
true, true, lz77_params.as_ref(),
Some(total_pixels),
);
eprintln!(
"DIAG tree: {} nodes, {} contexts, {} samples, {} total_tokens, \
max_nodes={}, threshold={:.1}, pixel_frac={:.3}",
tree.len(),
num_contexts,
samples.num_samples,
all_tokens.len(),
params.max_nodes,
params.split_threshold,
pixel_fraction,
);
eprintln!(
"DIAG code: {} histograms (from {} contexts), rct={:?}, compact={}",
code.histograms.len(),
ans_num_contexts,
transforms.rct_type,
transforms.compact_info.len(),
);
let bits_before = writer.bits_written();
crate::f16::write_lf_quant(writer, dc_quant_custom)?;
writer.write(1, 1)?;
let bits_before_tree = writer.bits_written();
write_tree(writer, &tree)?;
let tree_bits = writer.bits_written() - bits_before_tree;
let bits_before_histo = writer.bits_written();
if ans_num_contexts > 1 {
write_lz77_header(lz77_params.as_ref(), writer)?;
write_entropy_code_ans(&code, writer)?;
} else {
write_ans_modular_header(writer, &code)?;
}
let histo_bits = writer.bits_written() - bits_before_histo;
writer.write(1, 1)?; write_wp_header(writer, &wp_params)?;
write_global_transforms_full(writer, &transforms)?;
if nb_meta_tokens > 0 {
let meta_token_slice = &all_tokens[..nb_meta_tokens];
write_tokens_ans(meta_token_slice, &code, None, writer)?;
}
let total_lf_global_bits = writer.bits_written() - bits_before;
eprintln!(
"DIAG LfGlobal: tree={} bits ({} B), histo={} bits ({} B), \
meta_tokens={}, total={} bits ({} B)",
tree_bits,
tree_bits / 8,
histo_bits,
histo_bits / 8,
nb_meta_tokens,
total_lf_global_bits,
total_lf_global_bits / 8,
);
writer.zero_pad_to_byte();
Ok(GlobalModularState::AnsWithTree {
code,
tree,
wp_params,
})
}
pub struct GlobalTransforms {
pub compact_info: Vec<(usize, usize)>,
pub rct_type: Option<RctType>,
}
impl GlobalTransforms {
pub fn rct_only(rct_type: Option<RctType>) -> Self {
Self {
compact_info: Vec::new(),
rct_type,
}
}
}
fn write_global_transforms_full(
writer: &mut BitWriter,
transforms: &GlobalTransforms,
) -> Result<()> {
let num_transforms =
transforms.compact_info.len() as u32 + transforms.rct_type.is_some() as u32;
super::encode::write_num_transforms(writer, num_transforms)?;
for &(begin_c, nb_colors) in &transforms.compact_info {
write_palette_transform(writer, begin_c, 1, nb_colors, 0, 0)?;
}
if let Some(rct) = transforms.rct_type {
let rct_begin_c = transforms.compact_info.len();
write_rct_transform(writer, rct_begin_c, rct)?;
}
Ok(())
}
fn collect_group_residuals(group_image: &ModularImage) -> Vec<u32> {
let mut residuals = Vec::new();
for channel in &group_image.channels {
let width = channel.width();
let height = channel.height();
for y in 0..height {
for x in 0..width {
let pixel = channel.get(x, y);
let left = if x > 0 { channel.get(x - 1, y) } else { 0 };
let top = if y > 0 { channel.get(x, y - 1) } else { left };
let topleft = if x > 0 && y > 0 {
channel.get(x - 1, y - 1)
} else {
left
};
let prediction = predict_gradient(left, top, topleft);
let residual = pixel - prediction;
residuals.push(pack_signed(residual));
}
}
}
residuals
}
pub fn write_group_modular_section(
group_image: &ModularImage,
state: &GlobalModularState,
writer: &mut BitWriter,
) -> Result<()> {
write_group_modular_section_idx(group_image, state, 0, &GroupTransforms::none(), writer)
}
#[derive(Clone)]
pub struct GroupTransforms {
pub compact_info: Vec<(usize, usize)>,
pub rct_type: Option<RctType>,
}
impl GroupTransforms {
pub fn none() -> Self {
Self {
compact_info: Vec::new(),
rct_type: None,
}
}
}
pub fn write_group_modular_section_idx(
group_image: &ModularImage,
state: &GlobalModularState,
group_idx: u32,
transforms: &GroupTransforms,
writer: &mut BitWriter,
) -> Result<()> {
crate::trace::debug_eprintln!(
"GROUP_MODULAR [bit {}]: Starting group section ({}x{}, compact={}, rct={:?})",
writer.bits_written(),
group_image.width(),
group_image.height(),
transforms.compact_info.len(),
transforms.rct_type,
);
writer.write(1, 1)?; match state {
GlobalModularState::AnsWithTree { wp_params, .. } => {
super::encode::write_wp_header(writer, wp_params)?;
}
_ => {
writer.write(1, 1)?; }
}
let num_transforms =
transforms.compact_info.len() as u32 + transforms.rct_type.is_some() as u32;
super::encode::write_num_transforms(writer, num_transforms)?;
for &(begin_c, nb_colors) in &transforms.compact_info {
write_palette_transform(writer, begin_c, 1, nb_colors, 0, 0)?;
}
if let Some(rct) = transforms.rct_type {
let rct_begin_c = transforms.compact_info.len();
write_rct_transform(writer, rct_begin_c, rct)?;
}
match state {
GlobalModularState::Huffman {
depths,
codes,
max_token: _,
} => {
for channel in &group_image.channels {
let width = channel.width();
let height = channel.height();
for y in 0..height {
for x in 0..width {
let pixel = channel.get(x, y);
let left = if x > 0 { channel.get(x - 1, y) } else { 0 };
let top = if y > 0 { channel.get(x, y - 1) } else { left };
let topleft = if x > 0 && y > 0 {
channel.get(x - 1, y - 1)
} else {
left
};
let prediction = predict_gradient(left, top, topleft);
let residual = pixel - prediction;
let packed = pack_signed(residual);
let (token, extra_bits, num_extra) = MODULAR_HYBRID_UINT.encode(packed);
let depth = depths.get(token as usize).copied().unwrap_or(0);
let code = codes.get(token as usize).copied().unwrap_or(0);
if depth > 0 {
writer.write(depth as usize, code as u64)?;
}
if num_extra > 0 {
writer.write(num_extra as usize, extra_bits as u64)?;
}
}
}
}
}
GlobalModularState::Ans { code } => {
let residuals = collect_group_residuals(group_image);
let tokens: Vec<AnsToken> = residuals.iter().map(|&r| AnsToken::new(0, r)).collect();
write_tokens_ans(&tokens, code, None, writer)?;
}
GlobalModularState::AnsWithTree {
code,
tree,
wp_params,
} => {
let tokens = super::tree_learn::collect_residuals_with_tree(
group_image,
tree,
group_idx,
wp_params,
);
write_tokens_ans(&tokens, code, None, writer)?;
}
}
writer.zero_pad_to_byte();
crate::trace::debug_eprintln!(
"GROUP_MODULAR [bit {}]: Group section done",
writer.bits_written()
);
Ok(())
}