use alloc::vec::Vec;
use crate::foundation::consts::{
JPEG_NATURAL_ORDER, MARKER_APP0, MARKER_APP2, MARKER_DQT, MARKER_EOI, MARKER_SOF0, MARKER_SOF1,
MARKER_SOF2, MARKER_SOF9, MARKER_SOF10, MARKER_SOI, MARKER_SOS,
};
#[derive(Debug)]
pub(crate) struct ScanResult {
pub dqt_tables: [Option<DqtEntry>; 4],
pub sof: Option<SofInfo>,
pub total_ac_symbols: u16,
pub dht_count: u8,
pub has_jfif: bool,
pub has_icc_profile: bool,
pub has_adobe: bool,
pub has_photoshop_iptc: bool,
pub sos_count: u16,
}
#[derive(Debug, Clone)]
pub(crate) struct DqtEntry {
pub values: [u16; 64],
pub precision: u8,
}
#[derive(Debug, Clone)]
pub(crate) struct SofInfo {
pub marker: u8,
pub width: u16,
pub height: u16,
pub num_components: u8,
pub components: Vec<(u8, u8, u8, u8)>,
}
#[derive(Debug)]
pub(crate) enum ScanError {
TooShort,
NotJpeg,
Truncated,
}
pub(crate) fn scan_headers(data: &[u8]) -> Result<ScanResult, ScanError> {
if data.len() < 2 {
return Err(ScanError::TooShort);
}
if data[0] != 0xFF || data[1] != MARKER_SOI {
return Err(ScanError::NotJpeg);
}
let mut result = ScanResult {
dqt_tables: [const { None }; 4],
sof: None,
total_ac_symbols: 0,
dht_count: 0,
has_jfif: false,
has_icc_profile: false,
has_adobe: false,
has_photoshop_iptc: false,
sos_count: 0,
};
let mut pos = 2;
loop {
pos = match find_marker(data, pos) {
Some(p) => p,
None => break,
};
if pos + 1 >= data.len() {
break;
}
let marker = data[pos + 1];
pos += 2;
match marker {
MARKER_EOI => break,
MARKER_DQT => {
pos = parse_dqt(data, pos, &mut result)?;
}
m if is_sof_marker(m) => {
pos = parse_sof(data, pos, m, &mut result)?;
}
0xC4 => {
pos = parse_dht(data, pos, &mut result)?;
}
MARKER_SOS => {
result.sos_count += 1;
if pos + 1 >= data.len() {
break;
}
let len = read_u16(data, pos) as usize;
pos += len;
pos = skip_entropy_data(data, pos);
}
MARKER_APP0 => {
pos = parse_app0(data, pos, &mut result)?;
}
MARKER_APP2 => {
pos = parse_app2(data, pos, &mut result)?;
}
0xED => {
pos = parse_app13(data, pos, &mut result)?;
}
0xEE => {
pos = parse_app14(data, pos, &mut result)?;
}
0xD0..=0xD7 => {
}
_ => {
if pos + 1 >= data.len() {
break;
}
let len = read_u16(data, pos) as usize;
if len < 2 || pos + len > data.len() {
break;
}
pos += len;
}
}
}
Ok(result)
}
fn find_marker(data: &[u8], mut pos: usize) -> Option<usize> {
while pos < data.len() {
if data[pos] == 0xFF {
while pos + 1 < data.len() && data[pos + 1] == 0xFF {
pos += 1;
}
if pos + 1 < data.len() && data[pos + 1] != 0x00 {
return Some(pos);
}
}
pos += 1;
}
None
}
fn is_sof_marker(m: u8) -> bool {
matches!(
m,
MARKER_SOF0 | MARKER_SOF1 | MARKER_SOF2 | MARKER_SOF9 | MARKER_SOF10
)
}
fn parse_dqt(data: &[u8], pos: usize, result: &mut ScanResult) -> Result<usize, ScanError> {
if pos + 1 >= data.len() {
return Err(ScanError::Truncated);
}
let seg_len = read_u16(data, pos) as usize;
if seg_len < 2 || pos + seg_len > data.len() {
return Err(ScanError::Truncated);
}
let seg_end = pos + seg_len;
let mut p = pos + 2;
while p < seg_end {
if p >= data.len() {
return Err(ScanError::Truncated);
}
let pq_tq = data[p];
let precision = (pq_tq >> 4) & 0x0F; let table_id = (pq_tq & 0x0F) as usize;
p += 1;
if table_id > 3 {
let value_bytes = if precision == 0 { 64 } else { 128 };
p += value_bytes;
continue;
}
let mut values_zigzag = [0u16; 64];
if precision == 0 {
if p + 64 > data.len() {
return Err(ScanError::Truncated);
}
for i in 0..64 {
values_zigzag[i] = data[p + i] as u16;
}
p += 64;
} else {
if p + 128 > data.len() {
return Err(ScanError::Truncated);
}
for i in 0..64 {
values_zigzag[i] = read_u16(data, p + i * 2);
}
p += 128;
}
let mut values_natural = [0u16; 64];
for zigzag_idx in 0..64 {
let natural_idx = JPEG_NATURAL_ORDER[zigzag_idx] as usize;
values_natural[natural_idx] = values_zigzag[zigzag_idx];
}
result.dqt_tables[table_id] = Some(DqtEntry {
values: values_natural,
precision,
});
}
Ok(seg_end)
}
fn parse_sof(
data: &[u8],
pos: usize,
marker: u8,
result: &mut ScanResult,
) -> Result<usize, ScanError> {
if pos + 1 >= data.len() {
return Err(ScanError::Truncated);
}
let seg_len = read_u16(data, pos) as usize;
if seg_len < 8 || pos + seg_len > data.len() {
return Err(ScanError::Truncated);
}
let precision = data[pos + 2];
let _ = precision; let height = read_u16(data, pos + 3);
let width = read_u16(data, pos + 5);
let num_components = data[pos + 7];
let expected_len = 8 + num_components as usize * 3;
if seg_len < expected_len || pos + expected_len > data.len() {
return Err(ScanError::Truncated);
}
let mut components = Vec::new();
for c in 0..num_components as usize {
let offset = pos + 8 + c * 3;
let id = data[offset];
let sampling = data[offset + 1];
let h_samp = (sampling >> 4) & 0x0F;
let v_samp = sampling & 0x0F;
let quant_table_idx = data[offset + 2];
components.push((id, h_samp, v_samp, quant_table_idx));
}
result.sof = Some(SofInfo {
marker,
width,
height,
num_components,
components,
});
Ok(pos + seg_len)
}
fn parse_dht(data: &[u8], pos: usize, result: &mut ScanResult) -> Result<usize, ScanError> {
if pos + 1 >= data.len() {
return Err(ScanError::Truncated);
}
let seg_len = read_u16(data, pos) as usize;
if seg_len < 2 || pos + seg_len > data.len() {
return Err(ScanError::Truncated);
}
let seg_end = pos + seg_len;
let mut p = pos + 2;
while p < seg_end {
if p >= data.len() {
return Err(ScanError::Truncated);
}
let tc_th = data[p];
let table_class = (tc_th >> 4) & 0x0F; p += 1;
if p + 16 > data.len() {
return Err(ScanError::Truncated);
}
let mut total_symbols: u16 = 0;
for i in 0..16 {
total_symbols += data[p + i] as u16;
}
p += 16;
if table_class == 1 {
result.total_ac_symbols += total_symbols;
}
result.dht_count += 1;
let sym_count = total_symbols as usize;
if p + sym_count > data.len() {
return Err(ScanError::Truncated);
}
p += sym_count;
}
Ok(seg_end)
}
fn parse_app0(data: &[u8], pos: usize, result: &mut ScanResult) -> Result<usize, ScanError> {
if pos + 1 >= data.len() {
return Err(ScanError::Truncated);
}
let seg_len = read_u16(data, pos) as usize;
if seg_len < 2 || pos + seg_len > data.len() {
return Err(ScanError::Truncated);
}
if seg_len >= 7 && pos + 6 < data.len() {
let id = &data[pos + 2..pos + 7];
if id == b"JFIF\0" {
result.has_jfif = true;
}
}
Ok(pos + seg_len)
}
fn parse_app2(data: &[u8], pos: usize, result: &mut ScanResult) -> Result<usize, ScanError> {
if pos + 1 >= data.len() {
return Err(ScanError::Truncated);
}
let seg_len = read_u16(data, pos) as usize;
if seg_len < 2 || pos + seg_len > data.len() {
return Err(ScanError::Truncated);
}
if seg_len >= 14 && pos + 14 <= data.len() {
let id = &data[pos + 2..pos + 14];
if id == b"ICC_PROFILE\0" {
result.has_icc_profile = true;
}
}
Ok(pos + seg_len)
}
fn parse_app13(data: &[u8], pos: usize, result: &mut ScanResult) -> Result<usize, ScanError> {
if pos + 1 >= data.len() {
return Err(ScanError::Truncated);
}
let seg_len = read_u16(data, pos) as usize;
if seg_len < 2 || pos + seg_len > data.len() {
return Err(ScanError::Truncated);
}
if seg_len >= 16 && pos + 16 <= data.len() {
let id = &data[pos + 2..pos + 16];
if id == b"Photoshop 3.0\0" {
result.has_photoshop_iptc = true;
}
}
Ok(pos + seg_len)
}
fn parse_app14(data: &[u8], pos: usize, result: &mut ScanResult) -> Result<usize, ScanError> {
if pos + 1 >= data.len() {
return Err(ScanError::Truncated);
}
let seg_len = read_u16(data, pos) as usize;
if seg_len < 2 || pos + seg_len > data.len() {
return Err(ScanError::Truncated);
}
if seg_len >= 7 && pos + 7 <= data.len() {
let id = &data[pos + 2..pos + 7];
if id == b"Adobe" {
result.has_adobe = true;
}
}
Ok(pos + seg_len)
}
fn skip_entropy_data(data: &[u8], mut pos: usize) -> usize {
while pos < data.len() {
if data[pos] == 0xFF {
if pos + 1 >= data.len() {
return pos;
}
let next = data[pos + 1];
if next == 0x00 {
pos += 2;
continue;
}
if (0xD0..=0xD7).contains(&next) {
pos += 2;
continue;
}
return pos;
}
pos += 1;
}
pos
}
#[inline]
fn read_u16(data: &[u8], pos: usize) -> u16 {
(data[pos] as u16) << 8 | data[pos + 1] as u16
}