pco 1.0.0-rc

Good compression for numerical sequences
Documentation
use std::fmt::Debug;

use crate::ans::AnsState;
use crate::bit_reader::BitReader;
use crate::chunk_latent_decompressor::ChunkLatentDecompressor;
use crate::constants::{Bitlen, ANS_INTERLEAVING, FULL_BATCH_N};
use crate::data_types::Latent;
use crate::dyn_latent_slice::DynLatentSlice;
use crate::errors::PcoResult;
use crate::macros::define_latent_enum;
use crate::metadata::delta_encoding::LatentVarDeltaEncoding;
use crate::{bit_reader, delta};

#[inline(never)]
unsafe fn read_offsets<L: Latent, const READ_BYTES: usize>(
  reader: &mut BitReader,
  offset_bits_csum: &[u32],
  offset_bits: &[u32],
  latents: &mut [L],
  n: usize,
) {
  let base_bit_idx = reader.bit_idx();
  let src = reader.src;
  for i in 0..n {
    let offset_bits = offset_bits[i];
    let offset_bits_csum = offset_bits_csum[i];
    let bit_idx = base_bit_idx as Bitlen + offset_bits_csum;
    let byte_idx = bit_idx / 8;
    let bits_past_byte = bit_idx % 8;
    let offset = bit_reader::read_uint_at::<L, READ_BYTES>(
      src,
      byte_idx as usize,
      bits_past_byte,
      offset_bits,
    );

    latents[i] = latents[i].wrapping_add(offset);
  }

  let final_bit_idx = base_bit_idx + offset_bits_csum[n - 1] as usize + offset_bits[n - 1] as usize;
  reader.stale_byte_idx = final_bit_idx / 8;
  reader.bits_past_byte = final_bit_idx as Bitlen % 8;
}

// Here we do something very strange to ensure vectorization of
// decompress_offsets on aarch64: we force specializations to be exported by the
// pco compiled library, as opposed to downstream compilation units. I'm not
// sure why this changes vectorization rules, but it's a significant speedup.
macro_rules! force_export {
  ($name: ident, $l: ty, $rb: literal) => {
    #[used]
    static $name: unsafe fn(&mut BitReader, &[u32], &[u32], &mut [$l], usize) =
      read_offsets::<$l, $rb>;
  };
}
force_export!(_FORCE_EXPORT_U8_4, u8, 4);
force_export!(_FORCE_EXPORT_U16_4, u16, 4);
force_export!(_FORCE_EXPORT_U32_4, u32, 4);
force_export!(_FORCE_EXPORT_U32_8, u32, 8);
force_export!(_FORCE_EXPORT_U64_8, u64, 8);

// this is entirely state - any precomputed information is in the ChunkLatentDecompressor
#[derive(Clone, Debug)]
pub struct PageLatentDecompressor<L: Latent> {
  ans_state_idxs: [AnsState; ANS_INTERLEAVING],
  delta_state: Vec<L>,
  delta_state_pos: usize,
}

impl<L: Latent> PageLatentDecompressor<L> {
  pub fn new(
    ans_final_state_idxs: [AnsState; ANS_INTERLEAVING],
    delta_encoding: &LatentVarDeltaEncoding,
    stored_delta_state: Vec<L>,
  ) -> Self {
    let (working_delta_state, delta_state_pos) =
      delta::new_buffer_and_pos(delta_encoding, stored_delta_state);

    Self {
      ans_state_idxs: ans_final_state_idxs,
      delta_state: working_delta_state,
      delta_state_pos,
    }
  }

  // This implementation handles only a full batch, but is faster.
  #[inline(never)]
  unsafe fn read_full_ans_symbols(
    &mut self,
    reader: &mut BitReader,
    cld: &mut ChunkLatentDecompressor<L>,
  ) {
    // At each iteration, this loads a single u64 and has all ANS decoders
    // read a single symbol from it.
    // Therefore it requires that ANS_INTERLEAVING * MAX_BITS_PER_ANS <= 57.
    // Additionally, we're unpacking all ANS states using the fact that
    // ANS_INTERLEAVING == 4.
    let src = reader.src;
    let mut stale_byte_idx = reader.stale_byte_idx;
    let mut bits_past_byte = reader.bits_past_byte;
    let mut offset_bit_idx = 0;
    let [mut state_idx_0, mut state_idx_1, mut state_idx_2, mut state_idx_3] = self.ans_state_idxs;
    let ans_nodes = cld.decoder.nodes.as_slice();
    let lowers = cld.state_lowers.as_slice();
    for base_i in (0..FULL_BATCH_N).step_by(ANS_INTERLEAVING) {
      stale_byte_idx += bits_past_byte as usize / 8;
      bits_past_byte %= 8;
      let packed = bit_reader::u64_at(src, stale_byte_idx);
      // I hate that I have to do this with a macro, but it gives a serious
      // performance gain. If I use a [AnsState; 4] for the state_idxs instead
      // of separate identifiers, it tries to repeatedly load and write to
      // the array instead of keeping the states in registers.
      macro_rules! handle_single_symbol {
        ($j: expr, $state_idx: ident) => {
          let i = base_i + $j;
          let node = unsafe { ans_nodes.get_unchecked($state_idx as usize) };
          let bits_to_read = node.bits_to_read as Bitlen;
          let ans_val = (packed >> bits_past_byte) as AnsState & ((1 << bits_to_read) - 1);
          let lower = unsafe { *lowers.get_unchecked($state_idx as usize) };
          let offset_bits = node.offset_bits as Bitlen;
          *cld.scratch.offset_bits_csum.get_unchecked_mut(i) = offset_bit_idx;
          *cld.scratch.offset_bits.get_unchecked_mut(i) = offset_bits;
          *cld.scratch.latents.get_unchecked_mut(i) = lower;
          bits_past_byte += bits_to_read;
          offset_bit_idx += offset_bits;
          $state_idx = node.next_state_idx_base as AnsState + ans_val;
        };
      }
      handle_single_symbol!(0, state_idx_0);
      handle_single_symbol!(1, state_idx_1);
      handle_single_symbol!(2, state_idx_2);
      handle_single_symbol!(3, state_idx_3);
    }

    reader.stale_byte_idx = stale_byte_idx;
    reader.bits_past_byte = bits_past_byte;
    self.ans_state_idxs = [state_idx_0, state_idx_1, state_idx_2, state_idx_3];
  }

