use super::vector::{Int8QuantizedVector, Int8QuantizedVectorMetadata};
pub trait Quantize {
type Output;
fn quantize(&self) -> Self::Output;
fn quantize_with_metadata(&self, metadata: &Int8QuantizedVectorMetadata) -> Self::Output;
}
impl Quantize for [f32] {
type Output = Int8QuantizedVector;
fn quantize(&self) -> Self::Output {
let metadata = Int8QuantizedVectorMetadata::from_vector(self);
self.quantize_with_metadata(&metadata)
}
fn quantize_with_metadata(&self, metadata: &Int8QuantizedVectorMetadata) -> Self::Output {
let data: Vec<i8> = self
.iter()
.map(|&v| {
let scaled = v * metadata.scale + metadata.bias;
scaled.round().clamp(-128.0, 127.0) as i8
})
.collect();
Int8QuantizedVector::new(data, *metadata, self.len())
}
}
impl Quantize for Vec<f32> {
type Output = Int8QuantizedVector;
fn quantize(&self) -> Self::Output {
self.as_slice().quantize()
}
fn quantize_with_metadata(&self, metadata: &Int8QuantizedVectorMetadata) -> Self::Output {
self.as_slice().quantize_with_metadata(metadata)
}
}
#[inline]
pub fn dequantize_value(quantized: i8, metadata: &Int8QuantizedVectorMetadata) -> f32 {
(quantized as f32 - metadata.bias) / metadata.scale
}
pub trait Dequantize {
fn dequantize(&self) -> Vec<f32>;
}
impl Dequantize for Int8QuantizedVector {
fn dequantize(&self) -> Vec<f32> {
self.as_slice()
.iter()
.take(self.dimension)
.map(|&q| dequantize_value(q, &self.metadata))
.collect()
}
}
pub fn quantization_error(original: &[f32], dequantized: &[f32]) -> f32 {
assert_eq!(original.len(), dequantized.len());
let sum_sq_diff: f32 = original
.iter()
.zip(dequantized.iter())
.map(|(&o, &d)| (o - d) * (o - d))
.sum();
(sum_sq_diff / original.len() as f32).sqrt()
}
pub fn batch_quantize(
vectors: &[Vec<f32>],
) -> Option<(Vec<Int8QuantizedVector>, Int8QuantizedVectorMetadata)> {
if vectors.is_empty() {
return None;
}
let (global_min, global_max, _, _) = vectors.iter().fold(
(f32::INFINITY, f32::NEG_INFINITY, 0.0f32, 0.0f32),
|(min, max, sum, sq_sum), vec| {
vec.iter()
.fold((min, max, sum, sq_sum), |(min, max, sum, sq_sum), &v| {
(min.min(v), max.max(v), sum + v, sq_sum + v * v)
})
},
);
let scale = 254.0 / (global_max - global_min).max(1e-9);
let bias = -global_min * scale - 127.0;
let total_sum: f32 = vectors.iter().map(|v| v.iter().sum::<f32>()).sum();
let total_sq_sum: f32 = vectors
.iter()
.map(|v| v.iter().map(|&x| x * x).sum::<f32>())
.sum();
let n = vectors.len() as f32;
let metadata = Int8QuantizedVectorMetadata::new(scale, bias, total_sum / n, total_sq_sum / n);
let quantized: Vec<Int8QuantizedVector> = vectors
.iter()
.map(|v| v.quantize_with_metadata(&metadata))
.collect();
Some((quantized, metadata))
}