use crate::VectorType;
pub struct Int8Pool {
pub mins: Vec<f32>,
pub scales: Vec<f32>,
pub data: Vec<i8>,
pub dim: usize,
pub count: usize,
}
impl Int8Pool {
pub fn from_f32_vectors(flat_vectors: &[f32], dim: usize) -> Self {
let count = flat_vectors.len() / dim;
if count == 0 || dim == 0 {
return Self {
mins: vec![],
scales: vec![],
data: vec![],
dim,
count: 0,
};
}
let mut mins = vec![f32::MAX; dim];
let mut maxs = vec![f32::MIN; dim];
for chunk in flat_vectors.chunks_exact(dim) {
for (d, &val) in chunk.iter().enumerate() {
if val < mins[d] { mins[d] = val; }
if val > maxs[d] { maxs[d] = val; }
}
}
let scales: Vec<f32> = mins.iter().zip(maxs.iter())
.map(|(&mn, &mx)| {
let range = mx - mn;
if range < 1e-12 { 1.0 } else { 254.0 / range }
})
.collect();
let mut data = Vec::with_capacity(count * dim);
for chunk in flat_vectors.chunks_exact(dim) {
for (d, &val) in chunk.iter().enumerate() {
let q = ((val - mins[d]) * scales[d] - 127.0)
.round()
.clamp(-127.0, 127.0) as i8;
data.push(q);
}
}
Self { mins, scales, data, dim, count }
}
pub fn from_generic_vectors<T: VectorType>(flat_vectors: &[T], dim: usize) -> Self {
let f32_buf: Vec<f32> = flat_vectors.iter().map(|x| x.to_f32()).collect();
Self::from_f32_vectors(&f32_buf, dim)
}
#[inline]
pub fn quantize_query<T: VectorType>(&self, query: &[T]) -> Vec<i8> {
query.iter().enumerate().map(|(d, val)| {
let f = val.to_f32();
((f - self.mins[d]) * self.scales[d] - 127.0)
.round()
.clamp(-127.0, 127.0) as i8
}).collect()
}
#[inline]
pub fn dot_score(&self, index: usize, query_i8: &[i8]) -> i32 {
let offset = index * self.dim;
let vec_slice = &self.data[offset..offset + self.dim];
let (mut acc0, mut acc1, mut acc2, mut acc3) = (0i32, 0i32, 0i32, 0i32);
let chunks = self.dim / 4 * 4;
let mut i = 0;
while i < chunks {
acc0 += vec_slice[i] as i32 * query_i8[i] as i32;
acc1 += vec_slice[i + 1] as i32 * query_i8[i + 1] as i32;
acc2 += vec_slice[i + 2] as i32 * query_i8[i + 2] as i32;
acc3 += vec_slice[i + 3] as i32 * query_i8[i + 3] as i32;
i += 4;
}
let mut acc = acc0 + acc1 + acc2 + acc3;
while i < self.dim {
acc += vec_slice[i] as i32 * query_i8[i] as i32;
i += 1;
}
acc
}
#[inline]
pub fn is_valid_index(&self, index: usize) -> bool {
index < self.count
}
}