pub fn quantize_int8(values: &[f32]) -> (Vec<i8>, f32, i8) {
if values.is_empty() {
return (Vec::new(), 1.0, 0);
}
let max_abs = values.iter().map(|v| v.abs()).fold(0.0f32, |a, b| a.max(b));
let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
let quantized: Vec<i8> = values
.iter()
.map(|&v| (v / scale).round().clamp(-127.0, 127.0) as i8)
.collect();
(quantized, scale, 0i8)
}
pub fn dequantize_int8(quantized: &[i8], scale: f32, zero_point: i8) -> Vec<f32> {
quantized
.iter()
.map(|&q| (q as f32 - zero_point as f32) * scale)
.collect()
}
pub fn quantize_fp16(values: &[f32]) -> Vec<u16> {
values.iter().map(|&v| f32_to_fp16(v)).collect()
}
pub fn dequantize_fp16(quantized: &[u16]) -> Vec<f32> {
quantized.iter().map(|&q| fp16_to_f32(q)).collect()
}
fn f32_to_fp16(value: f32) -> u16 {
let bits = value.to_bits();
let sign = (bits >> 31) & 0x1;
let exponent = ((bits >> 23) & 0xFF) as i32;
let mantissa = bits & 0x7FFFFF;
if exponent == 0xFF {
return ((sign << 15) | 0x7C00 | ((mantissa != 0) as u32)) as u16;
}
if exponent == 0 {
return (sign << 15) as u16;
}
let exp_fp16 = exponent - 127 + 15;
if exp_fp16 < 0 {
return (sign << 15) as u16;
}
if exp_fp16 > 31 {
return ((sign << 15) | 0x7C00) as u16;
}
((sign << 15) | ((exp_fp16 as u32) << 10) | ((mantissa >> 13) & 0x3FF)) as u16
}
fn fp16_to_f32(value: u16) -> f32 {
let sign = (value >> 15) & 0x1;
let exponent = (value >> 10) & 0x1F;
let mantissa = value & 0x3FF;
if exponent == 0x1F {
if mantissa == 0 {
f32::from_bits(((sign as u32) << 31) | 0x7F800000)
} else {
f32::from_bits(((sign as u32) << 31) | 0x7FC00000)
}
} else if exponent == 0 {
if mantissa == 0 {
f32::from_bits((sign as u32) << 31)
} else {
let exp_fp32 = 127 - 15;
let mantissa_fp32 = (mantissa as u32) << 13;
f32::from_bits(((sign as u32) << 31) | ((exp_fp32 as u32) << 23) | mantissa_fp32)
}
} else {
let exp_fp32 = (exponent as i32) - 15 + 127;
let mantissa_fp32 = (mantissa as u32) << 13;
f32::from_bits(((sign as u32) << 31) | ((exp_fp32 as u32) << 23) | mantissa_fp32)
}
}
pub fn quantize_batch(
embeddings: &[Vec<f32>],
quantization_type: &str,
) -> Result<QuantizedBatch, QuantizationError> {
match quantization_type {
"int8" => {
let mut quantized = Vec::new();
let mut scales = Vec::new();
let mut zero_points = Vec::new();
for embedding in embeddings {
let (q, scale, zp) = quantize_int8(embedding);
quantized.push(q.iter().map(|&x| x as i16).collect()); scales.push(scale);
zero_points.push(zp);
}
Ok(QuantizedBatch::Int8 {
quantized,
scales,
zero_points,
})
}
"fp16" => {
let quantized: Vec<Vec<u16>> = embeddings.iter().map(|e| quantize_fp16(e)).collect();
Ok(QuantizedBatch::Fp16 { quantized })
}
_ => Err(QuantizationError::UnsupportedType(
quantization_type.to_string(),
)),
}
}
#[derive(Debug, Clone)]
pub enum QuantizedBatch {
Int8 {
quantized: Vec<Vec<i16>>,
scales: Vec<f32>,
zero_points: Vec<i8>,
},
Fp16 {
quantized: Vec<Vec<u16>>,
},
}
#[derive(Debug, Clone, PartialEq)]
pub enum QuantizationError {
UnsupportedType(String),
}
impl std::fmt::Display for QuantizationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnsupportedType(t) => write!(f, "Unsupported quantization type: {}", t),
}
}
}
impl std::error::Error for QuantizationError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_int8_quantization_roundtrip() {
let original = vec![0.1, 0.2, -0.3, 0.4, -0.5];
let (quantized, scale, zero_point) = quantize_int8(&original);
let dequantized = dequantize_int8(&quantized, scale, zero_point);
for (orig, deq) in original.iter().zip(dequantized.iter()) {
let error = (orig - deq).abs();
assert!(
error < scale,
"Quantization error too large: {} vs {}",
orig,
deq
);
}
}
#[test]
fn test_fp16_quantization_roundtrip() {
let original = vec![0.1, 0.2, -0.3, 0.4, -0.5];
let quantized = quantize_fp16(&original);
let dequantized = dequantize_fp16(&quantized);
for (orig, deq) in original.iter().zip(dequantized.iter()) {
let relative_error = (orig - deq).abs() / orig.abs().max(1e-6);
assert!(
relative_error < 0.01,
"FP16 quantization error too large: {} vs {} (rel error: {})",
orig,
deq,
relative_error
);
}
}
#[test]
fn test_quantize_batch_int8() {
let embeddings = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
let result = quantize_batch(&embeddings, "int8").unwrap();
match result {
QuantizedBatch::Int8 {
quantized,
scales,
zero_points,
} => {
assert_eq!(quantized.len(), 2);
assert_eq!(scales.len(), 2);
assert_eq!(zero_points.len(), 2);
}
_ => panic!("Expected Int8 batch"),
}
}
#[test]
fn test_quantize_batch_fp16() {
let embeddings = vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
let result = quantize_batch(&embeddings, "fp16").unwrap();
match result {
QuantizedBatch::Fp16 { quantized } => {
assert_eq!(quantized.len(), 2);
}
_ => panic!("Expected Fp16 batch"),
}
}
#[test]
fn test_quantize_batch_unsupported() {
let embeddings = vec![vec![0.1, 0.2]];
let result = quantize_batch(&embeddings, "invalid");
assert!(result.is_err());
}
}