use std::path::Path;
use std::time::Instant;
use crate::BackendPreference;
use crate::backend::{CpuBackend, RingComputeBackend};
use crate::config::RingDbConfig;
use crate::error::{Result, RingDbError};
use crate::payload::{OwnedPayloadStore, Payload, PayloadBuilderOps, RefPayloadStore};
use crate::persist::{read_f32_file, read_meta, write_f32_file, write_meta};
use crate::query::Hit;
use crate::query::{DiskIntersectionQuery, DiskQuery, QueryResult, RangeQuery, RingQuery};
fn into_hits(responses: Vec<crate::backend::QueryResponse>) -> Vec<Hit> {
responses
.into_iter()
.map(|r| Hit {
id: r.id,
dist_sq: r.dist_sq,
})
.collect()
}
pub struct RingDb<T: Payload = ()> {
config: RingDbConfig,
backend: Box<dyn RingComputeBackend>,
n_vectors: usize,
vectors: Vec<f32>,
norms_sq: Vec<f32>,
payload_builder: T::Builder,
}
impl<T: Payload> RingDb<T> {
pub fn new(config: RingDbConfig) -> Result<Self> {
let backend = match config.backend_preference {
BackendPreference::Cpu => Box::new(CpuBackend::new()),
};
Ok(Self {
config,
backend,
n_vectors: 0,
vectors: Vec::new(),
norms_sq: Vec::new(),
payload_builder: T::make_builder()?,
})
}
pub fn add_vector(&mut self, vector: &[f32], payload: T) -> Result<()> {
let dims = self.config.dims;
if vector.len() != dims {
return Err(RingDbError::DimensionMismatch {
expected: dims,
got: vector.len(),
});
}
let norm_sq: f32 = vector.iter().map(|x| x * x).sum();
self.norms_sq.push(norm_sq);
self.vectors.extend_from_slice(vector);
self.payload_builder.push(payload)?;
self.n_vectors += 1;
Ok(())
}
pub fn build(self) -> Result<SealedRingDb<T>> {
let RingDb {
config,
mut backend,
vectors,
norms_sq,
payload_builder,
n_vectors,
} = self;
let dims = config.dims;
if let Some(dir) = config.persist_dir.clone() {
std::fs::create_dir_all(&dir)?;
write_meta(&dir.join("meta.bin"), dims, n_vectors)?;
write_f32_file(&dir.join("vectors.bin"), &vectors)?;
write_f32_file(&dir.join("norms_sq.bin"), &norms_sq)?;
let payload_store = payload_builder
.finish_persisted(&dir.join("payloads.bin"), &dir.join("offsets.bin"))?;
backend.upload_f32_dataset(dims, vectors, norms_sq)?;
return Ok(SealedRingDb {
config,
backend,
n_vectors,
payload_store,
});
}
backend.upload_f32_dataset(dims, vectors, norms_sq)?;
let payload_store = payload_builder.finish()?;
Ok(SealedRingDb {
config,
backend,
n_vectors,
payload_store,
})
}
pub fn load(
dir: &Path,
backend_preference: crate::config::BackendPreference,
) -> Result<SealedRingDb<T>> {
let (dims, n_vectors) = read_meta(&dir.join("meta.bin"))?;
let vectors = read_f32_file(&dir.join("vectors.bin"))?;
let norms_sq = read_f32_file(&dir.join("norms_sq.bin"))?;
let expected = n_vectors * dims;
if vectors.len() != expected {
return Err(RingDbError::Corrupt(format!(
"vectors.bin has {} f32 values, expected {}",
vectors.len(),
expected,
)));
}
if norms_sq.len() != n_vectors {
return Err(RingDbError::Corrupt(format!(
"norms_sq.bin has {} f32 values, expected {}",
norms_sq.len(),
n_vectors,
)));
}
let mut backend: Box<dyn RingComputeBackend> = match backend_preference {
crate::config::BackendPreference::Cpu => Box::new(CpuBackend::new()),
};
backend.upload_f32_dataset(dims, vectors, norms_sq)?;
let payload_store = T::load_store(dir)?;
Ok(SealedRingDb {
config: RingDbConfig::new(dims)
.with_persist_dir(dir)
.with_backend_preference(backend_preference),
backend,
n_vectors,
payload_store,
})
}
pub fn len(&self) -> usize {
self.n_vectors
}
pub fn is_empty(&self) -> bool {
self.n_vectors == 0
}
pub fn dims(&self) -> usize {
self.config.dims
}
pub fn backend_name(&self) -> &str {
self.backend.name()
}
}
pub struct SealedRingDb<T: Payload = ()> {
config: RingDbConfig,
backend: Box<dyn RingComputeBackend>,
n_vectors: usize,
payload_store: T::Store,
}
impl<T: Payload> SealedRingDb<T> {
pub fn query(&self, q: &RingQuery<'_>) -> Result<QueryResult> {
let dims = self.config.dims;
if q.query.len() != dims {
return Err(RingDbError::DimensionMismatch {
expected: dims,
got: q.query.len(),
});
}
let t = Instant::now();
let hits = into_hits(self.backend.ring_query_f32(
dims,
q.query,
(q.d - q.lambda).max(0.0),
q.d + q.lambda,
)?);
Ok(QueryResult {
hits,
backend_used: self.backend.name(),
elapsed: t.elapsed(),
})
}
pub fn query_range(&self, q: &RangeQuery<'_>) -> Result<QueryResult> {
let dims = self.config.dims;
if q.query.len() != dims {
return Err(RingDbError::DimensionMismatch {
expected: dims,
got: q.query.len(),
});
}
let t = Instant::now();
let hits = into_hits(
self.backend
.ring_query_f32(dims, q.query, q.d_min, q.d_max)?,
);
Ok(QueryResult {
hits,
backend_used: self.backend.name(),
elapsed: t.elapsed(),
})
}
pub fn query_disk(&self, q: &DiskQuery<'_>) -> Result<QueryResult> {
let dims = self.config.dims;
if q.query.len() != dims {
return Err(RingDbError::DimensionMismatch {
expected: dims,
got: q.query.len(),
});
}
let t = Instant::now();
let hits = into_hits(self.backend.disk_query_f32(dims, q.query, q.d_max)?);
Ok(QueryResult {
hits,
backend_used: self.backend.name(),
elapsed: t.elapsed(),
})
}
pub fn query_disk_intersection(&self, q: &DiskIntersectionQuery<'_>) -> Result<QueryResult> {
let dims = self.config.dims;
if q.disks.is_empty() {
return Err(RingDbError::InvalidQuery(
"disk intersection requires at least one disk".to_string(),
));
}
for disk in q.disks {
if disk.query.len() != dims {
return Err(RingDbError::DimensionMismatch {
expected: dims,
got: disk.query.len(),
});
}
}
let disks: Vec<(&[f32], f32)> = q
.disks
.iter()
.map(|disk| (disk.query, disk.d_max))
.collect();
let t = Instant::now();
let hits = into_hits(self.backend.disk_intersection_query_f32(dims, &disks)?);
Ok(QueryResult {
hits,
backend_used: self.backend.name(),
elapsed: t.elapsed(),
})
}
pub fn fetch_payload(&self, id: u32) -> Result<T> {
self.payload_store.fetch_owned(id)
}
pub fn fetch_payloads(&self, ids: &[u32]) -> Result<Vec<T>> {
self.payload_store.fetch_many_owned(ids)
}
pub fn len(&self) -> usize {
self.n_vectors
}
pub fn is_empty(&self) -> bool {
self.n_vectors == 0
}
pub fn dims(&self) -> usize {
self.config.dims
}
pub fn backend_name(&self) -> &str {
self.backend.name()
}
}
impl<T: Payload> SealedRingDb<T>
where
T::Store: RefPayloadStore<T>,
{
pub fn fetch_pod(&self, id: u32) -> &T {
self.payload_store.fetch_ref(id)
}
pub fn fetch_pods<'a>(&'a self, ids: &[u32]) -> Vec<&'a T> {
self.payload_store.fetch_many_ref(ids)
}
}