1use super::channel::ModularImage;
11use super::encode::{
12 write_gradient_tree_tokens, write_hybrid_data_histogram, write_palette_transform,
13 write_rct_transform, write_tree_histogram_for_gradient,
14};
15use super::predictor::pack_signed;
16use super::rct::RctType;
17use crate::bit_writer::BitWriter;
18use crate::entropy_coding::encode::{
19 OwnedAnsEntropyCode, build_entropy_code_ans, write_tokens_ans,
20};
21use crate::entropy_coding::hybrid_uint::HybridUintConfig;
22use crate::entropy_coding::token::Token as AnsToken;
23use crate::error::Result;
24
25const MODULAR_HYBRID_UINT: HybridUintConfig = HybridUintConfig {
27 split_exponent: 4,
28 split: 16, msb_in_token: 2,
30 lsb_in_token: 0,
31};
32
33#[inline]
35fn predict_gradient(left: i32, top: i32, topleft: i32) -> i32 {
36 let grad = left + top - topleft;
37 let min = left.min(top);
39 let max = left.max(top);
40 grad.clamp(min, max)
41}
42
43pub fn collect_all_residuals(image: &ModularImage) -> (Vec<u32>, u32) {
44 let mut residuals = Vec::new();
45 let mut max_residual: u32 = 0;
46
47 for channel in &image.channels {
48 let width = channel.width();
49 let height = channel.height();
50
51 for y in 0..height {
52 for x in 0..width {
53 let pixel = channel.get(x, y);
54
55 let left = if x > 0 { channel.get(x - 1, y) } else { 0 };
57 let top = if y > 0 { channel.get(x, y - 1) } else { left };
58 let topleft = if x > 0 && y > 0 {
59 channel.get(x - 1, y - 1)
60 } else {
61 left
62 };
63
64 let prediction = predict_gradient(left, top, topleft);
66 let residual = pixel - prediction;
67 let packed = pack_signed(residual);
68
69 residuals.push(packed);
70 max_residual = max_residual.max(packed);
71 }
72 }
73 }
74
75 (residuals, max_residual)
76}
77
78pub fn build_histogram_from_residuals(residuals: &[u32], _max_residual: u32) -> (Vec<u32>, u32) {
81 let mut max_token: u32 = 0;
82 for &r in residuals {
84 let (token, _, _) = MODULAR_HYBRID_UINT.encode(r);
85 max_token = max_token.max(token);
86 }
87 let histogram_size = (max_token + 1) as usize;
89 let mut histogram = vec![0u32; histogram_size];
90 for &r in residuals {
91 let (token, _, _) = MODULAR_HYBRID_UINT.encode(r);
92 histogram[token as usize] += 1;
93 }
94 (histogram, max_token)
95}
96
97pub enum GlobalModularState {
100 Huffman {
102 depths: Vec<u8>,
104 codes: Vec<u16>,
106 max_token: u32,
108 },
109 Ans {
111 code: OwnedAnsEntropyCode,
113 },
114 AnsWithTree {
116 code: OwnedAnsEntropyCode,
118 tree: super::tree::Tree,
120 wp_params: super::predictor::WeightedPredictorParams,
122 },
123}
124
125fn ceil_log2_nonzero(x: u32) -> u32 {
127 debug_assert!(x > 0);
128 let floor = 31 - x.leading_zeros();
129 if x.is_power_of_two() {
130 floor
131 } else {
132 floor + 1
133 }
134}
135
136pub(super) fn write_ans_modular_header(
141 writer: &mut BitWriter,
142 code: &OwnedAnsEntropyCode,
143) -> Result<()> {
144 assert_eq!(
145 code.histograms.len(),
146 1,
147 "modular ANS header only supports single-distribution (single-leaf tree)"
148 );
149
150 writer.write(1, 0)?;
152
153 writer.write(1, 0)?;
157
158 let las = code.log_alpha_size;
160 writer.write(2, (las - 5) as u64)?;
161
162 let config = code
164 .uint_configs
165 .first()
166 .copied()
167 .unwrap_or(crate::entropy_coding::hybrid_uint::HybridUintConfig::default_config());
168 let se_bits = ceil_log2_nonzero(las as u32 + 1);
169 writer.write(se_bits as usize, config.split_exponent as u64)?;
170 if (config.split_exponent as usize) != las {
171 let msb_bits = ceil_log2_nonzero(config.split_exponent + 1);
172 writer.write(msb_bits as usize, config.msb_in_token as u64)?;
173 let lsb_bits = ceil_log2_nonzero(config.split_exponent - config.msb_in_token + 1);
174 writer.write(lsb_bits as usize, config.lsb_in_token as u64)?;
175 }
176
177 code.histograms[0].write(writer)?;
179
180 Ok(())
181}
182
183pub fn write_global_modular_section(
195 all_residuals: &[u32],
196 histogram: &[u32],
197 max_token: u32,
198 writer: &mut BitWriter,
199 use_ans: bool,
200 transforms: GlobalTransforms,
201) -> Result<GlobalModularState> {
202 crate::trace::debug_eprintln!(
203 "GLOBAL_MODULAR [bit {}]: Starting global section (ans={})",
204 writer.bits_written(),
205 use_ans
206 );
207
208 writer.write(1, 1)?;
210 writer.write(1, 1)?;
212
213 let (tree_depths, tree_codes) = write_tree_histogram_for_gradient(writer)?;
215 write_gradient_tree_tokens(writer, &tree_depths, &tree_codes)?;
216
217 if use_ans {
218 let tokens: Vec<AnsToken> = all_residuals.iter().map(|&r| AnsToken::new(0, r)).collect();
220 let code = build_entropy_code_ans(&tokens, 1); write_ans_modular_header(writer, &code)?;
224
225 writer.write(1, 1)?; writer.write(1, 1)?; write_global_transforms_full(writer, &transforms)?;
229
230 writer.zero_pad_to_byte();
232 crate::trace::debug_eprintln!(
233 "GLOBAL_MODULAR [bit {}]: Global section done (ANS)",
234 writer.bits_written()
235 );
236
237 Ok(GlobalModularState::Ans { code })
238 } else {
239 let (depths, codes) = write_hybrid_data_histogram(writer, histogram, max_token)?;
241
242 writer.write(1, 1)?; writer.write(1, 1)?; write_global_transforms_full(writer, &transforms)?;
246
247 writer.zero_pad_to_byte();
249 crate::trace::debug_eprintln!(
250 "GLOBAL_MODULAR [bit {}]: Global section done (Huffman)",
251 writer.bits_written()
252 );
253
254 Ok(GlobalModularState::Huffman {
255 depths,
256 codes,
257 max_token,
258 })
259 }
260}
261
262pub fn write_global_modular_section_with_tree(
272 images: &[ModularImage],
273 writer: &mut BitWriter,
274 profile: &crate::effort::EffortProfile,
275 transforms: GlobalTransforms,
276 use_lz77: bool,
277 lz77_method: crate::entropy_coding::lz77::Lz77Method,
278 meta_image: Option<&ModularImage>,
279) -> Result<GlobalModularState> {
280 write_global_modular_section_with_tree_dc_quant(
281 images,
282 writer,
283 profile,
284 transforms,
285 use_lz77,
286 lz77_method,
287 None,
288 meta_image,
289 )
290}
291
292#[allow(clippy::too_many_arguments)]
294pub(crate) fn write_global_modular_section_with_tree_dc_quant(
295 images: &[ModularImage],
296 writer: &mut BitWriter,
297 profile: &crate::effort::EffortProfile,
298 transforms: GlobalTransforms,
299 use_lz77: bool,
300 lz77_method: crate::entropy_coding::lz77::Lz77Method,
301 dc_quant_custom: Option<[f32; 3]>,
302 meta_image: Option<&ModularImage>,
303) -> Result<GlobalModularState> {
304 use super::encode::write_tree;
305 use super::encode::write_wp_header;
306 use super::predictor::WeightedPredictorParams;
307 use super::tree::count_contexts;
308 use super::tree_learn::{
309 TreeLearningParams, TreeSamples, collect_residuals_with_tree, compute_best_tree,
310 compute_gather_stride_from_profile, gather_samples_strided, max_ref_channels,
311 };
312 use crate::entropy_coding::encode::build_entropy_code_ans_with_options;
313 use crate::entropy_coding::encode::write_entropy_code_ans;
314 use crate::entropy_coding::lz77::write_lz77_header;
315
316 let all_channels: Vec<&super::channel::Channel> = meta_image
318 .into_iter()
319 .chain(images.iter())
320 .flat_map(|img| img.channels.iter())
321 .collect();
322 let wp_params = if profile.wp_num_param_sets > 0 {
323 let channels_for_wp: Vec<super::channel::Channel> =
325 all_channels.iter().map(|c| (*c).clone()).collect();
326 super::predictor::find_best_wp_params(&channels_for_wp, profile.wp_num_param_sets)
327 } else {
328 WeightedPredictorParams::default()
329 };
330
331 let total_pixels: usize = meta_image
333 .into_iter()
334 .chain(images.iter())
335 .flat_map(|img| img.channels.iter())
336 .map(|ch| ch.width() * ch.height())
337 .sum();
338 let stride = compute_gather_stride_from_profile(total_pixels, profile);
339 let num_refs = {
341 let mut mr = 0;
342 if let Some(meta) = meta_image {
343 mr = mr.max(max_ref_channels(meta));
344 }
345 for img in images.iter() {
346 mr = mr.max(max_ref_channels(img));
347 }
348 mr
349 };
350 let mut samples = TreeSamples::new_with_ref_channels(num_refs);
351 if let Some(meta) = meta_image {
353 gather_samples_strided(&mut samples, meta, 0, 0, stride, &wp_params);
354 }
355 let per_group_id_offset = if meta_image.is_some() { 1u32 } else { 0u32 };
364 for (group_idx, group_image) in images.iter().enumerate() {
365 gather_samples_strided(
366 &mut samples,
367 group_image,
368 group_idx as u32 + per_group_id_offset,
369 0,
370 stride,
371 &wp_params,
372 );
373 }
374
375 let pixel_fraction = if total_pixels > 0 {
377 samples.num_samples as f64 / total_pixels as f64
378 } else {
379 1.0
380 };
381 let params = TreeLearningParams::from_profile(profile)
382 .with_ref_properties(num_refs, profile.effort)
383 .with_pixel_fraction(pixel_fraction)
384 .with_total_pixels(total_pixels);
385 let tree = compute_best_tree(&mut samples, ¶ms);
386 let num_contexts = count_contexts(&tree) as usize;
387
388 crate::trace::debug_eprintln!(
389 "GLOBAL_MODULAR_TREE: {} nodes, {} leaves/contexts from {} samples \
390 (pixel_fraction={:.3}, threshold={:.1}*{:.3}={:.1})",
391 tree.len(),
392 num_contexts,
393 samples.num_samples,
394 pixel_fraction,
395 params.split_threshold,
396 pixel_fraction * 0.9 + 0.1,
397 params.split_threshold * (pixel_fraction * 0.9 + 0.1),
398 );
399
400 let mut all_tokens = Vec::new();
402 let nb_meta_tokens = if let Some(meta) = meta_image {
404 let meta_tokens = collect_residuals_with_tree(meta, &tree, 0, &wp_params);
405 let n = meta_tokens.len();
406 all_tokens.extend(meta_tokens);
407 n
408 } else {
409 0
410 };
411 for (group_idx, group_image) in images.iter().enumerate() {
413 let group_tokens = collect_residuals_with_tree(
414 group_image,
415 &tree,
416 group_idx as u32 + per_group_id_offset,
417 &wp_params,
418 );
419 all_tokens.extend(group_tokens);
420 }
421
422 let _ = (use_lz77, lz77_method); let lz77_params: Option<crate::entropy_coding::lz77::Lz77Params> = None;
429 let ans_num_contexts = if lz77_params.is_some() {
430 num_contexts + 1
431 } else {
432 num_contexts
433 };
434
435 let code = build_entropy_code_ans_with_options(
437 &all_tokens,
438 ans_num_contexts,
439 true, true, lz77_params.as_ref(),
442 Some(total_pixels),
443 );
444
445 eprintln!(
446 "DIAG tree: {} nodes, {} contexts, {} samples, {} total_tokens, \
447 max_nodes={}, threshold={:.1}, pixel_frac={:.3}",
448 tree.len(),
449 num_contexts,
450 samples.num_samples,
451 all_tokens.len(),
452 params.max_nodes,
453 params.split_threshold,
454 pixel_fraction,
455 );
456 eprintln!(
457 "DIAG code: {} histograms (from {} contexts), rct={:?}, compact={}",
458 code.histograms.len(),
459 ans_num_contexts,
460 transforms.rct_type,
461 transforms.compact_info.len(),
462 );
463
464 let bits_before = writer.bits_written();
466 crate::f16::write_lf_quant(writer, dc_quant_custom)?;
467 writer.write(1, 1)?;
469
470 let bits_before_tree = writer.bits_written();
472 write_tree(writer, &tree)?;
473 let tree_bits = writer.bits_written() - bits_before_tree;
474
475 let bits_before_histo = writer.bits_written();
477 if ans_num_contexts > 1 {
478 write_lz77_header(lz77_params.as_ref(), writer)?;
479 write_entropy_code_ans(&code, writer)?;
480 } else {
481 write_ans_modular_header(writer, &code)?;
482 }
483 let histo_bits = writer.bits_written() - bits_before_histo;
484
485 writer.write(1, 1)?; write_wp_header(writer, &wp_params)?;
488 write_global_transforms_full(writer, &transforms)?;
489
490 if nb_meta_tokens > 0 {
493 let meta_token_slice = &all_tokens[..nb_meta_tokens];
494 write_tokens_ans(meta_token_slice, &code, None, writer)?;
495 }
496
497 let total_lf_global_bits = writer.bits_written() - bits_before;
498 eprintln!(
499 "DIAG LfGlobal: tree={} bits ({} B), histo={} bits ({} B), \
500 meta_tokens={}, total={} bits ({} B)",
501 tree_bits,
502 tree_bits / 8,
503 histo_bits,
504 histo_bits / 8,
505 nb_meta_tokens,
506 total_lf_global_bits,
507 total_lf_global_bits / 8,
508 );
509
510 writer.zero_pad_to_byte();
511
512 Ok(GlobalModularState::AnsWithTree {
513 code,
514 tree,
515 wp_params,
516 })
517}
518
519pub struct GlobalTransforms {
521 pub compact_info: Vec<(usize, usize)>,
523 pub rct_type: Option<RctType>,
525}
526
527impl GlobalTransforms {
528 pub fn rct_only(rct_type: Option<RctType>) -> Self {
529 Self {
530 compact_info: Vec::new(),
531 rct_type,
532 }
533 }
534}
535
536fn write_global_transforms_full(
541 writer: &mut BitWriter,
542 transforms: &GlobalTransforms,
543) -> Result<()> {
544 let num_transforms =
545 transforms.compact_info.len() as u32 + transforms.rct_type.is_some() as u32;
546 super::encode::write_num_transforms(writer, num_transforms)?;
547
548 for &(begin_c, nb_colors) in &transforms.compact_info {
550 write_palette_transform(writer, begin_c, 1, nb_colors, 0, 0)?;
551 }
552 if let Some(rct) = transforms.rct_type {
554 let rct_begin_c = transforms.compact_info.len();
555 write_rct_transform(writer, rct_begin_c, rct)?;
556 }
557 Ok(())
558}
559
560fn collect_group_residuals(group_image: &ModularImage) -> Vec<u32> {
562 let mut residuals = Vec::new();
563 for channel in &group_image.channels {
564 let width = channel.width();
565 let height = channel.height();
566 for y in 0..height {
567 for x in 0..width {
568 let pixel = channel.get(x, y);
569 let left = if x > 0 { channel.get(x - 1, y) } else { 0 };
570 let top = if y > 0 { channel.get(x, y - 1) } else { left };
571 let topleft = if x > 0 && y > 0 {
572 channel.get(x - 1, y - 1)
573 } else {
574 left
575 };
576 let prediction = predict_gradient(left, top, topleft);
577 let residual = pixel - prediction;
578 residuals.push(pack_signed(residual));
579 }
580 }
581 }
582 residuals
583}
584
585pub fn write_group_modular_section(
593 group_image: &ModularImage,
594 state: &GlobalModularState,
595 writer: &mut BitWriter,
596) -> Result<()> {
597 write_group_modular_section_idx(group_image, state, 0, &GroupTransforms::none(), writer)
598}
599
600#[derive(Clone)]
608pub struct GroupTransforms {
609 pub compact_info: Vec<(usize, usize)>,
611 pub rct_type: Option<RctType>,
613}
614
615impl GroupTransforms {
616 pub fn none() -> Self {
617 Self {
618 compact_info: Vec::new(),
619 rct_type: None,
620 }
621 }
622}
623
624pub fn write_group_modular_section_idx(
625 group_image: &ModularImage,
626 state: &GlobalModularState,
627 group_idx: u32,
628 transforms: &GroupTransforms,
629 writer: &mut BitWriter,
630) -> Result<()> {
631 crate::trace::debug_eprintln!(
632 "GROUP_MODULAR [bit {}]: Starting group section ({}x{}, compact={}, rct={:?})",
633 writer.bits_written(),
634 group_image.width(),
635 group_image.height(),
636 transforms.compact_info.len(),
637 transforms.rct_type,
638 );
639
640 writer.write(1, 1)?; match state {
644 GlobalModularState::AnsWithTree { wp_params, .. } => {
645 super::encode::write_wp_header(writer, wp_params)?;
646 }
647 _ => {
648 writer.write(1, 1)?; }
650 }
651 let num_transforms =
653 transforms.compact_info.len() as u32 + transforms.rct_type.is_some() as u32;
654 super::encode::write_num_transforms(writer, num_transforms)?;
655 for &(begin_c, nb_colors) in &transforms.compact_info {
656 write_palette_transform(writer, begin_c, 1, nb_colors, 0, 0)?;
657 }
658 if let Some(rct) = transforms.rct_type {
659 let rct_begin_c = transforms.compact_info.len();
660 write_rct_transform(writer, rct_begin_c, rct)?;
661 }
662
663 match state {
664 GlobalModularState::Huffman {
665 depths,
666 codes,
667 max_token: _,
668 } => {
669 for channel in &group_image.channels {
671 let width = channel.width();
672 let height = channel.height();
673 for y in 0..height {
674 for x in 0..width {
675 let pixel = channel.get(x, y);
676 let left = if x > 0 { channel.get(x - 1, y) } else { 0 };
677 let top = if y > 0 { channel.get(x, y - 1) } else { left };
678 let topleft = if x > 0 && y > 0 {
679 channel.get(x - 1, y - 1)
680 } else {
681 left
682 };
683 let prediction = predict_gradient(left, top, topleft);
684 let residual = pixel - prediction;
685 let packed = pack_signed(residual);
686
687 let (token, extra_bits, num_extra) = MODULAR_HYBRID_UINT.encode(packed);
688 let depth = depths.get(token as usize).copied().unwrap_or(0);
689 let code = codes.get(token as usize).copied().unwrap_or(0);
690 if depth > 0 {
691 writer.write(depth as usize, code as u64)?;
692 }
693 if num_extra > 0 {
694 writer.write(num_extra as usize, extra_bits as u64)?;
695 }
696 }
697 }
698 }
699 }
700 GlobalModularState::Ans { code } => {
701 let residuals = collect_group_residuals(group_image);
703 let tokens: Vec<AnsToken> = residuals.iter().map(|&r| AnsToken::new(0, r)).collect();
704 write_tokens_ans(&tokens, code, None, writer)?;
705 }
706 GlobalModularState::AnsWithTree {
707 code,
708 tree,
709 wp_params,
710 } => {
711 let tokens = super::tree_learn::collect_residuals_with_tree(
715 group_image,
716 tree,
717 group_idx,
718 wp_params,
719 );
720 write_tokens_ans(&tokens, code, None, writer)?;
721 }
722 }
723
724 writer.zero_pad_to_byte();
726 crate::trace::debug_eprintln!(
727 "GROUP_MODULAR [bit {}]: Group section done",
728 writer.bits_written()
729 );
730
731 Ok(())
732}