use alloc::vec;
use alloc::vec::Vec;
#[derive(Debug, Clone, Copy)]
pub struct Int8QuantParams {
pub scale: f32,
pub offset: f32,
}
impl Int8QuantParams {
pub fn compute(data: &[f32]) -> Self {
if data.is_empty() {
return Self {
scale: 1.0,
offset: 0.0,
};
}
let min = data.iter().copied().fold(f32::MAX, f32::min);
let max = data.iter().copied().fold(f32::MIN, f32::max);
let range = max - min;
let scale = if range > 0.0 { range / 255.0 } else { 1.0 };
Self { scale, offset: min }
}
pub fn to_bytes(&self) -> [u8; 8] {
let mut buf = [0u8; 8];
buf[0..4].copy_from_slice(&self.scale.to_le_bytes());
buf[4..8].copy_from_slice(&self.offset.to_le_bytes());
buf
}
pub fn from_bytes(buf: &[u8; 8]) -> Self {
Self {
scale: f32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]),
offset: f32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]),
}
}
}
pub fn quantize_f32_to_int8(data: &[f32]) -> (Vec<u8>, Int8QuantParams) {
let params = Int8QuantParams::compute(data);
let quantized: Vec<u8> = data
.iter()
.map(|&v| {
let normalized = (v - params.offset) / params.scale;
libm::roundf(normalized).clamp(0.0, 255.0) as u8
})
.collect();
(quantized, params)
}
pub fn dequantize_int8_to_f32(quantized: &[u8], params: &Int8QuantParams) -> Vec<f32> {
quantized
.iter()
.map(|&q| q as f32 * params.scale + params.offset)
.collect()
}
pub fn quantize_f32_to_int8_signed(data: &[f32]) -> (Vec<i8>, f32) {
if data.is_empty() {
return (Vec::new(), 1.0);
}
let max_abs = data.iter().map(|x| libm::fabsf(*x)).fold(0.0f32, f32::max);
let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
let quantized: Vec<i8> = data
.iter()
.map(|&v| {
let normalized = v / scale;
libm::roundf(normalized).clamp(-127.0, 127.0) as i8
})
.collect();
(quantized, scale)
}
pub fn dequantize_int8_signed_to_f32(quantized: &[i8], scale: f32) -> Vec<f32> {
quantized.iter().map(|&q| q as f32 * scale).collect()
}
pub fn f32_to_f16(value: f32) -> u16 {
let bits = value.to_bits();
let sign = (bits >> 31) & 1;
let exp = ((bits >> 23) & 0xFF) as i32;
let mantissa = bits & 0x7FFFFF;
if exp == 255 {
if mantissa == 0 {
return ((sign << 15) | 0x7C00) as u16; } else {
return ((sign << 15) | 0x7E00) as u16; }
}
let new_exp = exp - 127 + 15;
if new_exp >= 31 {
return ((sign << 15) | 0x7C00) as u16;
}
if new_exp <= 0 {
return (sign << 15) as u16;
}
let new_mantissa = (mantissa >> 13) as u16;
((sign << 15) | ((new_exp as u32) << 10) | new_mantissa as u32) as u16
}
pub fn f16_to_f32(value: u16) -> f32 {
let sign = (value >> 15) & 1;
let exp = ((value >> 10) & 0x1F) as i32;
let mantissa = value & 0x3FF;
if exp == 31 {
if mantissa == 0 {
return if sign == 1 {
f32::NEG_INFINITY
} else {
f32::INFINITY
};
} else {
return f32::NAN;
}
}
if exp == 0 {
if mantissa == 0 {
return if sign == 1 { -0.0 } else { 0.0 };
}
return 0.0;
}
let new_exp = (exp - 15 + 127) as u32;
let new_mantissa = (mantissa as u32) << 13;
let bits = ((sign as u32) << 31) | (new_exp << 23) | new_mantissa;
f32::from_bits(bits)
}
pub fn quantize_f32_to_f16(data: &[f32]) -> Vec<u16> {
data.iter().map(|&v| f32_to_f16(v)).collect()
}
pub fn dequantize_f16_to_f32(quantized: &[u16]) -> Vec<f32> {
quantized.iter().map(|&v| f16_to_f32(v)).collect()
}
pub fn f16_to_bytes(data: &[u16]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(data.len() * 2);
for &v in data {
bytes.extend_from_slice(&v.to_le_bytes());
}
bytes
}
pub fn bytes_to_f16(bytes: &[u8]) -> Vec<u16> {
bytes
.chunks_exact(2)
.map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]]))
.collect()
}
pub fn quantize_f32_to_binary(data: &[f32]) -> Vec<u8> {
let num_bytes = data.len().div_ceil(8);
let mut result = vec![0u8; num_bytes];
for (i, &v) in data.iter().enumerate() {
if v > 0.0 {
result[i / 8] |= 1 << (i % 8);
}
}
result
}
pub fn binary_cosine_approx(a: &[u8], b: &[u8], dimensions: usize) -> f32 {
let hamming = super::distance::hamming_distance(a, b) as f32;
1.0 - 2.0 * hamming / dimensions as f32
}
#[derive(Debug, Clone)]
pub struct PqParams {
pub m: usize,
pub dims_per_subspace: usize,
pub k: usize,
pub codebooks: Vec<Vec<Vec<f32>>>,
}
impl PqParams {
pub fn new(dimensions: usize, m: usize) -> Self {
let dims_per_subspace = dimensions / m;
Self {
m,
dims_per_subspace,
k: 256,
codebooks: Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 0.1;
#[test]
fn test_int8_quantization_roundtrip() {
let data = vec![0.0, 0.5, 1.0, -0.5, -1.0, 0.25, 0.75];
let (quantized, params) = quantize_f32_to_int8(&data);
let recovered = dequantize_int8_to_f32(&quantized, ¶ms);
assert_eq!(data.len(), recovered.len());
for (orig, rec) in data.iter().zip(recovered.iter()) {
assert!(
(orig - rec).abs() < EPSILON,
"Original: {}, Recovered: {}",
orig,
rec
);
}
}
#[test]
fn test_int8_signed_quantization_roundtrip() {
let data = vec![0.0, 0.5, 1.0, -0.5, -1.0, 0.25, -0.75];
let (quantized, scale) = quantize_f32_to_int8_signed(&data);
let recovered = dequantize_int8_signed_to_f32(&quantized, scale);
assert_eq!(data.len(), recovered.len());
for (orig, rec) in data.iter().zip(recovered.iter()) {
assert!(
(orig - rec).abs() < EPSILON,
"Original: {}, Recovered: {}",
orig,
rec
);
}
}
#[test]
fn test_int8_quant_params_serialization() {
let params = Int8QuantParams {
scale: 0.125,
offset: -1.0,
};
let bytes = params.to_bytes();
let restored = Int8QuantParams::from_bytes(&bytes);
assert_eq!(params.scale, restored.scale);
assert_eq!(params.offset, restored.offset);
}
#[test]
fn test_f16_conversion_basic() {
let values = [0.0f32, 1.0, -1.0, 0.5, 2.0, 0.25];
for &v in &values {
let f16 = f32_to_f16(v);
let back = f16_to_f32(f16);
assert!(
(v - back).abs() < 0.01,
"Value: {}, F16: {}, Back: {}",
v,
f16,
back
);
}
}
#[test]
fn test_f16_infinity() {
let inf_f16 = f32_to_f16(f32::INFINITY);
let neg_inf_f16 = f32_to_f16(f32::NEG_INFINITY);
assert!(f16_to_f32(inf_f16).is_infinite());
assert!(f16_to_f32(inf_f16) > 0.0);
assert!(f16_to_f32(neg_inf_f16).is_infinite());
assert!(f16_to_f32(neg_inf_f16) < 0.0);
}
#[test]
fn test_f16_zero() {
let zero_f16 = f32_to_f16(0.0);
assert_eq!(f16_to_f32(zero_f16), 0.0);
}
#[test]
fn test_f16_vector_roundtrip() {
let data = vec![0.1, 0.2, 0.3, -0.1, -0.2, 0.5, 1.0, -1.0];
let quantized = quantize_f32_to_f16(&data);
let recovered = dequantize_f16_to_f32(&quantized);
assert_eq!(data.len(), recovered.len());
for (orig, rec) in data.iter().zip(recovered.iter()) {
assert!(
(orig - rec).abs() < 0.01,
"Original: {}, Recovered: {}",
orig,
rec
);
}
}
#[test]
fn test_f16_bytes_serialization() {
let f16_vec: Vec<u16> = vec![0x3C00, 0x4000, 0xC000];
let bytes = f16_to_bytes(&f16_vec);
assert_eq!(bytes.len(), 6);
let restored = bytes_to_f16(&bytes);
assert_eq!(f16_vec, restored);
}
#[test]
fn test_binary_quantization() {
let data = vec![0.5, -0.5, 0.1, -0.1, 0.0, 0.9, -0.9, 0.3];
let binary = quantize_f32_to_binary(&data);
assert_eq!(binary.len(), 1);
assert_eq!(binary[0], 0b10100101);
}
#[test]
fn test_binary_hamming() {
let a = vec![1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0];
let b = vec![1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0];
let bin_a = quantize_f32_to_binary(&a);
let bin_b = quantize_f32_to_binary(&b);
let hamming = super::super::distance::hamming_distance(&bin_a, &bin_b);
assert_eq!(hamming, 4);
}
#[test]
fn test_binary_cosine_approx() {
let a = vec![1.0; 64];
let b = vec![1.0; 64];
let bin_a = quantize_f32_to_binary(&a);
let bin_b = quantize_f32_to_binary(&b);
let similarity = binary_cosine_approx(&bin_a, &bin_b, 64);
assert!((similarity - 1.0).abs() < 0.01);
}
#[test]
fn test_quantization_preserves_order() {
let query = vec![0.5, 0.5, 0.0, 0.0];
let v1 = vec![0.6, 0.4, 0.0, 0.0]; let v2 = vec![0.0, 0.0, 0.5, 0.5];
let dp1: f32 = query.iter().zip(&v1).map(|(a, b)| a * b).sum();
let dp2: f32 = query.iter().zip(&v2).map(|(a, b)| a * b).sum();
let (q_query, params) = quantize_f32_to_int8(&query);
let (q_v1, _) = quantize_f32_to_int8(&v1);
let (q_v2, _) = quantize_f32_to_int8(&v2);
let dq_query = dequantize_int8_to_f32(&q_query, ¶ms);
let dq_v1 = dequantize_int8_to_f32(&q_v1, &Int8QuantParams::compute(&v1));
let dq_v2 = dequantize_int8_to_f32(&q_v2, &Int8QuantParams::compute(&v2));
let dq_dp1: f32 = dq_query.iter().zip(&dq_v1).map(|(a, b)| a * b).sum();
let dq_dp2: f32 = dq_query.iter().zip(&dq_v2).map(|(a, b)| a * b).sum();
assert!(
dp1 > dp2,
"Original order: dp1={} should be > dp2={}",
dp1,
dp2
);
assert!(
dq_dp1 > dq_dp2,
"Quantized order: dp1={} should be > dp2={}",
dq_dp1,
dq_dp2
);
}
}