#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
use crate::arch;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct QuantizationParams {
pub alpha: f32,
pub offset: f32,
}
impl QuantizationParams {
#[must_use]
pub fn from_range(min: f32, max: f32) -> Self {
let alpha = max - min;
Self {
alpha: if alpha > 0.0 { alpha } else { 1.0 },
offset: min,
}
}
#[must_use]
pub fn fit(values: &[f32]) -> Self {
if values.is_empty() {
return Self {
alpha: 1.0,
offset: 0.0,
};
}
let mut min = f32::MAX;
let mut max = f32::MIN;
for &v in values {
if v < min {
min = v;
}
if v > max {
max = v;
}
}
Self::from_range(min, max)
}
#[must_use]
pub fn fit_quantile(values: &[f32], quantile: f32) -> Self {
assert!(
quantile > 0.0 && quantile <= 1.0,
"quantile must be in (0.0, 1.0]"
);
if values.is_empty() {
return Self {
alpha: 1.0,
offset: 0.0,
};
}
if quantile >= 1.0 {
return Self::fit(values);
}
let mut sorted: Vec<f32> = values.iter().copied().filter(|v| v.is_finite()).collect();
sorted.sort_by(|a, b| a.total_cmp(b));
if sorted.is_empty() {
return Self {
alpha: 1.0,
offset: 0.0,
};
}
let tail = (1.0 - quantile) / 2.0;
let lo_idx = (tail * sorted.len() as f32).floor() as usize;
let hi_idx = ((1.0 - tail) * sorted.len() as f32).ceil() as usize;
let hi_idx = hi_idx.min(sorted.len() - 1);
Self::from_range(sorted[lo_idx], sorted[hi_idx])
}
#[must_use]
pub fn fit_vectors(vectors: &[&[f32]]) -> Self {
let mut min = f32::MAX;
let mut max = f32::MIN;
for v in vectors {
for &val in *v {
if val < min {
min = val;
}
if val > max {
max = val;
}
}
}
if min > max {
return Self {
alpha: 1.0,
offset: 0.0,
};
}
Self::from_range(min, max)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct QuantizedU8 {
data: Vec<u8>,
dimension: usize,
}
impl QuantizedU8 {
pub fn new(data: Vec<u8>, dimension: usize) -> Self {
assert_eq!(
data.len(),
dimension,
"QuantizedU8: data length {} doesn't match dimension {}",
data.len(),
dimension
);
Self { data, dimension }
}
pub fn data(&self) -> &[u8] {
&self.data
}
pub fn dimension(&self) -> usize {
self.dimension
}
#[must_use]
pub fn memory_bytes(&self) -> usize {
self.data.len()
}
}
#[must_use]
pub fn quantize_u8(values: &[f32], params: &QuantizationParams) -> QuantizedU8 {
let inv_alpha = 255.0 / params.alpha;
let data: Vec<u8> = values
.iter()
.map(|&v| {
let normalized = (v - params.offset) * inv_alpha;
normalized.round().clamp(0.0, 255.0) as u8
})
.collect();
QuantizedU8 {
dimension: values.len(),
data,
}
}
#[derive(Clone, Copy, Debug)]
pub struct QueryContext {
pub query_sum: f32,
}
#[must_use]
pub fn query_context(query: &[f32]) -> QueryContext {
QueryContext {
query_sum: query.iter().sum(),
}
}
#[must_use]
#[allow(unsafe_code)]
pub fn asymmetric_dot_u8(
query: &[f32],
quantized: &QuantizedU8,
params: &QuantizationParams,
) -> f32 {
assert_eq!(
query.len(),
quantized.dimension,
"asymmetric_dot_u8: dimension mismatch ({} vs {})",
query.len(),
quantized.dimension
);
let ctx = query_context(query);
asymmetric_dot_u8_precomputed(query, quantized, params, &ctx)
}
#[must_use]
#[allow(unsafe_code)]
pub fn asymmetric_dot_u8_precomputed(
query: &[f32],
quantized: &QuantizedU8,
params: &QuantizationParams,
ctx: &QueryContext,
) -> f32 {
assert_eq!(
query.len(),
quantized.dimension,
"asymmetric_dot_u8_precomputed: dimension mismatch ({} vs {})",
query.len(),
quantized.dimension
);
let mixed = mixed_dot_u8_f32(query, &quantized.data);
(params.alpha / 255.0) * mixed + params.offset * ctx.query_sum
}
#[inline]
#[allow(unsafe_code)]
fn mixed_dot_u8_f32(a: &[f32], b: &[u8]) -> f32 {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
let n = a.len().min(b.len());
#[cfg(target_arch = "x86_64")]
{
if n >= 16 && is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { arch::x86_64::dot_u8_f32_avx2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
if n >= 16 {
return unsafe { arch::aarch64::dot_u8_f32_neon(a, b) };
}
}
#[allow(unreachable_code)]
mixed_dot_u8_f32_portable(a, b)
}
#[inline]
fn mixed_dot_u8_f32_portable(a: &[f32], b: &[u8]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(&af, &bu)| af * bu as f32)
.sum()
}
#[must_use]
pub fn batch_knn_u8(
query: &[f32],
corpus: &[QuantizedU8],
params: &QuantizationParams,
k: usize,
) -> Vec<(usize, f32)> {
if corpus.is_empty() || k == 0 {
return Vec::new();
}
let ctx = query_context(query);
let k = k.min(corpus.len());
let mut scores: Vec<(usize, f32)> = corpus
.iter()
.enumerate()
.map(|(i, q)| (i, asymmetric_dot_u8_precomputed(query, q, params, &ctx)))
.collect();
scores.sort_by(|a, b| b.1.total_cmp(&a.1));
scores.truncate(k);
scores
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantize_roundtrip() {
let values = [0.0f32, 0.5, 1.0, -1.0, 0.25];
let params = QuantizationParams::fit(&values);
let quantized = quantize_u8(&values, ¶ms);
assert_eq!(quantized.dimension(), values.len());
for (i, &original) in values.iter().enumerate() {
let dequant = params.alpha * (quantized.data()[i] as f32 / 255.0) + params.offset;
let error = (original - dequant).abs();
assert!(
error < params.alpha / 255.0 + 1e-6,
"roundtrip error too large at {i}: original={original}, dequant={dequant}, error={error}"
);
}
}
#[test]
fn test_quantize_range() {
let values = [-1.0f32, 0.0, 1.0];
let params = QuantizationParams::fit(&values);
let q = quantize_u8(&values, ¶ms);
assert_eq!(q.data()[0], 0); assert_eq!(q.data()[2], 255); assert!((q.data()[1] as i32 - 128).abs() <= 1); }
#[test]
fn test_asymmetric_dot_matches_exact() {
let doc = [1.0f32, 2.0, 3.0, 4.0];
let query = [0.5f32, 0.5, 0.5, 0.5];
let exact_dot: f32 = doc.iter().zip(&query).map(|(d, q)| d * q).sum();
let params = QuantizationParams::fit(&doc);
let quantized = quantize_u8(&doc, ¶ms);
let approx_dot = asymmetric_dot_u8(&query, &quantized, ¶ms);
let error = (exact_dot - approx_dot).abs();
let tolerance = params.alpha / 255.0 * doc.len() as f32;
assert!(
error < tolerance,
"asymmetric dot error too large: exact={exact_dot}, approx={approx_dot}, error={error}, tolerance={tolerance}"
);
}
#[test]
fn test_precomputed_matches_direct() {
let doc = [1.0f32, 2.0, 3.0];
let query = [0.5f32, 1.0, 1.5];
let params = QuantizationParams::fit(&doc);
let quantized = quantize_u8(&doc, ¶ms);
let direct = asymmetric_dot_u8(&query, &quantized, ¶ms);
let ctx = query_context(&query);
let precomputed = asymmetric_dot_u8_precomputed(&query, &quantized, ¶ms, &ctx);
assert!(
(direct - precomputed).abs() < 1e-6,
"precomputed mismatch: direct={direct}, precomputed={precomputed}"
);
}
#[test]
fn test_quantize_empty() {
let params = QuantizationParams::fit(&[]);
let q = quantize_u8(&[], ¶ms);
assert_eq!(q.dimension(), 0);
assert_eq!(q.memory_bytes(), 0);
}
#[test]
fn test_quantize_constant() {
let values = [5.0f32; 10];
let params = QuantizationParams::fit(&values);
let q = quantize_u8(&values, ¶ms);
assert_eq!(q.dimension(), 10);
}
#[test]
fn test_params_from_range() {
let params = QuantizationParams::from_range(-1.0, 1.0);
assert!((params.alpha - 2.0).abs() < 1e-6);
assert!((params.offset - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_fit_vectors() {
let v1 = [0.0f32, 1.0];
let v2 = [-1.0f32, 2.0];
let params = QuantizationParams::fit_vectors(&[&v1, &v2]);
assert!((params.offset - (-1.0)).abs() < 1e-6);
assert!((params.alpha - 3.0).abs() < 1e-6);
}
#[test]
fn test_memory_bytes() {
let params = QuantizationParams::from_range(0.0, 1.0);
let q = quantize_u8(&[0.5; 768], ¶ms);
assert_eq!(q.memory_bytes(), 768); }
#[test]
fn test_asymmetric_dot_large() {
let dim = 128;
let doc: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect();
let query: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.3).cos()).collect();
let exact_dot: f32 = doc.iter().zip(&query).map(|(d, q)| d * q).sum();
let params = QuantizationParams::fit(&doc);
let quantized = quantize_u8(&doc, ¶ms);
let approx_dot = asymmetric_dot_u8(&query, &quantized, ¶ms);
let abs_error = (exact_dot - approx_dot).abs();
let tolerance = params.alpha / 255.0 * (dim as f32).sqrt() + 0.1;
assert!(
abs_error < tolerance,
"dim={dim}: exact={exact_dot}, approx={approx_dot}, abs_error={abs_error}, tolerance={tolerance}"
);
}
#[test]
#[should_panic(expected = "dimension mismatch")]
fn test_asymmetric_dot_dimension_mismatch() {
let params = QuantizationParams::from_range(0.0, 1.0);
let q = quantize_u8(&[0.5, 0.5], ¶ms);
let _ = asymmetric_dot_u8(&[1.0, 2.0, 3.0], &q, ¶ms);
}
#[test]
fn test_fit_quantile_clips_outliers() {
let mut values: Vec<f32> = (0..98).map(|i| (i as f32 / 49.0) - 1.0).collect();
values.push(100.0); values.push(-100.0);
let full = QuantizationParams::fit(&values);
let clipped = QuantizationParams::fit_quantile(&values, 0.95);
assert!(
clipped.alpha < full.alpha,
"clipped alpha {} should be < full alpha {}",
clipped.alpha,
full.alpha
);
assert!(
clipped.alpha < 10.0,
"clipped alpha should be small, got {}",
clipped.alpha
);
}
#[test]
fn test_batch_knn_u8() {
let params = QuantizationParams::from_range(-1.0, 1.0);
let corpus: Vec<QuantizedU8> = vec![
quantize_u8(&[1.0, 0.0, 0.0], ¶ms),
quantize_u8(&[0.0, 1.0, 0.0], ¶ms),
quantize_u8(&[-1.0, 0.0, 0.0], ¶ms),
quantize_u8(&[0.7, 0.7, 0.0], ¶ms),
];
let query = [1.0f32, 0.0, 0.0];
let results = batch_knn_u8(&query, &corpus, ¶ms, 2);
assert_eq!(results.len(), 2);
assert!(results[0].0 == 0 || results[0].0 == 3);
assert!(results[0].1 >= results[1].1);
}
#[test]
fn test_batch_knn_u8_empty() {
let params = QuantizationParams::from_range(0.0, 1.0);
let results = batch_knn_u8(&[1.0], &[], ¶ms, 5);
assert!(results.is_empty());
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(300))]
#[test]
fn asymmetric_dot_approximates_exact(
dim in 1..200usize
) {
let doc: Vec<f32> = (0..dim).map(|i| ((i * 7) as f32).sin()).collect();
let query: Vec<f32> = (0..dim).map(|i| ((i * 11) as f32).cos()).collect();
let exact: f32 = doc.iter().zip(&query).map(|(d, q)| d * q).sum();
let params = QuantizationParams::fit(&doc);
let quantized = quantize_u8(&doc, ¶ms);
let approx = asymmetric_dot_u8(&query, &quantized, ¶ms);
let tolerance = params.alpha / 255.0 * (dim as f32).sqrt()
* query.iter().map(|x| x * x).sum::<f32>().sqrt()
+ 0.1;
prop_assert!(
(exact - approx).abs() < tolerance,
"dim={}: exact={}, approx={}, error={}, tolerance={}",
dim, exact, approx, (exact - approx).abs(), tolerance
);
}
#[test]
fn precomputed_equals_direct(
dim in 1..100usize
) {
let doc: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
let query: Vec<f32> = (0..dim).map(|i| (i as f32).sin()).collect();
let params = QuantizationParams::fit(&doc);
let quantized = quantize_u8(&doc, ¶ms);
let direct = asymmetric_dot_u8(&query, &quantized, ¶ms);
let ctx = query_context(&query);
let precomputed = asymmetric_dot_u8_precomputed(&query, &quantized, ¶ms, &ctx);
prop_assert!(
(direct - precomputed).abs() < 1e-5,
"direct={}, precomputed={}", direct, precomputed
);
}
}
}