use super::context::{context_index, update_context, ContextState, NUM_TOTAL_CONTEXTS};
use super::golomb::{
compute_limit, compute_qbpp, decode_golomb_unsigned_limited, map_error_near,
unmap_error_lossless, unmap_error_near, BitReader,
};
use super::markers::{parse_headers, JlsHeaders};
use super::predictor::{predict, quantize_gradient};
use super::run_mode::{
bump_run_index, decrement_run_index, enter_run_near, j_for, run_termination_ctx, threshold_for,
RunState,
};
use super::{JlsError, JlsResult};
#[derive(Debug, Clone)]
pub struct DecodedImage {
pub width: u32,
pub height: u32,
pub num_components: u8,
pub precision: u8,
pub samples: Vec<Vec<u16>>,
}
pub struct JpegLsDecoder;
impl JpegLsDecoder {
pub fn new() -> Self {
Self
}
pub fn is_jpegls(data: &[u8]) -> bool {
data.len() >= 4 && data[0] == 0xFF && data[1] == 0xD8 && data[2] == 0xFF && data[3] == 0xF7
}
pub fn decode(data: &[u8]) -> JlsResult<DecodedImage> {
let headers = parse_headers(data)?;
let scan_data = &data[headers.scan_data_start..];
decode_scan(scan_data, &headers)
}
}
impl Default for JpegLsDecoder {
fn default() -> Self {
Self::new()
}
}
struct ScanDecodeParams {
max_val: i32,
near: i32,
reset: i32,
limit: i32,
qbpp: u8,
t1: i32,
t2: i32,
t3: i32,
w: usize,
}
#[inline]
fn fetch_neighbours(out: &[u16], row: usize, col: usize, w: usize) -> (i32, i32, i32, i32) {
let a = if col > 0 {
out[row * w + col - 1] as i32
} else if row > 0 {
out[(row - 1) * w] as i32
} else {
0
};
let b = if row > 0 {
out[(row - 1) * w + col] as i32
} else {
a
};
let c = if row > 0 && col > 0 {
out[(row - 1) * w + col - 1] as i32
} else if row > 0 {
out[(row - 1) * w] as i32
} else {
0
};
let d = if row > 0 && col + 1 < w {
out[(row - 1) * w + col + 1] as i32
} else {
b
};
(a, b, c, d)
}
fn decode_pixel_regular(
ctx_states: &mut [ContextState],
reader: &mut BitReader<'_>,
a: i32,
b: i32,
c: i32,
d: i32,
p: &ScanDecodeParams,
) -> JlsResult<u16> {
let d1 = d - b;
let d2 = b - c;
let d3 = c - a;
let q1 = quantize_gradient(d1, p.t1, p.t2, p.t3);
let q2 = quantize_gradient(d2, p.t1, p.t2, p.t3);
let q3 = quantize_gradient(d3, p.t1, p.t2, p.t3);
let (ctx_idx, sign) = context_index(q1, q2, q3);
let state = &mut ctx_states[ctx_idx];
let px = predict(a, b, c);
let corrected_px = (px - sign * state.cx).clamp(0, p.max_val);
let k = state.k.max(0);
let e_mapped =
decode_golomb_unsigned_limited(reader, k, p.limit, p.qbpp).ok_or(JlsError::Truncated {
context: "scan data",
})?;
let (err_q, rx) = if p.near == 0 {
let err_abs = unmap_error_lossless(e_mapped);
let err = err_abs * sign;
let reconstructed = (corrected_px + err).clamp(0, p.max_val) as u16;
(err_abs, reconstructed)
} else {
let q_step = 2 * p.near + 1;
let err_q_signed = unmap_error_near(e_mapped, p.near, p.max_val);
let err = err_q_signed * q_step * sign;
let reconstructed = (corrected_px + err).clamp(0, p.max_val) as u16;
(err_q_signed.abs(), reconstructed)
};
update_context(state, err_q, p.near, p.reset, p.max_val);
Ok(rx)
}
#[allow(clippy::too_many_arguments)]
fn decode_run_mode(
samples: &mut [u16],
ctx_states: &mut [ContextState],
run_state: &mut RunState,
reader: &mut BitReader<'_>,
row: usize,
start_col: usize,
ra: i32,
p: &ScanDecodeParams,
) -> JlsResult<usize> {
let w = p.w;
let runval = ra;
let mut col = start_col;
loop {
let thr = threshold_for(run_state.run_index);
if thr <= 0 {
break;
}
if col + thr as usize > w {
break;
}
let bit = reader.read_bit().ok_or(JlsError::Truncated {
context: "run mode token",
})?;
if bit == 1 {
for k in 0..thr as usize {
samples[row * w + col + k] = runval as u16;
}
col += thr as usize;
bump_run_index(run_state);
if col == w {
return Ok(col);
}
continue;
} else {
let j_bits = j_for(run_state.run_index) as u8;
let residual = if j_bits == 0 {
0i32
} else {
reader.read_bits(j_bits).ok_or(JlsError::Truncated {
context: "run residual length",
})? as i32
};
for k in 0..residual as usize {
samples[row * w + col + k] = runval as u16;
}
col += residual as usize;
if col >= w {
return Err(JlsError::Truncated {
context: "run residual overflows row",
});
}
let term_sample =
decode_run_termination_sample(samples, ctx_states, reader, row, col, runval, p)?;
samples[row * w + col] = term_sample;
col += 1;
decrement_run_index(run_state);
return Ok(col);
}
}
let bit = reader.read_bit().ok_or(JlsError::Truncated {
context: "run trailing bit",
})?;
if bit == 1 {
while col < w {
samples[row * w + col] = runval as u16;
col += 1;
}
Ok(col)
} else {
let j_bits = j_for(run_state.run_index) as u8;
let residual = if j_bits == 0 {
0i32
} else {
reader.read_bits(j_bits).ok_or(JlsError::Truncated {
context: "run residual length",
})? as i32
};
for k in 0..residual as usize {
samples[row * w + col + k] = runval as u16;
}
col += residual as usize;
if col >= w {
return Err(JlsError::Truncated {
context: "run residual overflows row",
});
}
let term_sample =
decode_run_termination_sample(samples, ctx_states, reader, row, col, runval, p)?;
samples[row * w + col] = term_sample;
col += 1;
decrement_run_index(run_state);
Ok(col)
}
}
fn decode_run_termination_sample(
samples: &[u16],
ctx_states: &mut [ContextState],
reader: &mut BitReader<'_>,
row: usize,
col: usize,
runval: i32,
p: &ScanDecodeParams,
) -> JlsResult<u16> {
let rb = if row > 0 {
samples[(row - 1) * p.w + col] as i32
} else {
runval
};
let ctx_idx = run_termination_ctx(runval, rb);
let state = &mut ctx_states[ctx_idx];
let sign = if runval == rb {
1i32
} else if rb > runval {
1i32
} else {
-1i32
};
let predicted = runval;
let corrected_px = (predicted - sign * state.cx).clamp(0, p.max_val);
let k = state.k.max(0);
let e_mapped =
decode_golomb_unsigned_limited(reader, k, p.limit, p.qbpp).ok_or(JlsError::Truncated {
context: "run termination sample",
})?;
let (err_q, rx) = if p.near == 0 {
let err_abs = unmap_error_lossless(e_mapped);
let err = err_abs * sign;
let reconstructed = (corrected_px + err).clamp(0, p.max_val) as u16;
(err_abs, reconstructed)
} else {
let q_step = 2 * p.near + 1;
let err_q_signed = unmap_error_near(e_mapped, p.near, p.max_val);
let err = err_q_signed * q_step * sign;
let reconstructed = (corrected_px + err).clamp(0, p.max_val) as u16;
(err_q_signed.abs(), reconstructed)
};
update_context(state, err_q, p.near, p.reset, p.max_val);
Ok(rx)
}
#[allow(clippy::too_many_arguments)]
fn decode_row_with_run_mode(
samples: &mut [u16],
ctx_states: &mut [ContextState],
run_state: &mut RunState,
reader: &mut BitReader<'_>,
row: usize,
p: &ScanDecodeParams,
) -> JlsResult<()> {
let w = p.w;
run_state.reset_at_line_start();
let mut col = 0usize;
while col < w {
let (a, b, c, d) = fetch_neighbours(samples, row, col, w);
let d1 = d - b;
let d2 = b - c;
let d3 = c - a;
if enter_run_near(d1, d2, d3, p.near) {
col = decode_run_mode(samples, ctx_states, run_state, reader, row, col, a, p)?;
} else {
let rx = decode_pixel_regular(ctx_states, reader, a, b, c, d, p)?;
samples[row * w + col] = rx;
col += 1;
}
}
Ok(())
}
fn decode_scan(scan_data: &[u8], headers: &JlsHeaders) -> JlsResult<DecodedImage> {
let w = headers.frame.width as usize;
let h = headers.frame.height as usize;
let nc = headers.frame.num_components as usize;
let precision = headers.frame.precision;
let max_val = headers.presets.max_val as i32;
let near = headers.scan.near as i32;
let ilv = headers.scan.ilv;
let t1 = headers.presets.t1;
let t2 = headers.presets.t2;
let t3 = headers.presets.t3;
let reset = headers.presets.reset as i32;
let limit = compute_limit(max_val);
let qbpp = compute_qbpp(max_val);
let mut reader = BitReader::new(scan_data);
let mut all_samples: Vec<Vec<u16>> = (0..nc).map(|_| vec![0u16; w * h]).collect();
let mut all_ctx: Vec<Vec<ContextState>> = (0..nc)
.map(|_| vec![ContextState::default(); NUM_TOTAL_CONTEXTS])
.collect();
let params = ScanDecodeParams {
max_val,
near,
reset,
limit,
qbpp,
t1,
t2,
t3,
w,
};
match ilv {
0 => {
let mut run_states: Vec<RunState> = (0..nc).map(|_| RunState::new()).collect();
for comp in 0..nc {
let ctx_states = &mut all_ctx[comp];
let run_state = &mut run_states[comp];
for row in 0..h {
decode_row_with_run_mode(
&mut all_samples[comp],
ctx_states,
run_state,
&mut reader,
row,
¶ms,
)?;
}
}
}
1 => {
let mut run_states: Vec<RunState> = (0..nc).map(|_| RunState::new()).collect();
for row in 0..h {
for comp in 0..nc {
let ctx_states = &mut all_ctx[comp];
let run_state = &mut run_states[comp];
decode_row_with_run_mode(
&mut all_samples[comp],
ctx_states,
run_state,
&mut reader,
row,
¶ms,
)?;
}
}
}
2 => {
for row in 0..h {
for col in 0..w {
for comp in 0..nc {
let ctx_states = &mut all_ctx[comp];
let (a, b, c, d) = fetch_neighbours(&all_samples[comp], row, col, w);
let rx =
decode_pixel_regular(ctx_states, &mut reader, a, b, c, d, ¶ms)?;
all_samples[comp][row * w + col] = rx;
}
}
}
}
other => {
return Err(JlsError::Unsupported(format!(
"ILV mode {other} is not defined in ISO 14495-1"
)));
}
}
Ok(DecodedImage {
width: w as u32,
height: h as u32,
num_components: nc as u8,
precision,
samples: all_samples,
})
}
pub use super::golomb::map_error_near as golomb_map_error_near;