use crate::{Bitdepth, Components, Marker, Predictor};
use alloc::{vec, vec::Vec};
use core::cmp::min;
const NUM_BITS_TBL: [u16; 256] = build_num_bits_tbl();
const fn build_num_bits_tbl() -> [u16; 256] {
let mut tbl = [0; 256];
let mut i = 1;
loop {
if i < 256 {
let mut nbits = 1;
let mut tmp = i;
loop {
tmp >>= 1;
if tmp != 0 {
nbits += 1;
} else {
break;
}
}
tbl[i] = nbits;
i += 1;
} else {
break;
}
}
tbl
}
const fn lookup_ssss(diff: i16) -> u16 {
let diff_abs = (diff as i32).unsigned_abs() as usize; if diff_abs >= 256 {
NUM_BITS_TBL[(diff_abs >> 8) & 0xFF] + 8
} else {
NUM_BITS_TBL[diff_abs & 0xFF]
}
}
#[derive(Clone, Copy, Debug)]
pub struct Encoder {
width: usize,
height: usize,
components: usize,
bitdepth: u8,
point_transform: u8,
predictor: Predictor,
padding: usize,
}
#[derive(Debug, Default, Clone)]
struct HuffCode {
enc: u16,
bits: u16,
}
#[derive(Default, Debug)]
struct HuffTableBuilder {
freq: [f32; Self::CLASSES + 1],
codesize: [usize; Self::CLASSES + 1],
others: [Option<usize>; Self::CLASSES + 1],
bits: Vec<u8>,
huffval: [Option<u8>; Self::CLASSES],
huffcode: [HuffCode; Self::CLASSES + 1],
huffsym: [Option<usize>; Self::CLASSES],
}
impl HuffTableBuilder {
const CLASSES: usize = 17;
fn new(histogram: [usize; Self::CLASSES], resolution: f32) -> Self {
let mut ins = Self::default();
ins.bits.resize(33, 0);
for (i, freq) in histogram.iter().map(|f| *f as f32 / resolution).enumerate() {
ins.freq[i] = freq;
}
ins.freq[Self::CLASSES] = 1.0;
ins
}
fn gen_codesizes(&mut self) {
loop {
let mut v1freq: f32 = 3.0; let mut v2freq: f32 = 3.0;
let mut v1: Option<usize> = None;
let mut v2: Option<usize> = None;
for (i, f) in self.freq.iter().enumerate().filter(|(_i, f)| **f > 0.0) {
if *f <= v1freq {
v1freq = *f;
v1 = Some(i);
}
}
for (i, f) in self
.freq
.iter()
.enumerate()
.filter(|(i, f)| **f > 0.0 && Some(*i) != v1)
{
if *f <= v2freq {
v2freq = *f;
v2 = Some(i);
}
}
match (&mut v1, &mut v2) {
(Some(v1), Some(v2)) => {
self.freq[*v1] += self.freq[*v2];
self.freq[*v2] = 0.0;
loop {
self.codesize[*v1] += 1;
if let Some(other) = self.others[*v1] {
*v1 = other
} else {
break;
}
}
self.others[*v1] = Some(*v2);
loop {
self.codesize[*v2] += 1;
if let Some(other) = self.others[*v2] {
*v2 = other;
} else {
break;
}
}
}
_ => {
break; }
}
}
}
fn count_bits(&mut self) {
for i in 0..18 {
if self.codesize[i] > 0 {
self.bits[self.codesize[i]] += 1;
}
}
self.adjust_bits();
}
fn sort_input(&mut self) {
let mut k = 0;
for i in 1..=32 {
for j in 0..=16 {
if self.codesize[j] == i {
self.huffval[k] = Some(j as u8);
k += 1;
}
}
}
}
fn gen_size_table(&mut self) -> usize {
let mut k = 0;
let mut i = 1;
while i <= 16 {
let mut j = 1;
while j <= self.bits[i] {
self.huffcode[k].bits = i as u16;
j += 1;
k += 1;
}
i += 1;
}
self.huffcode[k].bits = 0;
k
}
fn gen_code_table(&mut self) {
let mut k = 0;
let mut code = 0;
let mut si = self.huffcode[0].bits;
loop {
loop {
self.huffcode[k].enc = code;
code += 1;
k += 1;
if self.huffcode[k].bits != si {
break;
}
}
if self.huffcode[k].bits == 0 {
break;
}
loop {
code <<= 1;
si += 1;
if self.huffcode[k].bits == si {
break;
}
}
}
}
fn order_codes(&mut self, _lastk: usize) {
for (i, ssss) in self.huffval.iter().enumerate() {
if let Some(ssss) = ssss {
self.huffsym[*ssss as usize] = Some(i);
}
}
}
fn adjust_bits(&mut self) {
let mut i = 32;
while i > 16 {
if self.bits[i] > 0 {
let mut j = i - 2; while self.bits[j] == 0 {
j -= 1;
}
self.bits[i] -= 2;
self.bits[i - 1] += 1;
self.bits[j + 1] += 2;
self.bits[j] -= 1;
} else {
i -= 1;
}
}
while self.bits[i] == 0 {
i -= 1;
}
self.bits[i] -= 1;
}
#[allow(clippy::needless_range_loop)]
fn build(mut self) -> [BitArray16; HuffTableBuilder::CLASSES] {
self.gen_codesizes();
self.count_bits();
self.sort_input();
let lastk = self.gen_size_table();
self.gen_code_table();
self.order_codes(lastk);
let mut table = [BitArray16::default(); HuffTableBuilder::CLASSES];
for ssss in 0..=16 {
if let Some(code) = self.huffsym[ssss] {
let enc = &self.huffcode[code];
table[ssss] = BitArray16::from_lsb(enc.bits as usize, enc.enc);
}
}
table
}
}
#[derive(Default, Clone, Copy, Debug)]
struct ComponentState {
histogram: [usize; 17],
hufftable: [BitArray16; HuffTableBuilder::CLASSES],
}
struct BitstreamJPEG<'a> {
inner: &'a mut Vec<u8>,
next: u8,
used: usize,
}
impl<'a> BitstreamJPEG<'a> {
#[inline]
fn new(inner: &'a mut Vec<u8>) -> Self {
Self {
inner,
next: 0,
used: 0,
}
}
#[inline]
fn write(&mut self, mut bits: usize, value: u64) {
while bits > 0 {
if self.used == 8 {
self.internal_flush();
}
let free = 8 - self.used;
let take = min(bits, free);
let peek = ((value >> (bits - take)) & ((1 << take) - 1)) as u8;
self.next |= peek << (free - take);
bits -= take;
self.used += take;
}
}
#[inline]
fn internal_flush(&mut self) {
self.inner.push(self.next);
if self.next == 0xFF {
self.inner.push(0x00);
}
self.used = 0;
self.next = 0;
}
fn flush(&mut self) {
if self.used > 0 {
self.internal_flush();
}
}
}
impl Encoder {
pub fn new(
width: u16,
height: u16,
components: Components,
bitdepth: Bitdepth,
predictor: Predictor,
point_transform: u8,
padding: usize,
) -> Self {
let width = width as usize;
let height = height as usize;
Self {
width,
height,
components: components as usize,
bitdepth: bitdepth as u8,
point_transform,
predictor,
padding,
}
}
pub fn encode(&self, image: &[u16]) -> Option<Vec<u8>> {
if image.len() < self.height * ((self.width + self.padding) * self.components) {
return None;
}
let mut encoded = Vec::with_capacity(2 * self.width * self.height * self.components);
let (mut comp_state, cache) = self.scan_frequency(image);
for comp in comp_state.iter_mut().take(self.components) {
self.build_hufftable(comp);
}
self.write_header(&comp_state, &mut encoded);
self.write_body(&comp_state, &cache, &mut encoded);
self.write_post(&mut encoded);
Some(encoded)
}
fn scan_frequency(&self, image: &[u16]) -> ([ComponentState; 4], Vec<i16>) {
let mut comp_state = [ComponentState::default(); 4];
let mut cache = vec![0; self.width * self.height * self.components];
let rowsize = self.width * self.components;
let linesize = (self.width + self.padding) * self.components;
let mut row_prev = &image[0..];
let mut row_curr = &image[0..];
let mut diffs = vec![0_i16; linesize];
macro_rules! match_predictor {
($comp:expr, $pred:expr) => {
match $pred {
Predictor::P1 => ljpeg92_diff::<$comp, 1>(
row_prev,
row_curr,
&mut diffs,
linesize,
self.point_transform,
self.bitdepth,
),
Predictor::P2 => ljpeg92_diff::<$comp, 2>(
row_prev,
row_curr,
&mut diffs,
linesize,
self.point_transform,
self.bitdepth,
),
Predictor::P3 => ljpeg92_diff::<$comp, 3>(
row_prev,
row_curr,
&mut diffs,
linesize,
self.point_transform,
self.bitdepth,
),
Predictor::P4 => ljpeg92_diff::<$comp, 4>(
row_prev,
row_curr,
&mut diffs,
linesize,
self.point_transform,
self.bitdepth,
),
Predictor::P5 => ljpeg92_diff::<$comp, 5>(
row_prev,
row_curr,
&mut diffs,
linesize,
self.point_transform,
self.bitdepth,
),
Predictor::P6 => ljpeg92_diff::<$comp, 6>(
row_prev,
row_curr,
&mut diffs,
linesize,
self.point_transform,
self.bitdepth,
),
Predictor::P7 => ljpeg92_diff::<$comp, 7>(
row_prev,
row_curr,
&mut diffs,
linesize,
self.point_transform,
self.bitdepth,
),
}
};
}
for row in 0..self.height {
match self.components {
1 => match_predictor!(1, self.predictor),
2 => match_predictor!(2, self.predictor),
3 => match_predictor!(3, self.predictor),
4 => match_predictor!(4, self.predictor),
_ => unreachable!(),
}
cache[row * rowsize..row * rowsize + rowsize].copy_from_slice(&diffs[..rowsize]);
for (i, diff) in diffs.iter().take(rowsize).enumerate() {
let comp = i % self.components;
let ssss = lookup_ssss(*diff);
comp_state[comp].histogram[ssss as usize] += 1;
}
row_prev = row_curr;
row_curr = &row_curr[linesize..];
}
(comp_state, cache)
}
#[inline]
fn build_hufftable(&self, comp: &mut ComponentState) {
let huffgen = HuffTableBuilder::new(comp.histogram, (self.width * self.height) as f32);
let table = huffgen.build();
comp.hufftable = table;
}
#[inline]
fn write_header(&self, comp_state: &[ComponentState; 4], encoded: &mut Vec<u8>) {
write_marker(encoded, Marker::SOI);
write_marker(encoded, Marker::SOF3);
write_u16(encoded, 2 + 6 + self.components as u16 * 3); encoded.push(self.bitdepth); write_u16(encoded, self.height as u16);
write_u16(encoded, self.width as u16);
encoded.push(self.components as u8); for c in 0..self.components {
encoded.push(c as u8); encoded.push(0x11); encoded.push(0); }
for (i, comp) in comp_state.iter().enumerate().take(self.components) {
write_marker(encoded, Marker::DHT);
let bit_sum: u16 = comp.hufftable.iter().filter(|e| !e.is_empty()).count() as u16;
write_u16(encoded, 2 + (1 + 16) + bit_sum); encoded.push(i as u8);
for bit_len in 1..=16 {
let count = comp
.hufftable
.iter()
.filter(|entry| entry.len() == bit_len)
.count();
encoded.push(count as u8);
}
for bit_len in 1..=16 {
let mut codes: Vec<(u16, BitArray16)> = comp
.hufftable
.iter()
.enumerate()
.filter(|(_, code)| code.len() == bit_len)
.map(|(ssss, code)| (ssss as u16, *code))
.collect();
codes.sort_by(|a, b| a.1.cmp(&b.1));
for (ssss, _) in codes.iter() {
encoded.push(*ssss as u8);
}
}
}
write_marker(encoded, Marker::SOS);
write_u16(encoded, 0x0006 + (self.components as u16 * 2)); encoded.push(self.components as u8); for c in 0..self.components {
encoded.push(c as u8); encoded.push((c as u8) << 4); }
encoded.push(self.predictor as u8); encoded.push(0); debug_assert!(self.point_transform <= 15);
encoded.push(self.point_transform & 0xF); }
#[inline]
fn write_post(&self, encoded: &mut Vec<u8>) {
write_marker(encoded, Marker::EOI);
}
#[inline]
fn write_body(&self, comp_state: &[ComponentState; 4], cache: &[i16], encoded: &mut Vec<u8>) {
let mut bitstream = BitstreamJPEG::new(encoded);
for (i, diff) in cache.iter().enumerate() {
let comp = i % self.components;
let ssss = lookup_ssss(*diff);
let enc = comp_state[comp].hufftable[ssss as usize];
let (bits, value) = (enc.len(), enc.get_lsb() as u64);
debug_assert!(bits > 0);
bitstream.write(bits, value);
debug_assert!(ssss <= 16);
if (ssss & 15) != 0 {
let diff = if *diff < 0 {
*diff as i32 - 1
} else {
*diff as i32
};
bitstream.write(ssss as usize, (diff & (0x0FFFF >> (16 - ssss))) as u64);
}
}
bitstream.flush();
}
}
#[allow(clippy::needless_range_loop)]
fn ljpeg92_diff<const NCOMP: usize, const PX: u8>(
row_prev: &[u16], row_curr: &[u16], diffs: &mut [i16], linesize: usize, point_transform: u8, bitdepth: u8, ) {
debug_assert_eq!(linesize % NCOMP, 0);
let pixels = linesize / NCOMP; let samplecnt = pixels * NCOMP;
let row_prev = &row_prev[..samplecnt]; let row_curr = &row_curr[..samplecnt]; let diffs = &mut diffs[..samplecnt];
#[cfg(debug_assertions)]
row_curr.iter().for_each(|sample| {
let max_value = ((1u32 << (bitdepth - point_transform)) - 1) as u16;
if (*sample >> point_transform) > max_value {
panic!(
"Sample overflow, sample is {:#x} but max value is {:#x}",
sample, max_value
);
}
});
if row_curr.as_ptr() == row_prev.as_ptr() {
for comp in 0..NCOMP {
let px = (1u16 << (bitdepth - point_transform - 1)) as i32;
let sample = pred_x::<NCOMP>(row_prev, row_curr, comp, point_transform);
diffs[comp] = (sample - px) as i16;
}
for idx in NCOMP..samplecnt {
let px = pred_a::<NCOMP>(row_prev, row_curr, idx, point_transform);
let sample = pred_x::<NCOMP>(row_prev, row_curr, idx, point_transform);
diffs[idx] = (sample - px) as i16;
}
} else {
for comp in 0..NCOMP {
let px = pred_b::<NCOMP>(row_prev, row_curr, comp, point_transform);
let sample = pred_x::<NCOMP>(row_prev, row_curr, comp, point_transform);
diffs[comp] = (sample - px) as i16;
}
let predictor = match PX {
1 => pred_a::<NCOMP>,
2 => pred_b::<NCOMP>,
3 => pred_c::<NCOMP>,
4 => |prev: &[u16], curr: &[u16], idx: usize, pt: u8| -> i32 {
let ra = pred_a::<NCOMP>(prev, curr, idx, pt);
let rb = pred_b::<NCOMP>(prev, curr, idx, pt);
let rc = pred_c::<NCOMP>(prev, curr, idx, pt);
ra + rb - rc
},
5 => |prev: &[u16], curr: &[u16], idx: usize, pt: u8| -> i32 {
let ra = pred_a::<NCOMP>(prev, curr, idx, pt);
let rb = pred_b::<NCOMP>(prev, curr, idx, pt);
let rc = pred_c::<NCOMP>(prev, curr, idx, pt);
ra + ((rb - rc) >> 1) },
6 => |prev: &[u16], curr: &[u16], idx: usize, pt: u8| -> i32 {
let ra = pred_a::<NCOMP>(prev, curr, idx, pt);
let rb = pred_b::<NCOMP>(prev, curr, idx, pt);
let rc = pred_c::<NCOMP>(prev, curr, idx, pt);
rb + ((ra - rc) >> 1) },
7 => |prev: &[u16], curr: &[u16], idx: usize, pt: u8| -> i32 {
let ra = pred_a::<NCOMP>(prev, curr, idx, pt);
let rb = pred_b::<NCOMP>(prev, curr, idx, pt);
(ra + rb) >> 1 },
_ => unreachable!(),
};
for idx in NCOMP..samplecnt {
let px = predictor(row_prev, row_curr, idx, point_transform);
let sample = pred_x::<NCOMP>(row_prev, row_curr, idx, point_transform);
diffs[idx] = (sample - px) as i16;
}
}
}
#[inline(always)]
fn pred_x<const NCOMP: usize>(_prev: &[u16], curr: &[u16], idx: usize, point_transform: u8) -> i32 {
unsafe { (curr.get_unchecked(idx) >> point_transform) as i32 }
}
#[inline(always)]
fn pred_a<const NCOMP: usize>(_prev: &[u16], curr: &[u16], idx: usize, point_transform: u8) -> i32 {
unsafe { (curr.get_unchecked(idx - NCOMP) >> point_transform) as i32 }
}
#[inline(always)]
fn pred_b<const NCOMP: usize>(prev: &[u16], _curr: &[u16], idx: usize, point_transform: u8) -> i32 {
unsafe { (prev.get_unchecked(idx) >> point_transform) as i32 }
}
#[inline(always)]
fn pred_c<const NCOMP: usize>(prev: &[u16], _curr: &[u16], idx: usize, point_transform: u8) -> i32 {
unsafe { (prev.get_unchecked(idx - NCOMP) >> point_transform) as i32 }
}
#[inline(always)]
fn write_u16(buf: &mut Vec<u8>, n: u16) {
buf.extend_from_slice(&n.to_be_bytes());
}
#[inline(always)]
fn write_marker(buf: &mut Vec<u8>, m: Marker) {
buf.extend_from_slice(&[0xff, m as u8]);
}
#[derive(Debug, Clone, Copy, Default, Eq, Ord, PartialEq, PartialOrd)]
struct BitArray16 {
storage: u16,
nbits: usize,
}
impl BitArray16 {
fn len(&self) -> usize {
self.nbits
}
fn is_empty(&self) -> bool {
self.nbits == 0
}
fn get_lsb(&self) -> u16 {
self.storage >> (16 - self.nbits)
}
fn from_lsb(nbits: usize, value: u16) -> Self {
Self {
storage: value << (16 - nbits),
nbits,
}
}
}
#[cfg(test)]
mod tests {
use super::BitstreamJPEG;
use alloc::vec::Vec;
#[test]
fn bitstream_test() {
let mut buf = Vec::new();
let mut bs = BitstreamJPEG::new(&mut buf);
bs.write(1, 0b1);
bs.flush();
bs.write(1, 0b0);
bs.write(3, 0b101);
bs.write(4, 0b11111101);
bs.write(2, 0b101);
bs.flush();
bs.write(16, 0b1111111111111111);
bs.flush();
bs.write(16, 0b0);
bs.flush();
assert_eq!(buf[0], 0b10000000);
assert_eq!(buf[1], 0b01011101);
assert_eq!(buf[2], 0b01000000);
assert_eq!(buf[3], 0xFF);
assert_eq!(buf[4], 0x00); assert_eq!(buf[5], 0xFF);
assert_eq!(buf[6], 0x00); assert_eq!(buf[7], 0x00);
}
}