use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use crate::index::topk::top_k_from_iter_f32;
const PARALLEL_THRESHOLD: usize = 1024;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DenseIndex {
matrix: Vec<f32>,
n: usize,
dim: usize,
}
impl DenseIndex {
pub fn new(embeddings: Vec<Vec<f32>>) -> Self {
let n = embeddings.len();
let dim = embeddings.first().map(|v| v.len()).unwrap_or(0);
let mut matrix = Vec::with_capacity(n * dim);
for v in &embeddings {
let mut buf = vec![0.0f32; dim];
let copy = v.len().min(dim);
buf[..copy].copy_from_slice(&v[..copy]);
normalise_in_place(&mut buf);
matrix.extend_from_slice(&buf);
}
Self { matrix, n, dim }
}
pub fn len(&self) -> usize {
self.n
}
pub fn is_empty(&self) -> bool {
self.n == 0
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn extract_rows(&self, indices: &[usize]) -> Vec<Vec<f32>> {
let mut out = Vec::with_capacity(indices.len());
for &i in indices {
if i < self.n {
out.push(self.row(i).to_vec());
}
}
out
}
pub fn compact_and_extend(&mut self, keep_indices: &[usize], new_embeddings: Vec<Vec<f32>>) {
debug_assert!(
keep_indices.windows(2).all(|w| w[0] < w[1]),
"keep_indices must be sorted ascending and unique"
);
let dim = self.dim;
let kept = keep_indices.len();
for (new_pos, &old_pos) in keep_indices.iter().enumerate() {
if new_pos == old_pos {
continue;
}
let src_start = old_pos * dim;
let dst_start = new_pos * dim;
self.matrix.copy_within(src_start..src_start + dim, dst_start);
}
self.matrix.truncate(kept * dim);
let total_new = new_embeddings.len();
self.matrix.reserve(total_new * dim);
for emb in &new_embeddings {
let copy = emb.len().min(dim);
self.matrix.extend_from_slice(&emb[..copy]);
self.matrix.extend(std::iter::repeat(0.0).take(dim - copy));
}
self.n = kept + total_new;
for i in kept..self.n {
let row = &mut self.matrix[i * dim..(i + 1) * dim];
normalise_in_place(row);
}
}
#[inline]
fn row(&self, i: usize) -> &[f32] {
let start = i * self.dim;
&self.matrix[start..start + self.dim]
}
pub fn query(
&self,
query: &[f32],
k: usize,
selector: Option<&[usize]>,
) -> (Vec<usize>, Vec<f32>) {
let _span = tracing::trace_span!("dense.query", n = self.n, k, dim = self.dim).entered();
if self.n == 0 || k == 0 {
return (Vec::new(), Vec::new());
}
let mut q = vec![0.0f32; self.dim];
let copy = query.len().min(self.dim);
q[..copy].copy_from_slice(&query[..copy]);
normalise_in_place(&mut q);
let candidates: &[usize] = match selector {
Some(sel) => sel,
None => &[],
};
let n_candidates = if selector.is_some() {
candidates.len()
} else {
self.n
};
if n_candidates == 0 {
return (Vec::new(), Vec::new());
}
let scored: Vec<(usize, f32)> = if n_candidates >= PARALLEL_THRESHOLD {
if let Some(sel) = selector {
sel.par_iter()
.filter_map(|&idx| {
if idx < self.n {
Some((idx, dot(self.row(idx), &q)))
} else {
None
}
})
.collect()
} else {
(0..self.n)
.into_par_iter()
.map(|idx| (idx, dot(self.row(idx), &q)))
.collect()
}
} else if let Some(sel) = selector {
sel.iter()
.filter_map(|&idx| {
if idx < self.n {
Some((idx, dot(self.row(idx), &q)))
} else {
None
}
})
.collect()
} else {
(0..self.n)
.map(|idx| (idx, dot(self.row(idx), &q)))
.collect()
};
let topk = top_k_from_iter_f32(scored, k);
let mut indices = Vec::with_capacity(topk.len());
let mut scores = Vec::with_capacity(topk.len());
for (i, s) in topk {
indices.push(i);
scores.push(s);
}
(indices, scores)
}
pub fn query_batch(
&self,
queries: &[Vec<f32>],
k: usize,
selector: Option<&[usize]>,
) -> Vec<(Vec<usize>, Vec<f32>)> {
queries
.par_iter()
.map(|q| self.query(q, k, selector))
.collect()
}
}
#[inline]
fn normalise_in_place(v: &mut [f32]) {
let mut sum_sq = 0.0f32;
for &x in v.iter() {
sum_sq += x * x;
}
if sum_sq > 0.0 {
let inv = sum_sq.sqrt().recip();
for x in v.iter_mut() {
*x *= inv;
}
}
}
#[inline]
fn dot(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let mut acc = 0.0f32;
let mut i = 0;
let chunks = a.len() / 8;
while i < chunks * 8 {
acc += a[i] * b[i]
+ a[i + 1] * b[i + 1]
+ a[i + 2] * b[i + 2]
+ a[i + 3] * b[i + 3]
+ a[i + 4] * b[i + 4]
+ a[i + 5] * b[i + 5]
+ a[i + 6] * b[i + 6]
+ a[i + 7] * b[i + 7];
i += 8;
}
while i < a.len() {
acc += a[i] * b[i];
i += 1;
}
acc
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_index() {
let index = DenseIndex::new(vec![]);
let (indices, _) = index.query(&[1.0, 0.0, 0.0], 5, None);
assert!(indices.is_empty());
}
#[test]
fn test_cosine_search() {
let embeddings = vec![
vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.9, 0.1, 0.0], ];
let index = DenseIndex::new(embeddings);
let (indices, scores) = index.query(&[1.0, 0.0, 0.0], 2, None);
assert_eq!(indices.len(), 2);
assert_eq!(indices[0], 0);
assert!((scores[0] - 1.0).abs() < 1e-4);
assert_eq!(indices[1], 2);
}
#[test]
fn test_with_selector() {
let embeddings = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]];
let index = DenseIndex::new(embeddings);
let (indices, _) = index.query(&[0.0, 1.0], 2, Some(&[1, 2]));
assert_eq!(indices[0], 1);
}
#[test]
fn compact_and_extend_preserves_kept_rows() {
let mut index = DenseIndex::new(vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![0.7, 0.7],
vec![0.5, 0.5],
]);
let kept_row0_before = index.row(0).to_vec();
let kept_row2_before = index.row(2).to_vec();
index.compact_and_extend(&[0, 2], vec![vec![1.0, 1.0]]);
assert_eq!(index.len(), 3);
assert_eq!(index.row(0), kept_row0_before.as_slice());
assert_eq!(index.row(1), kept_row2_before.as_slice());
let row2 = index.row(2);
let norm_sq: f32 = row2.iter().map(|x| x * x).sum();
assert!((norm_sq - 1.0).abs() < 1e-5, "appended row not normalised");
}
#[test]
fn compact_and_extend_full_drop() {
let mut index = DenseIndex::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]]);
index.compact_and_extend(&[], vec![vec![1.0, 0.0]]);
assert_eq!(index.len(), 1);
}
#[test]
fn compact_and_extend_no_new() {
let mut index = DenseIndex::new(vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![0.7, 0.7],
]);
let kept2 = index.row(2).to_vec();
index.compact_and_extend(&[0, 2], vec![]);
assert_eq!(index.len(), 2);
assert_eq!(index.row(1), kept2.as_slice());
}
}