use std::cmp::min;
use std::ops::Range;
use better_io::BetterBufRead;
use crate::bit_reader::BitReaderBuilder;
use crate::chunk_latent_decompressor::DynChunkLatentDecompressor;
use crate::constants::{FULL_BATCH_N, MAX_BATCH_LATENT_VAR_SIZE, OVERSHOOT_PADDING};
use crate::data_types::number_priv::NumberPriv;
use crate::data_types::Number;
use crate::dyn_slices::{DynLatentSlice, DynNumberSliceMut};
use crate::errors::{PcoError, PcoResult};
use crate::macros::{match_latent_enum, match_number_enum};
use crate::metadata::page::PageMeta;
use crate::metadata::per_latent_var::PerLatentVar;
use crate::page_latent_decompressor::{DynPageLatentDecompressor, PageLatentDecompressor};
use crate::progress::Progress;
use crate::wrapped::chunk_decompressor::ChunkDecompressorInner;
use crate::wrapped::ChunkDecompressor;
pub(crate) struct PageDecompressorState<R: BetterBufRead> {
reader_builder: BitReaderBuilder<R>,
n_remaining: usize,
latent_decompressors: PerLatentVar<DynPageLatentDecompressor>,
}
pub struct PageDecompressor<'a, T: Number, R: BetterBufRead> {
cd: &'a mut ChunkDecompressor<T>,
state: PageDecompressorState<R>,
}
fn make_latent_decompressors(
cd: &ChunkDecompressorInner,
page_meta: &PageMeta,
n: usize,
) -> PcoResult<PerLatentVar<DynPageLatentDecompressor>> {
let n_in_body = n.saturating_sub(cd.n_latents_per_delta_state());
cd.per_latent_var
.as_ref()
.zip_exact(page_meta.per_latent_var.as_ref())
.map_result(|_key, (dyn_cld, page_latent_var_meta)| {
let state = match_latent_enum!(
&dyn_cld,
DynChunkLatentDecompressor<L>(cld) => {
let delta_state = page_latent_var_meta
.delta_state
.downcast_ref::<L>()
.unwrap()
.clone();
if cld.n_bins == 0 && n_in_body > 0 {
return Err(PcoError::corruption(format!(
"unable to decompress chunk with no bins and {} latents",
n_in_body
)));
}
let pld = PageLatentDecompressor::new(
page_latent_var_meta.ans_final_state_idxs,
&cld.delta_encoding,
delta_state,
);
DynPageLatentDecompressor::new(pld)
}
);
Ok(state)
})
}
impl<R: BetterBufRead> PageDecompressorState<R> {
pub(crate) fn new(src: R, cd: &ChunkDecompressorInner, n: usize) -> PcoResult<Self> {
let mut reader_builder = BitReaderBuilder::new(src);
let page_meta = reader_builder.with_reader(
cd.meta.exact_page_meta_size() + OVERSHOOT_PADDING,
|reader| unsafe { PageMeta::read_from(reader, &cd.meta) },
)?;
let latent_decompressors = make_latent_decompressors(cd, &page_meta, n)?;
Ok(Self {
reader_builder,
n_remaining: n,
latent_decompressors,
})
}
}
fn read_primary_or_secondary<'a, R: BetterBufRead>(
reader_builder: &mut BitReaderBuilder<R>,
delta_latents: Option<DynLatentSlice>,
n_remaining: usize,
dyn_cld: &'a mut DynChunkLatentDecompressor,
dyn_pld: &'a mut DynPageLatentDecompressor,
) -> PcoResult<DynLatentSlice<'a>> {
reader_builder.with_reader(MAX_BATCH_LATENT_VAR_SIZE, |reader| unsafe {
match_latent_enum!(
dyn_pld,
DynPageLatentDecompressor<L>(pld) => {
let cld = dyn_cld.downcast_mut::<L>().unwrap();
pld.read_batch(
reader,
delta_latents,
n_remaining,
cld,
)
}
)
})?;
Ok(dyn_cld.latents())
}
impl<R: BetterBufRead> PageDecompressorState<R> {
fn read_batch(
&mut self,
cd: &mut ChunkDecompressorInner,
range: Range<usize>,
dst: &mut DynNumberSliceMut,
) -> PcoResult<()> {
let batch_n = range.len();
let n_remaining = self.n_remaining;
if let Some(dyn_pld) = self.latent_decompressors.delta.as_mut() {
let limit = min(
n_remaining.saturating_sub(cd.n_latents_per_delta_state()),
batch_n,
);
self
.reader_builder
.with_reader(MAX_BATCH_LATENT_VAR_SIZE, |reader| unsafe {
match_latent_enum!(
dyn_pld,
DynPageLatentDecompressor<L>(pld) => {
pld.read_batch_pre_delta(
reader,
limit,
cd.per_latent_var.delta.as_mut().unwrap().downcast_mut::<L>().unwrap(),
)
}
);
Ok(())
})?;
}
let primary = read_primary_or_secondary(
&mut self.reader_builder,
cd.per_latent_var.delta.as_mut().map(|cld| cld.latents()),
n_remaining,
&mut cd.per_latent_var.primary,
&mut self.latent_decompressors.primary,
)?;
let secondary = match self.latent_decompressors.secondary.as_mut() {
Some(dyn_pld) => Some(read_primary_or_secondary(
&mut self.reader_builder,
cd.per_latent_var.delta.as_mut().map(|cld| cld.latents()),
n_remaining,
cd.per_latent_var.secondary.as_mut().unwrap(),
dyn_pld,
)?),
None => None,
};
match_number_enum!(
dst,
DynNumberSliceMut<T>(dst) => {
T::join_latents(
&cd.meta.mode,
primary,
secondary,
&mut dst[range],
)?;
}
);
self.n_remaining -= batch_n;
if self.n_remaining == 0 {
self.reader_builder.with_reader(1, |reader| {
reader.drain_empty_byte("expected trailing bits at end of page to be empty")
})?;
}
Ok(())
}
pub fn read(
&mut self,
cd: &mut ChunkDecompressorInner,
mut dst: DynNumberSliceMut,
) -> PcoResult<Progress> {
let n_remaining = self.n_remaining;
let dst_len = dst.len();
if !dst_len.is_multiple_of(FULL_BATCH_N) && dst_len < n_remaining {
return Err(PcoError::invalid_argument(format!(
"num_dst's length must either be a multiple of {} or be \
at least the count of numbers remaining ({} < {})",
FULL_BATCH_N, dst_len, n_remaining,
)));
}
let n_to_process = min(dst_len, self.n_remaining);
let mut n_processed = 0;
while n_processed < n_to_process {
let dst_batch_end = min(n_processed + FULL_BATCH_N, n_to_process);
self.read_batch(cd, n_processed..dst_batch_end, &mut dst)?;
n_processed = dst_batch_end;
}
Ok(Progress {
n_processed,
finished: self.n_remaining == 0,
})
}
pub fn into_src(self) -> R {
self.reader_builder.into_inner()
}
}
impl<'a, T: Number, R: BetterBufRead> PageDecompressor<'a, T, R> {
#[inline(never)]
pub(crate) fn new(src: R, cd: &'a mut ChunkDecompressor<T>, n: usize) -> PcoResult<Self> {
let state = PageDecompressorState::new(src, &cd.inner, n)?;
Ok(Self { cd, state })
}
pub fn read(&mut self, dst: &mut [T]) -> PcoResult<Progress> {
self.state.read(
&mut self.cd.inner,
DynNumberSliceMut::new(dst),
)
}
pub fn into_src(self) -> R {
self.state.into_src()
}
}