use crate::backend::{QueryResponse, RingComputeBackend};
use crate::error::{Result, RingDbError};
use rayon::iter::IndexedParallelIterator;
use rayon::iter::ParallelIterator;
use rayon::prelude::{IntoParallelRefIterator, ParallelSlice};
#[inline(always)]
fn dot_f32(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let (mut a0, mut a1, mut a2, mut a3) = (0.0f32, 0.0f32, 0.0f32, 0.0f32);
let chunks_a = a.chunks_exact(4);
let chunks_b = b.chunks_exact(4);
let rem_a = chunks_a.remainder();
let rem_b = chunks_b.remainder();
for (ca, cb) in chunks_a.zip(chunks_b) {
a0 = ca[0].mul_add(cb[0], a0);
a1 = ca[1].mul_add(cb[1], a1);
a2 = ca[2].mul_add(cb[2], a2);
a3 = ca[3].mul_add(cb[3], a3);
}
let mut sum = (a0 + a1) + (a2 + a3);
for (a, b) in rem_a.iter().zip(rem_b.iter()) {
sum += a * b;
}
sum
}
#[inline(always)]
fn norm_sq_f32(v: &[f32]) -> f32 {
let (mut a0, mut a1, mut a2, mut a3) = (0.0f32, 0.0f32, 0.0f32, 0.0f32);
let chunks = v.chunks_exact(4);
let rem = chunks.remainder();
for c in chunks {
a0 = c[0].mul_add(c[0], a0);
a1 = c[1].mul_add(c[1], a1);
a2 = c[2].mul_add(c[2], a2);
a3 = c[3].mul_add(c[3], a3);
}
let mut sum = (a0 + a1) + (a2 + a3);
for x in rem {
sum += x * x;
}
sum
}
pub struct CpuBackend {
dims: usize,
n_vectors: usize,
vectors: Vec<f32>, norms_sq: Vec<f32>, }
impl CpuBackend {
pub fn new() -> Self {
Self {
dims: 0,
n_vectors: 0,
vectors: Vec::new(),
norms_sq: Vec::new(),
}
}
}
impl Default for CpuBackend {
fn default() -> Self {
Self::new()
}
}
impl RingComputeBackend for CpuBackend {
fn name(&self) -> &'static str {
"cpu"
}
fn upload_f32_dataset(
&mut self,
dims: usize,
vectors: Vec<f32>,
norms_sq: Vec<f32>,
) -> Result<()> {
self.dims = dims;
self.n_vectors = norms_sq.len();
self.vectors = vectors;
self.norms_sq = norms_sq;
Ok(())
}
fn ring_query_f32(
&self,
dims: usize,
query: &[f32],
d_min: f32,
d_max: f32,
) -> Result<Vec<QueryResponse>> {
if self.n_vectors == 0 {
return Ok(Vec::new());
}
if query.len() != dims {
return Err(RingDbError::DimensionMismatch {
expected: dims,
got: query.len(),
});
}
let norm_sq_q = norm_sq_f32(query);
let lower_sq = d_min * d_min;
let upper_sq = d_max * d_max;
Ok(self
.vectors
.par_chunks_exact(dims)
.zip(self.norms_sq.par_iter())
.enumerate()
.filter_map(|(i, (row, &norm_sq_i))| {
let dot = dot_f32(row, query);
let dist_sq = norm_sq_i + norm_sq_q - 2.0 * dot;
(dist_sq >= lower_sq && dist_sq <= upper_sq).then_some(QueryResponse {
id: i as u32,
dist_sq,
})
})
.collect())
}
fn disk_query_f32(&self, dims: usize, query: &[f32], d_max: f32) -> Result<Vec<QueryResponse>> {
if self.n_vectors == 0 {
return Ok(Vec::new());
}
if query.len() != dims {
return Err(RingDbError::DimensionMismatch {
expected: dims,
got: query.len(),
});
}
let norm_sq_q = norm_sq_f32(query);
let upper_sq = d_max * d_max;
Ok(self
.vectors
.par_chunks_exact(dims)
.zip(self.norms_sq.par_iter())
.enumerate()
.filter_map(|(i, (row, &norm_sq_i))| {
let dot = dot_f32(row, query);
let dist_sq = norm_sq_i + norm_sq_q - 2.0 * dot;
(dist_sq <= upper_sq).then_some(QueryResponse {
id: i as u32,
dist_sq,
})
})
.collect())
}
fn disk_intersection_query_f32(
&self,
dims: usize,
disks: &[(&[f32], f32)],
) -> Result<Vec<QueryResponse>> {
if self.n_vectors == 0 || disks.is_empty() {
return Ok(Vec::new());
}
for &(query, _) in disks {
if query.len() != dims {
return Err(RingDbError::DimensionMismatch {
expected: dims,
got: query.len(),
});
}
}
let prepared: Vec<(&[f32], f32, f32)> = disks
.iter()
.map(|&(query, d_max)| (query, norm_sq_f32(query), d_max * d_max))
.collect();
Ok(self
.vectors
.par_chunks_exact(dims)
.zip(self.norms_sq.par_iter())
.enumerate()
.filter_map(|(i, (row, &norm_sq_i))| {
let mut first_dist_sq = 0.0;
for (disk_i, &(query, norm_sq_q, upper_sq)) in prepared.iter().enumerate() {
let dot = dot_f32(row, query);
let dist_sq = norm_sq_i + norm_sq_q - 2.0 * dot;
if dist_sq > upper_sq {
return None;
}
if disk_i == 0 {
first_dist_sq = dist_sq;
}
}
Some(QueryResponse {
id: i as u32,
dist_sq: first_dist_sq,
})
})
.collect())
}
}