use crate::error::{Error, Result};
use crate::jpeg::byteorder::ReadBytesExt;
use crate::jpeg::marker::Marker;
use crate::jpeg::marker::Marker::{SOF, SOS};
use std::io::{self, Read};
use std::ops::RangeInclusive;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Dimensions {
pub width: u16,
pub height: u16,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum CodingProcess {
DctSequential,
DctProgressive,
Lossless,
}
#[derive(Debug, Clone)]
pub struct FrameInfo {
pub is_baseline: bool,
pub is_differential: bool,
pub coding_process: CodingProcess,
pub precision: u8,
pub image_size: Dimensions,
pub mcu_size: Dimensions,
pub components: Vec<Component>,
}
#[derive(Debug)]
pub struct ScanInfo {
pub component_indices: Vec<usize>,
pub dc_table_indices: Vec<usize>,
pub ac_table_indices: Vec<usize>,
pub spectral_selection: RangeInclusive<u8>,
pub successive_approximation_high: u8,
pub successive_approximation_low: u8,
}
#[derive(Clone, Debug)]
pub struct Component {
pub identifier: u8,
pub horizontal_sampling_factor: u8,
pub vertical_sampling_factor: u8,
pub quantization_table_index: usize,
pub size: Dimensions,
pub block_size: Dimensions,
}
fn read_length<R: Read>(reader: &mut R, marker: Marker) -> Result<usize> {
if !marker.has_length() {
return Err(Error::Format("unexpected empty marker"));
}
let length = reader.read_u16_be()? as usize;
if length < 2 {
return Err(Error::Format("encountered invalid length"));
}
Ok(length - 2)
}
fn skip_bytes<R: Read>(reader: &mut R, length: usize) -> Result<()> {
let length = length as u64;
let to_skip = &mut reader.by_ref().take(length);
let copied = io::copy(to_skip, &mut io::sink())?;
if copied < length {
Err(Error::Io(io::ErrorKind::UnexpectedEof.into()))
} else {
Ok(())
}
}
pub fn skip_marker<R: Read>(reader: &mut R, marker: Marker) -> Result<()> {
if marker.has_length() {
let length = read_length(reader, marker)?;
skip_bytes(reader, length)?;
}
Ok(())
}
pub fn parse_sof<R: Read>(reader: &mut R, marker: Marker) -> Result<FrameInfo> {
let length = read_length(reader, marker)?;
if length <= 6 {
return Err(Error::Format("invalid length in SOF"));
}
let is_baseline = marker == SOF(0);
let is_differential = match marker {
SOF(0..=3 | 9..=11) => false,
SOF(5..=7 | 13..=15) => true,
_ => panic!(),
};
let coding_process = match marker {
SOF(0 | 1 | 5 | 9 | 13) => CodingProcess::DctSequential,
SOF(2 | 6 | 10 | 14) => CodingProcess::DctProgressive,
SOF(3 | 7 | 11 | 15) => CodingProcess::Lossless,
_ => panic!(),
};
let precision = reader.read_u8()?;
match precision {
8 => {},
12 => {
if is_baseline {
return Err(Error::Format("12 bit sample precision is not allowed in baseline"));
}
},
_ => {
if coding_process != CodingProcess::Lossless {
return Err(Error::Format("invalid precision in frame header"))
}
},
}
let height = reader.read_u16_be()?;
let width = reader.read_u16_be()?;
if width == 0 || height == 0 {
return Err(Error::Format("zero size in frame header"));
}
let component_count = reader.read_u8()?;
if component_count == 0 {
return Err(Error::Format("zero component count in frame header"));
}
if coding_process == CodingProcess::DctProgressive && component_count > 4 {
return Err(Error::Format("progressive frame with more than 4 components"));
}
if length != 6 + 3 * component_count as usize {
return Err(Error::Format("invalid length in SOF"));
}
let mut components: Vec<Component> = Vec::with_capacity(component_count as usize);
for _ in 0..component_count {
let identifier = reader.read_u8()?;
if components.iter().any(|c| c.identifier == identifier) {
return Err(Error::Format("duplicate frame component identifier"));
}
let byte = reader.read_u8()?;
let horizontal_sampling_factor = byte >> 4;
let vertical_sampling_factor = byte & 0x0f;
if horizontal_sampling_factor == 0 || horizontal_sampling_factor > 4 {
return Err(Error::Format("invalid horizontal sampling factor"));
}
if vertical_sampling_factor == 0 || vertical_sampling_factor > 4 {
return Err(Error::Format("invalid vertical sampling factor"));
}
let quantization_table_index = reader.read_u8()?;
if quantization_table_index > 3 || (coding_process == CodingProcess::Lossless && quantization_table_index != 0) {
return Err(Error::Format("invalid quantization table index"));
}
components.push(Component {
identifier,
horizontal_sampling_factor,
vertical_sampling_factor,
quantization_table_index: quantization_table_index as usize,
size: Dimensions {width: 0, height: 0},
block_size: Dimensions {width: 0, height: 0},
});
}
let h_max = components.iter().map(|c| c.horizontal_sampling_factor).max().unwrap();
let v_max = components.iter().map(|c| c.vertical_sampling_factor).max().unwrap();
let mcu_size = Dimensions {
width: (f32::from(width) / (f32::from(h_max) * 8.0)).ceil() as u16,
height: (f32::from(height) / (f32::from(v_max) * 8.0)).ceil() as u16,
};
for component in &mut components {
component.size.width = (f32::from(width) * (f32::from(component.horizontal_sampling_factor) / f32::from(h_max))).ceil() as u16;
component.size.height = (f32::from(height) * (f32::from(component.vertical_sampling_factor) / f32::from(v_max))).ceil() as u16;
component.block_size.width = mcu_size.width * u16::from(component.horizontal_sampling_factor);
component.block_size.height = mcu_size.height * u16::from(component.vertical_sampling_factor);
}
Ok(FrameInfo {
is_baseline,
is_differential,
coding_process,
precision,
image_size: Dimensions { width, height },
mcu_size,
components,
})
}
pub fn parse_sos<R: Read>(reader: &mut R, frame: &FrameInfo) -> Result<ScanInfo> {
let length = read_length(reader, SOS)?;
if 0 == length {
return Err(Error::Format("zero length in SOS"));
}
let component_count = reader.read_u8()?;
if component_count == 0 || component_count > 4 {
return Err(Error::Format("invalid component count in scan header"));
}
if length != 4 + 2 * component_count as usize {
return Err(Error::Format("invalid length in SOS"));
}
let mut component_indices = Vec::with_capacity(component_count as usize);
let mut dc_table_indices = Vec::with_capacity(component_count as usize);
let mut ac_table_indices = Vec::with_capacity(component_count as usize);
for _ in 0..component_count {
let identifier = reader.read_u8()?;
let component_index = match frame.components.iter().position(|c| c.identifier == identifier) {
Some(value) => value,
None => return Err(Error::Format("scan component identifier does not match any of the component identifiers defined in the frame")),
};
if component_indices.contains(&component_index) {
return Err(Error::Format("duplicate scan component identifier"));
}
if component_index < *component_indices.iter().max().unwrap_or(&0) {
return Err(Error::Format("the scan component order does not follow the order in the frame header"));
}
let byte = reader.read_u8()?;
let dc_table_index = byte >> 4;
let ac_table_index = byte & 0x0f;
if dc_table_index > 3 || (frame.is_baseline && dc_table_index > 1) {
return Err(Error::Format("invalid dc table index"));
}
if ac_table_index > 3 || (frame.is_baseline && ac_table_index > 1) {
return Err(Error::Format("invalid ac table index"));
}
component_indices.push(component_index);
dc_table_indices.push(dc_table_index as usize);
ac_table_indices.push(ac_table_index as usize);
}
let blocks_per_mcu = component_indices.iter().map(|&i| {
u32::from(frame.components[i].horizontal_sampling_factor) * u32::from(frame.components[i].vertical_sampling_factor)
}).fold(0, ::std::ops::Add::add);
if component_count > 1 && blocks_per_mcu > 10 {
return Err(Error::Format("scan with more than one component and more than 10 blocks per MCU"));
}
let spectral_selection_start = reader.read_u8()?;
let spectral_selection_end = reader.read_u8()?;
let byte = reader.read_u8()?;
let successive_approximation_high = byte >> 4;
let successive_approximation_low = byte & 0x0f;
Ok(ScanInfo {
component_indices,
dc_table_indices,
ac_table_indices,
spectral_selection: spectral_selection_start..=spectral_selection_end,
successive_approximation_high,
successive_approximation_low,
})
}