1use alloc::borrow::ToOwned;
2use alloc::{format, vec};
3use alloc::vec::Vec;
4use core::ops::{self, Range};
5use std::io::{self, Read};
6use crate::{read_u16_from_be, read_u8};
7use crate::error::{Error, Result, UnsupportedFeature};
8use crate::huffman::{HuffmanTable, HuffmanTableClass};
9use crate::marker::Marker;
10use crate::marker::Marker::*;
11
12#[derive(Clone, Copy, Debug, PartialEq)]
13pub struct Dimensions {
14 pub width: u16,
15 pub height: u16,
16}
17
18#[derive(Clone, Copy, Debug, PartialEq)]
19pub enum EntropyCoding {
20 Huffman,
21 Arithmetic,
22}
23
24#[derive(Clone, Copy, Debug, PartialEq)]
26pub enum CodingProcess {
27 DctSequential,
29 DctProgressive,
31 Lossless,
33}
34
35#[derive(Clone, Copy, Debug, PartialEq)]
37pub enum Predictor {
38 NoPrediction,
39 Ra,
40 Rb,
41 Rc,
42 RaRbRc1, RaRbRc2, RaRbRc3, RaRb, }
47
48
49#[derive(Clone)]
50pub struct FrameInfo {
51 pub is_baseline: bool,
52 pub is_differential: bool,
53 pub coding_process: CodingProcess,
54 pub entropy_coding: EntropyCoding,
55 pub precision: u8,
56
57 pub image_size: Dimensions,
58 pub output_size: Dimensions,
59 pub mcu_size: Dimensions,
60 pub components: Vec<Component>,
61}
62
63#[derive(Debug)]
64pub struct ScanInfo {
65 pub component_indices: Vec<usize>,
66 pub dc_table_indices: Vec<usize>,
67 pub ac_table_indices: Vec<usize>,
68
69 pub spectral_selection: Range<u8>,
70 pub predictor_selection: Predictor, pub successive_approximation_high: u8,
72 pub successive_approximation_low: u8,
73 pub point_transform: u8, }
75
76#[derive(Clone, Debug)]
77pub struct Component {
78 pub identifier: u8,
79
80 pub horizontal_sampling_factor: u8,
81 pub vertical_sampling_factor: u8,
82
83 pub quantization_table_index: usize,
84
85 pub dct_scale: usize,
86
87 pub size: Dimensions,
88 pub block_size: Dimensions,
89}
90
91#[derive(Debug)]
92pub enum AppData {
93 Adobe(AdobeColorTransform),
94 Jfif,
95 Avi1,
96 Icc(IccChunk),
97 Exif(Vec<u8>),
98 Xmp(Vec<u8>),
99 Psir(Vec<u8>),
100}
101
102#[allow(clippy::upper_case_acronyms)]
104#[derive(Clone, Copy, Debug, PartialEq)]
105pub enum AdobeColorTransform {
106 Unknown,
108 YCbCr,
109 YCCK,
111}
112#[derive(Debug)]
113pub struct IccChunk {
114 pub num_markers: u8,
115 pub seq_no: u8,
116 pub data: Vec<u8>,
117}
118
119impl FrameInfo {
120 pub(crate) fn update_idct_size(&mut self, idct_size: usize) -> Result<()> {
121 for component in &mut self.components {
122 component.dct_scale = idct_size;
123 }
124
125 update_component_sizes(self.image_size, &mut self.components)?;
126
127 self.output_size = Dimensions {
128 width: (self.image_size.width as f32 * idct_size as f32 / 8.0).ceil() as u16,
129 height: (self.image_size.height as f32 * idct_size as f32 / 8.0).ceil() as u16
130 };
131
132 Ok(())
133 }
134}
135
136fn read_length<R: Read>(reader: &mut R, marker: Marker) -> Result<usize> {
137 assert!(marker.has_length());
138
139 let length = usize::from(read_u16_from_be(reader)?);
141
142 if length < 2 {
143 return Err(Error::Format(format!("encountered {:?} with invalid length {}", marker, length)));
144 }
145
146 Ok(length - 2)
147}
148
149fn skip_bytes<R: Read>(reader: &mut R, length: usize) -> Result<()> {
150 let length = length as u64;
151 let to_skip = &mut reader.by_ref().take(length);
152 let copied = io::copy(to_skip, &mut io::sink())?;
153 if copied < length {
154 Err(Error::Io(io::ErrorKind::UnexpectedEof.into()))
155 } else {
156 Ok(())
157 }
158}
159
160pub fn parse_sof<R: Read>(reader: &mut R, marker: Marker) -> Result<FrameInfo> {
162 let length = read_length(reader, marker)?;
163
164 if length <= 6 {
165 return Err(Error::Format("invalid length in SOF".to_owned()));
166 }
167
168 let is_baseline = marker == SOF(0);
169 let is_differential = match marker {
170 SOF(0 ..= 3) | SOF(9 ..= 11) => false,
171 SOF(5 ..= 7) | SOF(13 ..= 15) => true,
172 _ => panic!(),
173 };
174 let coding_process = match marker {
175 SOF(0) | SOF(1) | SOF(5) | SOF(9) | SOF(13) => CodingProcess::DctSequential,
176 SOF(2) | SOF(6) | SOF(10) | SOF(14) => CodingProcess::DctProgressive,
177 SOF(3) | SOF(7) | SOF(11) | SOF(15) => CodingProcess::Lossless,
178 _ => panic!(),
179 };
180 let entropy_coding = match marker {
181 SOF(0 ..= 3) | SOF(5 ..= 7) => EntropyCoding::Huffman,
182 SOF(9 ..= 11) | SOF(13 ..= 15) => EntropyCoding::Arithmetic,
183 _ => panic!(),
184 };
185
186 let precision = read_u8(reader)?;
187
188 match precision {
189 8 => {},
190 12 => {
191 if is_baseline {
192 return Err(Error::Format("12 bit sample precision is not allowed in baseline".to_owned()));
193 }
194 },
195 _ => {
196 if coding_process != CodingProcess::Lossless || precision > 16 {
197 return Err(Error::Format(format!("invalid precision {} in frame header", precision)))
198 }
199 },
200 }
201
202 let height = read_u16_from_be(reader)?;
203 let width = read_u16_from_be(reader)?;
204
205 if height == 0 {
209 return Err(Error::Unsupported(UnsupportedFeature::DNL));
210 }
211
212 if width == 0 {
213 return Err(Error::Format("zero width in frame header".to_owned()));
214 }
215
216 let component_count = read_u8(reader)?;
217
218 if component_count == 0 {
219 return Err(Error::Format("zero component count in frame header".to_owned()));
220 }
221 if coding_process == CodingProcess::DctProgressive && component_count > 4 {
222 return Err(Error::Format("progressive frame with more than 4 components".to_owned()));
223 }
224
225 if length != 6 + 3 * component_count as usize {
226 return Err(Error::Format("invalid length in SOF".to_owned()));
227 }
228
229 let mut components: Vec<Component> = Vec::with_capacity(component_count as usize);
230
231 for _ in 0 .. component_count {
232 let identifier = read_u8(reader)?;
233
234 if components.iter().any(|c| c.identifier == identifier) {
236 return Err(Error::Format(format!("duplicate frame component identifier {}", identifier)));
237 }
238
239 let byte = read_u8(reader)?;
240 let horizontal_sampling_factor = byte >> 4;
241 let vertical_sampling_factor = byte & 0x0f;
242
243 if horizontal_sampling_factor == 0 || horizontal_sampling_factor > 4 {
244 return Err(Error::Format(format!("invalid horizontal sampling factor {}", horizontal_sampling_factor)));
245 }
246 if vertical_sampling_factor == 0 || vertical_sampling_factor > 4 {
247 return Err(Error::Format(format!("invalid vertical sampling factor {}", vertical_sampling_factor)));
248 }
249
250 let quantization_table_index = read_u8(reader)?;
251
252 if quantization_table_index > 3 || (coding_process == CodingProcess::Lossless && quantization_table_index != 0) {
253 return Err(Error::Format(format!("invalid quantization table index {}", quantization_table_index)));
254 }
255
256 components.push(Component {
257 identifier,
258 horizontal_sampling_factor,
259 vertical_sampling_factor,
260 quantization_table_index: quantization_table_index as usize,
261 dct_scale: 8,
262 size: Dimensions {width: 0, height: 0},
263 block_size: Dimensions {width: 0, height: 0},
264 });
265 }
266
267 let mcu_size = update_component_sizes(Dimensions { width, height }, &mut components)?;
268
269 Ok(FrameInfo {
270 is_baseline,
271 is_differential,
272 coding_process,
273 entropy_coding,
274 precision,
275 image_size: Dimensions { width, height },
276 output_size: Dimensions { width, height },
277 mcu_size,
278 components,
279 })
280}
281
282fn ceil_div(x: u32, y: u32) -> Result<u16> {
284 if x == 0 || y == 0 {
285 return Err(Error::Format("invalid dimensions".to_owned()));
288 }
289 Ok((1 + ((x - 1) / y)) as u16)
290}
291
292fn update_component_sizes(size: Dimensions, components: &mut [Component]) -> Result<Dimensions> {
293 let h_max = components.iter().map(|c| c.horizontal_sampling_factor).max().unwrap() as u32;
294 let v_max = components.iter().map(|c| c.vertical_sampling_factor).max().unwrap() as u32;
295
296 let mcu_size = Dimensions {
297 width: ceil_div(size.width as u32, h_max * 8)?,
298 height: ceil_div(size.height as u32, v_max * 8)?,
299 };
300
301 for component in components {
302 component.size.width = ceil_div(size.width as u32 * component.horizontal_sampling_factor as u32 * component.dct_scale as u32, h_max * 8)?;
303 component.size.height = ceil_div(size.height as u32 * component.vertical_sampling_factor as u32 * component.dct_scale as u32, v_max * 8)?;
304
305 component.block_size.width = mcu_size.width * component.horizontal_sampling_factor as u16;
306 component.block_size.height = mcu_size.height * component.vertical_sampling_factor as u16;
307 }
308
309 Ok(mcu_size)
310}
311
312#[test]
313fn test_update_component_sizes() {
314 let mut components = [Component {
315 identifier: 1,
316 horizontal_sampling_factor: 2,
317 vertical_sampling_factor: 2,
318 quantization_table_index: 0,
319 dct_scale: 8,
320 size: Dimensions { width: 0, height: 0 },
321 block_size: Dimensions { width: 0, height: 0 },
322 }];
323 let mcu = update_component_sizes(
324 Dimensions { width: 800, height: 280 },
325 &mut components).unwrap();
326 assert_eq!(mcu, Dimensions { width: 50, height: 18 });
327 assert_eq!(components[0].block_size, Dimensions { width: 100, height: 36 });
328 assert_eq!(components[0].size, Dimensions { width: 800, height: 280 });
329}
330
331pub fn parse_sos<R: Read>(reader: &mut R, frame: &FrameInfo) -> Result<ScanInfo> {
333 let length = read_length(reader, SOS)?;
334 if 0 == length {
335 return Err(Error::Format("zero length in SOS".to_owned()));
336 }
337
338 let component_count = read_u8(reader)?;
339
340 if component_count == 0 || component_count > 4 {
341 return Err(Error::Format(format!("invalid component count {} in scan header", component_count)));
342 }
343
344 if length != 4 + 2 * component_count as usize {
345 return Err(Error::Format("invalid length in SOS".to_owned()));
346 }
347
348 let mut component_indices = Vec::with_capacity(component_count as usize);
349 let mut dc_table_indices = Vec::with_capacity(component_count as usize);
350 let mut ac_table_indices = Vec::with_capacity(component_count as usize);
351
352 for _ in 0 .. component_count {
353 let identifier = read_u8(reader)?;
354
355 let component_index = match frame.components.iter().position(|c| c.identifier == identifier) {
356 Some(value) => value,
357 None => return Err(Error::Format(format!("scan component identifier {} does not match any of the component identifiers defined in the frame", identifier))),
358 };
359
360 if component_indices.contains(&component_index) {
362 return Err(Error::Format(format!("duplicate scan component identifier {}", identifier)));
363 }
364
365 if component_index < *component_indices.iter().max().unwrap_or(&0) {
367 return Err(Error::Format("the scan component order does not follow the order in the frame header".to_owned()));
368 }
369
370 let byte = read_u8(reader)?;
371 let dc_table_index = byte >> 4;
372 let ac_table_index = byte & 0x0f;
373
374 if dc_table_index > 3 || (frame.is_baseline && dc_table_index > 1) {
375 return Err(Error::Format(format!("invalid dc table index {}", dc_table_index)));
376 }
377 if ac_table_index > 3 || (frame.is_baseline && ac_table_index > 1) {
378 return Err(Error::Format(format!("invalid ac table index {}", ac_table_index)));
379 }
380
381 component_indices.push(component_index);
382 dc_table_indices.push(dc_table_index as usize);
383 ac_table_indices.push(ac_table_index as usize);
384 }
385
386 let blocks_per_mcu = component_indices.iter().map(|&i| {
387 frame.components[i].horizontal_sampling_factor as u32 * frame.components[i].vertical_sampling_factor as u32
388 }).fold(0, ops::Add::add);
389
390 if component_count > 1 && blocks_per_mcu > 10 {
391 return Err(Error::Format("scan with more than one component and more than 10 blocks per MCU".to_owned()));
392 }
393
394 let spectral_selection_start = read_u8(reader)?;
396 let mut spectral_selection_end = read_u8(reader)?;
398
399 let byte = read_u8(reader)?;
400 let successive_approximation_high = byte >> 4;
401 let successive_approximation_low = byte & 0x0f;
402
403 let predictor_selection;
406 let point_transform = successive_approximation_low;
407
408 if frame.coding_process == CodingProcess::DctProgressive {
409 predictor_selection = Predictor::NoPrediction;
410 if spectral_selection_end > 63 || spectral_selection_start > spectral_selection_end ||
411 (spectral_selection_start == 0 && spectral_selection_end != 0) {
412 return Err(Error::Format(format!("invalid spectral selection parameters: ss={}, se={}", spectral_selection_start, spectral_selection_end)));
413 }
414 if spectral_selection_start != 0 && component_count != 1 {
415 return Err(Error::Format("spectral selection scan with AC coefficients can't have more than one component".to_owned()));
416 }
417
418 if successive_approximation_high > 13 || successive_approximation_low > 13 {
419 return Err(Error::Format(format!("invalid successive approximation parameters: ah={}, al={}", successive_approximation_high, successive_approximation_low)));
420 }
421
422 if successive_approximation_high != 0 && successive_approximation_high != successive_approximation_low + 1 {
426 return Err(Error::Format("successive approximation scan with more than one bit of improvement".to_owned()));
427 }
428 }
429 else if frame.coding_process == CodingProcess::Lossless {
430 if spectral_selection_end != 0 {
431 return Err(Error::Format("spectral selection end shall be zero in lossless scan".to_owned()));
432 }
433 if successive_approximation_high != 0 {
434 return Err(Error::Format("successive approximation high shall be zero in lossless scan".to_owned()));
435 }
436 predictor_selection = match spectral_selection_start {
437 0 => Predictor::NoPrediction,
438 1 => Predictor::Ra,
439 2 => Predictor::Rb,
440 3 => Predictor::Rc,
441 4 => Predictor::RaRbRc1,
442 5 => Predictor::RaRbRc2,
443 6 => Predictor::RaRbRc3,
444 7 => Predictor::RaRb,
445 _ => {
446 return Err(Error::Format(format!("invalid predictor selection value: {}", spectral_selection_start)));
447 },
448 };
449 }
450 else {
451 predictor_selection = Predictor::NoPrediction;
452 if spectral_selection_end == 0 {
453 spectral_selection_end = 63;
454 }
455 if spectral_selection_start != 0 || spectral_selection_end != 63 {
456 return Err(Error::Format("spectral selection is not allowed in non-progressive scan".to_owned()));
457 }
458 if successive_approximation_high != 0 || successive_approximation_low != 0 {
459 return Err(Error::Format("successive approximation is not allowed in non-progressive scan".to_owned()));
460 }
461 }
462
463 Ok(ScanInfo {
464 component_indices,
465 dc_table_indices,
466 ac_table_indices,
467 spectral_selection: Range {
468 start: spectral_selection_start,
469 end: spectral_selection_end + 1,
470 },
471 predictor_selection,
472 successive_approximation_high,
473 successive_approximation_low,
474 point_transform,
475 })
476}
477
478pub fn parse_dqt<R: Read>(reader: &mut R) -> Result<[Option<[u16; 64]>; 4]> {
480 let mut length = read_length(reader, DQT)?;
481 let mut tables = [None; 4];
482
483 while length > 0 {
485 let byte = read_u8(reader)?;
486 let precision = (byte >> 4) as usize;
487 let index = (byte & 0x0f) as usize;
488
489 if precision > 1 {
498 return Err(Error::Format(format!("invalid precision {} in DQT", precision)));
499 }
500 if index > 3 {
501 return Err(Error::Format(format!("invalid destination identifier {} in DQT", index)));
502 }
503 if length < 65 + 64 * precision {
504 return Err(Error::Format("invalid length in DQT".to_owned()));
505 }
506
507 let mut table = [0u16; 64];
508
509 for item in table.iter_mut() {
510 *item = match precision {
511 0 => u16::from(read_u8(reader)?),
512 1 => read_u16_from_be(reader)?,
513 _ => unreachable!(),
514 };
515 }
516
517 if table.iter().any(|&val| val == 0) {
518 return Err(Error::Format("quantization table contains element with a zero value".to_owned()));
519 }
520
521 tables[index] = Some(table);
522 length -= 65 + 64 * precision;
523 }
524
525 Ok(tables)
526}
527
528#[allow(clippy::type_complexity)]
530pub fn parse_dht<R: Read>(reader: &mut R, is_baseline: Option<bool>) -> Result<(Vec<Option<HuffmanTable>>, Vec<Option<HuffmanTable>>)> {
531 let mut length = read_length(reader, DHT)?;
532 let mut dc_tables = vec![None, None, None, None];
533 let mut ac_tables = vec![None, None, None, None];
534
535 while length > 17 {
537 let byte = read_u8(reader)?;
538 let class = byte >> 4;
539 let index = (byte & 0x0f) as usize;
540
541 if class != 0 && class != 1 {
542 return Err(Error::Format(format!("invalid class {} in DHT", class)));
543 }
544 if is_baseline == Some(true) && index > 1 {
545 return Err(Error::Format("a maximum of two huffman tables per class are allowed in baseline".to_owned()));
546 }
547 if index > 3 {
548 return Err(Error::Format(format!("invalid destination identifier {} in DHT", index)));
549 }
550
551 let mut counts = [0u8; 16];
552 reader.read_exact(&mut counts)?;
553
554 let size = counts.iter().map(|&val| val as usize).fold(0, ops::Add::add);
555
556 if size == 0 {
557 return Err(Error::Format("encountered table with zero length in DHT".to_owned()));
558 }
559 else if size > 256 {
560 return Err(Error::Format("encountered table with excessive length in DHT".to_owned()));
561 }
562 else if size > length - 17 {
563 return Err(Error::Format("invalid length in DHT".to_owned()));
564 }
565
566 let mut values = vec![0u8; size];
567 reader.read_exact(&mut values)?;
568
569 match class {
570 0 => dc_tables[index] = Some(HuffmanTable::new(&counts, &values, HuffmanTableClass::DC)?),
571 1 => ac_tables[index] = Some(HuffmanTable::new(&counts, &values, HuffmanTableClass::AC)?),
572 _ => unreachable!(),
573 }
574
575 length -= 17 + size;
576 }
577
578 if length != 0 {
579 return Err(Error::Format("invalid length in DHT".to_owned()));
580 }
581
582 Ok((dc_tables, ac_tables))
583}
584
585pub fn parse_dri<R: Read>(reader: &mut R) -> Result<u16> {
587 let length = read_length(reader, DRI)?;
588
589 if length != 2 {
590 return Err(Error::Format("DRI with invalid length".to_owned()));
591 }
592
593 Ok(read_u16_from_be(reader)?)
594}
595
596pub fn parse_com<R: Read>(reader: &mut R) -> Result<Vec<u8>> {
598 let length = read_length(reader, COM)?;
599 let mut buffer = vec![0u8; length];
600
601 reader.read_exact(&mut buffer)?;
602
603 Ok(buffer)
604}
605
606pub fn parse_app<R: Read>(reader: &mut R, marker: Marker) -> Result<Option<AppData>> {
608 let length = read_length(reader, marker)?;
609 let mut bytes_read = 0;
610 let mut result = None;
611
612 match marker {
613 APP(0) => {
614 if length >= 5 {
615 let mut buffer = [0u8; 5];
616 reader.read_exact(&mut buffer)?;
617 bytes_read = buffer.len();
618
619 if buffer[0..5] == *b"JFIF\0" {
621 result = Some(AppData::Jfif);
622 } else if buffer[0..5] == *b"AVI1\0" {
624 result = Some(AppData::Avi1);
625 }
626 }
627 }
628 APP(1) => {
629 let mut buffer = vec![0u8; length];
630 reader.read_exact(&mut buffer)?;
631 bytes_read = buffer.len();
632
633 if length >= 6 && buffer[0..6] == *b"Exif\x00\x00" {
636 result = Some(AppData::Exif(buffer[6..].to_vec()));
637 }
638 else if length >= 29 && buffer[0..29] == *b"http://ns.adobe.com/xap/1.0/\0" {
641 result = Some(AppData::Xmp(buffer[29..].to_vec()));
642 }
643 }
644 APP(2) => {
645 if length > 14 {
646 let mut buffer = [0u8; 14];
647 reader.read_exact(&mut buffer)?;
648 bytes_read = buffer.len();
649
650 if buffer[0..12] == *b"ICC_PROFILE\0" {
653 let mut data = vec![0; length - bytes_read];
654 reader.read_exact(&mut data)?;
655 bytes_read += data.len();
656 result = Some(AppData::Icc(IccChunk {
657 seq_no: buffer[12],
658 num_markers: buffer[13],
659 data,
660 }));
661 }
662 }
663 }
664 APP(13) => {
665 if length >= 14 {
666 let mut buffer = [0u8; 14];
667 reader.read_exact(&mut buffer)?;
668 bytes_read = buffer.len();
669
670 if buffer[0..14] == *b"Photoshop 3.0\0" {
673 let mut data = vec![0; length - bytes_read];
674 reader.read_exact(&mut data)?;
675 bytes_read += data.len();
676 result = Some(AppData::Psir(data));
677 }
678 }
679 }
680 APP(14) => {
681 if length >= 12 {
682 let mut buffer = [0u8; 12];
683 reader.read_exact(&mut buffer)?;
684 bytes_read = buffer.len();
685
686 if buffer[0 .. 6] == *b"Adobe\0" {
688 let color_transform = match buffer[11] {
689 0 => AdobeColorTransform::Unknown,
690 1 => AdobeColorTransform::YCbCr,
691 2 => AdobeColorTransform::YCCK,
692 _ => return Err(Error::Format("invalid color transform in adobe app segment".to_owned())),
693 };
694
695 result = Some(AppData::Adobe(color_transform));
696 }
697 }
698 },
699 _ => {},
700 }
701
702 skip_bytes(reader, length - bytes_read)?;
703 Ok(result)
704}