use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use crate::index::topk::top_k_from_iter_f32;
const PARALLEL_THRESHOLD: usize = 1024;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DenseIndex {
matrix: Vec<f32>,
n: usize,
dim: usize,
}
impl DenseIndex {
pub fn new(embeddings: Vec<Vec<f32>>) -> Self {
let n = embeddings.len();
let dim = embeddings.first().map(|v| v.len()).unwrap_or(0);
let mut matrix = Vec::with_capacity(n * dim);
for v in &embeddings {
let mut buf = vec![0.0f32; dim];
let copy = v.len().min(dim);
buf[..copy].copy_from_slice(&v[..copy]);
normalise_in_place(&mut buf);
matrix.extend_from_slice(&buf);
}
Self { matrix, n, dim }
}
pub fn len(&self) -> usize {
self.n
}
pub fn is_empty(&self) -> bool {
self.n == 0
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn extract_rows(&self, indices: &[usize]) -> Vec<Vec<f32>> {
let mut out = Vec::with_capacity(indices.len());
for &i in indices {
if i < self.n {
out.push(self.row(i).to_vec());
}
}
out
}
#[inline]
fn row(&self, i: usize) -> &[f32] {
let start = i * self.dim;
&self.matrix[start..start + self.dim]
}
pub fn query(
&self,
query: &[f32],
k: usize,
selector: Option<&[usize]>,
) -> (Vec<usize>, Vec<f32>) {
if self.n == 0 || k == 0 {
return (Vec::new(), Vec::new());
}
let mut q = vec![0.0f32; self.dim];
let copy = query.len().min(self.dim);
q[..copy].copy_from_slice(&query[..copy]);
normalise_in_place(&mut q);
let candidates: &[usize] = match selector {
Some(sel) => sel,
None => &[],
};
let n_candidates = if selector.is_some() {
candidates.len()
} else {
self.n
};
if n_candidates == 0 {
return (Vec::new(), Vec::new());
}
let scored: Vec<(usize, f32)> = if n_candidates >= PARALLEL_THRESHOLD {
if let Some(sel) = selector {
sel.par_iter()
.filter_map(|&idx| {
if idx < self.n {
Some((idx, dot(self.row(idx), &q)))
} else {
None
}
})
.collect()
} else {
(0..self.n)
.into_par_iter()
.map(|idx| (idx, dot(self.row(idx), &q)))
.collect()
}
} else if let Some(sel) = selector {
sel.iter()
.filter_map(|&idx| {
if idx < self.n {
Some((idx, dot(self.row(idx), &q)))
} else {
None
}
})
.collect()
} else {
(0..self.n).map(|idx| (idx, dot(self.row(idx), &q))).collect()
};
let topk = top_k_from_iter_f32(scored, k);
let mut indices = Vec::with_capacity(topk.len());
let mut scores = Vec::with_capacity(topk.len());
for (i, s) in topk {
indices.push(i);
scores.push(s);
}
(indices, scores)
}
pub fn query_batch(
&self,
queries: &[Vec<f32>],
k: usize,
selector: Option<&[usize]>,
) -> Vec<(Vec<usize>, Vec<f32>)> {
queries
.par_iter()
.map(|q| self.query(q, k, selector))
.collect()
}
}
#[inline]
fn normalise_in_place(v: &mut [f32]) {
let mut sum_sq = 0.0f32;
for &x in v.iter() {
sum_sq += x * x;
}
if sum_sq > 0.0 {
let inv = sum_sq.sqrt().recip();
for x in v.iter_mut() {
*x *= inv;
}
}
}
#[inline]
fn dot(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let mut acc = 0.0f32;
let mut i = 0;
let chunks = a.len() / 8;
while i < chunks * 8 {
acc += a[i] * b[i]
+ a[i + 1] * b[i + 1]
+ a[i + 2] * b[i + 2]
+ a[i + 3] * b[i + 3]
+ a[i + 4] * b[i + 4]
+ a[i + 5] * b[i + 5]
+ a[i + 6] * b[i + 6]
+ a[i + 7] * b[i + 7];
i += 8;
}
while i < a.len() {
acc += a[i] * b[i];
i += 1;
}
acc
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_index() {
let index = DenseIndex::new(vec![]);
let (indices, _) = index.query(&[1.0, 0.0, 0.0], 5, None);
assert!(indices.is_empty());
}
#[test]
fn test_cosine_search() {
let embeddings = vec![
vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.9, 0.1, 0.0], ];
let index = DenseIndex::new(embeddings);
let (indices, scores) = index.query(&[1.0, 0.0, 0.0], 2, None);
assert_eq!(indices.len(), 2);
assert_eq!(indices[0], 0);
assert!((scores[0] - 1.0).abs() < 1e-4);
assert_eq!(indices[1], 2);
}
#[test]
fn test_with_selector() {
let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]];
let index = DenseIndex::new(embeddings);
let (indices, _) = index.query(&[0.0, 1.0], 2, Some(&[1, 2]));
assert_eq!(indices[0], 1);
}
}