use std::path::Path;
use crate::backend::Backend;
use crate::error::Result;
use crate::file_store::FileStore;
use crate::record::{Record, RecordId};
use crate::search::{flat_search, SearchResult};
use crate::store::MemoryStore;
use crate::vector::{DistanceMetric, Vector};
#[derive(Debug)]
pub struct Iqdb {
backend: Backend,
}
impl Iqdb {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let store = FileStore::open(path.as_ref())?;
Ok(Self {
backend: Backend::File(store),
})
}
#[must_use]
pub fn open_in_memory() -> Self {
Self {
backend: Backend::Memory(MemoryStore::new()),
}
}
pub fn upsert(&self, record: Record) -> Result<()> {
self.backend.upsert(record)
}
pub fn get(&self, id: RecordId) -> Result<Option<Record>> {
self.backend.get(id)
}
pub fn delete(&self, id: RecordId) -> Result<bool> {
self.backend.delete(id)
}
#[must_use]
pub fn len(&self) -> usize {
self.backend.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.backend.is_empty()
}
pub fn search(
&self,
query: &Vector,
k: usize,
metric: DistanceMetric,
) -> Result<Vec<SearchResult>> {
flat_search(&self.backend, query, k, metric, |_| true)
}
pub fn search_with<F>(
&self,
query: &Vector,
k: usize,
metric: DistanceMetric,
filter: F,
) -> Result<Vec<SearchResult>>
where
F: Fn(&Record) -> bool,
{
flat_search(&self.backend, query, k, metric, filter)
}
pub fn search_batch(
&self,
queries: &[Vector],
k: usize,
metric: DistanceMetric,
) -> Result<Vec<Vec<SearchResult>>> {
self.search_batch_with(queries, k, metric, |_| true)
}
pub fn search_batch_with<F>(
&self,
queries: &[Vector],
k: usize,
metric: DistanceMetric,
filter: F,
) -> Result<Vec<Vec<SearchResult>>>
where
F: Fn(&Record) -> bool,
{
let mut out = Vec::with_capacity(queries.len());
for query in queries {
out.push(flat_search(&self.backend, query, k, metric, &filter)?);
}
Ok(out)
}
pub fn flush(&self) -> Result<()> {
self.backend.flush()
}
pub fn close(self) -> Result<()> {
self.backend.close()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::payload::Payload;
use crate::vector::Vector;
fn vec3(a: f32, b: f32, c: f32) -> Vector {
Vector::new(vec![a, b, c]).unwrap()
}
#[test]
fn open_in_memory_returns_empty_handle() {
let db = Iqdb::open_in_memory();
assert!(db.is_empty());
assert_eq!(db.len(), 0);
}
#[test]
fn flush_on_in_memory_is_noop_ok() {
let db = Iqdb::open_in_memory();
assert!(db.flush().is_ok());
}
#[test]
fn close_on_in_memory_succeeds() {
let db = Iqdb::open_in_memory();
assert!(db.close().is_ok());
}
#[test]
fn upsert_get_round_trip() {
let db = Iqdb::open_in_memory();
db.upsert(Record::new(RecordId::new(1), vec3(0.1, 0.2, 0.3)))
.unwrap();
let hit = db.get(RecordId::new(1)).unwrap().expect("present");
assert_eq!(hit.id().get(), 1);
assert_eq!(hit.vector().as_slice(), &[0.1, 0.2, 0.3]);
}
#[test]
fn upsert_replaces_existing_record() {
let db = Iqdb::open_in_memory();
db.upsert(Record::new(RecordId::new(1), vec3(1.0, 0.0, 0.0)))
.unwrap();
db.upsert(Record::new(RecordId::new(1), vec3(0.0, 1.0, 0.0)))
.unwrap();
assert_eq!(db.len(), 1);
let hit = db.get(RecordId::new(1)).unwrap().expect("present");
assert_eq!(hit.vector().as_slice(), &[0.0, 1.0, 0.0]);
}
#[test]
fn get_returns_none_for_missing_id() {
let db = Iqdb::open_in_memory();
assert!(db.get(RecordId::new(99)).unwrap().is_none());
}
#[test]
fn delete_returns_true_only_when_removed() {
let db = Iqdb::open_in_memory();
db.upsert(Record::new(RecordId::new(1), vec3(1.0, 0.0, 0.0)))
.unwrap();
assert!(db.delete(RecordId::new(1)).unwrap());
assert!(!db.delete(RecordId::new(1)).unwrap());
}
#[test]
fn payload_round_trips_through_upsert_and_get() {
let db = Iqdb::open_in_memory();
let mut payload = Payload::new();
payload.insert("kind", "doc");
payload.insert("score", 0.97_f64);
let record = Record::with_payload(RecordId::new(7), vec3(1.0, 2.0, 3.0), payload);
db.upsert(record).unwrap();
let hit = db.get(RecordId::new(7)).unwrap().expect("present");
let payload = hit.payload().expect("payload present");
assert_eq!(
payload
.get("kind")
.and_then(crate::payload::PayloadValue::as_text),
Some("doc")
);
assert!(payload
.get("score")
.and_then(crate::payload::PayloadValue::as_float)
.map(|f| (f - 0.97).abs() < 1e-9)
.unwrap_or(false));
}
#[test]
fn handle_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Iqdb>();
}
}