#![allow(clippy::unwrap_used)]
use std::sync::Arc;
use iqdb_index::{Index, IndexCore, IndexStats};
use iqdb_types::{DistanceMetric, Hit, IqdbError, Metadata, Result, SearchParams, VectorId};
struct Scan {
dim: usize,
metric: DistanceMetric,
rows: Vec<(VectorId, Arc<[f32]>, Option<Metadata>)>,
}
impl Scan {
fn new(dim: usize, metric: DistanceMetric) -> Result<Self> {
if dim == 0 {
return Err(IqdbError::InvalidConfig {
reason: "dim must be greater than zero",
});
}
Ok(Self {
dim,
metric,
rows: Vec::new(),
})
}
fn insert(&mut self, id: VectorId, v: Arc<[f32]>, m: Option<Metadata>) -> Result<()> {
if v.len() != self.dim {
return Err(IqdbError::DimensionMismatch {
expected: self.dim,
found: v.len(),
});
}
if self.rows.iter().any(|(e, _, _)| e == &id) {
return Err(IqdbError::Duplicate);
}
self.rows.push((id, v, m));
Ok(())
}
fn delete(&mut self, id: &VectorId) -> Result<()> {
match self.rows.iter().position(|(e, _, _)| e == id) {
Some(p) => {
let _ = self.rows.remove(p);
Ok(())
}
None => Err(IqdbError::NotFound),
}
}
fn search(&self, query: &[f32], params: &SearchParams) -> Result<Vec<Hit>> {
if query.len() != self.dim {
return Err(IqdbError::DimensionMismatch {
expected: self.dim,
found: query.len(),
});
}
if params.metric != self.metric {
return Err(IqdbError::InvalidMetric);
}
let mut hits: Vec<Hit> = self
.rows
.iter()
.map(|(id, v, m)| Hit {
id: id.clone(),
distance: query
.iter()
.zip(v.iter())
.map(|(a, b)| (a - b).powi(2))
.sum(),
metadata: m.clone(),
})
.collect();
hits.sort_by(|a, b| a.distance.total_cmp(&b.distance));
hits.truncate(params.k);
Ok(hits)
}
}
#[derive(Default, Clone)]
struct FlatConfig;
struct FlatLike(Scan);
impl IndexCore for FlatLike {
fn insert(&mut self, id: VectorId, v: Arc<[f32]>, m: Option<Metadata>) -> Result<()> {
self.0.insert(id, v, m)
}
fn delete(&mut self, id: &VectorId) -> Result<()> {
self.0.delete(id)
}
fn search(&self, q: &[f32], p: &SearchParams) -> Result<Vec<Hit>> {
self.0.search(q, p)
}
fn len(&self) -> usize {
self.0.rows.len()
}
fn is_empty(&self) -> bool {
self.0.rows.is_empty()
}
fn dim(&self) -> usize {
self.0.dim
}
fn metric(&self) -> DistanceMetric {
self.0.metric
}
fn flush(&mut self) -> Result<()> {
Ok(())
}
fn stats(&self) -> IndexStats {
IndexStats {
n_vectors: self.0.rows.len(),
index_type: "flat",
..IndexStats::default()
}
}
}
impl Index for FlatLike {
type Config = FlatConfig;
fn new(dim: usize, metric: DistanceMetric, _config: Self::Config) -> Result<Self> {
Ok(Self(Scan::new(dim, metric)?))
}
}
#[derive(Clone)]
struct HnswConfig {
m: usize,
ef_construction: usize,
}
impl Default for HnswConfig {
fn default() -> Self {
Self {
m: 16,
ef_construction: 200,
}
}
}
struct HnswLike {
scan: Scan,
m: usize,
}
impl IndexCore for HnswLike {
fn insert(&mut self, id: VectorId, v: Arc<[f32]>, m: Option<Metadata>) -> Result<()> {
self.scan.insert(id, v, m)
}
fn delete(&mut self, id: &VectorId) -> Result<()> {
self.scan.delete(id)
}
fn search(&self, q: &[f32], p: &SearchParams) -> Result<Vec<Hit>> {
self.scan.search(q, p)
}
fn len(&self) -> usize {
self.scan.rows.len()
}
fn is_empty(&self) -> bool {
self.scan.rows.is_empty()
}
fn dim(&self) -> usize {
self.scan.dim
}
fn metric(&self) -> DistanceMetric {
self.scan.metric
}
fn flush(&mut self) -> Result<()> {
Ok(())
}
fn stats(&self) -> IndexStats {
IndexStats {
n_vectors: self.scan.rows.len(),
index_type: "hnsw",
..IndexStats::default()
}
}
}
impl Index for HnswLike {
type Config = HnswConfig;
fn new(dim: usize, metric: DistanceMetric, config: Self::Config) -> Result<Self> {
if config.m == 0 || config.ef_construction == 0 {
return Err(IqdbError::InvalidConfig {
reason: "HNSW m and ef_construction must be greater than zero",
});
}
Ok(Self {
scan: Scan::new(dim, metric)?,
m: config.m,
})
}
}
#[derive(Clone)]
struct IvfConfig {
n_clusters: usize,
n_probes: usize,
}
impl Default for IvfConfig {
fn default() -> Self {
Self {
n_clusters: 100,
n_probes: 8,
}
}
}
struct IvfLike {
scan: Scan,
n_clusters: usize,
}
impl IndexCore for IvfLike {
fn insert(&mut self, id: VectorId, v: Arc<[f32]>, m: Option<Metadata>) -> Result<()> {
self.scan.insert(id, v, m)
}
fn delete(&mut self, id: &VectorId) -> Result<()> {
self.scan.delete(id)
}
fn search(&self, q: &[f32], p: &SearchParams) -> Result<Vec<Hit>> {
self.scan.search(q, p)
}
fn len(&self) -> usize {
self.scan.rows.len()
}
fn is_empty(&self) -> bool {
self.scan.rows.is_empty()
}
fn dim(&self) -> usize {
self.scan.dim
}
fn metric(&self) -> DistanceMetric {
self.scan.metric
}
fn flush(&mut self) -> Result<()> {
Ok(())
}
fn stats(&self) -> IndexStats {
IndexStats {
n_vectors: self.scan.rows.len(),
index_type: "ivf",
..IndexStats::default()
}
}
}
impl Index for IvfLike {
type Config = IvfConfig;
fn new(dim: usize, metric: DistanceMetric, config: Self::Config) -> Result<Self> {
if config.n_clusters == 0 || config.n_probes == 0 {
return Err(IqdbError::InvalidConfig {
reason: "IVF n_clusters and n_probes must be greater than zero",
});
}
Ok(Self {
scan: Scan::new(dim, metric)?,
n_clusters: config.n_clusters,
})
}
}
fn vec2(x: f32, y: f32) -> Arc<[f32]> {
Arc::from([x, y].as_slice())
}
#[test]
fn each_family_constructs_with_its_own_config() {
let flat = FlatLike::new(2, DistanceMetric::Euclidean, FlatConfig).unwrap();
assert_eq!(flat.stats().index_type, "flat");
let hnsw = HnswLike::new(
2,
DistanceMetric::Cosine,
HnswConfig {
m: 32,
ef_construction: 128,
},
)
.unwrap();
assert_eq!(hnsw.stats().index_type, "hnsw");
assert_eq!(hnsw.m, 32);
let ivf = IvfLike::new(
2,
DistanceMetric::Euclidean,
IvfConfig {
n_clusters: 64,
n_probes: 4,
},
)
.unwrap();
assert_eq!(ivf.stats().index_type, "ivf");
assert_eq!(ivf.n_clusters, 64);
}
#[test]
fn each_family_rejects_its_own_invalid_config() {
assert!(matches!(
HnswLike::new(
2,
DistanceMetric::Cosine,
HnswConfig {
m: 0,
ef_construction: 1
}
),
Err(IqdbError::InvalidConfig { .. })
));
assert!(matches!(
IvfLike::new(
2,
DistanceMetric::Euclidean,
IvfConfig {
n_clusters: 0,
n_probes: 1
}
),
Err(IqdbError::InvalidConfig { .. })
));
assert!(matches!(
FlatLike::new(0, DistanceMetric::Euclidean, FlatConfig),
Err(IqdbError::InvalidConfig { .. })
));
}
#[test]
fn all_three_families_coexist_behind_one_trait_object() {
let mut engine: Vec<Box<dyn IndexCore>> = vec![
Box::new(FlatLike::new(2, DistanceMetric::Euclidean, FlatConfig).unwrap()),
Box::new(HnswLike::new(2, DistanceMetric::Euclidean, HnswConfig::default()).unwrap()),
Box::new(IvfLike::new(2, DistanceMetric::Euclidean, IvfConfig::default()).unwrap()),
];
for index in &mut engine {
index
.insert_batch(vec![
(VectorId::from(1u64), vec2(0.0, 0.0), None),
(VectorId::from(2u64), vec2(9.0, 9.0), None),
])
.unwrap();
index.delete(&VectorId::from(2u64)).unwrap();
let hits = index
.search(
&[0.0, 0.0],
&SearchParams::new(5, DistanceMetric::Euclidean),
)
.unwrap();
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].id, VectorId::U64(1));
index.flush().unwrap();
}
let kinds: Vec<&str> = engine.iter().map(|i| i.stats().index_type).collect();
assert_eq!(kinds, ["flat", "hnsw", "ivf"]);
}