use alloc::vec;
use alloc::vec::Vec;
use super::arithmetic_encoder::{ArithmeticEncoder, ArithmeticEncoderContext};
use super::build::SubBandType;
use super::codestream::CodeBlockStyle;
use crate::writer::BitWriter;
const SIGNIFICANT: u8 = 1 << 7;
const MAGNITUDE_REFINED: u8 = 1 << 6;
const CODED_IN_CURRENT_PASS: u8 = 1 << 5;
const NEGATIVE: u8 = 1 << 4;
#[derive(Debug)]
pub(crate) struct EncodedCodeBlock {
pub(crate) data: Vec<u8>,
pub(crate) num_coding_passes: u8,
pub(crate) num_zero_bitplanes: u8,
pub(crate) ht_cleanup_length: u32,
pub(crate) ht_refinement_length: u32,
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct EncodedCodeBlockSegment {
pub(crate) data_offset: u32,
pub(crate) data_length: u32,
pub(crate) start_coding_pass: u8,
pub(crate) end_coding_pass: u8,
pub(crate) distortion_delta: f64,
pub(crate) use_arithmetic: bool,
}
#[derive(Debug)]
pub(crate) struct EncodedCodeBlockWithSegments {
pub(crate) data: Vec<u8>,
pub(crate) segments: Vec<EncodedCodeBlockSegment>,
pub(crate) num_coding_passes: u8,
pub(crate) num_zero_bitplanes: u8,
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct ClassicTier1TokenSegment {
pub(crate) token_bit_offset: u32,
pub(crate) token_bit_count: u32,
pub(crate) start_coding_pass: u8,
pub(crate) end_coding_pass: u8,
pub(crate) use_arithmetic: bool,
}
pub(crate) fn pack_classic_selective_bypass_tier1_tokens(
token_bytes: &[u8],
token_segments: &[ClassicTier1TokenSegment],
number_of_coding_passes: u8,
missing_bit_planes: u8,
) -> Result<EncodedCodeBlockWithSegments, &'static str> {
let mut reader = ClassicTier1TokenReader::new(token_bytes);
let mut contexts = [ArithmeticEncoderContext::default(); 19];
reset_contexts(&mut contexts);
let mut data = Vec::new();
let mut segments = Vec::with_capacity(token_segments.len());
for segment in token_segments {
if segment.start_coding_pass > segment.end_coding_pass {
return Err("classic Tier-1 token segment pass range is invalid");
}
if segment.end_coding_pass > number_of_coding_passes {
return Err("classic Tier-1 token segment exceeds coding passes");
}
let token_bit_offset = usize::try_from(segment.token_bit_offset)
.map_err(|_| "classic Tier-1 token bit offset exceeds usize")?;
let token_bit_count = usize::try_from(segment.token_bit_count)
.map_err(|_| "classic Tier-1 token bit count exceeds usize")?;
reader.seek(token_bit_offset)?;
if segment.use_arithmetic {
if token_bit_count % 6 != 0 {
return Err("classic Tier-1 MQ token segment is not aligned to 6-bit symbols");
}
let symbol_count = token_bit_count / 6;
let mut encoder =
ArithmeticEncoder::with_capacity(symbol_count.saturating_div(16) + 32);
for _ in 0..symbol_count {
let token = reader.read_bits(6)?;
let ctx = (token & 0x1F) as usize;
if ctx >= contexts.len() {
return Err("classic Tier-1 MQ token context is out of range");
}
let bit = (token >> 5) & 1;
encoder.encode(bit, &mut contexts[ctx]);
}
push_segment(
&mut data,
&mut segments,
segment.start_coding_pass,
segment.end_coding_pass,
encoder.finish(),
f64::EPSILON,
true,
);
} else {
let mut writer = BitWriter::new();
for _ in 0..token_bit_count {
writer.write_bit(reader.read_bits(1)?);
}
push_segment(
&mut data,
&mut segments,
segment.start_coding_pass,
segment.end_coding_pass,
writer.finish(),
f64::EPSILON,
false,
);
}
}
Ok(EncodedCodeBlockWithSegments {
data,
segments,
num_coding_passes: number_of_coding_passes,
num_zero_bitplanes: missing_bit_planes,
})
}
struct ClassicTier1TokenReader<'a> {
bytes: &'a [u8],
bit_pos: usize,
}
impl<'a> ClassicTier1TokenReader<'a> {
fn new(bytes: &'a [u8]) -> Self {
Self { bytes, bit_pos: 0 }
}
fn seek(&mut self, bit_pos: usize) -> Result<(), &'static str> {
if bit_pos > self.bytes.len().saturating_mul(8) {
return Err("classic Tier-1 token offset exceeds token buffer");
}
self.bit_pos = bit_pos;
Ok(())
}
fn read_bits(&mut self, count: u8) -> Result<u32, &'static str> {
let end = self
.bit_pos
.checked_add(usize::from(count))
.ok_or("classic Tier-1 token bit range overflows")?;
if end > self.bytes.len().saturating_mul(8) {
return Err("classic Tier-1 token read exceeds token buffer");
}
let mut value = 0u32;
for _ in 0..count {
let byte = self.bytes[self.bit_pos / 8];
let shift = 7 - (self.bit_pos % 8);
value = (value << 1) | u32::from((byte >> shift) & 1);
self.bit_pos += 1;
}
Ok(value)
}
}
#[rustfmt::skip]
const ZERO_CTX_LL_LH: [u8; 256] = [
0, 3, 1, 3, 5, 7, 6, 7, 1, 3, 2, 3, 6, 7, 6, 7, 5, 7, 6, 7, 8, 8, 8, 8, 6,
7, 6, 7, 8, 8, 8, 8, 1, 3, 2, 3, 6, 7, 6, 7, 2, 3, 2, 3, 6, 7, 6, 7, 6, 7,
6, 7, 8, 8, 8, 8, 6, 7, 6, 7, 8, 8, 8, 8, 3, 4, 3, 4, 7, 7, 7, 7, 3, 4, 3,
4, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 7, 7, 7, 7, 8, 8, 8, 8, 3, 4, 3, 4,
7, 7, 7, 7, 3, 4, 3, 4, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 7, 7, 7, 7, 8,
8, 8, 8, 1, 3, 2, 3, 6, 7, 6, 7, 2, 3, 2, 3, 6, 7, 6, 7, 6, 7, 6, 7, 8, 8,
8, 8, 6, 7, 6, 7, 8, 8, 8, 8, 2, 3, 2, 3, 6, 7, 6, 7, 2, 3, 2, 3, 6, 7, 6,
7, 6, 7, 6, 7, 8, 8, 8, 8, 6, 7, 6, 7, 8, 8, 8, 8, 3, 4, 3, 4, 7, 7, 7, 7,
3, 4, 3, 4, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 7, 7, 7, 7, 8, 8, 8, 8, 3,
4, 3, 4, 7, 7, 7, 7, 3, 4, 3, 4, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 7, 7,
7, 7, 8, 8, 8, 8,
];
#[rustfmt::skip]
const ZERO_CTX_HL: [u8; 256] = [
0, 5, 1, 6, 3, 7, 3, 7, 1, 6, 2, 6, 3, 7, 3, 7, 3, 7, 3, 7, 4, 7, 4, 7, 3,
7, 3, 7, 4, 7, 4, 7, 1, 6, 2, 6, 3, 7, 3, 7, 2, 6, 2, 6, 3, 7, 3, 7, 3, 7,
3, 7, 4, 7, 4, 7, 3, 7, 3, 7, 4, 7, 4, 7, 5, 8, 6, 8, 7, 8, 7, 8, 6, 8, 6,
8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 6, 8, 6, 8,
7, 8, 7, 8, 6, 8, 6, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7,
8, 7, 8, 1, 6, 2, 6, 3, 7, 3, 7, 2, 6, 2, 6, 3, 7, 3, 7, 3, 7, 3, 7, 4, 7,
4, 7, 3, 7, 3, 7, 4, 7, 4, 7, 2, 6, 2, 6, 3, 7, 3, 7, 2, 6, 2, 6, 3, 7, 3,
7, 3, 7, 3, 7, 4, 7, 4, 7, 3, 7, 3, 7, 4, 7, 4, 7, 6, 8, 6, 8, 7, 8, 7, 8,
6, 8, 6, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 6,
8, 6, 8, 7, 8, 7, 8, 6, 8, 6, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8, 7, 8,
7, 8, 7, 8, 7, 8,
];
#[rustfmt::skip]
const ZERO_CTX_HH: [u8; 256] = [
0, 1, 3, 4, 1, 2, 4, 5, 3, 4, 6, 7, 4, 5, 7, 7, 1, 2, 4, 5, 2, 2, 5, 5, 4,
5, 7, 7, 5, 5, 7, 7, 3, 4, 6, 7, 4, 5, 7, 7, 6, 7, 8, 8, 7, 7, 8, 8, 4, 5,
7, 7, 5, 5, 7, 7, 7, 7, 8, 8, 7, 7, 8, 8, 1, 2, 4, 5, 2, 2, 5, 5, 4, 5, 7,
7, 5, 5, 7, 7, 2, 2, 5, 5, 2, 2, 5, 5, 5, 5, 7, 7, 5, 5, 7, 7, 4, 5, 7, 7,
5, 5, 7, 7, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 7, 7, 5, 5, 7, 7, 7, 7, 8, 8, 7,
7, 8, 8, 3, 4, 6, 7, 4, 5, 7, 7, 6, 7, 8, 8, 7, 7, 8, 8, 4, 5, 7, 7, 5, 5,
7, 7, 7, 7, 8, 8, 7, 7, 8, 8, 6, 7, 8, 8, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8,
8, 7, 7, 8, 8, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 4, 5, 7, 7, 5, 5, 7, 7,
7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 7, 7, 5, 5, 7, 7, 7, 7, 8, 8, 7, 7, 8, 8, 7,
7, 8, 8, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 7, 7, 8, 8, 7, 7, 8, 8, 8, 8,
8, 8, 8, 8, 8, 8,
];
#[rustfmt::skip]
const SIGN_CONTEXT_LOOKUP: [(u8, u8); 256] = [
(9,0), (10,0), (10,1), (0,0), (12,0), (13,0), (11,0), (0,0), (12,1), (11,1),
(13,1), (0,0), (0,0), (0,0), (0,0), (0,0), (12,0), (13,0), (11,0), (0,0),
(12,0), (13,0), (11,0), (0,0), (9,0), (10,0), (10,1), (0,0), (0,0), (0,0),
(0,0), (0,0), (12,1), (11,1), (13,1), (0,0), (9,0), (10,0), (10,1), (0,0),
(12,1), (11,1), (13,1), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0),
(0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0),
(0,0), (0,0), (0,0), (10,0), (10,0), (9,0), (0,0), (13,0), (13,0), (12,0),
(0,0), (11,1), (11,1), (12,1), (0,0), (0,0), (0,0), (0,0), (0,0), (13,0),
(13,0), (12,0), (0,0), (13,0), (13,0), (12,0), (0,0), (10,0), (10,0), (9,0),
(0,0), (0,0), (0,0), (0,0), (0,0), (11,1), (11,1), (12,1), (0,0), (10,0),
(10,0), (9,0), (0,0), (11,1), (11,1), (12,1), (0,0), (0,0), (0,0), (0,0),
(0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0),
(0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (10,1), (9,0), (10,1), (0,0),
(11,0), (12,0), (11,0), (0,0), (13,1), (12,1), (13,1), (0,0), (0,0), (0,0),
(0,0), (0,0), (11,0), (12,0), (11,0), (0,0), (11,0), (12,0), (11,0), (0,0),
(10,1), (9,0), (10,1), (0,0), (0,0), (0,0), (0,0), (0,0), (13,1), (12,1),
(13,1), (0,0), (10,1), (9,0), (10,1), (0,0), (13,1), (12,1), (13,1), (0,0),
(0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0),
(0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0),
(0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0),
(0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0),
(0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0),
(0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0),
(0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0),
(0,0), (0,0), (0,0), (0,0), (0,0), (0,0), (0,0),
];
pub(crate) fn encode_code_block(
coefficients: &[i32],
width: u32,
height: u32,
sub_band_type: SubBandType,
total_bitplanes: u8,
) -> EncodedCodeBlock {
encode_code_block_with_style(
coefficients,
width,
height,
sub_band_type,
total_bitplanes,
&CodeBlockStyle::default(),
)
}
fn prepare_padded_coefficients(
coefficients: &[i32],
w: usize,
h: usize,
pw: usize,
) -> (Vec<u32>, Vec<u8>) {
let mut magnitudes = vec![0u32; pw * (h + 2)];
let mut states = vec![0u8; magnitudes.len()];
for y in 0..h {
for x in 0..w {
let idx = (y + 1) * pw + (x + 1);
let coeff = coefficients[y * w + x];
magnitudes[idx] = coeff.unsigned_abs();
if coeff < 0 {
states[idx] = NEGATIVE;
}
}
}
(magnitudes, states)
}
pub(crate) fn encode_code_block_with_style(
coefficients: &[i32],
width: u32,
height: u32,
sub_band_type: SubBandType,
total_bitplanes: u8,
style: &CodeBlockStyle,
) -> EncodedCodeBlock {
let w = width as usize;
let h = height as usize;
let max_magnitude = coefficients
.iter()
.map(|c| c.unsigned_abs())
.max()
.unwrap_or(0);
if max_magnitude == 0 {
return EncodedCodeBlock {
data: Vec::new(),
num_coding_passes: 0,
num_zero_bitplanes: total_bitplanes,
ht_cleanup_length: 0,
ht_refinement_length: 0,
};
}
let num_bitplanes = 32 - max_magnitude.leading_zeros();
debug_assert!(num_bitplanes as u8 <= total_bitplanes);
let num_zero_bitplanes = total_bitplanes.saturating_sub(num_bitplanes as u8);
let pw = w + 2; let (magnitudes, mut states) = prepare_padded_coefficients(coefficients, w, h, pw);
let mut neighbors = vec![0u8; magnitudes.len()];
let mut encoder =
ArithmeticEncoder::with_capacity(arithmetic_encoder_capacity(w, h, num_bitplanes as usize));
let mut contexts = [ArithmeticEncoderContext::default(); 19];
reset_contexts(&mut contexts);
let mut num_coding_passes = 0u8;
let mut coded_indices = Vec::new();
for bp in (0..num_bitplanes).rev() {
let bit_mask = 1u32 << bp;
let is_first_bitplane = bp == num_bitplanes - 1;
if is_first_bitplane {
cleanup_pass(
&magnitudes,
&mut states,
&mut neighbors,
&mut encoder,
&mut contexts,
w,
h,
pw,
bit_mask,
sub_band_type,
style,
);
if style.segmentation_symbols {
encode_segmentation_symbols(&mut encoder, &mut contexts);
}
num_coding_passes += 1;
if style.reset_context_probabilities {
reset_contexts(&mut contexts);
}
} else {
significance_propagation_pass(
&magnitudes,
&mut states,
&mut neighbors,
&mut coded_indices,
&mut encoder,
&mut contexts,
w,
h,
pw,
bit_mask,
sub_band_type,
style,
);
num_coding_passes += 1;
if style.reset_context_probabilities {
reset_contexts(&mut contexts);
}
magnitude_refinement_pass(
&magnitudes,
&mut states,
&mut neighbors,
&mut encoder,
&mut contexts,
w,
h,
pw,
bit_mask,
style,
);
num_coding_passes += 1;
if style.reset_context_probabilities {
reset_contexts(&mut contexts);
}
cleanup_pass(
&magnitudes,
&mut states,
&mut neighbors,
&mut encoder,
&mut contexts,
w,
h,
pw,
bit_mask,
sub_band_type,
style,
);
if style.segmentation_symbols {
encode_segmentation_symbols(&mut encoder, &mut contexts);
}
num_coding_passes += 1;
if style.reset_context_probabilities {
reset_contexts(&mut contexts);
}
}
clear_coded_in_current_pass(&mut states, &mut coded_indices);
}
let data = encoder.finish();
EncodedCodeBlock {
data,
num_coding_passes,
num_zero_bitplanes,
ht_cleanup_length: 0,
ht_refinement_length: 0,
}
}
pub(crate) fn encode_code_block_segments_with_style(
coefficients: &[i32],
width: u32,
height: u32,
sub_band_type: SubBandType,
total_bitplanes: u8,
style: &CodeBlockStyle,
) -> EncodedCodeBlockWithSegments {
if !style.termination_on_each_pass && !style.selective_arithmetic_coding_bypass {
let encoded = encode_code_block_with_style(
coefficients,
width,
height,
sub_band_type,
total_bitplanes,
style,
);
let segments = if encoded.num_coding_passes == 0 {
Vec::new()
} else {
vec![EncodedCodeBlockSegment {
data_offset: 0,
data_length: u32::try_from(encoded.data.len())
.expect("classic code-block payload length fits in u32"),
start_coding_pass: 0,
end_coding_pass: encoded.num_coding_passes,
distortion_delta: segment_distortion_delta(
coefficients,
0,
encoded.num_coding_passes,
total_bitplanes,
),
use_arithmetic: true,
}]
};
return EncodedCodeBlockWithSegments {
data: encoded.data,
segments,
num_coding_passes: encoded.num_coding_passes,
num_zero_bitplanes: encoded.num_zero_bitplanes,
};
}
let w = width as usize;
let h = height as usize;
let max_magnitude = coefficients
.iter()
.map(|c| c.unsigned_abs())
.max()
.unwrap_or(0);
if max_magnitude == 0 {
return EncodedCodeBlockWithSegments {
data: Vec::new(),
segments: Vec::new(),
num_coding_passes: 0,
num_zero_bitplanes: total_bitplanes,
};
}
let num_bitplanes = 32 - max_magnitude.leading_zeros();
debug_assert!(num_bitplanes as u8 <= total_bitplanes);
let num_zero_bitplanes = total_bitplanes.saturating_sub(num_bitplanes as u8);
let pw = w + 2;
let (magnitudes, mut states) = prepare_padded_coefficients(coefficients, w, h, pw);
let mut neighbors = vec![0u8; magnitudes.len()];
let mut contexts = [ArithmeticEncoderContext::default(); 19];
reset_contexts(&mut contexts);
let mut data = Vec::new();
let mut segments = Vec::new();
let total_passes = 1 + 3 * (num_bitplanes as u8 - 1);
let mut current_segment_idx = None;
let mut current_segment_start_pass = 0u8;
let mut current_use_arithmetic = true;
let mut arithmetic_encoder: Option<ArithmeticEncoder> = None;
let mut bypass_writer: Option<BitWriter> = None;
let mut coded_indices = Vec::new();
for coding_pass in 0..total_passes {
let segment_idx = if style.termination_on_each_pass {
coding_pass
} else if style.selective_arithmetic_coding_bypass {
bypass_segment_idx(coding_pass)
} else {
0
};
let use_arithmetic = if style.selective_arithmetic_coding_bypass {
coding_pass <= 9 || coding_pass % 3 == 0
} else {
true
};
if current_segment_idx != Some(segment_idx) {
if let Some(previous_idx) = current_segment_idx {
if current_use_arithmetic {
push_segment(
&mut data,
&mut segments,
current_segment_start_pass,
coding_pass,
arithmetic_encoder
.take()
.expect("arithmetic segment encoder exists")
.finish(),
segment_distortion_delta(
coefficients,
current_segment_start_pass,
coding_pass,
num_bitplanes as u8,
),
true,
);
} else {
push_segment(
&mut data,
&mut segments,
current_segment_start_pass,
coding_pass,
bypass_writer
.take()
.expect("bypass segment writer exists")
.finish(),
segment_distortion_delta(
coefficients,
current_segment_start_pass,
coding_pass,
num_bitplanes as u8,
),
false,
);
}
debug_assert!(previous_idx < segment_idx);
}
current_segment_idx = Some(segment_idx);
current_segment_start_pass = coding_pass;
current_use_arithmetic = use_arithmetic;
if use_arithmetic {
arithmetic_encoder = Some(ArithmeticEncoder::new());
bypass_writer = None;
} else {
arithmetic_encoder = None;
bypass_writer = Some(BitWriter::new());
}
}
let current_bitplane = usize::from(coding_pass.div_ceil(3));
let bit_mask = 1u32 << (num_bitplanes as usize - 1 - current_bitplane);
match coding_pass % 3 {
0 => {
let encoder = arithmetic_encoder
.as_mut()
.expect("cleanup pass uses arithmetic encoder");
cleanup_pass(
&magnitudes,
&mut states,
&mut neighbors,
encoder,
&mut contexts,
w,
h,
pw,
bit_mask,
sub_band_type,
style,
);
if style.segmentation_symbols {
encode_segmentation_symbols(encoder, &mut contexts);
}
clear_coded_in_current_pass(&mut states, &mut coded_indices);
}
1 => {
if current_use_arithmetic {
significance_propagation_pass(
&magnitudes,
&mut states,
&mut neighbors,
&mut coded_indices,
arithmetic_encoder
.as_mut()
.expect("arithmetic encoder exists for significance pass"),
&mut contexts,
w,
h,
pw,
bit_mask,
sub_band_type,
style,
);
} else {
significance_propagation_pass_raw(
&magnitudes,
&mut states,
&mut neighbors,
&mut coded_indices,
bypass_writer
.as_mut()
.expect("bypass writer exists for significance pass"),
w,
h,
pw,
bit_mask,
style,
);
}
}
2 => {
if current_use_arithmetic {
magnitude_refinement_pass(
&magnitudes,
&mut states,
&mut neighbors,
arithmetic_encoder
.as_mut()
.expect("arithmetic encoder exists for refinement pass"),
&mut contexts,
w,
h,
pw,
bit_mask,
style,
);
} else {
magnitude_refinement_pass_raw(
&magnitudes,
&mut states,
&mut neighbors,
bypass_writer
.as_mut()
.expect("bypass writer exists for refinement pass"),
w,
h,
pw,
bit_mask,
style,
);
}
}
_ => unreachable!(),
}
if style.reset_context_probabilities {
reset_contexts(&mut contexts);
}
}
if current_segment_idx.is_some() {
if current_use_arithmetic {
push_segment(
&mut data,
&mut segments,
current_segment_start_pass,
total_passes,
arithmetic_encoder
.take()
.expect("final arithmetic segment encoder exists")
.finish(),
segment_distortion_delta(
coefficients,
current_segment_start_pass,
total_passes,
num_bitplanes as u8,
),
true,
);
} else {
push_segment(
&mut data,
&mut segments,
current_segment_start_pass,
total_passes,
bypass_writer
.take()
.expect("final bypass segment writer exists")
.finish(),
segment_distortion_delta(
coefficients,
current_segment_start_pass,
total_passes,
num_bitplanes as u8,
),
false,
);
}
}
EncodedCodeBlockWithSegments {
data,
segments,
num_coding_passes: total_passes,
num_zero_bitplanes,
}
}
fn reset_contexts(contexts: &mut [ArithmeticEncoderContext; 19]) {
*contexts = [ArithmeticEncoderContext::default(); 19];
contexts[0].reset_with_index(4);
contexts[17].reset_with_index(3);
contexts[18].reset_with_index(46);
}
fn arithmetic_encoder_capacity(width: usize, height: usize, bitplanes: usize) -> usize {
1 + width
.saturating_mul(height)
.saturating_mul(bitplanes)
.checked_div(16)
.unwrap_or(usize::MAX)
.max(32)
}
fn encode_segmentation_symbols(
encoder: &mut ArithmeticEncoder,
contexts: &mut [ArithmeticEncoderContext; 19],
) {
encoder.encode(1, &mut contexts[18]);
encoder.encode(0, &mut contexts[18]);
encoder.encode(1, &mut contexts[18]);
encoder.encode(0, &mut contexts[18]);
}
#[inline]
fn bypass_segment_idx(pass_idx: u8) -> u8 {
if pass_idx < 10 {
0
} else {
1 + (2 * ((pass_idx - 10) / 3)) + if ((pass_idx - 10) % 3) == 2 { 1 } else { 0 }
}
}
fn push_segment(
data: &mut Vec<u8>,
segments: &mut Vec<EncodedCodeBlockSegment>,
start_coding_pass: u8,
end_coding_pass: u8,
segment_data: Vec<u8>,
distortion_delta: f64,
use_arithmetic: bool,
) {
let data_offset =
u32::try_from(data.len()).expect("classic code-block data offset fits in u32");
let data_length =
u32::try_from(segment_data.len()).expect("classic code-block segment length fits in u32");
data.extend_from_slice(&segment_data);
segments.push(EncodedCodeBlockSegment {
data_offset,
data_length,
start_coding_pass,
end_coding_pass,
distortion_delta,
use_arithmetic,
});
}
fn segment_distortion_delta(
coefficients: &[i32],
start_coding_pass: u8,
end_coding_pass: u8,
num_bitplanes: u8,
) -> f64 {
let before =
coefficient_distortion_after_passes(coefficients, start_coding_pass, num_bitplanes);
let after = coefficient_distortion_after_passes(coefficients, end_coding_pass, num_bitplanes);
(before - after).max(f64::EPSILON)
}
fn coefficient_distortion_after_passes(
coefficients: &[i32],
completed_passes: u8,
num_bitplanes: u8,
) -> f64 {
coefficients
.iter()
.map(|coefficient| {
let magnitude = coefficient.unsigned_abs();
let reconstructed =
reconstructed_magnitude_after_passes(magnitude, completed_passes, num_bitplanes);
let error = f64::from(magnitude.saturating_sub(reconstructed));
error * error
})
.sum()
}
fn reconstructed_magnitude_after_passes(
magnitude: u32,
completed_passes: u8,
num_bitplanes: u8,
) -> u32 {
if magnitude == 0 || completed_passes == 0 || num_bitplanes == 0 {
return 0;
}
let deepest_coded_bitplane = completed_passes
.saturating_sub(1)
.div_ceil(3)
.min(num_bitplanes.saturating_sub(1));
let retained_bitplanes = deepest_coded_bitplane.saturating_add(1);
if retained_bitplanes >= num_bitplanes {
return magnitude;
}
let lower_bits = u32::from(num_bitplanes - retained_bitplanes);
let mask = !((1u32 << lower_bits) - 1);
magnitude & mask
}
fn mark_coded_in_current_pass(idx: usize, states: &mut [u8], coded_indices: &mut Vec<usize>) {
if states[idx] & CODED_IN_CURRENT_PASS == 0 {
states[idx] |= CODED_IN_CURRENT_PASS;
coded_indices.push(idx);
}
}
fn clear_coded_in_current_pass(states: &mut [u8], coded_indices: &mut Vec<usize>) {
for idx in coded_indices.drain(..) {
states[idx] &= !CODED_IN_CURRENT_PASS;
}
}
fn significance_propagation_pass(
magnitudes: &[u32],
states: &mut [u8],
neighbors: &mut [u8],
coded_indices: &mut Vec<usize>,
encoder: &mut ArithmeticEncoder,
contexts: &mut [ArithmeticEncoderContext; 19],
w: usize,
h: usize,
pw: usize,
bit_mask: u32,
sub_band_type: SubBandType,
style: &CodeBlockStyle,
) {
if style.vertically_causal_context {
significance_propagation_pass_impl::<true>(
magnitudes,
states,
neighbors,
coded_indices,
encoder,
contexts,
w,
h,
pw,
bit_mask,
sub_band_type,
);
} else {
significance_propagation_pass_impl::<false>(
magnitudes,
states,
neighbors,
coded_indices,
encoder,
contexts,
w,
h,
pw,
bit_mask,
sub_band_type,
);
}
}
fn significance_propagation_pass_impl<const VERTICAL_CAUSAL: bool>(
magnitudes: &[u32],
states: &mut [u8],
neighbors: &mut [u8],
coded_indices: &mut Vec<usize>,
encoder: &mut ArithmeticEncoder,
contexts: &mut [ArithmeticEncoderContext; 19],
w: usize,
h: usize,
pw: usize,
bit_mask: u32,
sub_band_type: SubBandType,
) {
for y_base in (0..h).step_by(4) {
for x in 0..w {
let y_end = (y_base + 4).min(h);
for y in y_base..y_end {
let idx = (y + 1) * pw + (x + 1);
let is_significant = states[idx] & SIGNIFICANT != 0;
let neighbor_sig = effective_neighbor_sig::<VERTICAL_CAUSAL>(neighbors[idx], y, h);
let has_sig_neighbors = neighbor_sig != 0;
if !is_significant && has_sig_neighbors {
let ctx_label = zero_coding_ctx(neighbor_sig, sub_band_type);
let bit = (magnitudes[idx] & bit_mask != 0) as u32;
encoder.encode(bit, &mut contexts[ctx_label as usize]);
mark_coded_in_current_pass(idx, states, coded_indices);
if bit == 1 {
encode_sign::<VERTICAL_CAUSAL>(
idx, neighbors, states, encoder, contexts, pw, y, h,
);
set_significant(idx, states, neighbors, pw);
}
}
}
}
}
}
fn significance_propagation_pass_raw(
magnitudes: &[u32],
states: &mut [u8],
neighbors: &mut [u8],
coded_indices: &mut Vec<usize>,
writer: &mut BitWriter,
w: usize,
h: usize,
pw: usize,
bit_mask: u32,
style: &CodeBlockStyle,
) {
if style.vertically_causal_context {
significance_propagation_pass_raw_impl::<true>(
magnitudes,
states,
neighbors,
coded_indices,
writer,
w,
h,
pw,
bit_mask,
);
} else {
significance_propagation_pass_raw_impl::<false>(
magnitudes,
states,
neighbors,
coded_indices,
writer,
w,
h,
pw,
bit_mask,
);
}
}
fn significance_propagation_pass_raw_impl<const VERTICAL_CAUSAL: bool>(
magnitudes: &[u32],
states: &mut [u8],
neighbors: &mut [u8],
coded_indices: &mut Vec<usize>,
writer: &mut BitWriter,
w: usize,
h: usize,
pw: usize,
bit_mask: u32,
) {
for y_base in (0..h).step_by(4) {
for x in 0..w {
let y_end = (y_base + 4).min(h);
for y in y_base..y_end {
let idx = (y + 1) * pw + (x + 1);
let is_significant = states[idx] & SIGNIFICANT != 0;
let neighbor_sig = effective_neighbor_sig::<VERTICAL_CAUSAL>(neighbors[idx], y, h);
if !is_significant && neighbor_sig != 0 {
let bit = (magnitudes[idx] & bit_mask != 0) as u32;
writer.write_bit(bit);
mark_coded_in_current_pass(idx, states, coded_indices);
if bit == 1 {
encode_sign_raw(idx, states, writer);
set_significant(idx, states, neighbors, pw);
}
}
}
}
}
}
fn magnitude_refinement_pass(
magnitudes: &[u32],
states: &mut [u8],
neighbors: &mut [u8],
encoder: &mut ArithmeticEncoder,
contexts: &mut [ArithmeticEncoderContext; 19],
w: usize,
h: usize,
pw: usize,
bit_mask: u32,
style: &CodeBlockStyle,
) {
if style.vertically_causal_context {
magnitude_refinement_pass_impl::<true>(
magnitudes, states, neighbors, encoder, contexts, w, h, pw, bit_mask,
);
} else {
magnitude_refinement_pass_impl::<false>(
magnitudes, states, neighbors, encoder, contexts, w, h, pw, bit_mask,
);
}
}
fn magnitude_refinement_pass_impl<const VERTICAL_CAUSAL: bool>(
magnitudes: &[u32],
states: &mut [u8],
neighbors: &mut [u8],
encoder: &mut ArithmeticEncoder,
contexts: &mut [ArithmeticEncoderContext; 19],
w: usize,
h: usize,
pw: usize,
bit_mask: u32,
) {
for y_base in (0..h).step_by(4) {
for x in 0..w {
let y_end = (y_base + 4).min(h);
for y in y_base..y_end {
let idx = (y + 1) * pw + (x + 1);
let is_significant = states[idx] & SIGNIFICANT != 0;
let coded_this_pass = states[idx] & CODED_IN_CURRENT_PASS != 0;
if is_significant && !coded_this_pass {
let ctx_label = magnitude_refinement_ctx(
states[idx],
effective_neighbor_sig::<VERTICAL_CAUSAL>(neighbors[idx], y, h),
);
let bit = (magnitudes[idx] & bit_mask != 0) as u32;
encoder.encode(bit, &mut contexts[ctx_label as usize]);
states[idx] |= MAGNITUDE_REFINED;
}
}
}
}
}
fn magnitude_refinement_pass_raw(
magnitudes: &[u32],
states: &mut [u8],
neighbors: &mut [u8],
writer: &mut BitWriter,
w: usize,
h: usize,
pw: usize,
bit_mask: u32,
style: &CodeBlockStyle,
) {
if style.vertically_causal_context {
magnitude_refinement_pass_raw_impl::<true>(
magnitudes, states, neighbors, writer, w, h, pw, bit_mask,
);
} else {
magnitude_refinement_pass_raw_impl::<false>(
magnitudes, states, neighbors, writer, w, h, pw, bit_mask,
);
}
}
fn magnitude_refinement_pass_raw_impl<const VERTICAL_CAUSAL: bool>(
magnitudes: &[u32],
states: &mut [u8],
neighbors: &mut [u8],
writer: &mut BitWriter,
w: usize,
h: usize,
pw: usize,
bit_mask: u32,
) {
for y_base in (0..h).step_by(4) {
for x in 0..w {
let y_end = (y_base + 4).min(h);
for y in y_base..y_end {
let idx = (y + 1) * pw + (x + 1);
let is_significant = states[idx] & SIGNIFICANT != 0;
let coded_this_pass = states[idx] & CODED_IN_CURRENT_PASS != 0;
let _neighbor_sig = effective_neighbor_sig::<VERTICAL_CAUSAL>(neighbors[idx], y, h);
if is_significant && !coded_this_pass {
let bit = (magnitudes[idx] & bit_mask != 0) as u32;
writer.write_bit(bit);
states[idx] |= MAGNITUDE_REFINED;
}
}
}
}
}
fn cleanup_pass(
magnitudes: &[u32],
states: &mut [u8],
neighbors: &mut [u8],
encoder: &mut ArithmeticEncoder,
contexts: &mut [ArithmeticEncoderContext; 19],
w: usize,
h: usize,
pw: usize,
bit_mask: u32,
sub_band_type: SubBandType,
style: &CodeBlockStyle,
) {
if style.vertically_causal_context {
cleanup_pass_impl::<true>(
magnitudes,
states,
neighbors,
encoder,
contexts,
w,
h,
pw,
bit_mask,
sub_band_type,
);
} else {
cleanup_pass_impl::<false>(
magnitudes,
states,
neighbors,
encoder,
contexts,
w,
h,
pw,
bit_mask,
sub_band_type,
);
}
}
fn cleanup_pass_impl<const VERTICAL_CAUSAL: bool>(
magnitudes: &[u32],
states: &mut [u8],
neighbors: &mut [u8],
encoder: &mut ArithmeticEncoder,
contexts: &mut [ArithmeticEncoderContext; 19],
w: usize,
h: usize,
pw: usize,
bit_mask: u32,
sub_band_type: SubBandType,
) {
for y_base in (0..h).step_by(4) {
for x in 0..w {
let y_end = (y_base + 4).min(h);
let stripe_height = y_end - y_base;
if stripe_height == 4 {
let mut all_zero_uncoded = true;
for y in y_base..y_end {
let idx = (y + 1) * pw + (x + 1);
if states[idx] & (SIGNIFICANT | CODED_IN_CURRENT_PASS) != 0
|| effective_neighbor_sig::<VERTICAL_CAUSAL>(neighbors[idx], y, h) != 0
{
all_zero_uncoded = false;
break;
}
}
if all_zero_uncoded {
let mut first_sig = None;
for (j, y) in (y_base..y_end).enumerate() {
let idx = (y + 1) * pw + (x + 1);
if magnitudes[idx] & bit_mask != 0 {
first_sig = Some(j);
break;
}
}
if let Some(pos) = first_sig {
encoder.encode(1, &mut contexts[17]); encoder.encode((pos >> 1) as u32 & 1, &mut contexts[18]); encoder.encode(pos as u32 & 1, &mut contexts[18]);
let y = y_base + pos;
let idx = (y + 1) * pw + (x + 1);
encode_sign::<VERTICAL_CAUSAL>(
idx, neighbors, states, encoder, contexts, pw, y, h,
);
set_significant(idx, states, neighbors, pw);
for y in (y_base + pos + 1)..y_end {
let idx = (y + 1) * pw + (x + 1);
if states[idx] & (SIGNIFICANT | CODED_IN_CURRENT_PASS) == 0 {
let ctx_label = zero_coding_ctx(
effective_neighbor_sig::<VERTICAL_CAUSAL>(neighbors[idx], y, h),
sub_band_type,
);
let bit = (magnitudes[idx] & bit_mask != 0) as u32;
encoder.encode(bit, &mut contexts[ctx_label as usize]);
if bit == 1 {
encode_sign::<VERTICAL_CAUSAL>(
idx, neighbors, states, encoder, contexts, pw, y, h,
);
set_significant(idx, states, neighbors, pw);
}
}
}
continue;
} else {
encoder.encode(0, &mut contexts[17]);
continue;
}
}
}
for y in y_base..y_end {
let idx = (y + 1) * pw + (x + 1);
if states[idx] & (SIGNIFICANT | CODED_IN_CURRENT_PASS) == 0 {
let ctx_label = zero_coding_ctx(
effective_neighbor_sig::<VERTICAL_CAUSAL>(neighbors[idx], y, h),
sub_band_type,
);
let bit = (magnitudes[idx] & bit_mask != 0) as u32;
encoder.encode(bit, &mut contexts[ctx_label as usize]);
if bit == 1 {
encode_sign::<VERTICAL_CAUSAL>(
idx, neighbors, states, encoder, contexts, pw, y, h,
);
set_significant(idx, states, neighbors, pw);
}
}
}
}
}
}
fn encode_sign<const VERTICAL_CAUSAL: bool>(
idx: usize,
neighbors: &[u8],
states: &[u8],
encoder: &mut ArithmeticEncoder,
contexts: &mut [ArithmeticEncoderContext; 19],
pw: usize,
y: usize,
h: usize,
) {
let significances =
effective_neighbor_sig::<VERTICAL_CAUSAL>(neighbors[idx], y, h) & 0b0101_0101;
let top_sign = if states[idx - pw] & SIGNIFICANT != 0 {
((states[idx - pw] & NEGATIVE) != 0) as u8
} else {
0
};
let left_sign = if states[idx - 1] & SIGNIFICANT != 0 {
((states[idx - 1] & NEGATIVE) != 0) as u8
} else {
0
};
let right_sign = if states[idx + 1] & SIGNIFICANT != 0 {
((states[idx + 1] & NEGATIVE) != 0) as u8
} else {
0
};
let bottom_sign = if VERTICAL_CAUSAL && neighbor_in_next_stripe(y, h) {
0
} else if states[idx + pw] & SIGNIFICANT != 0 {
((states[idx + pw] & NEGATIVE) != 0) as u8
} else {
0
};
let sign_bits = (top_sign << 6) | (left_sign << 4) | (right_sign << 2) | bottom_sign;
let negative_sigs = significances & sign_bits;
let positive_sigs = significances & !sign_bits;
let merged = (negative_sigs << 1) | positive_sigs;
let (ctx_label, xor_bit) = SIGN_CONTEXT_LOOKUP[merged as usize];
let sign_bit = ((states[idx] & NEGATIVE) != 0) as u32;
encoder.encode(sign_bit ^ xor_bit as u32, &mut contexts[ctx_label as usize]);
}
fn encode_sign_raw(idx: usize, states: &[u8], writer: &mut BitWriter) {
let is_significant = states[idx] & SIGNIFICANT != 0;
debug_assert!(!is_significant);
writer.write_bit(((states[idx] & NEGATIVE) != 0) as u32);
}
#[inline]
fn neighbor_in_next_stripe(y: usize, height: usize) -> bool {
y + 1 < height && ((y + 1) >> 2) > (y >> 2)
}
#[inline(always)]
fn effective_neighbor_sig<const VERTICAL_CAUSAL: bool>(
neighbor_sig: u8,
y: usize,
height: usize,
) -> u8 {
if VERTICAL_CAUSAL && neighbor_in_next_stripe(y, height) {
neighbor_sig & 0b1111_0100
} else {
neighbor_sig
}
}
#[inline]
fn zero_coding_ctx(neighbor_sig: u8, sub_band_type: SubBandType) -> u8 {
match sub_band_type {
SubBandType::LowLow | SubBandType::LowHigh => ZERO_CTX_LL_LH[neighbor_sig as usize],
SubBandType::HighLow => ZERO_CTX_HL[neighbor_sig as usize],
SubBandType::HighHigh => ZERO_CTX_HH[neighbor_sig as usize],
}
}
#[inline]
fn magnitude_refinement_ctx(state: u8, neighbor_sig: u8) -> u8 {
if state & MAGNITUDE_REFINED != 0 {
16
} else {
14 + neighbor_sig.min(1)
}
}
fn set_significant(idx: usize, states: &mut [u8], neighbors: &mut [u8], pw: usize) {
states[idx] |= SIGNIFICANT;
let top = idx - pw;
let bottom = idx + pw;
neighbors[top - 1] |= 1 << 1; neighbors[top] |= 1; neighbors[top + 1] |= 1 << 3; neighbors[idx - 1] |= 1 << 2; neighbors[idx + 1] |= 1 << 4; neighbors[bottom - 1] |= 1 << 5; neighbors[bottom] |= 1 << 6; neighbors[bottom + 1] |= 1 << 7; }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_all_zeros() {
let coeffs = vec![0i32; 16];
let result = encode_code_block(&coeffs, 4, 4, SubBandType::LowLow, 8);
assert_eq!(result.num_coding_passes, 0);
assert!(result.data.is_empty());
assert_eq!(result.num_zero_bitplanes, 8);
}
#[test]
fn test_encode_single_nonzero() {
let mut coeffs = vec![0i32; 16];
coeffs[0] = 128;
let result = encode_code_block(&coeffs, 4, 4, SubBandType::LowLow, 8);
assert!(result.num_coding_passes > 0);
assert!(!result.data.is_empty());
assert_eq!(result.num_zero_bitplanes, 0);
}
#[test]
fn pack_classic_selective_bypass_tokens_matches_scalar_single_cleanup_block() {
let style = CodeBlockStyle {
selective_arithmetic_coding_bypass: true,
reset_context_probabilities: false,
termination_on_each_pass: false,
vertically_causal_context: false,
segmentation_symbols: false,
high_throughput_block_coding: false,
};
let coefficients = [1i32];
let scalar = encode_code_block_segments_with_style(
&coefficients,
1,
1,
SubBandType::LowLow,
1,
&style,
);
let token_bytes = pack_mq_test_tokens(&[(0, 1), (9, 0)]);
let packed = pack_classic_selective_bypass_tier1_tokens(
&token_bytes,
&[ClassicTier1TokenSegment {
token_bit_offset: 0,
token_bit_count: 12,
start_coding_pass: 0,
end_coding_pass: 1,
use_arithmetic: true,
}],
scalar.num_coding_passes,
scalar.num_zero_bitplanes,
)
.expect("tokens pack");
assert_eq!(packed.data, scalar.data);
assert_eq!(packed.num_coding_passes, scalar.num_coding_passes);
assert_eq!(packed.num_zero_bitplanes, scalar.num_zero_bitplanes);
assert_eq!(packed.segments.len(), scalar.segments.len());
for (packed_segment, scalar_segment) in packed.segments.iter().zip(&scalar.segments) {
assert_eq!(packed_segment.data_offset, scalar_segment.data_offset);
assert_eq!(packed_segment.data_length, scalar_segment.data_length);
assert_eq!(
packed_segment.start_coding_pass,
scalar_segment.start_coding_pass
);
assert_eq!(
packed_segment.end_coding_pass,
scalar_segment.end_coding_pass
);
assert_eq!(packed_segment.use_arithmetic, scalar_segment.use_arithmetic);
}
}
fn pack_mq_test_tokens(tokens: &[(u8, u8)]) -> Vec<u8> {
let mut bytes = Vec::new();
let mut current = 0u8;
let mut bits = 0u8;
for &(ctx, bit) in tokens {
let value = (ctx & 0x1F) | ((bit & 1) << 5);
for shift in (0..6).rev() {
current = (current << 1) | ((value >> shift) & 1);
bits += 1;
if bits == 8 {
bytes.push(current);
current = 0;
bits = 0;
}
}
}
if bits != 0 {
bytes.push(current << (8 - bits));
}
bytes
}
#[test]
fn test_encode_various_magnitudes() {
let coeffs: Vec<i32> = (0..64)
.map(|x| if x % 3 == 0 { x * 10 } else { -x })
.collect();
let result = encode_code_block(&coeffs, 8, 8, SubBandType::HighHigh, 12);
assert!(result.num_coding_passes > 0);
assert!(!result.data.is_empty());
}
#[test]
fn test_zero_bitplanes_count() {
let coeffs = vec![7i32, -3, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
let result = encode_code_block(&coeffs, 4, 4, SubBandType::LowLow, 8);
assert_eq!(result.num_zero_bitplanes, 5);
}
#[test]
fn padded_coefficient_preparation_stores_sign_in_state_flags() {
let coeffs = vec![7i32, -3, 0, -9];
let (magnitudes, states) = prepare_padded_coefficients(&coeffs, 2, 2, 4);
assert_eq!(magnitudes[5], 7);
assert_eq!(magnitudes[6], 3);
assert_eq!(magnitudes[9], 0);
assert_eq!(magnitudes[10], 9);
assert_eq!(states[5] & NEGATIVE, 0);
assert_ne!(states[6] & NEGATIVE, 0);
assert_eq!(states[9] & NEGATIVE, 0);
assert_ne!(states[10] & NEGATIVE, 0);
}
#[test]
fn clear_coded_in_current_pass_touches_only_recorded_indices() {
let mut states = vec![0u8; 8];
let mut coded_indices = Vec::new();
mark_coded_in_current_pass(2, &mut states, &mut coded_indices);
mark_coded_in_current_pass(5, &mut states, &mut coded_indices);
states[6] = SIGNIFICANT;
clear_coded_in_current_pass(&mut states, &mut coded_indices);
assert_eq!(states[2] & CODED_IN_CURRENT_PASS, 0);
assert_eq!(states[5] & CODED_IN_CURRENT_PASS, 0);
assert_eq!(states[6], SIGNIFICANT);
assert!(coded_indices.is_empty());
}
#[test]
fn pcrd_distortion_delta_reflects_residual_error_reduction() {
let sparse_delta = segment_distortion_delta(&[8], 0, 1, 4);
let dense_delta = segment_distortion_delta(&[15], 0, 1, 4);
assert!(
dense_delta > sparse_delta,
"coefficients with the same MSB but larger residual error should have larger PCRD distortion reduction"
);
}
}