Skip to main content

jxl_encoder/modular/
frame.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! Frame encoder - assembles complete JXL frames.
6
7use super::channel::ModularImage;
8use super::encode::{
9    build_histogram_from_residuals, collect_all_residuals, write_global_modular_section,
10    write_group_modular_section, write_improved_modular_stream, write_modular_stream_with_tree,
11};
12use super::section::write_global_modular_section_with_tree;
13use crate::GROUP_DIM;
14use crate::bit_writer::BitWriter;
15use crate::error::Result;
16use crate::headers::ColorEncoding;
17use crate::headers::frame_header::{BlendMode, FrameCrop, FrameHeader};
18
19/// Options for frame encoding.
20#[derive(Debug, Clone)]
21pub struct FrameEncoderOptions {
22    /// Use modular mode (lossless).
23    pub use_modular: bool,
24    /// Effort level (1-10, higher = better compression, slower).
25    pub effort: u8,
26    /// Use ANS entropy coding instead of Huffman for modular.
27    pub use_ans: bool,
28    /// Use content-adaptive MA tree learning for modular encoding.
29    pub use_tree_learning: bool,
30    /// Use squeeze (Haar wavelet) transform for modular encoding.
31    pub use_squeeze: bool,
32    /// Whether this frame is part of an animation (enables duration field in header).
33    pub have_animation: bool,
34    /// Duration of this frame in ticks (only used when have_animation is true).
35    pub duration: u32,
36    /// Whether this is the last frame in the image/animation.
37    pub is_last: bool,
38    /// Optional crop rectangle for this frame (None = full frame).
39    pub crop: Option<FrameCrop>,
40}
41
42impl Default for FrameEncoderOptions {
43    fn default() -> Self {
44        Self {
45            use_modular: true, // Default to lossless
46            effort: 7,
47            use_ans: false,
48            use_tree_learning: false,
49            use_squeeze: false,
50            have_animation: false,
51            duration: 0,
52            is_last: true,
53            crop: None,
54        }
55    }
56}
57
58/// Encodes a single frame.
59pub struct FrameEncoder {
60    /// Encoding options.
61    #[allow(dead_code)]
62    options: FrameEncoderOptions,
63    /// Image width.
64    width: usize,
65    /// Image height.
66    height: usize,
67    #[allow(dead_code)]
68    /// Number of extra channels (e.g., 1 for alpha).
69    num_extra_channels: usize,
70}
71
72impl FrameEncoder {
73    /// Creates a new frame encoder.
74    pub fn new(width: usize, height: usize, options: FrameEncoderOptions) -> Self {
75        Self {
76            options,
77            width,
78            height,
79            num_extra_channels: 0,
80        }
81    }
82
83    /// Creates a new frame encoder with extra channel support.
84    pub fn new_with_extra_channels(
85        width: usize,
86        height: usize,
87        options: FrameEncoderOptions,
88        num_extra_channels: usize,
89    ) -> Self {
90        Self {
91            options,
92            width,
93            height,
94            num_extra_channels,
95        }
96    }
97
98    /// Encodes a modular image into a frame.
99    pub fn encode_modular(
100        &self,
101        image: &ModularImage,
102        _color_encoding: &ColorEncoding,
103        writer: &mut BitWriter,
104    ) -> Result<()> {
105        // Compute num_extra_channels from image
106        let num_extra_channels = if image.has_alpha { 1 } else { 0 };
107
108        // Write frame header using unified FrameHeader
109        {
110            let mut fh = FrameHeader::lossless();
111            fh.ec_upsampling = vec![1; num_extra_channels];
112            fh.ec_blend_modes = vec![BlendMode::Replace; num_extra_channels];
113            fh.have_animation = self.options.have_animation;
114            fh.duration = self.options.duration;
115            fh.is_last = self.options.is_last;
116            if let Some(ref crop) = self.options.crop {
117                fh.x0 = crop.x0;
118                fh.y0 = crop.y0;
119                fh.width = crop.width;
120                fh.height = crop.height;
121                fh.blend_mode = BlendMode::Replace;
122                fh.blend_source = 1;
123            }
124            // For animation, save non-last frames to reference slot 1
125            // so crop frames can composite onto the previous canvas.
126            if self.options.have_animation && !self.options.is_last {
127                fh.save_as_reference = 1;
128            }
129            fh.write(writer)?;
130        }
131
132        let num_groups = self.num_groups();
133
134        if num_groups == 1 {
135            // Single group: all sections combined into one TOC entry
136            let mut section_writer = BitWriter::new();
137
138            if self.options.use_squeeze {
139                super::encode::write_modular_stream_with_squeeze(
140                    image,
141                    &mut section_writer,
142                    self.options.use_ans,
143                )?;
144            } else if self.options.use_tree_learning && self.options.use_ans {
145                write_modular_stream_with_tree(
146                    image,
147                    &mut section_writer,
148                    256,                       // max_nodes
149                    1.0,                       // split_threshold
150                    image.channels.len() >= 3, // RCT for RGB
151                )?;
152            } else {
153                write_improved_modular_stream(image, &mut section_writer, self.options.use_ans)?;
154            }
155
156            let section_data = section_writer.finish();
157            let section_size = section_data.len();
158
159            crate::trace::debug_eprintln!("FRAME_ENCODER: section_size = {} bytes", section_size);
160
161            // Write TOC
162            self.write_toc(writer, section_size)?;
163
164            // Append section data (already byte-aligned)
165            for byte in section_data {
166                writer.write_u8(byte)?;
167            }
168        } else {
169            // Multi-group: separate TOC entries for global and each group
170            self.encode_modular_multi_group(image, writer)?;
171        }
172
173        Ok(())
174    }
175
176    /// Encodes a modular image using multi-group format (>256x256 images).
177    ///
178    /// For multi-group frames, the JXL spec requires this TOC structure:
179    /// - Section 0: LfGlobal (dc_quant + tree + histograms)
180    /// - Section 1: HfGlobal (empty for modular encoding)
181    /// - Section 2..2+num_lf_groups: LfGroup (empty for modular encoding)
182    /// - Section 2+num_lf_groups..: PassGroup (GroupHeader + pixel data per 256x256 region)
183    fn encode_modular_multi_group(
184        &self,
185        image: &ModularImage,
186        writer: &mut BitWriter,
187    ) -> Result<()> {
188        let num_groups = self.num_groups();
189        let num_lf_groups = self.num_lf_groups();
190        let num_passes = 1;
191
192        crate::trace::debug_eprintln!(
193            "MULTI_GROUP: Encoding {}x{} image with {} groups, {} lf_groups",
194            self.width,
195            self.height,
196            num_groups,
197            num_lf_groups
198        );
199
200        // Step 1: Extract each group image
201        let mut group_images: Vec<ModularImage> = Vec::with_capacity(num_groups);
202        for group_idx in 0..num_groups {
203            let (x_start, y_start, x_end, y_end) = self.group_bounds(group_idx);
204            let group_image = image.extract_region(x_start, y_start, x_end, y_end)?;
205            group_images.push(group_image);
206        }
207
208        // Step 2: Write LfGlobal section (tree + histogram)
209        let mut lf_global_writer = BitWriter::new();
210        let global_state = if self.options.use_tree_learning && self.options.use_ans {
211            // Tree learning path: gather samples, learn tree, build multi-context ANS
212            write_global_modular_section_with_tree(
213                &group_images,
214                &mut lf_global_writer,
215                256, // max_nodes
216                1.0, // split_threshold
217            )?
218        } else {
219            // Standard path: collect residuals with gradient predictor
220            let mut all_residuals = Vec::new();
221            let mut max_residual: u32 = 0;
222            for group_image in &group_images {
223                let (group_residuals, group_max) = collect_all_residuals(group_image);
224                all_residuals.extend(group_residuals);
225                max_residual = max_residual.max(group_max);
226            }
227            let (histogram, max_token) =
228                build_histogram_from_residuals(&all_residuals, max_residual);
229
230            crate::trace::debug_eprintln!(
231                "MULTI_GROUP: {} total residuals, max_raw={}, max_token={}, {} unique tokens",
232                all_residuals.len(),
233                max_residual,
234                max_token,
235                histogram.iter().filter(|&&c| c > 0).count()
236            );
237
238            write_global_modular_section(
239                &all_residuals,
240                &histogram,
241                max_token,
242                &mut lf_global_writer,
243                self.options.use_ans,
244            )?
245        };
246        let lf_global_data = lf_global_writer.finish();
247
248        crate::trace::debug_eprintln!(
249            "MULTI_GROUP: LfGlobal section = {} bytes",
250            lf_global_data.len()
251        );
252
253        // Step 3: HfGlobal is empty for modular encoding (0 bytes)
254        let hf_global_data: Vec<u8> = Vec::new();
255        crate::trace::debug_eprintln!(
256            "MULTI_GROUP: HfGlobal section = 0 bytes (empty for modular)"
257        );
258
259        // Step 4: LfGroup sections are empty for modular encoding
260        let lf_group_data: Vec<Vec<u8>> = (0..num_lf_groups).map(|_| Vec::new()).collect();
261        crate::trace::debug_eprintln!(
262            "MULTI_GROUP: {} LfGroup sections = 0 bytes each (empty for modular)",
263            num_lf_groups
264        );
265
266        // Step 5: Write each PassGroup's data (GroupHeader + pixel data)
267        // Use the pre-extracted group_images to ensure residual consistency
268        let mut pass_group_data: Vec<Vec<u8>> = Vec::with_capacity(num_groups * num_passes);
269        for (group_idx, group_image) in group_images.iter().enumerate() {
270            for _pass in 0..num_passes {
271                let (_x_start, _y_start, _x_end, _y_end) = self.group_bounds(group_idx);
272
273                crate::trace::debug_eprintln!(
274                    "MULTI_GROUP: Group {} bounds ({}, {}) - ({}, {}), size {}x{}",
275                    group_idx,
276                    _x_start,
277                    _y_start,
278                    _x_end,
279                    _y_end,
280                    group_image.width(),
281                    group_image.height()
282                );
283
284                let mut group_writer = BitWriter::new();
285                write_group_modular_section(group_image, &global_state, &mut group_writer)?;
286                pass_group_data.push(group_writer.finish());
287
288                crate::trace::debug_eprintln!(
289                    "MULTI_GROUP: PassGroup {} section = {} bytes",
290                    group_idx,
291                    pass_group_data.last().unwrap().len()
292                );
293            }
294        }
295
296        // Step 6: Collect all section sizes in correct order and write TOC
297        // JXL spec order: LfGlobal, LfGroup[0..num_lf_groups], HfGlobal, PassGroup[0..num_groups*num_passes]
298        // Note: LfGroup comes BEFORE HfGlobal!
299        let mut section_sizes = Vec::with_capacity(2 + num_lf_groups + num_groups * num_passes);
300        section_sizes.push(lf_global_data.len());
301        for data in &lf_group_data {
302            section_sizes.push(data.len());
303        }
304        section_sizes.push(hf_global_data.len());
305        for data in &pass_group_data {
306            section_sizes.push(data.len());
307        }
308
309        crate::trace::debug_eprintln!(
310            "MULTI_GROUP: {} total sections, sizes = {:?}",
311            section_sizes.len(),
312            section_sizes
313        );
314
315        self.write_toc_multi(writer, &section_sizes)?;
316
317        // Step 7: Append all section data in same order
318        for byte in lf_global_data {
319            writer.write_u8(byte)?;
320        }
321        for data in lf_group_data {
322            for byte in data {
323                writer.write_u8(byte)?;
324            }
325        }
326        for byte in hf_global_data {
327            writer.write_u8(byte)?;
328        }
329        for data in pass_group_data {
330            for byte in data {
331                writer.write_u8(byte)?;
332            }
333        }
334
335        Ok(())
336    }
337
338    /// Writes the table of contents with a single section.
339    fn write_toc(&self, writer: &mut BitWriter, section_size: usize) -> Result<()> {
340        self.write_toc_multi(writer, &[section_size])
341    }
342
343    /// Writes the table of contents with multiple sections.
344    fn write_toc_multi(&self, writer: &mut BitWriter, section_sizes: &[usize]) -> Result<()> {
345        crate::trace::debug_eprintln!("TOC [bit {}]: Writing permuted = 0", writer.bits_written());
346        // permuted = false
347        writer.write(1, 0)?;
348
349        crate::trace::debug_eprintln!(
350            "TOC [bit {}]: After permuted, byte aligning",
351            writer.bits_written()
352        );
353        // Byte align before TOC entries (permutation reads, then aligns)
354        writer.zero_pad_to_byte();
355
356        // Write TOC entries using u2S(Bits(10), Bits(14)+1024, Bits(22)+17408, Bits(30)+4211712)
357        #[allow(clippy::unused_enumerate_index)]
358        for (_i, &size) in section_sizes.iter().enumerate() {
359            crate::trace::debug_eprintln!(
360                "TOC [bit {}]: Writing entry {} size={}",
361                writer.bits_written(),
362                _i,
363                size
364            );
365            self.write_toc_entry(writer, size as u32)?;
366        }
367        crate::trace::debug_eprintln!("TOC [bit {}]: After TOC entries", writer.bits_written());
368
369        // Byte align after TOC entries
370        writer.zero_pad_to_byte();
371
372        Ok(())
373    }
374
375    /// Writes a single TOC entry.
376    fn write_toc_entry(&self, writer: &mut BitWriter, size: u32) -> Result<()> {
377        // u2S(Bits(10), Bits(14)+1024, Bits(22)+17408, Bits(30)+4211712)
378        if size < 1024 {
379            writer.write(2, 0)?; // selector 0
380            writer.write(10, size as u64)?;
381        } else if size < 17408 {
382            writer.write(2, 1)?; // selector 1
383            writer.write(14, (size - 1024) as u64)?;
384        } else if size < 4211712 {
385            writer.write(2, 2)?; // selector 2
386            writer.write(22, (size - 17408) as u64)?;
387        } else {
388            writer.write(2, 3)?; // selector 3
389            writer.write(30, (size - 4211712) as u64)?;
390        }
391        Ok(())
392    }
393
394    /// Returns the number of groups in this frame.
395    pub fn num_groups(&self) -> usize {
396        let num_groups_x = self.width.div_ceil(GROUP_DIM);
397        let num_groups_y = self.height.div_ceil(GROUP_DIM);
398        num_groups_x * num_groups_y
399    }
400
401    /// Returns the number of groups in X direction.
402    pub fn num_groups_x(&self) -> usize {
403        self.width.div_ceil(GROUP_DIM)
404    }
405
406    /// Returns the number of groups in Y direction.
407    pub fn num_groups_y(&self) -> usize {
408        self.height.div_ceil(GROUP_DIM)
409    }
410
411    /// Returns the number of LF groups (DC groups).
412    /// LF groups are 8x the size of regular groups (2048x2048 pixels).
413    pub fn num_lf_groups(&self) -> usize {
414        let lf_group_dim = GROUP_DIM * 8; // 2048
415        let lf_groups_x = self.width.div_ceil(lf_group_dim);
416        let lf_groups_y = self.height.div_ceil(lf_group_dim);
417        lf_groups_x * lf_groups_y
418    }
419
420    /// Returns the number of TOC entries for this frame.
421    /// Single group: 1 entry
422    /// Multi-group: 2 + num_lf_groups + num_groups * num_passes
423    pub fn num_toc_entries(&self, num_passes: usize) -> usize {
424        let num_groups = self.num_groups();
425        if num_groups == 1 && num_passes == 1 {
426            1
427        } else {
428            2 + self.num_lf_groups() + num_groups * num_passes
429        }
430    }
431
432    /// Get the pixel bounds for a group.
433    /// Returns (x_start, y_start, x_end, y_end).
434    pub fn group_bounds(&self, group_idx: usize) -> (usize, usize, usize, usize) {
435        let num_groups_x = self.num_groups_x();
436        let gx = group_idx % num_groups_x;
437        let gy = group_idx / num_groups_x;
438
439        let x_start = gx * GROUP_DIM;
440        let y_start = gy * GROUP_DIM;
441        let x_end = (x_start + GROUP_DIM).min(self.width);
442        let y_end = (y_start + GROUP_DIM).min(self.height);
443
444        (x_start, y_start, x_end, y_end)
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    #[test]
453    fn test_frame_encoder_creation() {
454        let encoder = FrameEncoder::new(256, 256, FrameEncoderOptions::default());
455        assert_eq!(encoder.num_groups(), 1);
456    }
457
458    #[test]
459    fn test_frame_encoder_multi_group() {
460        let encoder = FrameEncoder::new(512, 512, FrameEncoderOptions::default());
461        assert_eq!(encoder.num_groups(), 4); // 2x2 groups
462        assert_eq!(encoder.num_groups_x(), 2);
463        assert_eq!(encoder.num_groups_y(), 2);
464        assert_eq!(encoder.num_lf_groups(), 1); // 512 < 2048
465    }
466
467    #[test]
468    fn test_group_bounds() {
469        let encoder = FrameEncoder::new(512, 512, FrameEncoderOptions::default());
470
471        // Group 0: top-left
472        let (x0, y0, x1, y1) = encoder.group_bounds(0);
473        assert_eq!((x0, y0, x1, y1), (0, 0, 256, 256));
474
475        // Group 1: top-right
476        let (x0, y0, x1, y1) = encoder.group_bounds(1);
477        assert_eq!((x0, y0, x1, y1), (256, 0, 512, 256));
478
479        // Group 2: bottom-left
480        let (x0, y0, x1, y1) = encoder.group_bounds(2);
481        assert_eq!((x0, y0, x1, y1), (0, 256, 256, 512));
482
483        // Group 3: bottom-right
484        let (x0, y0, x1, y1) = encoder.group_bounds(3);
485        assert_eq!((x0, y0, x1, y1), (256, 256, 512, 512));
486    }
487
488    #[test]
489    fn test_group_bounds_partial() {
490        // 300x200 image: 2x1 groups, second group is partial
491        let encoder = FrameEncoder::new(300, 200, FrameEncoderOptions::default());
492        assert_eq!(encoder.num_groups(), 2); // 2x1
493
494        let (x0, y0, x1, y1) = encoder.group_bounds(0);
495        assert_eq!((x0, y0, x1, y1), (0, 0, 256, 200));
496
497        let (x0, y0, x1, y1) = encoder.group_bounds(1);
498        assert_eq!((x0, y0, x1, y1), (256, 0, 300, 200)); // Clamped to image bounds
499    }
500
501    #[test]
502    fn test_num_toc_entries() {
503        // Single group, single pass
504        let encoder = FrameEncoder::new(256, 256, FrameEncoderOptions::default());
505        assert_eq!(encoder.num_toc_entries(1), 1);
506
507        // 4 groups, single pass: 2 + 1 + 4 = 7
508        let encoder = FrameEncoder::new(512, 512, FrameEncoderOptions::default());
509        assert_eq!(encoder.num_toc_entries(1), 7);
510
511        // 4 groups, 2 passes: 2 + 1 + 8 = 11
512        assert_eq!(encoder.num_toc_entries(2), 11);
513    }
514
515    #[test]
516    fn test_encode_multi_group_image() {
517        // 300x300 RGB image - requires 2x2 = 4 groups
518        let mut data = Vec::with_capacity(300 * 300 * 3);
519        for y in 0..300 {
520            for x in 0..300 {
521                // Smooth gradient for good compression
522                data.push(((x + y) % 256) as u8); // R
523                data.push(((x * 2) % 256) as u8); // G
524                data.push(((y * 2) % 256) as u8); // B
525            }
526        }
527
528        let image = ModularImage::from_rgb8(&data, 300, 300).unwrap();
529
530        let encoder = FrameEncoder::new(300, 300, FrameEncoderOptions::default());
531        assert_eq!(encoder.num_groups(), 4); // 2x2 groups
532
533        let mut writer = BitWriter::new();
534        let color_encoding = ColorEncoding::srgb();
535
536        encoder
537            .encode_modular(&image, &color_encoding, &mut writer)
538            .unwrap();
539
540        let bytes = writer.finish_with_padding();
541        crate::trace::debug_eprintln!("Multi-group modular: {} bytes", bytes.len());
542        assert!(!bytes.is_empty());
543        // Should have reasonable size (not huge, not tiny)
544        assert!(bytes.len() > 100); // Has content
545        assert!(bytes.len() < 300 * 300 * 3); // Better than raw
546    }
547
548    #[test]
549    fn test_encode_small_image() {
550        // 4x4 RGB image with only 4 unique values (max for simple Huffman)
551        // Pattern: checkerboard of two colors
552        let mut data = Vec::with_capacity(4 * 4 * 3);
553        for y in 0..4 {
554            for x in 0..4 {
555                let v = if (x + y) % 2 == 0 { 0u8 } else { 128u8 };
556                data.push(v); // R
557                data.push(v); // G
558                data.push(v); // B
559            }
560        }
561
562        let image = ModularImage::from_rgb8(&data, 4, 4).unwrap();
563
564        let encoder = FrameEncoder::new(4, 4, FrameEncoderOptions::default());
565        let mut writer = BitWriter::new();
566        let color_encoding = ColorEncoding::srgb();
567
568        encoder
569            .encode_modular(&image, &color_encoding, &mut writer)
570            .unwrap();
571
572        let bytes = writer.finish_with_padding();
573        assert!(!bytes.is_empty());
574    }
575}