pub mod chunk;
pub mod store;
pub mod embed_registry;
#[cfg(any(feature = "embed-tract", feature = "embed-ort"))]
pub mod embed_support;
#[cfg(feature = "embed-tract")]
pub mod embed; #[cfg(feature = "embed-ort")]
pub mod embed_ort; #[cfg(feature = "embed-ort")]
pub mod cuda; #[cfg(feature = "embed-ort-rocm")]
pub mod rocm;
#[cfg(any(feature = "embed-tract", feature = "embed-ort"))]
#[allow(clippy::needless_return)] pub fn load_embedder() -> anyhow::Result<Box<dyn store::Embedder>> {
#[cfg(feature = "embed-ort")]
{
return Ok(Box::new(embed_ort::OrtEmbedder::load()?));
}
#[cfg(all(feature = "embed-tract", not(feature = "embed-ort")))]
{
return Ok(Box::new(embed::JinaEmbedder::load()?));
}
}
#[cfg(any(feature = "embed-tract", feature = "embed-ort"))]
pub fn embedder_backend() -> &'static str {
#[cfg(feature = "embed-ort")]
{
"ort (ONNX Runtime, CUDA→CPU)"
}
#[cfg(all(feature = "embed-tract", not(feature = "embed-ort")))]
{
"tract-onnx (CPU, pure Rust)"
}
}
pub fn selected_model_id() -> &'static str {
embed_registry::selected().map(|m| m.id).unwrap_or("<invalid>")
}
pub fn selected_model_desc() -> String {
match embed_registry::selected() {
Ok(m) => format!("{} ({}-dim)", m.model_name, m.dim),
Err(e) => e,
}
}
use std::collections::HashMap;
use std::path::Path;
use anyhow::{bail, ensure, Context, Result};
const MIN_ROWS_PER_THREAD: usize = 1024;
const MAGIC: &[u8; 4] = b"NVF1";
pub struct VectorIndex {
dim: usize,
ids: Vec<u64>,
data: Vec<f32>,
pos: HashMap<u64, usize>,
}
impl VectorIndex {
pub fn new(dim: usize) -> Result<Self> {
ensure!(dim != 0, "vector dim must be non-zero");
Ok(Self {
dim,
ids: Vec::new(),
data: Vec::new(),
pos: HashMap::new(),
})
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn len(&self) -> usize {
self.ids.len()
}
pub fn is_empty(&self) -> bool {
self.ids.is_empty()
}
pub fn contains(&self, id: u64) -> bool {
self.pos.contains_key(&id)
}
pub fn add(&mut self, vectors: &[f32], ids: &[u64]) -> Result<()> {
ensure!(
vectors.len() == ids.len() * self.dim,
"vectors len {} != ids len {} * dim {}",
vectors.len(),
ids.len(),
self.dim
);
let mut seen = std::collections::HashSet::with_capacity(ids.len());
for &id in ids {
ensure!(
!self.pos.contains_key(&id) && seen.insert(id),
"duplicate id {id}"
);
}
self.ids.reserve(ids.len());
self.data.reserve(vectors.len());
self.pos.reserve(ids.len());
for (i, &id) in ids.iter().enumerate() {
let row = &vectors[i * self.dim..(i + 1) * self.dim];
let row_idx = self.ids.len();
push_normalized(&mut self.data, row);
self.ids.push(id);
self.pos.insert(id, row_idx);
}
Ok(())
}
pub fn remove(&mut self, id: u64) -> bool {
let Some(idx) = self.pos.remove(&id) else {
return false;
};
let last = self.ids.len() - 1;
let dim = self.dim;
if idx != last {
self.data
.copy_within(last * dim..(last + 1) * dim, idx * dim);
let moved_id = self.ids[last];
self.ids[idx] = moved_id;
self.pos.insert(moved_id, idx);
}
self.ids.pop();
self.data.truncate(last * dim);
true
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(u64, f32)> {
assert_eq!(
query.len(),
self.dim,
"query dim {} != index dim {}",
query.len(),
self.dim
);
let n = self.ids.len();
let m = k.min(n);
if m == 0 {
return Vec::new();
}
let qn = normalized(query);
let kernel = select_dot_kernel();
let threads = thread_count(n);
let mut merged = if threads <= 1 {
self.score_range(0, n, &qn, kernel, m)
} else {
let chunk = n.div_ceil(threads);
std::thread::scope(|s| {
let mut handles = Vec::with_capacity(threads);
let mut start = 0;
while start < n {
let end = (start + chunk).min(n);
let qn = &qn;
handles.push(s.spawn(move || self.score_range(start, end, qn, kernel, m)));
start = end;
}
let mut out = Vec::with_capacity(handles.len() * m);
for h in handles {
out.extend(h.join().expect("scoring thread panicked"));
}
out
})
};
top_k(&mut merged, m);
merged
}
pub fn search_i8(&self, query: &[f32], k: usize) -> Vec<(u64, f32)> {
assert_eq!(query.len(), self.dim, "query dim {} != index dim {}", query.len(), self.dim);
let n = self.ids.len();
let m = k.min(n);
if m == 0 {
return Vec::new();
}
let (rows, sums) = self.quantized();
let scores = score_i8_batch(query, &rows, self.dim, &sums);
let mut scored: Vec<(u64, f32)> = self.ids.iter().copied().zip(scores).collect();
top_k(&mut scored, m);
scored
}
pub fn quantized(&self) -> (Vec<i8>, Vec<i32>) {
let n = self.ids.len();
let mut rows = Vec::with_capacity(n * self.dim);
let mut sums = Vec::with_capacity(n);
for idx in 0..n {
let row = &self.data[idx * self.dim..(idx + 1) * self.dim];
sums.push(quantize_i8(row, &mut rows));
}
(rows, sums)
}
fn score_range(&self, start: usize, end: usize, qn: &[f32], kernel: DotFn, m: usize) -> Vec<(u64, f32)> {
let mut local: Vec<(u64, f32)> = Vec::with_capacity(end - start);
for idx in start..end {
let row = &self.data[idx * self.dim..(idx + 1) * self.dim];
let score = unsafe { kernel(qn, row) };
local.push((self.ids[idx], score));
}
top_k(&mut local, m);
local
}
pub fn write(&self, path: impl AsRef<Path>) -> Result<()> {
let path = path.as_ref();
let n = self.ids.len();
let mut buf = Vec::with_capacity(16 + n * 8 + self.data.len() * 4);
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&(self.dim as u32).to_le_bytes());
buf.extend_from_slice(&(n as u64).to_le_bytes());
for &id in &self.ids {
buf.extend_from_slice(&id.to_le_bytes());
}
for &f in &self.data {
buf.extend_from_slice(&f.to_le_bytes());
}
std::fs::write(path, &buf).with_context(|| format!("write vector index {}", path.display()))
}
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let buf =
std::fs::read(path).with_context(|| format!("read vector index {}", path.display()))?;
ensure!(buf.len() >= 16, "vector index too short");
ensure!(&buf[0..4] == MAGIC, "bad vector index magic");
let dim = u32::from_le_bytes(buf[4..8].try_into().unwrap()) as usize;
let n = u64::from_le_bytes(buf[8..16].try_into().unwrap()) as usize;
ensure!(dim != 0, "vector index has zero dim");
let want = 16 + n * 8 + n * dim * 4;
if buf.len() != want {
bail!(
"vector index length {} != expected {want} (dim {dim}, n {n})",
buf.len()
);
}
let mut off = 16;
let mut ids = Vec::with_capacity(n);
let mut pos = HashMap::with_capacity(n);
for row_idx in 0..n {
let id = u64::from_le_bytes(buf[off..off + 8].try_into().unwrap());
off += 8;
ensure!(pos.insert(id, row_idx).is_none(), "duplicate id {id} in file");
ids.push(id);
}
let mut data = Vec::with_capacity(n * dim);
for _ in 0..n * dim {
data.push(f32::from_le_bytes(buf[off..off + 4].try_into().unwrap()));
off += 4;
}
Ok(Self { dim, ids, data, pos })
}
}
pub fn active_simd() -> &'static str {
#[cfg(target_arch = "x86_64")]
{
if std::is_x86_feature_detected!("avx512f") {
return "avx512f";
}
if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
return "avx2+fma";
}
}
"scalar"
}
type DotFn = unsafe fn(&[f32], &[f32]) -> f32;
fn select_dot_kernel() -> DotFn {
#[cfg(target_arch = "x86_64")]
{
if std::is_x86_feature_detected!("avx512f") {
return dot_avx512;
}
if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
return dot_avx2;
}
}
dot_scalar
}
unsafe fn dot_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(x, y)| x * y).sum()
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
unsafe fn dot_avx512(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let n = a.len();
let mut acc = _mm512_setzero_ps();
let mut i = 0;
while i + 16 <= n {
let va = _mm512_loadu_ps(a.as_ptr().add(i));
let vb = _mm512_loadu_ps(b.as_ptr().add(i));
acc = _mm512_fmadd_ps(va, vb, acc);
i += 16;
}
let mut s = _mm512_reduce_add_ps(acc);
while i < n {
s += a[i] * b[i];
i += 1;
}
s
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,fma")]
unsafe fn dot_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let n = a.len();
let mut acc = _mm256_setzero_ps();
let mut i = 0;
while i + 8 <= n {
let va = _mm256_loadu_ps(a.as_ptr().add(i));
let vb = _mm256_loadu_ps(b.as_ptr().add(i));
acc = _mm256_fmadd_ps(va, vb, acc);
i += 8;
}
let mut tmp = [0f32; 8];
_mm256_storeu_ps(tmp.as_mut_ptr(), acc);
let mut s = tmp.iter().sum::<f32>();
while i < n {
s += a[i] * b[i];
i += 1;
}
s
}
const Q: f32 = 127.0;
fn quantize_i8(v: &[f32], out: &mut Vec<i8>) -> i32 {
let mut sum = 0i32;
for &x in v {
let q = (x * Q).round().clamp(-127.0, 127.0) as i32;
sum += q;
out.push(q as i8);
}
sum
}
fn dot_i8_scalar(q: &[i8], r: &[i8]) -> f32 {
let mut acc = 0i32;
for (a, b) in q.iter().zip(r) {
acc += (*a as i32) * (*b as i32);
}
acc as f32 / (Q * Q)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f,avx512vnni,avx512bw")]
unsafe fn dot_i8_vnni(q_biased: &[u8], row: &[i8], row_sum: i32) -> f32 {
use std::arch::x86_64::*;
let n = row.len();
let mut acc = _mm512_setzero_si512();
let mut i = 0;
while i + 64 <= n {
let vu = _mm512_loadu_si512(q_biased.as_ptr().add(i) as *const _);
let vi = _mm512_loadu_si512(row.as_ptr().add(i) as *const _);
acc = _mm512_dpbusd_epi32(acc, vu, vi);
i += 64;
}
let mut biased = _mm512_reduce_add_epi32(acc);
while i < n {
biased += (q_biased[i] as i32) * (row[i] as i32);
i += 1;
}
((biased - 128 * row_sum) as f32) / (Q * Q)
}
pub fn vnni_available() -> bool {
#[cfg(target_arch = "x86_64")]
{
return std::is_x86_feature_detected!("avx512f")
&& std::is_x86_feature_detected!("avx512vnni")
&& std::is_x86_feature_detected!("avx512bw");
}
#[allow(unreachable_code)]
false
}
pub fn score_i8_batch(query: &[f32], rows: &[i8], dim: usize, row_sums: &[i32]) -> Vec<f32> {
let n = row_sums.len();
debug_assert_eq!(rows.len(), n * dim);
let mut q_i8 = Vec::with_capacity(dim);
let qn = normalized(query);
quantize_i8(&qn, &mut q_i8);
if vnni_available() {
let q_biased: Vec<u8> = q_i8.iter().map(|&x| (x as i16 + 128) as u8).collect();
(0..n)
.map(|i| {
let row = &rows[i * dim..(i + 1) * dim];
unsafe { dot_i8_vnni(&q_biased, row, row_sums[i]) }
})
.collect()
} else {
(0..n)
.map(|i| {
let row = &rows[i * dim..(i + 1) * dim];
dot_i8_scalar(&q_i8, row)
})
.collect()
}
}
#[derive(Debug, Clone)]
pub struct KernelTiming {
pub name: String,
pub micros: u128,
pub mdps: f64,
pub max_err: f32,
}
#[derive(Debug, Clone)]
pub struct BenchReport {
pub n: usize,
pub dim: usize,
pub simd_kernel: &'static str,
pub timings: Vec<KernelTiming>,
}
impl BenchReport {
pub fn simd_speedup(&self) -> f64 {
let s = self.timings.iter().find(|t| t.name == "scalar");
let v = self.timings.iter().find(|t| t.name.starts_with("simd"));
match (s, v) {
(Some(s), Some(v)) if v.micros > 0 => s.micros as f64 / v.micros as f64,
_ => 1.0,
}
}
pub fn int8_speedup(&self) -> f64 {
let s = self.timings.iter().find(|t| t.name == "scalar");
let q = self.timings.iter().find(|t| t.name.starts_with("int8"));
match (s, q) {
(Some(s), Some(q)) if q.micros > 0 => s.micros as f64 / q.micros as f64,
_ => 1.0,
}
}
}
pub fn bench_kernels(n: usize, dim: usize, iters: usize) -> BenchReport {
use std::time::Instant;
let iters = iters.max(1);
let mk = |seed: f32| -> Vec<f32> {
let v: Vec<f32> = (0..dim).map(|i| ((i as f32 + 1.0) * seed).sin()).collect();
normalized(&v)
};
let mut data = Vec::with_capacity(n * dim);
for r in 0..n {
data.extend(mk(0.001 + r as f32 * 0.0003));
}
let query = mk(0.737);
let qn = normalized(&query);
let mut reference = vec![0f32; n];
for (r, slot) in reference.iter_mut().enumerate() {
let row = &data[r * dim..(r + 1) * dim];
*slot = unsafe { dot_scalar(&qn, row) };
}
let dps = |micros: u128| -> f64 {
if micros == 0 { 0.0 } else { (n as f64) / (micros as f64) }
};
let mut timings = Vec::new();
let t = Instant::now();
for _ in 0..iters {
for r in 0..n {
let row = &data[r * dim..(r + 1) * dim];
std::hint::black_box(unsafe { dot_scalar(&qn, row) });
}
}
let micros = t.elapsed().as_micros() / iters as u128;
timings.push(KernelTiming { name: "scalar".into(), micros, mdps: dps(micros), max_err: 0.0 });
let kernel = select_dot_kernel();
let simd_kernel = active_simd();
let t = Instant::now();
let mut simd_err = 0f32;
for it in 0..iters {
for r in 0..n {
let row = &data[r * dim..(r + 1) * dim];
let s = unsafe { kernel(&qn, row) };
std::hint::black_box(s);
if it == 0 {
simd_err = simd_err.max((s - reference[r]).abs());
}
}
}
let micros = t.elapsed().as_micros() / iters as u128;
timings.push(KernelTiming {
name: format!("simd ({simd_kernel})"),
micros,
mdps: dps(micros),
max_err: simd_err,
});
let mut rows = Vec::with_capacity(n * dim);
let mut sums = Vec::with_capacity(n);
for r in 0..n {
sums.push(quantize_i8(&data[r * dim..(r + 1) * dim], &mut rows));
}
let i8_kernel = if vnni_available() { "vnni" } else { "scalar" };
let t = Instant::now();
let mut i8_scores = Vec::new();
for _ in 0..iters {
i8_scores = score_i8_batch(&query, &rows, dim, &sums);
std::hint::black_box(&i8_scores);
}
let micros = t.elapsed().as_micros() / iters as u128;
let i8_err = i8_scores
.iter()
.zip(&reference)
.map(|(a, b)| (a - b).abs())
.fold(0f32, f32::max);
timings.push(KernelTiming {
name: format!("int8 ({i8_kernel})"),
micros,
mdps: dps(micros),
max_err: i8_err,
});
BenchReport { n, dim, simd_kernel, timings }
}
fn thread_count(n: usize) -> usize {
if n < 2 * MIN_ROWS_PER_THREAD {
return 1;
}
let hw = std::thread::available_parallelism()
.map(|x| x.get())
.unwrap_or(1);
hw.min(n / MIN_ROWS_PER_THREAD).max(1)
}
fn normalized(v: &[f32]) -> Vec<f32> {
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
v.iter().map(|x| x / norm).collect()
} else {
v.to_vec()
}
}
fn push_normalized(data: &mut Vec<f32>, row: &[f32]) {
let norm = row.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
data.extend(row.iter().map(|x| x / norm));
} else {
data.extend_from_slice(row);
}
}
fn top_k(v: &mut Vec<(u64, f32)>, m: usize) {
if v.len() > m {
v.select_nth_unstable_by(m - 1, |a, b| b.1.total_cmp(&a.1));
v.truncate(m);
}
v.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
}
#[cfg(test)]
mod tests {
use super::*;
fn unit(dim: usize, axis: usize) -> Vec<f32> {
let mut v = vec![0.0f32; dim];
v[axis] = 1.0;
v
}
#[test]
fn rejects_zero_dim() {
match VectorIndex::new(0) {
Ok(_) => panic!("dim 0 should be rejected"),
Err(e) => assert!(e.to_string().contains("non-zero"), "{e}"),
}
}
#[test]
fn add_and_search_nearest() {
let mut idx = VectorIndex::new(8).unwrap();
idx.add(&unit(8, 0), &[10]).unwrap();
idx.add(&unit(8, 1), &[20]).unwrap();
idx.add(&unit(8, 2), &[30]).unwrap();
assert_eq!(idx.len(), 3);
assert!(!idx.is_empty());
let mut q = unit(8, 0);
q[1] = 0.1; let hits = idx.search(&q, 2);
assert_eq!(hits.len(), 2);
assert_eq!(hits[0].0, 10, "nearest is the axis-0 vector");
assert_eq!(hits[1].0, 20, "runner-up is the axis-1 vector");
assert!(hits[0].1 > hits[1].1, "scores sorted descending");
}
#[test]
fn add_rejects_wrong_buffer_len() {
let mut idx = VectorIndex::new(8).unwrap();
let err = idx.add(&[1.0, 2.0, 3.0, 4.0], &[1]).unwrap_err();
assert!(err.to_string().contains("!= ids len"), "{err}");
}
#[test]
fn add_rejects_duplicate_id() {
let mut idx = VectorIndex::new(8).unwrap();
idx.add(&unit(8, 0), &[7]).unwrap();
let err = idx.add(&unit(8, 1), &[7]).unwrap_err();
assert!(err.to_string().contains("duplicate id 7"), "{err}");
let mut two = unit(8, 0);
two.extend(unit(8, 1));
let err = idx.add(&two, &[9, 9]).unwrap_err();
assert!(err.to_string().contains("duplicate id 9"), "{err}");
}
#[test]
fn remove_and_contains() {
let mut idx = VectorIndex::new(8).unwrap();
idx.add(&unit(8, 0), &[10]).unwrap();
idx.add(&unit(8, 1), &[20]).unwrap();
idx.add(&unit(8, 2), &[30]).unwrap();
assert!(idx.contains(20));
assert!(idx.remove(20));
assert!(!idx.contains(20));
assert!(!idx.remove(20), "second remove is a no-op");
assert_eq!(idx.len(), 2);
let hits = idx.search(&unit(8, 2), 1);
assert_eq!(hits[0].0, 30);
}
#[test]
fn write_then_load_roundtrips() {
let mut idx = VectorIndex::new(8).unwrap();
idx.add(&unit(8, 0), &[10]).unwrap();
idx.add(&unit(8, 1), &[20]).unwrap();
idx.add(&unit(8, 2), &[30]).unwrap();
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("basis.nvf");
idx.write(&path).unwrap();
let loaded = VectorIndex::load(&path).unwrap();
assert_eq!(loaded.len(), 3);
assert_eq!(loaded.dim(), 8);
let hits = loaded.search(&unit(8, 2), 1);
assert_eq!(hits[0].0, 30, "nearest to axis-2 query is id 30");
}
#[test]
fn load_rejects_corrupt_header() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("bad.nvf");
std::fs::write(&path, b"NOPExxxxxxxxxxxx").unwrap();
assert!(VectorIndex::load(&path).is_err());
}
#[test]
fn high_dim_search_matches_reference() {
let dim = 768;
let mut idx = VectorIndex::new(dim).unwrap();
let mk = |seed: f32| -> Vec<f32> { (0..dim).map(|i| (i as f32 * seed).sin()).collect() };
let a = mk(0.013);
let b = mk(0.027);
let c = mk(0.041);
idx.add(&a, &[1]).unwrap();
idx.add(&b, &[2]).unwrap();
idx.add(&c, &[3]).unwrap();
let hits = idx.search(&b, 1);
assert_eq!(hits[0].0, 2);
assert!((hits[0].1 - 1.0).abs() < 1e-3, "score {}", hits[0].1);
}
#[test]
fn parallel_path_finds_exact_match() {
let dim = 32;
let n = 4 * MIN_ROWS_PER_THREAD; let mut idx = VectorIndex::new(dim).unwrap();
let target_id = 1234u64;
let mut flat = Vec::with_capacity(n * dim);
let mut ids = Vec::with_capacity(n);
for j in 0..n as u64 {
let axis = if j == target_id { 0 } else { 1 };
flat.extend(unit(dim, axis));
ids.push(j);
}
idx.add(&flat, &ids).unwrap();
assert!(thread_count(idx.len()) > 1, "test should hit the parallel path");
let hits = idx.search(&unit(dim, 0), 1);
assert_eq!(hits[0].0, target_id, "the lone axis-0 vector wins");
}
#[test]
fn active_simd_is_known() {
let s = active_simd();
assert!(
matches!(s, "avx512f" | "avx2+fma" | "scalar"),
"unexpected kernel {s}"
);
}
#[test]
fn simd_kernel_matches_scalar_dot() {
let dim = 768; let a: Vec<f32> = (0..dim).map(|i| ((i as f32 + 1.0) * 0.013).sin()).collect();
let b: Vec<f32> = (0..dim).map(|i| ((i as f32 + 1.0) * 0.027).cos()).collect();
let an = normalized(&a);
let bn = normalized(&b);
let scalar = unsafe { dot_scalar(&an, &bn) };
let kernel = select_dot_kernel();
let simd = unsafe { kernel(&an, &bn) };
assert!(
(scalar - simd).abs() < 1e-5,
"SIMD {} dot {simd} != scalar {scalar}",
active_simd()
);
}
#[test]
fn int8_matches_f32_within_tolerance() {
let dim = 768;
let n = 200;
let mut idx = VectorIndex::new(dim).unwrap();
let mk = |seed: f32| -> Vec<f32> {
(0..dim).map(|i| ((i as f32 + 1.0) * seed).sin()).collect()
};
let mut flat = Vec::with_capacity(n * dim);
let mut ids = Vec::with_capacity(n);
for r in 0..n as u64 {
flat.extend(mk(0.005 + r as f32 * 0.0007));
ids.push(r);
}
idx.add(&flat, &ids).unwrap();
let query = mk(0.005 + 42.0 * 0.0007); let f32_hits = idx.search(&query, 5);
let i8_hits = idx.search_i8(&query, 5);
assert_eq!(f32_hits[0].0, 42, "f32 top-1 should be the matching row");
assert_eq!(i8_hits[0].0, f32_hits[0].0, "int8 top-1 disagrees with f32");
use std::collections::HashMap;
let f32_map: HashMap<u64, f32> = f32_hits.iter().copied().collect();
for (id, s8) in &i8_hits {
if let Some(s32) = f32_map.get(id) {
assert!(
(s8 - s32).abs() < 4e-2,
"int8 cosine {s8} vs f32 {s32} for id {id} exceeds tolerance"
);
}
}
}
#[test]
fn bench_kernels_reports_real_numbers() {
let rep = bench_kernels(2000, 768, 3);
assert_eq!(rep.timings.len(), 3, "expected scalar+simd+int8 timings");
assert!(rep.timings.iter().any(|t| t.name == "scalar"));
assert!(rep.timings.iter().any(|t| t.name.starts_with("simd")));
assert!(rep.timings.iter().any(|t| t.name.starts_with("int8")));
for t in &rep.timings {
assert!(t.micros > 0, "kernel {} reported 0µs", t.name);
assert!(t.mdps > 0.0, "kernel {} reported 0 Mdps", t.name);
}
let simd = rep.timings.iter().find(|t| t.name.starts_with("simd")).unwrap();
assert!(simd.max_err < 1e-4, "simd error {} too high", simd.max_err);
let i8 = rep.timings.iter().find(|t| t.name.starts_with("int8")).unwrap();
assert!(i8.max_err < 4e-2, "int8 error {} too high", i8.max_err);
assert!(rep.simd_speedup() > 0.5, "implausible simd speedup {}", rep.simd_speedup());
}
}