use alloc::format;
use alloc::sync::Arc;
use sha2::{Digest, Sha256};
use crate::errors::CodecError;
use crate::types::ConfigHash;
pub const SUPPORTED_BIT_WIDTHS: &[u8] = &[2, 4, 8];
#[derive(Clone, Debug)]
pub struct CodecConfig {
bit_width: u8,
seed: u64,
dimension: u32,
residual_enabled: bool,
config_hash: ConfigHash,
}
impl CodecConfig {
pub fn new(
bit_width: u8,
seed: u64,
dimension: u32,
residual_enabled: bool,
) -> Result<Self, CodecError> {
if !SUPPORTED_BIT_WIDTHS.contains(&bit_width) {
return Err(CodecError::UnsupportedBitWidth { got: bit_width });
}
if dimension == 0 {
return Err(CodecError::InvalidDimension { got: 0 });
}
let config_hash = compute_config_hash(bit_width, seed, dimension, residual_enabled);
Ok(Self {
bit_width,
seed,
dimension,
residual_enabled,
config_hash,
})
}
#[inline]
pub const fn bit_width(&self) -> u8 {
self.bit_width
}
#[inline]
pub const fn seed(&self) -> u64 {
self.seed
}
#[inline]
pub const fn dimension(&self) -> u32 {
self.dimension
}
#[inline]
pub const fn residual_enabled(&self) -> bool {
self.residual_enabled
}
#[inline]
pub const fn num_codebook_entries(&self) -> u32 {
1u32 << self.bit_width
}
#[inline]
pub const fn config_hash(&self) -> &ConfigHash {
&self.config_hash
}
}
impl PartialEq for CodecConfig {
fn eq(&self, other: &Self) -> bool {
self.bit_width == other.bit_width
&& self.seed == other.seed
&& self.dimension == other.dimension
&& self.residual_enabled == other.residual_enabled
}
}
impl Eq for CodecConfig {}
impl core::hash::Hash for CodecConfig {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.bit_width.hash(state);
self.seed.hash(state);
self.dimension.hash(state);
self.residual_enabled.hash(state);
}
}
pub(crate) fn compute_config_hash(
bit_width: u8,
seed: u64,
dimension: u32,
residual_enabled: bool,
) -> ConfigHash {
let canonical = format!(
"CodecConfig(bit_width={b},seed={s},dimension={d},residual_enabled={r})",
b = bit_width,
s = seed,
d = dimension,
r = if residual_enabled { "True" } else { "False" },
);
let digest = Sha256::digest(canonical.as_bytes());
Arc::from(hex::encode(digest).as_str())
}