use std::fmt::Debug;
use crate::ans::AnsState;
use crate::bit_reader::BitReader;
use crate::chunk_latent_decompressor::ChunkLatentDecompressor;
use crate::constants::{Bitlen, ANS_INTERLEAVING, FULL_BATCH_N};
use crate::data_types::Latent;
use crate::dyn_latent_slice::DynLatentSlice;
use crate::errors::PcoResult;
use crate::macros::define_latent_enum;
use crate::metadata::delta_encoding::LatentVarDeltaEncoding;
use crate::{bit_reader, delta};
#[inline(never)]
unsafe fn read_offsets<L: Latent, const READ_BYTES: usize>(
reader: &mut BitReader,
offset_bits_csum: &[u32],
offset_bits: &[u32],
latents: &mut [L],
n: usize,
) {
let base_bit_idx = reader.bit_idx();
let src = reader.src;
for i in 0..n {
let offset_bits = offset_bits[i];
let offset_bits_csum = offset_bits_csum[i];
let bit_idx = base_bit_idx as Bitlen + offset_bits_csum;
let byte_idx = bit_idx / 8;
let bits_past_byte = bit_idx % 8;
let offset = bit_reader::read_uint_at::<L, READ_BYTES>(
src,
byte_idx as usize,
bits_past_byte,
offset_bits,
);
latents[i] = latents[i].wrapping_add(offset);
}
let final_bit_idx = base_bit_idx + offset_bits_csum[n - 1] as usize + offset_bits[n - 1] as usize;
reader.stale_byte_idx = final_bit_idx / 8;
reader.bits_past_byte = final_bit_idx as Bitlen % 8;
}
macro_rules! force_export {
($name: ident, $l: ty, $rb: literal) => {
#[used]
static $name: unsafe fn(&mut BitReader, &[u32], &[u32], &mut [$l], usize) =
read_offsets::<$l, $rb>;
};
}
force_export!(_FORCE_EXPORT_U8_4, u8, 4);
force_export!(_FORCE_EXPORT_U16_4, u16, 4);
force_export!(_FORCE_EXPORT_U32_4, u32, 4);
force_export!(_FORCE_EXPORT_U32_8, u32, 8);
force_export!(_FORCE_EXPORT_U64_8, u64, 8);
#[derive(Clone, Debug)]
pub struct PageLatentDecompressor<L: Latent> {
ans_state_idxs: [AnsState; ANS_INTERLEAVING],
delta_state: Vec<L>,
delta_state_pos: usize,
}
impl<L: Latent> PageLatentDecompressor<L> {
pub fn new(
ans_final_state_idxs: [AnsState; ANS_INTERLEAVING],
delta_encoding: &LatentVarDeltaEncoding,
stored_delta_state: Vec<L>,
) -> Self {
let (working_delta_state, delta_state_pos) =
delta::new_buffer_and_pos(delta_encoding, stored_delta_state);
Self {
ans_state_idxs: ans_final_state_idxs,
delta_state: working_delta_state,
delta_state_pos,
}
}
#[inline(never)]
unsafe fn read_full_ans_symbols(
&mut self,
reader: &mut BitReader,
cld: &mut ChunkLatentDecompressor<L>,
) {
let src = reader.src;
let mut stale_byte_idx = reader.stale_byte_idx;
let mut bits_past_byte = reader.bits_past_byte;
let mut offset_bit_idx = 0;
let [mut state_idx_0, mut state_idx_1, mut state_idx_2, mut state_idx_3] = self.ans_state_idxs;
let ans_nodes = cld.decoder.nodes.as_slice();
let lowers = cld.state_lowers.as_slice();
for base_i in (0..FULL_BATCH_N).step_by(ANS_INTERLEAVING) {
stale_byte_idx += bits_past_byte as usize / 8;
bits_past_byte %= 8;
let packed = bit_reader::u64_at(src, stale_byte_idx);
macro_rules! handle_single_symbol {
($j: expr, $state_idx: ident) => {
let i = base_i + $j;
let node = unsafe { ans_nodes.get_unchecked($state_idx as usize) };
let bits_to_read = node.bits_to_read as Bitlen;
let ans_val = (packed >> bits_past_byte) as AnsState & ((1 << bits_to_read) - 1);
let lower = unsafe { *lowers.get_unchecked($state_idx as usize) };
let offset_bits = node.offset_bits as Bitlen;
*cld.scratch.offset_bits_csum.get_unchecked_mut(i) = offset_bit_idx;
*cld.scratch.offset_bits.get_unchecked_mut(i) = offset_bits;
*cld.scratch.latents.get_unchecked_mut(i) = lower;
bits_past_byte += bits_to_read;
offset_bit_idx += offset_bits;
$state_idx = node.next_state_idx_base as AnsState + ans_val;
};
}
handle_single_symbol!(0, state_idx_0);
handle_single_symbol!(1, state_idx_1);
handle_single_symbol!(2, state_idx_2);
handle_single_symbol!(3, state_idx_3);
}
reader.stale_byte_idx = stale_byte_idx;
reader.bits_past_byte = bits_past_byte;
self.ans_state_idxs = [state_idx_0, state_idx_1, state_idx_2, state_idx_3];
}
#[inline(never)]
unsafe fn read_ans_symbols(
&mut self,
reader: &mut BitReader,
batch_n: usize,
cld: &mut ChunkLatentDecompressor<L>,
) {
let src = reader.src;
let mut stale_byte_idx = reader.stale_byte_idx;
let mut bits_past_byte = reader.bits_past_byte;
let mut offset_bit_idx = 0;
let mut state_idxs = self.ans_state_idxs;
for i in 0..batch_n {
let j = i % ANS_INTERLEAVING;
let state_idx = state_idxs[j] as usize;
stale_byte_idx += bits_past_byte as usize / 8;
bits_past_byte %= 8;
let packed = bit_reader::u64_at(src, stale_byte_idx);
let node = unsafe { cld.decoder.nodes.get_unchecked(state_idx) };
let bits_to_read = node.bits_to_read as Bitlen;
let ans_val = (packed >> bits_past_byte) as AnsState & ((1 << bits_to_read) - 1);
let lower = unsafe { *cld.state_lowers.get_unchecked(state_idx) };
let offset_bits = node.offset_bits as Bitlen;
*cld.scratch.offset_bits_csum.get_unchecked_mut(i) = offset_bit_idx;
*cld.scratch.offset_bits.get_unchecked_mut(i) = offset_bits;
*cld.scratch.latents.get_unchecked_mut(i) = lower;
bits_past_byte += bits_to_read;
offset_bit_idx += offset_bits;
state_idxs[j] = node.next_state_idx_base as AnsState + ans_val;
}
reader.stale_byte_idx = stale_byte_idx;
reader.bits_past_byte = bits_past_byte;
self.ans_state_idxs = state_idxs;
}
pub unsafe fn read_batch_pre_delta(
&mut self,
reader: &mut BitReader,
batch_n: usize,
cld: &mut ChunkLatentDecompressor<L>,
) {
if batch_n == 0 {
return;
}
assert!(batch_n <= FULL_BATCH_N);
if cld.n_bins > 1 {
if batch_n == FULL_BATCH_N {
self.read_full_ans_symbols(reader, cld);
} else {
self.read_ans_symbols(reader, batch_n, cld);
}
} else {
cld.scratch.latents[..batch_n].fill(cld.state_lowers[0]);
}
macro_rules! specialized_read_offsets {
($rb: literal) => {
read_offsets::<L, $rb>(
reader,
&cld.scratch.offset_bits_csum.0,
&cld.scratch.offset_bits.0,
&mut cld.scratch.latents.0,
batch_n,
)
};
}
match (cld.bytes_per_offset, L::BITS) {
(0, _) => (),
(1..=4, 8) => specialized_read_offsets!(4),
(1..=4, 16) => specialized_read_offsets!(4),
(1..=4, 32) => specialized_read_offsets!(4),
(5..=8, 32) => specialized_read_offsets!(8),
(1..=8, 64) => specialized_read_offsets!(8),
(9..=15, 64) => specialized_read_offsets!(15),
_ => panic!(
"[PageLatentDecompressor] {} byte read not supported for {}-bit Latents",
cld.bytes_per_offset,
L::BITS
),
}
}
pub unsafe fn read_batch(
&mut self,
reader: &mut BitReader,
delta_latents: Option<DynLatentSlice>,
n_remaining_in_page: usize,
cld: &mut ChunkLatentDecompressor<L>,
) -> PcoResult<()> {
let n_remaining_pre_delta =
n_remaining_in_page.saturating_sub(cld.delta_encoding.n_latents_per_state());
let pre_delta_len = FULL_BATCH_N.min(n_remaining_pre_delta);
self.read_batch_pre_delta(reader, pre_delta_len, cld);
let dst = &mut cld.scratch.latents[..n_remaining_in_page.min(FULL_BATCH_N)];
delta::decode_in_place(
&cld.delta_encoding,
delta_latents,
&mut self.delta_state_pos,
&mut self.delta_state,
dst,
)
}
}
define_latent_enum!(
#[derive()]
pub DynPageLatentDecompressor(PageLatentDecompressor)
);