use alloc::vec;
use alloc::vec::Vec;
use crate::arithmetic_decoder::{ArithmeticDecoder, Context};
use crate::bitmap::Bitmap;
use crate::decode::generic::{ContextGatherer, decode_bitmap_mmr};
use crate::decode::{AdaptiveTemplatePixel, Template};
use crate::error::Result;
#[derive(Debug, Clone)]
pub(crate) struct GrayScaleParams<'a> {
pub(crate) use_mmr: bool,
pub(crate) bits_per_pixel: u32,
pub(crate) width: u32,
pub(crate) height: u32,
pub(crate) template: Template,
pub(crate) skip_mask: Option<&'a [u32]>,
}
#[inline(always)]
pub(crate) fn decode_gray_scale_image(
data: &[u8],
params: &GrayScaleParams<'_>,
) -> Result<Vec<u32>> {
if params.use_mmr {
decode_mmr(data, params)
} else {
decode_arithmetic(data, params)
}
}
fn decode_mmr(data: &[u8], params: &GrayScaleParams<'_>) -> Result<Vec<u32>> {
let width = params.width;
let height = params.height;
let bits_per_pixel = params.bits_per_pixel;
let stride = width.div_ceil(32);
let mut offset = 0;
decode_bitplanes(width, height, stride, bits_per_pixel, |_| {
let mut bitplane = Bitmap::new(width, height);
offset += decode_bitmap_mmr(&mut bitplane, &data[offset..])?;
Ok(bitplane.data)
})
}
fn decode_arithmetic(data: &[u8], params: &GrayScaleParams<'_>) -> Result<Vec<u32>> {
let width = params.width;
let height = params.height;
let bits_per_pixel = params.bits_per_pixel;
let stride = width.div_ceil(32);
let skip_mask = params.skip_mask;
let template = params.template;
let at_pixels: Vec<AdaptiveTemplatePixel> = match template {
Template::Template0 => vec![
AdaptiveTemplatePixel { x: 3, y: -1 },
AdaptiveTemplatePixel { x: -3, y: -1 },
AdaptiveTemplatePixel { x: 2, y: -2 },
AdaptiveTemplatePixel { x: -2, y: -2 },
],
Template::Template1 => vec![AdaptiveTemplatePixel { x: 3, y: -1 }],
Template::Template2 | Template::Template3 => {
vec![AdaptiveTemplatePixel { x: 2, y: -1 }]
}
};
let mut decoder = ArithmeticDecoder::new(data);
let mut contexts = vec![Context::default(); 1 << template.context_bits()];
decode_bitplanes(width, height, stride, bits_per_pixel, |_| {
let mut bitplane = Bitmap::new(width, height);
let mut gatherer = ContextGatherer::new(width, height, template, &at_pixels);
for y in 0..height {
gatherer.start_row(&bitplane, y);
for x in 0..width {
if let Some(mask) = skip_mask {
let word_idx = (y * stride + x / 32) as usize;
let bit_pos = 31 - (x % 32);
if (mask[word_idx] >> bit_pos) & 1 != 0 {
let _ = gatherer.gather(&bitplane, x);
gatherer.update_current_row(x, false);
continue;
}
}
let context = gatherer.gather(&bitplane, x);
let pixel = decoder.decode(&mut contexts[context as usize]);
let value = pixel != 0;
bitplane.set_pixel(x, y, value);
gatherer.update_current_row(x, value);
}
}
Ok(bitplane.data)
})
}
fn decode_bitplanes<F>(
width: u32,
height: u32,
stride: u32,
bits_per_pixel: u32,
mut decode_next: F,
) -> Result<Vec<u32>>
where
F: FnMut(u32) -> Result<Vec<u32>>,
{
let size = (width * height) as usize;
let mut values = vec![0_u32; size];
let mut prev_plane = decode_next(bits_per_pixel - 1)?;
for y in 0..height {
for x in 0..width {
let word_idx = (y * stride + x / 32) as usize;
let bit_pos = 31 - (x % 32);
if (prev_plane[word_idx] >> bit_pos) & 1 != 0 {
let i = (y * width + x) as usize;
values[i] |= 1 << (bits_per_pixel - 1);
}
}
}
for j in (0..bits_per_pixel - 1).rev() {
let mut plane = decode_next(j)?;
for i in 0..plane.len() {
plane[i] ^= prev_plane[i];
}
for y in 0..height {
for x in 0..width {
let word_idx = (y * stride + x / 32) as usize;
let bit_pos = 31 - (x % 32);
if (plane[word_idx] >> bit_pos) & 1 != 0 {
let i = (y * width + x) as usize;
values[i] |= 1 << j;
}
}
}
prev_plane = plane;
}
Ok(values)
}