jpegli/
decode.rs

1//! JPEG decoder implementation.
2//!
3//! This module provides the main decoder interface for reading JPEG images.
4//!
5//! # ICC Profile Support
6//!
7//! The decoder can extract and apply embedded ICC profiles, including XYB profiles
8//! used by jpegli. ICC profile support requires enabling `cms-lcms2` or `cms-moxcms` feature.
9//!
10//! ```ignore
11//! use jpegli::decode::Decoder;
12//!
13//! let decoder = Decoder::new().apply_icc(true);
14//! let decoded = decoder.decode(&jpeg_data)?;
15//! ```
16
17use crate::alloc::{
18    checked_size_2d, try_alloc_dct_blocks, try_alloc_zeroed, validate_dimensions,
19    DEFAULT_MAX_MEMORY, DEFAULT_MAX_PIXELS,
20};
21use crate::color;
22use crate::consts::{
23    DCT_BLOCK_SIZE, DCT_SIZE, JPEG_NATURAL_ORDER, MARKER_APP0, MARKER_COM, MARKER_DHT, MARKER_DQT,
24    MARKER_DRI, MARKER_EOI, MARKER_SOF0, MARKER_SOF1, MARKER_SOF2, MARKER_SOI, MARKER_SOS,
25    MAX_COMPONENTS, MAX_HUFFMAN_TABLES, MAX_QUANT_TABLES,
26};
27use crate::entropy::EntropyDecoder;
28use crate::error::{Error, Result};
29use crate::huffman::HuffmanDecodeTable;
30#[cfg(any(feature = "cms-lcms2", feature = "cms-moxcms"))]
31use crate::icc::apply_icc_transform;
32use crate::icc::{extract_icc_profile, is_xyb_profile};
33use crate::idct::inverse_dct_8x8;
34use crate::quant::dequantize_block;
35use crate::types::{ColorSpace, Component, Dimensions, JpegMode, PixelFormat};
36
37/// Decoder configuration.
38#[derive(Debug, Clone)]
39pub struct DecoderConfig {
40    /// Output pixel format (None = use source format)
41    pub output_format: Option<PixelFormat>,
42    /// Whether to apply fancy upsampling
43    pub fancy_upsampling: bool,
44    /// Whether to apply block smoothing
45    pub block_smoothing: bool,
46    /// Whether to apply embedded ICC profile (requires cms feature)
47    pub apply_icc: bool,
48    /// Maximum pixels allowed (for DoS protection).
49    /// Default is 100 megapixels. Set to 0 for unlimited.
50    pub max_pixels: u64,
51    /// Maximum total memory for allocations (for DoS protection).
52    /// Default is 512 MB. Set to 0 for unlimited.
53    pub max_memory: usize,
54}
55
56impl Default for DecoderConfig {
57    fn default() -> Self {
58        Self {
59            output_format: None,
60            fancy_upsampling: false,
61            block_smoothing: false,
62            // Apply ICC by default when CMS is available
63            apply_icc: cfg!(any(feature = "cms-lcms2", feature = "cms-moxcms")),
64            max_pixels: DEFAULT_MAX_PIXELS,
65            max_memory: DEFAULT_MAX_MEMORY,
66        }
67    }
68}
69
70/// Information about a decoded JPEG.
71#[derive(Debug, Clone)]
72pub struct JpegInfo {
73    /// Image dimensions
74    pub dimensions: Dimensions,
75    /// Color space
76    pub color_space: ColorSpace,
77    /// Sample precision (8 or 12 bits)
78    pub precision: u8,
79    /// Number of components
80    pub num_components: u8,
81    /// Encoding mode
82    pub mode: JpegMode,
83    /// Whether an ICC profile is embedded
84    pub has_icc_profile: bool,
85    /// Whether the ICC profile is an XYB profile
86    pub is_xyb: bool,
87}
88
89/// JPEG decoder.
90pub struct Decoder {
91    config: DecoderConfig,
92}
93
94impl Decoder {
95    /// Creates a new decoder with default settings.
96    #[must_use]
97    pub fn new() -> Self {
98        Self {
99            config: DecoderConfig::default(),
100        }
101    }
102
103    /// Creates a decoder from configuration.
104    #[must_use]
105    pub fn from_config(config: DecoderConfig) -> Self {
106        Self { config }
107    }
108
109    /// Sets the output pixel format.
110    #[must_use]
111    pub fn output_format(mut self, format: PixelFormat) -> Self {
112        self.config.output_format = Some(format);
113        self
114    }
115
116    /// Enables fancy upsampling.
117    #[must_use]
118    pub fn fancy_upsampling(mut self, enable: bool) -> Self {
119        self.config.fancy_upsampling = enable;
120        self
121    }
122
123    /// Enables block smoothing.
124    #[must_use]
125    pub fn block_smoothing(mut self, enable: bool) -> Self {
126        self.config.block_smoothing = enable;
127        self
128    }
129
130    /// Enables ICC profile application.
131    ///
132    /// When enabled, embedded ICC profiles will be applied to convert
133    /// the image to sRGB. This is required for correct display of
134    /// XYB-encoded images.
135    ///
136    /// Note: Requires `cms-lcms2` or `cms-moxcms` feature to be enabled.
137    /// Without a CMS feature, this setting has no effect.
138    #[must_use]
139    pub fn apply_icc(mut self, enable: bool) -> Self {
140        self.config.apply_icc = enable;
141        self
142    }
143
144    /// Sets the maximum number of pixels allowed (for DoS protection).
145    ///
146    /// Default is 100 megapixels. Set to 0 for unlimited.
147    #[must_use]
148    pub fn max_pixels(mut self, pixels: u64) -> Self {
149        self.config.max_pixels = pixels;
150        self
151    }
152
153    /// Sets the maximum memory allowed for allocations during decoding.
154    ///
155    /// Default is 512 MB. Set to `usize::MAX` for unlimited.
156    /// This prevents memory exhaustion attacks from malicious images.
157    #[must_use]
158    pub fn max_memory(mut self, bytes: usize) -> Self {
159        self.config.max_memory = bytes;
160        self
161    }
162
163    /// Reads JPEG info without decoding.
164    pub fn read_info(&self, data: &[u8]) -> Result<JpegInfo> {
165        let mut parser = JpegParser::new(data, self.config.max_pixels)?;
166        parser.read_header()?;
167        Ok(parser.info())
168    }
169
170    /// Decodes a JPEG image.
171    pub fn decode(&self, data: &[u8]) -> Result<DecodedImage> {
172        let mut parser = JpegParser::new(data, self.config.max_pixels)?;
173        parser.decode()?;
174
175        let info = parser.info();
176        let output_format = self.config.output_format.unwrap_or(PixelFormat::Rgb);
177
178        // Convert to output format
179        let mut pixels = parser.to_pixels(output_format)?;
180
181        // Apply ICC profile if enabled and present
182        #[cfg(any(feature = "cms-lcms2", feature = "cms-moxcms"))]
183        if self.config.apply_icc && output_format == PixelFormat::Rgb {
184            if let Some(ref icc_profile) = parser.icc_profile {
185                pixels = apply_icc_transform(
186                    &pixels,
187                    info.dimensions.width as usize,
188                    info.dimensions.height as usize,
189                    icc_profile,
190                )?;
191            }
192        }
193
194        Ok(DecodedImage {
195            width: info.dimensions.width,
196            height: info.dimensions.height,
197            format: output_format,
198            data: pixels,
199        })
200    }
201}
202
203impl Default for Decoder {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209/// A decoded image with dimensions and pixel data.
210#[derive(Debug, Clone)]
211#[non_exhaustive]
212pub struct DecodedImage {
213    /// Image width in pixels
214    pub width: u32,
215    /// Image height in pixels
216    pub height: u32,
217    /// Pixel format of the data
218    pub format: PixelFormat,
219    /// Raw pixel data in the specified format
220    pub data: Vec<u8>,
221}
222
223impl DecodedImage {
224    /// Returns the image dimensions as a tuple (width, height).
225    #[must_use]
226    pub fn dimensions(&self) -> (u32, u32) {
227        (self.width, self.height)
228    }
229
230    /// Returns the number of bytes per pixel for this image's format.
231    #[must_use]
232    pub fn bytes_per_pixel(&self) -> usize {
233        self.format.bytes_per_pixel()
234    }
235
236    /// Returns the stride (bytes per row) of the image.
237    #[must_use]
238    pub fn stride(&self) -> usize {
239        self.width as usize * self.bytes_per_pixel()
240    }
241}
242
243/// Internal JPEG parser state.
244struct JpegParser<'a> {
245    data: &'a [u8],
246    position: usize,
247
248    // Frame info
249    width: u32,
250    height: u32,
251    precision: u8,
252    num_components: u8,
253    mode: JpegMode,
254
255    // Component info
256    components: [Component; MAX_COMPONENTS],
257
258    // Tables
259    quant_tables: [Option<[u16; DCT_BLOCK_SIZE]>; MAX_QUANT_TABLES],
260    dc_tables: [Option<HuffmanDecodeTable>; MAX_HUFFMAN_TABLES],
261    ac_tables: [Option<HuffmanDecodeTable>; MAX_HUFFMAN_TABLES],
262
263    // Restart
264    restart_interval: u16,
265
266    // Decoded coefficient data
267    coeffs: Vec<Vec<[i16; DCT_BLOCK_SIZE]>>, // Per component
268
269    // ICC profile (extracted from raw data, not during parsing)
270    icc_profile: Option<Vec<u8>>,
271
272    // Security limits
273    max_pixels: u64,
274}
275
276impl<'a> JpegParser<'a> {
277    fn new(data: &'a [u8], max_pixels: u64) -> Result<Self> {
278        // Check for SOI
279        if data.len() < 2 || data[0] != 0xFF || data[1] != MARKER_SOI {
280            return Err(Error::InvalidJpegData {
281                reason: "missing SOI marker",
282            });
283        }
284
285        // Extract ICC profile from raw data upfront
286        let icc_profile = extract_icc_profile(data);
287
288        Ok(Self {
289            data,
290            position: 2,
291            width: 0,
292            height: 0,
293            precision: 8,
294            num_components: 0,
295            mode: JpegMode::Baseline,
296            components: std::array::from_fn(|_| Component::default()),
297            quant_tables: [None, None, None, None],
298            dc_tables: [None, None, None, None],
299            ac_tables: [None, None, None, None],
300            restart_interval: 0,
301            coeffs: Vec::new(),
302            icc_profile,
303            max_pixels,
304        })
305    }
306
307    fn read_u8(&mut self) -> Result<u8> {
308        if self.position >= self.data.len() {
309            return Err(Error::UnexpectedEof {
310                context: "reading byte",
311            });
312        }
313        let byte = self.data[self.position];
314        self.position += 1;
315        Ok(byte)
316    }
317
318    fn read_u16(&mut self) -> Result<u16> {
319        let high = self.read_u8()? as u16;
320        let low = self.read_u8()? as u16;
321        Ok((high << 8) | low)
322    }
323
324    fn read_marker(&mut self) -> Result<u8> {
325        loop {
326            let byte = self.read_u8()?;
327            if byte != 0xFF {
328                continue;
329            }
330
331            let marker = self.read_u8()?;
332            if marker != 0x00 && marker != 0xFF {
333                return Ok(marker);
334            }
335        }
336    }
337
338    fn read_header(&mut self) -> Result<()> {
339        loop {
340            let marker = self.read_marker()?;
341
342            match marker {
343                MARKER_SOF0 | MARKER_SOF1 => {
344                    self.mode = JpegMode::Baseline;
345                    self.parse_frame_header()?;
346                    return Ok(());
347                }
348                MARKER_SOF2 => {
349                    self.mode = JpegMode::Progressive;
350                    self.parse_frame_header()?;
351                    return Ok(());
352                }
353                MARKER_DQT => self.parse_quant_table()?,
354                MARKER_DHT => self.parse_huffman_table()?,
355                MARKER_DRI => self.parse_restart_interval()?,
356                MARKER_APP0..=0xEF | MARKER_COM => self.skip_segment()?,
357                MARKER_EOI => {
358                    return Err(Error::InvalidJpegData {
359                        reason: "unexpected EOI before frame header",
360                    });
361                }
362                _ => self.skip_segment()?,
363            }
364        }
365    }
366
367    fn parse_frame_header(&mut self) -> Result<()> {
368        let length = self.read_u16()?;
369        if length < 8 {
370            return Err(Error::InvalidJpegData {
371                reason: "frame header too short",
372            });
373        }
374
375        self.precision = self.read_u8()?;
376        // Validate precision: must be 8 for baseline JPEG, 8 or 12 for extended
377        if self.precision != 8 && self.precision != 12 {
378            return Err(Error::InvalidJpegData {
379                reason: "invalid data precision (must be 8 or 12)",
380            });
381        }
382
383        self.height = self.read_u16()? as u32;
384        self.width = self.read_u16()? as u32;
385
386        // Validate dimensions against security limits
387        // max_pixels == 0 means unlimited
388        let effective_max = if self.max_pixels == 0 {
389            u64::MAX
390        } else {
391            self.max_pixels
392        };
393        validate_dimensions(self.width, self.height, effective_max)?;
394
395        self.num_components = self.read_u8()?;
396
397        // Validate num_components
398        if self.num_components == 0 {
399            return Err(Error::InvalidJpegData {
400                reason: "number of components is zero",
401            });
402        }
403        if self.num_components > MAX_COMPONENTS as u8 {
404            return Err(Error::UnsupportedFeature {
405                feature: "more than 4 components",
406            });
407        }
408
409        // Validate marker length matches expected size
410        let expected_length = 8 + 3 * self.num_components as u16;
411        if length != expected_length {
412            return Err(Error::InvalidJpegData {
413                reason: "SOF marker length mismatch",
414            });
415        }
416
417        for i in 0..self.num_components as usize {
418            self.components[i].id = self.read_u8()?;
419            let sampling = self.read_u8()?;
420            let h_samp = sampling >> 4;
421            let v_samp = sampling & 0x0F;
422
423            // Validate sampling factors are non-zero and <= 4
424            if h_samp == 0 || v_samp == 0 {
425                return Err(Error::InvalidJpegData {
426                    reason: "sampling factor is zero",
427                });
428            }
429            if h_samp > 4 || v_samp > 4 {
430                return Err(Error::InvalidJpegData {
431                    reason: "sampling factor exceeds maximum (4)",
432                });
433            }
434
435            self.components[i].h_samp_factor = h_samp;
436            self.components[i].v_samp_factor = v_samp;
437
438            let quant_idx = self.read_u8()?;
439            // Validate quant table index
440            if quant_idx as usize >= MAX_QUANT_TABLES {
441                return Err(Error::InvalidJpegData {
442                    reason: "quantization table index out of range",
443                });
444            }
445            self.components[i].quant_table_idx = quant_idx;
446        }
447
448        Ok(())
449    }
450
451    fn parse_quant_table(&mut self) -> Result<()> {
452        let mut length = self.read_u16()? as i32 - 2;
453
454        while length > 0 {
455            let info = self.read_u8()?;
456            let precision = info >> 4;
457            let table_idx = (info & 0x0F) as usize;
458
459            // Validate precision (0 = 8-bit, 1 = 16-bit)
460            if precision > 1 {
461                return Err(Error::InvalidQuantTable {
462                    table_idx: table_idx as u8,
463                    reason: "invalid precision (must be 0 or 1)",
464                });
465            }
466
467            if table_idx >= MAX_QUANT_TABLES {
468                return Err(Error::InvalidQuantTable {
469                    table_idx: table_idx as u8,
470                    reason: "table index out of range",
471                });
472            }
473
474            // Read values in zigzag order (as stored in JPEG)
475            let mut zigzag_values = [0u16; DCT_BLOCK_SIZE];
476
477            if precision == 0 {
478                // 8-bit values
479                for i in 0..DCT_BLOCK_SIZE {
480                    let val = self.read_u8()? as u16;
481                    if val == 0 {
482                        return Err(Error::InvalidQuantTable {
483                            table_idx: table_idx as u8,
484                            reason: "quantization value is zero",
485                        });
486                    }
487                    zigzag_values[i] = val;
488                }
489                length -= 65;
490            } else {
491                // 16-bit values
492                for i in 0..DCT_BLOCK_SIZE {
493                    let val = self.read_u16()?;
494                    if val == 0 {
495                        return Err(Error::InvalidQuantTable {
496                            table_idx: table_idx as u8,
497                            reason: "quantization value is zero",
498                        });
499                    }
500                    zigzag_values[i] = val;
501                }
502                length -= 129;
503            }
504
505            // Validate DQT marker length consistency
506            if length < 0 {
507                return Err(Error::InvalidJpegData {
508                    reason: "DQT marker length mismatch",
509                });
510            }
511
512            // Convert from zigzag order to natural order for dequantization
513            let mut natural_values = [0u16; DCT_BLOCK_SIZE];
514            for i in 0..DCT_BLOCK_SIZE {
515                natural_values[JPEG_NATURAL_ORDER[i] as usize] = zigzag_values[i];
516            }
517
518            self.quant_tables[table_idx] = Some(natural_values);
519        }
520
521        Ok(())
522    }
523
524    fn parse_huffman_table(&mut self) -> Result<()> {
525        let mut length = self.read_u16()? as i32 - 2;
526
527        while length > 0 {
528            let info = self.read_u8()?;
529            let table_class = info >> 4; // 0 = DC, 1 = AC
530            let table_idx = (info & 0x0F) as usize;
531
532            // Validate table class (must be 0 for DC or 1 for AC)
533            if table_class > 1 {
534                return Err(Error::InvalidHuffmanTable {
535                    table_idx: table_idx as u8,
536                    reason: "invalid table class (must be 0 or 1)",
537                });
538            }
539
540            if table_idx >= MAX_HUFFMAN_TABLES {
541                return Err(Error::InvalidHuffmanTable {
542                    table_idx: table_idx as u8,
543                    reason: "table index out of range",
544                });
545            }
546
547            let mut bits = [0u8; 16];
548            for i in 0..16 {
549                bits[i] = self.read_u8()?;
550            }
551
552            let num_values: usize = bits.iter().map(|&b| b as usize).sum();
553            let mut values = vec![0u8; num_values];
554            for i in 0..num_values {
555                values[i] = self.read_u8()?;
556            }
557
558            length -= 17 + num_values as i32;
559
560            // Validate that we didn't read past the marker length
561            if length < 0 {
562                return Err(Error::InvalidJpegData {
563                    reason: "DHT marker length mismatch",
564                });
565            }
566
567            let table = HuffmanDecodeTable::from_bits_values(&bits, &values)?;
568
569            if table_class == 0 {
570                self.dc_tables[table_idx] = Some(table);
571            } else {
572                self.ac_tables[table_idx] = Some(table);
573            }
574        }
575
576        Ok(())
577    }
578
579    fn parse_restart_interval(&mut self) -> Result<()> {
580        let _length = self.read_u16()?;
581        self.restart_interval = self.read_u16()?;
582        Ok(())
583    }
584
585    fn skip_segment(&mut self) -> Result<()> {
586        let length = self.read_u16()? as usize;
587        if length < 2 {
588            return Err(Error::InvalidJpegData {
589                reason: "segment length too short",
590            });
591        }
592        self.position += length - 2;
593        Ok(())
594    }
595
596    fn decode(&mut self) -> Result<()> {
597        // First read header
598        self.position = 2; // Skip SOI
599        self.read_header()?;
600
601        // Continue parsing until we hit SOS
602        loop {
603            let marker = self.read_marker()?;
604
605            match marker {
606                MARKER_SOS => {
607                    self.parse_scan()?;
608                    // After scan, look for more markers
609                }
610                MARKER_DQT => self.parse_quant_table()?,
611                MARKER_DHT => self.parse_huffman_table()?,
612                MARKER_DRI => self.parse_restart_interval()?,
613                MARKER_EOI => break,
614                MARKER_APP0..=0xEF | MARKER_COM => self.skip_segment()?,
615                _ => self.skip_segment()?,
616            }
617        }
618
619        Ok(())
620    }
621
622    fn parse_scan(&mut self) -> Result<()> {
623        let _length = self.read_u16()?;
624        let num_components = self.read_u8()?;
625
626        // Validate num_components in scan
627        if num_components == 0 {
628            return Err(Error::InvalidJpegData {
629                reason: "SOS num_components is zero",
630            });
631        }
632        if num_components > self.num_components {
633            return Err(Error::InvalidJpegData {
634                reason: "SOS num_components exceeds frame components",
635            });
636        }
637        if num_components > MAX_COMPONENTS as u8 {
638            return Err(Error::InvalidJpegData {
639                reason: "SOS num_components too large",
640            });
641        }
642
643        let mut scan_components = Vec::with_capacity(num_components as usize);
644
645        for _ in 0..num_components {
646            let component_id = self.read_u8()?;
647            let tables = self.read_u8()?;
648            let dc_table = tables >> 4;
649            let ac_table = tables & 0x0F;
650
651            // Validate Huffman table indexes
652            if dc_table as usize >= MAX_HUFFMAN_TABLES {
653                return Err(Error::InvalidJpegData {
654                    reason: "SOS DC Huffman table index out of range",
655                });
656            }
657            if ac_table as usize >= MAX_HUFFMAN_TABLES {
658                return Err(Error::InvalidJpegData {
659                    reason: "SOS AC Huffman table index out of range",
660                });
661            }
662
663            // Find component index
664            let comp_idx = self.components[..self.num_components as usize]
665                .iter()
666                .position(|c| c.id == component_id)
667                .ok_or(Error::InvalidJpegData {
668                    reason: "unknown component in scan",
669                })?;
670
671            scan_components.push((comp_idx, dc_table, ac_table));
672        }
673
674        let ss = self.read_u8()?; // Spectral selection start
675        let se = self.read_u8()?; // Spectral selection end
676        let ah_al = self.read_u8()?;
677        let ah = ah_al >> 4;
678        let al = ah_al & 0x0F;
679
680        // Validate spectral selection (must be 0-63)
681        if ss > 63 {
682            return Err(Error::InvalidJpegData {
683                reason: "SOS Ss (spectral start) out of range",
684            });
685        }
686        if se > 63 {
687            return Err(Error::InvalidJpegData {
688                reason: "SOS Se (spectral end) out of range",
689            });
690        }
691
692        // Decode entropy-coded segment based on mode
693        if self.mode == JpegMode::Progressive {
694            self.decode_progressive_scan(&scan_components, ss, se, ah, al)?;
695        } else {
696            self.decode_scan(&scan_components)?;
697        }
698
699        Ok(())
700    }
701
702    fn decode_scan(&mut self, scan_components: &[(usize, u8, u8)]) -> Result<()> {
703        // Calculate max sampling factors to determine MCU structure
704        let mut max_h_samp = 1u8;
705        let mut max_v_samp = 1u8;
706        for i in 0..self.num_components as usize {
707            max_h_samp = max_h_samp.max(self.components[i].h_samp_factor);
708            max_v_samp = max_v_samp.max(self.components[i].v_samp_factor);
709        }
710
711        // MCU dimensions in pixels
712        let mcu_width = (max_h_samp as usize) * 8;
713        let mcu_height = (max_v_samp as usize) * 8;
714
715        // Number of MCUs
716        let mcu_cols = (self.width as usize + mcu_width - 1) / mcu_width;
717        let mcu_rows = (self.height as usize + mcu_height - 1) / mcu_height;
718
719        // Initialize coefficient storage - size depends on component's sampling factor
720        if self.coeffs.is_empty() {
721            for i in 0..self.num_components as usize {
722                let h_samp = self.components[i].h_samp_factor as usize;
723                let v_samp = self.components[i].v_samp_factor as usize;
724                let comp_blocks_h = checked_size_2d(mcu_cols, h_samp)?;
725                let comp_blocks_v = checked_size_2d(mcu_rows, v_samp)?;
726                let num_blocks = checked_size_2d(comp_blocks_h, comp_blocks_v)?;
727                self.coeffs.push(try_alloc_dct_blocks(
728                    num_blocks,
729                    "allocating DCT coefficients",
730                )?);
731            }
732        }
733
734        // Set up entropy decoder
735        let scan_data = &self.data[self.position..];
736        let mut decoder = EntropyDecoder::new(scan_data);
737
738        for (_comp_idx, dc_table, ac_table) in scan_components {
739            let dc_idx = (*dc_table as usize).min(MAX_HUFFMAN_TABLES - 1);
740            let ac_idx = (*ac_table as usize).min(MAX_HUFFMAN_TABLES - 1);
741            if let Some(table) = &self.dc_tables[dc_idx] {
742                decoder.set_dc_table(dc_idx, table.clone());
743            }
744            if let Some(table) = &self.ac_tables[ac_idx] {
745                decoder.set_ac_table(ac_idx, table.clone());
746            }
747        }
748
749        // Decode MCUs with proper interleaving
750        for mcu_y in 0..mcu_rows {
751            for mcu_x in 0..mcu_cols {
752                // For each component in the scan
753                for (comp_idx, dc_table, ac_table) in scan_components {
754                    let h_samp = self.components[*comp_idx].h_samp_factor as usize;
755                    let v_samp = self.components[*comp_idx].v_samp_factor as usize;
756                    let comp_blocks_h = mcu_cols * h_samp;
757
758                    // Decode all blocks for this component in this MCU
759                    for v in 0..v_samp {
760                        for h in 0..h_samp {
761                            let block_x = mcu_x * h_samp + h;
762                            let block_y = mcu_y * v_samp + v;
763                            let block_idx = block_y * comp_blocks_h + block_x;
764
765                            let coeffs = decoder.decode_block(
766                                *comp_idx,
767                                *dc_table as usize,
768                                *ac_table as usize,
769                            )?;
770                            self.coeffs[*comp_idx][block_idx] = coeffs;
771                        }
772                    }
773                }
774            }
775        }
776
777        self.position += decoder.position();
778        Ok(())
779    }
780
781    fn decode_progressive_scan(
782        &mut self,
783        scan_components: &[(usize, u8, u8)],
784        ss: u8,
785        se: u8,
786        ah: u8,
787        al: u8,
788    ) -> Result<()> {
789        // Calculate max sampling factors to determine MCU structure
790        let mut max_h_samp = 1u8;
791        let mut max_v_samp = 1u8;
792        for i in 0..self.num_components as usize {
793            max_h_samp = max_h_samp.max(self.components[i].h_samp_factor);
794            max_v_samp = max_v_samp.max(self.components[i].v_samp_factor);
795        }
796
797        // MCU dimensions in pixels
798        let mcu_width = (max_h_samp as usize) * 8;
799        let mcu_height = (max_v_samp as usize) * 8;
800
801        // Number of MCUs
802        let mcu_cols = (self.width as usize + mcu_width - 1) / mcu_width;
803        let mcu_rows = (self.height as usize + mcu_height - 1) / mcu_height;
804
805        // Initialize coefficient storage if not already done
806        if self.coeffs.is_empty() {
807            for i in 0..self.num_components as usize {
808                let h_samp = self.components[i].h_samp_factor as usize;
809                let v_samp = self.components[i].v_samp_factor as usize;
810                let comp_blocks_h = checked_size_2d(mcu_cols, h_samp)?;
811                let comp_blocks_v = checked_size_2d(mcu_rows, v_samp)?;
812                let num_blocks = checked_size_2d(comp_blocks_h, comp_blocks_v)?;
813                self.coeffs.push(try_alloc_dct_blocks(
814                    num_blocks,
815                    "allocating DCT coefficients",
816                )?);
817            }
818        }
819
820        // Set up entropy decoder
821        let scan_data = &self.data[self.position..];
822        let mut decoder = EntropyDecoder::new(scan_data);
823
824        for (_comp_idx, dc_table, ac_table) in scan_components {
825            let dc_idx = (*dc_table as usize).min(MAX_HUFFMAN_TABLES - 1);
826            let ac_idx = (*ac_table as usize).min(MAX_HUFFMAN_TABLES - 1);
827            if let Some(table) = &self.dc_tables[dc_idx] {
828                decoder.set_dc_table(dc_idx, table.clone());
829            }
830            if let Some(table) = &self.ac_tables[ac_idx] {
831                decoder.set_ac_table(ac_idx, table.clone());
832            }
833        }
834
835        // Determine scan type
836        let is_dc_scan = ss == 0 && se == 0;
837        let is_first_scan = ah == 0;
838
839        // EOB run tracking for AC scans
840        let mut eob_run = 0u16;
841
842        if is_dc_scan {
843            // DC scan (interleaved or single component)
844            for mcu_y in 0..mcu_rows {
845                for mcu_x in 0..mcu_cols {
846                    for (comp_idx, dc_table, _ac_table) in scan_components {
847                        let h_samp = self.components[*comp_idx].h_samp_factor as usize;
848                        let v_samp = self.components[*comp_idx].v_samp_factor as usize;
849                        let comp_blocks_h = mcu_cols * h_samp;
850
851                        for v in 0..v_samp {
852                            for h in 0..h_samp {
853                                let block_x = mcu_x * h_samp + h;
854                                let block_y = mcu_y * v_samp + v;
855                                let block_idx = block_y * comp_blocks_h + block_x;
856
857                                if is_first_scan {
858                                    // DC first scan
859                                    let dc = decoder.decode_dc_first(
860                                        *comp_idx,
861                                        *dc_table as usize,
862                                        al,
863                                    )?;
864                                    self.coeffs[*comp_idx][block_idx][0] = dc;
865                                } else {
866                                    // DC refinement scan
867                                    let bit = decoder.decode_dc_refine(al)?;
868                                    self.coeffs[*comp_idx][block_idx][0] |= bit;
869                                }
870                            }
871                        }
872                    }
873                }
874            }
875        } else {
876            // AC scan (single component only for progressive)
877            // Progressive AC scans can only have one component
878            if scan_components.len() != 1 {
879                return Err(Error::InvalidJpegData {
880                    reason: "progressive AC scan must have single component",
881                });
882            }
883
884            let (comp_idx, _dc_table, ac_table) = scan_components[0];
885            let h_samp = self.components[comp_idx].h_samp_factor as usize;
886            let v_samp = self.components[comp_idx].v_samp_factor as usize;
887            let comp_blocks_h = mcu_cols * h_samp;
888
889            for mcu_y in 0..mcu_rows {
890                for mcu_x in 0..mcu_cols {
891                    for v in 0..v_samp {
892                        for h in 0..h_samp {
893                            let block_x = mcu_x * h_samp + h;
894                            let block_y = mcu_y * v_samp + v;
895                            let block_idx = block_y * comp_blocks_h + block_x;
896
897                            if is_first_scan {
898                                // AC first scan
899                                decoder.decode_ac_first(
900                                    &mut self.coeffs[comp_idx][block_idx],
901                                    ac_table as usize,
902                                    ss,
903                                    se,
904                                    al,
905                                    &mut eob_run,
906                                )?;
907                            } else {
908                                // AC refinement scan
909                                decoder.decode_ac_refine(
910                                    &mut self.coeffs[comp_idx][block_idx],
911                                    ac_table as usize,
912                                    ss,
913                                    se,
914                                    al,
915                                    &mut eob_run,
916                                )?;
917                            }
918                        }
919                    }
920                }
921            }
922        }
923
924        self.position += decoder.position();
925        Ok(())
926    }
927
928    fn info(&self) -> JpegInfo {
929        let has_icc = self.icc_profile.is_some();
930        let is_xyb = self.icc_profile.as_ref().is_some_and(|p| is_xyb_profile(p));
931
932        // Determine color space, considering XYB profile
933        let color_space = if is_xyb {
934            ColorSpace::Xyb
935        } else {
936            match self.num_components {
937                1 => ColorSpace::Grayscale,
938                3 => ColorSpace::YCbCr,
939                4 => ColorSpace::Cmyk,
940                _ => ColorSpace::Unknown,
941            }
942        };
943
944        JpegInfo {
945            dimensions: Dimensions::new(self.width, self.height),
946            color_space,
947            precision: self.precision,
948            num_components: self.num_components,
949            mode: self.mode,
950            has_icc_profile: has_icc,
951            is_xyb,
952        }
953    }
954
955    fn to_pixels(&self, format: PixelFormat) -> Result<Vec<u8>> {
956        if self.coeffs.is_empty() {
957            return Err(Error::InternalError {
958                reason: "no decoded data",
959            });
960        }
961
962        let width = self.width as usize;
963        let height = self.height as usize;
964
965        // Calculate max sampling factors
966        let mut max_h_samp = 1u8;
967        let mut max_v_samp = 1u8;
968        for i in 0..self.num_components as usize {
969            max_h_samp = max_h_samp.max(self.components[i].h_samp_factor);
970            max_v_samp = max_v_samp.max(self.components[i].v_samp_factor);
971        }
972
973        // MCU dimensions
974        let mcu_width = (max_h_samp as usize) * 8;
975        let mcu_height = (max_v_samp as usize) * 8;
976        let mcu_cols = (width + mcu_width - 1) / mcu_width;
977        let mcu_rows = (height + mcu_height - 1) / mcu_height;
978
979        // Dequantize and IDCT all blocks, then upsample if needed
980        let mut planes: Vec<Vec<u8>> = Vec::new();
981
982        for comp_idx in 0..self.num_components as usize {
983            let quant_idx = self.components[comp_idx].quant_table_idx as usize;
984            let quant = self.quant_tables[quant_idx]
985                .as_ref()
986                .ok_or(Error::InternalError {
987                    reason: "missing quantization table",
988                })?;
989
990            let h_samp = self.components[comp_idx].h_samp_factor as usize;
991            let v_samp = self.components[comp_idx].v_samp_factor as usize;
992            let comp_blocks_h = mcu_cols * h_samp;
993            let comp_blocks_v = mcu_rows * v_samp;
994
995            // Component plane dimensions (may be smaller than full image for subsampled)
996            let comp_width = checked_size_2d(comp_blocks_h, 8)?;
997            let comp_height = checked_size_2d(comp_blocks_v, 8)?;
998
999            let comp_plane_size = checked_size_2d(comp_width, comp_height)?;
1000            let mut comp_plane = try_alloc_zeroed(comp_plane_size, "allocating component plane")?;
1001
1002            for by in 0..comp_blocks_v {
1003                for bx in 0..comp_blocks_h {
1004                    let block_idx = by * comp_blocks_h + bx;
1005                    if block_idx >= self.coeffs[comp_idx].len() {
1006                        continue;
1007                    }
1008                    let coeffs = &self.coeffs[comp_idx][block_idx];
1009
1010                    // Convert to natural order and dequantize
1011                    let mut natural_coeffs = [0i16; DCT_BLOCK_SIZE];
1012                    for (i, &zi) in JPEG_NATURAL_ORDER[..DCT_BLOCK_SIZE].iter().enumerate() {
1013                        natural_coeffs[zi as usize] = coeffs[i];
1014                    }
1015
1016                    let dequant = dequantize_block(&natural_coeffs, quant);
1017                    let pixels = inverse_dct_8x8(&dequant);
1018
1019                    // Copy to component plane with level shift
1020                    for y in 0..DCT_SIZE {
1021                        for x in 0..DCT_SIZE {
1022                            let px = bx * DCT_SIZE + x;
1023                            let py = by * DCT_SIZE + y;
1024                            if px < comp_width && py < comp_height {
1025                                let val = (pixels[y * DCT_SIZE + x] + 128.0).round();
1026                                comp_plane[py * comp_width + px] = val.clamp(0.0, 255.0) as u8;
1027                            }
1028                        }
1029                    }
1030                }
1031            }
1032
1033            // Upsample if this component has lower sampling than max
1034            let output_size = checked_size_2d(width, height)?;
1035            let plane = if h_samp < max_h_samp as usize || v_samp < max_v_samp as usize {
1036                let scale_x = max_h_samp as usize / h_samp;
1037                let scale_y = max_v_samp as usize / v_samp;
1038                let mut upsampled = try_alloc_zeroed(output_size, "allocating upsampled plane")?;
1039                for py in 0..height {
1040                    for px in 0..width {
1041                        let sx = (px / scale_x).min(comp_width - 1);
1042                        let sy = (py / scale_y).min(comp_height - 1);
1043                        upsampled[py * width + px] = comp_plane[sy * comp_width + sx];
1044                    }
1045                }
1046                upsampled
1047            } else {
1048                // Full resolution - just clip to image dimensions
1049                let mut plane = try_alloc_zeroed(output_size, "allocating output plane")?;
1050                for py in 0..height {
1051                    for px in 0..width {
1052                        plane[py * width + px] = comp_plane[py * comp_width + px];
1053                    }
1054                }
1055                plane
1056            };
1057
1058            planes.push(plane);
1059        }
1060
1061        // Convert to output format
1062        match (self.num_components, format) {
1063            (1, PixelFormat::Gray) => Ok(planes[0].clone()),
1064            (1, PixelFormat::Rgb) => {
1065                let rgb_size =
1066                    checked_size_2d(width, height).and_then(|s| checked_size_2d(s, 3))?;
1067                let mut rgb = try_alloc_zeroed(rgb_size, "allocating RGB output")?;
1068                for (i, &y) in planes[0].iter().enumerate() {
1069                    rgb[i * 3] = y;
1070                    rgb[i * 3 + 1] = y;
1071                    rgb[i * 3 + 2] = y;
1072                }
1073                Ok(rgb)
1074            }
1075            (3, PixelFormat::Rgb) => {
1076                color::ycbcr_planes_to_rgb(&planes[0], &planes[1], &planes[2], width, height)
1077            }
1078            _ => Err(Error::UnsupportedFeature {
1079                feature: "unsupported color conversion",
1080            }),
1081        }
1082    }
1083}
1084
1085#[cfg(test)]
1086mod tests {
1087    use super::*;
1088    use crate::encode::Encoder;
1089    use crate::quant::Quality;
1090
1091    #[test]
1092    fn test_decoder_creation() {
1093        let decoder = Decoder::new()
1094            .output_format(PixelFormat::Rgb)
1095            .fancy_upsampling(true);
1096
1097        assert_eq!(decoder.config.output_format, Some(PixelFormat::Rgb));
1098        assert!(decoder.config.fancy_upsampling);
1099    }
1100
1101    #[test]
1102    fn test_encode_decode_roundtrip_gray() {
1103        // Create a simple 8x8 grayscale image
1104        let width = 8;
1105        let height = 8;
1106        let mut input = vec![0u8; width * height];
1107        for y in 0..height {
1108            for x in 0..width {
1109                input[y * width + x] = ((x + y) * 16) as u8;
1110            }
1111        }
1112
1113        // Encode
1114        let encoder = Encoder::new()
1115            .width(width as u32)
1116            .height(height as u32)
1117            .pixel_format(PixelFormat::Gray)
1118            .quality(Quality::from_quality(95.0));
1119
1120        let jpeg = encoder.encode(&input).expect("encoding should succeed");
1121
1122        // Verify JPEG structure
1123        assert_eq!(jpeg[0], 0xFF);
1124        assert_eq!(jpeg[1], 0xD8); // SOI
1125        assert_eq!(jpeg[jpeg.len() - 2], 0xFF);
1126        assert_eq!(jpeg[jpeg.len() - 1], 0xD9); // EOI
1127
1128        // Decode
1129        let decoder = Decoder::new().output_format(PixelFormat::Gray);
1130        let decoded = decoder.decode(&jpeg).expect("decoding should succeed");
1131
1132        assert_eq!(decoded.width, width as u32);
1133        assert_eq!(decoded.height, height as u32);
1134        assert_eq!(decoded.data.len(), width * height);
1135
1136        // Check pixel values are reasonably close (JPEG is lossy)
1137        let mut max_diff = 0i32;
1138        for i in 0..input.len() {
1139            let diff = (input[i] as i32 - decoded.data[i] as i32).abs();
1140            max_diff = max_diff.max(diff);
1141        }
1142        // At quality 95, differences should be small
1143        assert!(max_diff < 20, "max_diff {} too large", max_diff);
1144    }
1145
1146    #[test]
1147    fn test_encode_decode_roundtrip_rgb() {
1148        // Create a simple 16x16 RGB image
1149        let width = 16;
1150        let height = 16;
1151        let mut input = vec![0u8; width * height * 3];
1152        for y in 0..height {
1153            for x in 0..width {
1154                let idx = (y * width + x) * 3;
1155                input[idx] = (x * 16) as u8; // R
1156                input[idx + 1] = (y * 16) as u8; // G
1157                input[idx + 2] = 128; // B
1158            }
1159        }
1160
1161        // Encode
1162        let encoder = Encoder::new()
1163            .width(width as u32)
1164            .height(height as u32)
1165            .pixel_format(PixelFormat::Rgb)
1166            .quality(Quality::from_quality(95.0));
1167
1168        let jpeg = encoder.encode(&input).expect("encoding should succeed");
1169
1170        // Decode
1171        let decoder = Decoder::new().output_format(PixelFormat::Rgb);
1172        let decoded = decoder.decode(&jpeg).expect("decoding should succeed");
1173
1174        assert_eq!(decoded.width, width as u32);
1175        assert_eq!(decoded.height, height as u32);
1176        assert_eq!(decoded.data.len(), width * height * 3);
1177
1178        // Check pixel values are reasonably close
1179        let mut max_diff = 0i32;
1180        for i in 0..input.len() {
1181            let diff = (input[i] as i32 - decoded.data[i] as i32).abs();
1182            max_diff = max_diff.max(diff);
1183        }
1184        // At quality 95, differences should be small
1185        assert!(max_diff < 30, "max_diff {} too large", max_diff);
1186    }
1187}