#[derive(Clone, Debug, PartialEq)]
pub struct VerticalBatch {
data: Vec<f32>,
num_vectors: usize,
dimension: usize,
}
impl VerticalBatch {
pub fn from_rows(vectors: &[Vec<f32>]) -> Self {
if vectors.is_empty() {
return Self {
data: Vec::new(),
num_vectors: 0,
dimension: 0,
};
}
let dimension = vectors[0].len();
let num_vectors = vectors.len();
let mut data = vec![0.0f32; dimension * num_vectors];
for (i, vec) in vectors.iter().enumerate() {
assert_eq!(vec.len(), dimension, "Inconsistent vector dimension");
for (d, &val) in vec.iter().enumerate() {
data[d * num_vectors + i] = val;
}
}
Self {
data,
num_vectors,
dimension,
}
}
pub fn from_slices(vectors: &[&[f32]]) -> Self {
if vectors.is_empty() {
return Self {
data: Vec::new(),
num_vectors: 0,
dimension: 0,
};
}
let dimension = vectors[0].len();
let num_vectors = vectors.len();
let mut data = vec![0.0f32; dimension * num_vectors];
for (i, vec) in vectors.iter().enumerate() {
assert_eq!(vec.len(), dimension, "Inconsistent vector dimension");
for (d, &val) in vec.iter().enumerate() {
data[d * num_vectors + i] = val;
}
}
Self {
data,
num_vectors,
dimension,
}
}
pub fn from_flat(data: &[f32], num_vectors: usize, dimension: usize) -> Self {
assert_eq!(data.len(), num_vectors * dimension);
let mut vertical = vec![0.0f32; dimension * num_vectors];
for i in 0..num_vectors {
for d in 0..dimension {
vertical[d * num_vectors + i] = data[i * dimension + d];
}
}
Self {
data: vertical,
num_vectors,
dimension,
}
}
#[inline]
pub fn get(&self, dim: usize, vec_idx: usize) -> f32 {
self.data[dim * self.num_vectors + vec_idx]
}
#[inline]
pub fn dimension_slice(&self, dim: usize) -> &[f32] {
let start = dim * self.num_vectors;
&self.data[start..start + self.num_vectors]
}
pub fn num_vectors(&self) -> usize {
self.num_vectors
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn data(&self) -> &[f32] {
&self.data
}
pub fn extract_vector(&self, vec_idx: usize) -> Vec<f32> {
(0..self.dimension).map(|d| self.get(d, vec_idx)).collect()
}
}
#[must_use]
pub fn batch_l2_squared(query: &[f32], batch: &VerticalBatch) -> Vec<f32> {
assert_eq!(query.len(), batch.dimension);
let mut distances = vec![0.0f32; batch.num_vectors];
for (d, &q_d) in query.iter().enumerate().take(batch.dimension) {
let dim_slice = batch.dimension_slice(d);
for (dist, &v_d) in distances.iter_mut().zip(dim_slice.iter()) {
let diff = q_d - v_d;
*dist += diff * diff;
}
}
distances
}
#[must_use]
pub fn batch_dot(query: &[f32], batch: &VerticalBatch) -> Vec<f32> {
assert_eq!(query.len(), batch.dimension);
let mut products = vec![0.0f32; batch.num_vectors];
for (d, &q_d) in query.iter().enumerate().take(batch.dimension) {
let dim_slice = batch.dimension_slice(d);
for (prod, &v_d) in products.iter_mut().zip(dim_slice.iter()) {
*prod += q_d * v_d;
}
}
products
}
#[must_use]
pub fn batch_l2_squared_pruning(
query: &[f32],
batch: &VerticalBatch,
threshold: f32,
) -> Vec<(usize, f32)> {
assert_eq!(query.len(), batch.dimension);
let mut distances = vec![0.0f32; batch.num_vectors];
let mut alive: Vec<bool> = vec![true; batch.num_vectors];
let mut num_alive = batch.num_vectors;
for (d, &q_d) in query.iter().enumerate().take(batch.dimension) {
if num_alive == 0 {
break;
}
let dim_slice = batch.dimension_slice(d);
for (&v_d, (dist, is_alive)) in dim_slice
.iter()
.zip(distances.iter_mut().zip(alive.iter_mut()))
{
if !*is_alive {
continue;
}
let diff = q_d - v_d;
*dist += diff * diff;
if *dist > threshold {
*is_alive = false;
num_alive -= 1;
}
}
}
alive
.iter()
.enumerate()
.filter(|(_, &a)| a)
.map(|(i, _)| (i, distances[i]))
.collect()
}
#[derive(Clone, Debug, PartialEq)]
pub struct BatchKnnResult {
pub indices: Vec<usize>,
pub scores: Vec<f32>,
}
#[must_use]
pub fn batch_knn(query: &[f32], batch: &VerticalBatch, k: usize) -> BatchKnnResult {
assert_eq!(query.len(), batch.dimension);
if batch.num_vectors == 0 || k == 0 {
return BatchKnnResult {
indices: Vec::new(),
scores: Vec::new(),
};
}
let k = k.min(batch.num_vectors);
let distances = batch_l2_squared(query, batch);
let mut indexed: Vec<(usize, f32)> = distances.into_iter().enumerate().collect();
indexed.sort_by(|a, b| a.1.total_cmp(&b.1));
indexed.truncate(k);
BatchKnnResult {
indices: indexed.iter().map(|(i, _)| *i).collect(),
scores: indexed.iter().map(|(_, d)| *d).collect(),
}
}
#[must_use]
pub fn batch_knn_adaptive(
query: &[f32],
batch: &VerticalBatch,
k: usize,
warmup_dims: usize,
) -> BatchKnnResult {
assert_eq!(query.len(), batch.dimension);
assert!(warmup_dims > 0, "warmup_dims must be > 0");
if batch.num_vectors == 0 || k == 0 {
return BatchKnnResult {
indices: Vec::new(),
scores: Vec::new(),
};
}
let k = k.min(batch.num_vectors);
let warmup_dims = warmup_dims.min(batch.dimension);
let mut distances = vec![0.0f32; batch.num_vectors];
let mut alive: Vec<bool> = vec![true; batch.num_vectors];
for (d, &q_d) in query.iter().enumerate().take(warmup_dims) {
let dim_slice = batch.dimension_slice(d);
for (dist, &v_d) in distances.iter_mut().zip(dim_slice.iter()) {
let diff = q_d - v_d;
*dist += diff * diff;
}
}
let mut partial_indexed: Vec<(usize, f32)> = distances.iter().copied().enumerate().collect();
partial_indexed.sort_by(|a, b| a.1.total_cmp(&b.1));
let mut threshold = if k <= partial_indexed.len() {
partial_indexed[k - 1].1 * (batch.dimension as f32 / warmup_dims as f32)
} else {
f32::MAX
};
for (i, &dist) in distances.iter().enumerate() {
let estimated_full = dist * (batch.dimension as f32 / warmup_dims as f32);
if estimated_full > threshold * 1.5 {
alive[i] = false;
}
}
let mut threshold_buf = vec![0.0f32; batch.num_vectors];
for (d, &q_d) in query
.iter()
.enumerate()
.skip(warmup_dims)
.take(batch.dimension - warmup_dims)
{
let dim_slice = batch.dimension_slice(d);
for ((&v_d, dist), is_alive) in dim_slice
.iter()
.zip(distances.iter_mut())
.zip(alive.iter_mut())
{
if !*is_alive {
continue;
}
let diff = q_d - v_d;
*dist += diff * diff;
if *dist > threshold {
*is_alive = false;
}
}
if d % 32 == 0 {
let buf = &mut threshold_buf[..];
let mut count = 0;
for (&is_alive, &dist) in alive.iter().zip(distances.iter()) {
if is_alive && count < buf.len() {
buf[count] = dist;
count += 1;
}
}
if count >= k {
let slice = &mut buf[..count];
slice.select_nth_unstable_by(k - 1, |a, b| a.total_cmp(b));
threshold = slice[k - 1];
}
}
}
let mut results: Vec<(usize, f32)> = alive
.iter()
.enumerate()
.filter(|(_, &a)| a)
.map(|(i, _)| (i, distances[i]))
.collect();
results.sort_by(|a, b| a.1.total_cmp(&b.1));
results.truncate(k);
BatchKnnResult {
indices: results.iter().map(|(i, _)| *i).collect(),
scores: results.iter().map(|(_, d)| *d).collect(),
}
}
#[must_use]
pub fn batch_dimension_variance(batch: &VerticalBatch) -> Vec<f32> {
if batch.num_vectors <= 1 || batch.dimension == 0 {
return vec![0.0; batch.dimension];
}
let n = batch.num_vectors as f32;
let mut variances = Vec::with_capacity(batch.dimension);
for d in 0..batch.dimension {
let dim_slice = batch.dimension_slice(d);
let mean: f32 = dim_slice.iter().sum::<f32>() / n;
let var: f32 = dim_slice
.iter()
.map(|&x| (x - mean) * (x - mean))
.sum::<f32>()
/ n;
variances.push(var);
}
variances
}
#[must_use]
fn variance_order(variances: &[f32]) -> Vec<usize> {
let mut order: Vec<usize> = (0..variances.len()).collect();
order.sort_by(|&a, &b| variances[b].total_cmp(&variances[a]));
order
}
#[must_use]
pub fn batch_knn_reordered(query: &[f32], batch: &VerticalBatch, k: usize) -> BatchKnnResult {
assert_eq!(query.len(), batch.dimension);
if batch.num_vectors == 0 || k == 0 {
return BatchKnnResult {
indices: Vec::new(),
scores: Vec::new(),
};
}
let k = k.min(batch.num_vectors);
let variances = batch_dimension_variance(batch);
let order = variance_order(&variances);
let mut distances = vec![0.0f32; batch.num_vectors];
for &d in &order {
let q_d = query[d];
let dim_slice = batch.dimension_slice(d);
for (dist, &v_d) in distances.iter_mut().zip(dim_slice.iter()) {
let diff = q_d - v_d;
*dist += diff * diff;
}
}
let mut indexed: Vec<(usize, f32)> = distances.into_iter().enumerate().collect();
indexed.sort_by(|a, b| a.1.total_cmp(&b.1));
indexed.truncate(k);
BatchKnnResult {
indices: indexed.iter().map(|(i, _)| *i).collect(),
scores: indexed.iter().map(|(_, d)| *d).collect(),
}
}
#[must_use]
pub fn batch_norms(batch: &VerticalBatch) -> Vec<f32> {
let mut norms = vec![0.0f32; batch.num_vectors];
for d in 0..batch.dimension {
let dim_slice = batch.dimension_slice(d);
for (norm, &v_d) in norms.iter_mut().zip(dim_slice.iter()) {
*norm += v_d * v_d;
}
}
for norm in &mut norms {
*norm = norm.sqrt();
}
norms
}
#[must_use]
pub fn batch_cosine(query: &[f32], batch: &VerticalBatch, norms: &[f32]) -> Vec<f32> {
assert_eq!(norms.len(), batch.num_vectors);
let dots = batch_dot(query, batch);
let query_norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
if query_norm < crate::NORM_EPSILON {
return vec![0.0; batch.num_vectors];
}
dots.into_iter()
.zip(norms.iter())
.map(|(dot, &norm)| {
if norm > crate::NORM_EPSILON {
dot / (query_norm * norm)
} else {
0.0
}
})
.collect()
}
#[must_use]
pub fn batch_knn_dot(query: &[f32], batch: &VerticalBatch, k: usize) -> BatchKnnResult {
assert_eq!(query.len(), batch.dimension);
if batch.num_vectors == 0 || k == 0 {
return BatchKnnResult {
indices: Vec::new(),
scores: Vec::new(),
};
}
let k = k.min(batch.num_vectors);
let dots = batch_dot(query, batch);
let mut indexed: Vec<(usize, f32)> = dots.into_iter().enumerate().collect();
indexed.sort_by(|a, b| b.1.total_cmp(&a.1));
indexed.truncate(k);
BatchKnnResult {
indices: indexed.iter().map(|(i, _)| *i).collect(),
scores: indexed.iter().map(|(_, s)| *s).collect(),
}
}
#[must_use]
pub fn batch_knn_cosine(query: &[f32], batch: &VerticalBatch, k: usize) -> BatchKnnResult {
assert_eq!(query.len(), batch.dimension);
if batch.num_vectors == 0 || k == 0 {
return BatchKnnResult {
indices: Vec::new(),
scores: Vec::new(),
};
}
let k = k.min(batch.num_vectors);
let norms = batch_norms(batch);
let cosines = batch_cosine(query, batch, &norms);
let mut indexed: Vec<(usize, f32)> = cosines.into_iter().enumerate().collect();
indexed.sort_by(|a, b| b.1.total_cmp(&a.1));
indexed.truncate(k);
BatchKnnResult {
indices: indexed.iter().map(|(i, _)| *i).collect(),
scores: indexed.iter().map(|(_, s)| *s).collect(),
}
}
#[must_use]
pub fn batch_knn_filtered<F>(
query: &[f32],
batch: &VerticalBatch,
k: usize,
predicate: F,
) -> BatchKnnResult
where
F: Fn(usize) -> bool,
{
assert_eq!(query.len(), batch.dimension);
if batch.num_vectors == 0 || k == 0 {
return BatchKnnResult {
indices: Vec::new(),
scores: Vec::new(),
};
}
let mask: Vec<bool> = (0..batch.num_vectors).map(&predicate).collect();
let num_passing = mask.iter().filter(|&&m| m).count();
if num_passing == 0 {
return BatchKnnResult {
indices: Vec::new(),
scores: Vec::new(),
};
}
let k = k.min(num_passing);
let mut distances: Vec<f32> = mask
.iter()
.map(|&m| if m { 0.0 } else { f32::MAX })
.collect();
for (d, &q_d) in query.iter().enumerate().take(batch.dimension) {
let dim_slice = batch.dimension_slice(d);
for (i, (&v_d, dist)) in dim_slice.iter().zip(distances.iter_mut()).enumerate() {
if mask[i] {
let diff = q_d - v_d;
*dist += diff * diff;
}
}
}
let mut indexed: Vec<(usize, f32)> = distances
.into_iter()
.enumerate()
.filter(|(i, _)| mask[*i])
.collect();
indexed.sort_by(|a, b| a.1.total_cmp(&b.1));
indexed.truncate(k);
BatchKnnResult {
indices: indexed.iter().map(|(i, _)| *i).collect(),
scores: indexed.iter().map(|(_, d)| *d).collect(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vertical_batch_creation() {
let vectors = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
let batch = VerticalBatch::from_rows(&vectors);
assert_eq!(batch.num_vectors(), 2);
assert_eq!(batch.dimension(), 3);
assert_eq!(batch.get(0, 0), 1.0); assert_eq!(batch.get(0, 1), 4.0); assert_eq!(batch.get(1, 0), 2.0); assert_eq!(batch.get(2, 1), 6.0); }
#[test]
fn test_batch_l2_squared() {
let vectors = vec![
vec![0.0, 0.0, 0.0],
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![1.0, 1.0, 0.0];
let distances = batch_l2_squared(&query, &batch);
assert!((distances[0] - 2.0).abs() < 1e-6); assert!((distances[1] - 1.0).abs() < 1e-6); assert!((distances[2] - 1.0).abs() < 1e-6); }
#[test]
fn test_batch_dot() {
let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![1.0, 2.0];
let dots = batch_dot(&query, &batch);
assert!((dots[0] - 1.0).abs() < 1e-6); assert!((dots[1] - 2.0).abs() < 1e-6); assert!((dots[2] - 3.0).abs() < 1e-6); }
#[test]
fn test_batch_knn() {
let vectors = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![2.0, 0.0],
vec![3.0, 0.0],
];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![0.5, 0.0];
let result = batch_knn(&query, &batch, 2);
assert_eq!(result.indices.len(), 2);
assert!(result.indices.contains(&0));
assert!(result.indices.contains(&1));
}
#[test]
fn test_batch_pruning() {
let vectors = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![10.0, 0.0], ];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![0.0, 0.0];
let survivors = batch_l2_squared_pruning(&query, &batch, 2.0);
assert_eq!(survivors.len(), 2);
let indices: Vec<usize> = survivors.iter().map(|(i, _)| *i).collect();
assert!(indices.contains(&0));
assert!(indices.contains(&1));
assert!(!indices.contains(&2));
}
#[test]
fn test_extract_vector() {
let vectors = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
let batch = VerticalBatch::from_rows(&vectors);
let extracted = batch.extract_vector(1);
assert_eq!(extracted, vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_batch_cosine() {
let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
let batch = VerticalBatch::from_rows(&vectors);
let norms = batch_norms(&batch);
let query = vec![1.0, 0.0];
let cosines = batch_cosine(&query, &batch, &norms);
assert!((cosines[0] - 1.0).abs() < 1e-6); assert!(cosines[1].abs() < 1e-6); assert!((cosines[2] - std::f32::consts::FRAC_1_SQRT_2).abs() < 0.01); }
#[test]
fn test_empty_batch() {
let batch = VerticalBatch::from_rows(&[]);
assert_eq!(batch.num_vectors(), 0);
assert_eq!(batch.dimension(), 0);
}
#[test]
fn test_single_vector_batch() {
let vectors = vec![vec![1.0, 2.0, 3.0]];
let batch = VerticalBatch::from_rows(&vectors);
assert_eq!(batch.num_vectors(), 1);
assert_eq!(batch.dimension(), 3);
assert_eq!(batch.get(0, 0), 1.0);
assert_eq!(batch.get(1, 0), 2.0);
assert_eq!(batch.get(2, 0), 3.0);
assert_eq!(batch.extract_vector(0), vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_from_flat_matches_from_rows() {
let vectors = vec![
vec![1.0, 2.0, 3.0],
vec![4.0, 5.0, 6.0],
vec![7.0, 8.0, 9.0],
];
let flat: Vec<f32> = vectors.iter().flatten().copied().collect();
let batch_rows = VerticalBatch::from_rows(&vectors);
let batch_flat = VerticalBatch::from_flat(&flat, 3, 3);
for d in 0..3 {
for v in 0..3 {
assert_eq!(
batch_rows.get(d, v),
batch_flat.get(d, v),
"mismatch at dim={d}, vec={v}"
);
}
}
}
#[test]
fn test_from_flat_single_vector() {
let flat = [10.0, 20.0];
let batch = VerticalBatch::from_flat(&flat, 1, 2);
assert_eq!(batch.num_vectors(), 1);
assert_eq!(batch.dimension(), 2);
assert_eq!(batch.extract_vector(0), vec![10.0, 20.0]);
}
#[test]
fn test_dimension_slice() {
let vectors = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
let batch = VerticalBatch::from_rows(&vectors);
assert_eq!(batch.dimension_slice(0), &[1.0, 3.0, 5.0]);
assert_eq!(batch.dimension_slice(1), &[2.0, 4.0, 6.0]);
}
#[test]
fn test_batch_norms() {
let vectors = vec![vec![3.0, 4.0], vec![0.0, 0.0], vec![1.0, 0.0]];
let batch = VerticalBatch::from_rows(&vectors);
let norms = batch_norms(&batch);
assert!((norms[0] - 5.0).abs() < 1e-6); assert!(norms[1].abs() < 1e-6); assert!((norms[2] - 1.0).abs() < 1e-6); }
#[test]
fn test_batch_l2_squared_exact_match() {
let vectors = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![3.0, 4.0];
let distances = batch_l2_squared(&query, &batch);
assert!(
distances[1].abs() < 1e-9,
"exact match should have distance ~0"
);
assert!(distances[0] > 0.0);
assert!(distances[2] > 0.0);
}
#[test]
fn test_batch_dot_zero_query() {
let vectors = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![0.0, 0.0];
let dots = batch_dot(&query, &batch);
assert_eq!(dots, vec![0.0, 0.0]);
}
#[test]
fn test_batch_cosine_zero_query() {
let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let batch = VerticalBatch::from_rows(&vectors);
let norms = batch_norms(&batch);
let query = vec![0.0, 0.0];
let cosines = batch_cosine(&query, &batch, &norms);
assert_eq!(cosines, vec![0.0, 0.0]);
}
#[test]
fn test_batch_cosine_zero_norm_vector() {
let vectors = vec![vec![1.0, 0.0], vec![0.0, 0.0]];
let batch = VerticalBatch::from_rows(&vectors);
let norms = batch_norms(&batch);
let query = vec![1.0, 0.0];
let cosines = batch_cosine(&query, &batch, &norms);
assert!((cosines[0] - 1.0).abs() < 1e-6); assert_eq!(cosines[1], 0.0); }
#[test]
fn test_batch_knn_dot_basic() {
let vectors = vec![
vec![1.0, 0.0], vec![0.0, 1.0], vec![-1.0, 0.0], ];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![1.0, 0.0];
let result = batch_knn_dot(&query, &batch, 2);
assert_eq!(result.indices[0], 0); assert!((result.scores[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_batch_knn_dot_sorted_descending() {
let vectors = vec![vec![0.5, 0.5], vec![1.0, 0.0], vec![0.0, 1.0]];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![1.0, 0.0];
let result = batch_knn_dot(&query, &batch, 3);
for w in result.scores.windows(2) {
assert!(w[0] >= w[1], "not sorted descending: {:?}", result.scores);
}
}
#[test]
fn test_batch_knn_reordered_matches_exact() {
let vectors: Vec<Vec<f32>> = (0..50)
.map(|i| (0..16).map(|d| ((i * 7 + d * 3) as f32).sin()).collect())
.collect();
let batch = VerticalBatch::from_rows(&vectors);
let query: Vec<f32> = (0..16).map(|i| (i as f32 * 0.1).cos()).collect();
let exact = batch_knn(&query, &batch, 5);
let reordered = batch_knn_reordered(&query, &batch, 5);
assert_eq!(exact.indices, reordered.indices);
for (e, r) in exact.scores.iter().zip(&reordered.scores) {
assert!(
(e - r).abs() < 1e-4,
"distance mismatch: exact={e}, reordered={r}"
);
}
}
#[test]
fn test_batch_knn_reordered_empty() {
let batch = VerticalBatch::from_rows(&[]);
let result = batch_knn_reordered(&[], &batch, 5);
assert!(result.indices.is_empty());
}
#[test]
fn test_batch_dimension_variance() {
let vectors = vec![vec![1.0, 0.0], vec![1.0, 5.0], vec![1.0, 10.0]];
let batch = VerticalBatch::from_rows(&vectors);
let var = batch_dimension_variance(&batch);
assert!(var[0].abs() < 1e-6, "constant dim should have 0 variance");
assert!(
var[1] > 10.0,
"varying dim should have high variance: {}",
var[1]
);
}
#[test]
fn test_from_slices_matches_from_rows() {
let v0 = [1.0f32, 2.0, 3.0];
let v1 = [4.0f32, 5.0, 6.0];
let v2 = [7.0f32, 8.0, 9.0];
let from_rows = VerticalBatch::from_rows(&[v0.to_vec(), v1.to_vec(), v2.to_vec()]);
let from_slices = VerticalBatch::from_slices(&[&v0, &v1, &v2]);
for d in 0..3 {
for v in 0..3 {
assert_eq!(
from_rows.get(d, v),
from_slices.get(d, v),
"mismatch at dim={d}, vec={v}"
);
}
}
}
#[test]
fn test_from_slices_empty() {
let batch = VerticalBatch::from_slices(&[]);
assert_eq!(batch.num_vectors(), 0);
assert_eq!(batch.dimension(), 0);
}
#[test]
fn test_batch_knn_cosine_basic() {
let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![-1.0, 0.0]];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![1.0, 0.0];
let result = batch_knn_cosine(&query, &batch, 2);
assert_eq!(result.indices.len(), 2);
assert_eq!(result.indices[0], 0); assert_eq!(result.indices[1], 1); assert!((result.scores[0] - 1.0).abs() < 1e-5);
assert!(result.scores[1].abs() < 1e-5);
}
#[test]
fn test_batch_knn_cosine_empty() {
let batch = VerticalBatch::from_rows(&[]);
let result = batch_knn_cosine(&[], &batch, 5);
assert!(result.indices.is_empty());
}
#[test]
fn test_batch_knn_cosine_sorted_descending() {
let vectors = vec![
vec![0.1, 1.0], vec![1.0, 0.0], vec![0.5, 0.5], ];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![1.0, 0.0];
let result = batch_knn_cosine(&query, &batch, 3);
for w in result.scores.windows(2) {
assert!(
w[0] >= w[1],
"cosine kNN not sorted descending: {:?}",
result.scores
);
}
assert_eq!(result.indices[0], 1);
}
#[test]
fn test_filtered_knn_basic() {
let vectors = vec![
vec![0.0, 0.0], vec![1.0, 0.0], vec![0.1, 0.0], vec![10.0, 0.0], ];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![0.0, 0.0];
let result = batch_knn_filtered(&query, &batch, 2, |i| i % 2 == 0);
assert_eq!(result.indices.len(), 2);
assert_eq!(result.indices[0], 0); assert_eq!(result.indices[1], 2); }
#[test]
fn test_filtered_knn_none_pass() {
let vectors = vec![vec![1.0, 0.0], vec![2.0, 0.0]];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![0.0, 0.0];
let result = batch_knn_filtered(&query, &batch, 2, |_| false);
assert!(result.indices.is_empty());
}
#[test]
fn test_filtered_knn_all_pass() {
let vectors = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![2.0, 0.0]];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![0.0, 0.0];
let filtered = batch_knn_filtered(&query, &batch, 2, |_| true);
let unfiltered = batch_knn(&query, &batch, 2);
assert_eq!(filtered.indices, unfiltered.indices);
}
#[test]
fn test_filtered_knn_k_larger_than_passing() {
let vectors = vec![vec![1.0], vec![2.0], vec![3.0]];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![0.0];
let result = batch_knn_filtered(&query, &batch, 10, |i| i == 0);
assert_eq!(result.indices.len(), 1);
assert_eq!(result.indices[0], 0);
}
#[test]
fn test_filtered_knn_preserves_original_indices() {
let vectors = vec![
vec![100.0], vec![100.0], vec![0.1], vec![100.0], vec![0.2], ];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![0.0];
let result = batch_knn_filtered(&query, &batch, 2, |i| i == 2 || i == 4);
assert_eq!(result.indices, vec![2, 4]); }
#[test]
fn test_batch_knn_k_zero() {
let vectors = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![1.0, 0.0];
let result = batch_knn(&query, &batch, 0);
assert!(result.indices.is_empty());
assert!(result.scores.is_empty());
}
#[test]
fn test_batch_knn_empty_batch() {
let batch = VerticalBatch::from_rows(&[]);
let result = batch_knn(&[], &batch, 5);
assert!(result.indices.is_empty());
}
#[test]
fn test_batch_knn_k_larger_than_n() {
let vectors = vec![vec![1.0], vec![2.0]];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![1.5];
let result = batch_knn(&query, &batch, 10);
assert_eq!(result.indices.len(), 2);
}
#[test]
fn test_batch_knn_sorted_by_distance() {
let vectors = vec![
vec![10.0, 0.0],
vec![1.0, 0.0],
vec![5.0, 0.0],
vec![0.0, 0.0],
];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![0.0, 0.0];
let result = batch_knn(&query, &batch, 4);
for w in result.scores.windows(2) {
assert!(w[0] <= w[1], "distances not sorted: {:?}", result.scores);
}
assert_eq!(result.indices[0], 3);
}
#[test]
fn test_pruning_threshold_zero() {
let vectors = vec![vec![0.0, 0.0], vec![1.0, 0.0], vec![0.0, 1.0]];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![0.0, 0.0];
let survivors = batch_l2_squared_pruning(&query, &batch, 0.0);
assert_eq!(survivors.len(), 1);
assert_eq!(survivors[0].0, 0);
assert!(survivors[0].1.abs() < 1e-9);
}
#[test]
fn test_pruning_all_survive() {
let vectors = vec![vec![0.1, 0.0], vec![0.0, 0.1]];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![0.0, 0.0];
let survivors = batch_l2_squared_pruning(&query, &batch, 100.0);
assert_eq!(survivors.len(), 2);
}
#[test]
fn test_pruning_none_survive() {
let vectors = vec![vec![10.0, 0.0], vec![0.0, 10.0]];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![0.0, 0.0];
let survivors = batch_l2_squared_pruning(&query, &batch, 0.5);
assert!(survivors.is_empty());
}
#[test]
fn test_batch_knn_adaptive_empty() {
let batch = VerticalBatch::from_rows(&[]);
let result = batch_knn_adaptive(&[], &batch, 5, 2);
assert!(result.indices.is_empty());
}
#[test]
fn test_batch_knn_adaptive_k_zero() {
let vectors = vec![vec![1.0, 2.0]];
let batch = VerticalBatch::from_rows(&vectors);
let result = batch_knn_adaptive(&[1.0, 2.0], &batch, 0, 1);
assert!(result.indices.is_empty());
}
#[test]
fn test_batch_knn_adaptive_finds_nearest() {
let vectors = vec![
vec![0.0, 0.0, 0.0, 0.0],
vec![100.0, 100.0, 100.0, 100.0],
vec![0.1, 0.1, 0.1, 0.1],
];
let batch = VerticalBatch::from_rows(&vectors);
let query = vec![0.0, 0.0, 0.0, 0.0];
let exact = batch_knn(&query, &batch, 1);
let adaptive = batch_knn_adaptive(&query, &batch, 1, 2);
assert_eq!(exact.indices[0], 0);
assert_eq!(adaptive.indices[0], 0);
}
#[test]
fn test_batch_l2_squared_large() {
let n = 32;
let dim = 8;
let vectors: Vec<Vec<f32>> = (0..n)
.map(|i| (0..dim).map(|d| (i * dim + d) as f32).collect())
.collect();
let batch = VerticalBatch::from_rows(&vectors);
let query: Vec<f32> = vectors[0].clone();
let distances = batch_l2_squared(&query, &batch);
assert!(distances[0].abs() < 1e-9, "self-distance should be ~0");
for (i, &d) in distances.iter().enumerate().skip(1) {
assert!(
d > 0.0,
"distance to vector {i} should be positive, got {d}"
);
}
}
#[test]
fn test_extract_all_vectors_roundtrip() {
let vectors = vec![
vec![1.0, 2.0, 3.0],
vec![4.0, 5.0, 6.0],
vec![7.0, 8.0, 9.0],
];
let batch = VerticalBatch::from_rows(&vectors);
for (i, original) in vectors.iter().enumerate() {
assert_eq!(
batch.extract_vector(i),
*original,
"roundtrip failed for vector {i}"
);
}
}
}