Skip to main content

jxl_encoder/modular/
frame.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! Frame encoder - assembles complete JXL frames.
6
7use super::channel::{Channel, ModularImage};
8use super::encode::{
9    build_histogram_from_residuals, collect_all_residuals, select_best_rct_at,
10    write_global_modular_section, write_group_modular_section_idx, write_improved_modular_stream,
11    write_modular_stream_with_tree,
12};
13use super::palette::{CHANNEL_COLORS_PERCENT, analyze_channel_compact};
14use super::section::write_global_modular_section_with_tree;
15use crate::GROUP_DIM;
16use crate::bit_writer::BitWriter;
17use crate::entropy_coding::lz77::Lz77Method;
18use crate::error::Result;
19use crate::headers::ColorEncoding;
20use crate::headers::frame_header::{BlendMode, FrameCrop, FrameHeader};
21
22/// Options for frame encoding.
23#[derive(Debug, Clone)]
24pub struct FrameEncoderOptions {
25    /// Use modular mode (lossless).
26    pub use_modular: bool,
27    /// Effort level (1-10, higher = better compression, slower).
28    pub effort: u8,
29    /// Use ANS entropy coding instead of Huffman for modular.
30    pub use_ans: bool,
31    /// Use content-adaptive MA tree learning for modular encoding.
32    pub use_tree_learning: bool,
33    /// Use squeeze (Haar wavelet) transform for modular encoding.
34    pub use_squeeze: bool,
35    /// Enable LZ77 compression on modular token streams.
36    pub enable_lz77: bool,
37    /// LZ77 method to use when enable_lz77 is true.
38    pub lz77_method: Lz77Method,
39    /// Use lossy delta palette for near-lossless modular encoding.
40    pub lossy_palette: bool,
41    /// Encoder mode: Reference (match libjxl) or Experimental (own improvements).
42    pub encoder_mode: crate::api::EncoderMode,
43    /// Effort profile with all effort-derived parameters.
44    pub profile: crate::effort::EffortProfile,
45    /// Whether this frame is part of an animation (enables duration field in header).
46    pub have_animation: bool,
47    /// Duration of this frame in ticks (only used when have_animation is true).
48    pub duration: u32,
49    /// Whether this is the last frame in the image/animation.
50    pub is_last: bool,
51    /// Optional crop rectangle for this frame (None = full frame).
52    pub crop: Option<FrameCrop>,
53    /// Skip RCT even for 3-channel images (e.g., XYB channels already decorrelated).
54    pub skip_rct: bool,
55}
56
57impl Default for FrameEncoderOptions {
58    fn default() -> Self {
59        Self {
60            use_modular: true, // Default to lossless
61            effort: 7,
62            use_ans: false,
63            use_tree_learning: false,
64            use_squeeze: false,
65            enable_lz77: false,
66            lz77_method: Lz77Method::Rle,
67            lossy_palette: false,
68            encoder_mode: crate::api::EncoderMode::Reference,
69            profile: crate::effort::EffortProfile::lossless(7, crate::api::EncoderMode::Reference),
70            have_animation: false,
71            duration: 0,
72            is_last: true,
73            crop: None,
74            skip_rct: false,
75        }
76    }
77}
78
79/// Encodes a single frame.
80pub struct FrameEncoder {
81    /// Encoding options.
82    #[allow(dead_code)]
83    options: FrameEncoderOptions,
84    /// Image width.
85    width: usize,
86    /// Image height.
87    height: usize,
88    #[allow(dead_code)]
89    /// Number of extra channels (e.g., 1 for alpha).
90    num_extra_channels: usize,
91}
92
93impl FrameEncoder {
94    /// Creates a new frame encoder.
95    pub fn new(width: usize, height: usize, options: FrameEncoderOptions) -> Self {
96        Self {
97            options,
98            width,
99            height,
100            num_extra_channels: 0,
101        }
102    }
103
104    /// Creates a new frame encoder with extra channel support.
105    pub fn new_with_extra_channels(
106        width: usize,
107        height: usize,
108        options: FrameEncoderOptions,
109        num_extra_channels: usize,
110    ) -> Self {
111        Self {
112            options,
113            width,
114            height,
115            num_extra_channels,
116        }
117    }
118
119    /// Encodes a modular image into a frame with optional patches.
120    ///
121    /// When patches are provided, sets PATCHES_FLAG in the frame header and
122    /// writes the patches section at the start of LfGlobal data.
123    pub(crate) fn encode_modular_with_patches(
124        &self,
125        image: &ModularImage,
126        color_encoding: &ColorEncoding,
127        writer: &mut BitWriter,
128        patches: Option<&crate::vardct::patches::PatchesData>,
129    ) -> Result<()> {
130        if patches.is_none() {
131            return self.encode_modular(image, color_encoding, writer);
132        }
133        let patches = patches.unwrap();
134
135        // Compute num_extra_channels from image
136        let num_extra_channels = if image.has_alpha { 1 } else { 0 };
137
138        // Write frame header with PATCHES_FLAG
139        {
140            use crate::headers::frame_header::PATCHES_FLAG;
141            let mut fh = FrameHeader::lossless();
142            fh.flags |= PATCHES_FLAG;
143            fh.ec_upsampling = vec![1; num_extra_channels];
144            fh.ec_blend_modes = vec![BlendMode::Replace; num_extra_channels];
145            fh.have_animation = self.options.have_animation;
146            fh.duration = self.options.duration;
147            fh.is_last = self.options.is_last;
148            if let Some(ref crop) = self.options.crop {
149                fh.x0 = crop.x0;
150                fh.y0 = crop.y0;
151                fh.width = crop.width;
152                fh.height = crop.height;
153                fh.blend_mode = BlendMode::Replace;
154                fh.blend_source = 1;
155            }
156            if self.options.have_animation && !self.options.is_last {
157                fh.save_as_reference = 1;
158            }
159            fh.write(writer)?;
160        }
161
162        let num_groups = self.num_groups();
163
164        if num_groups == 1 {
165            // Single group: combine patches section + modular data into one TOC entry
166            let mut section_writer = BitWriter::new();
167
168            // Write patches section first (within the single TOC section)
169            crate::vardct::patches::encode_patches_section(
170                patches,
171                self.options.use_ans,
172                &mut section_writer,
173            )?;
174
175            // Then write modular data (same logic as encode_modular)
176            let has_squeeze = self.options.use_squeeze
177                && !super::squeeze::default_squeeze_params(image).is_empty();
178
179            if self.options.lossy_palette && image.channels.len() >= 3 {
180                let max_colors = 1usize << image.bit_depth.min(12);
181                super::encode::write_modular_stream_with_lossy_palette(
182                    image,
183                    &mut section_writer,
184                    self.options.use_ans,
185                    0,
186                    image.channels.len().min(3),
187                    max_colors,
188                )?;
189            } else if has_squeeze && self.options.use_tree_learning && self.options.use_ans {
190                super::encode::write_modular_stream_with_squeeze_and_tree(
191                    image,
192                    &mut section_writer,
193                    &self.options.profile,
194                    self.options.enable_lz77,
195                    self.options.lz77_method,
196                )?;
197            } else if has_squeeze {
198                super::encode::write_modular_stream_with_squeeze(
199                    image,
200                    &mut section_writer,
201                    self.options.use_ans,
202                )?;
203            } else if self.options.use_tree_learning && self.options.use_ans {
204                // Tree learning: handles palette internally when beneficial
205                write_modular_stream_with_tree(
206                    image,
207                    &mut section_writer,
208                    &self.options.profile,
209                    image.channels.len() >= 3,
210                    self.options.enable_lz77,
211                    self.options.lz77_method,
212                )?;
213            } else if image.channels.len() >= 3 {
214                super::encode::write_modular_stream_with_rct(
215                    image,
216                    &mut section_writer,
217                    self.options.use_ans,
218                )?;
219            } else {
220                write_improved_modular_stream(image, &mut section_writer, self.options.use_ans)?;
221            }
222
223            let section_data = section_writer.finish();
224            self.write_toc(writer, section_data.len())?;
225            for byte in section_data {
226                writer.write_u8(byte)?;
227            }
228        } else {
229            // Multi-group with patches: patches section goes into LfGlobal.
230            // Squeeze + patches is not yet supported; use non-squeeze multi-group path.
231            self.encode_modular_multi_group_inner(image, writer, Some(patches))?;
232        }
233
234        Ok(())
235    }
236
237    /// Encodes a modular image into a frame.
238    pub fn encode_modular(
239        &self,
240        image: &ModularImage,
241        _color_encoding: &ColorEncoding,
242        writer: &mut BitWriter,
243    ) -> Result<()> {
244        // Compute num_extra_channels from image
245        let num_extra_channels = if image.has_alpha { 1 } else { 0 };
246
247        // Write frame header using unified FrameHeader
248        {
249            let mut fh = FrameHeader::lossless();
250            fh.ec_upsampling = vec![1; num_extra_channels];
251            fh.ec_blend_modes = vec![BlendMode::Replace; num_extra_channels];
252            fh.have_animation = self.options.have_animation;
253            fh.duration = self.options.duration;
254            fh.is_last = self.options.is_last;
255            if let Some(ref crop) = self.options.crop {
256                fh.x0 = crop.x0;
257                fh.y0 = crop.y0;
258                fh.width = crop.width;
259                fh.height = crop.height;
260                fh.blend_mode = BlendMode::Replace;
261                fh.blend_source = 1;
262            }
263            // For animation, save non-last frames to reference slot 1
264            // so crop frames can composite onto the previous canvas.
265            if self.options.have_animation && !self.options.is_last {
266                fh.save_as_reference = 1;
267            }
268            fh.write(writer)?;
269        }
270
271        let num_groups = self.num_groups();
272
273        if num_groups == 1 {
274            // Single group: all sections combined into one TOC entry
275            let mut section_writer = BitWriter::new();
276            let has_squeeze = self.options.use_squeeze
277                && !super::squeeze::default_squeeze_params(image).is_empty();
278
279            if self.options.lossy_palette && image.channels.len() >= 3 {
280                // Lossy delta palette: near-lossless with error diffusion
281                let max_colors = 1usize << image.bit_depth.min(12);
282                super::encode::write_modular_stream_with_lossy_palette(
283                    image,
284                    &mut section_writer,
285                    self.options.use_ans,
286                    0,
287                    image.channels.len().min(3),
288                    max_colors,
289                )?;
290            } else if has_squeeze && self.options.use_tree_learning && self.options.use_ans {
291                // Combined squeeze + tree learning: best compression
292                super::encode::write_modular_stream_with_squeeze_and_tree(
293                    image,
294                    &mut section_writer,
295                    &self.options.profile,
296                    self.options.enable_lz77,
297                    self.options.lz77_method,
298                )?;
299            } else if has_squeeze {
300                // Squeeze without tree learning (lower effort levels)
301                super::encode::write_modular_stream_with_squeeze(
302                    image,
303                    &mut section_writer,
304                    self.options.use_ans,
305                )?;
306            } else if self.options.use_tree_learning && self.options.use_ans {
307                // Tree learning: handles palette internally when beneficial
308                write_modular_stream_with_tree(
309                    image,
310                    &mut section_writer,
311                    &self.options.profile,     // effort-dependent tree params
312                    image.channels.len() >= 3, // RCT for RGB
313                    self.options.enable_lz77,
314                    self.options.lz77_method,
315                )?;
316            } else if image.channels.len() >= 3 {
317                super::encode::write_modular_stream_with_rct(
318                    image,
319                    &mut section_writer,
320                    self.options.use_ans,
321                )?;
322            } else {
323                write_improved_modular_stream(image, &mut section_writer, self.options.use_ans)?;
324            }
325
326            let section_data = section_writer.finish();
327            let section_size = section_data.len();
328
329            crate::trace::debug_eprintln!("FRAME_ENCODER: section_size = {} bytes", section_size);
330
331            // Write TOC
332            self.write_toc(writer, section_size)?;
333
334            // Append section data (already byte-aligned)
335            for byte in section_data {
336                writer.write_u8(byte)?;
337            }
338        } else if self.options.lossy_palette && image.channels.len() >= 3 {
339            // Multi-group lossy palette: palette meta in LfGlobal, index across groups
340            self.encode_modular_multi_group_lossy_palette(image, writer)?;
341        } else if self.options.use_squeeze
342            && !super::squeeze::default_squeeze_params(image).is_empty()
343        {
344            if self.options.use_tree_learning && self.options.use_ans {
345                // Multi-group with squeeze + tree learning: best compression
346                self.encode_modular_multi_group_squeeze_with_tree(image, writer)?;
347            } else {
348                // Multi-group with squeeze: gradient predictor, single context
349                self.encode_modular_multi_group_squeeze(image, writer)?;
350            }
351        } else {
352            // Multi-group: separate TOC entries for global and each group
353            self.encode_modular_multi_group(image, writer)?;
354        }
355
356        Ok(())
357    }
358
359    /// Encodes a modular image using multi-group format (>256x256 images).
360    ///
361    /// For multi-group frames, the JXL spec requires this TOC structure:
362    /// - Section 0: LfGlobal (dc_quant + tree + histograms)
363    /// - Section 1: HfGlobal (empty for modular encoding)
364    /// - Section 2..2+num_lf_groups: LfGroup (empty for modular encoding)
365    /// - Section 2+num_lf_groups..: PassGroup (GroupHeader + pixel data per 256x256 region)
366    fn encode_modular_multi_group(
367        &self,
368        image: &ModularImage,
369        writer: &mut BitWriter,
370    ) -> Result<()> {
371        self.encode_modular_multi_group_inner(image, writer, None)
372    }
373
374    /// Inner multi-group encoder that accepts optional patches.
375    /// When patches are provided, writes patches section at the start of LfGlobal.
376    fn encode_modular_multi_group_inner(
377        &self,
378        image: &ModularImage,
379        writer: &mut BitWriter,
380        patches: Option<&crate::vardct::patches::PatchesData>,
381    ) -> Result<()> {
382        let num_groups = self.num_groups();
383        let num_lf_groups = self.num_lf_groups();
384        let num_passes = 1;
385
386        crate::trace::debug_eprintln!(
387            "MULTI_GROUP: Encoding {}x{} image with {} groups, {} lf_groups",
388            self.width,
389            self.height,
390            num_groups,
391            num_lf_groups
392        );
393
394        // Step 0: ChannelCompact + RCT + split into meta-image / per-group index images.
395        //
396        // ChannelCompact (per-channel palette) dramatically reduces bit depth for
397        // screenshots with sparse per-channel values (e.g., R uses 30/256 values).
398        // Applied BEFORE RCT because RCT spreads values, negating compaction benefit.
399        //
400        // The key insight: palette meta-channels (small, e.g. 30×1) stay in the global
401        // section, while index channels (image-sized) get extract_region per-group.
402        // This avoids the root cause of previous failures: extract_region corrupting
403        // tiny meta-channels by forcing them to group dimensions.
404
405        // Step 0a: ChannelCompact on raw image (before RCT)
406        // Only try ChannelCompact when tree learning + ANS are enabled (the global
407        // meta-channel path requires the AnsWithTree codepath in section.rs).
408        let has_rct = !self.options.skip_rct && image.channels.len() >= 3;
409        let num_color_channels = if has_rct {
410            3
411        } else {
412            image.channels.len().min(3)
413        };
414        let try_compact = self.options.use_tree_learning && self.options.use_ans;
415
416        let compact_analyses: Vec<(usize, super::palette::PaletteAnalysis)> = if try_compact {
417            // For multi-group, compact overhead is higher (meta-channels in global section,
418            // tree quality dilution across many groups). Require density <= 50%
419            // (i.e. range >= 2x unique), which means >= 1 bit/pixel entropy savings.
420            // Below this threshold, savings are eaten by per-group overhead.
421            (0..num_color_channels)
422                .filter_map(|ch_idx| {
423                    let analysis =
424                        analyze_channel_compact(&image.channels[ch_idx], CHANNEL_COLORS_PERCENT)?;
425                    // Reject if unique values use >50% of the range (< 1 bit/pixel savings)
426                    let ch = &image.channels[ch_idx];
427                    let mut min_v = i32::MAX;
428                    let mut max_v = i32::MIN;
429                    for y in 0..ch.height() {
430                        for x in 0..ch.width() {
431                            let v = ch.get(x, y);
432                            min_v = min_v.min(v);
433                            max_v = max_v.max(v);
434                        }
435                    }
436                    let range = (max_v as i64 - min_v as i64 + 1).max(1) as f64;
437                    let density = analysis.num_colors as f64 / range;
438                    crate::trace::debug_eprintln!(
439                        "COMPACT_FILTER: ch={} unique={} range={:.0} density={:.3}",
440                        ch_idx,
441                        analysis.num_colors,
442                        range,
443                        density
444                    );
445                    if density > 0.5 {
446                        return None;
447                    }
448                    Some((ch_idx, analysis))
449                })
450                .collect()
451        } else {
452            Vec::new()
453        };
454
455        let (meta_image, source_image_owned, compact_info, rct_type);
456        if !compact_analyses.is_empty() {
457            // Build palette meta-channels + index channels.
458            // Layout: [pal_N-1, ..., pal_0, idx_0, ch_1, idx_2, ...extra]
459            // (palettes reversed for decoder MetaPalette insertion order)
460            let mut palettes: Vec<Channel> = Vec::new();
461            let mut non_meta: Vec<Channel> = Vec::new();
462            let mut info: Vec<(usize, usize)> = Vec::new();
463            let mut nb_meta = 0usize;
464
465            for (orig_idx, ch) in image.channels.iter().enumerate() {
466                if let Some((_, analysis)) =
467                    compact_analyses.iter().find(|(idx, _)| *idx == orig_idx)
468                {
469                    // Create palette meta-channel (nb_colors wide, 1 high)
470                    let mut pal_ch = Channel::new(analysis.num_colors, 1)?;
471                    for (i, color) in analysis.palette.iter().enumerate() {
472                        pal_ch.set(i, 0, color[0]);
473                    }
474                    palettes.push(pal_ch);
475
476                    // Create index channel (same dimensions as original)
477                    let mut idx_ch = Channel::new(ch.width(), ch.height())?;
478                    for y in 0..ch.height() {
479                        for x in 0..ch.width() {
480                            let val = ch.get(x, y);
481                            let index = analysis.color_to_index[&vec![val]];
482                            idx_ch.set(x, y, index);
483                        }
484                    }
485                    non_meta.push(idx_ch);
486
487                    // begin_c for the transform descriptor
488                    let begin_c = orig_idx + nb_meta;
489                    info.push((begin_c, analysis.num_colors));
490                    nb_meta += 1;
491                } else {
492                    non_meta.push(ch.clone());
493                }
494            }
495
496            // Decoder's MetaPalette inserts each palette at position 0,
497            // so earlier palettes end up deeper. Reverse to match.
498            palettes.reverse();
499
500            crate::trace::debug_eprintln!(
501                "MULTI_GROUP_COMPACT: {} channels compacted, {} meta + {} non-meta, info={:?}",
502                compact_analyses.len(),
503                nb_meta,
504                non_meta.len(),
505                info,
506            );
507
508            // Split: meta-channels stay whole in global section
509            let mut meta_img = image.clone();
510            meta_img.channels = palettes;
511
512            // Non-meta channels: index + non-compacted + extra (full-size, to be split)
513            let mut work = image.clone();
514            work.channels = non_meta;
515
516            // Step 0b: RCT on index channels (starting at position 0 in non-meta)
517            if has_rct && work.channels.len() >= 3 {
518                let (selected_rct, transformed) =
519                    select_best_rct_at(&work, 0, self.options.profile.nb_rcts_to_try);
520                rct_type = Some(selected_rct);
521                work = transformed;
522            } else {
523                rct_type = None;
524            }
525
526            meta_image = Some(meta_img);
527            source_image_owned = work;
528            compact_info = info;
529        } else {
530            // No ChannelCompact — standard RCT-only path
531            if has_rct {
532                let (selected_rct, rct_image) =
533                    super::encode::select_best_rct(image, self.options.profile.nb_rcts_to_try);
534                rct_type = Some(selected_rct);
535                source_image_owned = rct_image;
536            } else {
537                rct_type = None;
538                source_image_owned = image.clone();
539            }
540            meta_image = None;
541            compact_info = Vec::new();
542        };
543
544        let global_transforms = super::section::GlobalTransforms {
545            compact_info,
546            rct_type,
547        };
548
549        // Step 1: Extract each group image from the index/non-meta channels only.
550        // Meta-channels (palettes) are NOT split — they go whole in the global section.
551        let mut group_images: Vec<ModularImage> = Vec::with_capacity(num_groups);
552        let group_transforms: Vec<super::section::GroupTransforms> =
553            vec![super::section::GroupTransforms::none(); num_groups];
554        for group_idx in 0..num_groups {
555            let (x_start, y_start, x_end, y_end) = self.group_bounds(group_idx);
556            let group_image = source_image_owned.extract_region(x_start, y_start, x_end, y_end)?;
557            group_images.push(group_image);
558        }
559
560        // Step 2: Write LfGlobal section (patches + tree + histogram)
561        let mut lf_global_writer = BitWriter::new();
562
563        // If patches are provided, write patches section first in LfGlobal
564        if let Some(pd) = patches {
565            crate::vardct::patches::encode_patches_section(
566                pd,
567                self.options.use_ans,
568                &mut lf_global_writer,
569            )?;
570        }
571
572        let global_state = if self.options.use_tree_learning && self.options.use_ans {
573            // Tree learning path: gather samples, learn tree, build multi-context ANS
574            write_global_modular_section_with_tree(
575                &group_images,
576                &mut lf_global_writer,
577                &self.options.profile, // effort-dependent tree params
578                global_transforms,
579                self.options.enable_lz77,
580                self.options.lz77_method,
581                meta_image.as_ref(),
582            )?
583        } else {
584            // Standard path: collect residuals with gradient predictor
585            let mut all_residuals = Vec::new();
586            let mut max_residual: u32 = 0;
587            for group_image in &group_images {
588                let (group_residuals, group_max) = collect_all_residuals(group_image);
589                all_residuals.extend(group_residuals);
590                max_residual = max_residual.max(group_max);
591            }
592            let (histogram, max_token) =
593                build_histogram_from_residuals(&all_residuals, max_residual);
594
595            crate::trace::debug_eprintln!(
596                "MULTI_GROUP: {} total residuals, max_raw={}, max_token={}, {} unique tokens",
597                all_residuals.len(),
598                max_residual,
599                max_token,
600                histogram.iter().filter(|&&c| c > 0).count()
601            );
602
603            write_global_modular_section(
604                &all_residuals,
605                &histogram,
606                max_token,
607                &mut lf_global_writer,
608                self.options.use_ans,
609                global_transforms,
610            )?
611        };
612        let lf_global_data = lf_global_writer.finish();
613
614        crate::trace::debug_eprintln!(
615            "MULTI_GROUP: LfGlobal section = {} bytes",
616            lf_global_data.len()
617        );
618
619        // Step 3: HfGlobal is empty for modular encoding (0 bytes)
620        let hf_global_data: Vec<u8> = Vec::new();
621        crate::trace::debug_eprintln!(
622            "MULTI_GROUP: HfGlobal section = 0 bytes (empty for modular)"
623        );
624
625        // Step 4: LfGroup sections are empty for modular encoding
626        let lf_group_data: Vec<Vec<u8>> = (0..num_lf_groups).map(|_| Vec::new()).collect();
627        crate::trace::debug_eprintln!(
628            "MULTI_GROUP: {} LfGroup sections = 0 bytes each (empty for modular)",
629            num_lf_groups
630        );
631
632        // Step 5: Write each PassGroup's data (GroupHeader + pixel data)
633        // Use the pre-extracted group_images to ensure residual consistency
634        //
635        // When ChannelCompact meta-channels exist in the global section (group_id=0),
636        // per-group channels use group_id = 1 + group_idx to avoid collision.
637        // This must match the offset used during tree learning in section.rs.
638        let per_group_id_offset: u32 = if meta_image.is_some() { 1 } else { 0 };
639        // PassGroup sections — parallelizable (each group writes to its own BitWriter)
640        let pass_group_data: Vec<Vec<u8>> =
641            crate::parallel::parallel_map_result(num_groups * num_passes, |flat_idx| {
642                let group_idx = flat_idx / num_passes;
643                let group_image = &group_images[group_idx];
644
645                let mut group_writer = BitWriter::new();
646                write_group_modular_section_idx(
647                    group_image,
648                    &global_state,
649                    group_idx as u32 + per_group_id_offset,
650                    &group_transforms[group_idx],
651                    &mut group_writer,
652                )?;
653
654                crate::trace::debug_eprintln!(
655                    "MULTI_GROUP: PassGroup {} section = {} bytes",
656                    group_idx,
657                    group_writer.bits_written() / 8,
658                );
659                Ok(group_writer.finish())
660            })?;
661
662        // Step 6: Collect all section sizes in correct order and write TOC
663        // JXL spec order: LfGlobal, LfGroup[0..num_lf_groups], HfGlobal, PassGroup[0..num_groups*num_passes]
664        // Note: LfGroup comes BEFORE HfGlobal!
665        let mut section_sizes = Vec::with_capacity(2 + num_lf_groups + num_groups * num_passes);
666        section_sizes.push(lf_global_data.len());
667        for data in &lf_group_data {
668            section_sizes.push(data.len());
669        }
670        section_sizes.push(hf_global_data.len());
671        for data in &pass_group_data {
672            section_sizes.push(data.len());
673        }
674
675        crate::trace::debug_eprintln!(
676            "MULTI_GROUP: {} total sections, sizes = {:?}",
677            section_sizes.len(),
678            section_sizes
679        );
680
681        self.write_toc_multi(writer, &section_sizes)?;
682
683        // Step 7: Append all section data in same order
684        for byte in lf_global_data {
685            writer.write_u8(byte)?;
686        }
687        for data in lf_group_data {
688            for byte in data {
689                writer.write_u8(byte)?;
690            }
691        }
692        for byte in hf_global_data {
693            writer.write_u8(byte)?;
694        }
695        for data in pass_group_data {
696            for byte in data {
697                writer.write_u8(byte)?;
698            }
699        }
700
701        Ok(())
702    }
703
704    /// Encodes a modular image using multi-group format with lossy (delta) palette.
705    ///
706    /// After applying the lossy palette transform, the channel layout is:
707    /// - Channel 0: palette meta-channel (width=total_size, height=num_c) — SMALL
708    /// - Channel 1: index channel (width=image_width, height=image_height) — LARGE
709    /// - Channel 2+: optional alpha/extra channels
710    ///
711    /// The palette meta-channel goes into LfGlobal (alongside the tree + histogram).
712    /// The index and extra channels are split across PassGroups by 256x256 regions.
713    /// The palette transform descriptor is written in the LfGlobal GroupHeader.
714    fn encode_modular_multi_group_lossy_palette(
715        &self,
716        image: &ModularImage,
717        writer: &mut BitWriter,
718    ) -> Result<()> {
719        use super::encode::{
720            write_gradient_tree_tokens, write_hybrid_data_histogram,
721            write_tree_histogram_for_gradient,
722        };
723        use super::encode_transforms::write_palette_transform;
724        use super::predictor::pack_signed;
725        use crate::entropy_coding::encode::{build_entropy_code_ans, write_tokens_ans};
726        use crate::entropy_coding::hybrid_uint::HybridUintConfig;
727        use crate::entropy_coding::token::Token as AnsToken;
728
729        const MODULAR_HYBRID_UINT: HybridUintConfig = HybridUintConfig {
730            split_exponent: 4,
731            split: 16,
732            msb_in_token: 2,
733            lsb_in_token: 0,
734        };
735
736        let num_groups = self.num_groups();
737        let num_lf_groups = self.num_lf_groups();
738
739        // Step 1: Apply lossy palette to full image
740        let mut transformed = image.clone();
741        let max_colors = 1usize << image.bit_depth.min(12);
742        let num_c = image.channels.len().min(3);
743        let result = super::palette::apply_lossy_palette(&mut transformed, 0, num_c, max_colors);
744        let result = match result {
745            Some(r) => r,
746            None => {
747                // Lossy palette not beneficial, fall back to standard multi-group
748                return self.encode_modular_multi_group_inner(image, writer, None);
749            }
750        };
751
752        crate::trace::debug_eprintln!(
753            "LOSSY_PALETTE_MULTI: {} colors + {} deltas, predictor={}, {} → {} channels, {}x{}",
754            result.nb_colors,
755            result.nb_deltas,
756            result.predictor,
757            image.channels.len(),
758            transformed.channels.len(),
759            self.width,
760            self.height,
761        );
762
763        // After palette: transformed.channels = [palette_meta, index, ...extra]
764        // Separate palette_meta (small, global) from spatial channels (split across groups)
765        let palette_meta = transformed.channels[0].clone();
766
767        // Build a ModularImage of only the spatial channels (index + alpha)
768        let spatial_image = ModularImage {
769            channels: transformed.channels[1..].to_vec(),
770            bit_depth: transformed.bit_depth,
771            is_grayscale: transformed.is_grayscale,
772            has_alpha: transformed.has_alpha,
773        };
774
775        // Step 2: Extract group images from spatial channels only
776        let mut group_images: Vec<ModularImage> = Vec::with_capacity(num_groups);
777        for group_idx in 0..num_groups {
778            let (x_start, y_start, x_end, y_end) = self.group_bounds(group_idx);
779            let group_image = spatial_image.extract_region(x_start, y_start, x_end, y_end)?;
780            group_images.push(group_image);
781        }
782
783        // Step 3: Collect ALL residuals (palette_meta + all groups) for histogram
784        let predict_gradient = |left: i32, top: i32, topleft: i32| -> i32 {
785            let grad = left + top - topleft;
786            grad.clamp(left.min(top), left.max(top))
787        };
788
789        let collect_channel_residuals = |channel: &super::channel::Channel| -> Vec<u32> {
790            let w = channel.width();
791            let h = channel.height();
792            let mut residuals = Vec::with_capacity(w * h);
793            for y in 0..h {
794                for x in 0..w {
795                    let pixel = channel.get(x, y);
796                    let left = if x > 0 { channel.get(x - 1, y) } else { 0 };
797                    let top = if y > 0 { channel.get(x, y - 1) } else { left };
798                    let topleft = if x > 0 && y > 0 {
799                        channel.get(x - 1, y - 1)
800                    } else {
801                        left
802                    };
803                    let prediction = predict_gradient(left, top, topleft);
804                    residuals.push(pack_signed(pixel - prediction));
805                }
806            }
807            residuals
808        };
809
810        // Palette meta-channel residuals (goes to LfGlobal)
811        let palette_residuals = collect_channel_residuals(&palette_meta);
812
813        // All residuals: palette_meta + all group spatial channels
814        let mut all_residuals = palette_residuals.clone();
815        for group_image in &group_images {
816            for channel in &group_image.channels {
817                all_residuals.extend(collect_channel_residuals(channel));
818            }
819        }
820
821        // Step 4: Build histogram and entropy codes
822        let mut max_token: u32 = 0;
823        for &r in &all_residuals {
824            let (token, _, _) = MODULAR_HYBRID_UINT.encode(r);
825            max_token = max_token.max(token);
826        }
827
828        // Step 5: Write LfGlobal section
829        let mut lf_global_writer = BitWriter::new();
830
831        // dc_quant.all_default = true
832        lf_global_writer.write(1, 1)?;
833        // has_tree = true
834        lf_global_writer.write(1, 1)?;
835
836        // Tree histogram + tokens (gradient predictor)
837        let (tree_depths, tree_codes) = write_tree_histogram_for_gradient(&mut lf_global_writer)?;
838        write_gradient_tree_tokens(&mut lf_global_writer, &tree_depths, &tree_codes)?;
839
840        // Build entropy coding state
841        let use_ans = self.options.use_ans;
842
843        enum EntropyState {
844            Huffman {
845                depths: Vec<u8>,
846                codes: Vec<u16>,
847            },
848            Ans {
849                code: crate::entropy_coding::encode::OwnedAnsEntropyCode,
850            },
851        }
852
853        let entropy_state = if use_ans {
854            let tokens: Vec<AnsToken> =
855                all_residuals.iter().map(|&r| AnsToken::new(0, r)).collect();
856            let code = build_entropy_code_ans(&tokens, 1);
857            super::section::write_ans_modular_header(&mut lf_global_writer, &code)?;
858            EntropyState::Ans { code }
859        } else {
860            let histogram_size = (max_token + 1) as usize;
861            let mut histogram = vec![0u32; histogram_size];
862            for &r in &all_residuals {
863                let (token, _, _) = MODULAR_HYBRID_UINT.encode(r);
864                histogram[token as usize] += 1;
865            }
866            let (depths, codes) =
867                write_hybrid_data_histogram(&mut lf_global_writer, &histogram, max_token)?;
868            EntropyState::Huffman { depths, codes }
869        };
870
871        // GroupHeader with palette transform
872        lf_global_writer.write(1, 1)?; // use_global_tree = true
873        lf_global_writer.write(1, 1)?; // wp_params.default_wp = true
874        lf_global_writer.write(2, 1)?; // nb_transforms = 1
875        write_palette_transform(
876            &mut lf_global_writer,
877            0,
878            num_c,
879            result.nb_colors,
880            result.nb_deltas,
881            result.predictor,
882        )?;
883
884        // Encode palette_meta residuals in LfGlobal
885        let encode_residuals =
886            |residuals: &[u32], writer: &mut BitWriter, state: &EntropyState| -> Result<()> {
887                match state {
888                    EntropyState::Huffman { depths, codes } => {
889                        for &r in residuals {
890                            let (token, extra_bits, num_extra) = MODULAR_HYBRID_UINT.encode(r);
891                            let depth = depths.get(token as usize).copied().unwrap_or(0);
892                            let code = codes.get(token as usize).copied().unwrap_or(0);
893                            if depth > 0 {
894                                writer.write(depth as usize, code as u64)?;
895                            }
896                            if num_extra > 0 {
897                                writer.write(num_extra as usize, extra_bits as u64)?;
898                            }
899                        }
900                    }
901                    EntropyState::Ans { code } => {
902                        let tokens: Vec<AnsToken> =
903                            residuals.iter().map(|&r| AnsToken::new(0, r)).collect();
904                        write_tokens_ans(&tokens, code, None, writer)?;
905                    }
906                }
907                Ok(())
908            };
909
910        encode_residuals(&palette_residuals, &mut lf_global_writer, &entropy_state)?;
911
912        lf_global_writer.zero_pad_to_byte();
913        let lf_global_data = lf_global_writer.finish();
914
915        crate::trace::debug_eprintln!(
916            "LOSSY_PALETTE_MULTI: LfGlobal = {} bytes (palette_meta {}x{})",
917            lf_global_data.len(),
918            palette_meta.width(),
919            palette_meta.height(),
920        );
921
922        // Step 6: HfGlobal is empty for modular
923        let hf_global_data: Vec<u8> = Vec::new();
924
925        // Step 7: LfGroup sections are empty for modular
926        let lf_group_data: Vec<Vec<u8>> = (0..num_lf_groups).map(|_| Vec::new()).collect();
927
928        // Step 8: Write each PassGroup's data — parallelizable
929        let pass_group_data: Vec<Vec<u8>> =
930            crate::parallel::parallel_map_result(num_groups, |g| {
931                let group_image = &group_images[g];
932                let mut group_writer = BitWriter::new();
933
934                // GroupHeader
935                group_writer.write(1, 1)?; // use_global_tree = true
936                group_writer.write(1, 1)?; // wp_params.default_wp = true
937                group_writer.write(2, 0)?; // nb_transforms = 0
938
939                // Collect and encode spatial channel residuals for this group
940                let mut section_residuals: Vec<u32> = Vec::new();
941                for channel in &group_image.channels {
942                    section_residuals.extend(collect_channel_residuals(channel));
943                }
944                encode_residuals(&section_residuals, &mut group_writer, &entropy_state)?;
945
946                group_writer.zero_pad_to_byte();
947                let data = group_writer.finish();
948                crate::trace::debug_eprintln!(
949                    "LOSSY_PALETTE_MULTI: PassGroup[{}] = {} bytes",
950                    g,
951                    data.len(),
952                );
953                Ok(data)
954            })?;
955
956        // Step 9: Assemble TOC and sections
957        // Section order: LfGlobal, LfGroup[0..n], HfGlobal, PassGroup[0..m]
958        let mut section_sizes = Vec::with_capacity(2 + num_lf_groups + num_groups);
959        section_sizes.push(lf_global_data.len());
960        for data in &lf_group_data {
961            section_sizes.push(data.len());
962        }
963        section_sizes.push(hf_global_data.len());
964        for data in &pass_group_data {
965            section_sizes.push(data.len());
966        }
967
968        self.write_toc_multi(writer, &section_sizes)?;
969
970        // Write all section data in same order
971        for byte in lf_global_data {
972            writer.write_u8(byte)?;
973        }
974        for data in lf_group_data {
975            for byte in data {
976                writer.write_u8(byte)?;
977            }
978        }
979        for byte in hf_global_data {
980            writer.write_u8(byte)?;
981        }
982        for data in pass_group_data {
983            for byte in data {
984                writer.write_u8(byte)?;
985            }
986        }
987
988        Ok(())
989    }
990
991    /// Encodes a modular image using multi-group format with squeeze (Haar wavelet) transform.
992    ///
993    /// After squeeze, channels are partitioned by resolution:
994    /// - **LfGlobal**: channels small enough to fit in GROUP_DIM (tree + histogram + data)
995    /// - **LfGroup**: channels with min(hshift, vshift) >= 3 (DC-group-sized regions)
996    /// - **PassGroup**: channels with min(hshift, vshift) < 3 (group-sized regions)
997    fn encode_modular_multi_group_squeeze(
998        &self,
999        image: &ModularImage,
1000        writer: &mut BitWriter,
1001    ) -> Result<()> {
1002        use super::encode::{
1003            write_gradient_tree_tokens, write_rct_transform, write_squeeze_transform,
1004            write_tree_histogram_for_gradient,
1005        };
1006        use super::predictor::pack_signed;
1007        use super::rct::{RctType, forward_rct};
1008        use super::squeeze::{apply_squeeze, default_squeeze_params};
1009        use crate::entropy_coding::encode::{build_entropy_code_ans, write_tokens_ans};
1010        use crate::entropy_coding::hybrid_uint::HybridUintConfig;
1011        use crate::entropy_coding::token::Token as AnsToken;
1012
1013        const MODULAR_HYBRID_UINT: HybridUintConfig = HybridUintConfig {
1014            split_exponent: 4,
1015            split: 16,
1016            msb_in_token: 2,
1017            lsb_in_token: 0,
1018        };
1019
1020        let num_groups = self.num_groups();
1021        let num_lf_groups = self.num_lf_groups();
1022        let lf_group_dim = GROUP_DIM * 8; // 2048
1023
1024        // Step 1: Apply RCT (YCoCg) before squeeze for RGB images, then squeeze
1025        let squeeze_params = default_squeeze_params(image);
1026        let mut squeezed = image.clone();
1027        let has_rct = squeezed.channels.len() >= 3;
1028        if has_rct {
1029            forward_rct(&mut squeezed.channels, 0, RctType::YCOCG)?;
1030        }
1031        apply_squeeze(&mut squeezed, &squeeze_params)?;
1032
1033        #[cfg(test)]
1034        {
1035            eprintln!(
1036                "SQUEEZE_MULTI: {} steps, {} → {} channels, image {}x{}",
1037                squeeze_params.len(),
1038                image.channels.len(),
1039                squeezed.channels.len(),
1040                self.width,
1041                self.height,
1042            );
1043            for (i, ch) in squeezed.channels.iter().enumerate() {
1044                eprintln!(
1045                    "  ch[{}]: {}x{} hshift={} vshift={} min_shift={}",
1046                    i,
1047                    ch.width(),
1048                    ch.height(),
1049                    ch.hshift,
1050                    ch.vshift,
1051                    ch.hshift.min(ch.vshift),
1052                );
1053            }
1054        }
1055
1056        // Step 2: Partition channels by size/shift
1057        // Global channels: both dimensions <= GROUP_DIM
1058        let global_cutoff = squeezed
1059            .channels
1060            .iter()
1061            .position(|c| c.width() > GROUP_DIM || c.height() > GROUP_DIM)
1062            .unwrap_or(squeezed.channels.len());
1063
1064        crate::trace::debug_eprintln!(
1065            "SQUEEZE_MULTI: {} global channels (<={}x{}), {} group channels",
1066            global_cutoff,
1067            GROUP_DIM,
1068            GROUP_DIM,
1069            squeezed.channels.len() - global_cutoff,
1070        );
1071
1072        // Classify non-global channels by shift bracket
1073        // LfGroup: min(hshift, vshift) >= 3
1074        // PassGroup: min(hshift, vshift) < 3
1075        let mut lf_channel_indices: Vec<usize> = Vec::new();
1076        let mut pass_channel_indices: Vec<usize> = Vec::new();
1077        for i in global_cutoff..squeezed.channels.len() {
1078            let ch = &squeezed.channels[i];
1079            let min_shift = ch.hshift.min(ch.vshift);
1080            if min_shift >= 3 {
1081                lf_channel_indices.push(i);
1082            } else {
1083                pass_channel_indices.push(i);
1084            }
1085        }
1086
1087        #[cfg(test)]
1088        eprintln!(
1089            "SQUEEZE_MULTI: {} global, {} LfGroup (shift>=3), {} PassGroup (shift<3) channels",
1090            global_cutoff,
1091            lf_channel_indices.len(),
1092            pass_channel_indices.len(),
1093        );
1094
1095        // Step 3: Collect residuals from ALL channels for histogram building
1096        let predict_gradient = |left: i32, top: i32, topleft: i32| -> i32 {
1097            let grad = left + top - topleft;
1098            grad.clamp(left.min(top), left.max(top))
1099        };
1100
1101        let collect_channel_residuals = |channel: &super::channel::Channel| -> Vec<u32> {
1102            let w = channel.width();
1103            let h = channel.height();
1104            let mut residuals = Vec::with_capacity(w * h);
1105            for y in 0..h {
1106                for x in 0..w {
1107                    let pixel = channel.get(x, y);
1108                    let left = if x > 0 { channel.get(x - 1, y) } else { 0 };
1109                    let top = if y > 0 { channel.get(x, y - 1) } else { left };
1110                    let topleft = if x > 0 && y > 0 {
1111                        channel.get(x - 1, y - 1)
1112                    } else {
1113                        left
1114                    };
1115                    let prediction = predict_gradient(left, top, topleft);
1116                    residuals.push(pack_signed(pixel - prediction));
1117                }
1118            }
1119            residuals
1120        };
1121
1122        // 3a: Global channel residuals (full channels)
1123        let mut all_residuals: Vec<u32> = Vec::new();
1124        for i in 0..global_cutoff {
1125            all_residuals.extend(collect_channel_residuals(&squeezed.channels[i]));
1126        }
1127
1128        // 3b: LfGroup channel residuals (cropped to each DC group rect)
1129        // Use extract_grid_cell matching decoder's get_grid_rect: computes regions
1130        // in channel space via grid_dim = (group_dim >> hshift, group_dim >> vshift).
1131        let num_lf_groups_x = self.width.div_ceil(lf_group_dim);
1132        let mut lf_group_channel_data: Vec<Vec<Vec<u32>>> = vec![Vec::new(); num_lf_groups]; // [lf_group_idx][channel_within_group] = residuals
1133        for &ch_idx in &lf_channel_indices {
1134            let ch = &squeezed.channels[ch_idx];
1135            for (lg, lg_channels) in lf_group_channel_data
1136                .iter_mut()
1137                .enumerate()
1138                .take(num_lf_groups)
1139            {
1140                let lg_x = lg % num_lf_groups_x;
1141                let lg_y = lg / num_lf_groups_x;
1142                if let Some(cropped) = ch.extract_grid_cell(lg_x, lg_y, lf_group_dim) {
1143                    let residuals = collect_channel_residuals(&cropped);
1144                    all_residuals.extend(&residuals);
1145                    lg_channels.push(residuals);
1146                }
1147            }
1148        }
1149
1150        // 3c: PassGroup channel residuals (cropped to each group rect)
1151        // Use extract_grid_cell matching decoder's get_grid_rect logic.
1152        let num_groups_x = self.num_groups_x();
1153        let mut pass_group_channel_data: Vec<Vec<Vec<u32>>> = vec![Vec::new(); num_groups]; // [group_idx][channel_within_group] = residuals
1154        for &ch_idx in &pass_channel_indices {
1155            let ch = &squeezed.channels[ch_idx];
1156            for (g, g_channels) in pass_group_channel_data
1157                .iter_mut()
1158                .enumerate()
1159                .take(num_groups)
1160            {
1161                let gx = g % num_groups_x;
1162                let gy = g / num_groups_x;
1163                if let Some(cropped) = ch.extract_grid_cell(gx, gy, GROUP_DIM) {
1164                    let residuals = collect_channel_residuals(&cropped);
1165                    all_residuals.extend(&residuals);
1166                    g_channels.push(residuals);
1167                }
1168            }
1169        }
1170
1171        // Step 4: Build histogram and entropy codes
1172        let mut max_token: u32 = 0;
1173        for &r in &all_residuals {
1174            let (token, _, _) = MODULAR_HYBRID_UINT.encode(r);
1175            max_token = max_token.max(token);
1176        }
1177
1178        // Step 5: Write LfGlobal section
1179        let mut lf_global_writer = BitWriter::new();
1180
1181        // dc_quant.all_default = true
1182        lf_global_writer.write(1, 1)?;
1183        // has_tree = true
1184        lf_global_writer.write(1, 1)?;
1185
1186        // Tree histogram + tokens (gradient predictor)
1187        let (tree_depths, tree_codes) = write_tree_histogram_for_gradient(&mut lf_global_writer)?;
1188        write_gradient_tree_tokens(&mut lf_global_writer, &tree_depths, &tree_codes)?;
1189
1190        // Data histogram (Huffman or ANS) — covers ALL channels across ALL sections
1191        let use_ans = self.options.use_ans;
1192
1193        // Build the entropy coding state
1194        enum EntropyState {
1195            Huffman {
1196                depths: Vec<u8>,
1197                codes: Vec<u16>,
1198            },
1199            Ans {
1200                code: crate::entropy_coding::encode::OwnedAnsEntropyCode,
1201            },
1202        }
1203
1204        let entropy_state = if use_ans {
1205            let tokens: Vec<AnsToken> =
1206                all_residuals.iter().map(|&r| AnsToken::new(0, r)).collect();
1207            let code = build_entropy_code_ans(&tokens, 1);
1208            super::section::write_ans_modular_header(&mut lf_global_writer, &code)?;
1209            EntropyState::Ans { code }
1210        } else {
1211            let histogram_size = (max_token + 1) as usize;
1212            let mut histogram = vec![0u32; histogram_size];
1213            for &r in &all_residuals {
1214                let (token, _, _) = MODULAR_HYBRID_UINT.encode(r);
1215                histogram[token as usize] += 1;
1216            }
1217            let (depths, codes) = super::encode::write_hybrid_data_histogram(
1218                &mut lf_global_writer,
1219                &histogram,
1220                max_token,
1221            )?;
1222            EntropyState::Huffman { depths, codes }
1223        };
1224
1225        // GroupHeader for global modular stream — includes RCT (if RGB) + squeeze transform
1226        lf_global_writer.write(1, 1)?; // use_global_tree = true
1227        lf_global_writer.write(1, 1)?; // wp_params.default_wp = true
1228        if has_rct {
1229            // nb_transforms = 2: U32 BitsOffset(4,2), offset=0
1230            lf_global_writer.write(2, 2)?;
1231            lf_global_writer.write(4, 0)?;
1232            write_rct_transform(&mut lf_global_writer, 0, RctType::YCOCG)?;
1233            write_squeeze_transform(&mut lf_global_writer, &squeeze_params)?;
1234        } else {
1235            lf_global_writer.write(2, 1)?; // nb_transforms = 1
1236            write_squeeze_transform(&mut lf_global_writer, &squeeze_params)?;
1237        }
1238
1239        // Encode global channel data (small channels that fit within GROUP_DIM)
1240        let encode_residuals =
1241            |residuals: &[u32], writer: &mut BitWriter, state: &EntropyState| -> Result<()> {
1242                match state {
1243                    EntropyState::Huffman { depths, codes } => {
1244                        for &r in residuals {
1245                            let (token, extra_bits, num_extra) = MODULAR_HYBRID_UINT.encode(r);
1246                            let depth = depths.get(token as usize).copied().unwrap_or(0);
1247                            let code = codes.get(token as usize).copied().unwrap_or(0);
1248                            if depth > 0 {
1249                                writer.write(depth as usize, code as u64)?;
1250                            }
1251                            if num_extra > 0 {
1252                                writer.write(num_extra as usize, extra_bits as u64)?;
1253                            }
1254                        }
1255                    }
1256                    EntropyState::Ans { code } => {
1257                        let tokens: Vec<AnsToken> =
1258                            residuals.iter().map(|&r| AnsToken::new(0, r)).collect();
1259                        write_tokens_ans(&tokens, code, None, writer)?;
1260                    }
1261                }
1262                Ok(())
1263            };
1264
1265        // Write global channel residuals
1266        let mut global_residuals: Vec<u32> = Vec::new();
1267        for i in 0..global_cutoff {
1268            global_residuals.extend(collect_channel_residuals(&squeezed.channels[i]));
1269        }
1270        encode_residuals(&global_residuals, &mut lf_global_writer, &entropy_state)?;
1271
1272        lf_global_writer.zero_pad_to_byte();
1273        let lf_global_data = lf_global_writer.finish();
1274
1275        crate::trace::debug_eprintln!(
1276            "SQUEEZE_MULTI: LfGlobal = {} bytes ({} global channels)",
1277            lf_global_data.len(),
1278            global_cutoff,
1279        );
1280
1281        // Step 6: Write LfGroup sections
1282        let mut lf_group_data: Vec<Vec<u8>> = Vec::with_capacity(num_lf_groups);
1283        for (_lg, lg_channels) in lf_group_channel_data.iter().enumerate().take(num_lf_groups) {
1284            let mut lg_writer = BitWriter::new();
1285
1286            if lg_channels.is_empty() {
1287                // Empty LfGroup (no channels assigned)
1288                lf_group_data.push(lg_writer.finish());
1289                continue;
1290            }
1291
1292            // GroupHeader
1293            lg_writer.write(1, 1)?; // use_global_tree = true
1294            lg_writer.write(1, 1)?; // wp_params.default_wp = true
1295            lg_writer.write(2, 0)?; // nb_transforms = 0
1296
1297            // Concatenate all channel residuals for this section, then encode once.
1298            // ANS requires a single encoder per section (one ANS state per section).
1299            let mut section_residuals: Vec<u32> = Vec::new();
1300            for channel_residuals in lg_channels {
1301                section_residuals.extend(channel_residuals);
1302            }
1303            encode_residuals(&section_residuals, &mut lg_writer, &entropy_state)?;
1304
1305            lg_writer.zero_pad_to_byte();
1306            let data = lg_writer.finish();
1307            crate::trace::debug_eprintln!(
1308                "SQUEEZE_MULTI: LfGroup[{}] = {} bytes ({} channels)",
1309                _lg,
1310                data.len(),
1311                lg_channels.len(),
1312            );
1313            lf_group_data.push(data);
1314        }
1315
1316        // Step 7: HfGlobal is empty for modular
1317        let hf_global_data: Vec<u8> = Vec::new();
1318
1319        // Step 8: Write PassGroup sections
1320        // Step 8: Write PassGroup sections — parallelizable
1321        let pass_group_data: Vec<Vec<u8>> =
1322            crate::parallel::parallel_map_result(num_groups, |g| {
1323                let g_channels = &pass_group_channel_data[g];
1324                let mut pg_writer = BitWriter::new();
1325
1326                if g_channels.is_empty() {
1327                    // Empty PassGroup (no channels assigned)
1328                    return Ok(pg_writer.finish());
1329                }
1330
1331                // GroupHeader
1332                pg_writer.write(1, 1)?; // use_global_tree = true
1333                pg_writer.write(1, 1)?; // wp_params.default_wp = true
1334                pg_writer.write(2, 0)?; // nb_transforms = 0
1335
1336                // Concatenate all channel residuals for this section, then encode once.
1337                let mut section_residuals: Vec<u32> = Vec::new();
1338                for channel_residuals in g_channels {
1339                    section_residuals.extend(channel_residuals);
1340                }
1341                encode_residuals(&section_residuals, &mut pg_writer, &entropy_state)?;
1342
1343                pg_writer.zero_pad_to_byte();
1344                let data = pg_writer.finish();
1345                crate::trace::debug_eprintln!(
1346                    "SQUEEZE_MULTI: PassGroup[{}] = {} bytes ({} channels)",
1347                    g,
1348                    data.len(),
1349                    g_channels.len(),
1350                );
1351                Ok(data)
1352            })?;
1353
1354        // Step 9: Assemble TOC and sections
1355        // Section order: LfGlobal, LfGroup[0..n], HfGlobal, PassGroup[0..m]
1356        let mut section_sizes = Vec::with_capacity(2 + num_lf_groups + num_groups);
1357        section_sizes.push(lf_global_data.len());
1358        for data in &lf_group_data {
1359            section_sizes.push(data.len());
1360        }
1361        section_sizes.push(hf_global_data.len());
1362        for data in &pass_group_data {
1363            section_sizes.push(data.len());
1364        }
1365
1366        #[cfg(test)]
1367        eprintln!(
1368            "SQUEEZE_MULTI: {} sections, sizes = {:?}",
1369            section_sizes.len(),
1370            section_sizes,
1371        );
1372
1373        self.write_toc_multi(writer, &section_sizes)?;
1374
1375        // Write all section data
1376        for byte in lf_global_data {
1377            writer.write_u8(byte)?;
1378        }
1379        for data in lf_group_data {
1380            for byte in data {
1381                writer.write_u8(byte)?;
1382            }
1383        }
1384        for byte in hf_global_data {
1385            writer.write_u8(byte)?;
1386        }
1387        for data in pass_group_data {
1388            for byte in data {
1389                writer.write_u8(byte)?;
1390            }
1391        }
1392
1393        Ok(())
1394    }
1395
1396    /// Encodes a multi-group modular image with squeeze + tree learning.
1397    ///
1398    /// This combines the Haar wavelet (squeeze) transform with learned MA tree
1399    /// for multi-context ANS encoding across all sections. The tree is learned
1400    /// from the full squeezed image and shared across all sections.
1401    ///
1402    /// Pipeline: RCT -> squeeze -> partition channels -> gather samples ->
1403    /// learn tree -> collect residuals per section -> multi-context ANS
1404    fn encode_modular_multi_group_squeeze_with_tree(
1405        &self,
1406        image: &ModularImage,
1407        writer: &mut BitWriter,
1408    ) -> Result<()> {
1409        use super::encode::{write_rct_transform, write_squeeze_transform, write_tree};
1410        use super::rct::{RctType, forward_rct};
1411        use super::squeeze::{apply_squeeze, default_squeeze_params};
1412        use super::tree::count_contexts;
1413        use super::tree_learn::{
1414            TreeLearningParams, TreeSamples, collect_residuals_with_tree, compute_best_tree,
1415            compute_gather_stride_from_profile, gather_samples_strided,
1416        };
1417        use crate::entropy_coding::encode::build_entropy_code_ans_with_options;
1418        use crate::entropy_coding::encode::{write_entropy_code_ans, write_tokens_ans};
1419        use crate::entropy_coding::token::Token as AnsToken;
1420
1421        let num_groups = self.num_groups();
1422        let num_lf_groups = self.num_lf_groups();
1423        let lf_group_dim = GROUP_DIM * 8; // 2048
1424
1425        // Step 1: Apply RCT (YCoCg) before squeeze for RGB images, then squeeze
1426        let squeeze_params = default_squeeze_params(image);
1427        let mut squeezed = image.clone();
1428        let has_rct = squeezed.channels.len() >= 3;
1429        if has_rct {
1430            forward_rct(&mut squeezed.channels, 0, RctType::YCOCG)?;
1431        }
1432        apply_squeeze(&mut squeezed, &squeeze_params)?;
1433
1434        crate::trace::debug_eprintln!(
1435            "SQUEEZE_TREE_MULTI: {} steps, {} → {} channels, image {}x{}",
1436            squeeze_params.len(),
1437            image.channels.len(),
1438            squeezed.channels.len(),
1439            self.width,
1440            self.height,
1441        );
1442
1443        // Step 2: Partition channels by size/shift
1444        let global_cutoff = squeezed
1445            .channels
1446            .iter()
1447            .position(|c| c.width() > GROUP_DIM || c.height() > GROUP_DIM)
1448            .unwrap_or(squeezed.channels.len());
1449
1450        let mut lf_channel_indices: Vec<usize> = Vec::new();
1451        let mut pass_channel_indices: Vec<usize> = Vec::new();
1452        for i in global_cutoff..squeezed.channels.len() {
1453            let ch = &squeezed.channels[i];
1454            let min_shift = ch.hshift.min(ch.vshift);
1455            if min_shift >= 3 {
1456                lf_channel_indices.push(i);
1457            } else {
1458                pass_channel_indices.push(i);
1459            }
1460        }
1461
1462        crate::trace::debug_eprintln!(
1463            "SQUEEZE_TREE_MULTI: {} global, {} LfGroup, {} PassGroup channels",
1464            global_cutoff,
1465            lf_channel_indices.len(),
1466            pass_channel_indices.len(),
1467        );
1468
1469        // Step 3: Build sub-images for each section and gather samples
1470        // Compute stride from total pixel count for subsampling
1471        let total_pixels: usize = squeezed
1472            .channels
1473            .iter()
1474            .map(|ch| ch.width() * ch.height())
1475            .sum();
1476        let stride = compute_gather_stride_from_profile(total_pixels, &self.options.profile);
1477
1478        // Find best WP parameters (effort-dependent search)
1479        let wp_params = if self.options.profile.wp_num_param_sets > 0 {
1480            super::predictor::find_best_wp_params(
1481                &squeezed.channels,
1482                self.options.profile.wp_num_param_sets,
1483            )
1484        } else {
1485            super::predictor::WeightedPredictorParams::default()
1486        };
1487
1488        let mut samples = TreeSamples::new();
1489
1490        // Compute stream_id values matching the decoder's ModularStreamId formula.
1491        // The decoder assigns stream_id = property[1] during tree traversal:
1492        //   GlobalData:    0
1493        //   ModularLF(g):  1 + num_lf_groups + g
1494        //   ModularHF(p,g): 1 + 3*num_lf_groups + NUM_QUANT_TABLES + num_groups*p + g
1495        // These must match between encoder (tree training + residual collection) and decoder.
1496        const NUM_QUANT_TABLES: usize = 17;
1497        let stream_id_lf_base = 1 + num_lf_groups;
1498        let stream_id_hf_base = 1 + 3 * num_lf_groups + NUM_QUANT_TABLES;
1499
1500        // 3a: Global channels (full, no cropping needed)
1501        let global_sub = ModularImage {
1502            channels: squeezed.channels[..global_cutoff].to_vec(),
1503            bit_depth: squeezed.bit_depth,
1504            is_grayscale: squeezed.is_grayscale,
1505            has_alpha: false,
1506        };
1507        // group_id=0 for global section, channel_offset=0
1508        gather_samples_strided(&mut samples, &global_sub, 0, 0, stride, &wp_params);
1509
1510        // 3b: LfGroup channels — crop to each LfGroup rect
1511        let num_lf_groups_x = self.width.div_ceil(lf_group_dim);
1512        // Store cropped sub-images: [lf_group_idx] = Vec<Channel>
1513        let mut lf_group_sub_images: Vec<Vec<super::channel::Channel>> =
1514            vec![Vec::new(); num_lf_groups];
1515        for &ch_idx in &lf_channel_indices {
1516            let ch = &squeezed.channels[ch_idx];
1517            for (lg, lg_channels) in lf_group_sub_images.iter_mut().enumerate() {
1518                let lg_x = lg % num_lf_groups_x;
1519                let lg_y = lg / num_lf_groups_x;
1520                if let Some(cropped) = ch.extract_grid_cell(lg_x, lg_y, lf_group_dim) {
1521                    lg_channels.push(cropped);
1522                }
1523            }
1524        }
1525        // Gather samples from LfGroup sub-images
1526        // The first LfGroup channel in the squeezed image is at lf_channel_indices[0],
1527        // but we don't need channel_offset here because LfGroup channels form a separate
1528        // sub-image with their own local channel indices for the decoder.
1529        // Actually — the decoder uses a SINGLE tree across all sections. Property[0] = channel
1530        // index within the sub-image (modular stream). For squeeze multi-group, the decoder
1531        // reconstructs each section as a separate modular sub-image. The global section has
1532        // channels 0..gc, each LfGroup section has its own channels (starting from 0),
1533        // and each PassGroup section has its own channels (starting from 0).
1534        //
1535        // BUT — the tree was trained on the full image where these had specific channel indices.
1536        // The decoder assigns local indices per-section. So we need channel_offset to map:
1537        //   - Global: channels 0..gc → offset 0 (correct by default)
1538        //   - LfGroup: decoder's ch[0..n] → should map to squeezed ch[lf_channel_indices[0]..end]
1539        //   - PassGroup: decoder's ch[0..n] → should map to squeezed ch[pass_channel_indices[0]..end]
1540        //
1541        // Wait — I need to verify what the decoder actually does. Let me think about this more carefully.
1542        //
1543        // In JXL multi-group modular, the decoder processes:
1544        //   1. LfGlobal: reads tree, histograms, then decodes channels 0..gc using the tree
1545        //   2. LfGroup[i]: each section is a new modular stream with its own GroupHeader.
1546        //      The decoder maps these channels to the overall image by shift classification.
1547        //      Within each section, channels start from index 0.
1548        //   3. PassGroup[i]: same — channels start from index 0 within each section.
1549        //
1550        // The tree's property[0] (channel index) sees 0..n within each section.
1551        // So for tree learning to work correctly across sections, we should:
1552        //   - Either use a channel_offset to remap local indices to global indices (what PLAN.md suggests)
1553        //   - Or accept that the tree splits on local channel indices (may be less optimal)
1554        //
1555        // libjxl's approach: the tree IS global, but each section's channels are numbered locally.
1556        // The tree learns to split on "channel 0 vs channel 1" etc. which means different things
1557        // in different sections. But typically LfGroup has 0 or 3 channels (one per color component)
1558        // and PassGroup has 3 channels too, so the splits transfer well.
1559        //
1560        // For simplicity and correctness: use local channel indices (no offset) with per-section
1561        // group_id to disambiguate. This matches how the decoder will traverse the tree.
1562        for (lg, lg_channels) in lf_group_sub_images.iter().enumerate() {
1563            if lg_channels.is_empty() {
1564                continue;
1565            }
1566            let sub_image = ModularImage {
1567                channels: lg_channels.clone(),
1568                bit_depth: squeezed.bit_depth,
1569                is_grayscale: squeezed.is_grayscale,
1570                has_alpha: false,
1571            };
1572            // stream_id for LfGroup: ModularLF(lg) = 1 + num_lf_groups + lg
1573            gather_samples_strided(
1574                &mut samples,
1575                &sub_image,
1576                (stream_id_lf_base + lg) as u32,
1577                0,
1578                stride,
1579                &wp_params,
1580            );
1581        }
1582
1583        // 3c: PassGroup channels — crop to each group rect
1584        let num_groups_x = self.num_groups_x();
1585        let mut pass_group_sub_images: Vec<Vec<super::channel::Channel>> =
1586            vec![Vec::new(); num_groups];
1587        for &ch_idx in &pass_channel_indices {
1588            let ch = &squeezed.channels[ch_idx];
1589            for (g, g_channels) in pass_group_sub_images.iter_mut().enumerate() {
1590                let gx = g % num_groups_x;
1591                let gy = g / num_groups_x;
1592                if let Some(cropped) = ch.extract_grid_cell(gx, gy, GROUP_DIM) {
1593                    g_channels.push(cropped);
1594                }
1595            }
1596        }
1597        // Gather samples from PassGroup sub-images
1598        for (g, g_channels) in pass_group_sub_images.iter().enumerate() {
1599            if g_channels.is_empty() {
1600                continue;
1601            }
1602            let sub_image = ModularImage {
1603                channels: g_channels.clone(),
1604                bit_depth: squeezed.bit_depth,
1605                is_grayscale: squeezed.is_grayscale,
1606                has_alpha: false,
1607            };
1608            // stream_id for PassGroup: ModularHF(pass=0, group=g) = 1 + 3*num_lf_groups + 17 + g
1609            gather_samples_strided(
1610                &mut samples,
1611                &sub_image,
1612                (stream_id_hf_base + g) as u32,
1613                0,
1614                stride,
1615                &wp_params,
1616            );
1617        }
1618
1619        // Step 4: Learn tree
1620        let pixel_fraction = if total_pixels > 0 {
1621            samples.num_samples as f64 / total_pixels as f64
1622        } else {
1623            1.0
1624        };
1625        let tree_params = TreeLearningParams::from_profile(&self.options.profile)
1626            .with_pixel_fraction(pixel_fraction)
1627            .with_total_pixels(total_pixels);
1628        let tree = compute_best_tree(&mut samples, &tree_params);
1629        let num_contexts = count_contexts(&tree) as usize;
1630
1631        crate::trace::debug_eprintln!(
1632            "SQUEEZE_TREE_MULTI: {} tree nodes, {} contexts from {} samples (pf={:.3})",
1633            tree.len(),
1634            num_contexts,
1635            samples.num_samples,
1636            pixel_fraction,
1637        );
1638
1639        // Step 5: Collect residuals per section with the learned tree
1640        // Global section tokens
1641        let mut global_tokens = collect_residuals_with_tree(&global_sub, &tree, 0, &wp_params);
1642
1643        // LfGroup section tokens
1644        let mut lf_group_tokens: Vec<Vec<AnsToken>> = Vec::with_capacity(num_lf_groups);
1645        for (lg, lg_channels) in lf_group_sub_images.iter().enumerate() {
1646            if lg_channels.is_empty() {
1647                lf_group_tokens.push(Vec::new());
1648                continue;
1649            }
1650            let sub_image = ModularImage {
1651                channels: lg_channels.clone(),
1652                bit_depth: squeezed.bit_depth,
1653                is_grayscale: squeezed.is_grayscale,
1654                has_alpha: false,
1655            };
1656            let tokens = collect_residuals_with_tree(
1657                &sub_image,
1658                &tree,
1659                (stream_id_lf_base + lg) as u32,
1660                &wp_params,
1661            );
1662            lf_group_tokens.push(tokens);
1663        }
1664
1665        // PassGroup section tokens
1666        let mut pass_group_tokens: Vec<Vec<AnsToken>> = Vec::with_capacity(num_groups);
1667        for (g, g_channels) in pass_group_sub_images.iter().enumerate() {
1668            if g_channels.is_empty() {
1669                pass_group_tokens.push(Vec::new());
1670                continue;
1671            }
1672            let sub_image = ModularImage {
1673                channels: g_channels.clone(),
1674                bit_depth: squeezed.bit_depth,
1675                is_grayscale: squeezed.is_grayscale,
1676                has_alpha: false,
1677            };
1678            let tokens = collect_residuals_with_tree(
1679                &sub_image,
1680                &tree,
1681                (stream_id_hf_base + g) as u32,
1682                &wp_params,
1683            );
1684            pass_group_tokens.push(tokens);
1685        }
1686
1687        // Step 5b: Optionally apply LZ77 to each section's tokens independently
1688        // IMPORTANT: dist_multiplier must be computed PER-SECTION from that section's
1689        // channel widths, because the decoder creates a fresh LZ77 state per section
1690        // with dist_multiplier = max(section_channel_widths).
1691        let use_lz77 = self.options.enable_lz77;
1692        let lz77_method = self.options.lz77_method;
1693        let lz77_params = if use_lz77 {
1694            use crate::entropy_coding::lz77::apply_lz77;
1695
1696            let try_lz77 = |tokens: &[AnsToken], dist_multiplier: i32| -> Vec<AnsToken> {
1697                if tokens.is_empty() {
1698                    return tokens.to_vec();
1699                }
1700                match apply_lz77(tokens, num_contexts, false, lz77_method, dist_multiplier) {
1701                    Some((lz77_tokens, _)) => lz77_tokens,
1702                    None => tokens.to_vec(),
1703                }
1704            };
1705
1706            // Global section: dist_multiplier from global channels
1707            let global_dm = squeezed.channels[..global_cutoff]
1708                .iter()
1709                .map(|c| c.width())
1710                .max()
1711                .unwrap_or(0) as i32;
1712            global_tokens = try_lz77(&global_tokens, global_dm);
1713
1714            // LfGroup sections: dist_multiplier from each LfGroup's channels
1715            for (lg, lg_tokens) in lf_group_tokens.iter_mut().enumerate() {
1716                let dm = lf_group_sub_images[lg]
1717                    .iter()
1718                    .map(|c| c.width())
1719                    .max()
1720                    .unwrap_or(0) as i32;
1721                *lg_tokens = try_lz77(lg_tokens, dm);
1722            }
1723
1724            // PassGroup sections: dist_multiplier from each PassGroup's channels
1725            for (g, pg_tokens) in pass_group_tokens.iter_mut().enumerate() {
1726                let dm = pass_group_sub_images[g]
1727                    .iter()
1728                    .map(|c| c.width())
1729                    .max()
1730                    .unwrap_or(0) as i32;
1731                *pg_tokens = try_lz77(pg_tokens, dm);
1732            }
1733
1734            // Check if any section has LZ77 references
1735            let has_lz77 = global_tokens.iter().any(|t| t.is_lz77_length())
1736                || lf_group_tokens
1737                    .iter()
1738                    .any(|ts| ts.iter().any(|t| t.is_lz77_length()))
1739                || pass_group_tokens
1740                    .iter()
1741                    .any(|ts| ts.iter().any(|t| t.is_lz77_length()));
1742
1743            if has_lz77 {
1744                let mut params = crate::entropy_coding::lz77::Lz77Params::new(num_contexts, false);
1745                params.enabled = true;
1746                Some(params)
1747            } else {
1748                None
1749            }
1750        } else {
1751            None
1752        };
1753        let ans_num_contexts = if lz77_params.is_some() {
1754            num_contexts + 1
1755        } else {
1756            num_contexts
1757        };
1758
1759        // Step 6: Build ANS codes from ALL tokens
1760        let mut all_tokens: Vec<AnsToken> = Vec::new();
1761        all_tokens.extend(&global_tokens);
1762        for lg_tokens in &lf_group_tokens {
1763            all_tokens.extend(lg_tokens);
1764        }
1765        for pg_tokens in &pass_group_tokens {
1766            all_tokens.extend(pg_tokens);
1767        }
1768        let code = build_entropy_code_ans_with_options(
1769            &all_tokens,
1770            ans_num_contexts,
1771            true, // enhanced clustering (pair-merge refinement)
1772            true, // optimize uint configs
1773            lz77_params.as_ref(),
1774            Some(total_pixels),
1775        );
1776
1777        // Step 7: Write LfGlobal section
1778        let mut lf_global_writer = BitWriter::new();
1779
1780        // dc_quant.all_default = true
1781        lf_global_writer.write(1, 1)?;
1782        // has_tree = true
1783        lf_global_writer.write(1, 1)?;
1784
1785        // Write the learned tree
1786        write_tree(&mut lf_global_writer, &tree)?;
1787
1788        // Write LZ77 header + ANS histogram
1789        if ans_num_contexts > 1 {
1790            crate::entropy_coding::lz77::write_lz77_header(
1791                lz77_params.as_ref(),
1792                &mut lf_global_writer,
1793            )?;
1794            write_entropy_code_ans(&code, &mut lf_global_writer)?;
1795        } else {
1796            super::section::write_ans_modular_header(&mut lf_global_writer, &code)?;
1797        }
1798
1799        // GroupHeader for global modular stream — includes RCT (if RGB) + squeeze transform
1800        lf_global_writer.write(1, 1)?; // use_global_tree = true
1801        super::encode::write_wp_header(&mut lf_global_writer, &wp_params)?;
1802        if has_rct {
1803            // nb_transforms = 2: U32 BitsOffset(4,2), offset=0
1804            lf_global_writer.write(2, 2)?;
1805            lf_global_writer.write(4, 0)?;
1806            write_rct_transform(&mut lf_global_writer, 0, RctType::YCOCG)?;
1807            write_squeeze_transform(&mut lf_global_writer, &squeeze_params)?;
1808        } else {
1809            lf_global_writer.write(2, 1)?; // nb_transforms = 1
1810            write_squeeze_transform(&mut lf_global_writer, &squeeze_params)?;
1811        }
1812
1813        // Write global channel tokens
1814        write_tokens_ans(
1815            &global_tokens,
1816            &code,
1817            lz77_params.as_ref(),
1818            &mut lf_global_writer,
1819        )?;
1820
1821        lf_global_writer.zero_pad_to_byte();
1822        let lf_global_data = lf_global_writer.finish();
1823
1824        crate::trace::debug_eprintln!(
1825            "SQUEEZE_TREE_MULTI: LfGlobal = {} bytes ({} global channels, {} contexts)",
1826            lf_global_data.len(),
1827            global_cutoff,
1828            num_contexts,
1829        );
1830
1831        // Step 8: Write LfGroup sections
1832        let mut lf_group_data: Vec<Vec<u8>> = Vec::with_capacity(num_lf_groups);
1833        for lg_tokens in &lf_group_tokens {
1834            let mut lg_writer = BitWriter::new();
1835
1836            if lg_tokens.is_empty() {
1837                lf_group_data.push(lg_writer.finish());
1838                continue;
1839            }
1840
1841            // GroupHeader
1842            lg_writer.write(1, 1)?; // use_global_tree = true
1843            super::encode::write_wp_header(&mut lg_writer, &wp_params)?;
1844            lg_writer.write(2, 0)?; // nb_transforms = 0
1845
1846            write_tokens_ans(lg_tokens, &code, lz77_params.as_ref(), &mut lg_writer)?;
1847
1848            lg_writer.zero_pad_to_byte();
1849            let data = lg_writer.finish();
1850            crate::trace::debug_eprintln!(
1851                "SQUEEZE_TREE_MULTI: LfGroup = {} bytes ({} tokens)",
1852                data.len(),
1853                lg_tokens.len(),
1854            );
1855            lf_group_data.push(data);
1856        }
1857
1858        // Step 9: HfGlobal is empty for modular
1859        let hf_global_data: Vec<u8> = Vec::new();
1860
1861        // Step 10: Write PassGroup sections — parallelizable
1862        let pass_group_data: Vec<Vec<u8>> =
1863            crate::parallel::parallel_map_result(num_groups, |g| {
1864                let pg_tokens = &pass_group_tokens[g];
1865                let mut pg_writer = BitWriter::new();
1866
1867                if pg_tokens.is_empty() {
1868                    return Ok(pg_writer.finish());
1869                }
1870
1871                // GroupHeader
1872                pg_writer.write(1, 1)?; // use_global_tree = true
1873                super::encode::write_wp_header(&mut pg_writer, &wp_params)?;
1874                pg_writer.write(2, 0)?; // nb_transforms = 0
1875
1876                write_tokens_ans(pg_tokens, &code, lz77_params.as_ref(), &mut pg_writer)?;
1877
1878                pg_writer.zero_pad_to_byte();
1879                let data = pg_writer.finish();
1880                crate::trace::debug_eprintln!(
1881                    "SQUEEZE_TREE_MULTI: PassGroup = {} bytes ({} tokens)",
1882                    data.len(),
1883                    pg_tokens.len(),
1884                );
1885                Ok(data)
1886            })?;
1887
1888        // Step 11: Assemble TOC and sections
1889        let mut section_sizes = Vec::with_capacity(2 + num_lf_groups + num_groups);
1890        section_sizes.push(lf_global_data.len());
1891        for data in &lf_group_data {
1892            section_sizes.push(data.len());
1893        }
1894        section_sizes.push(hf_global_data.len());
1895        for data in &pass_group_data {
1896            section_sizes.push(data.len());
1897        }
1898
1899        self.write_toc_multi(writer, &section_sizes)?;
1900
1901        // Write all section data
1902        for byte in lf_global_data {
1903            writer.write_u8(byte)?;
1904        }
1905        for data in lf_group_data {
1906            for byte in data {
1907                writer.write_u8(byte)?;
1908            }
1909        }
1910        for byte in hf_global_data {
1911            writer.write_u8(byte)?;
1912        }
1913        for data in pass_group_data {
1914            for byte in data {
1915                writer.write_u8(byte)?;
1916            }
1917        }
1918
1919        Ok(())
1920    }
1921
1922    /// Encode modular image body (TOC + sections) without writing a frame header.
1923    ///
1924    /// Caller is responsible for writing the frame header before calling this.
1925    /// This enables encoding reference frames with custom frame headers (e.g.,
1926    /// `FrameType::ReferenceOnly`, `save_before_ct=true`) while getting full
1927    /// FrameEncoder features (RCT, multi-group, histogram optimization, ANS).
1928    pub(crate) fn encode_modular_body(
1929        &self,
1930        image: &ModularImage,
1931        writer: &mut BitWriter,
1932    ) -> Result<()> {
1933        let num_groups = self.num_groups();
1934
1935        if num_groups == 1 {
1936            // Single group: all sections combined into one TOC entry
1937            let mut section_writer = BitWriter::new();
1938
1939            let use_rct = image.channels.len() >= 3 && !self.options.skip_rct;
1940            if self.options.use_tree_learning && self.options.use_ans {
1941                write_modular_stream_with_tree(
1942                    image,
1943                    &mut section_writer,
1944                    &self.options.profile,
1945                    use_rct,
1946                    self.options.enable_lz77,
1947                    self.options.lz77_method,
1948                )?;
1949            } else if use_rct {
1950                super::encode::write_modular_stream_with_rct(
1951                    image,
1952                    &mut section_writer,
1953                    self.options.use_ans,
1954                )?;
1955            } else {
1956                write_improved_modular_stream(image, &mut section_writer, self.options.use_ans)?;
1957            }
1958
1959            let section_data = section_writer.finish();
1960            self.write_toc(writer, section_data.len())?;
1961            for byte in section_data {
1962                writer.write_u8(byte)?;
1963            }
1964        } else {
1965            // Multi-group: use the standard multi-group encoder (no patches in body)
1966            self.encode_modular_multi_group_inner(image, writer, None)?;
1967        }
1968
1969        Ok(())
1970    }
1971
1972    /// Writes the table of contents with a single section.
1973    fn write_toc(&self, writer: &mut BitWriter, section_size: usize) -> Result<()> {
1974        self.write_toc_multi(writer, &[section_size])
1975    }
1976
1977    /// Writes the table of contents with multiple sections.
1978    fn write_toc_multi(&self, writer: &mut BitWriter, section_sizes: &[usize]) -> Result<()> {
1979        crate::trace::debug_eprintln!("TOC [bit {}]: Writing permuted = 0", writer.bits_written());
1980        // permuted = false
1981        writer.write(1, 0)?;
1982
1983        crate::trace::debug_eprintln!(
1984            "TOC [bit {}]: After permuted, byte aligning",
1985            writer.bits_written()
1986        );
1987        // Byte align before TOC entries (permutation reads, then aligns)
1988        writer.zero_pad_to_byte();
1989
1990        // Write TOC entries using u2S(Bits(10), Bits(14)+1024, Bits(22)+17408, Bits(30)+4211712)
1991        #[allow(clippy::unused_enumerate_index)]
1992        for (_i, &size) in section_sizes.iter().enumerate() {
1993            crate::trace::debug_eprintln!(
1994                "TOC [bit {}]: Writing entry {} size={}",
1995                writer.bits_written(),
1996                _i,
1997                size
1998            );
1999            self.write_toc_entry(writer, size as u32)?;
2000        }
2001        crate::trace::debug_eprintln!("TOC [bit {}]: After TOC entries", writer.bits_written());
2002
2003        // Byte align after TOC entries
2004        writer.zero_pad_to_byte();
2005
2006        Ok(())
2007    }
2008
2009    /// Writes a single TOC entry.
2010    fn write_toc_entry(&self, writer: &mut BitWriter, size: u32) -> Result<()> {
2011        // u2S(Bits(10), Bits(14)+1024, Bits(22)+17408, Bits(30)+4211712)
2012        if size < 1024 {
2013            writer.write(2, 0)?; // selector 0
2014            writer.write(10, size as u64)?;
2015        } else if size < 17408 {
2016            writer.write(2, 1)?; // selector 1
2017            writer.write(14, (size - 1024) as u64)?;
2018        } else if size < 4211712 {
2019            writer.write(2, 2)?; // selector 2
2020            writer.write(22, (size - 17408) as u64)?;
2021        } else {
2022            writer.write(2, 3)?; // selector 3
2023            writer.write(30, (size - 4211712) as u64)?;
2024        }
2025        Ok(())
2026    }
2027
2028    /// Returns the number of groups in this frame.
2029    pub fn num_groups(&self) -> usize {
2030        let num_groups_x = self.width.div_ceil(GROUP_DIM);
2031        let num_groups_y = self.height.div_ceil(GROUP_DIM);
2032        num_groups_x * num_groups_y
2033    }
2034
2035    /// Returns the number of groups in X direction.
2036    pub fn num_groups_x(&self) -> usize {
2037        self.width.div_ceil(GROUP_DIM)
2038    }
2039
2040    /// Returns the number of groups in Y direction.
2041    pub fn num_groups_y(&self) -> usize {
2042        self.height.div_ceil(GROUP_DIM)
2043    }
2044
2045    /// Returns the number of LF groups (DC groups).
2046    /// LF groups are 8x the size of regular groups (2048x2048 pixels).
2047    pub fn num_lf_groups(&self) -> usize {
2048        let lf_group_dim = GROUP_DIM * 8; // 2048
2049        let lf_groups_x = self.width.div_ceil(lf_group_dim);
2050        let lf_groups_y = self.height.div_ceil(lf_group_dim);
2051        lf_groups_x * lf_groups_y
2052    }
2053
2054    /// Returns the number of TOC entries for this frame.
2055    /// Single group: 1 entry
2056    /// Multi-group: 2 + num_lf_groups + num_groups * num_passes
2057    pub fn num_toc_entries(&self, num_passes: usize) -> usize {
2058        let num_groups = self.num_groups();
2059        if num_groups == 1 && num_passes == 1 {
2060            1
2061        } else {
2062            2 + self.num_lf_groups() + num_groups * num_passes
2063        }
2064    }
2065
2066    /// Get the pixel bounds for a group.
2067    /// Returns (x_start, y_start, x_end, y_end).
2068    pub fn group_bounds(&self, group_idx: usize) -> (usize, usize, usize, usize) {
2069        let num_groups_x = self.num_groups_x();
2070        let gx = group_idx % num_groups_x;
2071        let gy = group_idx / num_groups_x;
2072
2073        let x_start = gx * GROUP_DIM;
2074        let y_start = gy * GROUP_DIM;
2075        let x_end = (x_start + GROUP_DIM).min(self.width);
2076        let y_end = (y_start + GROUP_DIM).min(self.height);
2077
2078        (x_start, y_start, x_end, y_end)
2079    }
2080}
2081
2082#[cfg(test)]
2083mod tests {
2084    use super::*;
2085
2086    #[test]
2087    fn test_frame_encoder_creation() {
2088        let encoder = FrameEncoder::new(256, 256, FrameEncoderOptions::default());
2089        assert_eq!(encoder.num_groups(), 1);
2090    }
2091
2092    #[test]
2093    fn test_frame_encoder_multi_group() {
2094        let encoder = FrameEncoder::new(512, 512, FrameEncoderOptions::default());
2095        assert_eq!(encoder.num_groups(), 4); // 2x2 groups
2096        assert_eq!(encoder.num_groups_x(), 2);
2097        assert_eq!(encoder.num_groups_y(), 2);
2098        assert_eq!(encoder.num_lf_groups(), 1); // 512 < 2048
2099    }
2100
2101    #[test]
2102    fn test_group_bounds() {
2103        let encoder = FrameEncoder::new(512, 512, FrameEncoderOptions::default());
2104
2105        // Group 0: top-left
2106        let (x0, y0, x1, y1) = encoder.group_bounds(0);
2107        assert_eq!((x0, y0, x1, y1), (0, 0, 256, 256));
2108
2109        // Group 1: top-right
2110        let (x0, y0, x1, y1) = encoder.group_bounds(1);
2111        assert_eq!((x0, y0, x1, y1), (256, 0, 512, 256));
2112
2113        // Group 2: bottom-left
2114        let (x0, y0, x1, y1) = encoder.group_bounds(2);
2115        assert_eq!((x0, y0, x1, y1), (0, 256, 256, 512));
2116
2117        // Group 3: bottom-right
2118        let (x0, y0, x1, y1) = encoder.group_bounds(3);
2119        assert_eq!((x0, y0, x1, y1), (256, 256, 512, 512));
2120    }
2121
2122    #[test]
2123    fn test_group_bounds_partial() {
2124        // 300x200 image: 2x1 groups, second group is partial
2125        let encoder = FrameEncoder::new(300, 200, FrameEncoderOptions::default());
2126        assert_eq!(encoder.num_groups(), 2); // 2x1
2127
2128        let (x0, y0, x1, y1) = encoder.group_bounds(0);
2129        assert_eq!((x0, y0, x1, y1), (0, 0, 256, 200));
2130
2131        let (x0, y0, x1, y1) = encoder.group_bounds(1);
2132        assert_eq!((x0, y0, x1, y1), (256, 0, 300, 200)); // Clamped to image bounds
2133    }
2134
2135    #[test]
2136    fn test_num_toc_entries() {
2137        // Single group, single pass
2138        let encoder = FrameEncoder::new(256, 256, FrameEncoderOptions::default());
2139        assert_eq!(encoder.num_toc_entries(1), 1);
2140
2141        // 4 groups, single pass: 2 + 1 + 4 = 7
2142        let encoder = FrameEncoder::new(512, 512, FrameEncoderOptions::default());
2143        assert_eq!(encoder.num_toc_entries(1), 7);
2144
2145        // 4 groups, 2 passes: 2 + 1 + 8 = 11
2146        assert_eq!(encoder.num_toc_entries(2), 11);
2147    }
2148
2149    #[test]
2150    fn test_encode_multi_group_image() {
2151        // 300x300 RGB image - requires 2x2 = 4 groups
2152        let mut data = Vec::with_capacity(300 * 300 * 3);
2153        for y in 0..300 {
2154            for x in 0..300 {
2155                // Smooth gradient for good compression
2156                data.push(((x + y) % 256) as u8); // R
2157                data.push(((x * 2) % 256) as u8); // G
2158                data.push(((y * 2) % 256) as u8); // B
2159            }
2160        }
2161
2162        let image = ModularImage::from_rgb8(&data, 300, 300).unwrap();
2163
2164        let encoder = FrameEncoder::new(300, 300, FrameEncoderOptions::default());
2165        assert_eq!(encoder.num_groups(), 4); // 2x2 groups
2166
2167        let mut writer = BitWriter::new();
2168        let color_encoding = ColorEncoding::srgb();
2169
2170        encoder
2171            .encode_modular(&image, &color_encoding, &mut writer)
2172            .unwrap();
2173
2174        let bytes = writer.finish_with_padding();
2175        crate::trace::debug_eprintln!("Multi-group modular: {} bytes", bytes.len());
2176        assert!(!bytes.is_empty());
2177        // Should have reasonable size (not huge, not tiny)
2178        assert!(bytes.len() > 100); // Has content
2179        assert!(bytes.len() < 300 * 300 * 3); // Better than raw
2180    }
2181
2182    #[test]
2183    fn test_encode_small_image() {
2184        // 4x4 RGB image with only 4 unique values (max for simple Huffman)
2185        // Pattern: checkerboard of two colors
2186        let mut data = Vec::with_capacity(4 * 4 * 3);
2187        for y in 0..4 {
2188            for x in 0..4 {
2189                let v = if (x + y) % 2 == 0 { 0u8 } else { 128u8 };
2190                data.push(v); // R
2191                data.push(v); // G
2192                data.push(v); // B
2193            }
2194        }
2195
2196        let image = ModularImage::from_rgb8(&data, 4, 4).unwrap();
2197
2198        let encoder = FrameEncoder::new(4, 4, FrameEncoderOptions::default());
2199        let mut writer = BitWriter::new();
2200        let color_encoding = ColorEncoding::srgb();
2201
2202        encoder
2203            .encode_modular(&image, &color_encoding, &mut writer)
2204            .unwrap();
2205
2206        let bytes = writer.finish_with_padding();
2207        assert!(!bytes.is_empty());
2208    }
2209}