jpegli/
decode.rs

1//! JPEG decoder implementation.
2//!
3//! This module provides the main decoder interface for reading JPEG images.
4//!
5//! # ICC Profile Support
6//!
7//! The decoder can extract and apply embedded ICC profiles, including XYB profiles
8//! used by jpegli. ICC profile support requires enabling `cms-lcms2` or `cms-moxcms` feature.
9//!
10//! ```ignore
11//! use jpegli::decode::Decoder;
12//!
13//! let decoder = Decoder::new().apply_icc(true);
14//! let decoded = decoder.decode(&jpeg_data)?;
15//! ```
16
17use crate::alloc::{
18    checked_size_2d, try_alloc_dct_blocks, validate_dimensions, DEFAULT_MAX_MEMORY,
19    DEFAULT_MAX_PIXELS,
20};
21use crate::color::{
22    gray_f32_to_gray_f32, gray_f32_to_gray_u8, gray_f32_to_rgb_f32, gray_f32_to_rgb_u8,
23    ycbcr_planes_f32_to_rgb_f32, ycbcr_planes_f32_to_rgb_u8,
24};
25use crate::consts::{
26    DCT_BLOCK_SIZE, DCT_SIZE, JPEG_NATURAL_ORDER, MARKER_APP0, MARKER_COM, MARKER_DHT, MARKER_DQT,
27    MARKER_DRI, MARKER_EOI, MARKER_SOF0, MARKER_SOF1, MARKER_SOF2, MARKER_SOI, MARKER_SOS,
28    MAX_COMPONENTS, MAX_HUFFMAN_TABLES, MAX_QUANT_TABLES,
29};
30use crate::entropy::EntropyDecoder;
31use crate::error::{Error, Result};
32use crate::huffman::HuffmanDecodeTable;
33#[cfg(any(feature = "cms-lcms2", feature = "cms-moxcms"))]
34use crate::icc::apply_icc_transform;
35use crate::icc::{extract_icc_profile, is_xyb_profile};
36use crate::idct::inverse_dct_8x8;
37use crate::quant::{dequantize_block, dequantize_block_with_bias, DequantBiasStats};
38use crate::types::{ColorSpace, Component, Dimensions, JpegMode, PixelFormat};
39
40/// Decoder configuration.
41#[derive(Debug, Clone)]
42pub struct DecoderConfig {
43    /// Output pixel format (None = use source format)
44    pub output_format: Option<PixelFormat>,
45    /// Whether to apply fancy upsampling
46    pub fancy_upsampling: bool,
47    /// Whether to apply block smoothing
48    pub block_smoothing: bool,
49    /// Whether to apply embedded ICC profile (requires cms feature)
50    pub apply_icc: bool,
51    /// Maximum pixels allowed (for DoS protection).
52    /// Default is 100 megapixels. Set to 0 for unlimited.
53    pub max_pixels: u64,
54    /// Maximum total memory for allocations (for DoS protection).
55    /// Default is 512 MB. Set to 0 for unlimited.
56    pub max_memory: usize,
57}
58
59impl Default for DecoderConfig {
60    fn default() -> Self {
61        Self {
62            output_format: None,
63            fancy_upsampling: false,
64            block_smoothing: false,
65            // Apply ICC by default when CMS is available
66            apply_icc: cfg!(any(feature = "cms-lcms2", feature = "cms-moxcms")),
67            max_pixels: DEFAULT_MAX_PIXELS,
68            max_memory: DEFAULT_MAX_MEMORY,
69        }
70    }
71}
72
73/// Information about a decoded JPEG.
74#[derive(Debug, Clone)]
75pub struct JpegInfo {
76    /// Image dimensions
77    pub dimensions: Dimensions,
78    /// Color space
79    pub color_space: ColorSpace,
80    /// Sample precision (8 or 12 bits)
81    pub precision: u8,
82    /// Number of components
83    pub num_components: u8,
84    /// Encoding mode
85    pub mode: JpegMode,
86    /// Whether an ICC profile is embedded
87    pub has_icc_profile: bool,
88    /// Whether the ICC profile is an XYB profile
89    pub is_xyb: bool,
90}
91
92/// JPEG decoder.
93pub struct Decoder {
94    config: DecoderConfig,
95}
96
97impl Decoder {
98    /// Creates a new decoder with default settings.
99    #[must_use]
100    pub fn new() -> Self {
101        Self {
102            config: DecoderConfig::default(),
103        }
104    }
105
106    /// Creates a decoder from configuration.
107    #[must_use]
108    pub fn from_config(config: DecoderConfig) -> Self {
109        Self { config }
110    }
111
112    /// Sets the output pixel format.
113    #[must_use]
114    pub fn output_format(mut self, format: PixelFormat) -> Self {
115        self.config.output_format = Some(format);
116        self
117    }
118
119    /// Enables fancy upsampling.
120    #[must_use]
121    pub fn fancy_upsampling(mut self, enable: bool) -> Self {
122        self.config.fancy_upsampling = enable;
123        self
124    }
125
126    /// Enables block smoothing.
127    #[must_use]
128    pub fn block_smoothing(mut self, enable: bool) -> Self {
129        self.config.block_smoothing = enable;
130        self
131    }
132
133    /// Enables ICC profile application.
134    ///
135    /// When enabled, embedded ICC profiles will be applied to convert
136    /// the image to sRGB. This is required for correct display of
137    /// XYB-encoded images.
138    ///
139    /// Note: Requires `cms-lcms2` or `cms-moxcms` feature to be enabled.
140    /// Without a CMS feature, this setting has no effect.
141    #[must_use]
142    pub fn apply_icc(mut self, enable: bool) -> Self {
143        self.config.apply_icc = enable;
144        self
145    }
146
147    /// Sets the maximum number of pixels allowed (for DoS protection).
148    ///
149    /// Default is 100 megapixels. Set to 0 for unlimited.
150    #[must_use]
151    pub fn max_pixels(mut self, pixels: u64) -> Self {
152        self.config.max_pixels = pixels;
153        self
154    }
155
156    /// Sets the maximum memory allowed for allocations during decoding.
157    ///
158    /// Default is 512 MB. Set to `usize::MAX` for unlimited.
159    /// This prevents memory exhaustion attacks from malicious images.
160    #[must_use]
161    pub fn max_memory(mut self, bytes: usize) -> Self {
162        self.config.max_memory = bytes;
163        self
164    }
165
166    /// Reads JPEG info without decoding.
167    pub fn read_info(&self, data: &[u8]) -> Result<JpegInfo> {
168        let mut parser = JpegParser::new(data, self.config.max_pixels)?;
169        parser.read_header()?;
170        Ok(parser.info())
171    }
172
173    /// Decodes a JPEG image.
174    pub fn decode(&self, data: &[u8]) -> Result<DecodedImage> {
175        let mut parser = JpegParser::new(data, self.config.max_pixels)?;
176        parser.decode()?;
177
178        let info = parser.info();
179        let output_format = self.config.output_format.unwrap_or(PixelFormat::Rgb);
180
181        // Convert to output format
182        // For XYB images, use simple dequantization so ICC profile works correctly
183        let mut pixels = parser.to_pixels(output_format, info.is_xyb)?;
184
185        // Apply ICC profile if enabled and present
186        #[cfg(any(feature = "cms-lcms2", feature = "cms-moxcms"))]
187        if self.config.apply_icc && output_format == PixelFormat::Rgb {
188            if let Some(ref icc_profile) = parser.icc_profile {
189                pixels = apply_icc_transform(
190                    &pixels,
191                    info.dimensions.width as usize,
192                    info.dimensions.height as usize,
193                    icc_profile,
194                )?;
195            }
196        }
197
198        Ok(DecodedImage {
199            width: info.dimensions.width,
200            height: info.dimensions.height,
201            format: output_format,
202            data: pixels,
203        })
204    }
205
206    /// Decodes a JPEG image to 32-bit floating point pixels.
207    ///
208    /// This preserves the full 12-bit internal precision of jpegli's decoder
209    /// without quantization to 8-bit. Values are normalized to range 0.0-1.0.
210    ///
211    /// # Example
212    ///
213    /// ```rust,ignore
214    /// use jpegli::decode::Decoder;
215    ///
216    /// let decoder = Decoder::new();
217    /// let image = decoder.decode_f32(&jpeg_data)?;
218    /// // image.data contains f32 values in range 0.0-1.0
219    /// ```
220    ///
221    /// Note: ICC profile application is not supported for f32 output.
222    /// If you need ICC profile transformation, decode to u8 first.
223    pub fn decode_f32(&self, data: &[u8]) -> Result<DecodedImageF32> {
224        let mut parser = JpegParser::new(data, self.config.max_pixels)?;
225        parser.decode()?;
226
227        let info = parser.info();
228        let output_format = self.config.output_format.unwrap_or(PixelFormat::Rgb);
229
230        // Convert to output format as f32
231        let pixels = parser.to_pixels_f32(output_format, info.is_xyb)?;
232
233        Ok(DecodedImageF32 {
234            width: info.dimensions.width,
235            height: info.dimensions.height,
236            format: output_format,
237            data: pixels,
238        })
239    }
240}
241
242impl Default for Decoder {
243    fn default() -> Self {
244        Self::new()
245    }
246}
247
248/// A decoded image with dimensions and pixel data.
249#[derive(Debug, Clone)]
250#[non_exhaustive]
251pub struct DecodedImage {
252    /// Image width in pixels
253    pub width: u32,
254    /// Image height in pixels
255    pub height: u32,
256    /// Pixel format of the data
257    pub format: PixelFormat,
258    /// Raw pixel data in the specified format
259    pub data: Vec<u8>,
260}
261
262impl DecodedImage {
263    /// Returns the image dimensions as a tuple (width, height).
264    #[must_use]
265    pub fn dimensions(&self) -> (u32, u32) {
266        (self.width, self.height)
267    }
268
269    /// Returns the number of bytes per pixel for this image's format.
270    #[must_use]
271    pub fn bytes_per_pixel(&self) -> usize {
272        self.format.bytes_per_pixel()
273    }
274
275    /// Returns the stride (bytes per row) of the image.
276    #[must_use]
277    pub fn stride(&self) -> usize {
278        self.width as usize * self.bytes_per_pixel()
279    }
280}
281
282/// A decoded image with 32-bit floating point pixel data.
283///
284/// This preserves the full 12-bit internal precision of jpegli's decoder
285/// without quantization to 8-bit. Values are in the range 0.0-1.0.
286///
287/// Use this format when you need:
288/// - Maximum precision for further image processing
289/// - HDR workflows
290/// - Scientific/medical imaging applications
291/// - Input to machine learning models
292#[derive(Debug, Clone)]
293#[non_exhaustive]
294pub struct DecodedImageF32 {
295    /// Image width in pixels
296    pub width: u32,
297    /// Image height in pixels
298    pub height: u32,
299    /// Pixel format of the data
300    pub format: PixelFormat,
301    /// Float pixel data in range 0.0-1.0
302    pub data: Vec<f32>,
303}
304
305impl DecodedImageF32 {
306    /// Returns the image dimensions as a tuple (width, height).
307    #[must_use]
308    pub fn dimensions(&self) -> (u32, u32) {
309        (self.width, self.height)
310    }
311
312    /// Returns the number of channels for this image's format.
313    #[must_use]
314    pub fn channels(&self) -> usize {
315        self.format.num_channels()
316    }
317
318    /// Returns the stride (floats per row) of the image.
319    #[must_use]
320    pub fn stride(&self) -> usize {
321        self.width as usize * self.channels()
322    }
323
324    /// Converts to 8-bit integer format.
325    ///
326    /// Values are scaled from 0.0-1.0 to 0-255 and clamped.
327    #[must_use]
328    pub fn to_u8(&self) -> DecodedImage {
329        let data: Vec<u8> = self
330            .data
331            .iter()
332            .map(|&v| (v * 255.0).round().clamp(0.0, 255.0) as u8)
333            .collect();
334
335        DecodedImage {
336            width: self.width,
337            height: self.height,
338            format: self.format,
339            data,
340        }
341    }
342
343    /// Converts to 16-bit integer format.
344    ///
345    /// Values are scaled from 0.0-1.0 to 0-65535 and clamped.
346    #[must_use]
347    pub fn to_u16(&self) -> Vec<u16> {
348        self.data
349            .iter()
350            .map(|&v| (v * 65535.0).round().clamp(0.0, 65535.0) as u16)
351            .collect()
352    }
353}
354
355/// Internal JPEG parser state.
356struct JpegParser<'a> {
357    data: &'a [u8],
358    position: usize,
359
360    // Frame info
361    width: u32,
362    height: u32,
363    precision: u8,
364    num_components: u8,
365    mode: JpegMode,
366
367    // Component info
368    components: [Component; MAX_COMPONENTS],
369
370    // Tables
371    quant_tables: [Option<[u16; DCT_BLOCK_SIZE]>; MAX_QUANT_TABLES],
372    dc_tables: [Option<HuffmanDecodeTable>; MAX_HUFFMAN_TABLES],
373    ac_tables: [Option<HuffmanDecodeTable>; MAX_HUFFMAN_TABLES],
374
375    // Restart
376    restart_interval: u16,
377
378    // Decoded coefficient data
379    coeffs: Vec<Vec<[i16; DCT_BLOCK_SIZE]>>, // Per component
380
381    // ICC profile (extracted from raw data, not during parsing)
382    icc_profile: Option<Vec<u8>>,
383
384    // Security limits
385    max_pixels: u64,
386}
387
388impl<'a> JpegParser<'a> {
389    fn new(data: &'a [u8], max_pixels: u64) -> Result<Self> {
390        // Check for SOI
391        if data.len() < 2 || data[0] != 0xFF || data[1] != MARKER_SOI {
392            return Err(Error::InvalidJpegData {
393                reason: "missing SOI marker",
394            });
395        }
396
397        // Extract ICC profile from raw data upfront
398        let icc_profile = extract_icc_profile(data);
399
400        Ok(Self {
401            data,
402            position: 2,
403            width: 0,
404            height: 0,
405            precision: 8,
406            num_components: 0,
407            mode: JpegMode::Baseline,
408            components: std::array::from_fn(|_| Component::default()),
409            quant_tables: [None, None, None, None],
410            dc_tables: [None, None, None, None],
411            ac_tables: [None, None, None, None],
412            restart_interval: 0,
413            coeffs: Vec::new(),
414            icc_profile,
415            max_pixels,
416        })
417    }
418
419    fn read_u8(&mut self) -> Result<u8> {
420        if self.position >= self.data.len() {
421            return Err(Error::UnexpectedEof {
422                context: "reading byte",
423            });
424        }
425        let byte = self.data[self.position];
426        self.position += 1;
427        Ok(byte)
428    }
429
430    fn read_u16(&mut self) -> Result<u16> {
431        let high = self.read_u8()? as u16;
432        let low = self.read_u8()? as u16;
433        Ok((high << 8) | low)
434    }
435
436    fn read_marker(&mut self) -> Result<u8> {
437        loop {
438            let byte = self.read_u8()?;
439            if byte != 0xFF {
440                continue;
441            }
442
443            let marker = self.read_u8()?;
444            if marker != 0x00 && marker != 0xFF {
445                return Ok(marker);
446            }
447        }
448    }
449
450    fn read_header(&mut self) -> Result<()> {
451        loop {
452            let marker = self.read_marker()?;
453
454            match marker {
455                MARKER_SOF0 | MARKER_SOF1 => {
456                    self.mode = JpegMode::Baseline;
457                    self.parse_frame_header()?;
458                    return Ok(());
459                }
460                MARKER_SOF2 => {
461                    self.mode = JpegMode::Progressive;
462                    self.parse_frame_header()?;
463                    return Ok(());
464                }
465                MARKER_DQT => self.parse_quant_table()?,
466                MARKER_DHT => self.parse_huffman_table()?,
467                MARKER_DRI => self.parse_restart_interval()?,
468                MARKER_APP0..=0xEF | MARKER_COM => self.skip_segment()?,
469                MARKER_EOI => {
470                    return Err(Error::InvalidJpegData {
471                        reason: "unexpected EOI before frame header",
472                    });
473                }
474                _ => self.skip_segment()?,
475            }
476        }
477    }
478
479    fn parse_frame_header(&mut self) -> Result<()> {
480        let length = self.read_u16()?;
481        if length < 8 {
482            return Err(Error::InvalidJpegData {
483                reason: "frame header too short",
484            });
485        }
486
487        self.precision = self.read_u8()?;
488        // Validate precision: must be 8 for baseline JPEG, 8 or 12 for extended
489        if self.precision != 8 && self.precision != 12 {
490            return Err(Error::InvalidJpegData {
491                reason: "invalid data precision (must be 8 or 12)",
492            });
493        }
494
495        self.height = self.read_u16()? as u32;
496        self.width = self.read_u16()? as u32;
497
498        // Validate dimensions against security limits
499        // max_pixels == 0 means unlimited
500        let effective_max = if self.max_pixels == 0 {
501            u64::MAX
502        } else {
503            self.max_pixels
504        };
505        validate_dimensions(self.width, self.height, effective_max)?;
506
507        self.num_components = self.read_u8()?;
508
509        // Validate num_components
510        if self.num_components == 0 {
511            return Err(Error::InvalidJpegData {
512                reason: "number of components is zero",
513            });
514        }
515        if self.num_components > MAX_COMPONENTS as u8 {
516            return Err(Error::UnsupportedFeature {
517                feature: "more than 4 components",
518            });
519        }
520
521        // Validate marker length matches expected size
522        let expected_length = 8 + 3 * self.num_components as u16;
523        if length != expected_length {
524            return Err(Error::InvalidJpegData {
525                reason: "SOF marker length mismatch",
526            });
527        }
528
529        for i in 0..self.num_components as usize {
530            self.components[i].id = self.read_u8()?;
531            let sampling = self.read_u8()?;
532            let h_samp = sampling >> 4;
533            let v_samp = sampling & 0x0F;
534
535            // Validate sampling factors are non-zero and <= 4
536            if h_samp == 0 || v_samp == 0 {
537                return Err(Error::InvalidJpegData {
538                    reason: "sampling factor is zero",
539                });
540            }
541            if h_samp > 4 || v_samp > 4 {
542                return Err(Error::InvalidJpegData {
543                    reason: "sampling factor exceeds maximum (4)",
544                });
545            }
546
547            self.components[i].h_samp_factor = h_samp;
548            self.components[i].v_samp_factor = v_samp;
549
550            let quant_idx = self.read_u8()?;
551            // Validate quant table index
552            if quant_idx as usize >= MAX_QUANT_TABLES {
553                return Err(Error::InvalidJpegData {
554                    reason: "quantization table index out of range",
555                });
556            }
557            self.components[i].quant_table_idx = quant_idx;
558        }
559
560        Ok(())
561    }
562
563    fn parse_quant_table(&mut self) -> Result<()> {
564        let mut length = self.read_u16()? as i32 - 2;
565
566        while length > 0 {
567            let info = self.read_u8()?;
568            let precision = info >> 4;
569            let table_idx = (info & 0x0F) as usize;
570
571            // Validate precision (0 = 8-bit, 1 = 16-bit)
572            if precision > 1 {
573                return Err(Error::InvalidQuantTable {
574                    table_idx: table_idx as u8,
575                    reason: "invalid precision (must be 0 or 1)",
576                });
577            }
578
579            if table_idx >= MAX_QUANT_TABLES {
580                return Err(Error::InvalidQuantTable {
581                    table_idx: table_idx as u8,
582                    reason: "table index out of range",
583                });
584            }
585
586            // Read values in zigzag order (as stored in JPEG)
587            let mut zigzag_values = [0u16; DCT_BLOCK_SIZE];
588
589            if precision == 0 {
590                // 8-bit values
591                for i in 0..DCT_BLOCK_SIZE {
592                    let val = self.read_u8()? as u16;
593                    if val == 0 {
594                        return Err(Error::InvalidQuantTable {
595                            table_idx: table_idx as u8,
596                            reason: "quantization value is zero",
597                        });
598                    }
599                    zigzag_values[i] = val;
600                }
601                length -= 65;
602            } else {
603                // 16-bit values
604                for i in 0..DCT_BLOCK_SIZE {
605                    let val = self.read_u16()?;
606                    if val == 0 {
607                        return Err(Error::InvalidQuantTable {
608                            table_idx: table_idx as u8,
609                            reason: "quantization value is zero",
610                        });
611                    }
612                    zigzag_values[i] = val;
613                }
614                length -= 129;
615            }
616
617            // Validate DQT marker length consistency
618            if length < 0 {
619                return Err(Error::InvalidJpegData {
620                    reason: "DQT marker length mismatch",
621                });
622            }
623
624            // Convert from zigzag order to natural order for dequantization
625            let mut natural_values = [0u16; DCT_BLOCK_SIZE];
626            for i in 0..DCT_BLOCK_SIZE {
627                natural_values[JPEG_NATURAL_ORDER[i] as usize] = zigzag_values[i];
628            }
629
630            self.quant_tables[table_idx] = Some(natural_values);
631        }
632
633        Ok(())
634    }
635
636    fn parse_huffman_table(&mut self) -> Result<()> {
637        let mut length = self.read_u16()? as i32 - 2;
638
639        while length > 0 {
640            let info = self.read_u8()?;
641            let table_class = info >> 4; // 0 = DC, 1 = AC
642            let table_idx = (info & 0x0F) as usize;
643
644            // Validate table class (must be 0 for DC or 1 for AC)
645            if table_class > 1 {
646                return Err(Error::InvalidHuffmanTable {
647                    table_idx: table_idx as u8,
648                    reason: "invalid table class (must be 0 or 1)",
649                });
650            }
651
652            if table_idx >= MAX_HUFFMAN_TABLES {
653                return Err(Error::InvalidHuffmanTable {
654                    table_idx: table_idx as u8,
655                    reason: "table index out of range",
656                });
657            }
658
659            let mut bits = [0u8; 16];
660            for i in 0..16 {
661                bits[i] = self.read_u8()?;
662            }
663
664            let num_values: usize = bits.iter().map(|&b| b as usize).sum();
665            let mut values = vec![0u8; num_values];
666            for i in 0..num_values {
667                values[i] = self.read_u8()?;
668            }
669
670            length -= 17 + num_values as i32;
671
672            // Validate that we didn't read past the marker length
673            if length < 0 {
674                return Err(Error::InvalidJpegData {
675                    reason: "DHT marker length mismatch",
676                });
677            }
678
679            let table = HuffmanDecodeTable::from_bits_values(&bits, &values)?;
680
681            if table_class == 0 {
682                self.dc_tables[table_idx] = Some(table);
683            } else {
684                self.ac_tables[table_idx] = Some(table);
685            }
686        }
687
688        Ok(())
689    }
690
691    fn parse_restart_interval(&mut self) -> Result<()> {
692        let _length = self.read_u16()?;
693        self.restart_interval = self.read_u16()?;
694        Ok(())
695    }
696
697    fn skip_segment(&mut self) -> Result<()> {
698        let length = self.read_u16()? as usize;
699        if length < 2 {
700            return Err(Error::InvalidJpegData {
701                reason: "segment length too short",
702            });
703        }
704        self.position += length - 2;
705        Ok(())
706    }
707
708    fn decode(&mut self) -> Result<()> {
709        // First read header
710        self.position = 2; // Skip SOI
711        self.read_header()?;
712
713        // Continue parsing until we hit SOS
714        loop {
715            let marker = self.read_marker()?;
716
717            match marker {
718                MARKER_SOS => {
719                    self.parse_scan()?;
720                    // After scan, look for more markers
721                }
722                MARKER_DQT => self.parse_quant_table()?,
723                MARKER_DHT => self.parse_huffman_table()?,
724                MARKER_DRI => self.parse_restart_interval()?,
725                MARKER_EOI => break,
726                MARKER_APP0..=0xEF | MARKER_COM => self.skip_segment()?,
727                _ => self.skip_segment()?,
728            }
729        }
730
731        Ok(())
732    }
733
734    fn parse_scan(&mut self) -> Result<()> {
735        let _length = self.read_u16()?;
736        let num_components = self.read_u8()?;
737
738        // Validate num_components in scan
739        if num_components == 0 {
740            return Err(Error::InvalidJpegData {
741                reason: "SOS num_components is zero",
742            });
743        }
744        if num_components > self.num_components {
745            return Err(Error::InvalidJpegData {
746                reason: "SOS num_components exceeds frame components",
747            });
748        }
749        if num_components > MAX_COMPONENTS as u8 {
750            return Err(Error::InvalidJpegData {
751                reason: "SOS num_components too large",
752            });
753        }
754
755        let mut scan_components = Vec::with_capacity(num_components as usize);
756
757        for _ in 0..num_components {
758            let component_id = self.read_u8()?;
759            let tables = self.read_u8()?;
760            let dc_table = tables >> 4;
761            let ac_table = tables & 0x0F;
762
763            // Validate Huffman table indexes
764            if dc_table as usize >= MAX_HUFFMAN_TABLES {
765                return Err(Error::InvalidJpegData {
766                    reason: "SOS DC Huffman table index out of range",
767                });
768            }
769            if ac_table as usize >= MAX_HUFFMAN_TABLES {
770                return Err(Error::InvalidJpegData {
771                    reason: "SOS AC Huffman table index out of range",
772                });
773            }
774
775            // Find component index
776            let comp_idx = self.components[..self.num_components as usize]
777                .iter()
778                .position(|c| c.id == component_id)
779                .ok_or(Error::InvalidJpegData {
780                    reason: "unknown component in scan",
781                })?;
782
783            scan_components.push((comp_idx, dc_table, ac_table));
784        }
785
786        let ss = self.read_u8()?; // Spectral selection start
787        let se = self.read_u8()?; // Spectral selection end
788        let ah_al = self.read_u8()?;
789        let ah = ah_al >> 4;
790        let al = ah_al & 0x0F;
791
792        // Validate spectral selection (must be 0-63)
793        if ss > 63 {
794            return Err(Error::InvalidJpegData {
795                reason: "SOS Ss (spectral start) out of range",
796            });
797        }
798        if se > 63 {
799            return Err(Error::InvalidJpegData {
800                reason: "SOS Se (spectral end) out of range",
801            });
802        }
803
804        // Decode entropy-coded segment based on mode
805        if self.mode == JpegMode::Progressive {
806            self.decode_progressive_scan(&scan_components, ss, se, ah, al)?;
807        } else {
808            self.decode_scan(&scan_components)?;
809        }
810
811        Ok(())
812    }
813
814    fn decode_scan(&mut self, scan_components: &[(usize, u8, u8)]) -> Result<()> {
815        // Calculate max sampling factors to determine MCU structure
816        let mut max_h_samp = 1u8;
817        let mut max_v_samp = 1u8;
818        for i in 0..self.num_components as usize {
819            max_h_samp = max_h_samp.max(self.components[i].h_samp_factor);
820            max_v_samp = max_v_samp.max(self.components[i].v_samp_factor);
821        }
822
823        // MCU dimensions in pixels
824        let mcu_width = (max_h_samp as usize) * 8;
825        let mcu_height = (max_v_samp as usize) * 8;
826
827        // Number of MCUs
828        let mcu_cols = (self.width as usize + mcu_width - 1) / mcu_width;
829        let mcu_rows = (self.height as usize + mcu_height - 1) / mcu_height;
830
831        // Initialize coefficient storage - size depends on component's sampling factor
832        if self.coeffs.is_empty() {
833            for i in 0..self.num_components as usize {
834                let h_samp = self.components[i].h_samp_factor as usize;
835                let v_samp = self.components[i].v_samp_factor as usize;
836                let comp_blocks_h = checked_size_2d(mcu_cols, h_samp)?;
837                let comp_blocks_v = checked_size_2d(mcu_rows, v_samp)?;
838                let num_blocks = checked_size_2d(comp_blocks_h, comp_blocks_v)?;
839                self.coeffs.push(try_alloc_dct_blocks(
840                    num_blocks,
841                    "allocating DCT coefficients",
842                )?);
843            }
844        }
845
846        // Set up entropy decoder
847        let scan_data = &self.data[self.position..];
848        let mut decoder = EntropyDecoder::new(scan_data);
849
850        for (_comp_idx, dc_table, ac_table) in scan_components {
851            let dc_idx = (*dc_table as usize).min(MAX_HUFFMAN_TABLES - 1);
852            let ac_idx = (*ac_table as usize).min(MAX_HUFFMAN_TABLES - 1);
853            if let Some(table) = &self.dc_tables[dc_idx] {
854                decoder.set_dc_table(dc_idx, table.clone());
855            }
856            if let Some(table) = &self.ac_tables[ac_idx] {
857                decoder.set_ac_table(ac_idx, table.clone());
858            }
859        }
860
861        // Decode MCUs with proper interleaving
862        let mut mcu_count = 0u32;
863        let restart_interval = self.restart_interval as u32;
864        let mut next_restart_num = 0u8;
865
866        for mcu_y in 0..mcu_rows {
867            for mcu_x in 0..mcu_cols {
868                // Check for restart marker
869                if restart_interval > 0 && mcu_count > 0 && mcu_count % restart_interval == 0 {
870                    // Align to byte boundary (discard padding bits)
871                    decoder.align_to_byte();
872                    // Read and verify restart marker
873                    decoder.read_restart_marker(next_restart_num)?;
874                    // Update expected marker number (cycles 0-7)
875                    next_restart_num = (next_restart_num + 1) & 7;
876                    // Reset DC predictors
877                    decoder.reset_dc();
878                }
879
880                // For each component in the scan
881                for (comp_idx, dc_table, ac_table) in scan_components {
882                    let h_samp = self.components[*comp_idx].h_samp_factor as usize;
883                    let v_samp = self.components[*comp_idx].v_samp_factor as usize;
884                    let comp_blocks_h = mcu_cols * h_samp;
885
886                    // Decode all blocks for this component in this MCU
887                    for v in 0..v_samp {
888                        for h in 0..h_samp {
889                            let block_x = mcu_x * h_samp + h;
890                            let block_y = mcu_y * v_samp + v;
891                            let block_idx = block_y * comp_blocks_h + block_x;
892
893                            let coeffs = decoder.decode_block(
894                                *comp_idx,
895                                *dc_table as usize,
896                                *ac_table as usize,
897                            )?;
898                            self.coeffs[*comp_idx][block_idx] = coeffs;
899                        }
900                    }
901                }
902
903                mcu_count += 1;
904            }
905        }
906
907        self.position += decoder.position();
908        Ok(())
909    }
910
911    fn decode_progressive_scan(
912        &mut self,
913        scan_components: &[(usize, u8, u8)],
914        ss: u8,
915        se: u8,
916        ah: u8,
917        al: u8,
918    ) -> Result<()> {
919        // Calculate max sampling factors to determine MCU structure
920        let mut max_h_samp = 1u8;
921        let mut max_v_samp = 1u8;
922        for i in 0..self.num_components as usize {
923            max_h_samp = max_h_samp.max(self.components[i].h_samp_factor);
924            max_v_samp = max_v_samp.max(self.components[i].v_samp_factor);
925        }
926
927        // MCU dimensions in pixels
928        let mcu_width = (max_h_samp as usize) * 8;
929        let mcu_height = (max_v_samp as usize) * 8;
930
931        // Number of MCUs
932        let mcu_cols = (self.width as usize + mcu_width - 1) / mcu_width;
933        let mcu_rows = (self.height as usize + mcu_height - 1) / mcu_height;
934
935        // Initialize coefficient storage if not already done
936        if self.coeffs.is_empty() {
937            for i in 0..self.num_components as usize {
938                let h_samp = self.components[i].h_samp_factor as usize;
939                let v_samp = self.components[i].v_samp_factor as usize;
940                let comp_blocks_h = checked_size_2d(mcu_cols, h_samp)?;
941                let comp_blocks_v = checked_size_2d(mcu_rows, v_samp)?;
942                let num_blocks = checked_size_2d(comp_blocks_h, comp_blocks_v)?;
943                self.coeffs.push(try_alloc_dct_blocks(
944                    num_blocks,
945                    "allocating DCT coefficients",
946                )?);
947            }
948        }
949
950        // Set up entropy decoder
951        let scan_data = &self.data[self.position..];
952        let mut decoder = EntropyDecoder::new(scan_data);
953
954        for (_comp_idx, dc_table, ac_table) in scan_components {
955            let dc_idx = (*dc_table as usize).min(MAX_HUFFMAN_TABLES - 1);
956            let ac_idx = (*ac_table as usize).min(MAX_HUFFMAN_TABLES - 1);
957            if let Some(table) = &self.dc_tables[dc_idx] {
958                decoder.set_dc_table(dc_idx, table.clone());
959            }
960            if let Some(table) = &self.ac_tables[ac_idx] {
961                decoder.set_ac_table(ac_idx, table.clone());
962            }
963        }
964
965        // Determine scan type
966        let is_dc_scan = ss == 0 && se == 0;
967        let is_first_scan = ah == 0;
968
969        // EOB run tracking for AC scans
970        let mut eob_run = 0u16;
971
972        // Restart marker handling
973        let mut mcu_count = 0u32;
974        let restart_interval = self.restart_interval as u32;
975        let mut next_restart_num = 0u8;
976
977        if is_dc_scan {
978            // DC scan (interleaved or single component)
979            for mcu_y in 0..mcu_rows {
980                for mcu_x in 0..mcu_cols {
981                    // Check for restart marker
982                    if restart_interval > 0 && mcu_count > 0 && mcu_count % restart_interval == 0 {
983                        // Align to byte boundary (discard padding bits)
984                        decoder.align_to_byte();
985                        // Read and verify restart marker
986                        decoder.read_restart_marker(next_restart_num)?;
987                        // Update expected marker number (cycles 0-7)
988                        next_restart_num = (next_restart_num + 1) & 7;
989                        // Reset DC predictors
990                        decoder.reset_dc();
991                    }
992
993                    for (comp_idx, dc_table, _ac_table) in scan_components {
994                        let h_samp = self.components[*comp_idx].h_samp_factor as usize;
995                        let v_samp = self.components[*comp_idx].v_samp_factor as usize;
996                        let comp_blocks_h = mcu_cols * h_samp;
997
998                        for v in 0..v_samp {
999                            for h in 0..h_samp {
1000                                let block_x = mcu_x * h_samp + h;
1001                                let block_y = mcu_y * v_samp + v;
1002                                let block_idx = block_y * comp_blocks_h + block_x;
1003
1004                                if is_first_scan {
1005                                    // DC first scan
1006                                    let dc = decoder.decode_dc_first(
1007                                        *comp_idx,
1008                                        *dc_table as usize,
1009                                        al,
1010                                    )?;
1011                                    self.coeffs[*comp_idx][block_idx][0] = dc;
1012                                } else {
1013                                    // DC refinement scan
1014                                    let bit = decoder.decode_dc_refine(al)?;
1015                                    self.coeffs[*comp_idx][block_idx][0] |= bit;
1016                                }
1017                            }
1018                        }
1019                    }
1020
1021                    mcu_count += 1;
1022                }
1023            }
1024        } else {
1025            // AC scan (single component only for progressive)
1026            // Progressive AC scans can only have one component
1027            if scan_components.len() != 1 {
1028                return Err(Error::InvalidJpegData {
1029                    reason: "progressive AC scan must have single component",
1030                });
1031            }
1032
1033            let (comp_idx, _dc_table, ac_table) = scan_components[0];
1034            let h_samp = self.components[comp_idx].h_samp_factor as usize;
1035            let v_samp = self.components[comp_idx].v_samp_factor as usize;
1036            let comp_blocks_h = mcu_cols * h_samp;
1037
1038            // Reset MCU count and restart number for AC scan (each scan has its own restart sequence)
1039            mcu_count = 0;
1040            next_restart_num = 0;
1041
1042            for mcu_y in 0..mcu_rows {
1043                for mcu_x in 0..mcu_cols {
1044                    // Check for restart marker
1045                    if restart_interval > 0 && mcu_count > 0 && mcu_count % restart_interval == 0 {
1046                        // Align to byte boundary (discard padding bits)
1047                        decoder.align_to_byte();
1048                        // Read and verify restart marker
1049                        decoder.read_restart_marker(next_restart_num)?;
1050                        // Update expected marker number (cycles 0-7)
1051                        next_restart_num = (next_restart_num + 1) & 7;
1052                        // Reset DC predictors and EOB run
1053                        decoder.reset_dc();
1054                        eob_run = 0;
1055                    }
1056
1057                    for v in 0..v_samp {
1058                        for h in 0..h_samp {
1059                            let block_x = mcu_x * h_samp + h;
1060                            let block_y = mcu_y * v_samp + v;
1061                            let block_idx = block_y * comp_blocks_h + block_x;
1062
1063                            if is_first_scan {
1064                                // AC first scan
1065                                decoder.decode_ac_first(
1066                                    &mut self.coeffs[comp_idx][block_idx],
1067                                    ac_table as usize,
1068                                    ss,
1069                                    se,
1070                                    al,
1071                                    &mut eob_run,
1072                                )?;
1073                            } else {
1074                                // AC refinement scan
1075                                decoder.decode_ac_refine(
1076                                    &mut self.coeffs[comp_idx][block_idx],
1077                                    ac_table as usize,
1078                                    ss,
1079                                    se,
1080                                    al,
1081                                    &mut eob_run,
1082                                )?;
1083                            }
1084                        }
1085                    }
1086
1087                    mcu_count += 1;
1088                }
1089            }
1090        }
1091
1092        self.position += decoder.position();
1093        Ok(())
1094    }
1095
1096    fn info(&self) -> JpegInfo {
1097        let has_icc = self.icc_profile.is_some();
1098        let is_xyb = self.icc_profile.as_ref().is_some_and(|p| is_xyb_profile(p));
1099
1100        // Determine color space, considering XYB profile
1101        let color_space = if is_xyb {
1102            ColorSpace::Xyb
1103        } else {
1104            match self.num_components {
1105                1 => ColorSpace::Grayscale,
1106                3 => ColorSpace::YCbCr,
1107                4 => ColorSpace::Cmyk,
1108                _ => ColorSpace::Unknown,
1109            }
1110        };
1111
1112        JpegInfo {
1113            dimensions: Dimensions::new(self.width, self.height),
1114            color_space,
1115            precision: self.precision,
1116            num_components: self.num_components,
1117            mode: self.mode,
1118            has_icc_profile: has_icc,
1119            is_xyb,
1120        }
1121    }
1122
1123    fn to_pixels(&self, format: PixelFormat, is_xyb: bool) -> Result<Vec<u8>> {
1124        if self.coeffs.is_empty() {
1125            return Err(Error::InternalError {
1126                reason: "no decoded data",
1127            });
1128        }
1129
1130        let width = self.width as usize;
1131        let height = self.height as usize;
1132
1133        // Calculate max sampling factors
1134        let mut max_h_samp = 1u8;
1135        let mut max_v_samp = 1u8;
1136        for i in 0..self.num_components as usize {
1137            max_h_samp = max_h_samp.max(self.components[i].h_samp_factor);
1138            max_v_samp = max_v_samp.max(self.components[i].v_samp_factor);
1139        }
1140
1141        // MCU dimensions
1142        let mcu_width = (max_h_samp as usize) * 8;
1143        let mcu_height = (max_v_samp as usize) * 8;
1144        let mcu_cols = (width + mcu_width - 1) / mcu_width;
1145        let mcu_rows = (height + mcu_height - 1) / mcu_height;
1146
1147        // Pre-compute component info for efficiency
1148        struct CompInfo {
1149            quant_idx: usize,
1150            h_samp: usize,
1151            v_samp: usize,
1152            comp_blocks_h: usize,
1153            comp_blocks_v: usize,
1154            comp_width: usize,
1155            comp_height: usize,
1156            is_full_res: bool,
1157        }
1158
1159        let mut comp_infos: Vec<CompInfo> = Vec::new();
1160        for comp_idx in 0..self.num_components as usize {
1161            let h_samp = self.components[comp_idx].h_samp_factor as usize;
1162            let v_samp = self.components[comp_idx].v_samp_factor as usize;
1163            let comp_blocks_h = mcu_cols * h_samp;
1164            let comp_blocks_v = mcu_rows * v_samp;
1165            let comp_width = checked_size_2d(comp_blocks_h, 8)?;
1166            let comp_height = checked_size_2d(comp_blocks_v, 8)?;
1167            comp_infos.push(CompInfo {
1168                quant_idx: self.components[comp_idx].quant_table_idx as usize,
1169                h_samp,
1170                v_samp,
1171                comp_blocks_h,
1172                comp_blocks_v,
1173                comp_width,
1174                comp_height,
1175                is_full_res: h_samp == max_h_samp as usize && v_samp == max_v_samp as usize,
1176            });
1177        }
1178
1179        // Initialize bias stats and biases (C++ initializes to 0 via memset)
1180        let mut bias_stats = DequantBiasStats::new(self.num_components as usize);
1181        let mut component_biases: Vec<[f32; DCT_BLOCK_SIZE]> =
1182            vec![[0.0f32; DCT_BLOCK_SIZE]; self.num_components as usize];
1183
1184        // Allocate component planes as f32 (C++ jpegli keeps f32 until final output)
1185        let mut comp_planes_f32: Vec<Vec<f32>> = Vec::new();
1186        for info in &comp_infos {
1187            let comp_plane_size = checked_size_2d(info.comp_width, info.comp_height)?;
1188            comp_planes_f32.push(vec![0.0f32; comp_plane_size]);
1189        }
1190
1191        // Process MCU row by MCU row (matching C++ incremental bias recomputation)
1192        for imcu_row in 0..mcu_rows {
1193            // For each component in this MCU row
1194            for comp_idx in 0..self.num_components as usize {
1195                let info = &comp_infos[comp_idx];
1196                let quant =
1197                    self.quant_tables[info.quant_idx]
1198                        .as_ref()
1199                        .ok_or(Error::InternalError {
1200                            reason: "missing quantization table",
1201                        })?;
1202
1203                // Phase 1: Gather stats for full-res components
1204                if info.is_full_res {
1205                    for iy in 0..info.v_samp {
1206                        let by = imcu_row * info.v_samp + iy;
1207                        if by >= info.comp_blocks_v {
1208                            continue;
1209                        }
1210                        for bx in 0..info.comp_blocks_h {
1211                            let block_idx = by * info.comp_blocks_h + bx;
1212                            if block_idx >= self.coeffs[comp_idx].len() {
1213                                continue;
1214                            }
1215                            let coeffs = &self.coeffs[comp_idx][block_idx];
1216                            let mut natural_coeffs = [0i16; DCT_BLOCK_SIZE];
1217                            for (i, &zi) in JPEG_NATURAL_ORDER[..DCT_BLOCK_SIZE].iter().enumerate()
1218                            {
1219                                natural_coeffs[zi as usize] = coeffs[i];
1220                            }
1221                            bias_stats.gather_block(comp_idx, &natural_coeffs);
1222                        }
1223                    }
1224
1225                    // Phase 2: Recompute biases every 4 MCU rows (matching C++ behavior)
1226                    if imcu_row % 4 == 3 {
1227                        component_biases[comp_idx] = bias_stats.compute_biases(comp_idx);
1228                    }
1229                }
1230
1231                // Phase 3: IDCT for this component in this MCU row
1232                // Store as f32 (C++ jpegli keeps f32 until final output for precision)
1233                let biases = &component_biases[comp_idx];
1234                let comp_plane_f32 = &mut comp_planes_f32[comp_idx];
1235
1236                for iy in 0..info.v_samp {
1237                    let by = imcu_row * info.v_samp + iy;
1238                    if by >= info.comp_blocks_v {
1239                        continue;
1240                    }
1241
1242                    // Pre-compute base y position and check row bounds once
1243                    let base_py = by * DCT_SIZE;
1244                    let rows_to_copy = DCT_SIZE.min(info.comp_height.saturating_sub(base_py));
1245
1246                    for bx in 0..info.comp_blocks_h {
1247                        let block_idx = by * info.comp_blocks_h + bx;
1248                        if block_idx >= self.coeffs[comp_idx].len() {
1249                            continue;
1250                        }
1251                        let coeffs = &self.coeffs[comp_idx][block_idx];
1252
1253                        // Zigzag reorder
1254                        let mut natural_coeffs = [0i16; DCT_BLOCK_SIZE];
1255                        for (i, &zi) in JPEG_NATURAL_ORDER[..DCT_BLOCK_SIZE].iter().enumerate() {
1256                            natural_coeffs[zi as usize] = coeffs[i];
1257                        }
1258
1259                        // Dequantize and IDCT
1260                        let dequant = if is_xyb {
1261                            dequantize_block(&natural_coeffs, quant)
1262                        } else {
1263                            dequantize_block_with_bias(&natural_coeffs, quant, biases)
1264                        };
1265                        let pixels = inverse_dct_8x8(&dequant);
1266
1267                        // Store pixels - use row-based copy for efficiency
1268                        let base_px = bx * DCT_SIZE;
1269                        let cols_to_copy = DCT_SIZE.min(info.comp_width.saturating_sub(base_px));
1270
1271                        if cols_to_copy == DCT_SIZE {
1272                            // Fast path: full 8-pixel row copy
1273                            for y in 0..rows_to_copy {
1274                                let dst_offset = (base_py + y) * info.comp_width + base_px;
1275                                let src_offset = y * DCT_SIZE;
1276                                comp_plane_f32[dst_offset..dst_offset + DCT_SIZE]
1277                                    .copy_from_slice(&pixels[src_offset..src_offset + DCT_SIZE]);
1278                            }
1279                        } else {
1280                            // Slow path: partial row copy (edge blocks)
1281                            for y in 0..rows_to_copy {
1282                                for x in 0..cols_to_copy {
1283                                    comp_plane_f32[(base_py + y) * info.comp_width + base_px + x] =
1284                                        pixels[y * DCT_SIZE + x];
1285                                }
1286                            }
1287                        }
1288                    }
1289                }
1290            }
1291        }
1292
1293        // Upsample if needed - keep as f32 for precision
1294        let output_size = checked_size_2d(width, height)?;
1295        let mut planes_f32: Vec<Vec<f32>> = Vec::new();
1296
1297        for comp_idx in 0..self.num_components as usize {
1298            let info = &comp_infos[comp_idx];
1299            let comp_plane_f32 = &comp_planes_f32[comp_idx];
1300
1301            let plane_f32 =
1302                if info.h_samp < max_h_samp as usize || info.v_samp < max_v_samp as usize {
1303                    let scale_x = max_h_samp as usize / info.h_samp;
1304                    let scale_y = max_v_samp as usize / info.v_samp;
1305                    let mut upsampled = vec![0.0f32; output_size];
1306                    for py in 0..height {
1307                        for px in 0..width {
1308                            let sx = (px / scale_x).min(info.comp_width - 1);
1309                            let sy = (py / scale_y).min(info.comp_height - 1);
1310                            upsampled[py * width + px] = comp_plane_f32[sy * info.comp_width + sx];
1311                        }
1312                    }
1313                    upsampled
1314                } else {
1315                    // Full resolution - just clip to image dimensions
1316                    let mut plane = vec![0.0f32; output_size];
1317                    for py in 0..height {
1318                        for px in 0..width {
1319                            plane[py * width + px] = comp_plane_f32[py * info.comp_width + px];
1320                        }
1321                    }
1322                    plane
1323                };
1324
1325            planes_f32.push(plane_f32);
1326        }
1327
1328        // Convert to output format using batch conversion functions
1329        match (self.num_components, format) {
1330            (1, PixelFormat::Gray) => {
1331                // Grayscale: level shift and convert to u8
1332                let mut output = vec![0u8; output_size];
1333                gray_f32_to_gray_u8(&planes_f32[0], &mut output);
1334                Ok(output)
1335            }
1336            (1, PixelFormat::Rgb) => {
1337                let rgb_size =
1338                    checked_size_2d(width, height).and_then(|s| checked_size_2d(s, 3))?;
1339                let mut rgb = vec![0u8; rgb_size];
1340                gray_f32_to_rgb_u8(&planes_f32[0], &mut rgb);
1341                Ok(rgb)
1342            }
1343            (3, PixelFormat::Rgb) => {
1344                let rgb_size =
1345                    checked_size_2d(width, height).and_then(|s| checked_size_2d(s, 3))?;
1346                let mut rgb = vec![0u8; rgb_size];
1347
1348                if is_xyb {
1349                    // XYB mode: Output raw level-shifted values, NO YCbCr→RGB conversion.
1350                    // The XYB values are stored in YCbCr positions but are NOT YCbCr.
1351                    // The ICC profile transforms these directly to sRGB.
1352                    for i in 0..output_size {
1353                        rgb[i * 3] = (planes_f32[0][i] + 128.0).clamp(0.0, 255.0) as u8;
1354                        rgb[i * 3 + 1] = (planes_f32[1][i] + 128.0).clamp(0.0, 255.0) as u8;
1355                        rgb[i * 3 + 2] = (planes_f32[2][i] + 128.0).clamp(0.0, 255.0) as u8;
1356                    }
1357                } else {
1358                    // YCbCr to RGB conversion using batch function
1359                    ycbcr_planes_f32_to_rgb_u8(
1360                        &planes_f32[0],
1361                        &planes_f32[1],
1362                        &planes_f32[2],
1363                        &mut rgb,
1364                    );
1365                }
1366                Ok(rgb)
1367            }
1368            _ => Err(Error::UnsupportedFeature {
1369                feature: "unsupported color conversion",
1370            }),
1371        }
1372    }
1373
1374    /// Convert decoded coefficients to f32 pixels.
1375    /// Values are normalized to range 0.0-1.0.
1376    fn to_pixels_f32(&self, format: PixelFormat, is_xyb: bool) -> Result<Vec<f32>> {
1377        if self.coeffs.is_empty() {
1378            return Err(Error::InternalError {
1379                reason: "no decoded data",
1380            });
1381        }
1382
1383        let width = self.width as usize;
1384        let height = self.height as usize;
1385
1386        // Calculate max sampling factors
1387        let mut max_h_samp = 1u8;
1388        let mut max_v_samp = 1u8;
1389        for i in 0..self.num_components as usize {
1390            max_h_samp = max_h_samp.max(self.components[i].h_samp_factor);
1391            max_v_samp = max_v_samp.max(self.components[i].v_samp_factor);
1392        }
1393
1394        // MCU dimensions
1395        let mcu_width = (max_h_samp as usize) * 8;
1396        let mcu_height = (max_v_samp as usize) * 8;
1397        let mcu_cols = (width + mcu_width - 1) / mcu_width;
1398        let mcu_rows = (height + mcu_height - 1) / mcu_height;
1399
1400        // Pre-compute component info
1401        struct CompInfo {
1402            quant_idx: usize,
1403            h_samp: usize,
1404            v_samp: usize,
1405            comp_blocks_h: usize,
1406            comp_blocks_v: usize,
1407            comp_width: usize,
1408            comp_height: usize,
1409            is_full_res: bool,
1410        }
1411
1412        let mut comp_infos: Vec<CompInfo> = Vec::new();
1413        for comp_idx in 0..self.num_components as usize {
1414            let h_samp = self.components[comp_idx].h_samp_factor as usize;
1415            let v_samp = self.components[comp_idx].v_samp_factor as usize;
1416            let comp_blocks_h = mcu_cols * h_samp;
1417            let comp_blocks_v = mcu_rows * v_samp;
1418            let comp_width = checked_size_2d(comp_blocks_h, 8)?;
1419            let comp_height = checked_size_2d(comp_blocks_v, 8)?;
1420            comp_infos.push(CompInfo {
1421                quant_idx: self.components[comp_idx].quant_table_idx as usize,
1422                h_samp,
1423                v_samp,
1424                comp_blocks_h,
1425                comp_blocks_v,
1426                comp_width,
1427                comp_height,
1428                is_full_res: h_samp == max_h_samp as usize && v_samp == max_v_samp as usize,
1429            });
1430        }
1431
1432        // Initialize bias stats and biases
1433        let mut bias_stats = DequantBiasStats::new(self.num_components as usize);
1434        let mut component_biases: Vec<[f32; DCT_BLOCK_SIZE]> =
1435            vec![[0.0f32; DCT_BLOCK_SIZE]; self.num_components as usize];
1436
1437        // Allocate component planes as f32
1438        let mut comp_planes_f32: Vec<Vec<f32>> = Vec::new();
1439        for info in &comp_infos {
1440            let comp_plane_size = checked_size_2d(info.comp_width, info.comp_height)?;
1441            comp_planes_f32.push(vec![0.0f32; comp_plane_size]);
1442        }
1443
1444        // Process MCU row by MCU row
1445        for imcu_row in 0..mcu_rows {
1446            for comp_idx in 0..self.num_components as usize {
1447                let info = &comp_infos[comp_idx];
1448                let quant =
1449                    self.quant_tables[info.quant_idx]
1450                        .as_ref()
1451                        .ok_or(Error::InternalError {
1452                            reason: "missing quantization table",
1453                        })?;
1454
1455                // Gather stats for full-res components
1456                if info.is_full_res {
1457                    for iy in 0..info.v_samp {
1458                        let by = imcu_row * info.v_samp + iy;
1459                        if by >= info.comp_blocks_v {
1460                            continue;
1461                        }
1462                        for bx in 0..info.comp_blocks_h {
1463                            let block_idx = by * info.comp_blocks_h + bx;
1464                            if block_idx >= self.coeffs[comp_idx].len() {
1465                                continue;
1466                            }
1467                            let coeffs = &self.coeffs[comp_idx][block_idx];
1468                            let mut natural_coeffs = [0i16; DCT_BLOCK_SIZE];
1469                            for (i, &zi) in JPEG_NATURAL_ORDER[..DCT_BLOCK_SIZE].iter().enumerate()
1470                            {
1471                                natural_coeffs[zi as usize] = coeffs[i];
1472                            }
1473                            bias_stats.gather_block(comp_idx, &natural_coeffs);
1474                        }
1475                    }
1476
1477                    // Recompute biases every 4 MCU rows
1478                    if imcu_row % 4 == 3 {
1479                        component_biases[comp_idx] = bias_stats.compute_biases(comp_idx);
1480                    }
1481                }
1482
1483                // IDCT for this component
1484                let biases = &component_biases[comp_idx];
1485                let comp_plane_f32 = &mut comp_planes_f32[comp_idx];
1486
1487                for iy in 0..info.v_samp {
1488                    let by = imcu_row * info.v_samp + iy;
1489                    if by >= info.comp_blocks_v {
1490                        continue;
1491                    }
1492
1493                    for bx in 0..info.comp_blocks_h {
1494                        let block_idx = by * info.comp_blocks_h + bx;
1495                        if block_idx >= self.coeffs[comp_idx].len() {
1496                            continue;
1497                        }
1498                        let coeffs = &self.coeffs[comp_idx][block_idx];
1499
1500                        let mut natural_coeffs = [0i16; DCT_BLOCK_SIZE];
1501                        for (i, &zi) in JPEG_NATURAL_ORDER[..DCT_BLOCK_SIZE].iter().enumerate() {
1502                            natural_coeffs[zi as usize] = coeffs[i];
1503                        }
1504
1505                        let dequant = if is_xyb {
1506                            dequantize_block(&natural_coeffs, quant)
1507                        } else {
1508                            dequantize_block_with_bias(&natural_coeffs, quant, biases)
1509                        };
1510                        let pixels = inverse_dct_8x8(&dequant);
1511
1512                        for y in 0..DCT_SIZE {
1513                            for x in 0..DCT_SIZE {
1514                                let px = bx * DCT_SIZE + x;
1515                                let py = by * DCT_SIZE + y;
1516                                if px < info.comp_width && py < info.comp_height {
1517                                    comp_plane_f32[py * info.comp_width + px] =
1518                                        pixels[y * DCT_SIZE + x];
1519                                }
1520                            }
1521                        }
1522                    }
1523                }
1524            }
1525        }
1526
1527        // Upsample if needed
1528        let output_size = checked_size_2d(width, height)?;
1529        let mut planes_f32: Vec<Vec<f32>> = Vec::new();
1530
1531        for comp_idx in 0..self.num_components as usize {
1532            let info = &comp_infos[comp_idx];
1533            let comp_plane_f32 = &comp_planes_f32[comp_idx];
1534
1535            let plane_f32 =
1536                if info.h_samp < max_h_samp as usize || info.v_samp < max_v_samp as usize {
1537                    let scale_x = max_h_samp as usize / info.h_samp;
1538                    let scale_y = max_v_samp as usize / info.v_samp;
1539                    let mut upsampled = vec![0.0f32; output_size];
1540                    for py in 0..height {
1541                        for px in 0..width {
1542                            let sx = (px / scale_x).min(info.comp_width - 1);
1543                            let sy = (py / scale_y).min(info.comp_height - 1);
1544                            upsampled[py * width + px] = comp_plane_f32[sy * info.comp_width + sx];
1545                        }
1546                    }
1547                    upsampled
1548                } else {
1549                    let mut plane = vec![0.0f32; output_size];
1550                    for py in 0..height {
1551                        for px in 0..width {
1552                            plane[py * width + px] = comp_plane_f32[py * info.comp_width + px];
1553                        }
1554                    }
1555                    plane
1556                };
1557
1558            planes_f32.push(plane_f32);
1559        }
1560
1561        // Convert to output format as f32 (values normalized to 0.0-1.0)
1562        match (self.num_components, format) {
1563            (1, PixelFormat::Gray) => {
1564                // Grayscale: level shift and normalize to 0.0-1.0
1565                let mut output = vec![0.0f32; output_size];
1566                gray_f32_to_gray_f32(&planes_f32[0], &mut output);
1567                Ok(output)
1568            }
1569            (1, PixelFormat::Rgb) => {
1570                let rgb_size =
1571                    checked_size_2d(width, height).and_then(|s| checked_size_2d(s, 3))?;
1572                let mut rgb = vec![0.0f32; rgb_size];
1573                gray_f32_to_rgb_f32(&planes_f32[0], &mut rgb);
1574                Ok(rgb)
1575            }
1576            (3, PixelFormat::Rgb) => {
1577                let rgb_size =
1578                    checked_size_2d(width, height).and_then(|s| checked_size_2d(s, 3))?;
1579                let mut rgb = vec![0.0f32; rgb_size];
1580
1581                if is_xyb {
1582                    // XYB mode: Output raw level-shifted values, normalized to 0.0-1.0
1583                    for i in 0..output_size {
1584                        rgb[i * 3] = ((planes_f32[0][i] + 128.0) / 255.0).clamp(0.0, 1.0);
1585                        rgb[i * 3 + 1] = ((planes_f32[1][i] + 128.0) / 255.0).clamp(0.0, 1.0);
1586                        rgb[i * 3 + 2] = ((planes_f32[2][i] + 128.0) / 255.0).clamp(0.0, 1.0);
1587                    }
1588                } else {
1589                    // YCbCr to RGB conversion using batch function
1590                    ycbcr_planes_f32_to_rgb_f32(
1591                        &planes_f32[0],
1592                        &planes_f32[1],
1593                        &planes_f32[2],
1594                        &mut rgb,
1595                    );
1596                }
1597                Ok(rgb)
1598            }
1599            _ => Err(Error::UnsupportedFeature {
1600                feature: "unsupported color conversion",
1601            }),
1602        }
1603    }
1604}
1605
1606#[cfg(test)]
1607mod tests {
1608    use super::*;
1609    use crate::encode::Encoder;
1610    use crate::quant::Quality;
1611
1612    #[test]
1613    fn test_decoder_creation() {
1614        let decoder = Decoder::new()
1615            .output_format(PixelFormat::Rgb)
1616            .fancy_upsampling(true);
1617
1618        assert_eq!(decoder.config.output_format, Some(PixelFormat::Rgb));
1619        assert!(decoder.config.fancy_upsampling);
1620    }
1621
1622    #[test]
1623    fn test_encode_decode_roundtrip_gray() {
1624        // Create a simple 8x8 grayscale image
1625        let width = 8;
1626        let height = 8;
1627        let mut input = vec![0u8; width * height];
1628        for y in 0..height {
1629            for x in 0..width {
1630                input[y * width + x] = ((x + y) * 16) as u8;
1631            }
1632        }
1633
1634        // Encode
1635        let encoder = Encoder::new()
1636            .width(width as u32)
1637            .height(height as u32)
1638            .pixel_format(PixelFormat::Gray)
1639            .quality(Quality::from_quality(95.0));
1640
1641        let jpeg = encoder.encode(&input).expect("encoding should succeed");
1642
1643        // Verify JPEG structure
1644        assert_eq!(jpeg[0], 0xFF);
1645        assert_eq!(jpeg[1], 0xD8); // SOI
1646        assert_eq!(jpeg[jpeg.len() - 2], 0xFF);
1647        assert_eq!(jpeg[jpeg.len() - 1], 0xD9); // EOI
1648
1649        // Decode
1650        let decoder = Decoder::new().output_format(PixelFormat::Gray);
1651        let decoded = decoder.decode(&jpeg).expect("decoding should succeed");
1652
1653        assert_eq!(decoded.width, width as u32);
1654        assert_eq!(decoded.height, height as u32);
1655        assert_eq!(decoded.data.len(), width * height);
1656
1657        // Check pixel values are reasonably close (JPEG is lossy)
1658        let mut max_diff = 0i32;
1659        for i in 0..input.len() {
1660            let diff = (input[i] as i32 - decoded.data[i] as i32).abs();
1661            max_diff = max_diff.max(diff);
1662        }
1663        // At quality 95, differences should be small
1664        assert!(max_diff < 20, "max_diff {} too large", max_diff);
1665    }
1666
1667    #[test]
1668    fn test_encode_decode_roundtrip_rgb() {
1669        // Create a simple 16x16 RGB image
1670        let width = 16;
1671        let height = 16;
1672        let mut input = vec![0u8; width * height * 3];
1673        for y in 0..height {
1674            for x in 0..width {
1675                let idx = (y * width + x) * 3;
1676                input[idx] = (x * 16) as u8; // R
1677                input[idx + 1] = (y * 16) as u8; // G
1678                input[idx + 2] = 128; // B
1679            }
1680        }
1681
1682        // Encode
1683        let encoder = Encoder::new()
1684            .width(width as u32)
1685            .height(height as u32)
1686            .pixel_format(PixelFormat::Rgb)
1687            .quality(Quality::from_quality(95.0));
1688
1689        let jpeg = encoder.encode(&input).expect("encoding should succeed");
1690
1691        // Decode
1692        let decoder = Decoder::new().output_format(PixelFormat::Rgb);
1693        let decoded = decoder.decode(&jpeg).expect("decoding should succeed");
1694
1695        assert_eq!(decoded.width, width as u32);
1696        assert_eq!(decoded.height, height as u32);
1697        assert_eq!(decoded.data.len(), width * height * 3);
1698
1699        // Check pixel values are reasonably close
1700        let mut max_diff = 0i32;
1701        for i in 0..input.len() {
1702            let diff = (input[i] as i32 - decoded.data[i] as i32).abs();
1703            max_diff = max_diff.max(diff);
1704        }
1705        // At quality 95, differences should be small
1706        assert!(max_diff < 30, "max_diff {} too large", max_diff);
1707    }
1708
1709    #[test]
1710    fn test_decode_f32_roundtrip() {
1711        // Create a simple 16x16 RGB image
1712        let width = 16;
1713        let height = 16;
1714        let mut input = vec![0u8; width * height * 3];
1715        for y in 0..height {
1716            for x in 0..width {
1717                let idx = (y * width + x) * 3;
1718                input[idx] = (x * 16) as u8; // R
1719                input[idx + 1] = (y * 16) as u8; // G
1720                input[idx + 2] = 128; // B
1721            }
1722        }
1723
1724        // Encode
1725        let encoder = Encoder::new()
1726            .width(width as u32)
1727            .height(height as u32)
1728            .pixel_format(PixelFormat::Rgb)
1729            .quality(Quality::from_quality(95.0));
1730
1731        let jpeg = encoder.encode(&input).expect("encoding should succeed");
1732
1733        // Decode to f32
1734        let decoder = Decoder::new().output_format(PixelFormat::Rgb);
1735        let decoded_f32 = decoder
1736            .decode_f32(&jpeg)
1737            .expect("f32 decoding should succeed");
1738
1739        assert_eq!(decoded_f32.width, width as u32);
1740        assert_eq!(decoded_f32.height, height as u32);
1741        assert_eq!(decoded_f32.data.len(), width * height * 3);
1742
1743        // Verify values are in 0.0-1.0 range
1744        for &v in &decoded_f32.data {
1745            assert!(v >= 0.0 && v <= 1.0, "f32 value {} out of range", v);
1746        }
1747
1748        // Compare with u8 decode - converted f32 should match
1749        let decoded_u8 = decoder.decode(&jpeg).expect("u8 decoding should succeed");
1750        let converted_u8 = decoded_f32.to_u8();
1751
1752        // Values should be very close (within 1 due to rounding)
1753        let mut max_diff = 0i32;
1754        for i in 0..decoded_u8.data.len() {
1755            let diff = (decoded_u8.data[i] as i32 - converted_u8.data[i] as i32).abs();
1756            max_diff = max_diff.max(diff);
1757        }
1758        assert!(
1759            max_diff <= 1,
1760            "f32→u8 conversion differs by {} from direct u8",
1761            max_diff
1762        );
1763    }
1764
1765    #[test]
1766    fn test_decode_f32_precision() {
1767        // Create a gradient image to test precision
1768        let width = 64;
1769        let height = 64;
1770        let mut input = vec![0u8; width * height * 3];
1771        for y in 0..height {
1772            for x in 0..width {
1773                let idx = (y * width + x) * 3;
1774                // Create a smooth gradient
1775                let val = ((x + y) * 2) as u8;
1776                input[idx] = val;
1777                input[idx + 1] = val;
1778                input[idx + 2] = val;
1779            }
1780        }
1781
1782        // Encode at high quality
1783        let encoder = Encoder::new()
1784            .width(width as u32)
1785            .height(height as u32)
1786            .pixel_format(PixelFormat::Rgb)
1787            .quality(Quality::from_quality(98.0));
1788
1789        let jpeg = encoder.encode(&input).expect("encoding should succeed");
1790
1791        // Decode to f32
1792        let decoder = Decoder::new().output_format(PixelFormat::Rgb);
1793        let decoded_f32 = decoder
1794            .decode_f32(&jpeg)
1795            .expect("f32 decoding should succeed");
1796
1797        // Check that f32 values show more precision than just u8/255
1798        // by verifying we have non-quantized intermediate values
1799        let mut found_fractional = false;
1800        for &v in &decoded_f32.data {
1801            let scaled = v * 255.0;
1802            let frac = scaled - scaled.round();
1803            if frac.abs() > 0.001 && frac.abs() < 0.999 {
1804                found_fractional = true;
1805                break;
1806            }
1807        }
1808        // f32 should preserve sub-integer precision
1809        assert!(
1810            found_fractional,
1811            "f32 output should have fractional precision"
1812        );
1813    }
1814}