Skip to main content

polarisdb_core/
collection.rs

1//! Persistent vector collection with crash-safe storage.
2//!
3//! A `Collection` combines the in-memory index with durable storage:
4//! - WAL for atomic operations
5//! - Append-only data file for vectors
6//! - Automatic recovery on open
7
8use std::collections::HashMap;
9use std::fs;
10use std::path::{Path, PathBuf};
11
12use parking_lot::RwLock;
13
14use crate::distance::DistanceMetric;
15use crate::error::{Error, Result};
16use crate::filter::Filter;
17use crate::index::brute_force::{BruteForceIndex, SearchResult};
18use crate::payload::Payload;
19use crate::storage::data_file::DataFile;
20use crate::storage::wal::{SyncMode, Wal, WalEntry, WalEntryKind};
21use crate::vector::VectorId;
22
23/// Configuration for a collection.
24#[derive(Debug, Clone)]
25pub struct CollectionConfig {
26    /// Dimensionality of vectors.
27    pub dimension: usize,
28    /// Distance metric.
29    pub metric: DistanceMetric,
30    /// WAL sync mode.
31    pub sync_mode: SyncMode,
32}
33
34impl CollectionConfig {
35    /// Creates a new config with the given dimension and metric.
36    pub fn new(dimension: usize, metric: DistanceMetric) -> Self {
37        Self {
38            dimension,
39            metric,
40            sync_mode: SyncMode::Batched,
41        }
42    }
43
44    /// Sets the sync mode. Chainable.
45    pub fn with_sync_mode(mut self, mode: SyncMode) -> Self {
46        self.sync_mode = mode;
47        self
48    }
49}
50
51/// Metadata for the collection, persisted to disk.
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
53struct CollectionMeta {
54    dimension: usize,
55    metric: String,
56    vector_count: u64,
57    next_id: u64,
58}
59
60/// A persistent vector collection.
61///
62/// The collection provides durable storage with automatic crash recovery.
63/// All mutations are first written to a WAL, then applied to the in-memory
64/// index. On open, the WAL is replayed to restore state.
65///
66/// # Example
67///
68/// ```no_run
69/// use polarisdb_core::{Collection, CollectionConfig, DistanceMetric, Payload};
70///
71/// // Open or create a collection
72/// let config = CollectionConfig::new(384, DistanceMetric::Cosine);
73/// let collection = Collection::open_or_create("./my_vectors", config).unwrap();
74///
75/// // Insert vectors
76/// collection.insert(1, vec![0.1; 384], Payload::new()).unwrap();
77///
78/// // Search
79/// let query = vec![0.1; 384];
80/// let results = collection.search(&query, 10, None);
81///
82/// // Flush to ensure durability
83/// collection.flush().unwrap();
84/// ```
85pub struct Collection {
86    /// Path to collection directory.
87    path: PathBuf,
88    /// Configuration.
89    config: CollectionConfig,
90    /// In-memory index.
91    index: RwLock<BruteForceIndex>,
92    /// Write-ahead log.
93    wal: RwLock<Wal>,
94    /// Data file for vector storage.
95    data_file: RwLock<DataFile>,
96    /// Mapping from vector ID to data file offset.
97    offsets: RwLock<HashMap<VectorId, u64>>,
98    /// Next auto-generated ID.
99    next_id: RwLock<u64>,
100}
101
102impl Collection {
103    /// Opens an existing collection or creates a new one.
104    pub fn open_or_create<P: AsRef<Path>>(path: P, config: CollectionConfig) -> Result<Self> {
105        let path = path.as_ref().to_path_buf();
106
107        // Create directory if needed
108        if !path.exists() {
109            fs::create_dir_all(&path)
110                .map_err(|e| Error::CollectionError(format!("create dir failed: {}", e)))?;
111        }
112
113        let meta_path = path.join("meta.json");
114        let wal_path = path.join("wal.log");
115        let data_path = path.join("data.pdb");
116
117        // Load or create metadata
118        let (meta, is_new) = if meta_path.exists() {
119            let content = fs::read_to_string(&meta_path)
120                .map_err(|e| Error::CollectionError(format!("read meta failed: {}", e)))?;
121            let meta: CollectionMeta = serde_json::from_str(&content)
122                .map_err(|e| Error::CollectionError(format!("parse meta failed: {}", e)))?;
123            (meta, false)
124        } else {
125            let meta = CollectionMeta {
126                dimension: config.dimension,
127                metric: format!("{:?}", config.metric),
128                vector_count: 0,
129                next_id: 1,
130            };
131            (meta, true)
132        };
133
134        // Verify dimension matches
135        if !is_new && meta.dimension != config.dimension {
136            return Err(Error::CollectionError(format!(
137                "dimension mismatch: collection has {}, config has {}",
138                meta.dimension, config.dimension
139            )));
140        }
141
142        // Open WAL and data file
143        let wal = Wal::open(&wal_path, config.sync_mode)?;
144        let data_file = DataFile::open(&data_path)?;
145
146        // Create in-memory index
147        let index = BruteForceIndex::new(config.metric, config.dimension);
148
149        let collection = Self {
150            path: path.clone(),
151            config,
152            index: RwLock::new(index),
153            wal: RwLock::new(wal),
154            data_file: RwLock::new(data_file),
155            offsets: RwLock::new(HashMap::new()),
156            next_id: RwLock::new(meta.next_id),
157        };
158
159        // Recover from data file and WAL
160        collection.recover()?;
161
162        // Save metadata if new
163        if is_new {
164            collection.save_meta()?;
165        }
166
167        Ok(collection)
168    }
169
170    /// Recovers state by reading the data file and replaying the WAL.
171    fn recover(&self) -> Result<()> {
172        // First, load all active records from data file
173        let records = {
174            let df = self.data_file.read();
175            df.iter_active()?
176        };
177
178        {
179            let mut index = self.index.write();
180            let mut offsets = self.offsets.write();
181            let mut max_id = 0u64;
182
183            for record in records {
184                let _ = index.insert(record.id, record.vector.clone(), record.payload.clone());
185                offsets.insert(record.id, record.offset);
186                max_id = max_id.max(record.id);
187            }
188
189            *self.next_id.write() = max_id + 1;
190        }
191
192        // Then replay WAL (may have newer operations)
193        let wal_path = self.path.join("wal.log");
194        let entries = Wal::read_all(&wal_path)?;
195
196        for entry in entries {
197            match entry.kind {
198                WalEntryKind::Insert => {
199                    self.apply_insert_no_wal(entry.id, entry.vector, entry.payload)?;
200                }
201                WalEntryKind::Update => {
202                    self.apply_update_no_wal(entry.id, entry.vector, entry.payload)?;
203                }
204                WalEntryKind::Delete => {
205                    self.apply_delete_no_wal(entry.id)?;
206                }
207                WalEntryKind::Checkpoint => {
208                    // Checkpoint entries are just markers
209                }
210            }
211        }
212
213        Ok(())
214    }
215
216    /// Inserts a vector with the given ID.
217    pub fn insert(&self, id: VectorId, vector: Vec<f32>, payload: Payload) -> Result<()> {
218        // Write to WAL first
219        {
220            let mut wal = self.wal.write();
221            wal.append(&WalEntry::insert(id, vector.clone(), payload.clone()))?;
222        }
223
224        // Apply to index and data file
225        self.apply_insert_no_wal(id, vector, payload)
226    }
227
228    /// Inserts with auto-generated ID. Returns the ID.
229    pub fn insert_auto(&self, vector: Vec<f32>, payload: Payload) -> Result<VectorId> {
230        let id = {
231            let mut next_id = self.next_id.write();
232            let id = *next_id;
233            *next_id += 1;
234            id
235        };
236
237        self.insert(id, vector, payload)?;
238        Ok(id)
239    }
240
241    /// Updates an existing vector.
242    pub fn update(&self, id: VectorId, vector: Vec<f32>, payload: Payload) -> Result<()> {
243        // Write to WAL first
244        {
245            let mut wal = self.wal.write();
246            wal.append(&WalEntry::update(id, vector.clone(), payload.clone()))?;
247        }
248
249        self.apply_update_no_wal(id, vector, payload)
250    }
251
252    /// Deletes a vector.
253    pub fn delete(&self, id: VectorId) -> Result<bool> {
254        // Write to WAL first
255        {
256            let mut wal = self.wal.write();
257            wal.append(&WalEntry::delete(id))?;
258        }
259
260        self.apply_delete_no_wal(id)
261    }
262
263    /// Searches for similar vectors.
264    pub fn search(&self, query: &[f32], k: usize, filter: Option<Filter>) -> Vec<SearchResult> {
265        let index = self.index.read();
266        index.search(query, k, filter)
267    }
268
269    /// Gets a vector by ID.
270    pub fn get(&self, id: VectorId) -> Option<(Vec<f32>, Payload)> {
271        let index = self.index.read();
272        index
273            .get(id)
274            .map(|(v, p)| (v.as_slice().to_vec(), p.clone()))
275    }
276
277    /// Returns the number of vectors.
278    pub fn len(&self) -> usize {
279        self.index.read().len()
280    }
281
282    /// Returns true if empty.
283    pub fn is_empty(&self) -> bool {
284        self.index.read().is_empty()
285    }
286
287    /// Flushes all pending writes and performs a checkpoint.
288    pub fn flush(&self) -> Result<()> {
289        // Flush data file
290        {
291            let mut df = self.data_file.write();
292            df.flush()?;
293        }
294
295        // Checkpoint WAL
296        {
297            let mut wal = self.wal.write();
298            wal.checkpoint()?;
299        }
300
301        // Save metadata
302        self.save_meta()?;
303
304        Ok(())
305    }
306
307    /// Returns the collection path.
308    pub fn path(&self) -> &Path {
309        &self.path
310    }
311
312    // Internal: apply insert without writing to WAL
313    fn apply_insert_no_wal(&self, id: VectorId, vector: Vec<f32>, payload: Payload) -> Result<()> {
314        // Write to data file
315        let offset = {
316            let mut df = self.data_file.write();
317            df.append(id, &vector, &payload)?
318        };
319
320        // Update index
321        {
322            let mut index = self.index.write();
323            // Remove if exists (for recovery idempotence)
324            index.delete(id);
325            index.insert(id, vector, payload)?;
326        }
327
328        // Track offset
329        {
330            let mut offsets = self.offsets.write();
331            offsets.insert(id, offset);
332        }
333
334        // Update next_id
335        {
336            let mut next_id = self.next_id.write();
337            *next_id = (*next_id).max(id + 1);
338        }
339
340        Ok(())
341    }
342
343    // Internal: apply update without writing to WAL
344    fn apply_update_no_wal(&self, id: VectorId, vector: Vec<f32>, payload: Payload) -> Result<()> {
345        // Mark old record as deleted
346        {
347            let offsets = self.offsets.read();
348            if let Some(&offset) = offsets.get(&id) {
349                let df = self.data_file.read();
350                df.mark_deleted(offset)?;
351            }
352        }
353
354        // Write new record to data file
355        let offset = {
356            let mut df = self.data_file.write();
357            df.append(id, &vector, &payload)?
358        };
359
360        // Update index
361        {
362            let mut index = self.index.write();
363            index.delete(id);
364            index.insert(id, vector, payload)?;
365        }
366
367        // Track new offset
368        {
369            let mut offsets = self.offsets.write();
370            offsets.insert(id, offset);
371        }
372
373        Ok(())
374    }
375
376    // Internal: apply delete without writing to WAL
377    fn apply_delete_no_wal(&self, id: VectorId) -> Result<bool> {
378        // Mark record as deleted in data file
379        {
380            let offsets = self.offsets.read();
381            if let Some(&offset) = offsets.get(&id) {
382                let df = self.data_file.read();
383                df.mark_deleted(offset)?;
384            }
385        }
386
387        // Remove from index
388        let deleted = {
389            let mut index = self.index.write();
390            index.delete(id)
391        };
392
393        // Remove offset tracking
394        {
395            let mut offsets = self.offsets.write();
396            offsets.remove(&id);
397        }
398
399        Ok(deleted)
400    }
401
402    // Saves metadata to disk
403    fn save_meta(&self) -> Result<()> {
404        let meta = CollectionMeta {
405            dimension: self.config.dimension,
406            metric: format!("{:?}", self.config.metric),
407            vector_count: self.len() as u64,
408            next_id: *self.next_id.read(),
409        };
410
411        let content = serde_json::to_string_pretty(&meta)
412            .map_err(|e| Error::CollectionError(format!("serialize meta failed: {}", e)))?;
413
414        let meta_path = self.path.join("meta.json");
415        fs::write(&meta_path, content)
416            .map_err(|e| Error::CollectionError(format!("write meta failed: {}", e)))?;
417
418        Ok(())
419    }
420}
421
422// Async API when tokio feature is enabled
423#[cfg(feature = "async")]
424mod async_api {
425    use super::*;
426    use std::sync::Arc;
427
428    /// Async wrapper for Collection.
429    ///
430    /// Provides async versions of Collection methods using `spawn_blocking`
431    /// for compatibility with async runtimes like Tokio.
432    ///
433    /// # Example
434    ///
435    /// ```ignore
436    /// use polarisdb_core::{AsyncCollection, CollectionConfig, DistanceMetric, Payload};
437    ///
438    /// #[tokio::main]
439    /// async fn main() {
440    ///     let config = CollectionConfig::new(384, DistanceMetric::Cosine);
441    ///     let collection = AsyncCollection::open_or_create("./my_vectors", config).await.unwrap();
442    ///
443    ///     collection.insert(1, vec![0.1; 384], Payload::new()).await.unwrap();
444    ///     let results = collection.search(&[0.1; 384], 10, None).await;
445    /// }
446    /// ```
447    #[derive(Clone)]
448    pub struct AsyncCollection {
449        inner: Arc<Collection>,
450    }
451
452    impl AsyncCollection {
453        /// Opens or creates a collection asynchronously.
454        pub async fn open_or_create<P: AsRef<std::path::Path> + Send + 'static>(
455            path: P,
456            config: CollectionConfig,
457        ) -> Result<Self> {
458            let path = path.as_ref().to_path_buf();
459            let collection =
460                tokio::task::spawn_blocking(move || Collection::open_or_create(path, config))
461                    .await
462                    .map_err(|e| {
463                        Error::CollectionError(format!("spawn_blocking failed: {}", e))
464                    })??;
465
466            Ok(Self {
467                inner: Arc::new(collection),
468            })
469        }
470
471        /// Wraps an existing Collection in an async wrapper.
472        pub fn from_sync(collection: Collection) -> Self {
473            Self {
474                inner: Arc::new(collection),
475            }
476        }
477
478        /// Inserts a vector asynchronously.
479        pub async fn insert(&self, id: VectorId, vector: Vec<f32>, payload: Payload) -> Result<()> {
480            let inner = Arc::clone(&self.inner);
481            tokio::task::spawn_blocking(move || inner.insert(id, vector, payload))
482                .await
483                .map_err(|e| Error::CollectionError(format!("spawn_blocking failed: {}", e)))?
484        }
485
486        /// Inserts with auto-generated ID asynchronously.
487        pub async fn insert_auto(&self, vector: Vec<f32>, payload: Payload) -> Result<VectorId> {
488            let inner = Arc::clone(&self.inner);
489            tokio::task::spawn_blocking(move || inner.insert_auto(vector, payload))
490                .await
491                .map_err(|e| Error::CollectionError(format!("spawn_blocking failed: {}", e)))?
492        }
493
494        /// Updates a vector asynchronously.
495        pub async fn update(&self, id: VectorId, vector: Vec<f32>, payload: Payload) -> Result<()> {
496            let inner = Arc::clone(&self.inner);
497            tokio::task::spawn_blocking(move || inner.update(id, vector, payload))
498                .await
499                .map_err(|e| Error::CollectionError(format!("spawn_blocking failed: {}", e)))?
500        }
501
502        /// Deletes a vector asynchronously.
503        pub async fn delete(&self, id: VectorId) -> Result<bool> {
504            let inner = Arc::clone(&self.inner);
505            tokio::task::spawn_blocking(move || inner.delete(id))
506                .await
507                .map_err(|e| Error::CollectionError(format!("spawn_blocking failed: {}", e)))?
508        }
509
510        /// Searches for similar vectors asynchronously.
511        pub async fn search(
512            &self,
513            query: &[f32],
514            k: usize,
515            filter: Option<Filter>,
516        ) -> Vec<SearchResult> {
517            let inner = Arc::clone(&self.inner);
518            let query = query.to_vec();
519            tokio::task::spawn_blocking(move || inner.search(&query, k, filter))
520                .await
521                .unwrap_or_default()
522        }
523
524        /// Gets a vector by ID asynchronously.
525        pub async fn get(&self, id: VectorId) -> Option<(Vec<f32>, Payload)> {
526            let inner = Arc::clone(&self.inner);
527            tokio::task::spawn_blocking(move || inner.get(id))
528                .await
529                .ok()
530                .flatten()
531        }
532
533        /// Returns the number of vectors.
534        pub fn len(&self) -> usize {
535            self.inner.len()
536        }
537
538        /// Returns true if empty.
539        pub fn is_empty(&self) -> bool {
540            self.inner.is_empty()
541        }
542
543        /// Flushes all pending writes asynchronously.
544        pub async fn flush(&self) -> Result<()> {
545            let inner = Arc::clone(&self.inner);
546            tokio::task::spawn_blocking(move || inner.flush())
547                .await
548                .map_err(|e| Error::CollectionError(format!("spawn_blocking failed: {}", e)))?
549        }
550
551        /// Returns reference to inner sync collection.
552        pub fn inner(&self) -> &Collection {
553            &self.inner
554        }
555    }
556}
557
558#[cfg(feature = "async")]
559pub use async_api::AsyncCollection;
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564    use std::sync::atomic::{AtomicU64, Ordering};
565
566    static TEST_COUNTER: AtomicU64 = AtomicU64::new(0);
567
568    fn temp_collection_path() -> PathBuf {
569        let id = TEST_COUNTER.fetch_add(1, Ordering::SeqCst);
570        let dir = std::env::temp_dir()
571            .join("polarisdb_test_col")
572            .join(format!("col_{}_{}", std::process::id(), id));
573        let _ = fs::remove_dir_all(&dir);
574        dir
575    }
576
577    #[test]
578    fn test_collection_create_and_insert() {
579        let path = temp_collection_path();
580        let config = CollectionConfig::new(3, DistanceMetric::Euclidean);
581
582        let col = Collection::open_or_create(&path, config).unwrap();
583        col.insert(
584            1,
585            vec![1.0, 2.0, 3.0],
586            Payload::new().with_field("key", "val"),
587        )
588        .unwrap();
589
590        assert_eq!(col.len(), 1);
591
592        let (vec, payload) = col.get(1).unwrap();
593        assert_eq!(vec, vec![1.0, 2.0, 3.0]);
594        assert_eq!(payload.get_str("key"), Some("val"));
595
596        let _ = fs::remove_dir_all(&path);
597    }
598
599    #[test]
600    fn test_collection_persistence() {
601        let path = temp_collection_path();
602        let config = CollectionConfig::new(3, DistanceMetric::Euclidean);
603
604        // Create and insert
605        {
606            let col = Collection::open_or_create(&path, config.clone()).unwrap();
607            col.insert(1, vec![1.0, 2.0, 3.0], Payload::new()).unwrap();
608            col.insert(2, vec![4.0, 5.0, 6.0], Payload::new()).unwrap();
609            col.flush().unwrap();
610        }
611
612        // Reopen and verify
613        {
614            let col = Collection::open_or_create(&path, config).unwrap();
615            assert_eq!(col.len(), 2);
616            assert!(col.get(1).is_some());
617            assert!(col.get(2).is_some());
618        }
619
620        let _ = fs::remove_dir_all(&path);
621    }
622
623    #[test]
624    fn test_collection_delete() {
625        let path = temp_collection_path();
626        let config = CollectionConfig::new(3, DistanceMetric::Euclidean);
627
628        let col = Collection::open_or_create(&path, config).unwrap();
629        col.insert(1, vec![1.0, 2.0, 3.0], Payload::new()).unwrap();
630        assert_eq!(col.len(), 1);
631
632        col.delete(1).unwrap();
633        assert_eq!(col.len(), 0);
634        assert!(col.get(1).is_none());
635
636        let _ = fs::remove_dir_all(&path);
637    }
638
639    #[test]
640    fn test_collection_search() {
641        let path = temp_collection_path();
642        let config = CollectionConfig::new(3, DistanceMetric::Euclidean);
643
644        let col = Collection::open_or_create(&path, config).unwrap();
645        col.insert(1, vec![1.0, 0.0, 0.0], Payload::new()).unwrap();
646        col.insert(2, vec![0.0, 1.0, 0.0], Payload::new()).unwrap();
647        col.insert(3, vec![0.0, 0.0, 1.0], Payload::new()).unwrap();
648
649        let results = col.search(&[1.0, 0.0, 0.0], 1, None);
650        assert_eq!(results.len(), 1);
651        assert_eq!(results[0].id, 1);
652
653        let _ = fs::remove_dir_all(&path);
654    }
655
656    #[test]
657    fn test_collection_update() {
658        let path = temp_collection_path();
659        let config = CollectionConfig::new(3, DistanceMetric::Euclidean);
660
661        let col = Collection::open_or_create(&path, config).unwrap();
662        col.insert(1, vec![1.0, 2.0, 3.0], Payload::new().with_field("v", 1))
663            .unwrap();
664        col.update(1, vec![4.0, 5.0, 6.0], Payload::new().with_field("v", 2))
665            .unwrap();
666
667        let (vec, payload) = col.get(1).unwrap();
668        assert_eq!(vec, vec![4.0, 5.0, 6.0]);
669        assert_eq!(payload.get_i64("v"), Some(2));
670
671        let _ = fs::remove_dir_all(&path);
672    }
673
674    #[test]
675    fn test_collection_recovery_after_crash() {
676        let path = temp_collection_path();
677        let config = CollectionConfig::new(3, DistanceMetric::Euclidean);
678
679        // Simulate writes without checkpoint (mimics crash)
680        {
681            let col = Collection::open_or_create(&path, config.clone()).unwrap();
682            col.insert(1, vec![1.0, 2.0, 3.0], Payload::new()).unwrap();
683            col.insert(2, vec![4.0, 5.0, 6.0], Payload::new()).unwrap();
684            // No flush() - simulates crash before checkpoint
685        }
686
687        // Reopen - should recover from WAL
688        {
689            let col = Collection::open_or_create(&path, config).unwrap();
690            assert_eq!(col.len(), 2);
691        }
692
693        let _ = fs::remove_dir_all(&path);
694    }
695}