use super::mq_coder::{MqDecoder, MQ_NUM_CONTEXTS};
use super::{Jp2Error, Jp2Result};
const NUM_SIG_CONTEXTS: usize = 9;
const SIGN_CTX_BASE: usize = 9;
const _NUM_SIGN_CONTEXTS: usize = 5;
const MR_CTX_BASE: usize = 14;
const UNI_CTX: usize = 17;
const RLC_CTX: usize = 18;
#[derive(Debug, Clone)]
pub struct CodeBlock {
pub coeffs: Vec<i32>,
pub width: usize,
pub height: usize,
}
impl CodeBlock {
#[must_use]
pub fn dequantize(&self, step_size: f64, decoded_bit_planes: usize) -> Vec<f64> {
if (step_size - 1.0).abs() < 1e-10 {
return self.coeffs.iter().map(|&v| v as f64).collect();
}
let scale = step_size * (0.5f64).powi(decoded_bit_planes as i32);
self.coeffs
.iter()
.map(|&v| {
let mag = (v.abs() as f64) * scale;
if v < 0 {
-mag
} else {
mag
}
})
.collect()
}
}
#[derive(Clone, Copy, Default)]
struct CoeffState {
significant: bool,
sign: i32,
magnitude: i32,
visited: bool,
}
fn significance_context(
state: &[CoeffState],
col: usize,
row: usize,
width: usize,
height: usize,
) -> usize {
let mut h_count = 0u32;
let mut v_count = 0u32;
let mut d_count = 0u32;
let left_sig = col > 0 && state[row * width + col - 1].significant;
let right_sig = col + 1 < width && state[row * width + col + 1].significant;
let up_sig = row > 0 && state[(row - 1) * width + col].significant;
let dn_sig = row + 1 < height && state[(row + 1) * width + col].significant;
if left_sig {
h_count += 1;
}
if right_sig {
h_count += 1;
}
if up_sig {
v_count += 1;
}
if dn_sig {
v_count += 1;
}
if col > 0 && row > 0 && state[(row - 1) * width + (col - 1)].significant {
d_count += 1;
}
if col + 1 < width && row > 0 && state[(row - 1) * width + col + 1].significant {
d_count += 1;
}
if col > 0 && row + 1 < height && state[(row + 1) * width + (col - 1)].significant {
d_count += 1;
}
if col + 1 < width && row + 1 < height && state[(row + 1) * width + col + 1].significant {
d_count += 1;
}
let hv = h_count + v_count;
match hv {
0 if d_count == 0 => 0,
0 if d_count == 1 => 1,
0 => 2,
1 if d_count == 0 => 3,
1 if d_count == 1 => 4,
1 => 5,
2 if d_count == 0 => 6,
2 => 7,
_ => 8,
}
}
fn sign_context(
state: &[CoeffState],
col: usize,
row: usize,
width: usize,
height: usize,
) -> (usize, u8) {
let h_contrib = {
let l = col > 0 && state[row * width + col - 1].significant;
let r = col + 1 < width && state[row * width + col + 1].significant;
let l_sign = l && state[row * width + col - 1].sign != 0;
let r_sign = r && state[row * width + col + 1].sign != 0;
if !l && !r {
0i32
} else if l && !r {
if l_sign {
-1
} else {
1
}
} else if !l && r {
if r_sign {
-1
} else {
1
}
} else {
let ls = if l_sign { -1i32 } else { 1 };
let rs = if r_sign { -1i32 } else { 1 };
(ls + rs).signum()
}
};
let v_contrib = {
let u = row > 0 && state[(row - 1) * width + col].significant;
let d = row + 1 < height && state[(row + 1) * width + col].significant;
let u_sign = u && state[(row - 1) * width + col].sign != 0;
let d_sign = d && state[(row + 1) * width + col].sign != 0;
if !u && !d {
0i32
} else if u && !d {
if u_sign {
-1
} else {
1
}
} else if !u && d {
if d_sign {
-1
} else {
1
}
} else {
let us = if u_sign { -1i32 } else { 1 };
let ds = if d_sign { -1i32 } else { 1 };
(us + ds).signum()
}
};
let (ctx_offset, xor_bit) = match (h_contrib, v_contrib) {
(1, 1) | (1, 0) | (0, 1) => (0, 0u8),
(1, -1) => (1, 0),
(0, 0) => (2, 0),
(-1, 1) => (1, 1),
(-1, 0) | (0, -1) | (-1, -1) => (0, 1),
_ => (0, 0),
};
(SIGN_CTX_BASE + ctx_offset, xor_bit)
}
fn mr_context(
state: &[CoeffState],
col: usize,
row: usize,
width: usize,
height: usize,
first_mr: bool,
) -> usize {
if first_mr {
let has_sig_neighbour = {
let mut any = false;
if col > 0 && state[row * width + col - 1].significant {
any = true;
}
if col + 1 < width && state[row * width + col + 1].significant {
any = true;
}
if row > 0 && state[(row - 1) * width + col].significant {
any = true;
}
if row + 1 < height && state[(row + 1) * width + col].significant {
any = true;
}
any
};
MR_CTX_BASE + if has_sig_neighbour { 1 } else { 0 }
} else {
MR_CTX_BASE + 2
}
}
fn significance_propagation_pass(
mq: &mut MqDecoder,
state: &mut [CoeffState],
width: usize,
height: usize,
bit_plane: u8,
cx: &mut [u8; MQ_NUM_CONTEXTS],
) -> Jp2Result<()> {
for row in 0..height {
for col in 0..width {
let idx = row * width + col;
if state[idx].significant || state[idx].visited {
continue;
}
let ctx = significance_context(state, col, row, width, height);
if ctx == 0 {
continue;
}
let sig_bit = mq.decode_bit(ctx)?;
state[idx].visited = true;
if sig_bit == 1 {
state[idx].significant = true;
state[idx].magnitude |= 1 << bit_plane;
let (sign_ctx, xor_bit) = sign_context(state, col, row, width, height);
let sign_coded = mq.decode_bit(sign_ctx)?;
state[idx].sign = i32::from(sign_coded ^ xor_bit);
}
}
}
Ok(())
}
fn magnitude_refinement_pass(
mq: &mut MqDecoder,
state: &mut [CoeffState],
width: usize,
height: usize,
bit_plane: u8,
first_mr: bool,
cx: &mut [u8; MQ_NUM_CONTEXTS],
) -> Jp2Result<()> {
for row in 0..height {
for col in 0..width {
let idx = row * width + col;
if !state[idx].significant || state[idx].visited {
continue;
}
let ctx = mr_context(state, col, row, width, height, first_mr);
let mr_bit = mq.decode_bit(ctx)?;
if mr_bit == 1 {
state[idx].magnitude |= 1 << bit_plane;
}
}
}
Ok(())
}
fn cleanup_pass(
mq: &mut MqDecoder,
state: &mut [CoeffState],
width: usize,
height: usize,
bit_plane: u8,
cx: &mut [u8; MQ_NUM_CONTEXTS],
) -> Jp2Result<()> {
let mut row = 0;
while row < height {
let mut col = 0;
while col < width {
let idx = row * width + col;
if state[idx].visited {
col += 1;
continue;
}
let can_rlc = row + 3 < height
&& (0..4).all(|dr| {
let r = row + dr;
let i = r * width + col;
!state[i].significant
&& !state[i].visited
&& significance_context(state, col, r, width, height) == 0
});
if can_rlc {
let rlc_bit = mq.decode_bit(RLC_CTX)?;
if rlc_bit == 0 {
for dr in 0..4usize {
state[(row + dr) * width + col].visited = true;
}
col += 1;
continue;
}
let p0 = mq.decode_bit(UNI_CTX)?;
let p1 = mq.decode_bit(UNI_CTX)?;
let first_one = usize::from(p0) * 2 + usize::from(p1);
for dr in 0..first_one {
state[(row + dr) * width + col].visited = true;
}
let sig_row = row + first_one;
let sig_idx = sig_row * width + col;
state[sig_idx].significant = true;
state[sig_idx].magnitude |= 1 << bit_plane;
state[sig_idx].visited = true;
let (sign_ctx, xor_bit) = sign_context(state, col, sig_row, width, height);
let sign_coded = mq.decode_bit(sign_ctx)?;
state[sig_idx].sign = i32::from(sign_coded ^ xor_bit);
for dr in (first_one + 1)..4 {
let r = row + dr;
let i = r * width + col;
if state[i].visited {
continue;
}
let sig_ctx = significance_context(state, col, r, width, height);
let sig_bit = mq.decode_bit(sig_ctx)?;
state[i].visited = true;
if sig_bit == 1 {
state[i].significant = true;
state[i].magnitude |= 1 << bit_plane;
let (sc, xb) = sign_context(state, col, r, width, height);
let sb = mq.decode_bit(sc)?;
state[i].sign = i32::from(sb ^ xb);
}
}
col += 1;
continue;
}
let sig_ctx = significance_context(state, col, row, width, height);
let sig_bit = mq.decode_bit(sig_ctx)?;
state[idx].visited = true;
if sig_bit == 1 {
state[idx].significant = true;
state[idx].magnitude |= 1 << bit_plane;
let (sc, xb) = sign_context(state, col, row, width, height);
let sb = mq.decode_bit(sc)?;
state[idx].sign = i32::from(sb ^ xb);
}
col += 1;
}
row += 1;
}
Ok(())
}
pub fn decode_code_block(
data: &[u8],
width: usize,
height: usize,
num_bit_planes: u8,
) -> Jp2Result<CodeBlock> {
if width == 0 || height == 0 {
return Err(Jp2Error::InternalError(
"code-block dimensions must be non-zero".to_string(),
));
}
if num_bit_planes == 0 {
return Ok(CodeBlock {
coeffs: vec![0; width * height],
width,
height,
});
}
let mut mq = MqDecoder::new(data);
let mut cx = [0u8; MQ_NUM_CONTEXTS];
let mut state = vec![CoeffState::default(); width * height];
for bp_idx in 0..num_bit_planes {
let bit_plane = num_bit_planes - 1 - bp_idx;
for s in state.iter_mut() {
s.visited = false;
}
if bp_idx == 0 {
cleanup_pass(&mut mq, &mut state, width, height, bit_plane, &mut cx)?;
} else {
significance_propagation_pass(&mut mq, &mut state, width, height, bit_plane, &mut cx)?;
for s in state.iter_mut() {
s.visited = false;
}
let first_mr = bp_idx == 1;
magnitude_refinement_pass(
&mut mq, &mut state, width, height, bit_plane, first_mr, &mut cx,
)?;
for s in state.iter_mut() {
s.visited = false;
}
cleanup_pass(&mut mq, &mut state, width, height, bit_plane, &mut cx)?;
}
}
let coeffs: Vec<i32> = state
.iter()
.map(|s| {
if s.sign == 0 {
s.magnitude
} else {
-s.magnitude
}
})
.collect();
Ok(CodeBlock {
coeffs,
width,
height,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decode_zero_bit_planes_returns_zeros() {
let data = vec![0u8; 16];
let block = decode_code_block(&data, 4, 4, 0).expect("decode");
assert_eq!(block.coeffs.len(), 16);
for &c in &block.coeffs {
assert_eq!(c, 0);
}
}
#[test]
fn decode_code_block_runs_without_panic() {
let data: Vec<u8> = (0u8..=255).collect();
let result = decode_code_block(&data, 8, 8, 4);
match result {
Ok(block) => {
assert_eq!(block.width, 8);
assert_eq!(block.height, 8);
assert_eq!(block.coeffs.len(), 64);
for &c in &block.coeffs {
let _ = c;
}
}
Err(_) => {
}
}
}
#[test]
fn zero_dimension_returns_error() {
let data = vec![0u8; 8];
assert!(decode_code_block(&data, 0, 4, 1).is_err());
assert!(decode_code_block(&data, 4, 0, 1).is_err());
}
#[test]
fn significance_context_zero_for_isolated() {
let state = vec![CoeffState::default(); 4 * 4];
let ctx = significance_context(&state, 1, 1, 4, 4);
assert_eq!(ctx, 0);
}
#[test]
fn num_sig_contexts_is_nine() {
assert_eq!(NUM_SIG_CONTEXTS, 9);
}
}