use std::cmp::Ordering;
#[derive(Clone, Debug)]
pub struct Int8Vector {
data: Vec<i8>,
scale: f32,
norm: f32,
}
impl Int8Vector {
pub fn from_f32(values: &[f32]) -> Self {
if values.is_empty() {
return Self {
data: Vec::new(),
scale: 1.0,
norm: 0.0,
};
}
let max_abs = values
.iter()
.map(|v| v.abs())
.max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
.unwrap_or(1.0);
let scale = if max_abs > 0.0 { max_abs / 127.0 } else { 1.0 };
let data: Vec<i8> = values
.iter()
.map(|&v| {
let quantized = (v / scale).round();
quantized.clamp(-127.0, 127.0) as i8
})
.collect();
let norm = values.iter().map(|v| v * v).sum::<f32>().sqrt();
Self { data, scale, norm }
}
pub fn from_f32_with_scale(values: &[f32], scale: f32) -> Self {
let data: Vec<i8> = values
.iter()
.map(|&v| {
let quantized = (v / scale).round();
quantized.clamp(-127.0, 127.0) as i8
})
.collect();
let norm = values.iter().map(|v| v * v).sum::<f32>().sqrt();
Self { data, scale, norm }
}
pub fn from_raw(data: Vec<i8>, scale: f32, norm: f32) -> Self {
Self { data, scale, norm }
}
#[inline]
pub fn dim(&self) -> usize {
self.data.len()
}
#[inline]
pub fn data(&self) -> &[i8] {
&self.data
}
#[inline]
pub fn scale(&self) -> f32 {
self.scale
}
#[inline]
pub fn size_bytes(&self) -> usize {
self.data.len() + 8 }
pub fn to_f32(&self) -> Vec<f32> {
self.data.iter().map(|&v| v as f32 * self.scale).collect()
}
#[inline]
pub fn dot_product(&self, other: &Self) -> f32 {
debug_assert_eq!(self.data.len(), other.data.len(), "Dimensions must match");
let raw_dot = dot_product_i8_simd(&self.data, &other.data);
raw_dot as f32 * self.scale * other.scale
}
#[inline]
pub fn dot_product_f32(&self, query: &[f32]) -> f32 {
debug_assert_eq!(self.data.len(), query.len(), "Dimensions must match");
dot_product_i8_f32_simd(&self.data, query) * self.scale
}
#[inline]
pub fn l2_squared(&self, other: &Self) -> f32 {
debug_assert_eq!(self.data.len(), other.data.len(), "Dimensions must match");
let raw_dist = l2_squared_i8_simd(&self.data, &other.data);
raw_dist as f32 * self.scale * other.scale
}
#[inline]
pub fn cosine_distance(&self, other: &Self) -> f32 {
let dot = self.dot_product(other);
let denom = self.norm * other.norm;
if denom > 0.0 {
1.0 - (dot / denom)
} else {
1.0
}
}
}
#[inline]
pub fn dot_product_i8_simd(a: &[i8], b: &[i8]) -> i32 {
debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { dot_product_i8_avx2(a, b) };
}
if is_x86_feature_detected!("sse4.1") {
return unsafe { dot_product_i8_sse4(a, b) };
}
}
dot_product_i8_scalar(a, b)
}
#[inline]
fn dot_product_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
let mut sum = 0i32;
for (x, y) in a.iter().zip(b.iter()) {
sum += (*x as i32) * (*y as i32);
}
sum
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn dot_product_i8_avx2(a: &[i8], b: &[i8]) -> i32 {
use std::arch::x86_64::*;
let len = a.len();
let mut sum = _mm256_setzero_si256();
let chunks = len / 32;
for i in 0..chunks {
let idx = i * 32;
let va = _mm256_loadu_si256(a.as_ptr().add(idx) as *const __m256i);
let vb = _mm256_loadu_si256(b.as_ptr().add(idx) as *const __m256i);
let va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va));
let va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1));
let vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb));
let vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1));
let prod_lo = _mm256_madd_epi16(va_lo, vb_lo);
let prod_hi = _mm256_madd_epi16(va_hi, vb_hi);
sum = _mm256_add_epi32(sum, prod_lo);
sum = _mm256_add_epi32(sum, prod_hi);
}
let sum128 = _mm_add_epi32(
_mm256_castsi256_si128(sum),
_mm256_extracti128_si256(sum, 1),
);
let sum64 = _mm_add_epi32(sum128, _mm_srli_si128(sum128, 8));
let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
let mut result = _mm_cvtsi128_si32(sum32);
for i in (chunks * 32)..len {
result += (a[i] as i32) * (b[i] as i32);
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.1")]
#[inline]
unsafe fn dot_product_i8_sse4(a: &[i8], b: &[i8]) -> i32 {
use std::arch::x86_64::*;
let len = a.len();
let mut sum = _mm_setzero_si128();
let chunks = len / 16;
for i in 0..chunks {
let idx = i * 16;
let va = _mm_loadu_si128(a.as_ptr().add(idx) as *const __m128i);
let vb = _mm_loadu_si128(b.as_ptr().add(idx) as *const __m128i);
let va_lo = _mm_cvtepi8_epi16(va);
let va_hi = _mm_cvtepi8_epi16(_mm_srli_si128(va, 8));
let vb_lo = _mm_cvtepi8_epi16(vb);
let vb_hi = _mm_cvtepi8_epi16(_mm_srli_si128(vb, 8));
let prod_lo = _mm_madd_epi16(va_lo, vb_lo);
let prod_hi = _mm_madd_epi16(va_hi, vb_hi);
sum = _mm_add_epi32(sum, prod_lo);
sum = _mm_add_epi32(sum, prod_hi);
}
let sum64 = _mm_add_epi32(sum, _mm_srli_si128(sum, 8));
let sum32 = _mm_add_epi32(sum64, _mm_srli_si128(sum64, 4));
let mut result = _mm_cvtsi128_si32(sum32);
for i in (chunks * 16)..len {
result += (a[i] as i32) * (b[i] as i32);
}
result
}
#[inline]
pub fn dot_product_i8_f32_simd(a: &[i8], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { dot_product_i8_f32_avx2(a, b) };
}
}
dot_product_i8_f32_scalar(a, b)
}
#[inline]
fn dot_product_i8_f32_scalar(a: &[i8], b: &[f32]) -> f32 {
let mut sum = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
sum += (*x as f32) * y;
}
sum
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn dot_product_i8_f32_avx2(a: &[i8], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len();
let mut sum = _mm256_setzero_ps();
let chunks = len / 8;
for i in 0..chunks {
let idx = i * 8;
let va_i8 = _mm_loadl_epi64(a.as_ptr().add(idx) as *const __m128i);
let va_i16 = _mm_cvtepi8_epi16(va_i8);
let va_i32 = _mm256_cvtepi16_epi32(va_i16);
let va_f32 = _mm256_cvtepi32_ps(va_i32);
let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
sum = _mm256_fmadd_ps(va_f32, vb, sum);
}
let sum128 = _mm_add_ps(_mm256_castps256_ps128(sum), _mm256_extractf128_ps(sum, 1));
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
for i in (chunks * 8)..len {
result += (a[i] as f32) * b[i];
}
result
}
#[inline]
pub fn l2_squared_i8_simd(a: &[i8], b: &[i8]) -> i32 {
debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
let mut sum = 0i32;
for (x, y) in a.iter().zip(b.iter()) {
let d = (*x as i32) - (*y as i32);
sum += d * d;
}
sum
}
#[derive(Clone)]
pub struct Int8Index {
vectors: Vec<i8>,
scales: Vec<f32>,
norms: Vec<f32>,
dim: usize,
n_vectors: usize,
}
impl Int8Index {
pub fn new(dim: usize) -> Self {
Self {
vectors: Vec::new(),
scales: Vec::new(),
norms: Vec::new(),
dim,
n_vectors: 0,
}
}
pub fn with_capacity(dim: usize, capacity: usize) -> Self {
Self {
vectors: Vec::with_capacity(capacity * dim),
scales: Vec::with_capacity(capacity),
norms: Vec::with_capacity(capacity),
dim,
n_vectors: 0,
}
}
pub fn add(&mut self, vector: &Int8Vector) {
debug_assert_eq!(vector.dim(), self.dim, "Dimension mismatch");
self.vectors.extend_from_slice(&vector.data);
self.scales.push(vector.scale);
self.norms.push(vector.norm);
self.n_vectors += 1;
}
pub fn add_f32(&mut self, vector: &[f32]) {
let int8 = Int8Vector::from_f32(vector);
self.add(&int8);
}
#[inline]
pub fn len(&self) -> usize {
self.n_vectors
}
#[inline]
pub fn is_empty(&self) -> bool {
self.n_vectors == 0
}
pub fn memory_bytes(&self) -> usize {
self.vectors.len() + self.scales.len() * 4 + self.norms.len() * 4
}
pub fn get(&self, idx: usize) -> Option<Int8Vector> {
if idx >= self.n_vectors {
return None;
}
let start = idx * self.dim;
let end = start + self.dim;
Some(Int8Vector::from_raw(
self.vectors[start..end].to_vec(),
self.scales[idx],
self.norms[idx],
))
}
#[inline]
pub fn get_data(&self, idx: usize) -> &[i8] {
let start = idx * self.dim;
let end = start + self.dim;
&self.vectors[start..end]
}
#[inline]
pub fn dot_product_f32(&self, idx: usize, query: &[f32]) -> f32 {
let data = self.get_data(idx);
let scale = self.scales[idx];
dot_product_i8_f32_simd(data, query) * scale
}
pub fn rescore_candidates(
&self,
candidates: &[(usize, u32)],
query: &[f32],
) -> Vec<(usize, f32)> {
let mut results: Vec<(usize, f32)> = candidates
.iter()
.filter_map(|&(idx, _)| {
if idx < self.n_vectors {
let dot = self.dot_product_f32(idx, query);
Some((idx, -dot))
} else {
None
}
})
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
results
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_int8_quantization() {
let values = vec![1.0, -1.0, 0.5, -0.5, 0.0];
let int8 = Int8Vector::from_f32(&values);
assert_eq!(int8.data[0], 127); assert_eq!(int8.data[1], -127); assert_eq!(int8.data[4], 0); }
#[test]
fn test_dot_product_identical() {
let v1 = Int8Vector::from_f32(&[1.0, 2.0, 3.0, 4.0]);
let v2 = Int8Vector::from_f32(&[1.0, 2.0, 3.0, 4.0]);
let dot = v1.dot_product(&v2);
let expected = 1.0 + 4.0 + 9.0 + 16.0; assert!((dot - expected).abs() < 1.0); }
#[test]
fn test_dot_product_f32() {
let int8 = Int8Vector::from_f32(&[1.0, 0.0, -1.0, 0.5]);
let query = vec![1.0, 1.0, 1.0, 1.0];
let dot = int8.dot_product_f32(&query);
assert!((dot - 0.5).abs() < 0.1);
}
#[test]
fn test_compression_ratio() {
let fp32_size = 1024 * 4;
let int8 = Int8Vector::from_f32(&vec![1.0; 1024]);
let int8_size = int8.size_bytes();
assert_eq!(int8_size, 1032);
assert!(fp32_size / int8_size >= 3); }
#[test]
fn test_index_rescore() {
let mut index = Int8Index::new(4);
index.add_f32(&[1.0, 0.0, 0.0, 0.0]);
index.add_f32(&[0.0, 1.0, 0.0, 0.0]);
index.add_f32(&[0.0, 0.0, 1.0, 0.0]);
let query = vec![1.0, 0.0, 0.0, 0.0];
let binary_candidates = vec![(0, 10), (1, 20), (2, 30)];
let rescored = index.rescore_candidates(&binary_candidates, &query);
assert_eq!(rescored[0].0, 0);
}
#[test]
fn test_simd_vs_scalar() {
let a: Vec<i8> = (0..128).map(|i| (i % 127) as i8).collect();
let b: Vec<i8> = (0..128).map(|i| ((127 - i) % 127) as i8).collect();
let scalar = dot_product_i8_scalar(&a, &b);
#[cfg(target_arch = "x86_64")]
{
let simd = dot_product_i8_simd(&a, &b);
assert_eq!(scalar, simd);
}
}
}