Skip to main content

jxl_encoder/vardct/
encoder.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! Main tiny encoder implementation.
6
7use super::ac_strategy::{AcStrategyMap, adjust_quant_field_with_distance, compute_ac_strategy};
8use super::adaptive_quant::{compute_mask1x1, compute_quant_field_float, quantize_quant_field};
9use super::chroma_from_luma::{CflMap, compute_cfl_map};
10use super::common::*;
11use super::frame::{DistanceParams, write_toc};
12use super::gaborish::gaborish_inverse;
13use super::noise::{denoise_xyb, estimate_noise_params, noise_quality_coef};
14use super::static_codes::{get_ac_entropy_code, get_dc_entropy_code};
15use crate::bit_writer::BitWriter;
16#[cfg(feature = "debug-tokens")]
17use crate::debug_log;
18use crate::entropy_coding::encode::{
19    OwnedAnsEntropyCode, OwnedEntropyCode, write_entropy_code_ans, write_tokens, write_tokens_ans,
20};
21use crate::entropy_coding::token::Token;
22use crate::error::Result;
23use crate::headers::frame_header::FrameHeader;
24
25/// Create an AC strategy map forcing a specific strategy.
26pub(crate) fn force_strategy_map(
27    xsize_blocks: usize,
28    ysize_blocks: usize,
29    raw_strategy: u8,
30) -> AcStrategyMap {
31    AcStrategyMap::force_strategy(xsize_blocks, ysize_blocks, raw_strategy)
32}
33
34/// Entropy code that holds either Huffman or ANS code.
35pub enum BuiltEntropyCode<'a> {
36    /// Static Huffman prefix codes (borrowed).
37    StaticHuffman(crate::entropy_coding::encode::EntropyCode<'a>),
38    /// Dynamic Huffman prefix codes (owned).
39    Huffman(OwnedEntropyCode),
40    /// ANS distributions with context map.
41    Ans(OwnedAnsEntropyCode),
42}
43
44impl<'a> BuiltEntropyCode<'a> {
45    /// Write the entropy code header (context map + codes/distributions).
46    pub fn write_header(&self, writer: &mut BitWriter) -> Result<()> {
47        match self {
48            BuiltEntropyCode::StaticHuffman(code) => {
49                crate::entropy_coding::encode::write_entropy_code(code, writer)
50            }
51            BuiltEntropyCode::Huffman(code) => {
52                crate::entropy_coding::encode::write_entropy_code(&code.as_entropy_code(), writer)
53            }
54            BuiltEntropyCode::Ans(code) => write_entropy_code_ans(code, writer),
55        }
56    }
57
58    /// Write tokens using this entropy code.
59    pub fn write_tokens(
60        &self,
61        tokens: &[Token],
62        lz77: Option<&crate::entropy_coding::lz77::Lz77Params>,
63        writer: &mut BitWriter,
64    ) -> Result<()> {
65        match self {
66            BuiltEntropyCode::StaticHuffman(code) => write_tokens(tokens, code, lz77, writer),
67            BuiltEntropyCode::Huffman(code) => {
68                write_tokens(tokens, &code.as_entropy_code(), lz77, writer)
69            }
70            BuiltEntropyCode::Ans(code) => write_tokens_ans(tokens, code, lz77, writer),
71        }
72    }
73
74    /// Get the underlying Huffman code for streaming token writing.
75    ///
76    /// Panics if this is an ANS code (streaming with ANS is not supported).
77    pub fn as_huffman(&self) -> crate::entropy_coding::encode::EntropyCode<'_> {
78        match self {
79            BuiltEntropyCode::StaticHuffman(code) => *code,
80            BuiltEntropyCode::Huffman(code) => code.as_entropy_code(),
81            BuiltEntropyCode::Ans(_) => {
82                panic!("ANS codes cannot be used with streaming encoder")
83            }
84        }
85    }
86
87    #[allow(dead_code)]
88    /// Returns the number of contexts in this entropy code.
89    pub fn num_contexts(&self) -> usize {
90        match self {
91            BuiltEntropyCode::StaticHuffman(code) => code.num_contexts,
92            BuiltEntropyCode::Huffman(code) => code.context_map.len(),
93            BuiltEntropyCode::Ans(code) => code.context_map.len(),
94        }
95    }
96
97    #[allow(dead_code)]
98    /// Returns the number of histograms/prefix codes in this entropy code.
99    pub fn num_histograms(&self) -> usize {
100        match self {
101            BuiltEntropyCode::StaticHuffman(code) => code.num_prefix_codes,
102            BuiltEntropyCode::Huffman(code) => code.prefix_codes.len(),
103            BuiltEntropyCode::Ans(code) => code.histograms.len(),
104        }
105    }
106}
107
108/// Output of a VarDCT encode operation.
109pub struct VarDctOutput {
110    /// Encoded JXL codestream bytes.
111    pub data: Vec<u8>,
112    /// Per-strategy first-block counts, indexed by raw strategy code (0..19).
113    pub strategy_counts: [u32; 19],
114}
115
116/// Tiny JPEG XL encoder.
117///
118/// This is a simplified VarDCT encoder based on libjxl-tiny that uses:
119/// - Only DCT8, DCT8x16, DCT16x8 transforms
120/// - Huffman or ANS entropy coding
121/// - Default zig-zag coefficient order
122/// - Fixed context tree for DC
123pub struct VarDctEncoder {
124    /// Target distance (quality). 1.0 = visually lossless.
125    pub distance: f32,
126    /// Use dynamic Huffman codes built from actual token frequencies.
127    /// When true (default), uses a two-pass mode: collect tokens first, build optimal codes, then write.
128    /// When false, uses pre-computed static codes (streaming, single-pass).
129    pub optimize_codes: bool,
130    /// Use enhanced histogram clustering with pair merge refinement.
131    /// Only effective when `optimize_codes` is true.
132    ///
133    /// Note: The enhanced clustering algorithm was designed for ANS entropy coding
134    /// and may not provide benefits (or may slightly increase size) when used with
135    /// Huffman coding. This option is experimental.
136    pub enhanced_clustering: bool,
137    /// Use ANS entropy coding instead of Huffman.
138    /// Only effective when `optimize_codes` is true (requires two-pass mode).
139    /// ANS typically produces 5-10% smaller files than Huffman.
140    pub use_ans: bool,
141    /// Enable chroma-from-luma (CfL) optimization.
142    /// When true (default), computes per-tile ytox/ytob values via least-squares fitting.
143    /// When false, uses ytox=0, ytob=0 (no chroma decorrelation).
144    pub cfl_enabled: bool,
145    /// Enable adaptive AC strategy selection (DCT8/DCT16x8/DCT8x16).
146    /// When true (default), selects the best transform size per 16x16 block region.
147    /// When false, uses DCT8 for all blocks.
148    pub ac_strategy_enabled: bool,
149    /// Enable custom coefficient ordering.
150    /// When true (default when optimize_codes is true), reorders AC coefficients
151    /// so frequently-zero positions appear last, reducing bitstream size.
152    /// Only effective when `optimize_codes` is true (requires two-pass mode).
153    pub custom_orders: bool,
154    /// Force a specific AC strategy for all blocks (for testing).
155    /// When Some(strategy), uses that raw strategy code for all blocks that fit.
156    /// None (default) uses normal strategy selection based on `ac_strategy_enabled`.
157    pub force_strategy: Option<u8>,
158    /// Enable noise synthesis.
159    /// When true, estimates noise parameters from the image and encodes them
160    /// in the frame header. The decoder regenerates noise during rendering.
161    /// Off by default (matching libjxl's default).
162    pub enable_noise: bool,
163    /// Enable Wiener denoising pre-filter (requires `enable_noise`).
164    /// When true, applies a conservative Wiener filter to remove estimated noise
165    /// before encoding. The decoder re-adds noise from the encoded parameters.
166    /// Provides 1-8% file size savings with near-zero Butteraugli quality impact.
167    /// Off by default (libjxl does not have a denoising pre-filter).
168    pub enable_denoise: bool,
169    /// Enable gaborish inverse pre-filter.
170    /// When true (default), applies a 5x5 sharpening kernel to XYB before DCT
171    /// and signals gab=1 in the frame header. The decoder applies a 3x3 blur
172    /// to compensate, reducing blocking artifacts.
173    /// Matches the libjxl VarDCT encoder default.
174    pub enable_gaborish: bool,
175    /// Enable error diffusion in AC quantization.
176    /// When true, spreads quantization error to neighboring coefficients in
177    /// zigzag order, helping preserve smooth gradients at high compression.
178    /// Off by default (modest quality improvement, slight performance cost).
179    pub error_diffusion: bool,
180    /// Enable pixel-domain loss calculation in AC strategy selection.
181    /// When true, uses full libjxl's pixel-domain loss model (IDCT error,
182    /// per-pixel masking, 8th power norm). This provides better distance
183    /// calibration matching cjxl's output.
184    /// When false (default), uses coefficient-domain loss (libjxl-tiny style).
185    /// Note: Requires `ac_strategy_enabled` to have any effect.
186    pub pixel_domain_loss: bool,
187    /// Enable LZ77 backward references in entropy coding.
188    /// When true, compresses token streams using LZ77 length+distance tokens.
189    /// Only effective with two-pass mode (optimize_codes=true) and ANS (use_ans=true).
190    /// Off by default — works for most cases but has known interactions with certain
191    /// forced strategy combinations (DCT2x2, IDENTITY) that cause InvalidAnsStream.
192    pub enable_lz77: bool,
193    /// LZ77 method to use when enable_lz77 is true.
194    ///
195    /// - `Rle`: Only matches consecutive identical values (fast, limited on photos)
196    /// - `Greedy`: Hash chain backward references (slower, 1-3% better on photos)
197    ///
198    /// Default: `Greedy` (best compression)
199    pub lz77_method: crate::entropy_coding::lz77::Lz77Method,
200    /// Enable DC tree learning.
201    /// When true, learns an optimal context tree for DC coding from image content
202    /// instead of using the fixed GRADIENT_CONTEXT_LUT.
203    /// **DISABLED/BROKEN**: The learned tree doesn't correctly route AC metadata
204    /// samples to contexts 0-10. Fixing requires parsing the static tree structure
205    /// and splicing in the learned DC subtree while preserving AC metadata routing.
206    /// Expected gain (~1.2% overall) doesn't justify the complexity. See CLAUDE.md.
207    pub dc_tree_learning: bool,
208    /// Number of butteraugli quantization loop iterations.
209    /// When > 0, iteratively refines the per-block quant field using butteraugli
210    /// perceptual distance feedback. Each iteration: encode → reconstruct → measure
211    /// → adjust quant_field. AC strategy is kept fixed; only quant_field changes.
212    ///
213    /// libjxl uses 2 iterations at effort 8, 4 at effort 9.
214    /// Requires the `butteraugli-loop` feature.
215    ///
216    /// Default: 0 (disabled)
217    #[cfg(feature = "butteraugli-loop")]
218    pub butteraugli_iters: u32,
219    /// Whether the input has 16-bit samples. When true, the file header signals
220    /// bit_depth=16 instead of 8. The actual VarDCT encoding is the same (XYB
221    /// is always f32 internally), but the decoder uses this to reconstruct at
222    /// the correct output bit depth.
223    pub bit_depth_16: bool,
224    /// ICC profile to embed in the codestream.
225    /// When Some, writes has_icc=1 and encodes the profile after the file header.
226    pub icc_profile: Option<Vec<u8>>,
227}
228
229impl Default for VarDctEncoder {
230    fn default() -> Self {
231        Self {
232            distance: 1.0,
233            optimize_codes: true,
234            enhanced_clustering: true, // Pair-merge refinement helps ANS (larger header savings)
235            use_ans: true,             // ANS produces 4-10% smaller files than Huffman
236            cfl_enabled: true,
237            ac_strategy_enabled: true,
238            custom_orders: true,
239            force_strategy: None,
240            enable_noise: false,
241            enable_denoise: false,
242            enable_gaborish: true,
243            error_diffusion: true, // libjxl enables at speed_tier <= kSquirrel (effort 7)
244            pixel_domain_loss: true, // Full libjxl pixel-domain loss: +0.2-1.9 SSIM2 at all distances
245            enable_lz77: false,      // LZ77 has known interactions with DCT2x2/IDENTITY strategies
246            lz77_method: crate::entropy_coding::lz77::Lz77Method::Greedy, // Best compression
247            dc_tree_learning: false, // DC tree learning (experimental)
248            #[cfg(feature = "butteraugli-loop")]
249            butteraugli_iters: 0, // Effort-gated: default off (effort 7). Set via LossyConfig.
250            bit_depth_16: false,
251            icc_profile: None,
252        }
253    }
254}
255
256impl VarDctEncoder {
257    /// Create a new tiny encoder with the given distance.
258    pub fn new(distance: f32) -> Self {
259        Self {
260            distance,
261            optimize_codes: true,
262            enhanced_clustering: true, // Pair-merge refinement helps ANS (larger header savings)
263            use_ans: true,             // ANS produces 4-10% smaller files than Huffman
264            cfl_enabled: true,
265            ac_strategy_enabled: true,
266            custom_orders: true,
267            force_strategy: None,
268            enable_noise: false,
269            enable_denoise: false,
270            enable_gaborish: true,
271            error_diffusion: true, // libjxl enables at speed_tier <= kSquirrel (effort 7)
272            pixel_domain_loss: true, // Full libjxl pixel-domain loss: +0.2-1.9 SSIM2
273            enable_lz77: false,    // LZ77 has known interactions with DCT2x2/IDENTITY strategies
274            lz77_method: crate::entropy_coding::lz77::Lz77Method::Greedy, // Best compression
275            dc_tree_learning: false, // DC tree learning (experimental)
276            #[cfg(feature = "butteraugli-loop")]
277            butteraugli_iters: 0, // Effort-gated: default off (effort 7). Set via LossyConfig.
278            bit_depth_16: false,
279            icc_profile: None,
280        }
281    }
282
283    /// Encode an image in linear sRGB format, optionally with an alpha channel.
284    ///
285    /// Input should be 3 channels (RGB) of f32 values in [0, 1] range.
286    /// Values outside [0, 1] are allowed for out-of-gamut colors.
287    ///
288    /// If `alpha` is provided, it must be `width * height` bytes of u8 alpha values.
289    /// Alpha is encoded as a modular extra channel alongside the VarDCT RGB data.
290    pub fn encode(
291        &self,
292        width: usize,
293        height: usize,
294        linear_rgb: &[f32],
295        alpha: Option<&[u8]>,
296    ) -> Result<VarDctOutput> {
297        assert_eq!(linear_rgb.len(), width * height * 3);
298        if let Some(a) = alpha {
299            assert_eq!(a.len(), width * height);
300        }
301
302        // Calculate dimensions
303        let xsize_blocks = div_ceil(width, BLOCK_DIM);
304        let ysize_blocks = div_ceil(height, BLOCK_DIM);
305        let xsize_groups = div_ceil(width, GROUP_DIM);
306        let ysize_groups = div_ceil(height, GROUP_DIM);
307        let xsize_dc_groups = div_ceil(width, DC_GROUP_DIM);
308        let ysize_dc_groups = div_ceil(height, DC_GROUP_DIM);
309        let num_groups = xsize_groups * ysize_groups;
310        let num_dc_groups = xsize_dc_groups * ysize_dc_groups;
311
312        // Number of sections: DC global + DC groups + AC global + AC groups
313        let num_sections = 2 + num_dc_groups + num_groups;
314
315        // Pad to block boundary dimensions
316        let padded_width = xsize_blocks * BLOCK_DIM;
317        let padded_height = ysize_blocks * BLOCK_DIM;
318
319        // Convert to XYB with edge-replicated padding to block boundaries.
320        // This allows SIMD to process full blocks without bounds checking.
321        let (mut xyb_x, mut xyb_y, mut xyb_b) =
322            self.convert_to_xyb_padded(width, height, padded_width, padded_height, linear_rgb);
323
324        // Estimate noise parameters (if enabled).
325        // The decoder adds noise during rendering; the encoder just encodes the params.
326        let noise_params = if self.enable_noise {
327            let quality_coef = noise_quality_coef(self.distance);
328            let params = estimate_noise_params(
329                &xyb_x,
330                &xyb_y,
331                &xyb_b,
332                padded_width,
333                padded_height,
334                quality_coef,
335            );
336
337            // Apply denoising pre-filter if enabled and noise was detected.
338            // Removes estimated noise before encoding so the encoder spends fewer
339            // bits on noise; the decoder re-adds it from the encoded parameters.
340            if self.enable_denoise
341                && let Some(ref p) = params
342            {
343                denoise_xyb(
344                    &mut xyb_x,
345                    &mut xyb_y,
346                    &mut xyb_b,
347                    padded_width,
348                    padded_height,
349                    p,
350                    quality_coef,
351                );
352            }
353
354            params
355        } else {
356            None
357        };
358
359        // Compute pixel chromacity stats BEFORE gaborish (matching libjxl pipeline).
360        // Gaborish sharpening inflates gradients, producing overly aggressive adjustment.
361        let pixel_stats = super::frame::PixelStatsForChromacityAdjustment::calc(
362            &xyb_x,
363            &xyb_y,
364            &xyb_b,
365            padded_width,
366            padded_height,
367        );
368        let chromacity_x = pixel_stats.how_much_is_x_channel_pixelized();
369        let chromacity_b = pixel_stats.how_much_is_b_channel_pixelized();
370
371        // Apply gaborish inverse (5x5 sharpening) before adaptive quant.
372        // The decoder will apply a 3x3 blur to compensate.
373        if self.enable_gaborish {
374            gaborish_inverse(
375                &mut xyb_x,
376                &mut xyb_y,
377                &mut xyb_b,
378                padded_width,
379                padded_height,
380            );
381        }
382
383        // Compute adaptive per-block quantization field and masking.
384        // Pass padded dimensions: XYB buffers have stride=padded_width, and all
385        // modulation/extraction functions index as [py * stride + px].
386        // When gaborish is off, scale distance by 0.62 for the quant field only
387        // (not global_scale/quant_dc). This matches libjxl enc_heuristics.cc:1119.
388        let distance_for_iqf = if self.enable_gaborish {
389            self.distance
390        } else {
391            self.distance * 0.62
392        };
393
394        // Step 1: Compute float quant field (independent of global_scale)
395        let (quant_field_float, masking) = compute_quant_field_float(
396            &xyb_x,
397            &xyb_y,
398            &xyb_b,
399            padded_width,
400            padded_height,
401            xsize_blocks,
402            ysize_blocks,
403            distance_for_iqf,
404        );
405
406        // Step 2: Compute distance params with content-adaptive global_scale.
407        // Uses median and MAD of the quant field to adapt quantization precision
408        // to image content (matches libjxl ComputeGlobalScaleAndQuant).
409        let mut params =
410            DistanceParams::compute_from_quant_field(self.distance, &quant_field_float);
411
412        // Apply pixel-level chromacity adjustments using pre-gaborish stats
413        params.apply_chromacity_adjustment(chromacity_x, chromacity_b);
414
415        // Step 3: Quantize float quant field to raw u8 with adaptive inv_scale
416        let mut quant_field = quantize_quant_field(&quant_field_float, params.inv_scale);
417
418        // Compute per-tile chroma-from-luma map
419        let cfl_map = if self.cfl_enabled {
420            compute_cfl_map(
421                &xyb_x,
422                &xyb_y,
423                &xyb_b,
424                padded_width,
425                padded_height,
426                xsize_blocks,
427                ysize_blocks,
428            )
429        } else {
430            CflMap::zeros(
431                div_ceil(xsize_blocks, TILE_DIM_IN_BLOCKS),
432                div_ceil(ysize_blocks, TILE_DIM_IN_BLOCKS),
433            )
434        };
435
436        // Compute per-pixel mask for pixel-domain loss (full libjxl cost model)
437        // Only compute if AC strategy selection is enabled
438        let mask1x1 = if self.ac_strategy_enabled && self.pixel_domain_loss {
439            Some(compute_mask1x1(&xyb_y, padded_width, padded_height))
440        } else {
441            None
442        };
443
444        // Compute adaptive AC strategy (DCT8/DCT16x8/DCT8x16/DCT16x16/DCT32x32)
445        let ac_strategy = if let Some(forced) = self.force_strategy {
446            // Force a specific strategy for all blocks that fit
447            force_strategy_map(xsize_blocks, ysize_blocks, forced)
448        } else if !self.ac_strategy_enabled {
449            AcStrategyMap::new_dct8(xsize_blocks, ysize_blocks)
450        } else {
451            compute_ac_strategy(
452                &xyb_x,
453                &xyb_y,
454                &xyb_b,
455                padded_width,
456                padded_height,
457                xsize_blocks,
458                ysize_blocks,
459                self.distance,
460                &quant_field_float,
461                &masking,
462                &cfl_map,
463                mask1x1.as_deref(),
464                padded_width,
465            )
466        };
467
468        // Debug: print strategy histogram if enabled
469        #[cfg(feature = "debug-ac-strategy")]
470        {
471            eprintln!(
472                "AC strategy mode: {}",
473                if mask1x1.is_some() {
474                    "pixel-domain"
475                } else {
476                    "coefficient-domain"
477                }
478            );
479            ac_strategy.print_histogram();
480        }
481
482        // Adjust quant field for multi-block transforms.
483        // At low distances uses max, at high distances blends toward mean for better quality.
484        adjust_quant_field_with_distance(&ac_strategy, &mut quant_field, self.distance);
485
486        // Butteraugli quantization loop: iteratively refine quant_field using
487        // perceptual distance feedback. AC strategy is fixed; only quant_field changes.
488        #[cfg(feature = "butteraugli-loop")]
489        if self.butteraugli_iters > 0 {
490            let initial_quant_field = quant_field.clone();
491            self.butteraugli_refine_quant_field(
492                linear_rgb,
493                width,
494                height,
495                &xyb_x,
496                &xyb_y,
497                &xyb_b,
498                padded_width,
499                padded_height,
500                xsize_blocks,
501                ysize_blocks,
502                &params,
503                &mut quant_field,
504                &initial_quant_field,
505                &cfl_map,
506                &ac_strategy,
507            );
508        }
509
510        // Perform DCT and quantization (XYB data is padded to block boundaries)
511        let transform_out = self.transform_and_quantize(
512            &xyb_x,
513            &xyb_y,
514            &xyb_b,
515            padded_width,
516            xsize_blocks,
517            ysize_blocks,
518            &params,
519            &mut quant_field,
520            &cfl_map,
521            &ac_strategy,
522        );
523        let quant_dc = &transform_out.quant_dc;
524        let quant_ac = &transform_out.quant_ac;
525        let nzeros = &transform_out.nzeros;
526        let raw_nzeros = &transform_out.raw_nzeros;
527
528        // Compute per-block EPF sharpness map when EPF is active
529        let sharpness_map = if params.epf_iters > 0 && self.distance >= 0.5 {
530            let mask = mask1x1.unwrap_or_else(|| {
531                super::adaptive_quant::compute_mask1x1(&xyb_y, padded_width, padded_height)
532            });
533            Some(super::epf::compute_epf_sharpness(
534                [&xyb_x, &xyb_y, &xyb_b],
535                quant_dc,
536                quant_ac,
537                &quant_field,
538                &mask,
539                &params,
540                &cfl_map,
541                &ac_strategy,
542                self.enable_gaborish,
543                xsize_blocks,
544                ysize_blocks,
545            ))
546        } else {
547            None
548        };
549
550        // Two-pass mode: collect tokens, build optimal codes, write bitstream
551        if self.optimize_codes {
552            let strategy_counts = ac_strategy.strategy_histogram();
553            let data = self.encode_two_pass(
554                width,
555                height,
556                &params,
557                xsize_blocks,
558                ysize_blocks,
559                xsize_groups,
560                ysize_groups,
561                xsize_dc_groups,
562                ysize_dc_groups,
563                num_groups,
564                num_dc_groups,
565                num_sections,
566                quant_dc,
567                quant_ac,
568                nzeros,
569                raw_nzeros,
570                &quant_field,
571                &cfl_map,
572                &ac_strategy,
573                &noise_params,
574                sharpness_map.as_deref(),
575                alpha,
576            )?;
577            return Ok(VarDctOutput {
578                data,
579                strategy_counts,
580            });
581        }
582
583        // Get static entropy codes (wrapped in BuiltEntropyCode for uniform handling)
584        let dc_code = BuiltEntropyCode::StaticHuffman(get_dc_entropy_code());
585        let ac_code = BuiltEntropyCode::StaticHuffman(get_ac_entropy_code());
586
587        // Create main writer
588        let mut writer = BitWriter::with_capacity(width * height * 4);
589
590        // Write file header (includes JXL signature, ICC, and byte padding)
591        // Streaming path does not support alpha
592        self.write_file_header_and_pad(width, height, false, &mut writer)?;
593        #[cfg(feature = "debug-tokens")]
594        debug_log!(
595            "After file header: bit {} (byte {})",
596            writer.bits_written(),
597            writer.bits_written() / 8
598        );
599
600        // Write frame header
601        {
602            let mut fh = FrameHeader::lossy();
603            fh.x_qm_scale = params.x_qm_scale;
604            fh.b_qm_scale = params.b_qm_scale;
605            fh.epf_iters = params.epf_iters;
606            fh.gaborish = self.enable_gaborish;
607            if noise_params.is_some() {
608                fh.flags |= 0x01; // ENABLE_NOISE
609            }
610            // streaming path: no extra channels
611            fh.write(&mut writer)?;
612        }
613        #[cfg(feature = "debug-tokens")]
614        debug_log!(
615            "After frame header: bit {} (byte {})",
616            writer.bits_written(),
617            writer.bits_written() / 8
618        );
619
620        // For single-group images, combine all sections at the bit level
621        // (no byte padding between sections, only at the end)
622        if num_sections == 4 {
623            // Write sections to individual BitWriters (no padding)
624            let block_ctx_map = super::ac_context::BlockCtxMap::default();
625            let num_blocks = xsize_blocks * ysize_blocks;
626            let mut dc_global = BitWriter::with_capacity(4096);
627            self.write_dc_global(
628                &params,
629                num_dc_groups,
630                &dc_code,
631                &noise_params,
632                None,
633                &block_ctx_map,
634                None, // No learned tree in single-pass mode
635                &mut dc_global,
636            )?;
637
638            // Get borrowed Huffman codes for streaming token writing
639            let dc_huffman = dc_code.as_huffman();
640            let ac_huffman = ac_code.as_huffman();
641
642            let mut dc_group = BitWriter::with_capacity(num_blocks * 10);
643            self.write_dc_group(
644                0,
645                quant_dc,
646                xsize_blocks,
647                ysize_blocks,
648                xsize_dc_groups,
649                &quant_field,
650                &cfl_map,
651                &ac_strategy,
652                None, // no sharpness map in single-pass mode
653                &dc_huffman,
654                &mut dc_group,
655            )?;
656
657            let mut ac_global = BitWriter::with_capacity(4096);
658            self.write_ac_global(num_groups, &ac_code, 0, None, None, &mut ac_global)?;
659
660            let mut ac_group_writer = BitWriter::with_capacity(num_blocks * 100);
661            self.write_ac_group(
662                0,
663                quant_ac,
664                nzeros,
665                raw_nzeros,
666                xsize_blocks,
667                ysize_blocks,
668                xsize_groups,
669                &quant_field,
670                &ac_strategy,
671                &block_ctx_map,
672                &ac_huffman,
673                &mut ac_group_writer,
674            )?;
675
676            #[cfg(feature = "debug-tokens")]
677            {
678                debug_log!(
679                    "Section bit counts: DC_global={}, DC_group={}, AC_global={}, AC_group={}",
680                    dc_global.bits_written(),
681                    dc_group.bits_written(),
682                    ac_global.bits_written(),
683                    ac_group_writer.bits_written()
684                );
685            }
686
687            // Combine at bit level
688            let mut combined = dc_global;
689            #[cfg(feature = "debug-tokens")]
690            debug_log!("After DC_global: {} bits", combined.bits_written());
691            combined.append_unaligned(&dc_group)?;
692            #[cfg(feature = "debug-tokens")]
693            debug_log!("After DC_group: {} bits", combined.bits_written());
694            combined.append_unaligned(&ac_global)?;
695            #[cfg(feature = "debug-tokens")]
696            debug_log!("After AC_global: {} bits", combined.bits_written());
697            combined.append_unaligned(&ac_group_writer)?;
698            #[cfg(feature = "debug-tokens")]
699            debug_log!("After AC_group: {} bits", combined.bits_written());
700            combined.zero_pad_to_byte();
701            let combined_bytes = combined.finish();
702
703            #[cfg(feature = "debug-tokens")]
704            {
705                debug_log!("Combined section size: {} bytes", combined_bytes.len());
706                debug_log!(
707                    "Before TOC: bit {} (byte {})",
708                    writer.bits_written(),
709                    writer.bits_written() / 8
710                );
711            }
712            write_toc(&[combined_bytes.len()], &mut writer)?;
713            #[cfg(feature = "debug-tokens")]
714            debug_log!(
715                "After TOC: bit {} (byte {})",
716                writer.bits_written(),
717                writer.bits_written() / 8
718            );
719            writer.append_bytes(&combined_bytes)?;
720        } else {
721            // Multi-group: use byte-aligned sections
722            let mut sections: Vec<Vec<u8>> = Vec::with_capacity(num_sections);
723            let dc_huffman = dc_code.as_huffman();
724            let ac_huffman = ac_code.as_huffman();
725
726            // DC Global section
727            let block_ctx_map = super::ac_context::BlockCtxMap::default();
728            let mut dc_global = BitWriter::with_capacity(4096);
729            self.write_dc_global(
730                &params,
731                num_dc_groups,
732                &dc_code,
733                &noise_params,
734                None,
735                &block_ctx_map,
736                None, // No learned tree in single-pass mode
737                &mut dc_global,
738            )?;
739            dc_global.zero_pad_to_byte();
740            sections.push(dc_global.finish());
741
742            // DC group sections
743            let blocks_per_dc_group = (256 / 8) * (256 / 8); // 1024 blocks per DC group
744            for dc_group_idx in 0..num_dc_groups {
745                let mut dc_group = BitWriter::with_capacity(blocks_per_dc_group * 10);
746                self.write_dc_group(
747                    dc_group_idx,
748                    quant_dc,
749                    xsize_blocks,
750                    ysize_blocks,
751                    xsize_dc_groups,
752                    &quant_field,
753                    &cfl_map,
754                    &ac_strategy,
755                    None, // no sharpness map in single-pass mode
756                    &dc_huffman,
757                    &mut dc_group,
758                )?;
759                dc_group.zero_pad_to_byte();
760                sections.push(dc_group.finish());
761            }
762
763            // AC Global section
764            let mut ac_global = BitWriter::with_capacity(4096);
765            self.write_ac_global(num_groups, &ac_code, 0, None, None, &mut ac_global)?;
766            ac_global.zero_pad_to_byte();
767            sections.push(ac_global.finish());
768
769            // AC group sections
770            let blocks_per_ac_group = (256 / 8) * (256 / 8); // 1024 blocks per AC group
771            for group_idx in 0..num_groups {
772                let mut ac_group_writer = BitWriter::with_capacity(blocks_per_ac_group * 100);
773                self.write_ac_group(
774                    group_idx,
775                    quant_ac,
776                    nzeros,
777                    raw_nzeros,
778                    xsize_blocks,
779                    ysize_blocks,
780                    xsize_groups,
781                    &quant_field,
782                    &ac_strategy,
783                    &block_ctx_map,
784                    &ac_huffman,
785                    &mut ac_group_writer,
786                )?;
787                ac_group_writer.zero_pad_to_byte();
788                sections.push(ac_group_writer.finish());
789            }
790
791            let section_sizes: Vec<usize> = sections.iter().map(|s| s.len()).collect();
792            write_toc(&section_sizes, &mut writer)?;
793            for section in sections {
794                writer.append_bytes(&section)?;
795            }
796        }
797
798        let strategy_counts = ac_strategy.strategy_histogram();
799        Ok(VarDctOutput {
800            data: writer.finish_with_padding(),
801            strategy_counts,
802        })
803    }
804
805    /// Butteraugli quantization loop: iteratively refines per-block quant_field
806    /// by measuring perceptual distance (butteraugli) between the original image
807    /// and the reconstruction from quantized coefficients.
808    ///
809    /// Algorithm (libjxl FindBestQuantization):
810    /// For each iteration:
811    ///   1. transform_and_quantize with current quant_field
812    ///   2. reconstruct XYB → apply gab → EPF → XYB-to-linear
813    ///   3. butteraugli(original_linear, reconstructed_linear) → per-block distmap
814    ///   4. For blocks where distmap > target: increase quant (qf *= distmap/target)
815    ///      For blocks where distmap < target: decrease quant (qf *= distmap/target)
816    ///   5. Clamp and constrain (don't diverge too far from initial)
817    ///
818    /// AC strategy is FIXED throughout — only quant_field changes.
819    #[cfg(feature = "butteraugli-loop")]
820    #[allow(clippy::too_many_arguments)]
821    pub(crate) fn butteraugli_refine_quant_field(
822        &self,
823        linear_rgb: &[f32],
824        width: usize,
825        height: usize,
826        xyb_x: &[f32],
827        xyb_y: &[f32],
828        xyb_b: &[f32],
829        padded_width: usize,
830        padded_height: usize,
831        xsize_blocks: usize,
832        ysize_blocks: usize,
833        params: &DistanceParams,
834        quant_field: &mut [u8],
835        initial_quant_field: &[u8],
836        cfl_map: &CflMap,
837        ac_strategy: &AcStrategyMap,
838    ) {
839        use super::epf;
840        use super::reconstruct::{gab_smooth, reconstruct_xyb, xyb_to_linear_rgb_planar};
841
842        let target_distance = self.distance;
843        let num_blocks = xsize_blocks * ysize_blocks;
844        let padded_pixels = padded_width * padded_height;
845
846        // Precompute butteraugli reference from original image ONCE.
847        // This saves ~40-50% of butteraugli time per iteration by caching
848        // the XYB conversion and frequency decomposition of the reference.
849        let butteraugli_params = butteraugli::ButteraugliParams::new()
850            .with_intensity_target(80.0)
851            .with_compute_diffmap(true);
852        let reference = match butteraugli::ButteraugliReference::new_linear(
853            linear_rgb,
854            width,
855            height,
856            butteraugli_params,
857        ) {
858            Ok(r) => r,
859            Err(_) => return, // Bail on error (e.g., image too small)
860        };
861
862        // Work in f32 during the loop for precision (libjxl uses float quant_field).
863        // Converting to u8 each iteration loses ~0.5-1.5 per value, accumulating over iters.
864        let mut qf_float: Vec<f32> = quant_field.iter().map(|&v| v as f32).collect();
865        let initial_qf_float: Vec<f32> = initial_quant_field.iter().map(|&v| v as f32).collect();
866
867        // Compute qf_lower/qf_higher deviation bounds (matching libjxl lines 968-976).
868        // These prevent the quant field from diverging too far from the initial field,
869        // avoiding oscillation and wild over/under-quantization.
870        let initial_qf_min = initial_qf_float
871            .iter()
872            .copied()
873            .reduce(f32::min)
874            .unwrap_or(1.0)
875            .max(1.0);
876        let initial_qf_max = initial_qf_float
877            .iter()
878            .copied()
879            .reduce(f32::max)
880            .unwrap_or(255.0);
881        let initial_qf_ratio = initial_qf_max / initial_qf_min;
882        let qf_max_deviation_low = (250.0 / initial_qf_ratio).sqrt();
883        let asymmetry = 2.0f32.min(qf_max_deviation_low);
884        let qf_lower = (initial_qf_min / (asymmetry * qf_max_deviation_low)).max(1.0);
885        let qf_higher = (initial_qf_max * (qf_max_deviation_low / asymmetry)).min(255.0);
886
887        // Pre-allocate buffers reused across butteraugli iterations
888        let mut qf_copy = vec![0u8; quant_field.len()];
889        let sharpness = vec![4u8; num_blocks];
890        let mut tile_dist = vec![0.0f32; num_blocks];
891        // Planar reconstruction buffers (padded dimensions, reused across iterations)
892        let mut recon_r = vec![0.0f32; padded_pixels];
893        let mut recon_g = vec![0.0f32; padded_pixels];
894        let mut recon_b = vec![0.0f32; padded_pixels];
895        let mut transform_out = super::transform::TransformOutput::new(xsize_blocks, ysize_blocks);
896
897        for iter in 0..self.butteraugli_iters {
898            // Step 1: Quantize with current quant_field (convert float→u8 for quantizer)
899            for (dst, &src) in qf_copy.iter_mut().zip(qf_float.iter()) {
900                *dst = (src.round() as u8).clamp(1, 255);
901            }
902            self.transform_and_quantize_into(
903                xyb_x,
904                xyb_y,
905                xyb_b,
906                padded_width,
907                xsize_blocks,
908                ysize_blocks,
909                params,
910                &mut qf_copy,
911                cfl_map,
912                ac_strategy,
913                &mut transform_out,
914            );
915
916            // Step 2: Reconstruct XYB from quantized coefficients
917            let mut planes = reconstruct_xyb(
918                &transform_out.quant_dc,
919                &transform_out.quant_ac,
920                params,
921                &qf_copy,
922                cfl_map,
923                ac_strategy,
924                xsize_blocks,
925                ysize_blocks,
926            );
927
928            // Apply gaborish smooth if enabled
929            if self.enable_gaborish {
930                gab_smooth(&mut planes, padded_width, padded_height);
931            }
932
933            // Apply EPF if active
934            if params.epf_iters > 0 {
935                epf::apply_epf(
936                    &mut planes,
937                    &qf_copy,
938                    &sharpness,
939                    params.scale,
940                    params.epf_iters,
941                    xsize_blocks,
942                    ysize_blocks,
943                    padded_width,
944                    padded_height,
945                );
946            }
947
948            // Step 3: Convert reconstructed XYB to planar linear RGB (in-place, no interleave)
949            xyb_to_linear_rgb_planar(
950                &planes[0],
951                &planes[1],
952                &planes[2],
953                &mut recon_r,
954                &mut recon_g,
955                &mut recon_b,
956                padded_pixels,
957            );
958
959            // Step 4: Compare against precomputed reference using planar API.
960            // Pass padded buffers with stride=padded_width; butteraugli reads only
961            // width pixels per row, skipping the padding — no crop copy needed.
962            let result =
963                match reference.compare_linear_planar(&recon_r, &recon_g, &recon_b, padded_width) {
964                    Ok(r) => r,
965                    Err(_) => return,
966                };
967
968            let diffmap = match result.diffmap {
969                Some(dm) => dm,
970                None => return,
971            };
972
973            // Step 5: Compute per-block tile distance (16th-power norm, matching libjxl)
974            // libjxl uses TileDistMap with 16th-norm and kTileNorm=1.2 scaling
975            const K_TILE_NORM: f32 = 1.2;
976            let diffmap_buf = diffmap.buf();
977            tile_dist.fill(0.0);
978            for by in 0..ysize_blocks {
979                for bx in 0..xsize_blocks {
980                    if !ac_strategy.is_first(bx, by) {
981                        continue;
982                    }
983                    let covered_x = ac_strategy.covered_blocks_x(bx, by);
984                    let covered_y = ac_strategy.covered_blocks_y(bx, by);
985                    let px_start_x = bx * BLOCK_DIM;
986                    let px_start_y = by * BLOCK_DIM;
987                    let px_end_x = ((bx + covered_x) * BLOCK_DIM).min(width);
988                    let px_end_y = ((by + covered_y) * BLOCK_DIM).min(height);
989                    if px_start_x >= width || px_start_y >= height {
990                        continue;
991                    }
992                    let mut dist_norm = 0.0f64;
993                    let mut pixels = 0.0f64;
994                    for py in px_start_y..px_end_y {
995                        for px in px_start_x..px_end_x {
996                            let v = diffmap_buf[py * width + px] as f64;
997                            // v^16 (16th-power norm)
998                            let v2 = v * v;
999                            let v4 = v2 * v2;
1000                            let v8 = v4 * v4;
1001                            let v16 = v8 * v8;
1002                            dist_norm += v16;
1003                            pixels += 1.0;
1004                        }
1005                    }
1006                    if pixels == 0.0 {
1007                        pixels = 1.0;
1008                    }
1009                    // x^(1/16) = sqrt(sqrt(sqrt(sqrt(x))))
1010                    let td = K_TILE_NORM * (dist_norm / pixels).sqrt().sqrt().sqrt().sqrt() as f32;
1011                    // Fill all sub-blocks of this transform
1012                    for sy in 0..covered_y {
1013                        for sx in 0..covered_x {
1014                            tile_dist[(by + sy) * xsize_blocks + (bx + sx)] = td;
1015                        }
1016                    }
1017                }
1018            }
1019
1020            // Step 6: Constrain and adjust quant_field based on tile distances.
1021            //
1022            // Convention: higher qf = finer quantization = better quality (same as libjxl).
1023            // quantize_coeff_ac: val = coef * inv_weight * qac * qm_mul
1024            // Higher qac (from higher qf) → larger quantized int → more precision.
1025            //
1026            // libjxl order: constrain toward initial (kOriginalComparisonRound=1),
1027            // THEN adjust based on tile distances. Both phases enforce qf_lower/qf_higher.
1028
1029            // kOriginalComparisonRound = 1: constrain toward initial BEFORE adjustment.
1030            // Prevents oscillation by keeping qf from diverging too far from initial.
1031            if iter == 1 {
1032                const K_INIT_MUL: f64 = 0.6;
1033                const K_ONE_MINUS_INIT_MUL: f64 = 1.0 - K_INIT_MUL;
1034                for bi in 0..num_blocks {
1035                    let init_qf = initial_qf_float[bi] as f64;
1036                    let cur_qf = qf_float[bi] as f64;
1037                    let clamp_val = K_ONE_MINUS_INIT_MUL * cur_qf + K_INIT_MUL * init_qf;
1038                    if cur_qf < clamp_val {
1039                        qf_float[bi] = (clamp_val as f32).clamp(qf_lower, qf_higher);
1040                    }
1041                }
1042            }
1043
1044            // Adjust quant_field based on tile distances.
1045            // ASYMMETRIC: aggressively fix bad blocks, barely touch good blocks.
1046            // kPow = [0.2, 0.2, 0, 0, ...] — only iters 0-1 touch good blocks, gently.
1047            let cur_pow: f64 = if iter < 2 {
1048                0.2 + (target_distance as f64 - 1.0) * 0.0 // kPowMod[0..1] = 0
1049            } else {
1050                0.0
1051            };
1052
1053            for bi in 0..num_blocks {
1054                let diff = tile_dist[bi] / target_distance;
1055                let old_qf = qf_float[bi];
1056
1057                if diff <= 1.0 {
1058                    // Quality is good enough — save bits by reducing precision.
1059                    if cur_pow != 0.0 {
1060                        // diff < 1 → pow(diff, 0.2) < 1 → qf decreases slightly.
1061                        qf_float[bi] = old_qf * (diff as f64).powf(cur_pow) as f32;
1062                    }
1063                    // cur_pow == 0: don't touch good blocks on later iterations
1064                } else {
1065                    // Quality too bad — aggressively improve by increasing qf.
1066                    qf_float[bi] = old_qf * diff;
1067                    // Ensure at least 1 integer step change (matching libjxl's rounding check)
1068                    if qf_float[bi].round() as u8 == old_qf.round() as u8 {
1069                        qf_float[bi] = old_qf + 1.0;
1070                    }
1071                }
1072                // Enforce deviation bounds after every adjustment (matching libjxl)
1073                qf_float[bi] = qf_float[bi].clamp(qf_lower, qf_higher);
1074            }
1075
1076            eprintln!(
1077                "  Butteraugli iter {}: score={:.3} (target={:.3})",
1078                iter, result.score, target_distance,
1079            );
1080        }
1081
1082        // Convert float quant_field back to u8 for final encoding
1083        for (dst, &src) in quant_field.iter_mut().zip(qf_float.iter()) {
1084            *dst = (src.round() as u8).clamp(1, 255);
1085        }
1086    }
1087
1088    /// Encode with iterative rate control for improved distance targeting.
1089    ///
1090    /// This method:
1091    /// 1. Computes precomputed state (XYB, CfL, masking, AC strategy) once
1092    /// 2. Loops: encode → decode → butteraugli → adjust quant field
1093    /// 3. Returns when converged (within 5% of target) or max iterations reached
1094    ///
1095    /// Typically converges in 2-4 iterations. Each iteration costs ~50% of a
1096    /// full encode since XYB conversion, CfL, masking, and AC strategy are reused.
1097    ///
1098    /// Returns the encoded bytes. Use `encode_with_rate_control_config` for
1099    /// iteration count and custom configuration.
1100    ///
1101    /// Requires the `rate-control` feature.
1102    #[cfg(feature = "rate-control")]
1103    pub fn encode_with_rate_control(
1104        &self,
1105        width: usize,
1106        height: usize,
1107        linear_rgb: &[f32],
1108    ) -> Result<Vec<u8>> {
1109        let config = super::rate_control::RateControlConfig::default();
1110        let (encoded, _iters) =
1111            self.encode_with_rate_control_config(width, height, linear_rgb, &config)?;
1112        Ok(encoded)
1113    }
1114
1115    /// Encode with iterative rate control and custom configuration.
1116    ///
1117    /// Returns `(encoded_bytes, iteration_count)`.
1118    ///
1119    /// Requires the `rate-control` feature.
1120    #[cfg(feature = "rate-control")]
1121    pub fn encode_with_rate_control_config(
1122        &self,
1123        width: usize,
1124        height: usize,
1125        linear_rgb: &[f32],
1126        config: &super::rate_control::RateControlConfig,
1127    ) -> Result<(Vec<u8>, usize)> {
1128        // Compute precomputed state
1129        let precomputed = super::precomputed::EncoderPrecomputed::compute(
1130            width,
1131            height,
1132            linear_rgb,
1133            self.distance,
1134            self.cfl_enabled,
1135            self.ac_strategy_enabled,
1136            self.pixel_domain_loss,
1137            self.enable_noise,
1138            self.enable_denoise,
1139            self.enable_gaborish,
1140            self.force_strategy,
1141        );
1142
1143        // Run rate control loop
1144        super::rate_control::encode_with_rate_control(self, &precomputed, config)
1145    }
1146
1147    /// Encode from precomputed state with a specific quant field.
1148    ///
1149    /// This is the core encoding function used by rate control iterations.
1150    /// It skips XYB conversion, CfL, masking, and AC strategy computation,
1151    /// using the values from `precomputed` instead.
1152    ///
1153    /// Requires the `rate-control` feature.
1154    #[cfg(feature = "rate-control")]
1155    pub fn encode_from_precomputed(
1156        &self,
1157        precomputed: &super::precomputed::EncoderPrecomputed,
1158        quant_field: &[u8],
1159    ) -> Result<Vec<u8>> {
1160        let width = precomputed.width;
1161        let height = precomputed.height;
1162        let xsize_blocks = precomputed.xsize_blocks;
1163        let ysize_blocks = precomputed.ysize_blocks;
1164        let padded_width = precomputed.padded_width;
1165
1166        // Calculate group dimensions
1167        let xsize_groups = div_ceil(width, GROUP_DIM);
1168        let ysize_groups = div_ceil(height, GROUP_DIM);
1169        let xsize_dc_groups = div_ceil(width, DC_GROUP_DIM);
1170        let ysize_dc_groups = div_ceil(height, DC_GROUP_DIM);
1171        let num_groups = xsize_groups * ysize_groups;
1172        let num_dc_groups = xsize_dc_groups * ysize_dc_groups;
1173        let num_sections = 2 + num_dc_groups + num_groups;
1174
1175        // Copy and adjust quant field for multi-block transforms
1176        let mut quant_field = quant_field.to_vec();
1177        adjust_quant_field_with_distance(&precomputed.ac_strategy, &mut quant_field, self.distance);
1178
1179        // Compute distance params from precomputed quant field
1180        let mut params =
1181            DistanceParams::compute_from_quant_field(self.distance, &precomputed.quant_field_float);
1182
1183        // Apply pixel-level chromacity adjustments using pre-gaborish stats
1184        params.apply_chromacity_adjustment(
1185            precomputed.chromacity_x_pixelized,
1186            precomputed.chromacity_b_pixelized,
1187        );
1188
1189        // Perform DCT and quantization using precomputed XYB data
1190        let transform_out = self.transform_and_quantize(
1191            &precomputed.xyb_x,
1192            &precomputed.xyb_y,
1193            &precomputed.xyb_b,
1194            padded_width,
1195            xsize_blocks,
1196            ysize_blocks,
1197            &params,
1198            &mut quant_field,
1199            &precomputed.cfl_map,
1200            &precomputed.ac_strategy,
1201        );
1202        let quant_dc = &transform_out.quant_dc;
1203        let quant_ac = &transform_out.quant_ac;
1204        let nzeros = &transform_out.nzeros;
1205        let raw_nzeros = &transform_out.raw_nzeros;
1206
1207        // Use two-pass mode for rate control (required for ANS)
1208        self.encode_two_pass(
1209            width,
1210            height,
1211            &params,
1212            xsize_blocks,
1213            ysize_blocks,
1214            xsize_groups,
1215            ysize_groups,
1216            xsize_dc_groups,
1217            ysize_dc_groups,
1218            num_groups,
1219            num_dc_groups,
1220            num_sections,
1221            quant_dc,
1222            quant_ac,
1223            nzeros,
1224            raw_nzeros,
1225            &quant_field,
1226            &precomputed.cfl_map,
1227            &precomputed.ac_strategy,
1228            &precomputed.noise_params,
1229            None, // TODO: compute sharpness_map for rate control path
1230            None, // TODO: thread alpha through butteraugli path
1231        )
1232    }
1233}
1234
1235#[cfg(test)]
1236mod tests {
1237    use super::*;
1238
1239    #[test]
1240    fn test_encoder_creation() {
1241        let encoder = VarDctEncoder::new(1.0);
1242        assert_eq!(encoder.distance, 1.0);
1243
1244        let encoder_default = VarDctEncoder::default();
1245        assert_eq!(encoder_default.distance, 1.0);
1246    }
1247
1248    #[test]
1249    fn test_encode_small_image() {
1250        let encoder = VarDctEncoder::new(1.0);
1251
1252        // Create a simple 8x8 red image
1253        let width = 8;
1254        let height = 8;
1255        let mut linear_rgb = vec![0.0f32; width * height * 3];
1256        for y in 0..height {
1257            for x in 0..width {
1258                let idx = (y * width + x) * 3;
1259                linear_rgb[idx] = 1.0; // R
1260                linear_rgb[idx + 1] = 0.0; // G
1261                linear_rgb[idx + 2] = 0.0; // B
1262            }
1263        }
1264
1265        // This should at least not panic - full encoding not yet implemented
1266        let result = encoder.encode(width, height, &linear_rgb, None);
1267        // For now, just check it produces some output
1268        assert!(result.is_ok());
1269        let output = result.unwrap();
1270        assert!(output.data.len() > 2);
1271        assert_eq!(output.data[0], 0xFF);
1272        assert_eq!(output.data[1], 0x0A);
1273    }
1274
1275    #[test]
1276    fn test_convert_to_xyb_padded() {
1277        let encoder = VarDctEncoder::new(1.0);
1278
1279        // Gray pixel (1x1 image -> padded to 8x8)
1280        let linear_rgb = vec![0.5, 0.5, 0.5];
1281        let (x, y, b) = encoder.convert_to_xyb_padded(1, 1, 8, 8, &linear_rgb);
1282
1283        // Padded to 8x8 = 64 pixels
1284        assert_eq!(x.len(), 64);
1285        assert_eq!(y.len(), 64);
1286        assert_eq!(b.len(), 64);
1287
1288        // Gray should have X ≈ 0 (equal L and M)
1289        assert!(x[0].abs() < 0.01, "X should be near zero for gray");
1290        assert!(y[0] > 0.0, "Y should be positive");
1291        assert!(b[0] > 0.0, "B should be positive");
1292
1293        // Edge replication: all padded pixels should match the corner
1294        for i in 0..64 {
1295            assert!((x[i] - x[0]).abs() < 1e-6, "All padded X should match");
1296            assert!((y[i] - y[0]).abs() < 1e-6, "All padded Y should match");
1297            assert!((b[i] - b[0]).abs() < 1e-6, "All padded B should match");
1298        }
1299    }
1300
1301    #[test]
1302    fn test_encode_16x16_red_image() {
1303        // Test a 16x16 pixel image (2x2 blocks) to compare with libjxl-tiny
1304        let encoder = VarDctEncoder::new(1.0);
1305
1306        let width = 16;
1307        let height = 16;
1308        let mut linear_rgb = vec![0.0f32; width * height * 3];
1309        for y in 0..height {
1310            for x in 0..width {
1311                let idx = (y * width + x) * 3;
1312                linear_rgb[idx] = 1.0; // R
1313                linear_rgb[idx + 1] = 0.0; // G
1314                linear_rgb[idx + 2] = 0.0; // B
1315            }
1316        }
1317
1318        let result = encoder.encode(width, height, &linear_rgb, None);
1319        assert!(result.is_ok());
1320        let output = result.unwrap();
1321
1322        eprintln!("Output file size: {} bytes", output.data.len());
1323        eprintln!(
1324            "First 32 bytes: {:02x?}",
1325            &output.data[..32.min(output.data.len())]
1326        );
1327
1328        // Write output to file for comparison
1329        std::fs::write(std::env::temp_dir().join("our_16x16.jxl"), &output.data).unwrap();
1330
1331        // libjxl-tiny produces:
1332        // DC_group: 106 bits (14 bytes)
1333        // Total combined: 1086 bytes
1334        // Total file: 1104 bytes
1335        //
1336        // Our encoder should match these sizes
1337
1338        // Check signature
1339        assert_eq!(output.data[0], 0xFF);
1340        assert_eq!(output.data[1], 0x0A);
1341    }
1342
1343    /// Compute a simple hash of a byte slice for output locking.
1344    fn hash_bytes(bytes: &[u8]) -> u64 {
1345        use std::hash::{Hash, Hasher};
1346        let mut hasher = std::collections::hash_map::DefaultHasher::new();
1347        bytes.hash(&mut hasher);
1348        hasher.finish()
1349    }
1350
1351    /// Hash-locked test for 8x8 gradient image.
1352    /// This test ensures the encoder output doesn't change unexpectedly.
1353    /// x86_64 only: FP rounding differs on other architectures and 32-bit.
1354    #[test]
1355    #[cfg(target_arch = "x86_64")]
1356    fn test_hash_lock_8x8_gradient() {
1357        let encoder = VarDctEncoder::new(1.0);
1358        let width = 8;
1359        let height = 8;
1360        let mut linear_rgb = vec![0.0f32; width * height * 3];
1361
1362        // Simple gradient: R increases with x, G with y
1363        for y in 0..height {
1364            for x in 0..width {
1365                let idx = (y * width + x) * 3;
1366                linear_rgb[idx] = x as f32 / 7.0; // R
1367                linear_rgb[idx + 1] = y as f32 / 7.0; // G
1368                linear_rgb[idx + 2] = 0.5; // B
1369            }
1370        }
1371
1372        let bytes = encoder
1373            .encode(width, height, &linear_rgb, None)
1374            .unwrap()
1375            .data;
1376        let hash = hash_bytes(&bytes);
1377
1378        // Lock the hash - if this changes, the encoding has changed
1379        // Updated: unified file header (Phase 1 — shared FileHeader::write())
1380        const EXPECTED_HASH: u64 = 0x1578d7de62b7489d;
1381        assert_eq!(
1382            hash,
1383            EXPECTED_HASH,
1384            "8x8 gradient hash mismatch: got {:#x}, expected {:#x}. \
1385             Output size: {} bytes. If intentional, update EXPECTED_HASH.",
1386            hash,
1387            EXPECTED_HASH,
1388            bytes.len()
1389        );
1390    }
1391
1392    /// Hash-locked test for 16x16 solid color image.
1393    /// x86_64 only: FP rounding differs on other architectures and 32-bit.
1394    #[test]
1395    #[cfg(target_arch = "x86_64")]
1396    fn test_hash_lock_16x16_solid() {
1397        let encoder = VarDctEncoder::new(1.0);
1398        let width = 16;
1399        let height = 16;
1400        let linear_rgb = vec![0.3f32; width * height * 3]; // gray
1401
1402        let bytes = encoder
1403            .encode(width, height, &linear_rgb, None)
1404            .unwrap()
1405            .data;
1406        let hash = hash_bytes(&bytes);
1407
1408        // Updated: butteraugli default off (effort-gated, VarDctEncoder defaults to 0 iters)
1409        const EXPECTED_HASH: u64 = 0x2cf2e7aae4f14de7;
1410        assert_eq!(
1411            hash,
1412            EXPECTED_HASH,
1413            "16x16 solid hash mismatch: got {:#x}, expected {:#x}. \
1414             Output size: {} bytes. If intentional, update EXPECTED_HASH.",
1415            hash,
1416            EXPECTED_HASH,
1417            bytes.len()
1418        );
1419    }
1420
1421    /// Hash-locked test for 64x64 checkerboard pattern.
1422    /// x86_64 only: FP rounding differs on other architectures and 32-bit.
1423    #[test]
1424    #[cfg(target_arch = "x86_64")]
1425    fn test_hash_lock_64x64_checkerboard() {
1426        let encoder = VarDctEncoder::new(1.0);
1427        let width = 64;
1428        let height = 64;
1429        let mut linear_rgb = vec![0.0f32; width * height * 3];
1430
1431        // 8x8 checkerboard pattern
1432        for y in 0..height {
1433            for x in 0..width {
1434                let idx = (y * width + x) * 3;
1435                let checker = ((x / 8) + (y / 8)) % 2 == 0;
1436                let val = if checker { 0.8 } else { 0.2 };
1437                linear_rgb[idx] = val;
1438                linear_rgb[idx + 1] = val;
1439                linear_rgb[idx + 2] = val;
1440            }
1441        }
1442
1443        let bytes = encoder
1444            .encode(width, height, &linear_rgb, None)
1445            .unwrap()
1446            .data;
1447        let hash = hash_bytes(&bytes);
1448
1449        // Updated: AFV strategies enabled in auto-selection
1450        const EXPECTED_HASH: u64 = 0xd91c3989788e5448;
1451        assert_eq!(
1452            hash,
1453            EXPECTED_HASH,
1454            "64x64 checkerboard hash mismatch: got {:#x}, expected {:#x}. \
1455             Output size: {} bytes. If intentional, update EXPECTED_HASH.",
1456            hash,
1457            EXPECTED_HASH,
1458            bytes.len()
1459        );
1460    }
1461
1462    /// Hash-locked test for non-power-of-two size (tests padding).
1463    /// x86_64 only: FP rounding differs on other architectures and 32-bit.
1464    #[test]
1465    #[cfg(target_arch = "x86_64")]
1466    fn test_hash_lock_13x17_noise() {
1467        let encoder = VarDctEncoder::new(1.0);
1468        let width = 13;
1469        let height = 17;
1470        let mut linear_rgb = vec![0.0f32; width * height * 3];
1471
1472        // Deterministic pseudo-random pattern
1473        let mut seed = 12345u64;
1474        for val in &mut linear_rgb {
1475            seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
1476            *val = ((seed >> 32) as f32) / (u32::MAX as f32);
1477        }
1478
1479        let bytes = encoder
1480            .encode(width, height, &linear_rgb, None)
1481            .unwrap()
1482            .data;
1483        let hash = hash_bytes(&bytes);
1484
1485        // Updated: full libjxl adaptive quantization pipeline
1486        const EXPECTED_HASH: u64 = 0x5324c4f675e42ff7;
1487        assert_eq!(
1488            hash,
1489            EXPECTED_HASH,
1490            "13x17 noise hash mismatch: got {:#x}, expected {:#x}. \
1491             Output size: {} bytes. If intentional, update EXPECTED_HASH.",
1492            hash,
1493            EXPECTED_HASH,
1494            bytes.len()
1495        );
1496    }
1497
1498    /// Roundtrip quality test for non-8-aligned dimensions.
1499    ///
1500    /// Encodes a 100x75 gradient, decodes with jxl-oxide, and verifies:
1501    /// 1. Dimensions match
1502    /// 2. Output is a valid JXL file (correct signature, decodable)
1503    ///
1504    /// This catches stride mismatch bugs where padded XYB buffers have
1505    /// stride != width, which corrupts adaptive quant, CfL, and AC strategy.
1506    #[test]
1507    fn test_roundtrip_non_8_aligned() {
1508        for &(w, h) in &[(100, 75), (13, 17), (33, 49), (7, 9)] {
1509            let mut linear_rgb = vec![0.0f32; w * h * 3];
1510
1511            // Smooth gradient (linear RGB)
1512            for y in 0..h {
1513                for x in 0..w {
1514                    let idx = (y * w + x) * 3;
1515                    linear_rgb[idx] = x as f32 / w.max(1) as f32;
1516                    linear_rgb[idx + 1] = y as f32 / h.max(1) as f32;
1517                    linear_rgb[idx + 2] = 0.3;
1518                }
1519            }
1520
1521            let encoder = VarDctEncoder::new(1.0);
1522            let bytes = encoder
1523                .encode(w, h, &linear_rgb, None)
1524                .unwrap_or_else(|e| panic!("encode {}x{} failed: {}", w, h, e))
1525                .data;
1526
1527            // Verify JXL signature
1528            assert_eq!(bytes[0], 0xFF, "{}x{}: bad signature byte 0", w, h);
1529            assert_eq!(bytes[1], 0x0A, "{}x{}: bad signature byte 1", w, h);
1530
1531            // Decode with jxl-oxide and verify dimensions
1532            let image = jxl_oxide::JxlImage::builder()
1533                .read(std::io::Cursor::new(&bytes))
1534                .unwrap_or_else(|e| panic!("jxl-oxide decode {}x{} failed: {}", w, h, e));
1535            assert_eq!(
1536                image.width(),
1537                w as u32,
1538                "{}x{}: decoded width mismatch",
1539                w,
1540                h
1541            );
1542            assert_eq!(
1543                image.height(),
1544                h as u32,
1545                "{}x{}: decoded height mismatch",
1546                w,
1547                h
1548            );
1549
1550            // Render to verify pixel data is valid
1551            let render = image
1552                .render_frame(0)
1553                .unwrap_or_else(|e| panic!("jxl-oxide render {}x{} failed: {}", w, h, e));
1554            let _pixels = render.image_all_channels();
1555        }
1556    }
1557
1558    /// Test DC tree learning produces valid output.
1559    #[test]
1560    fn test_dc_tree_learning() {
1561        let width = 64;
1562        let height = 64;
1563
1564        // Create a gradient image
1565        let mut linear_rgb = vec![0.0f32; width * height * 3];
1566        for y in 0..height {
1567            for x in 0..width {
1568                let idx = (y * width + x) * 3;
1569                linear_rgb[idx] = x as f32 / width as f32;
1570                linear_rgb[idx + 1] = y as f32 / height as f32;
1571                linear_rgb[idx + 2] = 0.5;
1572            }
1573        }
1574
1575        // Encode WITHOUT DC tree learning (baseline) — use ANS
1576        let mut encoder_baseline = VarDctEncoder::new(1.0);
1577        encoder_baseline.dc_tree_learning = false;
1578        let bytes_baseline = encoder_baseline
1579            .encode(width, height, &linear_rgb, None)
1580            .expect("baseline encode failed")
1581            .data;
1582
1583        // Encode WITH DC tree learning — also use ANS
1584        let mut encoder_learned = VarDctEncoder::new(1.0);
1585        encoder_learned.dc_tree_learning = true;
1586        std::fs::write(
1587            std::env::temp_dir().join("dc_baseline_test.jxl"),
1588            &bytes_baseline,
1589        )
1590        .unwrap();
1591        let bytes_learned = encoder_learned
1592            .encode(width, height, &linear_rgb, None)
1593            .expect("learned encode failed")
1594            .data;
1595        std::fs::write(
1596            std::env::temp_dir().join("dc_learned_test.jxl"),
1597            &bytes_learned,
1598        )
1599        .unwrap();
1600
1601        eprintln!(
1602            "DC tree learning: baseline={} bytes, learned={} bytes (delta={:.2}%)",
1603            bytes_baseline.len(),
1604            bytes_learned.len(),
1605            (bytes_learned.len() as f64 / bytes_baseline.len() as f64 - 1.0) * 100.0
1606        );
1607
1608        // Verify both produce valid JXL signature
1609        assert_eq!(bytes_baseline[0], 0xFF);
1610        assert_eq!(bytes_baseline[1], 0x0A);
1611        assert_eq!(bytes_learned[0], 0xFF);
1612        assert_eq!(bytes_learned[1], 0x0A);
1613
1614        // Verify baseline decodes (sanity check)
1615        {
1616            let image = jxl_oxide::JxlImage::builder()
1617                .read(std::io::Cursor::new(&bytes_baseline))
1618                .expect("jxl-oxide parse of baseline failed");
1619            let render = image
1620                .render_frame(0)
1621                .expect("jxl-oxide render of baseline failed");
1622            let _pixels = render.image_all_channels();
1623            eprintln!("Baseline ANS decodes OK ({} bytes)", bytes_baseline.len());
1624        }
1625
1626        // Decode the learned version with jxl-oxide to verify it's valid
1627        let image = jxl_oxide::JxlImage::builder()
1628            .read(std::io::Cursor::new(&bytes_learned))
1629            .expect("jxl-oxide decode of learned version failed");
1630        assert_eq!(image.width(), width as u32);
1631        assert_eq!(image.height(), height as u32);
1632
1633        // Render to verify pixel data is valid
1634        let render = image
1635            .render_frame(0)
1636            .expect("jxl-oxide render of learned version failed");
1637        let _pixels = render.image_all_channels();
1638        eprintln!("Learned ANS decodes OK ({} bytes)", bytes_learned.len());
1639
1640        // Also verify with djxl
1641        std::fs::write(
1642            std::env::temp_dir().join("dc_learned_test.jxl"),
1643            &bytes_learned,
1644        )
1645        .unwrap();
1646    }
1647
1648    /// Test that the butteraugli quantization loop produces valid output.
1649    #[cfg(feature = "butteraugli-loop")]
1650    #[test]
1651    fn test_butteraugli_loop_basic() {
1652        // Create a 64x64 test image with some variation
1653        let width = 64;
1654        let height = 64;
1655        let mut linear_rgb = vec![0.0f32; width * height * 3];
1656        for y in 0..height {
1657            for x in 0..width {
1658                let idx = (y * width + x) * 3;
1659                let fx = x as f32 / width as f32;
1660                let fy = y as f32 / height as f32;
1661                linear_rgb[idx] = fx * 0.8; // R
1662                linear_rgb[idx + 1] = fy * 0.6; // G
1663                linear_rgb[idx + 2] = (1.0 - fx) * 0.4; // B
1664            }
1665        }
1666
1667        // Encode without butteraugli loop
1668        let mut encoder_baseline = VarDctEncoder::new(2.0);
1669        encoder_baseline.butteraugli_iters = 0;
1670        let bytes_baseline = encoder_baseline
1671            .encode(width, height, &linear_rgb, None)
1672            .expect("baseline encode failed")
1673            .data;
1674
1675        // Encode with 2 butteraugli loop iterations
1676        let mut encoder_loop = VarDctEncoder::new(2.0);
1677        encoder_loop.butteraugli_iters = 2;
1678        let bytes_loop = encoder_loop
1679            .encode(width, height, &linear_rgb, None)
1680            .expect("butteraugli loop encode failed")
1681            .data;
1682
1683        // Both should produce valid JXL
1684        assert_eq!(bytes_baseline[0], 0xFF);
1685        assert_eq!(bytes_baseline[1], 0x0A);
1686        assert_eq!(bytes_loop[0], 0xFF);
1687        assert_eq!(bytes_loop[1], 0x0A);
1688
1689        // File sizes should differ (butteraugli loop changes quant field)
1690        eprintln!(
1691            "Baseline: {} bytes, Butteraugli loop (2 iters): {} bytes",
1692            bytes_baseline.len(),
1693            bytes_loop.len()
1694        );
1695
1696        // Verify the butteraugli-loop output decodes correctly
1697        let image = jxl_oxide::JxlImage::builder()
1698            .read(std::io::Cursor::new(&bytes_loop))
1699            .expect("jxl-oxide decode of butteraugli loop output failed");
1700        assert_eq!(image.width(), width as u32);
1701        assert_eq!(image.height(), height as u32);
1702
1703        let render = image
1704            .render_frame(0)
1705            .expect("jxl-oxide render of butteraugli loop output failed");
1706        let _pixels = render.image_all_channels();
1707        eprintln!("Butteraugli loop output decodes OK");
1708    }
1709}