use crate::index::{AnnIndex, RabitqPlusIndex, SearchResult};
use crate::RabitqError;
#[derive(Debug, Clone)]
pub struct KernelCaps {
pub accelerator: &'static str,
pub min_batch: usize,
pub max_dim: usize,
pub deterministic: bool,
}
impl KernelCaps {
pub const fn cpu_default() -> Self {
Self {
accelerator: "cpu",
min_batch: 1,
max_dim: usize::MAX,
deterministic: true,
}
}
}
pub struct ScanRequest<'a> {
pub index: &'a RabitqPlusIndex,
pub queries: &'a [Vec<f32>],
pub k: usize,
pub rerank_factor: Option<usize>,
}
pub type ScanResponse = Vec<Vec<SearchResult>>;
pub trait VectorKernel: Send + Sync {
fn id(&self) -> &str;
fn caps(&self) -> KernelCaps;
fn scan(&self, req: ScanRequest<'_>) -> Result<ScanResponse, RabitqError>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct CpuKernel;
impl CpuKernel {
pub const fn new() -> Self {
Self
}
}
impl VectorKernel for CpuKernel {
fn id(&self) -> &str {
"cpu"
}
fn caps(&self) -> KernelCaps {
KernelCaps::cpu_default()
}
fn scan(&self, req: ScanRequest<'_>) -> Result<ScanResponse, RabitqError> {
let mut out = Vec::with_capacity(req.queries.len());
for q in req.queries {
let hits = match req.rerank_factor {
None => req.index.search(q, req.k)?,
Some(rf) => req.index.search_with_rerank(q, req.k, rf)?,
};
out.push(hits);
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tiny_index() -> RabitqPlusIndex {
let d = 8;
let mut idx = RabitqPlusIndex::new(d, 42, 5);
for i in 0..16 {
let v: Vec<f32> = (0..d).map(|j| (i + j) as f32).collect();
idx.add(i, v).unwrap();
}
idx
}
#[test]
fn cpu_kernel_matches_direct_search() {
let idx = tiny_index();
let kernel = CpuKernel::new();
let q: Vec<f32> = vec![2.0; 8];
let direct = idx.search(&q, 4).unwrap();
let batched = kernel
.scan(ScanRequest {
index: &idx,
queries: std::slice::from_ref(&q),
k: 4,
rerank_factor: None,
})
.unwrap();
assert_eq!(batched.len(), 1);
let batch = &batched[0];
assert_eq!(batch.len(), direct.len());
for (a, b) in batch.iter().zip(direct.iter()) {
assert_eq!(a.id, b.id);
assert!((a.score - b.score).abs() < 1e-5);
}
}
#[test]
fn cpu_kernel_respects_rerank_override() {
let idx = tiny_index();
let kernel = CpuKernel::new();
let q: Vec<f32> = vec![2.0; 8];
let out = kernel
.scan(ScanRequest {
index: &idx,
queries: &[q.clone(), q.clone()],
k: 3,
rerank_factor: Some(2),
})
.unwrap();
assert_eq!(out.len(), 2, "one result vec per input query");
for v in &out {
for w in v.windows(2) {
assert!(w[0].score <= w[1].score, "hits must be sorted");
}
}
}
#[test]
fn cpu_caps_are_deterministic_and_unbounded() {
let c = CpuKernel::new().caps();
assert_eq!(c.accelerator, "cpu");
assert_eq!(c.min_batch, 1);
assert_eq!(c.max_dim, usize::MAX);
assert!(c.deterministic);
}
}