use crate::codec::{
codebook::Codebook,
codec_config::CodecConfig,
compressed_vector::CompressedVector,
parallelism::Parallelism,
prepared::PreparedCodec,
residual::{apply_residual_into, compute_residual},
rotation_matrix::RotationMatrix,
};
use crate::errors::CodecError;
use alloc::{vec, vec::Vec};
#[derive(Default, Debug, Clone, Copy)]
pub struct Codec;
impl Codec {
#[must_use]
pub const fn new() -> Self {
Self
}
pub fn compress(
&self,
vector: &[f32],
config: &CodecConfig,
codebook: &Codebook,
) -> Result<CompressedVector, CodecError> {
let dim = config.dimension() as usize;
if vector.len() != dim {
#[allow(clippy::cast_possible_truncation)]
let got = vector.len() as u32;
return Err(CodecError::DimensionMismatch {
expected: config.dimension(),
got,
});
}
if codebook.bit_width() != config.bit_width() {
return Err(CodecError::CodebookIncompatible {
expected: config.bit_width(),
got: codebook.bit_width(),
});
}
let rotation = RotationMatrix::from_config(config);
Self::compress_with_rotation(vector, config, codebook, &rotation)
}
pub fn compress_prepared(
&self,
vector: &[f32],
prepared: &PreparedCodec,
) -> Result<CompressedVector, CodecError> {
let dim = prepared.config().dimension() as usize;
if vector.len() != dim {
#[allow(clippy::cast_possible_truncation)]
let got = vector.len() as u32;
return Err(CodecError::DimensionMismatch {
expected: prepared.config().dimension(),
got,
});
}
Self::compress_with_rotation(
vector,
prepared.config(),
prepared.codebook(),
prepared.rotation(),
)
}
fn compress_with_rotation(
vector: &[f32],
config: &CodecConfig,
codebook: &Codebook,
rotation: &RotationMatrix,
) -> Result<CompressedVector, CodecError> {
let dim = config.dimension() as usize;
let mut rotated = vec![0.0_f32; dim];
rotation.apply_into(vector, &mut rotated)?;
let mut indices = vec![0_u8; dim];
codebook.quantize_into(&rotated, &mut indices)?;
let residual = if config.residual_enabled() {
let mut reconstructed = vec![0.0_f32; dim];
codebook.dequantize_into(&indices, &mut reconstructed)?;
Some(compute_residual(&rotated, &reconstructed).into_boxed_slice())
} else {
None
};
CompressedVector::new(
indices.into_boxed_slice(),
residual,
config.config_hash().clone(),
config.dimension(),
config.bit_width(),
)
}
pub fn decompress(
&self,
compressed: &CompressedVector,
config: &CodecConfig,
codebook: &Codebook,
) -> Result<Vec<f32>, CodecError> {
let mut out = vec![0.0_f32; config.dimension() as usize];
self.decompress_into(compressed, config, codebook, &mut out)?;
Ok(out)
}
pub fn decompress_into(
&self,
compressed: &CompressedVector,
config: &CodecConfig,
codebook: &Codebook,
output: &mut [f32],
) -> Result<(), CodecError> {
if compressed.config_hash() != config.config_hash() {
return Err(CodecError::ConfigMismatch {
expected: config.config_hash().clone(),
got: compressed.config_hash().clone(),
});
}
if compressed.bit_width() != config.bit_width() {
return Err(CodecError::CodebookIncompatible {
expected: config.bit_width(),
got: compressed.bit_width(),
});
}
if codebook.bit_width() != config.bit_width() {
return Err(CodecError::CodebookIncompatible {
expected: config.bit_width(),
got: codebook.bit_width(),
});
}
if output.len() != config.dimension() as usize {
#[allow(clippy::cast_possible_truncation)]
let got = output.len() as u32;
return Err(CodecError::DimensionMismatch {
expected: config.dimension(),
got,
});
}
let rotation = RotationMatrix::from_config(config);
Self::decompress_into_with_rotation(compressed, codebook, &rotation, output)
}
pub fn decompress_prepared_into(
&self,
cv: &CompressedVector,
prepared: &PreparedCodec,
out: &mut [f32],
) -> Result<(), CodecError> {
if cv.config_hash() != prepared.config().config_hash() {
return Err(CodecError::ConfigMismatch {
expected: prepared.config().config_hash().clone(),
got: cv.config_hash().clone(),
});
}
if out.len() != prepared.config().dimension() as usize {
#[allow(clippy::cast_possible_truncation)]
let got = out.len() as u32;
return Err(CodecError::DimensionMismatch {
expected: prepared.config().dimension(),
got,
});
}
Self::decompress_into_with_rotation(cv, prepared.codebook(), prepared.rotation(), out)
}
fn decompress_into_with_rotation(
compressed: &CompressedVector,
codebook: &Codebook,
rotation: &RotationMatrix,
output: &mut [f32],
) -> Result<(), CodecError> {
let mut rotated = vec![0.0_f32; output.len()];
codebook.dequantize_into(compressed.indices(), &mut rotated)?;
if let Some(residual) = compressed.residual() {
apply_residual_into(&mut rotated, residual)?;
}
rotation.apply_inverse_into(&rotated, output)
}
pub fn compress_batch(
&self,
vectors: &[f32],
rows: usize,
cols: usize,
config: &CodecConfig,
codebook: &Codebook,
) -> Result<Vec<CompressedVector>, CodecError> {
self.compress_batch_with(vectors, rows, cols, config, codebook, Parallelism::Serial)
}
pub fn compress_batch_with(
&self,
vectors: &[f32],
rows: usize,
cols: usize,
config: &CodecConfig,
codebook: &Codebook,
parallelism: Parallelism,
) -> Result<Vec<CompressedVector>, CodecError> {
if cols != config.dimension() as usize {
#[allow(clippy::cast_possible_truncation)]
let got = cols as u32;
return Err(CodecError::DimensionMismatch {
expected: config.dimension(),
got,
});
}
let expected_len = rows.checked_mul(cols).ok_or(CodecError::LengthMismatch {
left: vectors.len(),
right: usize::MAX,
})?;
if vectors.len() != expected_len {
return Err(CodecError::LengthMismatch {
left: vectors.len(),
right: expected_len,
});
}
#[cfg(feature = "std")]
{
crate::codec::batch::compress_batch_parallel(
vectors,
rows,
cols,
config,
codebook,
parallelism,
)
}
#[cfg(not(feature = "std"))]
{
let _ = parallelism;
let mut out = Vec::with_capacity(rows);
#[allow(clippy::indexing_slicing)]
for row in 0..rows {
let start = row * cols;
out.push(self.compress(&vectors[start..start + cols], config, codebook)?);
}
Ok(out)
}
}
pub fn decompress_batch_into(
&self,
compressed: &[CompressedVector],
config: &CodecConfig,
codebook: &Codebook,
output: &mut [f32],
) -> Result<(), CodecError> {
let cols = config.dimension() as usize;
let needed = compressed.len() * cols;
if output.len() != needed {
return Err(CodecError::LengthMismatch {
left: output.len(),
right: needed,
});
}
#[allow(clippy::indexing_slicing)]
for (row, cv) in compressed.iter().enumerate() {
let start = row * cols;
self.decompress_into(cv, config, codebook, &mut output[start..start + cols])?;
}
Ok(())
}
}
pub const GPU_BATCH_THRESHOLD: usize = 512;
pub trait GpuComputeBackend {
type Error: core::fmt::Debug + Into<crate::errors::CodecError>;
fn prepare_for_device(&mut self, prepared: &mut PreparedCodec) -> Result<(), Self::Error>;
fn compress_batch(
&mut self,
input: &[f32],
rows: usize,
cols: usize,
prepared: &PreparedCodec,
) -> Result<alloc::vec::Vec<CompressedVector>, Self::Error>;
}
#[cfg(feature = "gpu-wgpu")]
impl Codec {
pub fn compress_batch_gpu_with<B>(
&self,
vectors: &[f32],
rows: usize,
cols: usize,
prepared: &mut PreparedCodec,
backend: &mut B,
parallelism: Parallelism,
) -> Result<Vec<CompressedVector>, CodecError>
where
B: GpuComputeBackend,
{
if rows >= GPU_BATCH_THRESHOLD {
backend.prepare_for_device(prepared).map_err(Into::into)?;
backend
.compress_batch(vectors, rows, cols, prepared)
.map_err(Into::into)
} else {
self.compress_batch_with(
vectors,
rows,
cols,
prepared.config(),
prepared.codebook(),
parallelism,
)
}
}
}
pub fn compress(
vector: &[f32],
config: &CodecConfig,
codebook: &Codebook,
) -> Result<CompressedVector, CodecError> {
Codec::new().compress(vector, config, codebook)
}
pub fn decompress(
compressed: &CompressedVector,
config: &CodecConfig,
codebook: &Codebook,
) -> Result<Vec<f32>, CodecError> {
Codec::new().decompress(compressed, config, codebook)
}