Skip to main content

jxl_encoder/modular/
section.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! Modular section encoding for multi-group images.
6//!
7//! Handles GlobalModularState and section writing for large images that
8//! are split into multiple groups.
9
10use super::channel::ModularImage;
11use super::encode::{
12    write_gradient_tree_tokens, write_hybrid_data_histogram, write_palette_transform,
13    write_rct_transform, write_tree_histogram_for_gradient,
14};
15use super::predictor::pack_signed;
16use super::rct::RctType;
17use crate::bit_writer::BitWriter;
18use crate::entropy_coding::encode::{
19    OwnedAnsEntropyCode, build_entropy_code_ans, write_tokens_ans,
20};
21use crate::entropy_coding::hybrid_uint::HybridUintConfig;
22use crate::entropy_coding::token::Token as AnsToken;
23use crate::error::Result;
24
25/// Default HybridUint config for modular data: split_exponent=4, msb_in_token=2, lsb_in_token=0.
26const MODULAR_HYBRID_UINT: HybridUintConfig = HybridUintConfig {
27    split_exponent: 4,
28    split: 16, // 1 << 4
29    msb_in_token: 2,
30    lsb_in_token: 0,
31};
32
33/// Gradient prediction (ClampedGradient).
34#[inline]
35fn predict_gradient(left: i32, top: i32, topleft: i32) -> i32 {
36    let grad = left + top - topleft;
37    // Clamp to [min(left, top), max(left, top)]
38    let min = left.min(top);
39    let max = left.max(top);
40    grad.clamp(min, max)
41}
42
43pub fn collect_all_residuals(image: &ModularImage) -> (Vec<u32>, u32) {
44    let mut residuals = Vec::new();
45    let mut max_residual: u32 = 0;
46
47    for channel in &image.channels {
48        let width = channel.width();
49        let height = channel.height();
50
51        for y in 0..height {
52            for x in 0..width {
53                let pixel = channel.get(x, y);
54
55                // Get neighbors (matching JXL decoder)
56                let left = if x > 0 { channel.get(x - 1, y) } else { 0 };
57                let top = if y > 0 { channel.get(x, y - 1) } else { left };
58                let topleft = if x > 0 && y > 0 {
59                    channel.get(x - 1, y - 1)
60                } else {
61                    left
62                };
63
64                // Predict using ClampedGradient (predictor 5)
65                let prediction = predict_gradient(left, top, topleft);
66                let residual = pixel - prediction;
67                let packed = pack_signed(residual);
68
69                residuals.push(packed);
70                max_residual = max_residual.max(packed);
71            }
72        }
73    }
74
75    (residuals, max_residual)
76}
77
78/// Builds a histogram from residuals, encoding through HybridUint {4,2,0}.
79/// Returns (histogram_on_tokens, max_token).
80pub fn build_histogram_from_residuals(residuals: &[u32], _max_residual: u32) -> (Vec<u32>, u32) {
81    let mut max_token: u32 = 0;
82    // First pass: find max token
83    for &r in residuals {
84        let (token, _, _) = MODULAR_HYBRID_UINT.encode(r);
85        max_token = max_token.max(token);
86    }
87    // Second pass: build histogram on tokens
88    let histogram_size = (max_token + 1) as usize;
89    let mut histogram = vec![0u32; histogram_size];
90    for &r in residuals {
91        let (token, _, _) = MODULAR_HYBRID_UINT.encode(r);
92        histogram[token as usize] += 1;
93    }
94    (histogram, max_token)
95}
96
97/// Result of writing the global modular section.
98/// Contains the entropy codes needed to encode pixel data in group sections.
99pub enum GlobalModularState {
100    /// Huffman entropy coding state.
101    Huffman {
102        /// Huffman bit depths for each HybridUint token.
103        depths: Vec<u8>,
104        /// Huffman codes for each HybridUint token.
105        codes: Vec<u16>,
106        /// Maximum HybridUint token value.
107        max_token: u32,
108    },
109    /// ANS entropy coding state (single-context gradient tree).
110    Ans {
111        /// The ANS entropy code (distributions, context map, etc.)
112        code: OwnedAnsEntropyCode,
113    },
114    /// ANS entropy coding with learned MA tree (multi-context).
115    AnsWithTree {
116        /// The ANS entropy code (multiple distributions, context map).
117        code: OwnedAnsEntropyCode,
118        /// The learned MA tree for per-pixel predictor/context selection.
119        tree: super::tree::Tree,
120        /// WP parameters used during tree learning and residual collection.
121        wp_params: super::predictor::WeightedPredictorParams,
122    },
123}
124
125/// CeilLog2Nonzero matching the JXL spec.
126fn ceil_log2_nonzero(x: u32) -> u32 {
127    debug_assert!(x > 0);
128    let floor = 31 - x.leading_zeros();
129    if x.is_power_of_two() {
130        floor
131    } else {
132        floor + 1
133    }
134}
135
136/// Write ANS data histogram header for a single-context modular stream.
137///
138/// For modular with a single-leaf MA tree (num_dist=1), the context map is NOT written.
139/// Layout: lz77.enabled=0 + use_prefix_code=0 + log_alpha_size + HybridUint config + ANS distribution
140pub(super) fn write_ans_modular_header(
141    writer: &mut BitWriter,
142    code: &OwnedAnsEntropyCode,
143) -> Result<()> {
144    assert_eq!(
145        code.histograms.len(),
146        1,
147        "modular ANS header only supports single-distribution (single-leaf tree)"
148    );
149
150    // lz77.enabled = 0
151    writer.write(1, 0)?;
152
153    // NO context map for num_dist=1
154
155    // use_prefix_code = 0 (ANS, not Huffman)
156    writer.write(1, 0)?;
157
158    // log_alpha_size - 5 (2 bits)
159    let las = code.log_alpha_size;
160    writer.write(2, (las - 5) as u64)?;
161
162    // HybridUint config (per-histogram optimized, or default {4,2,0})
163    let config = code
164        .uint_configs
165        .first()
166        .copied()
167        .unwrap_or(crate::entropy_coding::hybrid_uint::HybridUintConfig::default_config());
168    let se_bits = ceil_log2_nonzero(las as u32 + 1);
169    writer.write(se_bits as usize, config.split_exponent as u64)?;
170    if (config.split_exponent as usize) != las {
171        let msb_bits = ceil_log2_nonzero(config.split_exponent + 1);
172        writer.write(msb_bits as usize, config.msb_in_token as u64)?;
173        let lsb_bits = ceil_log2_nonzero(config.split_exponent - config.msb_in_token + 1);
174        writer.write(lsb_bits as usize, config.lsb_in_token as u64)?;
175    }
176
177    // Write the single ANS distribution
178    code.histograms[0].write(writer)?;
179
180    Ok(())
181}
182
183/// Writes the global modular section (tree + histogram) for multi-group encoding.
184///
185/// This writes:
186/// - dc_quant.all_default = 1
187/// - has_tree = 1
188/// - Tree histogram and tokens (Gradient predictor)
189/// - Data histogram with HybridUint {4,2,0} (Huffman or ANS)
190///
191/// `all_residuals` are the raw packed residuals from all groups (needed for ANS histogram building).
192/// `histogram` and `max_token` are built from HybridUint-encoded tokens (not raw residuals).
193/// Returns the entropy coding state needed to encode pixel data in group sections.
194pub fn write_global_modular_section(
195    all_residuals: &[u32],
196    histogram: &[u32],
197    max_token: u32,
198    writer: &mut BitWriter,
199    use_ans: bool,
200    transforms: GlobalTransforms,
201) -> Result<GlobalModularState> {
202    crate::trace::debug_eprintln!(
203        "GLOBAL_MODULAR [bit {}]: Starting global section (ans={})",
204        writer.bits_written(),
205        use_ans
206    );
207
208    // dc_quant.all_default = true
209    writer.write(1, 1)?;
210    // has_tree = true
211    writer.write(1, 1)?;
212
213    // Tree histogram (supports symbols 0-5 for Gradient predictor)
214    let (tree_depths, tree_codes) = write_tree_histogram_for_gradient(writer)?;
215    write_gradient_tree_tokens(writer, &tree_depths, &tree_codes)?;
216
217    if use_ans {
218        // Build ANS code from all residuals across all groups
219        let tokens: Vec<AnsToken> = all_residuals.iter().map(|&r| AnsToken::new(0, r)).collect();
220        let code = build_entropy_code_ans(&tokens, 1); // 1 context for single-leaf tree
221
222        // Write ANS data header (distribution + config)
223        write_ans_modular_header(writer, &code)?;
224
225        // Write GlobalModular's ModularHeader
226        writer.write(1, 1)?; // use_global_tree = true
227        writer.write(1, 1)?; // wp_params.default_wp = true
228        write_global_transforms_full(writer, &transforms)?;
229
230        // Byte-align at end of global section
231        writer.zero_pad_to_byte();
232        crate::trace::debug_eprintln!(
233            "GLOBAL_MODULAR [bit {}]: Global section done (ANS)",
234            writer.bits_written()
235        );
236
237        Ok(GlobalModularState::Ans { code })
238    } else {
239        // Data histogram with HybridUint {4,2,0} + Huffman
240        let (depths, codes) = write_hybrid_data_histogram(writer, histogram, max_token)?;
241
242        // Write GlobalModular's ModularHeader
243        writer.write(1, 1)?; // use_global_tree = true
244        writer.write(1, 1)?; // wp_params.default_wp = true
245        write_global_transforms_full(writer, &transforms)?;
246
247        // Byte-align at end of global section
248        writer.zero_pad_to_byte();
249        crate::trace::debug_eprintln!(
250            "GLOBAL_MODULAR [bit {}]: Global section done (Huffman)",
251            writer.bits_written()
252        );
253
254        Ok(GlobalModularState::Huffman {
255            depths,
256            codes,
257            max_token,
258        })
259    }
260}
261
262/// Writes the global modular section with a learned MA tree for multi-group encoding.
263///
264/// This writes:
265/// - dc_quant (all_default=1, or custom if dc_quant_custom is Some)
266/// - has_tree = 1
267/// - Learned tree (write_tree)
268/// - lz77.enabled = 0
269/// - Multi-context ANS data histogram (write_entropy_code_ans)
270/// - GroupHeader (use_global_tree=1, wp_header.all_default=1, num_transforms=0)
271pub fn write_global_modular_section_with_tree(
272    images: &[ModularImage],
273    writer: &mut BitWriter,
274    profile: &crate::effort::EffortProfile,
275    transforms: GlobalTransforms,
276    use_lz77: bool,
277    lz77_method: crate::entropy_coding::lz77::Lz77Method,
278    meta_image: Option<&ModularImage>,
279) -> Result<GlobalModularState> {
280    write_global_modular_section_with_tree_dc_quant(
281        images,
282        writer,
283        profile,
284        transforms,
285        use_lz77,
286        lz77_method,
287        None,
288        meta_image,
289    )
290}
291
292/// Like [`write_global_modular_section_with_tree`] but with custom dc_quant for LfFrame.
293#[allow(clippy::too_many_arguments)]
294pub(crate) fn write_global_modular_section_with_tree_dc_quant(
295    images: &[ModularImage],
296    writer: &mut BitWriter,
297    profile: &crate::effort::EffortProfile,
298    transforms: GlobalTransforms,
299    use_lz77: bool,
300    lz77_method: crate::entropy_coding::lz77::Lz77Method,
301    dc_quant_custom: Option<[f32; 3]>,
302    meta_image: Option<&ModularImage>,
303) -> Result<GlobalModularState> {
304    use super::encode::write_tree;
305    use super::encode::write_wp_header;
306    use super::predictor::WeightedPredictorParams;
307    use super::tree::count_contexts;
308    use super::tree_learn::{
309        TreeLearningParams, TreeSamples, collect_residuals_with_tree, compute_best_tree,
310        compute_gather_stride_from_profile, gather_samples_strided, max_ref_channels,
311    };
312    use crate::entropy_coding::encode::build_entropy_code_ans_with_options;
313    use crate::entropy_coding::encode::write_entropy_code_ans;
314    use crate::entropy_coding::lz77::write_lz77_header;
315
316    // Step 0: Find best WP parameters (effort-dependent search)
317    let all_channels: Vec<&super::channel::Channel> = meta_image
318        .into_iter()
319        .chain(images.iter())
320        .flat_map(|img| img.channels.iter())
321        .collect();
322    let wp_params = if profile.wp_num_param_sets > 0 {
323        // Collect channel references for cost estimation
324        let channels_for_wp: Vec<super::channel::Channel> =
325            all_channels.iter().map(|c| (*c).clone()).collect();
326        super::predictor::find_best_wp_params(&channels_for_wp, profile.wp_num_param_sets)
327    } else {
328        WeightedPredictorParams::default()
329    };
330
331    // Step 1: Gather samples from all groups (with subsampling for large images)
332    let total_pixels: usize = meta_image
333        .into_iter()
334        .chain(images.iter())
335        .flat_map(|img| img.channels.iter())
336        .map(|ch| ch.width() * ch.height())
337        .sum();
338    let stride = compute_gather_stride_from_profile(total_pixels, profile);
339    // Compute max ref channels across all images for cross-channel prediction
340    let num_refs = {
341        let mut mr = 0;
342        if let Some(meta) = meta_image {
343            mr = mr.max(max_ref_channels(meta));
344        }
345        for img in images.iter() {
346            mr = mr.max(max_ref_channels(img));
347        }
348        mr
349    };
350    let mut samples = TreeSamples::new_with_ref_channels(num_refs);
351    // Gather meta-channel samples first (channel_offset=0, group_id=0)
352    if let Some(meta) = meta_image {
353        gather_samples_strided(&mut samples, meta, 0, 0, stride, &wp_params);
354    }
355    // Gather per-group samples (channel_offset=0: per-group images use 0-based
356    // channel indices, matching the decoder which builds per-group images with
357    // only the non-meta channels. The tree distinguishes meta from per-group
358    // via group_id property, not channel_idx offset.)
359    //
360    // When meta-channels exist in the global section (group_id=0), per-group
361    // channels use group_id = 1 + group_idx to avoid collision. This lets the
362    // tree split on group_id > 0 to separate meta from per-group data.
363    let per_group_id_offset = if meta_image.is_some() { 1u32 } else { 0u32 };
364    for (group_idx, group_image) in images.iter().enumerate() {
365        gather_samples_strided(
366            &mut samples,
367            group_image,
368            group_idx as u32 + per_group_id_offset,
369            0,
370            stride,
371            &wp_params,
372        );
373    }
374
375    // Step 2: Learn tree with effort-dependent parameters
376    let pixel_fraction = if total_pixels > 0 {
377        samples.num_samples as f64 / total_pixels as f64
378    } else {
379        1.0
380    };
381    let params = TreeLearningParams::from_profile(profile)
382        .with_ref_properties(num_refs, profile.effort)
383        .with_pixel_fraction(pixel_fraction)
384        .with_total_pixels(total_pixels);
385    let tree = compute_best_tree(&mut samples, &params);
386    let num_contexts = count_contexts(&tree) as usize;
387
388    crate::trace::debug_eprintln!(
389        "GLOBAL_MODULAR_TREE: {} nodes, {} leaves/contexts from {} samples \
390         (pixel_fraction={:.3}, threshold={:.1}*{:.3}={:.1})",
391        tree.len(),
392        num_contexts,
393        samples.num_samples,
394        pixel_fraction,
395        params.split_threshold,
396        pixel_fraction * 0.9 + 0.1,
397        params.split_threshold * (pixel_fraction * 0.9 + 0.1),
398    );
399
400    // Step 3: Collect residuals from all groups with tree
401    let mut all_tokens = Vec::new();
402    // Collect meta-channel residuals first (channel_offset=0, group_id=0)
403    let nb_meta_tokens = if let Some(meta) = meta_image {
404        let meta_tokens = collect_residuals_with_tree(meta, &tree, 0, &wp_params);
405        let n = meta_tokens.len();
406        all_tokens.extend(meta_tokens);
407        n
408    } else {
409        0
410    };
411    // Collect per-group residuals (channel_offset=0, group_id offset matches gather above)
412    for (group_idx, group_image) in images.iter().enumerate() {
413        let group_tokens = collect_residuals_with_tree(
414            group_image,
415            &tree,
416            group_idx as u32 + per_group_id_offset,
417            &wp_params,
418        );
419        all_tokens.extend(group_tokens);
420    }
421
422    // Note: LZ77 is NOT applied in this path. The per-group sections
423    // (write_group_modular_section) re-collect tokens independently without LZ77.
424    // Applying LZ77 to the combined stream would cause a histogram mismatch because
425    // the ANS code would include LZ77 symbols that per-group sections don't emit.
426    // The squeeze multi-group path (frame.rs) handles LZ77 correctly per-section.
427    let _ = (use_lz77, lz77_method); // suppress unused warnings
428    let lz77_params: Option<crate::entropy_coding::lz77::Lz77Params> = None;
429    let ans_num_contexts = if lz77_params.is_some() {
430        num_contexts + 1
431    } else {
432        num_contexts
433    };
434
435    // Step 4: Build multi-context ANS code with enhanced clustering
436    let code = build_entropy_code_ans_with_options(
437        &all_tokens,
438        ans_num_contexts,
439        true, // enhanced clustering (pair-merge refinement)
440        true, // optimize uint configs
441        lz77_params.as_ref(),
442        Some(total_pixels),
443    );
444
445    eprintln!(
446        "DIAG tree: {} nodes, {} contexts, {} samples, {} total_tokens, \
447         max_nodes={}, threshold={:.1}, pixel_frac={:.3}",
448        tree.len(),
449        num_contexts,
450        samples.num_samples,
451        all_tokens.len(),
452        params.max_nodes,
453        params.split_threshold,
454        pixel_fraction,
455    );
456    eprintln!(
457        "DIAG code: {} histograms (from {} contexts), rct={:?}, compact={}",
458        code.histograms.len(),
459        ans_num_contexts,
460        transforms.rct_type,
461        transforms.compact_info.len(),
462    );
463
464    // Step 5: Write bitstream
465    let bits_before = writer.bits_written();
466    crate::f16::write_lf_quant(writer, dc_quant_custom)?;
467    // has_tree = true
468    writer.write(1, 1)?;
469
470    // Write the learned tree
471    let bits_before_tree = writer.bits_written();
472    write_tree(writer, &tree)?;
473    let tree_bits = writer.bits_written() - bits_before_tree;
474
475    // Write LZ77 header + ANS data histogram.
476    let bits_before_histo = writer.bits_written();
477    if ans_num_contexts > 1 {
478        write_lz77_header(lz77_params.as_ref(), writer)?;
479        write_entropy_code_ans(&code, writer)?;
480    } else {
481        write_ans_modular_header(writer, &code)?;
482    }
483    let histo_bits = writer.bits_written() - bits_before_histo;
484
485    // GroupHeader (global modular group)
486    writer.write(1, 1)?; // use_global_tree = true
487    write_wp_header(writer, &wp_params)?;
488    write_global_transforms_full(writer, &transforms)?;
489
490    // Write meta-channel tokens (palette data) in the global section, after GroupHeader.
491    // These are part of the global modular image — they stay whole (not split across groups).
492    if nb_meta_tokens > 0 {
493        let meta_token_slice = &all_tokens[..nb_meta_tokens];
494        write_tokens_ans(meta_token_slice, &code, None, writer)?;
495    }
496
497    let total_lf_global_bits = writer.bits_written() - bits_before;
498    eprintln!(
499        "DIAG LfGlobal: tree={} bits ({} B), histo={} bits ({} B), \
500         meta_tokens={}, total={} bits ({} B)",
501        tree_bits,
502        tree_bits / 8,
503        histo_bits,
504        histo_bits / 8,
505        nb_meta_tokens,
506        total_lf_global_bits,
507        total_lf_global_bits / 8,
508    );
509
510    writer.zero_pad_to_byte();
511
512    Ok(GlobalModularState::AnsWithTree {
513        code,
514        tree,
515        wp_params,
516    })
517}
518
519/// Info about global transforms to write in the LfGlobal GroupHeader.
520pub struct GlobalTransforms {
521    /// Per-channel ChannelCompact transforms: (begin_c, nb_colors).
522    pub compact_info: Vec<(usize, usize)>,
523    /// Optional RCT type (begin_c is adjusted for ChannelCompact meta channels).
524    pub rct_type: Option<RctType>,
525}
526
527impl GlobalTransforms {
528    pub fn rct_only(rct_type: Option<RctType>) -> Self {
529        Self {
530            compact_info: Vec::new(),
531            rct_type,
532        }
533    }
534}
535
536/// Write num_transforms + transform descriptors for the global GroupHeader.
537///
538/// When `compact_info` is present, writes ChannelCompact (kPalette with num_c=1)
539/// transforms first, then RCT with begin_c shifted by the number of compact meta channels.
540fn write_global_transforms_full(
541    writer: &mut BitWriter,
542    transforms: &GlobalTransforms,
543) -> Result<()> {
544    let num_transforms =
545        transforms.compact_info.len() as u32 + transforms.rct_type.is_some() as u32;
546    super::encode::write_num_transforms(writer, num_transforms)?;
547
548    // ChannelCompact transforms first (per-channel palette, num_c=1)
549    for &(begin_c, nb_colors) in &transforms.compact_info {
550        write_palette_transform(writer, begin_c, 1, nb_colors, 0, 0)?;
551    }
552    // RCT (begin_c adjusted for ChannelCompact meta channels)
553    if let Some(rct) = transforms.rct_type {
554        let rct_begin_c = transforms.compact_info.len();
555        write_rct_transform(writer, rct_begin_c, rct)?;
556    }
557    Ok(())
558}
559
560/// Collect packed residuals from a group image using gradient prediction.
561fn collect_group_residuals(group_image: &ModularImage) -> Vec<u32> {
562    let mut residuals = Vec::new();
563    for channel in &group_image.channels {
564        let width = channel.width();
565        let height = channel.height();
566        for y in 0..height {
567            for x in 0..width {
568                let pixel = channel.get(x, y);
569                let left = if x > 0 { channel.get(x - 1, y) } else { 0 };
570                let top = if y > 0 { channel.get(x, y - 1) } else { left };
571                let topleft = if x > 0 && y > 0 {
572                    channel.get(x - 1, y - 1)
573                } else {
574                    left
575                };
576                let prediction = predict_gradient(left, top, topleft);
577                let residual = pixel - prediction;
578                residuals.push(pack_signed(residual));
579            }
580        }
581    }
582    residuals
583}
584
585/// Writes a group's data section for multi-group modular encoding.
586///
587/// This writes:
588/// - GroupHeader (use_global_tree=1, wp_header.all_default=1, num_transforms=0)
589/// - Encoded pixel residuals using HybridUint {4,2,0} + global entropy codes
590///
591/// The `group_image` should be the extracted region for this group.
592pub fn write_group_modular_section(
593    group_image: &ModularImage,
594    state: &GlobalModularState,
595    writer: &mut BitWriter,
596) -> Result<()> {
597    write_group_modular_section_idx(group_image, state, 0, &GroupTransforms::none(), writer)
598}
599
600/// Like [`write_group_modular_section`] but with an explicit group index
601/// for tree property 1 (group_id). Required when the learned tree splits on group_id.
602///
603/// `rct_type`: Optional per-group RCT transform to write in this group's GroupHeader.
604/// When `Some`, the group data is assumed to be already RCT-transformed and the
605/// decoder will apply inverse RCT when decoding this group.
606/// Per-group transform info for ChannelCompact + RCT.
607#[derive(Clone)]
608pub struct GroupTransforms {
609    /// Per-channel ChannelCompact transforms: (begin_c, nb_colors).
610    pub compact_info: Vec<(usize, usize)>,
611    /// Optional RCT type (begin_c is adjusted for ChannelCompact meta channels).
612    pub rct_type: Option<RctType>,
613}
614
615impl GroupTransforms {
616    pub fn none() -> Self {
617        Self {
618            compact_info: Vec::new(),
619            rct_type: None,
620        }
621    }
622}
623
624pub fn write_group_modular_section_idx(
625    group_image: &ModularImage,
626    state: &GlobalModularState,
627    group_idx: u32,
628    transforms: &GroupTransforms,
629    writer: &mut BitWriter,
630) -> Result<()> {
631    crate::trace::debug_eprintln!(
632        "GROUP_MODULAR [bit {}]: Starting group section ({}x{}, compact={}, rct={:?})",
633        writer.bits_written(),
634        group_image.width(),
635        group_image.height(),
636        transforms.compact_info.len(),
637        transforms.rct_type,
638    );
639
640    // GroupHeader
641    writer.write(1, 1)?; // use_global_tree = true
642    // Write WP params matching the global section's params
643    match state {
644        GlobalModularState::AnsWithTree { wp_params, .. } => {
645            super::encode::write_wp_header(writer, wp_params)?;
646        }
647        _ => {
648            writer.write(1, 1)?; // wp_params.default_wp = true
649        }
650    }
651    // Per-group transforms: ChannelCompact(s) + optional RCT
652    let num_transforms =
653        transforms.compact_info.len() as u32 + transforms.rct_type.is_some() as u32;
654    super::encode::write_num_transforms(writer, num_transforms)?;
655    for &(begin_c, nb_colors) in &transforms.compact_info {
656        write_palette_transform(writer, begin_c, 1, nb_colors, 0, 0)?;
657    }
658    if let Some(rct) = transforms.rct_type {
659        let rct_begin_c = transforms.compact_info.len();
660        write_rct_transform(writer, rct_begin_c, rct)?;
661    }
662
663    match state {
664        GlobalModularState::Huffman {
665            depths,
666            codes,
667            max_token: _,
668        } => {
669            // Encode residuals with HybridUint {4,2,0} + Huffman
670            for channel in &group_image.channels {
671                let width = channel.width();
672                let height = channel.height();
673                for y in 0..height {
674                    for x in 0..width {
675                        let pixel = channel.get(x, y);
676                        let left = if x > 0 { channel.get(x - 1, y) } else { 0 };
677                        let top = if y > 0 { channel.get(x, y - 1) } else { left };
678                        let topleft = if x > 0 && y > 0 {
679                            channel.get(x - 1, y - 1)
680                        } else {
681                            left
682                        };
683                        let prediction = predict_gradient(left, top, topleft);
684                        let residual = pixel - prediction;
685                        let packed = pack_signed(residual);
686
687                        let (token, extra_bits, num_extra) = MODULAR_HYBRID_UINT.encode(packed);
688                        let depth = depths.get(token as usize).copied().unwrap_or(0);
689                        let code = codes.get(token as usize).copied().unwrap_or(0);
690                        if depth > 0 {
691                            writer.write(depth as usize, code as u64)?;
692                        }
693                        if num_extra > 0 {
694                            writer.write(num_extra as usize, extra_bits as u64)?;
695                        }
696                    }
697                }
698            }
699        }
700        GlobalModularState::Ans { code } => {
701            // Collect residuals for this group and encode with ANS
702            let residuals = collect_group_residuals(group_image);
703            let tokens: Vec<AnsToken> = residuals.iter().map(|&r| AnsToken::new(0, r)).collect();
704            write_tokens_ans(&tokens, code, None, writer)?;
705        }
706        GlobalModularState::AnsWithTree {
707            code,
708            tree,
709            wp_params,
710        } => {
711            // Collect residuals using the learned tree (multi-context).
712            // Per-group images use 0-based channel indices (matching the decoder,
713            // which builds per-group images with only non-meta channels).
714            let tokens = super::tree_learn::collect_residuals_with_tree(
715                group_image,
716                tree,
717                group_idx,
718                wp_params,
719            );
720            write_tokens_ans(&tokens, code, None, writer)?;
721        }
722    }
723
724    // Byte-align at end of group section
725    writer.zero_pad_to_byte();
726    crate::trace::debug_eprintln!(
727        "GROUP_MODULAR [bit {}]: Group section done",
728        writer.bits_written()
729    );
730
731    Ok(())
732}