use std::time::Instant;
use serde::{Serialize, de::DeserializeOwned};
use crate::backend::{CpuBackend, RingComputeBackend};
use crate::config::RingDbConfig;
use crate::error::Result;
use crate::payload::{PayloadStore, PayloadStoreBuilder};
use crate::query::{DiskQuery, QueryResult, RangeQuery, RingQuery};
pub struct RingDb<T = ()> {
config: RingDbConfig,
backend: Box<dyn RingComputeBackend>,
n_vectors: usize,
vectors: Vec<f32>,
norms_sq: Vec<f32>,
payload_builder: PayloadStoreBuilder<T>,
}
impl<T: Serialize + DeserializeOwned> RingDb<T> {
pub fn new(config: RingDbConfig) -> Result<Self> {
Ok(Self {
config,
backend: Box::new(CpuBackend::new()),
n_vectors: 0,
vectors: Vec::new(),
norms_sq: Vec::new(),
payload_builder: PayloadStoreBuilder::new()?,
})
}
pub fn add_vector(&mut self, vector: &[f32], payload: T) -> Result<()> {
let dims = self.config.dims;
if vector.len() != dims {
return Err(crate::error::RingDbError::DimensionMismatch {
expected: dims,
got: vector.len(),
});
}
let norm_sq: f32 = vector.iter().map(|x| x * x).sum();
self.norms_sq.push(norm_sq);
self.vectors.extend_from_slice(vector);
self.payload_builder.push(payload)?;
self.n_vectors += 1;
Ok(())
}
pub fn build(mut self) -> Result<SealedRingDb<T>> {
let dims = self.config.dims;
let n_vectors = self.n_vectors;
self.backend
.upload_f32_dataset(dims, self.vectors, self.norms_sq)?;
let payload_store = self.payload_builder.finish()?;
Ok(SealedRingDb {
config: self.config,
backend: self.backend,
n_vectors,
payload_store,
})
}
pub fn len(&self) -> usize {
self.n_vectors
}
pub fn is_empty(&self) -> bool {
self.n_vectors == 0
}
pub fn dims(&self) -> usize {
self.config.dims
}
pub fn backend_name(&self) -> &str {
self.backend.name()
}
}
pub struct SealedRingDb<T = ()> {
config: RingDbConfig,
backend: Box<dyn RingComputeBackend>,
n_vectors: usize,
payload_store: PayloadStore<T>,
}
impl<T: Serialize + DeserializeOwned> SealedRingDb<T> {
pub fn query(&self, q: &RingQuery<'_>) -> Result<QueryResult> {
let dims = self.config.dims;
if q.query.len() != dims {
return Err(crate::error::RingDbError::DimensionMismatch {
expected: dims,
got: q.query.len(),
});
}
let d_min = (q.d - q.lambda).max(0.0);
let d_max = q.d + q.lambda;
let t = Instant::now();
let ids = self.backend.ring_query_f32(dims, q.query, d_min, d_max)?;
let elapsed = t.elapsed();
Ok(QueryResult {
ids,
backend_used: self.backend.name(),
elapsed,
})
}
pub fn query_range(&self, q: &RangeQuery<'_>) -> Result<QueryResult> {
let dims = self.config.dims;
if q.query.len() != dims {
return Err(crate::error::RingDbError::DimensionMismatch {
expected: dims,
got: q.query.len(),
});
}
let t = Instant::now();
let ids = self
.backend
.ring_query_f32(dims, q.query, q.d_min, q.d_max)?;
let elapsed = t.elapsed();
Ok(QueryResult {
ids,
backend_used: self.backend.name(),
elapsed,
})
}
pub fn query_disk(&self, q: &DiskQuery<'_>) -> Result<QueryResult> {
let dims = self.config.dims;
if q.query.len() != dims {
return Err(crate::error::RingDbError::DimensionMismatch {
expected: dims,
got: q.query.len(),
});
}
let t = Instant::now();
let ids = self.backend.ring_query_f32(dims, q.query, 0.0, q.d_max)?;
let elapsed = t.elapsed();
Ok(QueryResult {
ids,
backend_used: self.backend.name(),
elapsed,
})
}
pub fn fetch_payload(&self, id: u32) -> Result<T> {
self.payload_store.fetch(id)
}
pub fn fetch_payloads(&self, ids: &[u32]) -> Result<Vec<T>> {
self.payload_store.fetch_many(ids)
}
pub fn len(&self) -> usize {
self.n_vectors
}
pub fn is_empty(&self) -> bool {
self.n_vectors == 0
}
pub fn dims(&self) -> usize {
self.config.dims
}
pub fn backend_name(&self) -> &str {
self.backend.name()
}
}