use std::fmt;
use std::path::Path;
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use iqdb_cache::CacheStats;
use iqdb_index::{Index, IndexCore};
use iqdb_persist::{PersistConfig, PersistedIndex};
use iqdb_types::{
DistanceMetric, Filter, Hit, IqdbError, Metadata, SearchParams, Vector, VectorId,
};
use crate::config::IqdbConfig;
use crate::engine::IqdbCore;
use crate::error::{Error, Result};
enum Storage {
Memory(IqdbCore),
Persisted(PersistedIndex<IqdbCore>),
}
impl Storage {
fn core(&self) -> &IqdbCore {
match self {
Self::Memory(core) => core,
Self::Persisted(persisted) => persisted.index(),
}
}
fn core_mut(&mut self) -> &mut IqdbCore {
match self {
Self::Memory(core) => core,
Self::Persisted(persisted) => persisted.index_mut(),
}
}
fn upsert(&mut self, id: VectorId, vector: Arc<[f32]>, meta: Option<Metadata>) -> Result<()> {
match self {
Self::Memory(core) => core.insert(id, vector, meta)?,
Self::Persisted(persisted) => persisted.insert(id, vector, meta)?,
}
Ok(())
}
fn delete(&mut self, id: &VectorId) -> Result<()> {
match self {
Self::Memory(core) => core.delete(id)?,
Self::Persisted(persisted) => persisted.delete(id)?,
}
Ok(())
}
fn checkpoint(&mut self) -> Result<()> {
if let Self::Persisted(persisted) = self {
persisted.checkpoint()?;
}
Ok(())
}
}
pub struct Iqdb {
dim: usize,
metric: DistanceMetric,
inner: RwLock<Storage>,
}
impl Iqdb {
pub fn open_in_memory(dim: usize, metric: DistanceMetric) -> Result<Self> {
Self::open_in_memory_with(IqdbConfig::new(dim, metric))
}
pub fn open_in_memory_with(config: IqdbConfig) -> Result<Self> {
let (dim, metric, core_cfg, _durability) = config.into_parts();
Self::require_nonzero_dim(dim)?;
let core = IqdbCore::new(dim, metric, core_cfg)?;
Ok(Self {
dim,
metric,
inner: RwLock::new(Storage::Memory(core)),
})
}
pub fn open<P: AsRef<Path>>(path: P, dim: usize, metric: DistanceMetric) -> Result<Self> {
Self::open_with(path, IqdbConfig::new(dim, metric))
}
pub fn open_with<P: AsRef<Path>>(path: P, config: IqdbConfig) -> Result<Self> {
let (dim, metric, core_cfg, durability) = config.into_parts();
Self::require_nonzero_dim(dim)?;
let cache = core_cfg.cache.clone();
let path = path.as_ref().to_path_buf();
let mut persist_cfg = PersistConfig::new(path.clone());
persist_cfg.wal_enabled = true;
persist_cfg.fsync_policy = durability.fsync;
persist_cfg.compression = durability.compression;
let storage = if path.exists() {
let mut persisted = PersistedIndex::<IqdbCore>::load(persist_cfg)?;
{
let core = persisted.index();
if core.dim() != dim || core.metric() != metric {
return Err(Error::Config(
"reopened database dim/metric does not match the requested values",
));
}
}
if cache.is_some() {
persisted.index_mut().set_cache(cache)?;
}
Storage::Persisted(persisted)
} else {
let core = IqdbCore::new(dim, metric, core_cfg)?;
Storage::Persisted(PersistedIndex::open_with(core, persist_cfg)?)
};
Ok(Self {
dim,
metric,
inner: RwLock::new(storage),
})
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
pub fn metric(&self) -> DistanceMetric {
self.metric
}
#[must_use]
pub fn len(&self) -> usize {
self.read().core().len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.read().core().is_empty()
}
pub fn upsert(&self, id: VectorId, vector: Vector, metadata: Option<Metadata>) -> Result<()> {
self.check_dim(vector.dim())?;
let arc: Arc<[f32]> = Arc::from(vector.into_inner().into_boxed_slice());
self.write().upsert(id, arc, metadata)
}
pub fn get(&self, id: &VectorId) -> Result<Option<(Vector, Option<Metadata>)>> {
match self.read().core().get_row(id) {
None => Ok(None),
Some((vector, meta)) => {
let vector = Vector::new(vector.as_ref().to_vec())?;
Ok(Some((vector, meta)))
}
}
}
pub fn delete(&self, id: &VectorId) -> Result<bool> {
let mut guard = self.write();
let existed = guard.core().contains(id);
guard.delete(id)?;
Ok(existed)
}
pub fn search(&self, query: &Vector, k: usize) -> Result<Vec<Hit>> {
self.search_inner(query, k, None)
}
pub fn search_with(&self, query: &Vector, k: usize, filter: Filter) -> Result<Vec<Hit>> {
self.search_inner(query, k, Some(filter))
}
pub fn search_batch(&self, queries: &[Vector], k: usize) -> Result<Vec<Vec<Hit>>> {
queries.iter().map(|q| self.search(q, k)).collect()
}
pub fn search_batch_with(
&self,
queries: &[Vector],
k: usize,
filter: Filter,
) -> Result<Vec<Vec<Hit>>> {
queries
.iter()
.map(|q| self.search_with(q, k, filter.clone()))
.collect()
}
pub fn optimize(&self) -> Result<()> {
self.write().core_mut().optimize().map_err(Error::from)
}
#[must_use]
pub fn cache_stats(&self) -> Option<CacheStats> {
self.read().core().cache_stats()
}
pub fn flush(&self) -> Result<()> {
self.write().checkpoint()
}
pub fn close(self) -> Result<()> {
let mut storage = self
.inner
.into_inner()
.unwrap_or_else(|poison| poison.into_inner());
storage.checkpoint()
}
fn require_nonzero_dim(dim: usize) -> Result<()> {
if dim == 0 {
Err(Error::Config("dim must be non-zero"))
} else {
Ok(())
}
}
fn check_dim(&self, found: usize) -> Result<()> {
if found == self.dim {
Ok(())
} else {
Err(Error::Index(IqdbError::DimensionMismatch {
expected: self.dim,
found,
}))
}
}
fn params(&self, k: usize, filter: Option<Filter>) -> SearchParams {
SearchParams {
filter,
..SearchParams::new(k, self.metric)
}
}
fn search_inner(&self, query: &Vector, k: usize, filter: Option<Filter>) -> Result<Vec<Hit>> {
self.check_dim(query.dim())?;
{
let guard = self.read();
if guard.core().is_empty() {
return Ok(Vec::new());
}
if !guard.core().needs_materialization() {
let params = self.params(k, filter);
return guard
.core()
.search(query.as_slice(), ¶ms)
.map_err(Error::from);
}
}
let mut guard = self.write();
if guard.core().is_empty() {
return Ok(Vec::new());
}
guard.core_mut().ensure_ready()?;
let params = self.params(k, filter);
guard
.core()
.search(query.as_slice(), ¶ms)
.map_err(Error::from)
}
fn read(&self) -> RwLockReadGuard<'_, Storage> {
self.inner
.read()
.unwrap_or_else(|poison| poison.into_inner())
}
fn write(&self) -> RwLockWriteGuard<'_, Storage> {
self.inner
.write()
.unwrap_or_else(|poison| poison.into_inner())
}
}
impl fmt::Debug for Iqdb {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Iqdb")
.field("dim", &self.dim)
.field("metric", &self.metric)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{HnswConfig, IndexKind, IvfConfig};
fn vec2(a: f32, b: f32) -> Vector {
Vector::new(vec![a, b]).unwrap()
}
#[test]
fn handle_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Iqdb>();
}
#[test]
fn open_in_memory_rejects_zero_dim() {
let direct = Iqdb::open_in_memory(0, DistanceMetric::Cosine).unwrap_err();
assert!(matches!(direct, Error::Config(_)), "got {direct:?}");
let via_config =
Iqdb::open_in_memory_with(IqdbConfig::new(0, DistanceMetric::Cosine)).unwrap_err();
assert!(matches!(via_config, Error::Config(_)), "got {via_config:?}");
}
#[test]
fn crud_round_trip_in_memory() {
let db = Iqdb::open_in_memory(2, DistanceMetric::Euclidean).unwrap();
assert!(db.is_empty());
db.upsert(VectorId::from(1u64), vec2(0.0, 0.0), None)
.unwrap();
db.upsert(VectorId::from(2u64), vec2(3.0, 4.0), None)
.unwrap();
assert_eq!(db.len(), 2);
let (got, _) = db.get(&VectorId::from(2u64)).unwrap().unwrap();
assert_eq!(got.as_slice(), &[3.0, 4.0]);
assert!(db.delete(&VectorId::from(2u64)).unwrap());
assert!(!db.delete(&VectorId::from(2u64)).unwrap());
assert_eq!(db.len(), 1);
}
#[test]
fn upsert_rejects_wrong_dimension() {
let db = Iqdb::open_in_memory(3, DistanceMetric::Cosine).unwrap();
let err = db
.upsert(VectorId::from(1u64), vec2(1.0, 0.0), None)
.unwrap_err();
assert!(matches!(
err,
Error::Index(IqdbError::DimensionMismatch {
expected: 3,
found: 2
})
));
}
#[test]
fn search_on_empty_returns_empty() {
let db = Iqdb::open_in_memory(2, DistanceMetric::Cosine).unwrap();
assert!(db.search(&vec2(1.0, 0.0), 5).unwrap().is_empty());
}
#[test]
fn search_orders_nearest_first() {
let db = Iqdb::open_in_memory(2, DistanceMetric::Euclidean).unwrap();
db.upsert(VectorId::from(1u64), vec2(0.0, 0.0), None)
.unwrap();
db.upsert(VectorId::from(2u64), vec2(10.0, 10.0), None)
.unwrap();
let hits = db.search(&vec2(0.1, 0.1), 2).unwrap();
assert_eq!(hits[0].id, VectorId::from(1u64));
assert_eq!(hits[1].id, VectorId::from(2u64));
}
#[test]
fn ivf_handle_searches_after_lazy_materialization() {
let cfg = IqdbConfig::new(2, DistanceMetric::Euclidean).index(IndexKind::Ivf(
IvfConfig::default()
.with_n_clusters(2)
.with_n_probes(2)
.with_training_sample_size(64)
.with_seed(7),
));
let db = Iqdb::open_in_memory_with(cfg).unwrap();
for (i, p) in [[0.0, 0.0], [0.1, -0.1], [10.0, 10.0], [9.9, 10.1]]
.iter()
.enumerate()
{
db.upsert(
VectorId::from(i as u64),
Vector::new(p.to_vec()).unwrap(),
None,
)
.unwrap();
}
let hits = db.search(&vec2(0.0, 0.0), 1).unwrap();
assert_eq!(hits[0].id, VectorId::from(0u64));
}
#[test]
fn hnsw_handle_round_trip() {
let cfg = IqdbConfig::new(2, DistanceMetric::Euclidean)
.index(IndexKind::Hnsw(HnswConfig::default()));
let db = Iqdb::open_in_memory_with(cfg).unwrap();
db.upsert(VectorId::from(1u64), vec2(0.0, 0.0), None)
.unwrap();
db.upsert(VectorId::from(2u64), vec2(5.0, 5.0), None)
.unwrap();
let hits = db.search(&vec2(0.0, 0.0), 1).unwrap();
assert_eq!(hits[0].id, VectorId::from(1u64));
}
}