1use crate::alloc::{
18 checked_size_2d, try_alloc_dct_blocks, try_alloc_zeroed, validate_dimensions,
19 DEFAULT_MAX_MEMORY, DEFAULT_MAX_PIXELS,
20};
21use crate::consts::{
22 DCT_BLOCK_SIZE, DCT_SIZE, JPEG_NATURAL_ORDER, MARKER_APP0, MARKER_COM, MARKER_DHT, MARKER_DQT,
23 MARKER_DRI, MARKER_EOI, MARKER_SOF0, MARKER_SOF1, MARKER_SOF2, MARKER_SOI, MARKER_SOS,
24 MAX_COMPONENTS, MAX_HUFFMAN_TABLES, MAX_QUANT_TABLES,
25};
26use crate::entropy::EntropyDecoder;
27use crate::error::{Error, Result};
28use crate::huffman::HuffmanDecodeTable;
29#[cfg(any(feature = "cms-lcms2", feature = "cms-moxcms"))]
30use crate::icc::apply_icc_transform;
31use crate::icc::{extract_icc_profile, is_xyb_profile};
32use crate::idct::inverse_dct_8x8;
33use crate::quant::{dequantize_block_with_bias, DequantBiasStats};
34use crate::types::{ColorSpace, Component, Dimensions, JpegMode, PixelFormat};
35
36#[derive(Debug, Clone)]
38pub struct DecoderConfig {
39 pub output_format: Option<PixelFormat>,
41 pub fancy_upsampling: bool,
43 pub block_smoothing: bool,
45 pub apply_icc: bool,
47 pub max_pixels: u64,
50 pub max_memory: usize,
53}
54
55impl Default for DecoderConfig {
56 fn default() -> Self {
57 Self {
58 output_format: None,
59 fancy_upsampling: false,
60 block_smoothing: false,
61 apply_icc: cfg!(any(feature = "cms-lcms2", feature = "cms-moxcms")),
63 max_pixels: DEFAULT_MAX_PIXELS,
64 max_memory: DEFAULT_MAX_MEMORY,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct JpegInfo {
72 pub dimensions: Dimensions,
74 pub color_space: ColorSpace,
76 pub precision: u8,
78 pub num_components: u8,
80 pub mode: JpegMode,
82 pub has_icc_profile: bool,
84 pub is_xyb: bool,
86}
87
88pub struct Decoder {
90 config: DecoderConfig,
91}
92
93impl Decoder {
94 #[must_use]
96 pub fn new() -> Self {
97 Self {
98 config: DecoderConfig::default(),
99 }
100 }
101
102 #[must_use]
104 pub fn from_config(config: DecoderConfig) -> Self {
105 Self { config }
106 }
107
108 #[must_use]
110 pub fn output_format(mut self, format: PixelFormat) -> Self {
111 self.config.output_format = Some(format);
112 self
113 }
114
115 #[must_use]
117 pub fn fancy_upsampling(mut self, enable: bool) -> Self {
118 self.config.fancy_upsampling = enable;
119 self
120 }
121
122 #[must_use]
124 pub fn block_smoothing(mut self, enable: bool) -> Self {
125 self.config.block_smoothing = enable;
126 self
127 }
128
129 #[must_use]
138 pub fn apply_icc(mut self, enable: bool) -> Self {
139 self.config.apply_icc = enable;
140 self
141 }
142
143 #[must_use]
147 pub fn max_pixels(mut self, pixels: u64) -> Self {
148 self.config.max_pixels = pixels;
149 self
150 }
151
152 #[must_use]
157 pub fn max_memory(mut self, bytes: usize) -> Self {
158 self.config.max_memory = bytes;
159 self
160 }
161
162 pub fn read_info(&self, data: &[u8]) -> Result<JpegInfo> {
164 let mut parser = JpegParser::new(data, self.config.max_pixels)?;
165 parser.read_header()?;
166 Ok(parser.info())
167 }
168
169 pub fn decode(&self, data: &[u8]) -> Result<DecodedImage> {
171 let mut parser = JpegParser::new(data, self.config.max_pixels)?;
172 parser.decode()?;
173
174 let info = parser.info();
175 let output_format = self.config.output_format.unwrap_or(PixelFormat::Rgb);
176
177 let mut pixels = parser.to_pixels(output_format)?;
179
180 #[cfg(any(feature = "cms-lcms2", feature = "cms-moxcms"))]
182 if self.config.apply_icc && output_format == PixelFormat::Rgb {
183 if let Some(ref icc_profile) = parser.icc_profile {
184 pixels = apply_icc_transform(
185 &pixels,
186 info.dimensions.width as usize,
187 info.dimensions.height as usize,
188 icc_profile,
189 )?;
190 }
191 }
192
193 Ok(DecodedImage {
194 width: info.dimensions.width,
195 height: info.dimensions.height,
196 format: output_format,
197 data: pixels,
198 })
199 }
200}
201
202impl Default for Decoder {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208#[derive(Debug, Clone)]
210#[non_exhaustive]
211pub struct DecodedImage {
212 pub width: u32,
214 pub height: u32,
216 pub format: PixelFormat,
218 pub data: Vec<u8>,
220}
221
222impl DecodedImage {
223 #[must_use]
225 pub fn dimensions(&self) -> (u32, u32) {
226 (self.width, self.height)
227 }
228
229 #[must_use]
231 pub fn bytes_per_pixel(&self) -> usize {
232 self.format.bytes_per_pixel()
233 }
234
235 #[must_use]
237 pub fn stride(&self) -> usize {
238 self.width as usize * self.bytes_per_pixel()
239 }
240}
241
242struct JpegParser<'a> {
244 data: &'a [u8],
245 position: usize,
246
247 width: u32,
249 height: u32,
250 precision: u8,
251 num_components: u8,
252 mode: JpegMode,
253
254 components: [Component; MAX_COMPONENTS],
256
257 quant_tables: [Option<[u16; DCT_BLOCK_SIZE]>; MAX_QUANT_TABLES],
259 dc_tables: [Option<HuffmanDecodeTable>; MAX_HUFFMAN_TABLES],
260 ac_tables: [Option<HuffmanDecodeTable>; MAX_HUFFMAN_TABLES],
261
262 restart_interval: u16,
264
265 coeffs: Vec<Vec<[i16; DCT_BLOCK_SIZE]>>, icc_profile: Option<Vec<u8>>,
270
271 max_pixels: u64,
273}
274
275impl<'a> JpegParser<'a> {
276 fn new(data: &'a [u8], max_pixels: u64) -> Result<Self> {
277 if data.len() < 2 || data[0] != 0xFF || data[1] != MARKER_SOI {
279 return Err(Error::InvalidJpegData {
280 reason: "missing SOI marker",
281 });
282 }
283
284 let icc_profile = extract_icc_profile(data);
286
287 Ok(Self {
288 data,
289 position: 2,
290 width: 0,
291 height: 0,
292 precision: 8,
293 num_components: 0,
294 mode: JpegMode::Baseline,
295 components: std::array::from_fn(|_| Component::default()),
296 quant_tables: [None, None, None, None],
297 dc_tables: [None, None, None, None],
298 ac_tables: [None, None, None, None],
299 restart_interval: 0,
300 coeffs: Vec::new(),
301 icc_profile,
302 max_pixels,
303 })
304 }
305
306 fn read_u8(&mut self) -> Result<u8> {
307 if self.position >= self.data.len() {
308 return Err(Error::UnexpectedEof {
309 context: "reading byte",
310 });
311 }
312 let byte = self.data[self.position];
313 self.position += 1;
314 Ok(byte)
315 }
316
317 fn read_u16(&mut self) -> Result<u16> {
318 let high = self.read_u8()? as u16;
319 let low = self.read_u8()? as u16;
320 Ok((high << 8) | low)
321 }
322
323 fn read_marker(&mut self) -> Result<u8> {
324 loop {
325 let byte = self.read_u8()?;
326 if byte != 0xFF {
327 continue;
328 }
329
330 let marker = self.read_u8()?;
331 if marker != 0x00 && marker != 0xFF {
332 return Ok(marker);
333 }
334 }
335 }
336
337 fn read_header(&mut self) -> Result<()> {
338 loop {
339 let marker = self.read_marker()?;
340
341 match marker {
342 MARKER_SOF0 | MARKER_SOF1 => {
343 self.mode = JpegMode::Baseline;
344 self.parse_frame_header()?;
345 return Ok(());
346 }
347 MARKER_SOF2 => {
348 self.mode = JpegMode::Progressive;
349 self.parse_frame_header()?;
350 return Ok(());
351 }
352 MARKER_DQT => self.parse_quant_table()?,
353 MARKER_DHT => self.parse_huffman_table()?,
354 MARKER_DRI => self.parse_restart_interval()?,
355 MARKER_APP0..=0xEF | MARKER_COM => self.skip_segment()?,
356 MARKER_EOI => {
357 return Err(Error::InvalidJpegData {
358 reason: "unexpected EOI before frame header",
359 });
360 }
361 _ => self.skip_segment()?,
362 }
363 }
364 }
365
366 fn parse_frame_header(&mut self) -> Result<()> {
367 let length = self.read_u16()?;
368 if length < 8 {
369 return Err(Error::InvalidJpegData {
370 reason: "frame header too short",
371 });
372 }
373
374 self.precision = self.read_u8()?;
375 if self.precision != 8 && self.precision != 12 {
377 return Err(Error::InvalidJpegData {
378 reason: "invalid data precision (must be 8 or 12)",
379 });
380 }
381
382 self.height = self.read_u16()? as u32;
383 self.width = self.read_u16()? as u32;
384
385 let effective_max = if self.max_pixels == 0 {
388 u64::MAX
389 } else {
390 self.max_pixels
391 };
392 validate_dimensions(self.width, self.height, effective_max)?;
393
394 self.num_components = self.read_u8()?;
395
396 if self.num_components == 0 {
398 return Err(Error::InvalidJpegData {
399 reason: "number of components is zero",
400 });
401 }
402 if self.num_components > MAX_COMPONENTS as u8 {
403 return Err(Error::UnsupportedFeature {
404 feature: "more than 4 components",
405 });
406 }
407
408 let expected_length = 8 + 3 * self.num_components as u16;
410 if length != expected_length {
411 return Err(Error::InvalidJpegData {
412 reason: "SOF marker length mismatch",
413 });
414 }
415
416 for i in 0..self.num_components as usize {
417 self.components[i].id = self.read_u8()?;
418 let sampling = self.read_u8()?;
419 let h_samp = sampling >> 4;
420 let v_samp = sampling & 0x0F;
421
422 if h_samp == 0 || v_samp == 0 {
424 return Err(Error::InvalidJpegData {
425 reason: "sampling factor is zero",
426 });
427 }
428 if h_samp > 4 || v_samp > 4 {
429 return Err(Error::InvalidJpegData {
430 reason: "sampling factor exceeds maximum (4)",
431 });
432 }
433
434 self.components[i].h_samp_factor = h_samp;
435 self.components[i].v_samp_factor = v_samp;
436
437 let quant_idx = self.read_u8()?;
438 if quant_idx as usize >= MAX_QUANT_TABLES {
440 return Err(Error::InvalidJpegData {
441 reason: "quantization table index out of range",
442 });
443 }
444 self.components[i].quant_table_idx = quant_idx;
445 }
446
447 Ok(())
448 }
449
450 fn parse_quant_table(&mut self) -> Result<()> {
451 let mut length = self.read_u16()? as i32 - 2;
452
453 while length > 0 {
454 let info = self.read_u8()?;
455 let precision = info >> 4;
456 let table_idx = (info & 0x0F) as usize;
457
458 if precision > 1 {
460 return Err(Error::InvalidQuantTable {
461 table_idx: table_idx as u8,
462 reason: "invalid precision (must be 0 or 1)",
463 });
464 }
465
466 if table_idx >= MAX_QUANT_TABLES {
467 return Err(Error::InvalidQuantTable {
468 table_idx: table_idx as u8,
469 reason: "table index out of range",
470 });
471 }
472
473 let mut zigzag_values = [0u16; DCT_BLOCK_SIZE];
475
476 if precision == 0 {
477 for i in 0..DCT_BLOCK_SIZE {
479 let val = self.read_u8()? as u16;
480 if val == 0 {
481 return Err(Error::InvalidQuantTable {
482 table_idx: table_idx as u8,
483 reason: "quantization value is zero",
484 });
485 }
486 zigzag_values[i] = val;
487 }
488 length -= 65;
489 } else {
490 for i in 0..DCT_BLOCK_SIZE {
492 let val = self.read_u16()?;
493 if val == 0 {
494 return Err(Error::InvalidQuantTable {
495 table_idx: table_idx as u8,
496 reason: "quantization value is zero",
497 });
498 }
499 zigzag_values[i] = val;
500 }
501 length -= 129;
502 }
503
504 if length < 0 {
506 return Err(Error::InvalidJpegData {
507 reason: "DQT marker length mismatch",
508 });
509 }
510
511 let mut natural_values = [0u16; DCT_BLOCK_SIZE];
513 for i in 0..DCT_BLOCK_SIZE {
514 natural_values[JPEG_NATURAL_ORDER[i] as usize] = zigzag_values[i];
515 }
516
517 self.quant_tables[table_idx] = Some(natural_values);
518 }
519
520 Ok(())
521 }
522
523 fn parse_huffman_table(&mut self) -> Result<()> {
524 let mut length = self.read_u16()? as i32 - 2;
525
526 while length > 0 {
527 let info = self.read_u8()?;
528 let table_class = info >> 4; let table_idx = (info & 0x0F) as usize;
530
531 if table_class > 1 {
533 return Err(Error::InvalidHuffmanTable {
534 table_idx: table_idx as u8,
535 reason: "invalid table class (must be 0 or 1)",
536 });
537 }
538
539 if table_idx >= MAX_HUFFMAN_TABLES {
540 return Err(Error::InvalidHuffmanTable {
541 table_idx: table_idx as u8,
542 reason: "table index out of range",
543 });
544 }
545
546 let mut bits = [0u8; 16];
547 for i in 0..16 {
548 bits[i] = self.read_u8()?;
549 }
550
551 let num_values: usize = bits.iter().map(|&b| b as usize).sum();
552 let mut values = vec![0u8; num_values];
553 for i in 0..num_values {
554 values[i] = self.read_u8()?;
555 }
556
557 length -= 17 + num_values as i32;
558
559 if length < 0 {
561 return Err(Error::InvalidJpegData {
562 reason: "DHT marker length mismatch",
563 });
564 }
565
566 let table = HuffmanDecodeTable::from_bits_values(&bits, &values)?;
567
568 if table_class == 0 {
569 self.dc_tables[table_idx] = Some(table);
570 } else {
571 self.ac_tables[table_idx] = Some(table);
572 }
573 }
574
575 Ok(())
576 }
577
578 fn parse_restart_interval(&mut self) -> Result<()> {
579 let _length = self.read_u16()?;
580 self.restart_interval = self.read_u16()?;
581 Ok(())
582 }
583
584 fn skip_segment(&mut self) -> Result<()> {
585 let length = self.read_u16()? as usize;
586 if length < 2 {
587 return Err(Error::InvalidJpegData {
588 reason: "segment length too short",
589 });
590 }
591 self.position += length - 2;
592 Ok(())
593 }
594
595 fn decode(&mut self) -> Result<()> {
596 self.position = 2; self.read_header()?;
599
600 loop {
602 let marker = self.read_marker()?;
603
604 match marker {
605 MARKER_SOS => {
606 self.parse_scan()?;
607 }
609 MARKER_DQT => self.parse_quant_table()?,
610 MARKER_DHT => self.parse_huffman_table()?,
611 MARKER_DRI => self.parse_restart_interval()?,
612 MARKER_EOI => break,
613 MARKER_APP0..=0xEF | MARKER_COM => self.skip_segment()?,
614 _ => self.skip_segment()?,
615 }
616 }
617
618 Ok(())
619 }
620
621 fn parse_scan(&mut self) -> Result<()> {
622 let _length = self.read_u16()?;
623 let num_components = self.read_u8()?;
624
625 if num_components == 0 {
627 return Err(Error::InvalidJpegData {
628 reason: "SOS num_components is zero",
629 });
630 }
631 if num_components > self.num_components {
632 return Err(Error::InvalidJpegData {
633 reason: "SOS num_components exceeds frame components",
634 });
635 }
636 if num_components > MAX_COMPONENTS as u8 {
637 return Err(Error::InvalidJpegData {
638 reason: "SOS num_components too large",
639 });
640 }
641
642 let mut scan_components = Vec::with_capacity(num_components as usize);
643
644 for _ in 0..num_components {
645 let component_id = self.read_u8()?;
646 let tables = self.read_u8()?;
647 let dc_table = tables >> 4;
648 let ac_table = tables & 0x0F;
649
650 if dc_table as usize >= MAX_HUFFMAN_TABLES {
652 return Err(Error::InvalidJpegData {
653 reason: "SOS DC Huffman table index out of range",
654 });
655 }
656 if ac_table as usize >= MAX_HUFFMAN_TABLES {
657 return Err(Error::InvalidJpegData {
658 reason: "SOS AC Huffman table index out of range",
659 });
660 }
661
662 let comp_idx = self.components[..self.num_components as usize]
664 .iter()
665 .position(|c| c.id == component_id)
666 .ok_or(Error::InvalidJpegData {
667 reason: "unknown component in scan",
668 })?;
669
670 scan_components.push((comp_idx, dc_table, ac_table));
671 }
672
673 let ss = self.read_u8()?; let se = self.read_u8()?; let ah_al = self.read_u8()?;
676 let ah = ah_al >> 4;
677 let al = ah_al & 0x0F;
678
679 if ss > 63 {
681 return Err(Error::InvalidJpegData {
682 reason: "SOS Ss (spectral start) out of range",
683 });
684 }
685 if se > 63 {
686 return Err(Error::InvalidJpegData {
687 reason: "SOS Se (spectral end) out of range",
688 });
689 }
690
691 if self.mode == JpegMode::Progressive {
693 self.decode_progressive_scan(&scan_components, ss, se, ah, al)?;
694 } else {
695 self.decode_scan(&scan_components)?;
696 }
697
698 Ok(())
699 }
700
701 fn decode_scan(&mut self, scan_components: &[(usize, u8, u8)]) -> Result<()> {
702 let mut max_h_samp = 1u8;
704 let mut max_v_samp = 1u8;
705 for i in 0..self.num_components as usize {
706 max_h_samp = max_h_samp.max(self.components[i].h_samp_factor);
707 max_v_samp = max_v_samp.max(self.components[i].v_samp_factor);
708 }
709
710 let mcu_width = (max_h_samp as usize) * 8;
712 let mcu_height = (max_v_samp as usize) * 8;
713
714 let mcu_cols = (self.width as usize + mcu_width - 1) / mcu_width;
716 let mcu_rows = (self.height as usize + mcu_height - 1) / mcu_height;
717
718 if self.coeffs.is_empty() {
720 for i in 0..self.num_components as usize {
721 let h_samp = self.components[i].h_samp_factor as usize;
722 let v_samp = self.components[i].v_samp_factor as usize;
723 let comp_blocks_h = checked_size_2d(mcu_cols, h_samp)?;
724 let comp_blocks_v = checked_size_2d(mcu_rows, v_samp)?;
725 let num_blocks = checked_size_2d(comp_blocks_h, comp_blocks_v)?;
726 self.coeffs.push(try_alloc_dct_blocks(
727 num_blocks,
728 "allocating DCT coefficients",
729 )?);
730 }
731 }
732
733 let scan_data = &self.data[self.position..];
735 let mut decoder = EntropyDecoder::new(scan_data);
736
737 for (_comp_idx, dc_table, ac_table) in scan_components {
738 let dc_idx = (*dc_table as usize).min(MAX_HUFFMAN_TABLES - 1);
739 let ac_idx = (*ac_table as usize).min(MAX_HUFFMAN_TABLES - 1);
740 if let Some(table) = &self.dc_tables[dc_idx] {
741 decoder.set_dc_table(dc_idx, table.clone());
742 }
743 if let Some(table) = &self.ac_tables[ac_idx] {
744 decoder.set_ac_table(ac_idx, table.clone());
745 }
746 }
747
748 for mcu_y in 0..mcu_rows {
750 for mcu_x in 0..mcu_cols {
751 for (comp_idx, dc_table, ac_table) in scan_components {
753 let h_samp = self.components[*comp_idx].h_samp_factor as usize;
754 let v_samp = self.components[*comp_idx].v_samp_factor as usize;
755 let comp_blocks_h = mcu_cols * h_samp;
756
757 for v in 0..v_samp {
759 for h in 0..h_samp {
760 let block_x = mcu_x * h_samp + h;
761 let block_y = mcu_y * v_samp + v;
762 let block_idx = block_y * comp_blocks_h + block_x;
763
764 let coeffs = decoder.decode_block(
765 *comp_idx,
766 *dc_table as usize,
767 *ac_table as usize,
768 )?;
769 self.coeffs[*comp_idx][block_idx] = coeffs;
770 }
771 }
772 }
773 }
774 }
775
776 self.position += decoder.position();
777 Ok(())
778 }
779
780 fn decode_progressive_scan(
781 &mut self,
782 scan_components: &[(usize, u8, u8)],
783 ss: u8,
784 se: u8,
785 ah: u8,
786 al: u8,
787 ) -> Result<()> {
788 let mut max_h_samp = 1u8;
790 let mut max_v_samp = 1u8;
791 for i in 0..self.num_components as usize {
792 max_h_samp = max_h_samp.max(self.components[i].h_samp_factor);
793 max_v_samp = max_v_samp.max(self.components[i].v_samp_factor);
794 }
795
796 let mcu_width = (max_h_samp as usize) * 8;
798 let mcu_height = (max_v_samp as usize) * 8;
799
800 let mcu_cols = (self.width as usize + mcu_width - 1) / mcu_width;
802 let mcu_rows = (self.height as usize + mcu_height - 1) / mcu_height;
803
804 if self.coeffs.is_empty() {
806 for i in 0..self.num_components as usize {
807 let h_samp = self.components[i].h_samp_factor as usize;
808 let v_samp = self.components[i].v_samp_factor as usize;
809 let comp_blocks_h = checked_size_2d(mcu_cols, h_samp)?;
810 let comp_blocks_v = checked_size_2d(mcu_rows, v_samp)?;
811 let num_blocks = checked_size_2d(comp_blocks_h, comp_blocks_v)?;
812 self.coeffs.push(try_alloc_dct_blocks(
813 num_blocks,
814 "allocating DCT coefficients",
815 )?);
816 }
817 }
818
819 let scan_data = &self.data[self.position..];
821 let mut decoder = EntropyDecoder::new(scan_data);
822
823 for (_comp_idx, dc_table, ac_table) in scan_components {
824 let dc_idx = (*dc_table as usize).min(MAX_HUFFMAN_TABLES - 1);
825 let ac_idx = (*ac_table as usize).min(MAX_HUFFMAN_TABLES - 1);
826 if let Some(table) = &self.dc_tables[dc_idx] {
827 decoder.set_dc_table(dc_idx, table.clone());
828 }
829 if let Some(table) = &self.ac_tables[ac_idx] {
830 decoder.set_ac_table(ac_idx, table.clone());
831 }
832 }
833
834 let is_dc_scan = ss == 0 && se == 0;
836 let is_first_scan = ah == 0;
837
838 let mut eob_run = 0u16;
840
841 if is_dc_scan {
842 for mcu_y in 0..mcu_rows {
844 for mcu_x in 0..mcu_cols {
845 for (comp_idx, dc_table, _ac_table) in scan_components {
846 let h_samp = self.components[*comp_idx].h_samp_factor as usize;
847 let v_samp = self.components[*comp_idx].v_samp_factor as usize;
848 let comp_blocks_h = mcu_cols * h_samp;
849
850 for v in 0..v_samp {
851 for h in 0..h_samp {
852 let block_x = mcu_x * h_samp + h;
853 let block_y = mcu_y * v_samp + v;
854 let block_idx = block_y * comp_blocks_h + block_x;
855
856 if is_first_scan {
857 let dc = decoder.decode_dc_first(
859 *comp_idx,
860 *dc_table as usize,
861 al,
862 )?;
863 self.coeffs[*comp_idx][block_idx][0] = dc;
864 } else {
865 let bit = decoder.decode_dc_refine(al)?;
867 self.coeffs[*comp_idx][block_idx][0] |= bit;
868 }
869 }
870 }
871 }
872 }
873 }
874 } else {
875 if scan_components.len() != 1 {
878 return Err(Error::InvalidJpegData {
879 reason: "progressive AC scan must have single component",
880 });
881 }
882
883 let (comp_idx, _dc_table, ac_table) = scan_components[0];
884 let h_samp = self.components[comp_idx].h_samp_factor as usize;
885 let v_samp = self.components[comp_idx].v_samp_factor as usize;
886 let comp_blocks_h = mcu_cols * h_samp;
887
888 for mcu_y in 0..mcu_rows {
889 for mcu_x in 0..mcu_cols {
890 for v in 0..v_samp {
891 for h in 0..h_samp {
892 let block_x = mcu_x * h_samp + h;
893 let block_y = mcu_y * v_samp + v;
894 let block_idx = block_y * comp_blocks_h + block_x;
895
896 if is_first_scan {
897 decoder.decode_ac_first(
899 &mut self.coeffs[comp_idx][block_idx],
900 ac_table as usize,
901 ss,
902 se,
903 al,
904 &mut eob_run,
905 )?;
906 } else {
907 decoder.decode_ac_refine(
909 &mut self.coeffs[comp_idx][block_idx],
910 ac_table as usize,
911 ss,
912 se,
913 al,
914 &mut eob_run,
915 )?;
916 }
917 }
918 }
919 }
920 }
921 }
922
923 self.position += decoder.position();
924 Ok(())
925 }
926
927 fn info(&self) -> JpegInfo {
928 let has_icc = self.icc_profile.is_some();
929 let is_xyb = self.icc_profile.as_ref().is_some_and(|p| is_xyb_profile(p));
930
931 let color_space = if is_xyb {
933 ColorSpace::Xyb
934 } else {
935 match self.num_components {
936 1 => ColorSpace::Grayscale,
937 3 => ColorSpace::YCbCr,
938 4 => ColorSpace::Cmyk,
939 _ => ColorSpace::Unknown,
940 }
941 };
942
943 JpegInfo {
944 dimensions: Dimensions::new(self.width, self.height),
945 color_space,
946 precision: self.precision,
947 num_components: self.num_components,
948 mode: self.mode,
949 has_icc_profile: has_icc,
950 is_xyb,
951 }
952 }
953
954 fn to_pixels(&self, format: PixelFormat) -> Result<Vec<u8>> {
955 if self.coeffs.is_empty() {
956 return Err(Error::InternalError {
957 reason: "no decoded data",
958 });
959 }
960
961 let width = self.width as usize;
962 let height = self.height as usize;
963
964 let mut max_h_samp = 1u8;
966 let mut max_v_samp = 1u8;
967 for i in 0..self.num_components as usize {
968 max_h_samp = max_h_samp.max(self.components[i].h_samp_factor);
969 max_v_samp = max_v_samp.max(self.components[i].v_samp_factor);
970 }
971
972 let mcu_width = (max_h_samp as usize) * 8;
974 let mcu_height = (max_v_samp as usize) * 8;
975 let mcu_cols = (width + mcu_width - 1) / mcu_width;
976 let mcu_rows = (height + mcu_height - 1) / mcu_height;
977
978 struct CompInfo {
980 quant_idx: usize,
981 h_samp: usize,
982 v_samp: usize,
983 comp_blocks_h: usize,
984 comp_blocks_v: usize,
985 comp_width: usize,
986 comp_height: usize,
987 is_full_res: bool,
988 }
989
990 let mut comp_infos: Vec<CompInfo> = Vec::new();
991 for comp_idx in 0..self.num_components as usize {
992 let h_samp = self.components[comp_idx].h_samp_factor as usize;
993 let v_samp = self.components[comp_idx].v_samp_factor as usize;
994 let comp_blocks_h = mcu_cols * h_samp;
995 let comp_blocks_v = mcu_rows * v_samp;
996 let comp_width = checked_size_2d(comp_blocks_h, 8)?;
997 let comp_height = checked_size_2d(comp_blocks_v, 8)?;
998 comp_infos.push(CompInfo {
999 quant_idx: self.components[comp_idx].quant_table_idx as usize,
1000 h_samp,
1001 v_samp,
1002 comp_blocks_h,
1003 comp_blocks_v,
1004 comp_width,
1005 comp_height,
1006 is_full_res: h_samp == max_h_samp as usize && v_samp == max_v_samp as usize,
1007 });
1008 }
1009
1010 let mut bias_stats = DequantBiasStats::new(self.num_components as usize);
1012 let mut component_biases: Vec<[f32; DCT_BLOCK_SIZE]> =
1013 vec![[0.0f32; DCT_BLOCK_SIZE]; self.num_components as usize];
1014
1015 let mut comp_planes_f32: Vec<Vec<f32>> = Vec::new();
1017 for info in &comp_infos {
1018 let comp_plane_size = checked_size_2d(info.comp_width, info.comp_height)?;
1019 comp_planes_f32.push(vec![0.0f32; comp_plane_size]);
1020 }
1021
1022 for imcu_row in 0..mcu_rows {
1024 for comp_idx in 0..self.num_components as usize {
1026 let info = &comp_infos[comp_idx];
1027 let quant =
1028 self.quant_tables[info.quant_idx]
1029 .as_ref()
1030 .ok_or(Error::InternalError {
1031 reason: "missing quantization table",
1032 })?;
1033
1034 if info.is_full_res {
1036 for iy in 0..info.v_samp {
1037 let by = imcu_row * info.v_samp + iy;
1038 if by >= info.comp_blocks_v {
1039 continue;
1040 }
1041 for bx in 0..info.comp_blocks_h {
1042 let block_idx = by * info.comp_blocks_h + bx;
1043 if block_idx >= self.coeffs[comp_idx].len() {
1044 continue;
1045 }
1046 let coeffs = &self.coeffs[comp_idx][block_idx];
1047 let mut natural_coeffs = [0i16; DCT_BLOCK_SIZE];
1048 for (i, &zi) in JPEG_NATURAL_ORDER[..DCT_BLOCK_SIZE].iter().enumerate()
1049 {
1050 natural_coeffs[zi as usize] = coeffs[i];
1051 }
1052 bias_stats.gather_block(comp_idx, &natural_coeffs);
1053 }
1054 }
1055
1056 if imcu_row % 4 == 3 {
1058 component_biases[comp_idx] = bias_stats.compute_biases(comp_idx);
1059 }
1060 }
1061
1062 let biases = &component_biases[comp_idx];
1065 let comp_plane_f32 = &mut comp_planes_f32[comp_idx];
1066
1067 for iy in 0..info.v_samp {
1068 let by = imcu_row * info.v_samp + iy;
1069 if by >= info.comp_blocks_v {
1070 continue;
1071 }
1072
1073 for bx in 0..info.comp_blocks_h {
1074 let block_idx = by * info.comp_blocks_h + bx;
1075 if block_idx >= self.coeffs[comp_idx].len() {
1076 continue;
1077 }
1078 let coeffs = &self.coeffs[comp_idx][block_idx];
1079
1080 let mut natural_coeffs = [0i16; DCT_BLOCK_SIZE];
1081 for (i, &zi) in JPEG_NATURAL_ORDER[..DCT_BLOCK_SIZE].iter().enumerate() {
1082 natural_coeffs[zi as usize] = coeffs[i];
1083 }
1084
1085 let dequant = dequantize_block_with_bias(&natural_coeffs, quant, biases);
1086 let pixels = inverse_dct_8x8(&dequant);
1087
1088 for y in 0..DCT_SIZE {
1090 for x in 0..DCT_SIZE {
1091 let px = bx * DCT_SIZE + x;
1092 let py = by * DCT_SIZE + y;
1093 if px < info.comp_width && py < info.comp_height {
1094 comp_plane_f32[py * info.comp_width + px] =
1095 pixels[y * DCT_SIZE + x];
1096 }
1097 }
1098 }
1099 }
1100 }
1101 }
1102 }
1103
1104 let output_size = checked_size_2d(width, height)?;
1106 let mut planes_f32: Vec<Vec<f32>> = Vec::new();
1107
1108 for comp_idx in 0..self.num_components as usize {
1109 let info = &comp_infos[comp_idx];
1110 let comp_plane_f32 = &comp_planes_f32[comp_idx];
1111
1112 let plane_f32 =
1113 if info.h_samp < max_h_samp as usize || info.v_samp < max_v_samp as usize {
1114 let scale_x = max_h_samp as usize / info.h_samp;
1115 let scale_y = max_v_samp as usize / info.v_samp;
1116 let mut upsampled = vec![0.0f32; output_size];
1117 for py in 0..height {
1118 for px in 0..width {
1119 let sx = (px / scale_x).min(info.comp_width - 1);
1120 let sy = (py / scale_y).min(info.comp_height - 1);
1121 upsampled[py * width + px] = comp_plane_f32[sy * info.comp_width + sx];
1122 }
1123 }
1124 upsampled
1125 } else {
1126 let mut plane = vec![0.0f32; output_size];
1128 for py in 0..height {
1129 for px in 0..width {
1130 plane[py * width + px] = comp_plane_f32[py * info.comp_width + px];
1131 }
1132 }
1133 plane
1134 };
1135
1136 planes_f32.push(plane_f32);
1137 }
1138
1139 match (self.num_components, format) {
1141 (1, PixelFormat::Gray) => {
1142 let mut output = try_alloc_zeroed(output_size, "allocating gray output")?;
1144 for (i, &y) in planes_f32[0].iter().enumerate() {
1145 output[i] = (y + 128.0).round().clamp(0.0, 255.0) as u8;
1146 }
1147 Ok(output)
1148 }
1149 (1, PixelFormat::Rgb) => {
1150 let rgb_size =
1151 checked_size_2d(width, height).and_then(|s| checked_size_2d(s, 3))?;
1152 let mut rgb = try_alloc_zeroed(rgb_size, "allocating RGB output")?;
1153 for (i, &y) in planes_f32[0].iter().enumerate() {
1154 let val = (y + 128.0).round().clamp(0.0, 255.0) as u8;
1155 rgb[i * 3] = val;
1156 rgb[i * 3 + 1] = val;
1157 rgb[i * 3 + 2] = val;
1158 }
1159 Ok(rgb)
1160 }
1161 (3, PixelFormat::Rgb) => {
1162 let rgb_size =
1164 checked_size_2d(width, height).and_then(|s| checked_size_2d(s, 3))?;
1165 let mut rgb = try_alloc_zeroed(rgb_size, "allocating RGB output")?;
1166
1167 for i in 0..output_size {
1168 let y = planes_f32[0][i];
1170 let cb = planes_f32[1][i]; let cr = planes_f32[2][i]; let r = y + 1.402 * cr;
1178 let g = y - 0.344136 * cb - 0.714136 * cr;
1179 let b = y + 1.772 * cb;
1180
1181 rgb[i * 3] = (r + 128.0).round().clamp(0.0, 255.0) as u8;
1183 rgb[i * 3 + 1] = (g + 128.0).round().clamp(0.0, 255.0) as u8;
1184 rgb[i * 3 + 2] = (b + 128.0).round().clamp(0.0, 255.0) as u8;
1185 }
1186 Ok(rgb)
1187 }
1188 _ => Err(Error::UnsupportedFeature {
1189 feature: "unsupported color conversion",
1190 }),
1191 }
1192 }
1193}
1194
1195#[cfg(test)]
1196mod tests {
1197 use super::*;
1198 use crate::encode::Encoder;
1199 use crate::quant::Quality;
1200
1201 #[test]
1202 fn test_decoder_creation() {
1203 let decoder = Decoder::new()
1204 .output_format(PixelFormat::Rgb)
1205 .fancy_upsampling(true);
1206
1207 assert_eq!(decoder.config.output_format, Some(PixelFormat::Rgb));
1208 assert!(decoder.config.fancy_upsampling);
1209 }
1210
1211 #[test]
1212 fn test_encode_decode_roundtrip_gray() {
1213 let width = 8;
1215 let height = 8;
1216 let mut input = vec![0u8; width * height];
1217 for y in 0..height {
1218 for x in 0..width {
1219 input[y * width + x] = ((x + y) * 16) as u8;
1220 }
1221 }
1222
1223 let encoder = Encoder::new()
1225 .width(width as u32)
1226 .height(height as u32)
1227 .pixel_format(PixelFormat::Gray)
1228 .quality(Quality::from_quality(95.0));
1229
1230 let jpeg = encoder.encode(&input).expect("encoding should succeed");
1231
1232 assert_eq!(jpeg[0], 0xFF);
1234 assert_eq!(jpeg[1], 0xD8); assert_eq!(jpeg[jpeg.len() - 2], 0xFF);
1236 assert_eq!(jpeg[jpeg.len() - 1], 0xD9); let decoder = Decoder::new().output_format(PixelFormat::Gray);
1240 let decoded = decoder.decode(&jpeg).expect("decoding should succeed");
1241
1242 assert_eq!(decoded.width, width as u32);
1243 assert_eq!(decoded.height, height as u32);
1244 assert_eq!(decoded.data.len(), width * height);
1245
1246 let mut max_diff = 0i32;
1248 for i in 0..input.len() {
1249 let diff = (input[i] as i32 - decoded.data[i] as i32).abs();
1250 max_diff = max_diff.max(diff);
1251 }
1252 assert!(max_diff < 20, "max_diff {} too large", max_diff);
1254 }
1255
1256 #[test]
1257 fn test_encode_decode_roundtrip_rgb() {
1258 let width = 16;
1260 let height = 16;
1261 let mut input = vec![0u8; width * height * 3];
1262 for y in 0..height {
1263 for x in 0..width {
1264 let idx = (y * width + x) * 3;
1265 input[idx] = (x * 16) as u8; input[idx + 1] = (y * 16) as u8; input[idx + 2] = 128; }
1269 }
1270
1271 let encoder = Encoder::new()
1273 .width(width as u32)
1274 .height(height as u32)
1275 .pixel_format(PixelFormat::Rgb)
1276 .quality(Quality::from_quality(95.0));
1277
1278 let jpeg = encoder.encode(&input).expect("encoding should succeed");
1279
1280 let decoder = Decoder::new().output_format(PixelFormat::Rgb);
1282 let decoded = decoder.decode(&jpeg).expect("decoding should succeed");
1283
1284 assert_eq!(decoded.width, width as u32);
1285 assert_eq!(decoded.height, height as u32);
1286 assert_eq!(decoded.data.len(), width * height * 3);
1287
1288 let mut max_diff = 0i32;
1290 for i in 0..input.len() {
1291 let diff = (input[i] as i32 - decoded.data[i] as i32).abs();
1292 max_diff = max_diff.max(diff);
1293 }
1294 assert!(max_diff < 30, "max_diff {} too large", max_diff);
1296 }
1297}