Skip to main content

ringdb/
engine.rs

1use std::path::Path;
2use std::time::Instant;
3
4use crate::BackendPreference;
5use crate::backend::{CpuBackend, RingComputeBackend};
6use crate::config::RingDbConfig;
7use crate::error::{Result, RingDbError};
8use crate::payload::{OwnedPayloadStore, Payload, PayloadBuilderOps, RefPayloadStore};
9use crate::persist::{read_f32_file, read_meta, write_f32_file, write_meta};
10use crate::query::Hit;
11use crate::query::{DiskIntersectionQuery, DiskQuery, QueryResult, RangeQuery, RingQuery};
12
13fn into_hits(responses: Vec<crate::backend::QueryResponse>) -> Vec<Hit> {
14    responses
15        .into_iter()
16        .map(|r| Hit {
17            id: r.id,
18            dist_sq: r.dist_sq,
19        })
20        .collect()
21}
22
23// ─── RingDb (builder) ────────────────────────────────────────────────────────
24
25/// Builder for a ring-query vector database.
26///
27/// Insert vectors with their associated payloads via
28/// [`add_vector()`](Self::add_vector), then call [`build()`](Self::build)
29/// to obtain a [`SealedRingDb`].
30///
31/// `T` must implement [`Payload`], which is derived with `#[derive(Payload)]`.
32/// Use `T = ()` when no payload is needed.
33///
34/// # Example — no payload
35///
36/// ```
37/// use ringdb::{RingDb, RingDbConfig, RingQuery};
38///
39/// let mut db = RingDb::new(RingDbConfig::new(4)).unwrap();
40/// db.add_vector(&[1.0, 0.0, 0.0, 0.0], ()).unwrap();
41/// db.add_vector(&[0.0, 1.0, 0.0, 0.0], ()).unwrap();
42///
43/// let db = db.build().unwrap();
44/// let result = db.query(&RingQuery { query: &[1.0f32, 0.0, 0.0, 0.0], d: 1.0, lambda: 0.1 }).unwrap();
45/// println!("hits: {:?}", result.ids());
46/// ```
47pub struct RingDb<T: Payload = ()> {
48    config: RingDbConfig,
49    backend: Box<dyn RingComputeBackend>,
50    n_vectors: usize,
51
52    /// Staging buffer: f32 vectors, row-major, `n_vectors × dims`.
53    vectors: Vec<f32>,
54
55    /// Staging buffer: per-vector squared L2 norm.
56    norms_sq: Vec<f32>,
57
58    /// Concrete builder — `SerdeStoreBuilder<T>` or `PodStoreBuilder<T>`,
59    /// determined at construction time by `T::make_builder()`.
60    /// No heap indirection; lives directly in the struct.
61    payload_builder: T::Builder,
62}
63
64impl<T: Payload> RingDb<T> {
65    /// Create a new empty `RingDb`.
66    ///
67    /// The storage strategy (Serde or Pod) is determined entirely by `T`'s
68    /// `#[derive(Payload)]` — no second constructor needed.
69    ///
70    /// # Example — with Serde payload
71    ///
72    /// ```
73    /// use ringdb::{RingDb, RingDbConfig, RingQuery, Payload};
74    /// use serde::{Serialize, Deserialize};
75    ///
76    /// #[derive(Serialize, Deserialize, Payload)]
77    /// struct Meta { label: String }
78    ///
79    /// let mut db: RingDb<Meta> = RingDb::new(RingDbConfig::new(2)).unwrap();
80    /// db.add_vector(&[1.0, 0.0], Meta { label: "dog".into() }).unwrap();
81    /// db.add_vector(&[0.0, 1.0], Meta { label: "cat".into() }).unwrap();
82    ///
83    /// let db = db.build().unwrap();
84    /// let result = db.query(&RingQuery { query: &[1.0f32, 0.0], d: 1.0, lambda: 0.1 }).unwrap();
85    /// let payloads = db.fetch_payloads(&result.ids()).unwrap();
86    /// ```
87    pub fn new(config: RingDbConfig) -> Result<Self> {
88        let backend = match config.backend_preference {
89            BackendPreference::Cpu => Box::new(CpuBackend::new()),
90        };
91        Ok(Self {
92            config,
93            backend,
94            n_vectors: 0,
95            vectors: Vec::new(),
96            norms_sq: Vec::new(),
97            payload_builder: T::make_builder()?,
98        })
99    }
100
101    /// Insert a single vector and its associated payload.
102    ///
103    /// Vectors are assigned sequential IDs starting from 0.
104    /// The slice length must equal `dims`.
105    pub fn add_vector(&mut self, vector: &[f32], payload: T) -> Result<()> {
106        let dims = self.config.dims;
107        if vector.len() != dims {
108            return Err(RingDbError::DimensionMismatch {
109                expected: dims,
110                got: vector.len(),
111            });
112        }
113        let norm_sq: f32 = vector.iter().map(|x| x * x).sum();
114        self.norms_sq.push(norm_sq);
115        self.vectors.extend_from_slice(vector);
116        self.payload_builder.push(payload)?;
117        self.n_vectors += 1;
118        Ok(())
119    }
120
121    /// Seal the database.
122    ///
123    /// Transfers vectors to the compute backend and flushes the payload builder
124    /// to its mmap. If [`RingDbConfig::persist_dir`] is set, all data is also
125    /// written to disk (reload with [`RingDb::load`]).
126    pub fn build(self) -> Result<SealedRingDb<T>> {
127        let RingDb {
128            config,
129            mut backend,
130            vectors,
131            norms_sq,
132            payload_builder,
133            n_vectors,
134        } = self;
135        let dims = config.dims;
136
137        if let Some(dir) = config.persist_dir.clone() {
138            std::fs::create_dir_all(&dir)?;
139            write_meta(&dir.join("meta.bin"), dims, n_vectors)?;
140            write_f32_file(&dir.join("vectors.bin"), &vectors)?;
141            write_f32_file(&dir.join("norms_sq.bin"), &norms_sq)?;
142            let payload_store = payload_builder
143                .finish_persisted(&dir.join("payloads.bin"), &dir.join("offsets.bin"))?;
144            backend.upload_f32_dataset(dims, vectors, norms_sq)?;
145            return Ok(SealedRingDb {
146                config,
147                backend,
148                n_vectors,
149                payload_store,
150            });
151        }
152
153        backend.upload_f32_dataset(dims, vectors, norms_sq)?;
154        let payload_store = payload_builder.finish()?;
155        Ok(SealedRingDb {
156            config,
157            backend,
158            n_vectors,
159            payload_store,
160        })
161    }
162
163    /// Reconstruct a [`SealedRingDb`] from a directory previously written by
164    /// [`build()`](Self::build) with a persist dir configured.
165    ///
166    /// The correct store variant is selected automatically based on `T`'s
167    /// `Payload` impl — no separate `load_pod` method needed.
168    ///
169    /// # Example
170    ///
171    /// ```no_run
172    /// use ringdb::{RingDb, RingDbConfig, BackendPreference};
173    /// use std::path::Path;
174    ///
175    /// // --- save ---
176    /// let mut db = RingDb::<()>::new(RingDbConfig::new(4).with_persist_dir("/tmp/mydb")).unwrap();
177    /// db.add_vector(&[1.0, 0.0, 0.0, 0.0], ()).unwrap();
178    /// let _sealed = db.build().unwrap();
179    ///
180    /// // --- load ---
181    /// let loaded = RingDb::<()>::load(Path::new("/tmp/mydb"), BackendPreference::Cpu).unwrap();
182    /// ```
183    pub fn load(
184        dir: &Path,
185        backend_preference: crate::config::BackendPreference,
186    ) -> Result<SealedRingDb<T>> {
187        let (dims, n_vectors) = read_meta(&dir.join("meta.bin"))?;
188        let vectors = read_f32_file(&dir.join("vectors.bin"))?;
189        let norms_sq = read_f32_file(&dir.join("norms_sq.bin"))?;
190
191        let expected = n_vectors * dims;
192        if vectors.len() != expected {
193            return Err(RingDbError::Corrupt(format!(
194                "vectors.bin has {} f32 values, expected {}",
195                vectors.len(),
196                expected,
197            )));
198        }
199        if norms_sq.len() != n_vectors {
200            return Err(RingDbError::Corrupt(format!(
201                "norms_sq.bin has {} f32 values, expected {}",
202                norms_sq.len(),
203                n_vectors,
204            )));
205        }
206
207        let mut backend: Box<dyn RingComputeBackend> = match backend_preference {
208            crate::config::BackendPreference::Cpu => Box::new(CpuBackend::new()),
209        };
210        backend.upload_f32_dataset(dims, vectors, norms_sq)?;
211
212        let payload_store = T::load_store(dir)?;
213
214        Ok(SealedRingDb {
215            config: RingDbConfig::new(dims)
216                .with_persist_dir(dir)
217                .with_backend_preference(backend_preference),
218            backend,
219            n_vectors,
220            payload_store,
221        })
222    }
223
224    /// Number of vectors currently staged.
225    pub fn len(&self) -> usize {
226        self.n_vectors
227    }
228
229    /// Returns `true` if no vectors have been inserted.
230    pub fn is_empty(&self) -> bool {
231        self.n_vectors == 0
232    }
233
234    /// Number of dimensions per vector.
235    pub fn dims(&self) -> usize {
236        self.config.dims
237    }
238
239    /// Name of the backend currently in use.
240    pub fn backend_name(&self) -> &str {
241        self.backend.name()
242    }
243}
244
245// ─── SealedRingDb ────────────────────────────────────────────────────────────
246
247/// Sealed (immutable) ring-query database.
248///
249/// Obtained by calling [`RingDb::build()`] or [`RingDb::load()`].
250///
251/// The hot side (vectors + norms) is owned by the compute backend.
252/// The cold side (payloads) lives in a file-backed mmap via `T::Store`.
253pub struct SealedRingDb<T: Payload = ()> {
254    config: RingDbConfig,
255    backend: Box<dyn RingComputeBackend>,
256    n_vectors: usize,
257    payload_store: T::Store,
258}
259
260impl<T: Payload> SealedRingDb<T> {
261    // ── Query methods ─────────────────────────────────────────────────────────
262
263    /// Execute a ring query and return matching vector IDs.
264    pub fn query(&self, q: &RingQuery<'_>) -> Result<QueryResult> {
265        let dims = self.config.dims;
266        if q.query.len() != dims {
267            return Err(RingDbError::DimensionMismatch {
268                expected: dims,
269                got: q.query.len(),
270            });
271        }
272        let t = Instant::now();
273        let hits = into_hits(self.backend.ring_query_f32(
274            dims,
275            q.query,
276            (q.d - q.lambda).max(0.0),
277            q.d + q.lambda,
278        )?);
279        Ok(QueryResult {
280            hits,
281            backend_used: self.backend.name(),
282            elapsed: t.elapsed(),
283        })
284    }
285
286    /// Execute a range query: all vectors with distance in `[d_min, d_max]`.
287    pub fn query_range(&self, q: &RangeQuery<'_>) -> Result<QueryResult> {
288        let dims = self.config.dims;
289        if q.query.len() != dims {
290            return Err(RingDbError::DimensionMismatch {
291                expected: dims,
292                got: q.query.len(),
293            });
294        }
295        let t = Instant::now();
296        let hits = into_hits(
297            self.backend
298                .ring_query_f32(dims, q.query, q.d_min, q.d_max)?,
299        );
300        Ok(QueryResult {
301            hits,
302            backend_used: self.backend.name(),
303            elapsed: t.elapsed(),
304        })
305    }
306
307    /// Execute a disk query: all vectors within radius `d_max`.
308    pub fn query_disk(&self, q: &DiskQuery<'_>) -> Result<QueryResult> {
309        let dims = self.config.dims;
310        if q.query.len() != dims {
311            return Err(RingDbError::DimensionMismatch {
312                expected: dims,
313                got: q.query.len(),
314            });
315        }
316        let t = Instant::now();
317        let hits = into_hits(self.backend.disk_query_f32(dims, q.query, q.d_max)?);
318        Ok(QueryResult {
319            hits,
320            backend_used: self.backend.name(),
321            elapsed: t.elapsed(),
322        })
323    }
324
325    /// Execute a disk-intersection query: all vectors inside every disk.
326    ///
327    /// Returned hit distances are measured against the first disk.
328    pub fn query_disk_intersection(&self, q: &DiskIntersectionQuery<'_>) -> Result<QueryResult> {
329        let dims = self.config.dims;
330        if q.disks.is_empty() {
331            return Err(RingDbError::InvalidQuery(
332                "disk intersection requires at least one disk".to_string(),
333            ));
334        }
335        for disk in q.disks {
336            if disk.query.len() != dims {
337                return Err(RingDbError::DimensionMismatch {
338                    expected: dims,
339                    got: disk.query.len(),
340                });
341            }
342        }
343
344        let disks: Vec<(&[f32], f32)> = q
345            .disks
346            .iter()
347            .map(|disk| (disk.query, disk.d_max))
348            .collect();
349        let t = Instant::now();
350        let hits = into_hits(self.backend.disk_intersection_query_f32(dims, &disks)?);
351        Ok(QueryResult {
352            hits,
353            backend_used: self.backend.name(),
354            elapsed: t.elapsed(),
355        })
356    }
357
358    // ── Serde payload fetch ───────────────────────────────────────────────────
359
360    /// Fetch and deserialize the payload for a single vector ID.
361    pub fn fetch_payload(&self, id: u32) -> Result<T> {
362        self.payload_store.fetch_owned(id)
363    }
364
365    /// Fetch and deserialize payloads for a slice of IDs, in order.
366    pub fn fetch_payloads(&self, ids: &[u32]) -> Result<Vec<T>> {
367        self.payload_store.fetch_many_owned(ids)
368    }
369
370    // ── Accessors ─────────────────────────────────────────────────────────────
371
372    /// Number of vectors stored.
373    pub fn len(&self) -> usize {
374        self.n_vectors
375    }
376
377    /// Returns `true` if the database contains no vectors.
378    pub fn is_empty(&self) -> bool {
379        self.n_vectors == 0
380    }
381
382    /// Number of dimensions per vector.
383    pub fn dims(&self) -> usize {
384        self.config.dims
385    }
386
387    /// Name of the backend currently in use.
388    pub fn backend_name(&self) -> &str {
389        self.backend.name()
390    }
391}
392
393// ── Pod fetch — only when T::Store: RefPayloadStore<T> ───────────────────────
394//
395// This impl block is only available for types whose `#[derive(Payload)]`
396// chose `storage = "pod"`. For Serde types, `T::Store = SerdeStore<T>` which
397// does NOT implement `RefPayloadStore<T>`, so these methods simply don't exist.
398// The compiler enforces this statically — no runtime error possible.
399
400impl<T: Payload> SealedRingDb<T>
401where
402    T::Store: RefPayloadStore<T>,
403{
404    /// Fetch a zero-copy reference to the payload for a single vector ID.
405    ///
406    /// Returns a `&T` pointing directly into the mmap — O(1), no allocation,
407    /// no deserialization. Only available for `#[payload(storage = "pod")]` types.
408    pub fn fetch_pod(&self, id: u32) -> &T {
409        self.payload_store.fetch_ref(id)
410    }
411
412    /// Fetch zero-copy references to payloads for a slice of IDs, in order.
413    pub fn fetch_pods<'a>(&'a self, ids: &[u32]) -> Vec<&'a T> {
414        self.payload_store.fetch_many_ref(ids)
415    }
416}