use better_io::BetterBufRead;
use crate::bit_reader::BitReaderBuilder;
use crate::bit_writer::BitWriter;
use crate::constants::{
Bitlen, BITS_TO_ENCODE_DICT_LEN, BITS_TO_ENCODE_MODE_VARIANT, BITS_TO_ENCODE_QUANTIZE_K,
MAX_SUPPORTED_PRECISION_BYTES, OVERSHOOT_PADDING,
};
use crate::data_types::float::Float;
use crate::data_types::latent_priv::LatentPriv;
use crate::data_types::{Latent, LatentType};
use crate::errors::{PcoError, PcoResult};
use crate::macros::match_latent_enum;
use crate::metadata::dyn_latent::DynLatent;
use crate::metadata::format_version::FormatVersion;
use crate::metadata::DynLatents;
use crate::metadata::Mode::*;
use std::fmt::Debug;
use std::io::Write;
const FIXED_READ_SIZE: usize = BITS_TO_ENCODE_MODE_VARIANT.div_ceil(8) as usize
+ MAX_SUPPORTED_PRECISION_BYTES
+ OVERSHOOT_PADDING;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
#[non_exhaustive]
pub enum Mode {
#[default]
Classic,
IntMult(DynLatent),
FloatMult(DynLatent),
FloatQuant(Bitlen),
Dict(DynLatents),
}
impl Mode {
pub(crate) fn read_from<R: BetterBufRead>(
reader_builder: &mut BitReaderBuilder<R>,
version: &FormatVersion,
latent_type: LatentType,
) -> PcoResult<Self> {
let mut mode = reader_builder.with_reader(FIXED_READ_SIZE, |reader| unsafe {
let read_latent = |reader| {
match_latent_enum!(
latent_type,
LatentType<L> => {
DynLatent::read_uncompressed_from::<L>(reader)
}
)
};
let mode = match reader.read_bitlen(BITS_TO_ENCODE_MODE_VARIANT) {
0 => Classic,
1 => {
if version.used_old_gcds() {
return Err(PcoError::corruption(
"unable to decompress data from yanked v0.0.0 of pco with different GCD encoding",
));
}
let base = read_latent(reader);
IntMult(base)
}
2 => {
let base_latent = read_latent(reader);
FloatMult(base_latent)
}
3 => {
let k = reader.read_bitlen(BITS_TO_ENCODE_QUANTIZE_K);
FloatQuant(k)
}
4 => {
let n_unique = reader.read_usize(BITS_TO_ENCODE_DICT_LEN);
reader.drain_empty_byte("expected zeros between dict mode length and values")?;
let dict = match_latent_enum!(
latent_type,
LatentType<L> => { DynLatents::new::<L>(vec![L::ZERO; n_unique]) }
);
Dict(dict)
}
value => {
return Err(PcoError::corruption(format!(
"unknown mode variant {}",
value
)))
}
};
Ok(mode)
})?;
if let Mode::Dict(dict) = &mut mode {
dict.read_long_uncompressed_in_place(reader_builder)?;
}
Ok(mode)
}
pub(crate) unsafe fn write_to<W: Write>(&self, writer: &mut BitWriter<W>) {
let mode_value = match self {
Classic => 0,
IntMult(_) => 1,
FloatMult(_) => 2,
FloatQuant(_) => 3,
Dict(_) => 4,
};
writer.write_bitlen(mode_value, BITS_TO_ENCODE_MODE_VARIANT);
match self {
Classic => (),
IntMult(base) => {
base.write_uncompressed_to(writer);
}
FloatMult(base_latent) => {
base_latent.write_uncompressed_to(writer);
}
&FloatQuant(k) => {
writer.write_uint(k, BITS_TO_ENCODE_QUANTIZE_K);
}
Dict(dict) => {
writer.write_usize(dict.len(), BITS_TO_ENCODE_DICT_LEN);
writer.finish_byte();
dict.write_uncompressed_to(writer);
}
};
}
pub(crate) fn primary_latent_type(&self, number_latent_type: LatentType) -> LatentType {
match self {
Classic | FloatMult(_) | FloatQuant(_) | IntMult(_) => number_latent_type,
Dict(_) => LatentType::U32,
}
}
pub(crate) fn secondary_latent_type(&self, number_latent_type: LatentType) -> Option<LatentType> {
match self {
Classic | Dict(_) => None,
FloatMult(_) | FloatQuant(_) | IntMult(_) => Some(number_latent_type),
}
}
pub(crate) fn float_mult<F: Float>(base: F) -> Self {
FloatMult(DynLatent::new(base.to_latent_ordered()))
}
pub(crate) fn int_mult<L: Latent>(base: L) -> Self {
IntMult(DynLatent::new(base))
}
pub(crate) fn max_bit_size(&self) -> usize {
let payload_bits = match self {
Mode::Classic => 0,
Mode::Dict(dict) => BITS_TO_ENCODE_DICT_LEN as usize + 7 + dict.exact_bit_size(),
Mode::FloatMult(base) => base.exact_bit_size() as usize,
Mode::FloatQuant(_) => BITS_TO_ENCODE_QUANTIZE_K as usize,
Mode::IntMult(base) => base.exact_bit_size() as usize,
};
BITS_TO_ENCODE_MODE_VARIANT as usize + payload_bits
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bit_writer::BitWriter;
use crate::metadata::{DynLatent, DynLatents};
fn check_bit_size(mode: Mode) {
let mut bytes = Vec::new();
let mut writer = BitWriter::new(&mut bytes, 100);
unsafe {
mode.write_to(&mut writer);
}
let true_bit_size = writer.bit_idx();
assert!(true_bit_size <= mode.max_bit_size());
if !matches!(mode, Mode::Dict(_)) {
assert_eq!(true_bit_size, mode.max_bit_size());
}
}
#[test]
fn test_bit_size() {
check_bit_size(Mode::Classic);
check_bit_size(Mode::Dict(DynLatents::new::<u32>(vec![])));
check_bit_size(Mode::Dict(DynLatents::new::<u64>(vec![
1, 77, 1111,
])));
check_bit_size(Mode::IntMult(DynLatent::new(77_u32)));
check_bit_size(Mode::FloatMult(DynLatent::new(77_u32)));
check_bit_size(Mode::FloatQuant(7));
}
}