use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use super::calibration::CalibrationStats;
use super::error::{validate_embedding, QuantizationError};
use super::params::QuantizationParams;
#[derive(Debug, Clone)]
pub struct QuantizedEmbedding {
pub values: Vec<i8>,
pub params: QuantizationParams,
pub hash: [u8; 32],
}
impl QuantizedEmbedding {
pub fn from_f32(
embedding: &[f32],
calibration: &CalibrationStats,
) -> Result<Self, QuantizationError> {
validate_embedding(embedding, calibration.dims)?;
let params = calibration.to_quant_params()?;
let values: Vec<i8> = embedding.iter().map(|&v| params.quantize_value(v)).collect();
let hash = compute_hash(&values);
Ok(Self { values, params, hash })
}
pub fn from_f32_uncalibrated(embedding: &[f32]) -> Result<Self, QuantizationError> {
validate_embedding(embedding, embedding.len())?;
let mut absmax: f32 = embedding.iter().fold(0.0f32, |acc, &v| acc.max(v.abs()));
if absmax == 0.0 {
absmax = 1.0; }
let params = QuantizationParams::from_absmax(absmax, embedding.len())?;
let values: Vec<i8> = embedding.iter().map(|&v| params.quantize_value(v)).collect();
let hash = compute_hash(&values);
Ok(Self { values, params, hash })
}
pub fn dequantize(&self) -> Vec<f32> {
self.values.iter().map(|&v| self.params.dequantize_value(v)).collect()
}
pub fn verify_integrity(&self) -> bool {
let computed = compute_hash(&self.values);
computed == self.hash
}
pub fn dims(&self) -> usize {
self.values.len()
}
pub fn memory_size(&self) -> usize {
self.values.len() + std::mem::size_of::<QuantizationParams>()
+ 32 }
}
pub fn compute_hash(values: &[i8]) -> [u8; 32] {
let mut hasher = DefaultHasher::new();
values.hash(&mut hasher);
let mut hashes = [0u64; 4];
hashes[0] = hasher.finish();
let mut hasher = DefaultHasher::new();
hashes[0].hash(&mut hasher);
values.len().hash(&mut hasher);
hashes[1] = hasher.finish();
for i in 2..4 {
let mut hasher = DefaultHasher::new();
hashes[i - 1].hash(&mut hasher);
hashes[i] = hasher.finish();
}
let mut result = [0u8; 32];
for (i, &h) in hashes.iter().enumerate() {
result[i * 8..(i + 1) * 8].copy_from_slice(&h.to_le_bytes());
}
result
}