pub mod chunk;
pub mod store;
pub mod embed_registry;
#[cfg(any(feature = "embed-tract", feature = "embed-ort"))]
pub mod embed_support;
#[cfg(feature = "embed-tract")]
pub mod embed; #[cfg(feature = "embed-ort")]
pub mod embed_ort; #[cfg(feature = "embed-ort")]
pub mod cuda;
#[cfg(any(feature = "embed-tract", feature = "embed-ort"))]
#[allow(clippy::needless_return)] pub fn load_embedder() -> anyhow::Result<Box<dyn store::Embedder>> {
#[cfg(feature = "embed-ort")]
{
return Ok(Box::new(embed_ort::OrtEmbedder::load()?));
}
#[cfg(all(feature = "embed-tract", not(feature = "embed-ort")))]
{
return Ok(Box::new(embed::JinaEmbedder::load()?));
}
}
#[cfg(any(feature = "embed-tract", feature = "embed-ort"))]
pub fn embedder_backend() -> &'static str {
#[cfg(feature = "embed-ort")]
{
"ort (ONNX Runtime, CUDA→CPU)"
}
#[cfg(all(feature = "embed-tract", not(feature = "embed-ort")))]
{
"tract-onnx (CPU, pure Rust)"
}
}
pub fn selected_model_id() -> &'static str {
embed_registry::selected().map(|m| m.id).unwrap_or("<invalid>")
}
pub fn selected_model_desc() -> String {
match embed_registry::selected() {
Ok(m) => format!("{} ({}-dim)", m.model_name, m.dim),
Err(e) => e,
}
}
use std::collections::HashMap;
use std::path::Path;
use anyhow::{bail, ensure, Context, Result};
const MIN_ROWS_PER_THREAD: usize = 1024;
const MAGIC: &[u8; 4] = b"NVF1";
pub struct VectorIndex {
dim: usize,
ids: Vec<u64>,
data: Vec<f32>,
pos: HashMap<u64, usize>,
}
impl VectorIndex {
pub fn new(dim: usize) -> Result<Self> {
ensure!(dim != 0, "vector dim must be non-zero");
Ok(Self {
dim,
ids: Vec::new(),
data: Vec::new(),
pos: HashMap::new(),
})
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn len(&self) -> usize {
self.ids.len()
}
pub fn is_empty(&self) -> bool {
self.ids.is_empty()
}
pub fn contains(&self, id: u64) -> bool {
self.pos.contains_key(&id)
}
pub fn add(&mut self, vectors: &[f32], ids: &[u64]) -> Result<()> {
ensure!(
vectors.len() == ids.len() * self.dim,
"vectors len {} != ids len {} * dim {}",
vectors.len(),
ids.len(),
self.dim
);
let mut seen = std::collections::HashSet::with_capacity(ids.len());
for &id in ids {
ensure!(
!self.pos.contains_key(&id) && seen.insert(id),
"duplicate id {id}"
);
}
self.ids.reserve(ids.len());
self.data.reserve(vectors.len());
self.pos.reserve(ids.len());
for (i, &id) in ids.iter().enumerate() {
let row = &vectors[i * self.dim..(i + 1) * self.dim];
let row_idx = self.ids.len();
push_normalized(&mut self.data, row);
self.ids.push(id);
self.pos.insert(id, row_idx);
}
Ok(())
}
pub fn remove(&mut self, id: u64) -> bool {
let Some(idx) = self.pos.remove(&id) else {
return false;
};
let last = self.ids.len() - 1;
let dim = self.dim;
if idx != last {
self.data
.copy_within(last * dim..(last + 1) * dim, idx * dim);
let moved_id = self.ids[last];
self.ids[idx] = moved_id;
self.pos.insert(moved_id, idx);
}
self.ids.pop();
self.data.truncate(last * dim);
true
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(u64, f32)> {
assert_eq!(
query.len(),
self.dim,
"query dim {} != index dim {}",
query.len(),
self.dim
);
let n = self.ids.len();
let m = k.min(n);
if m == 0 {
return Vec::new();
}
let qn = normalized(query);
let kernel = select_dot_kernel();
let threads = thread_count(n);
let mut merged = if threads <= 1 {
self.score_range(0, n, &qn, kernel, m)
} else {
let chunk = n.div_ceil(threads);
std::thread::scope(|s| {
let mut handles = Vec::with_capacity(threads);
let mut start = 0;
while start < n {
let end = (start + chunk).min(n);
let qn = &qn;
handles.push(s.spawn(move || self.score_range(start, end, qn, kernel, m)));
start = end;
}
let mut out = Vec::with_capacity(handles.len() * m);
for h in handles {
out.extend(h.join().expect("scoring thread panicked"));
}
out
})
};
top_k(&mut merged, m);
merged
}
fn score_range(&self, start: usize, end: usize, qn: &[f32], kernel: DotFn, m: usize) -> Vec<(u64, f32)> {
let mut local: Vec<(u64, f32)> = Vec::with_capacity(end - start);
for idx in start..end {
let row = &self.data[idx * self.dim..(idx + 1) * self.dim];
let score = unsafe { kernel(qn, row) };
local.push((self.ids[idx], score));
}
top_k(&mut local, m);
local
}
pub fn write(&self, path: impl AsRef<Path>) -> Result<()> {
let path = path.as_ref();
let n = self.ids.len();
let mut buf = Vec::with_capacity(16 + n * 8 + self.data.len() * 4);
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&(self.dim as u32).to_le_bytes());
buf.extend_from_slice(&(n as u64).to_le_bytes());
for &id in &self.ids {
buf.extend_from_slice(&id.to_le_bytes());
}
for &f in &self.data {
buf.extend_from_slice(&f.to_le_bytes());
}
std::fs::write(path, &buf).with_context(|| format!("write vector index {}", path.display()))
}
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let buf =
std::fs::read(path).with_context(|| format!("read vector index {}", path.display()))?;
ensure!(buf.len() >= 16, "vector index too short");
ensure!(&buf[0..4] == MAGIC, "bad vector index magic");
let dim = u32::from_le_bytes(buf[4..8].try_into().unwrap()) as usize;
let n = u64::from_le_bytes(buf[8..16].try_into().unwrap()) as usize;
ensure!(dim != 0, "vector index has zero dim");
let want = 16 + n * 8 + n * dim * 4;
if buf.len() != want {
bail!(
"vector index length {} != expected {want} (dim {dim}, n {n})",
buf.len()
);
}
let mut off = 16;
let mut ids = Vec::with_capacity(n);
let mut pos = HashMap::with_capacity(n);
for row_idx in 0..n {
let id = u64::from_le_bytes(buf[off..off + 8].try_into().unwrap());
off += 8;
ensure!(pos.insert(id, row_idx).is_none(), "duplicate id {id} in file");
ids.push(id);
}
let mut data = Vec::with_capacity(n * dim);
for _ in 0..n * dim {
data.push(f32::from_le_bytes(buf[off..off + 4].try_into().unwrap()));
off += 4;
}
Ok(Self { dim, ids, data, pos })
}
}
pub fn active_simd() -> &'static str {
#[cfg(target_arch = "x86_64")]
{
if std::is_x86_feature_detected!("avx512f") {
return "avx512f";
}
if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
return "avx2+fma";
}
}
"scalar"
}
type DotFn = unsafe fn(&[f32], &[f32]) -> f32;
fn select_dot_kernel() -> DotFn {
#[cfg(target_arch = "x86_64")]
{
if std::is_x86_feature_detected!("avx512f") {
return dot_avx512;
}
if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
return dot_avx2;
}
}
dot_scalar
}
unsafe fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(x, y)| x * y).sum()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn dot_avx512(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let n = a.len();
let mut acc = _mm512_setzero_ps();
let mut i = 0;
while i + 16 <= n {
let va = _mm512_loadu_ps(a.as_ptr().add(i));
let vb = _mm512_loadu_ps(b.as_ptr().add(i));
acc = _mm512_fmadd_ps(va, vb, acc);
i += 16;
}
let mut s = _mm512_reduce_add_ps(acc);
while i < n {
s += a[i] * b[i];
i += 1;
}
s
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn dot_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let n = a.len();
let mut acc = _mm256_setzero_ps();
let mut i = 0;
while i + 8 <= n {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vb = _mm256_loadu_ps(b.as_ptr().add(i));
acc = _mm256_fmadd_ps(va, vb, acc);
i += 8;
}
let mut tmp = [0f32; 8];
_mm256_storeu_ps(tmp.as_mut_ptr(), acc);
let mut s = tmp.iter().sum::<f32>();
while i < n {
s += a[i] * b[i];
i += 1;
}
s
}
fn thread_count(n: usize) -> usize {
if n < 2 * MIN_ROWS_PER_THREAD {
return 1;
}
let hw = std::thread::available_parallelism()
.map(|x| x.get())
.unwrap_or(1);
hw.min(n / MIN_ROWS_PER_THREAD).max(1)
}
fn normalized(v: &[f32]) -> Vec<f32> {
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
v.iter().map(|x| x / norm).collect()
} else {
v.to_vec()
}
}
fn push_normalized(data: &mut Vec<f32>, row: &[f32]) {
let norm = row.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
data.extend(row.iter().map(|x| x / norm));
} else {
data.extend_from_slice(row);
}
}
fn top_k(v: &mut Vec<(u64, f32)>, m: usize) {
if v.len() > m {
v.select_nth_unstable_by(m - 1, |a, b| b.1.total_cmp(&a.1));
v.truncate(m);
}
v.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
}
#[cfg(test)]
mod tests {
use super::*;
fn unit(dim: usize, axis: usize) -> Vec<f32> {
let mut v = vec![0.0f32; dim];
v[axis] = 1.0;
v
}
#[test]
fn rejects_zero_dim() {
match VectorIndex::new(0) {
Ok(_) => panic!("dim 0 should be rejected"),
Err(e) => assert!(e.to_string().contains("non-zero"), "{e}"),
}
}
#[test]
fn add_and_search_nearest() {
let mut idx = VectorIndex::new(8).unwrap();
idx.add(&unit(8, 0), &[10]).unwrap();
idx.add(&unit(8, 1), &[20]).unwrap();
idx.add(&unit(8, 2), &[30]).unwrap();
assert_eq!(idx.len(), 3);
assert!(!idx.is_empty());
let mut q = unit(8, 0);
q[1] = 0.1; let hits = idx.search(&q, 2);
assert_eq!(hits.len(), 2);
assert_eq!(hits[0].0, 10, "nearest is the axis-0 vector");
assert_eq!(hits[1].0, 20, "runner-up is the axis-1 vector");
assert!(hits[0].1 > hits[1].1, "scores sorted descending");
}
#[test]
fn add_rejects_wrong_buffer_len() {
let mut idx = VectorIndex::new(8).unwrap();
let err = idx.add(&[1.0, 2.0, 3.0, 4.0], &[1]).unwrap_err();
assert!(err.to_string().contains("!= ids len"), "{err}");
}
#[test]
fn add_rejects_duplicate_id() {
let mut idx = VectorIndex::new(8).unwrap();
idx.add(&unit(8, 0), &[7]).unwrap();
let err = idx.add(&unit(8, 1), &[7]).unwrap_err();
assert!(err.to_string().contains("duplicate id 7"), "{err}");
let mut two = unit(8, 0);
two.extend(unit(8, 1));
let err = idx.add(&two, &[9, 9]).unwrap_err();
assert!(err.to_string().contains("duplicate id 9"), "{err}");
}
#[test]
fn remove_and_contains() {
let mut idx = VectorIndex::new(8).unwrap();
idx.add(&unit(8, 0), &[10]).unwrap();
idx.add(&unit(8, 1), &[20]).unwrap();
idx.add(&unit(8, 2), &[30]).unwrap();
assert!(idx.contains(20));
assert!(idx.remove(20));
assert!(!idx.contains(20));
assert!(!idx.remove(20), "second remove is a no-op");
assert_eq!(idx.len(), 2);
let hits = idx.search(&unit(8, 2), 1);
assert_eq!(hits[0].0, 30);
}
#[test]
fn write_then_load_roundtrips() {
let mut idx = VectorIndex::new(8).unwrap();
idx.add(&unit(8, 0), &[10]).unwrap();
idx.add(&unit(8, 1), &[20]).unwrap();
idx.add(&unit(8, 2), &[30]).unwrap();
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("basis.nvf");
idx.write(&path).unwrap();
let loaded = VectorIndex::load(&path).unwrap();
assert_eq!(loaded.len(), 3);
assert_eq!(loaded.dim(), 8);
let hits = loaded.search(&unit(8, 2), 1);
assert_eq!(hits[0].0, 30, "nearest to axis-2 query is id 30");
}
#[test]
fn load_rejects_corrupt_header() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("bad.nvf");
std::fs::write(&path, b"NOPExxxxxxxxxxxx").unwrap();
assert!(VectorIndex::load(&path).is_err());
}
#[test]
fn high_dim_search_matches_reference() {
let dim = 768;
let mut idx = VectorIndex::new(dim).unwrap();
let mk = |seed: f32| -> Vec<f32> { (0..dim).map(|i| (i as f32 * seed).sin()).collect() };
let a = mk(0.013);
let b = mk(0.027);
let c = mk(0.041);
idx.add(&a, &[1]).unwrap();
idx.add(&b, &[2]).unwrap();
idx.add(&c, &[3]).unwrap();
let hits = idx.search(&b, 1);
assert_eq!(hits[0].0, 2);
assert!((hits[0].1 - 1.0).abs() < 1e-3, "score {}", hits[0].1);
}
#[test]
fn parallel_path_finds_exact_match() {
let dim = 32;
let n = 4 * MIN_ROWS_PER_THREAD; let mut idx = VectorIndex::new(dim).unwrap();
let target_id = 1234u64;
let mut flat = Vec::with_capacity(n * dim);
let mut ids = Vec::with_capacity(n);
for j in 0..n as u64 {
let axis = if j == target_id { 0 } else { 1 };
flat.extend(unit(dim, axis));
ids.push(j);
}
idx.add(&flat, &ids).unwrap();
assert!(thread_count(idx.len()) > 1, "test should hit the parallel path");
let hits = idx.search(&unit(dim, 0), 1);
assert_eq!(hits[0].0, target_id, "the lone axis-0 vector wins");
}
#[test]
fn active_simd_is_known() {
let s = active_simd();
assert!(
matches!(s, "avx512f" | "avx2+fma" | "scalar"),
"unexpected kernel {s}"
);
}
}