use crate::constants::*;
use crate::data_types::LatentType;
use crate::errors::{PcoError, PcoResult};
use crate::DEFAULT_COMPRESSION_LEVEL;
#[derive(Clone, Copy, Debug, Default, PartialEq)]
#[non_exhaustive]
pub enum ModeSpec {
#[default]
Auto,
Classic,
TryFloatMult(f64),
TryFloatQuant(Bitlen),
TryIntMult(u64),
TryDict,
}
#[derive(Clone, Copy, Debug, Default, PartialEq)]
#[non_exhaustive]
pub enum DeltaSpec {
#[default]
Auto,
NoOp,
TryConsecutive(usize),
TryLookback,
TryConv1(usize),
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct ChunkConfig {
pub compression_level: usize,
pub mode_spec: ModeSpec,
pub delta_spec: DeltaSpec,
pub paging_spec: PagingSpec,
pub enable_8_bit: bool,
}
impl Default for ChunkConfig {
fn default() -> Self {
Self {
compression_level: DEFAULT_COMPRESSION_LEVEL,
mode_spec: ModeSpec::default(),
delta_spec: DeltaSpec::default(),
paging_spec: PagingSpec::EqualPagesUpTo(DEFAULT_MAX_PAGE_N),
enable_8_bit: false,
}
}
}
impl ChunkConfig {
pub fn with_compression_level(mut self, level: usize) -> Self {
self.compression_level = level;
self
}
pub fn with_mode_spec(mut self, mode_spec: ModeSpec) -> Self {
self.mode_spec = mode_spec;
self
}
pub fn with_delta_spec(mut self, delta_spec: DeltaSpec) -> Self {
self.delta_spec = delta_spec;
self
}
pub fn with_paging_spec(mut self, paging_spec: PagingSpec) -> Self {
self.paging_spec = paging_spec;
self
}
pub fn with_enable_8_bit(mut self, enable: bool) -> Self {
self.enable_8_bit = enable;
self
}
pub(crate) fn validate(&self, latent_type: LatentType) -> PcoResult<()> {
let compression_level = self.compression_level;
if compression_level > MAX_COMPRESSION_LEVEL {
return Err(PcoError::invalid_argument(format!(
"compression level may not exceed {} (was {})",
MAX_COMPRESSION_LEVEL, compression_level,
)));
}
match self.delta_spec {
DeltaSpec::Auto | DeltaSpec::NoOp | DeltaSpec::TryLookback => (),
DeltaSpec::TryConsecutive(order) => {
if order > MAX_CONSECUTIVE_DELTA_ORDER {
return Err(PcoError::invalid_argument(format!(
"consecutive delta order may not exceed {} (was {})",
MAX_CONSECUTIVE_DELTA_ORDER, order,
)));
}
}
DeltaSpec::TryConv1(order) => {
if order > MAX_CONV1_DELTA_ORDER {
return Err(PcoError::invalid_argument(format!(
"conv1 delta order may not exceed {} (was {})",
MAX_CONV1_DELTA_ORDER, order,
)));
}
}
}
if matches!(latent_type, LatentType::U8) && !self.enable_8_bit {
return Err(PcoError::invalid_argument(
"compressing 8-bit types with Pco is often a mistake; \
enable them on the ChunkConfig if you know what you're doing",
));
}
Ok(())
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum PagingSpec {
EqualPagesUpTo(usize),
Exact(Vec<usize>),
}
impl Default for PagingSpec {
fn default() -> Self {
Self::EqualPagesUpTo(DEFAULT_MAX_PAGE_N)
}
}
impl PagingSpec {
pub fn n_per_page(&self, n: usize) -> PcoResult<Vec<usize>> {
let n_per_page = match self {
PagingSpec::EqualPagesUpTo(max_page_n) => {
if n == 0 {
return Ok(Vec::new());
}
let n_pages = n.div_ceil(*max_page_n);
let page_n_low = n / n_pages;
let page_n_high = page_n_low + 1;
let r = n % n_pages;
debug_assert!(r == 0 || page_n_high <= *max_page_n);
let mut res = vec![page_n_low; n_pages];
res[..r].fill(page_n_high);
res
}
PagingSpec::Exact(n_per_page) => n_per_page.to_vec(),
};
let summed_n: usize = n_per_page.iter().sum();
if summed_n != n {
return Err(PcoError::invalid_argument(format!(
"paging spec suggests {} numbers but {} were given",
summed_n, n,
)));
}
for &page_n in &n_per_page {
if page_n == 0 {
return Err(PcoError::invalid_argument(
"cannot write data page of 0 numbers",
));
}
}
Ok(n_per_page)
}
}