use std::path::Path;
use std::sync::Arc;
use tokio::task::JoinError;
use crate::config::IqdbConfig;
use crate::error::Result;
use crate::{CacheStats, DistanceMetric, Filter, Hit, Iqdb, Metadata, Vector, VectorId};
#[derive(Debug, Clone)]
#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
pub struct AsyncIqdb {
inner: Arc<Iqdb>,
}
impl AsyncIqdb {
pub async fn open_in_memory(dim: usize, metric: DistanceMetric) -> Result<Self> {
Ok(Self {
inner: Arc::new(Iqdb::open_in_memory(dim, metric)?),
})
}
pub async fn open_in_memory_with(config: IqdbConfig) -> Result<Self> {
Ok(Self {
inner: Arc::new(Iqdb::open_in_memory_with(config)?),
})
}
pub async fn open<P: AsRef<Path>>(path: P, dim: usize, metric: DistanceMetric) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let inner =
unwrap_join(tokio::task::spawn_blocking(move || Iqdb::open(path, dim, metric)).await)?;
Ok(Self {
inner: Arc::new(inner),
})
}
pub async fn open_with<P: AsRef<Path>>(path: P, config: IqdbConfig) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let inner =
unwrap_join(tokio::task::spawn_blocking(move || Iqdb::open_with(path, config)).await)?;
Ok(Self {
inner: Arc::new(inner),
})
}
pub async fn upsert(
&self,
id: VectorId,
vector: Vector,
metadata: Option<Metadata>,
) -> Result<()> {
let db = Arc::clone(&self.inner);
unwrap_join(tokio::task::spawn_blocking(move || db.upsert(id, vector, metadata)).await)
}
pub async fn get(&self, id: VectorId) -> Result<Option<(Vector, Option<Metadata>)>> {
let db = Arc::clone(&self.inner);
unwrap_join(tokio::task::spawn_blocking(move || db.get(&id)).await)
}
pub async fn delete(&self, id: VectorId) -> Result<bool> {
let db = Arc::clone(&self.inner);
unwrap_join(tokio::task::spawn_blocking(move || db.delete(&id)).await)
}
pub async fn search(&self, query: Vector, k: usize) -> Result<Vec<Hit>> {
let db = Arc::clone(&self.inner);
unwrap_join(tokio::task::spawn_blocking(move || db.search(&query, k)).await)
}
pub async fn search_with(&self, query: Vector, k: usize, filter: Filter) -> Result<Vec<Hit>> {
let db = Arc::clone(&self.inner);
unwrap_join(tokio::task::spawn_blocking(move || db.search_with(&query, k, filter)).await)
}
pub async fn search_batch(&self, queries: Vec<Vector>, k: usize) -> Result<Vec<Vec<Hit>>> {
let db = Arc::clone(&self.inner);
unwrap_join(tokio::task::spawn_blocking(move || db.search_batch(&queries, k)).await)
}
pub async fn search_batch_with(
&self,
queries: Vec<Vector>,
k: usize,
filter: Filter,
) -> Result<Vec<Vec<Hit>>> {
let db = Arc::clone(&self.inner);
unwrap_join(
tokio::task::spawn_blocking(move || db.search_batch_with(&queries, k, filter)).await,
)
}
pub async fn optimize(&self) -> Result<()> {
let db = Arc::clone(&self.inner);
unwrap_join(tokio::task::spawn_blocking(move || db.optimize()).await)
}
pub async fn flush(&self) -> Result<()> {
let db = Arc::clone(&self.inner);
unwrap_join(tokio::task::spawn_blocking(move || db.flush()).await)
}
pub async fn close(self) -> Result<()> {
let db = self.inner;
unwrap_join(
tokio::task::spawn_blocking(move || match Arc::try_unwrap(db) {
Ok(handle) => handle.close(),
Err(shared) => shared.flush(),
})
.await,
)
}
#[must_use]
pub fn dim(&self) -> usize {
self.inner.dim()
}
#[must_use]
pub fn metric(&self) -> DistanceMetric {
self.inner.metric()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn cache_stats(&self) -> Option<CacheStats> {
self.inner.cache_stats()
}
}
fn unwrap_join<T>(res: std::result::Result<T, JoinError>) -> T {
match res {
Ok(value) => value,
Err(join) => std::panic::resume_unwind(join.into_panic()),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn vec2(a: f32, b: f32) -> Vector {
Vector::new(vec![a, b]).unwrap()
}
#[tokio::test]
async fn async_crud_round_trip() {
let db = AsyncIqdb::open_in_memory(2, DistanceMetric::Euclidean)
.await
.unwrap();
assert!(db.is_empty());
db.upsert(VectorId::from(1u64), vec2(0.0, 0.0), None)
.await
.unwrap();
db.upsert(VectorId::from(2u64), vec2(3.0, 4.0), None)
.await
.unwrap();
assert_eq!(db.len(), 2);
let (got, _) = db.get(VectorId::from(2u64)).await.unwrap().unwrap();
assert_eq!(got.as_slice(), &[3.0, 4.0]);
let hits = db.search(vec2(0.1, 0.1), 1).await.unwrap();
assert_eq!(hits[0].id, VectorId::from(1u64));
assert!(db.delete(VectorId::from(1u64)).await.unwrap());
assert_eq!(db.len(), 1);
db.close().await.unwrap();
}
#[tokio::test]
async fn async_handle_is_send_sync_and_cloneable() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<AsyncIqdb>();
let db = AsyncIqdb::open_in_memory(2, DistanceMetric::Cosine)
.await
.unwrap();
let clone = db.clone();
db.upsert(VectorId::from(1u64), vec2(1.0, 0.0), None)
.await
.unwrap();
assert_eq!(clone.len(), 1);
}
#[tokio::test]
async fn async_rejects_zero_dim() {
assert!(
AsyncIqdb::open_in_memory(0, DistanceMetric::Cosine)
.await
.is_err()
);
}
}