use std::path::Path;
use crate::error::{Error, Result};
use crate::record::{Record, RecordId};
use crate::search::{flat_search, SearchResult};
use crate::store::MemoryStore;
use crate::vector::{DistanceMetric, Vector};
#[derive(Debug)]
pub struct Iqdb {
store: MemoryStore,
}
impl Iqdb {
pub fn open<P: AsRef<Path>>(_path: P) -> Result<Self> {
Err(Error::NotImplemented)
}
#[must_use]
pub fn open_in_memory() -> Self {
Self {
store: MemoryStore::new(),
}
}
pub fn upsert(&self, record: Record) -> Result<()> {
self.store.upsert(record)
}
pub fn get(&self, id: RecordId) -> Result<Option<Record>> {
self.store.get(id)
}
pub fn delete(&self, id: RecordId) -> Result<bool> {
self.store.delete(id)
}
#[must_use]
pub fn len(&self) -> usize {
self.store.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.store.is_empty()
}
pub fn search(
&self,
query: &Vector,
k: usize,
metric: DistanceMetric,
) -> Result<Vec<SearchResult>> {
flat_search(&self.store, query, k, metric, |_| true)
}
pub fn search_with<F>(
&self,
query: &Vector,
k: usize,
metric: DistanceMetric,
filter: F,
) -> Result<Vec<SearchResult>>
where
F: Fn(&Record) -> bool,
{
flat_search(&self.store, query, k, metric, filter)
}
pub fn search_batch(
&self,
queries: &[Vector],
k: usize,
metric: DistanceMetric,
) -> Result<Vec<Vec<SearchResult>>> {
self.search_batch_with(queries, k, metric, |_| true)
}
pub fn search_batch_with<F>(
&self,
queries: &[Vector],
k: usize,
metric: DistanceMetric,
filter: F,
) -> Result<Vec<Vec<SearchResult>>>
where
F: Fn(&Record) -> bool,
{
let mut out = Vec::with_capacity(queries.len());
for query in queries {
out.push(flat_search(&self.store, query, k, metric, &filter)?);
}
Ok(out)
}
pub fn flush(&self) -> Result<()> {
Err(Error::NotImplemented)
}
pub fn close(self) -> Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::payload::Payload;
use crate::vector::Vector;
fn vec3(a: f32, b: f32, c: f32) -> Vector {
Vector::new(vec![a, b, c]).unwrap()
}
#[test]
fn open_in_memory_returns_empty_handle() {
let db = Iqdb::open_in_memory();
assert!(db.is_empty());
assert_eq!(db.len(), 0);
}
#[test]
fn open_returns_not_implemented() {
let result = Iqdb::open("/tmp/iqdb-test");
assert!(matches!(result, Err(Error::NotImplemented)));
}
#[test]
fn flush_returns_not_implemented() {
let db = Iqdb::open_in_memory();
let result = db.flush();
assert!(matches!(result, Err(Error::NotImplemented)));
}
#[test]
fn close_succeeds() {
let db = Iqdb::open_in_memory();
assert!(db.close().is_ok());
}
#[test]
fn upsert_get_round_trip() {
let db = Iqdb::open_in_memory();
db.upsert(Record::new(RecordId::new(1), vec3(0.1, 0.2, 0.3)))
.unwrap();
let hit = db.get(RecordId::new(1)).unwrap().expect("present");
assert_eq!(hit.id().get(), 1);
assert_eq!(hit.vector().as_slice(), &[0.1, 0.2, 0.3]);
}
#[test]
fn upsert_replaces_existing_record() {
let db = Iqdb::open_in_memory();
db.upsert(Record::new(RecordId::new(1), vec3(1.0, 0.0, 0.0)))
.unwrap();
db.upsert(Record::new(RecordId::new(1), vec3(0.0, 1.0, 0.0)))
.unwrap();
assert_eq!(db.len(), 1);
let hit = db.get(RecordId::new(1)).unwrap().expect("present");
assert_eq!(hit.vector().as_slice(), &[0.0, 1.0, 0.0]);
}
#[test]
fn get_returns_none_for_missing_id() {
let db = Iqdb::open_in_memory();
assert!(db.get(RecordId::new(99)).unwrap().is_none());
}
#[test]
fn delete_returns_true_only_when_removed() {
let db = Iqdb::open_in_memory();
db.upsert(Record::new(RecordId::new(1), vec3(1.0, 0.0, 0.0)))
.unwrap();
assert!(db.delete(RecordId::new(1)).unwrap());
assert!(!db.delete(RecordId::new(1)).unwrap());
}
#[test]
fn payload_round_trips_through_upsert_and_get() {
let db = Iqdb::open_in_memory();
let mut payload = Payload::new();
payload.insert("kind", "doc");
payload.insert("score", 0.97_f64);
let record = Record::with_payload(RecordId::new(7), vec3(1.0, 2.0, 3.0), payload);
db.upsert(record).unwrap();
let hit = db.get(RecordId::new(7)).unwrap().expect("present");
let payload = hit.payload().expect("payload present");
assert_eq!(
payload
.get("kind")
.and_then(crate::payload::PayloadValue::as_text),
Some("doc")
);
assert!(payload
.get("score")
.and_then(crate::payload::PayloadValue::as_float)
.map(|f| (f - 0.97).abs() < 1e-9)
.unwrap_or(false));
}
#[test]
fn handle_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Iqdb>();
}
}