  // This implementation handles arbitrary batch size and looks simpler, but is
  // slower, so we only use it at the end of the page.
  #[inline(never)]
  unsafe fn read_ans_symbols(
    &mut self,
    reader: &mut BitReader,
    batch_n: usize,
    cld: &mut ChunkLatentDecompressor<L>,
  ) {
    let src = reader.src;
    let mut stale_byte_idx = reader.stale_byte_idx;
    let mut bits_past_byte = reader.bits_past_byte;
    let mut offset_bit_idx = 0;
    let mut state_idxs = self.ans_state_idxs;
    for i in 0..batch_n {
      let j = i % ANS_INTERLEAVING;
      let state_idx = state_idxs[j] as usize;
      stale_byte_idx += bits_past_byte as usize / 8;
      bits_past_byte %= 8;
      let packed = bit_reader::u64_at(src, stale_byte_idx);
      let node = unsafe { cld.decoder.nodes.get_unchecked(state_idx) };
      let bits_to_read = node.bits_to_read as Bitlen;
      let ans_val = (packed >> bits_past_byte) as AnsState & ((1 << bits_to_read) - 1);
      let lower = unsafe { *cld.state_lowers.get_unchecked(state_idx) };
      let offset_bits = node.offset_bits as Bitlen;
      *cld.scratch.offset_bits_csum.get_unchecked_mut(i) = offset_bit_idx;
      *cld.scratch.offset_bits.get_unchecked_mut(i) = offset_bits;
      *cld.scratch.latents.get_unchecked_mut(i) = lower;
      bits_past_byte += bits_to_read;
      offset_bit_idx += offset_bits;
      state_idxs[j] = node.next_state_idx_base as AnsState + ans_val;
    }

    reader.stale_byte_idx = stale_byte_idx;
    reader.bits_past_byte = bits_past_byte;
    self.ans_state_idxs = state_idxs;
  }

  // If hits a corruption, it returns an error and leaves reader and self unchanged.
  // May contaminate dst.
  pub unsafe fn read_batch_pre_delta(
    &mut self,
    reader: &mut BitReader,
    batch_n: usize,
    cld: &mut ChunkLatentDecompressor<L>,
  ) {
    if batch_n == 0 {
      return;
    }

    assert!(batch_n <= FULL_BATCH_N);
    if cld.n_bins > 1 {
      if batch_n == FULL_BATCH_N {
        self.read_full_ans_symbols(reader, cld);
      } else {
        self.read_ans_symbols(reader, batch_n, cld);
      }
    } else {
      cld.scratch.latents[..batch_n].fill(cld.state_lowers[0]);
    }

    // We want to read the offsets for each latent type as fast as possible.
    // Depending on the number of bits per offset, we can read them in
    // different chunk sizes. We use the smallest chunk size that can hold
    // the maximum possible offset.
    // The matching is intentionally verbose to make it clear how different
    // latent types are handled.
    // Note: Providing a 2 byte read appears to degrade performance for 16-bit
    // latents.
    macro_rules! specialized_read_offsets {
      ($rb: literal) => {
        read_offsets::<L, $rb>(
          reader,
          &cld.scratch.offset_bits_csum.0,
          &cld.scratch.offset_bits.0,
          &mut cld.scratch.latents.0,
          batch_n,
        )
      };
    }
    match (cld.bytes_per_offset, L::BITS) {
      (0, _) => (),
      (1..=4, 8) => specialized_read_offsets!(4),
      (1..=4, 16) => specialized_read_offsets!(4),
      (1..=4, 32) => specialized_read_offsets!(4),
      (5..=8, 32) => specialized_read_offsets!(8),
      (1..=8, 64) => specialized_read_offsets!(8),
      (9..=15, 64) => specialized_read_offsets!(15),
      _ => panic!(
        "[PageLatentDecompressor] {} byte read not supported for {}-bit Latents",
        cld.bytes_per_offset,
        L::BITS
      ),
    }
  }

  pub unsafe fn read_batch(
    &mut self,
    reader: &mut BitReader,
    delta_latents: Option<DynLatentSlice>,
    n_remaining_in_page: usize,
    cld: &mut ChunkLatentDecompressor<L>,
  ) -> PcoResult<()> {
    let n_remaining_pre_delta =
      n_remaining_in_page.saturating_sub(cld.delta_encoding.n_latents_per_state());
    let pre_delta_len = FULL_BATCH_N.min(n_remaining_pre_delta);
    self.read_batch_pre_delta(reader, pre_delta_len, cld);
    let dst = &mut cld.scratch.latents[..n_remaining_in_page.min(FULL_BATCH_N)];

    delta::decode_in_place(
      &cld.delta_encoding,
      delta_latents,
      &mut self.delta_state_pos,
      &mut self.delta_state,
      dst,
    )
  }
}

define_latent_enum!(
  #[derive()]
  pub DynPageLatentDecompressor(PageLatentDecompressor)
);