use ndarray::{Array1, Array2, Axis};
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use serde::{Deserialize, Serialize};
use crate::ranking::truncate_top_k;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DenseIndex {
vectors: Array2<f32>,
}
impl DenseIndex {
pub fn new(mut vectors: Array2<f32>) -> Self {
for mut row in vectors.axis_iter_mut(Axis(0)) {
let norm = row.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > 1e-8 {
row.mapv_inplace(|v| v / norm);
}
}
Self { vectors }
}
pub fn len(&self) -> usize {
self.vectors.shape()[0]
}
pub fn dim(&self) -> usize {
self.vectors.shape()[1]
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn query(
&self,
vector: &Array1<f32>,
k: usize,
selector: Option<&[usize]>,
) -> Vec<(usize, f32)> {
if k == 0 || self.is_empty() {
return Vec::new();
}
if selector.is_some_and(|s| s.is_empty()) {
return Vec::new();
}
let mut scores: Vec<(usize, f32)> = match selector {
Some(candidates) => candidates
.par_iter()
.map(|&idx| {
let row = self.vectors.row(idx);
let score = row
.iter()
.zip(vector.iter())
.map(|(a, b)| a * b)
.sum::<f32>();
(idx, score)
})
.collect(),
None => (0..self.len())
.into_par_iter()
.fold(Vec::new, |mut local, idx| {
let row = self.vectors.row(idx);
let score = row
.iter()
.zip(vector.iter())
.map(|(a, b)| a * b)
.sum::<f32>();
local.push((idx, score));
if local.len() > k.saturating_mul(2).max(32) {
truncate_top_k(&mut local, k);
}
local
})
.reduce(Vec::new, |mut left, right| {
left.extend(right);
if left.len() > k.saturating_mul(2).max(32) {
truncate_top_k(&mut left, k);
}
left
}),
};
truncate_top_k(&mut scores, k);
scores
}
}
#[cfg(test)]
mod tests {
use super::DenseIndex;
use ndarray::array;
#[test]
fn query_respects_selector_and_top_k_order() {
let index = DenseIndex::new(array![[1.0, 0.0], [0.9, 0.1], [0.0, 1.0]]);
let results = index.query(&array![1.0, 0.0], 1, Some(&[1, 2]));
assert_eq!(results, vec![(1, results[0].1)]);
assert!(results[0].1 > 0.9);
}
#[test]
fn query_with_empty_selector_returns_no_candidates() {
let index = DenseIndex::new(array![[1.0, 0.0], [0.0, 1.0]]);
let results = index.query(&array![1.0, 0.0], 10, Some(&[]));
assert!(results.is_empty());
}
}