use super::ModeAndLatents;
use crate::constants::Bitlen;
use crate::data_types::{Latent, LatentPriv, Number, NumberPriv};
use crate::describers::LatentDescriber;
use crate::dyn_slices::DynLatentSlice;
use crate::errors::{PcoError, PcoResult};
use crate::metadata::per_latent_var::PerLatentVar;
use crate::metadata::{ChunkMeta, DynLatent, Mode};
use crate::mode::{classic, dict, int_mult};
use crate::{describers, ChunkConfig, ModeSpec};
pub fn choose_mode_and_split_latents<T: Number>(
nums: &[T],
config: &ChunkConfig,
) -> PcoResult<ModeAndLatents> {
match config.mode_spec {
ModeSpec::Auto => {
if let Some(base) = int_mult::choose_base(nums) {
let mode = Mode::int_mult(base);
let latents = int_mult::split_latents(nums, base);
Ok((mode, latents))
} else {
Ok((Mode::Classic, classic::split_latents(nums)))
}
}
ModeSpec::Classic => Ok((Mode::Classic, classic::split_latents(nums))),
ModeSpec::TryFloatMult(_) | ModeSpec::TryFloatQuant(_) => Err(PcoError::invalid_argument(
"unable to use float mode for ints",
)),
ModeSpec::TryIntMult(base_u64) => {
let base = T::L::from_u64(base_u64);
let mode = Mode::IntMult(DynLatent::new(base));
let latents = int_mult::split_latents(nums, base);
Ok((mode, latents))
}
ModeSpec::TryDict => dict::configure_and_split_latents(nums),
}
}
pub fn join_latents<T: Number>(
mode: &Mode,
primary: DynLatentSlice,
secondary: Option<DynLatentSlice>,
dst: &mut [T],
) -> PcoResult<()> {
match mode {
Mode::Classic => classic::join_latents(primary, dst),
Mode::Dict(dict) => dict::join_latents(dict, primary, dst),
Mode::IntMult(base) => int_mult::join_latents(*base, primary, secondary, dst),
Mode::FloatMult(_) | Mode::FloatQuant(_) => {
unreachable!("impossible mode for unsigned ints")
}
}
}
pub fn mode_is_valid<L: Latent>(mode: &Mode) -> bool {
match mode {
Mode::Classic | Mode::Dict(_) => true,
Mode::FloatMult(_) | Mode::FloatQuant(_) => false,
Mode::IntMult(base) => *base.downcast_ref::<L>().unwrap() > L::ZERO,
}
}
macro_rules! impl_latent {
($t: ty, $conv: ty) => {
impl LatentPriv for $t {
const ZERO: Self = 0;
const ONE: Self = 1;
const MID: Self = 1 << (Self::BITS - 1);
const MAX: Self = Self::MAX;
const BITS: Bitlen = Self::BITS as Bitlen;
type Conv = $conv;
#[inline]
fn from_u32(x: u32) -> Self {
x as Self
}
#[inline]
fn from_u64(x: u64) -> Self {
x as Self
}
#[inline]
fn leading_zeros(self) -> Bitlen {
self.leading_zeros() as Bitlen
}
#[inline]
fn to_u64(self) -> u64 {
self as u64
}
#[inline]
fn from_conv(x: Self::Conv) -> Self {
x as Self
}
#[inline]
fn to_conv(self) -> Self::Conv {
self as Self::Conv
}
#[inline]
fn wrapping_add(self, other: Self) -> Self {
self.wrapping_add(other)
}
#[inline]
fn wrapping_sub(self, other: Self) -> Self {
self.wrapping_sub(other)
}
}
};
}
impl_latent!(u8, i16);
impl_latent!(u16, i32);
impl_latent!(u32, i64);
impl_latent!(u64, i64);
macro_rules! impl_unsigned_number {
($t: ty, $header_byte: expr) => {
impl NumberPriv for $t {
const NUMBER_TYPE_BYTE: u8 = $header_byte;
type L = Self;
fn mode_is_valid(mode: &Mode) -> bool {
mode_is_valid::<Self::L>(mode)
}
fn choose_mode_and_split_latents(
nums: &[Self],
config: &ChunkConfig,
) -> PcoResult<ModeAndLatents> {
choose_mode_and_split_latents(nums, config)
}
#[inline]
fn from_latent_ordered(l: Self::L) -> Self {
l
}
#[inline]
fn to_latent_ordered(self) -> Self::L {
self
}
fn join_latents(
mode: &Mode,
primary: DynLatentSlice,
secondary: Option<DynLatentSlice>,
dst: &mut [Self],
) -> PcoResult<()> {
join_latents(mode, primary, secondary, dst)
}
}
impl Number for $t {
fn get_latent_describers(meta: &ChunkMeta) -> PerLatentVar<LatentDescriber> {
describers::match_classic_mode::<Self>(meta, "")
.or_else(|| describers::match_int_modes::<Self>(meta, false))
.expect("invalid mode for unsigned type")
}
}
};
}
impl_unsigned_number!(u32, 1);
impl_unsigned_number!(u64, 2);
impl_unsigned_number!(u16, 7);
impl_unsigned_number!(u8, 10);
#[cfg(test)]
mod tests {
use super::*;
use crate::metadata::{DynLatents, Mode};
#[test]
fn test_mode_validation() {
assert!(u32::mode_is_valid(&Mode::Classic));
assert!(u32::mode_is_valid(&Mode::Dict(
DynLatents::new(vec![1_u32, 3])
)));
for base in [1_u32, 77, u32::MAX] {
assert!(u32::mode_is_valid(&Mode::int_mult(base)))
}
assert!(!u32::mode_is_valid(&Mode::int_mult(0_u32)));
assert!(!u32::mode_is_valid(&Mode::FloatQuant(3)));
}
}