jpegli/
decode.rs

1//! JPEG decoder implementation.
2//!
3//! This module provides the main decoder interface for reading JPEG images.
4//!
5//! # ICC Profile Support
6//!
7//! The decoder can extract and apply embedded ICC profiles, including XYB profiles
8//! used by jpegli. ICC profile support requires enabling `cms-lcms2` or `cms-moxcms` feature.
9//!
10//! ```ignore
11//! use jpegli::decode::Decoder;
12//!
13//! let decoder = Decoder::new().apply_icc(true);
14//! let decoded = decoder.decode(&jpeg_data)?;
15//! ```
16
17use crate::alloc::{
18    checked_size_2d, try_alloc_dct_blocks, try_alloc_zeroed, validate_dimensions,
19    DEFAULT_MAX_MEMORY, DEFAULT_MAX_PIXELS,
20};
21use crate::consts::{
22    DCT_BLOCK_SIZE, DCT_SIZE, JPEG_NATURAL_ORDER, MARKER_APP0, MARKER_COM, MARKER_DHT, MARKER_DQT,
23    MARKER_DRI, MARKER_EOI, MARKER_SOF0, MARKER_SOF1, MARKER_SOF2, MARKER_SOI, MARKER_SOS,
24    MAX_COMPONENTS, MAX_HUFFMAN_TABLES, MAX_QUANT_TABLES,
25};
26use crate::entropy::EntropyDecoder;
27use crate::error::{Error, Result};
28use crate::huffman::HuffmanDecodeTable;
29#[cfg(any(feature = "cms-lcms2", feature = "cms-moxcms"))]
30use crate::icc::apply_icc_transform;
31use crate::icc::{extract_icc_profile, is_xyb_profile};
32use crate::idct::inverse_dct_8x8;
33use crate::quant::{dequantize_block_with_bias, DequantBiasStats};
34use crate::types::{ColorSpace, Component, Dimensions, JpegMode, PixelFormat};
35
36/// Decoder configuration.
37#[derive(Debug, Clone)]
38pub struct DecoderConfig {
39    /// Output pixel format (None = use source format)
40    pub output_format: Option<PixelFormat>,
41    /// Whether to apply fancy upsampling
42    pub fancy_upsampling: bool,
43    /// Whether to apply block smoothing
44    pub block_smoothing: bool,
45    /// Whether to apply embedded ICC profile (requires cms feature)
46    pub apply_icc: bool,
47    /// Maximum pixels allowed (for DoS protection).
48    /// Default is 100 megapixels. Set to 0 for unlimited.
49    pub max_pixels: u64,
50    /// Maximum total memory for allocations (for DoS protection).
51    /// Default is 512 MB. Set to 0 for unlimited.
52    pub max_memory: usize,
53}
54
55impl Default for DecoderConfig {
56    fn default() -> Self {
57        Self {
58            output_format: None,
59            fancy_upsampling: false,
60            block_smoothing: false,
61            // Apply ICC by default when CMS is available
62            apply_icc: cfg!(any(feature = "cms-lcms2", feature = "cms-moxcms")),
63            max_pixels: DEFAULT_MAX_PIXELS,
64            max_memory: DEFAULT_MAX_MEMORY,
65        }
66    }
67}
68
69/// Information about a decoded JPEG.
70#[derive(Debug, Clone)]
71pub struct JpegInfo {
72    /// Image dimensions
73    pub dimensions: Dimensions,
74    /// Color space
75    pub color_space: ColorSpace,
76    /// Sample precision (8 or 12 bits)
77    pub precision: u8,
78    /// Number of components
79    pub num_components: u8,
80    /// Encoding mode
81    pub mode: JpegMode,
82    /// Whether an ICC profile is embedded
83    pub has_icc_profile: bool,
84    /// Whether the ICC profile is an XYB profile
85    pub is_xyb: bool,
86}
87
88/// JPEG decoder.
89pub struct Decoder {
90    config: DecoderConfig,
91}
92
93impl Decoder {
94    /// Creates a new decoder with default settings.
95    #[must_use]
96    pub fn new() -> Self {
97        Self {
98            config: DecoderConfig::default(),
99        }
100    }
101
102    /// Creates a decoder from configuration.
103    #[must_use]
104    pub fn from_config(config: DecoderConfig) -> Self {
105        Self { config }
106    }
107
108    /// Sets the output pixel format.
109    #[must_use]
110    pub fn output_format(mut self, format: PixelFormat) -> Self {
111        self.config.output_format = Some(format);
112        self
113    }
114
115    /// Enables fancy upsampling.
116    #[must_use]
117    pub fn fancy_upsampling(mut self, enable: bool) -> Self {
118        self.config.fancy_upsampling = enable;
119        self
120    }
121
122    /// Enables block smoothing.
123    #[must_use]
124    pub fn block_smoothing(mut self, enable: bool) -> Self {
125        self.config.block_smoothing = enable;
126        self
127    }
128
129    /// Enables ICC profile application.
130    ///
131    /// When enabled, embedded ICC profiles will be applied to convert
132    /// the image to sRGB. This is required for correct display of
133    /// XYB-encoded images.
134    ///
135    /// Note: Requires `cms-lcms2` or `cms-moxcms` feature to be enabled.
136    /// Without a CMS feature, this setting has no effect.
137    #[must_use]
138    pub fn apply_icc(mut self, enable: bool) -> Self {
139        self.config.apply_icc = enable;
140        self
141    }
142
143    /// Sets the maximum number of pixels allowed (for DoS protection).
144    ///
145    /// Default is 100 megapixels. Set to 0 for unlimited.
146    #[must_use]
147    pub fn max_pixels(mut self, pixels: u64) -> Self {
148        self.config.max_pixels = pixels;
149        self
150    }
151
152    /// Sets the maximum memory allowed for allocations during decoding.
153    ///
154    /// Default is 512 MB. Set to `usize::MAX` for unlimited.
155    /// This prevents memory exhaustion attacks from malicious images.
156    #[must_use]
157    pub fn max_memory(mut self, bytes: usize) -> Self {
158        self.config.max_memory = bytes;
159        self
160    }
161
162    /// Reads JPEG info without decoding.
163    pub fn read_info(&self, data: &[u8]) -> Result<JpegInfo> {
164        let mut parser = JpegParser::new(data, self.config.max_pixels)?;
165        parser.read_header()?;
166        Ok(parser.info())
167    }
168
169    /// Decodes a JPEG image.
170    pub fn decode(&self, data: &[u8]) -> Result<DecodedImage> {
171        let mut parser = JpegParser::new(data, self.config.max_pixels)?;
172        parser.decode()?;
173
174        let info = parser.info();
175        let output_format = self.config.output_format.unwrap_or(PixelFormat::Rgb);
176
177        // Convert to output format
178        let mut pixels = parser.to_pixels(output_format)?;
179
180        // Apply ICC profile if enabled and present
181        #[cfg(any(feature = "cms-lcms2", feature = "cms-moxcms"))]
182        if self.config.apply_icc && output_format == PixelFormat::Rgb {
183            if let Some(ref icc_profile) = parser.icc_profile {
184                pixels = apply_icc_transform(
185                    &pixels,
186                    info.dimensions.width as usize,
187                    info.dimensions.height as usize,
188                    icc_profile,
189                )?;
190            }
191        }
192
193        Ok(DecodedImage {
194            width: info.dimensions.width,
195            height: info.dimensions.height,
196            format: output_format,
197            data: pixels,
198        })
199    }
200}
201
202impl Default for Decoder {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208/// A decoded image with dimensions and pixel data.
209#[derive(Debug, Clone)]
210#[non_exhaustive]
211pub struct DecodedImage {
212    /// Image width in pixels
213    pub width: u32,
214    /// Image height in pixels
215    pub height: u32,
216    /// Pixel format of the data
217    pub format: PixelFormat,
218    /// Raw pixel data in the specified format
219    pub data: Vec<u8>,
220}
221
222impl DecodedImage {
223    /// Returns the image dimensions as a tuple (width, height).
224    #[must_use]
225    pub fn dimensions(&self) -> (u32, u32) {
226        (self.width, self.height)
227    }
228
229    /// Returns the number of bytes per pixel for this image's format.
230    #[must_use]
231    pub fn bytes_per_pixel(&self) -> usize {
232        self.format.bytes_per_pixel()
233    }
234
235    /// Returns the stride (bytes per row) of the image.
236    #[must_use]
237    pub fn stride(&self) -> usize {
238        self.width as usize * self.bytes_per_pixel()
239    }
240}
241
242/// Internal JPEG parser state.
243struct JpegParser<'a> {
244    data: &'a [u8],
245    position: usize,
246
247    // Frame info
248    width: u32,
249    height: u32,
250    precision: u8,
251    num_components: u8,
252    mode: JpegMode,
253
254    // Component info
255    components: [Component; MAX_COMPONENTS],
256
257    // Tables
258    quant_tables: [Option<[u16; DCT_BLOCK_SIZE]>; MAX_QUANT_TABLES],
259    dc_tables: [Option<HuffmanDecodeTable>; MAX_HUFFMAN_TABLES],
260    ac_tables: [Option<HuffmanDecodeTable>; MAX_HUFFMAN_TABLES],
261
262    // Restart
263    restart_interval: u16,
264
265    // Decoded coefficient data
266    coeffs: Vec<Vec<[i16; DCT_BLOCK_SIZE]>>, // Per component
267
268    // ICC profile (extracted from raw data, not during parsing)
269    icc_profile: Option<Vec<u8>>,
270
271    // Security limits
272    max_pixels: u64,
273}
274
275impl<'a> JpegParser<'a> {
276    fn new(data: &'a [u8], max_pixels: u64) -> Result<Self> {
277        // Check for SOI
278        if data.len() < 2 || data[0] != 0xFF || data[1] != MARKER_SOI {
279            return Err(Error::InvalidJpegData {
280                reason: "missing SOI marker",
281            });
282        }
283
284        // Extract ICC profile from raw data upfront
285        let icc_profile = extract_icc_profile(data);
286
287        Ok(Self {
288            data,
289            position: 2,
290            width: 0,
291            height: 0,
292            precision: 8,
293            num_components: 0,
294            mode: JpegMode::Baseline,
295            components: std::array::from_fn(|_| Component::default()),
296            quant_tables: [None, None, None, None],
297            dc_tables: [None, None, None, None],
298            ac_tables: [None, None, None, None],
299            restart_interval: 0,
300            coeffs: Vec::new(),
301            icc_profile,
302            max_pixels,
303        })
304    }
305
306    fn read_u8(&mut self) -> Result<u8> {
307        if self.position >= self.data.len() {
308            return Err(Error::UnexpectedEof {
309                context: "reading byte",
310            });
311        }
312        let byte = self.data[self.position];
313        self.position += 1;
314        Ok(byte)
315    }
316
317    fn read_u16(&mut self) -> Result<u16> {
318        let high = self.read_u8()? as u16;
319        let low = self.read_u8()? as u16;
320        Ok((high << 8) | low)
321    }
322
323    fn read_marker(&mut self) -> Result<u8> {
324        loop {
325            let byte = self.read_u8()?;
326            if byte != 0xFF {
327                continue;
328            }
329
330            let marker = self.read_u8()?;
331            if marker != 0x00 && marker != 0xFF {
332                return Ok(marker);
333            }
334        }
335    }
336
337    fn read_header(&mut self) -> Result<()> {
338        loop {
339            let marker = self.read_marker()?;
340
341            match marker {
342                MARKER_SOF0 | MARKER_SOF1 => {
343                    self.mode = JpegMode::Baseline;
344                    self.parse_frame_header()?;
345                    return Ok(());
346                }
347                MARKER_SOF2 => {
348                    self.mode = JpegMode::Progressive;
349                    self.parse_frame_header()?;
350                    return Ok(());
351                }
352                MARKER_DQT => self.parse_quant_table()?,
353                MARKER_DHT => self.parse_huffman_table()?,
354                MARKER_DRI => self.parse_restart_interval()?,
355                MARKER_APP0..=0xEF | MARKER_COM => self.skip_segment()?,
356                MARKER_EOI => {
357                    return Err(Error::InvalidJpegData {
358                        reason: "unexpected EOI before frame header",
359                    });
360                }
361                _ => self.skip_segment()?,
362            }
363        }
364    }
365
366    fn parse_frame_header(&mut self) -> Result<()> {
367        let length = self.read_u16()?;
368        if length < 8 {
369            return Err(Error::InvalidJpegData {
370                reason: "frame header too short",
371            });
372        }
373
374        self.precision = self.read_u8()?;
375        // Validate precision: must be 8 for baseline JPEG, 8 or 12 for extended
376        if self.precision != 8 && self.precision != 12 {
377            return Err(Error::InvalidJpegData {
378                reason: "invalid data precision (must be 8 or 12)",
379            });
380        }
381
382        self.height = self.read_u16()? as u32;
383        self.width = self.read_u16()? as u32;
384
385        // Validate dimensions against security limits
386        // max_pixels == 0 means unlimited
387        let effective_max = if self.max_pixels == 0 {
388            u64::MAX
389        } else {
390            self.max_pixels
391        };
392        validate_dimensions(self.width, self.height, effective_max)?;
393
394        self.num_components = self.read_u8()?;
395
396        // Validate num_components
397        if self.num_components == 0 {
398            return Err(Error::InvalidJpegData {
399                reason: "number of components is zero",
400            });
401        }
402        if self.num_components > MAX_COMPONENTS as u8 {
403            return Err(Error::UnsupportedFeature {
404                feature: "more than 4 components",
405            });
406        }
407
408        // Validate marker length matches expected size
409        let expected_length = 8 + 3 * self.num_components as u16;
410        if length != expected_length {
411            return Err(Error::InvalidJpegData {
412                reason: "SOF marker length mismatch",
413            });
414        }
415
416        for i in 0..self.num_components as usize {
417            self.components[i].id = self.read_u8()?;
418            let sampling = self.read_u8()?;
419            let h_samp = sampling >> 4;
420            let v_samp = sampling & 0x0F;
421
422            // Validate sampling factors are non-zero and <= 4
423            if h_samp == 0 || v_samp == 0 {
424                return Err(Error::InvalidJpegData {
425                    reason: "sampling factor is zero",
426                });
427            }
428            if h_samp > 4 || v_samp > 4 {
429                return Err(Error::InvalidJpegData {
430                    reason: "sampling factor exceeds maximum (4)",
431                });
432            }
433
434            self.components[i].h_samp_factor = h_samp;
435            self.components[i].v_samp_factor = v_samp;
436
437            let quant_idx = self.read_u8()?;
438            // Validate quant table index
439            if quant_idx as usize >= MAX_QUANT_TABLES {
440                return Err(Error::InvalidJpegData {
441                    reason: "quantization table index out of range",
442                });
443            }
444            self.components[i].quant_table_idx = quant_idx;
445        }
446
447        Ok(())
448    }
449
450    fn parse_quant_table(&mut self) -> Result<()> {
451        let mut length = self.read_u16()? as i32 - 2;
452
453        while length > 0 {
454            let info = self.read_u8()?;
455            let precision = info >> 4;
456            let table_idx = (info & 0x0F) as usize;
457
458            // Validate precision (0 = 8-bit, 1 = 16-bit)
459            if precision > 1 {
460                return Err(Error::InvalidQuantTable {
461                    table_idx: table_idx as u8,
462                    reason: "invalid precision (must be 0 or 1)",
463                });
464            }
465
466            if table_idx >= MAX_QUANT_TABLES {
467                return Err(Error::InvalidQuantTable {
468                    table_idx: table_idx as u8,
469                    reason: "table index out of range",
470                });
471            }
472
473            // Read values in zigzag order (as stored in JPEG)
474            let mut zigzag_values = [0u16; DCT_BLOCK_SIZE];
475
476            if precision == 0 {
477                // 8-bit values
478                for i in 0..DCT_BLOCK_SIZE {
479                    let val = self.read_u8()? as u16;
480                    if val == 0 {
481                        return Err(Error::InvalidQuantTable {
482                            table_idx: table_idx as u8,
483                            reason: "quantization value is zero",
484                        });
485                    }
486                    zigzag_values[i] = val;
487                }
488                length -= 65;
489            } else {
490                // 16-bit values
491                for i in 0..DCT_BLOCK_SIZE {
492                    let val = self.read_u16()?;
493                    if val == 0 {
494                        return Err(Error::InvalidQuantTable {
495                            table_idx: table_idx as u8,
496                            reason: "quantization value is zero",
497                        });
498                    }
499                    zigzag_values[i] = val;
500                }
501                length -= 129;
502            }
503
504            // Validate DQT marker length consistency
505            if length < 0 {
506                return Err(Error::InvalidJpegData {
507                    reason: "DQT marker length mismatch",
508                });
509            }
510
511            // Convert from zigzag order to natural order for dequantization
512            let mut natural_values = [0u16; DCT_BLOCK_SIZE];
513            for i in 0..DCT_BLOCK_SIZE {
514                natural_values[JPEG_NATURAL_ORDER[i] as usize] = zigzag_values[i];
515            }
516
517            self.quant_tables[table_idx] = Some(natural_values);
518        }
519
520        Ok(())
521    }
522
523    fn parse_huffman_table(&mut self) -> Result<()> {
524        let mut length = self.read_u16()? as i32 - 2;
525
526        while length > 0 {
527            let info = self.read_u8()?;
528            let table_class = info >> 4; // 0 = DC, 1 = AC
529            let table_idx = (info & 0x0F) as usize;
530
531            // Validate table class (must be 0 for DC or 1 for AC)
532            if table_class > 1 {
533                return Err(Error::InvalidHuffmanTable {
534                    table_idx: table_idx as u8,
535                    reason: "invalid table class (must be 0 or 1)",
536                });
537            }
538
539            if table_idx >= MAX_HUFFMAN_TABLES {
540                return Err(Error::InvalidHuffmanTable {
541                    table_idx: table_idx as u8,
542                    reason: "table index out of range",
543                });
544            }
545
546            let mut bits = [0u8; 16];
547            for i in 0..16 {
548                bits[i] = self.read_u8()?;
549            }
550
551            let num_values: usize = bits.iter().map(|&b| b as usize).sum();
552            let mut values = vec![0u8; num_values];
553            for i in 0..num_values {
554                values[i] = self.read_u8()?;
555            }
556
557            length -= 17 + num_values as i32;
558
559            // Validate that we didn't read past the marker length
560            if length < 0 {
561                return Err(Error::InvalidJpegData {
562                    reason: "DHT marker length mismatch",
563                });
564            }
565
566            let table = HuffmanDecodeTable::from_bits_values(&bits, &values)?;
567
568            if table_class == 0 {
569                self.dc_tables[table_idx] = Some(table);
570            } else {
571                self.ac_tables[table_idx] = Some(table);
572            }
573        }
574
575        Ok(())
576    }
577
578    fn parse_restart_interval(&mut self) -> Result<()> {
579        let _length = self.read_u16()?;
580        self.restart_interval = self.read_u16()?;
581        Ok(())
582    }
583
584    fn skip_segment(&mut self) -> Result<()> {
585        let length = self.read_u16()? as usize;
586        if length < 2 {
587            return Err(Error::InvalidJpegData {
588                reason: "segment length too short",
589            });
590        }
591        self.position += length - 2;
592        Ok(())
593    }
594
595    fn decode(&mut self) -> Result<()> {
596        // First read header
597        self.position = 2; // Skip SOI
598        self.read_header()?;
599
600        // Continue parsing until we hit SOS
601        loop {
602            let marker = self.read_marker()?;
603
604            match marker {
605                MARKER_SOS => {
606                    self.parse_scan()?;
607                    // After scan, look for more markers
608                }
609                MARKER_DQT => self.parse_quant_table()?,
610                MARKER_DHT => self.parse_huffman_table()?,
611                MARKER_DRI => self.parse_restart_interval()?,
612                MARKER_EOI => break,
613                MARKER_APP0..=0xEF | MARKER_COM => self.skip_segment()?,
614                _ => self.skip_segment()?,
615            }
616        }
617
618        Ok(())
619    }
620
621    fn parse_scan(&mut self) -> Result<()> {
622        let _length = self.read_u16()?;
623        let num_components = self.read_u8()?;
624
625        // Validate num_components in scan
626        if num_components == 0 {
627            return Err(Error::InvalidJpegData {
628                reason: "SOS num_components is zero",
629            });
630        }
631        if num_components > self.num_components {
632            return Err(Error::InvalidJpegData {
633                reason: "SOS num_components exceeds frame components",
634            });
635        }
636        if num_components > MAX_COMPONENTS as u8 {
637            return Err(Error::InvalidJpegData {
638                reason: "SOS num_components too large",
639            });
640        }
641
642        let mut scan_components = Vec::with_capacity(num_components as usize);
643
644        for _ in 0..num_components {
645            let component_id = self.read_u8()?;
646            let tables = self.read_u8()?;
647            let dc_table = tables >> 4;
648            let ac_table = tables & 0x0F;
649
650            // Validate Huffman table indexes
651            if dc_table as usize >= MAX_HUFFMAN_TABLES {
652                return Err(Error::InvalidJpegData {
653                    reason: "SOS DC Huffman table index out of range",
654                });
655            }
656            if ac_table as usize >= MAX_HUFFMAN_TABLES {
657                return Err(Error::InvalidJpegData {
658                    reason: "SOS AC Huffman table index out of range",
659                });
660            }
661
662            // Find component index
663            let comp_idx = self.components[..self.num_components as usize]
664                .iter()
665                .position(|c| c.id == component_id)
666                .ok_or(Error::InvalidJpegData {
667                    reason: "unknown component in scan",
668                })?;
669
670            scan_components.push((comp_idx, dc_table, ac_table));
671        }
672
673        let ss = self.read_u8()?; // Spectral selection start
674        let se = self.read_u8()?; // Spectral selection end
675        let ah_al = self.read_u8()?;
676        let ah = ah_al >> 4;
677        let al = ah_al & 0x0F;
678
679        // Validate spectral selection (must be 0-63)
680        if ss > 63 {
681            return Err(Error::InvalidJpegData {
682                reason: "SOS Ss (spectral start) out of range",
683            });
684        }
685        if se > 63 {
686            return Err(Error::InvalidJpegData {
687                reason: "SOS Se (spectral end) out of range",
688            });
689        }
690
691        // Decode entropy-coded segment based on mode
692        if self.mode == JpegMode::Progressive {
693            self.decode_progressive_scan(&scan_components, ss, se, ah, al)?;
694        } else {
695            self.decode_scan(&scan_components)?;
696        }
697
698        Ok(())
699    }
700
701    fn decode_scan(&mut self, scan_components: &[(usize, u8, u8)]) -> Result<()> {
702        // Calculate max sampling factors to determine MCU structure
703        let mut max_h_samp = 1u8;
704        let mut max_v_samp = 1u8;
705        for i in 0..self.num_components as usize {
706            max_h_samp = max_h_samp.max(self.components[i].h_samp_factor);
707            max_v_samp = max_v_samp.max(self.components[i].v_samp_factor);
708        }
709
710        // MCU dimensions in pixels
711        let mcu_width = (max_h_samp as usize) * 8;
712        let mcu_height = (max_v_samp as usize) * 8;
713
714        // Number of MCUs
715        let mcu_cols = (self.width as usize + mcu_width - 1) / mcu_width;
716        let mcu_rows = (self.height as usize + mcu_height - 1) / mcu_height;
717
718        // Initialize coefficient storage - size depends on component's sampling factor
719        if self.coeffs.is_empty() {
720            for i in 0..self.num_components as usize {
721                let h_samp = self.components[i].h_samp_factor as usize;
722                let v_samp = self.components[i].v_samp_factor as usize;
723                let comp_blocks_h = checked_size_2d(mcu_cols, h_samp)?;
724                let comp_blocks_v = checked_size_2d(mcu_rows, v_samp)?;
725                let num_blocks = checked_size_2d(comp_blocks_h, comp_blocks_v)?;
726                self.coeffs.push(try_alloc_dct_blocks(
727                    num_blocks,
728                    "allocating DCT coefficients",
729                )?);
730            }
731        }
732
733        // Set up entropy decoder
734        let scan_data = &self.data[self.position..];
735        let mut decoder = EntropyDecoder::new(scan_data);
736
737        for (_comp_idx, dc_table, ac_table) in scan_components {
738            let dc_idx = (*dc_table as usize).min(MAX_HUFFMAN_TABLES - 1);
739            let ac_idx = (*ac_table as usize).min(MAX_HUFFMAN_TABLES - 1);
740            if let Some(table) = &self.dc_tables[dc_idx] {
741                decoder.set_dc_table(dc_idx, table.clone());
742            }
743            if let Some(table) = &self.ac_tables[ac_idx] {
744                decoder.set_ac_table(ac_idx, table.clone());
745            }
746        }
747
748        // Decode MCUs with proper interleaving
749        for mcu_y in 0..mcu_rows {
750            for mcu_x in 0..mcu_cols {
751                // For each component in the scan
752                for (comp_idx, dc_table, ac_table) in scan_components {
753                    let h_samp = self.components[*comp_idx].h_samp_factor as usize;
754                    let v_samp = self.components[*comp_idx].v_samp_factor as usize;
755                    let comp_blocks_h = mcu_cols * h_samp;
756
757                    // Decode all blocks for this component in this MCU
758                    for v in 0..v_samp {
759                        for h in 0..h_samp {
760                            let block_x = mcu_x * h_samp + h;
761                            let block_y = mcu_y * v_samp + v;
762                            let block_idx = block_y * comp_blocks_h + block_x;
763
764                            let coeffs = decoder.decode_block(
765                                *comp_idx,
766                                *dc_table as usize,
767                                *ac_table as usize,
768                            )?;
769                            self.coeffs[*comp_idx][block_idx] = coeffs;
770                        }
771                    }
772                }
773            }
774        }
775
776        self.position += decoder.position();
777        Ok(())
778    }
779
780    fn decode_progressive_scan(
781        &mut self,
782        scan_components: &[(usize, u8, u8)],
783        ss: u8,
784        se: u8,
785        ah: u8,
786        al: u8,
787    ) -> Result<()> {
788        // Calculate max sampling factors to determine MCU structure
789        let mut max_h_samp = 1u8;
790        let mut max_v_samp = 1u8;
791        for i in 0..self.num_components as usize {
792            max_h_samp = max_h_samp.max(self.components[i].h_samp_factor);
793            max_v_samp = max_v_samp.max(self.components[i].v_samp_factor);
794        }
795
796        // MCU dimensions in pixels
797        let mcu_width = (max_h_samp as usize) * 8;
798        let mcu_height = (max_v_samp as usize) * 8;
799
800        // Number of MCUs
801        let mcu_cols = (self.width as usize + mcu_width - 1) / mcu_width;
802        let mcu_rows = (self.height as usize + mcu_height - 1) / mcu_height;
803
804        // Initialize coefficient storage if not already done
805        if self.coeffs.is_empty() {
806            for i in 0..self.num_components as usize {
807                let h_samp = self.components[i].h_samp_factor as usize;
808                let v_samp = self.components[i].v_samp_factor as usize;
809                let comp_blocks_h = checked_size_2d(mcu_cols, h_samp)?;
810                let comp_blocks_v = checked_size_2d(mcu_rows, v_samp)?;
811                let num_blocks = checked_size_2d(comp_blocks_h, comp_blocks_v)?;
812                self.coeffs.push(try_alloc_dct_blocks(
813                    num_blocks,
814                    "allocating DCT coefficients",
815                )?);
816            }
817        }
818
819        // Set up entropy decoder
820        let scan_data = &self.data[self.position..];
821        let mut decoder = EntropyDecoder::new(scan_data);
822
823        for (_comp_idx, dc_table, ac_table) in scan_components {
824            let dc_idx = (*dc_table as usize).min(MAX_HUFFMAN_TABLES - 1);
825            let ac_idx = (*ac_table as usize).min(MAX_HUFFMAN_TABLES - 1);
826            if let Some(table) = &self.dc_tables[dc_idx] {
827                decoder.set_dc_table(dc_idx, table.clone());
828            }
829            if let Some(table) = &self.ac_tables[ac_idx] {
830                decoder.set_ac_table(ac_idx, table.clone());
831            }
832        }
833
834        // Determine scan type
835        let is_dc_scan = ss == 0 && se == 0;
836        let is_first_scan = ah == 0;
837
838        // EOB run tracking for AC scans
839        let mut eob_run = 0u16;
840
841        if is_dc_scan {
842            // DC scan (interleaved or single component)
843            for mcu_y in 0..mcu_rows {
844                for mcu_x in 0..mcu_cols {
845                    for (comp_idx, dc_table, _ac_table) in scan_components {
846                        let h_samp = self.components[*comp_idx].h_samp_factor as usize;
847                        let v_samp = self.components[*comp_idx].v_samp_factor as usize;
848                        let comp_blocks_h = mcu_cols * h_samp;
849
850                        for v in 0..v_samp {
851                            for h in 0..h_samp {
852                                let block_x = mcu_x * h_samp + h;
853                                let block_y = mcu_y * v_samp + v;
854                                let block_idx = block_y * comp_blocks_h + block_x;
855
856                                if is_first_scan {
857                                    // DC first scan
858                                    let dc = decoder.decode_dc_first(
859                                        *comp_idx,
860                                        *dc_table as usize,
861                                        al,
862                                    )?;
863                                    self.coeffs[*comp_idx][block_idx][0] = dc;
864                                } else {
865                                    // DC refinement scan
866                                    let bit = decoder.decode_dc_refine(al)?;
867                                    self.coeffs[*comp_idx][block_idx][0] |= bit;
868                                }
869                            }
870                        }
871                    }
872                }
873            }
874        } else {
875            // AC scan (single component only for progressive)
876            // Progressive AC scans can only have one component
877            if scan_components.len() != 1 {
878                return Err(Error::InvalidJpegData {
879                    reason: "progressive AC scan must have single component",
880                });
881            }
882
883            let (comp_idx, _dc_table, ac_table) = scan_components[0];
884            let h_samp = self.components[comp_idx].h_samp_factor as usize;
885            let v_samp = self.components[comp_idx].v_samp_factor as usize;
886            let comp_blocks_h = mcu_cols * h_samp;
887
888            for mcu_y in 0..mcu_rows {
889                for mcu_x in 0..mcu_cols {
890                    for v in 0..v_samp {
891                        for h in 0..h_samp {
892                            let block_x = mcu_x * h_samp + h;
893                            let block_y = mcu_y * v_samp + v;
894                            let block_idx = block_y * comp_blocks_h + block_x;
895
896                            if is_first_scan {
897                                // AC first scan
898                                decoder.decode_ac_first(
899                                    &mut self.coeffs[comp_idx][block_idx],
900                                    ac_table as usize,
901                                    ss,
902                                    se,
903                                    al,
904                                    &mut eob_run,
905                                )?;
906                            } else {
907                                // AC refinement scan
908                                decoder.decode_ac_refine(
909                                    &mut self.coeffs[comp_idx][block_idx],
910                                    ac_table as usize,
911                                    ss,
912                                    se,
913                                    al,
914                                    &mut eob_run,
915                                )?;
916                            }
917                        }
918                    }
919                }
920            }
921        }
922
923        self.position += decoder.position();
924        Ok(())
925    }
926
927    fn info(&self) -> JpegInfo {
928        let has_icc = self.icc_profile.is_some();
929        let is_xyb = self.icc_profile.as_ref().is_some_and(|p| is_xyb_profile(p));
930
931        // Determine color space, considering XYB profile
932        let color_space = if is_xyb {
933            ColorSpace::Xyb
934        } else {
935            match self.num_components {
936                1 => ColorSpace::Grayscale,
937                3 => ColorSpace::YCbCr,
938                4 => ColorSpace::Cmyk,
939                _ => ColorSpace::Unknown,
940            }
941        };
942
943        JpegInfo {
944            dimensions: Dimensions::new(self.width, self.height),
945            color_space,
946            precision: self.precision,
947            num_components: self.num_components,
948            mode: self.mode,
949            has_icc_profile: has_icc,
950            is_xyb,
951        }
952    }
953
954    fn to_pixels(&self, format: PixelFormat) -> Result<Vec<u8>> {
955        if self.coeffs.is_empty() {
956            return Err(Error::InternalError {
957                reason: "no decoded data",
958            });
959        }
960
961        let width = self.width as usize;
962        let height = self.height as usize;
963
964        // Calculate max sampling factors
965        let mut max_h_samp = 1u8;
966        let mut max_v_samp = 1u8;
967        for i in 0..self.num_components as usize {
968            max_h_samp = max_h_samp.max(self.components[i].h_samp_factor);
969            max_v_samp = max_v_samp.max(self.components[i].v_samp_factor);
970        }
971
972        // MCU dimensions
973        let mcu_width = (max_h_samp as usize) * 8;
974        let mcu_height = (max_v_samp as usize) * 8;
975        let mcu_cols = (width + mcu_width - 1) / mcu_width;
976        let mcu_rows = (height + mcu_height - 1) / mcu_height;
977
978        // Pre-compute component info for efficiency
979        struct CompInfo {
980            quant_idx: usize,
981            h_samp: usize,
982            v_samp: usize,
983            comp_blocks_h: usize,
984            comp_blocks_v: usize,
985            comp_width: usize,
986            comp_height: usize,
987            is_full_res: bool,
988        }
989
990        let mut comp_infos: Vec<CompInfo> = Vec::new();
991        for comp_idx in 0..self.num_components as usize {
992            let h_samp = self.components[comp_idx].h_samp_factor as usize;
993            let v_samp = self.components[comp_idx].v_samp_factor as usize;
994            let comp_blocks_h = mcu_cols * h_samp;
995            let comp_blocks_v = mcu_rows * v_samp;
996            let comp_width = checked_size_2d(comp_blocks_h, 8)?;
997            let comp_height = checked_size_2d(comp_blocks_v, 8)?;
998            comp_infos.push(CompInfo {
999                quant_idx: self.components[comp_idx].quant_table_idx as usize,
1000                h_samp,
1001                v_samp,
1002                comp_blocks_h,
1003                comp_blocks_v,
1004                comp_width,
1005                comp_height,
1006                is_full_res: h_samp == max_h_samp as usize && v_samp == max_v_samp as usize,
1007            });
1008        }
1009
1010        // Initialize bias stats and biases (C++ initializes to 0 via memset)
1011        let mut bias_stats = DequantBiasStats::new(self.num_components as usize);
1012        let mut component_biases: Vec<[f32; DCT_BLOCK_SIZE]> =
1013            vec![[0.0f32; DCT_BLOCK_SIZE]; self.num_components as usize];
1014
1015        // Allocate component planes as f32 (C++ jpegli keeps f32 until final output)
1016        let mut comp_planes_f32: Vec<Vec<f32>> = Vec::new();
1017        for info in &comp_infos {
1018            let comp_plane_size = checked_size_2d(info.comp_width, info.comp_height)?;
1019            comp_planes_f32.push(vec![0.0f32; comp_plane_size]);
1020        }
1021
1022        // Process MCU row by MCU row (matching C++ incremental bias recomputation)
1023        for imcu_row in 0..mcu_rows {
1024            // For each component in this MCU row
1025            for comp_idx in 0..self.num_components as usize {
1026                let info = &comp_infos[comp_idx];
1027                let quant =
1028                    self.quant_tables[info.quant_idx]
1029                        .as_ref()
1030                        .ok_or(Error::InternalError {
1031                            reason: "missing quantization table",
1032                        })?;
1033
1034                // Phase 1: Gather stats for full-res components
1035                if info.is_full_res {
1036                    for iy in 0..info.v_samp {
1037                        let by = imcu_row * info.v_samp + iy;
1038                        if by >= info.comp_blocks_v {
1039                            continue;
1040                        }
1041                        for bx in 0..info.comp_blocks_h {
1042                            let block_idx = by * info.comp_blocks_h + bx;
1043                            if block_idx >= self.coeffs[comp_idx].len() {
1044                                continue;
1045                            }
1046                            let coeffs = &self.coeffs[comp_idx][block_idx];
1047                            let mut natural_coeffs = [0i16; DCT_BLOCK_SIZE];
1048                            for (i, &zi) in JPEG_NATURAL_ORDER[..DCT_BLOCK_SIZE].iter().enumerate()
1049                            {
1050                                natural_coeffs[zi as usize] = coeffs[i];
1051                            }
1052                            bias_stats.gather_block(comp_idx, &natural_coeffs);
1053                        }
1054                    }
1055
1056                    // Phase 2: Recompute biases every 4 MCU rows (matching C++ behavior)
1057                    if imcu_row % 4 == 3 {
1058                        component_biases[comp_idx] = bias_stats.compute_biases(comp_idx);
1059                    }
1060                }
1061
1062                // Phase 3: IDCT for this component in this MCU row
1063                // Store as f32 (C++ jpegli keeps f32 until final output for precision)
1064                let biases = &component_biases[comp_idx];
1065                let comp_plane_f32 = &mut comp_planes_f32[comp_idx];
1066
1067                for iy in 0..info.v_samp {
1068                    let by = imcu_row * info.v_samp + iy;
1069                    if by >= info.comp_blocks_v {
1070                        continue;
1071                    }
1072
1073                    for bx in 0..info.comp_blocks_h {
1074                        let block_idx = by * info.comp_blocks_h + bx;
1075                        if block_idx >= self.coeffs[comp_idx].len() {
1076                            continue;
1077                        }
1078                        let coeffs = &self.coeffs[comp_idx][block_idx];
1079
1080                        let mut natural_coeffs = [0i16; DCT_BLOCK_SIZE];
1081                        for (i, &zi) in JPEG_NATURAL_ORDER[..DCT_BLOCK_SIZE].iter().enumerate() {
1082                            natural_coeffs[zi as usize] = coeffs[i];
1083                        }
1084
1085                        let dequant = dequantize_block_with_bias(&natural_coeffs, quant, biases);
1086                        let pixels = inverse_dct_8x8(&dequant);
1087
1088                        // Store as f32, NO level shift yet (matches C++ which adds 128/255 after color transform)
1089                        for y in 0..DCT_SIZE {
1090                            for x in 0..DCT_SIZE {
1091                                let px = bx * DCT_SIZE + x;
1092                                let py = by * DCT_SIZE + y;
1093                                if px < info.comp_width && py < info.comp_height {
1094                                    comp_plane_f32[py * info.comp_width + px] =
1095                                        pixels[y * DCT_SIZE + x];
1096                                }
1097                            }
1098                        }
1099                    }
1100                }
1101            }
1102        }
1103
1104        // Upsample if needed - keep as f32 for precision
1105        let output_size = checked_size_2d(width, height)?;
1106        let mut planes_f32: Vec<Vec<f32>> = Vec::new();
1107
1108        for comp_idx in 0..self.num_components as usize {
1109            let info = &comp_infos[comp_idx];
1110            let comp_plane_f32 = &comp_planes_f32[comp_idx];
1111
1112            let plane_f32 =
1113                if info.h_samp < max_h_samp as usize || info.v_samp < max_v_samp as usize {
1114                    let scale_x = max_h_samp as usize / info.h_samp;
1115                    let scale_y = max_v_samp as usize / info.v_samp;
1116                    let mut upsampled = vec![0.0f32; output_size];
1117                    for py in 0..height {
1118                        for px in 0..width {
1119                            let sx = (px / scale_x).min(info.comp_width - 1);
1120                            let sy = (py / scale_y).min(info.comp_height - 1);
1121                            upsampled[py * width + px] = comp_plane_f32[sy * info.comp_width + sx];
1122                        }
1123                    }
1124                    upsampled
1125                } else {
1126                    // Full resolution - just clip to image dimensions
1127                    let mut plane = vec![0.0f32; output_size];
1128                    for py in 0..height {
1129                        for px in 0..width {
1130                            plane[py * width + px] = comp_plane_f32[py * info.comp_width + px];
1131                        }
1132                    }
1133                    plane
1134                };
1135
1136            planes_f32.push(plane_f32);
1137        }
1138
1139        // Convert to output format - do color conversion in f32, then convert to u8
1140        match (self.num_components, format) {
1141            (1, PixelFormat::Gray) => {
1142                // Grayscale: level shift and convert to u8
1143                let mut output = try_alloc_zeroed(output_size, "allocating gray output")?;
1144                for (i, &y) in planes_f32[0].iter().enumerate() {
1145                    output[i] = (y + 128.0).round().clamp(0.0, 255.0) as u8;
1146                }
1147                Ok(output)
1148            }
1149            (1, PixelFormat::Rgb) => {
1150                let rgb_size =
1151                    checked_size_2d(width, height).and_then(|s| checked_size_2d(s, 3))?;
1152                let mut rgb = try_alloc_zeroed(rgb_size, "allocating RGB output")?;
1153                for (i, &y) in planes_f32[0].iter().enumerate() {
1154                    let val = (y + 128.0).round().clamp(0.0, 255.0) as u8;
1155                    rgb[i * 3] = val;
1156                    rgb[i * 3 + 1] = val;
1157                    rgb[i * 3 + 2] = val;
1158                }
1159                Ok(rgb)
1160            }
1161            (3, PixelFormat::Rgb) => {
1162                // YCbCr to RGB conversion in f32, then convert to u8
1163                let rgb_size =
1164                    checked_size_2d(width, height).and_then(|s| checked_size_2d(s, 3))?;
1165                let mut rgb = try_alloc_zeroed(rgb_size, "allocating RGB output")?;
1166
1167                for i in 0..output_size {
1168                    // Get YCbCr values (still centered around 0 for Y, 0 for Cb/Cr)
1169                    let y = planes_f32[0][i];
1170                    let cb = planes_f32[1][i]; // Cb is centered around 0 (not 128)
1171                    let cr = planes_f32[2][i]; // Cr is centered around 0 (not 128)
1172
1173                    // YCbCr to RGB conversion (BT.601)
1174                    // R = Y + 1.402 * Cr
1175                    // G = Y - 0.344136 * Cb - 0.714136 * Cr
1176                    // B = Y + 1.772 * Cb
1177                    let r = y + 1.402 * cr;
1178                    let g = y - 0.344136 * cb - 0.714136 * cr;
1179                    let b = y + 1.772 * cb;
1180
1181                    // Level shift (+128) and convert to u8
1182                    rgb[i * 3] = (r + 128.0).round().clamp(0.0, 255.0) as u8;
1183                    rgb[i * 3 + 1] = (g + 128.0).round().clamp(0.0, 255.0) as u8;
1184                    rgb[i * 3 + 2] = (b + 128.0).round().clamp(0.0, 255.0) as u8;
1185                }
1186                Ok(rgb)
1187            }
1188            _ => Err(Error::UnsupportedFeature {
1189                feature: "unsupported color conversion",
1190            }),
1191        }
1192    }
1193}
1194
1195#[cfg(test)]
1196mod tests {
1197    use super::*;
1198    use crate::encode::Encoder;
1199    use crate::quant::Quality;
1200
1201    #[test]
1202    fn test_decoder_creation() {
1203        let decoder = Decoder::new()
1204            .output_format(PixelFormat::Rgb)
1205            .fancy_upsampling(true);
1206
1207        assert_eq!(decoder.config.output_format, Some(PixelFormat::Rgb));
1208        assert!(decoder.config.fancy_upsampling);
1209    }
1210
1211    #[test]
1212    fn test_encode_decode_roundtrip_gray() {
1213        // Create a simple 8x8 grayscale image
1214        let width = 8;
1215        let height = 8;
1216        let mut input = vec![0u8; width * height];
1217        for y in 0..height {
1218            for x in 0..width {
1219                input[y * width + x] = ((x + y) * 16) as u8;
1220            }
1221        }
1222
1223        // Encode
1224        let encoder = Encoder::new()
1225            .width(width as u32)
1226            .height(height as u32)
1227            .pixel_format(PixelFormat::Gray)
1228            .quality(Quality::from_quality(95.0));
1229
1230        let jpeg = encoder.encode(&input).expect("encoding should succeed");
1231
1232        // Verify JPEG structure
1233        assert_eq!(jpeg[0], 0xFF);
1234        assert_eq!(jpeg[1], 0xD8); // SOI
1235        assert_eq!(jpeg[jpeg.len() - 2], 0xFF);
1236        assert_eq!(jpeg[jpeg.len() - 1], 0xD9); // EOI
1237
1238        // Decode
1239        let decoder = Decoder::new().output_format(PixelFormat::Gray);
1240        let decoded = decoder.decode(&jpeg).expect("decoding should succeed");
1241
1242        assert_eq!(decoded.width, width as u32);
1243        assert_eq!(decoded.height, height as u32);
1244        assert_eq!(decoded.data.len(), width * height);
1245
1246        // Check pixel values are reasonably close (JPEG is lossy)
1247        let mut max_diff = 0i32;
1248        for i in 0..input.len() {
1249            let diff = (input[i] as i32 - decoded.data[i] as i32).abs();
1250            max_diff = max_diff.max(diff);
1251        }
1252        // At quality 95, differences should be small
1253        assert!(max_diff < 20, "max_diff {} too large", max_diff);
1254    }
1255
1256    #[test]
1257    fn test_encode_decode_roundtrip_rgb() {
1258        // Create a simple 16x16 RGB image
1259        let width = 16;
1260        let height = 16;
1261        let mut input = vec![0u8; width * height * 3];
1262        for y in 0..height {
1263            for x in 0..width {
1264                let idx = (y * width + x) * 3;
1265                input[idx] = (x * 16) as u8; // R
1266                input[idx + 1] = (y * 16) as u8; // G
1267                input[idx + 2] = 128; // B
1268            }
1269        }
1270
1271        // Encode
1272        let encoder = Encoder::new()
1273            .width(width as u32)
1274            .height(height as u32)
1275            .pixel_format(PixelFormat::Rgb)
1276            .quality(Quality::from_quality(95.0));
1277
1278        let jpeg = encoder.encode(&input).expect("encoding should succeed");
1279
1280        // Decode
1281        let decoder = Decoder::new().output_format(PixelFormat::Rgb);
1282        let decoded = decoder.decode(&jpeg).expect("decoding should succeed");
1283
1284        assert_eq!(decoded.width, width as u32);
1285        assert_eq!(decoded.height, height as u32);
1286        assert_eq!(decoded.data.len(), width * height * 3);
1287
1288        // Check pixel values are reasonably close
1289        let mut max_diff = 0i32;
1290        for i in 0..input.len() {
1291            let diff = (input[i] as i32 - decoded.data[i] as i32).abs();
1292            max_diff = max_diff.max(diff);
1293        }
1294        // At quality 95, differences should be small
1295        assert!(max_diff < 30, "max_diff {} too large", max_diff);
1296    }
1297}