1use std::collections::HashMap;
9use std::fs;
10use std::path::{Path, PathBuf};
11
12use parking_lot::RwLock;
13
14use crate::distance::DistanceMetric;
15use crate::error::{Error, Result};
16use crate::filter::Filter;
17use crate::index::brute_force::{BruteForceIndex, SearchResult};
18use crate::payload::Payload;
19use crate::storage::data_file::DataFile;
20use crate::storage::wal::{SyncMode, Wal, WalEntry, WalEntryKind};
21use crate::vector::VectorId;
22
23#[derive(Debug, Clone)]
25pub struct CollectionConfig {
26 pub dimension: usize,
28 pub metric: DistanceMetric,
30 pub sync_mode: SyncMode,
32}
33
34impl CollectionConfig {
35 pub fn new(dimension: usize, metric: DistanceMetric) -> Self {
37 Self {
38 dimension,
39 metric,
40 sync_mode: SyncMode::Batched,
41 }
42 }
43
44 pub fn with_sync_mode(mut self, mode: SyncMode) -> Self {
46 self.sync_mode = mode;
47 self
48 }
49}
50
51#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
53struct CollectionMeta {
54 dimension: usize,
55 metric: String,
56 vector_count: u64,
57 next_id: u64,
58}
59
60pub struct Collection {
86 path: PathBuf,
88 config: CollectionConfig,
90 index: RwLock<BruteForceIndex>,
92 wal: RwLock<Wal>,
94 data_file: RwLock<DataFile>,
96 offsets: RwLock<HashMap<VectorId, u64>>,
98 next_id: RwLock<u64>,
100}
101
102impl Collection {
103 pub fn open_or_create<P: AsRef<Path>>(path: P, config: CollectionConfig) -> Result<Self> {
105 let path = path.as_ref().to_path_buf();
106
107 if !path.exists() {
109 fs::create_dir_all(&path)
110 .map_err(|e| Error::CollectionError(format!("create dir failed: {}", e)))?;
111 }
112
113 let meta_path = path.join("meta.json");
114 let wal_path = path.join("wal.log");
115 let data_path = path.join("data.pdb");
116
117 let (meta, is_new) = if meta_path.exists() {
119 let content = fs::read_to_string(&meta_path)
120 .map_err(|e| Error::CollectionError(format!("read meta failed: {}", e)))?;
121 let meta: CollectionMeta = serde_json::from_str(&content)
122 .map_err(|e| Error::CollectionError(format!("parse meta failed: {}", e)))?;
123 (meta, false)
124 } else {
125 let meta = CollectionMeta {
126 dimension: config.dimension,
127 metric: format!("{:?}", config.metric),
128 vector_count: 0,
129 next_id: 1,
130 };
131 (meta, true)
132 };
133
134 if !is_new && meta.dimension != config.dimension {
136 return Err(Error::CollectionError(format!(
137 "dimension mismatch: collection has {}, config has {}",
138 meta.dimension, config.dimension
139 )));
140 }
141
142 let wal = Wal::open(&wal_path, config.sync_mode)?;
144 let data_file = DataFile::open(&data_path)?;
145
146 let index = BruteForceIndex::new(config.metric, config.dimension);
148
149 let collection = Self {
150 path: path.clone(),
151 config,
152 index: RwLock::new(index),
153 wal: RwLock::new(wal),
154 data_file: RwLock::new(data_file),
155 offsets: RwLock::new(HashMap::new()),
156 next_id: RwLock::new(meta.next_id),
157 };
158
159 collection.recover()?;
161
162 if is_new {
164 collection.save_meta()?;
165 }
166
167 Ok(collection)
168 }
169
170 fn recover(&self) -> Result<()> {
172 let records = {
174 let df = self.data_file.read();
175 df.iter_active()?
176 };
177
178 {
179 let mut index = self.index.write();
180 let mut offsets = self.offsets.write();
181 let mut max_id = 0u64;
182
183 for record in records {
184 let _ = index.insert(record.id, record.vector.clone(), record.payload.clone());
185 offsets.insert(record.id, record.offset);
186 max_id = max_id.max(record.id);
187 }
188
189 *self.next_id.write() = max_id + 1;
190 }
191
192 let wal_path = self.path.join("wal.log");
194 let entries = Wal::read_all(&wal_path)?;
195
196 for entry in entries {
197 match entry.kind {
198 WalEntryKind::Insert => {
199 self.apply_insert_no_wal(entry.id, entry.vector, entry.payload)?;
200 }
201 WalEntryKind::Update => {
202 self.apply_update_no_wal(entry.id, entry.vector, entry.payload)?;
203 }
204 WalEntryKind::Delete => {
205 self.apply_delete_no_wal(entry.id)?;
206 }
207 WalEntryKind::Checkpoint => {
208 }
210 }
211 }
212
213 Ok(())
214 }
215
216 pub fn insert(&self, id: VectorId, vector: Vec<f32>, payload: Payload) -> Result<()> {
218 {
220 let mut wal = self.wal.write();
221 wal.append(&WalEntry::insert(id, vector.clone(), payload.clone()))?;
222 }
223
224 self.apply_insert_no_wal(id, vector, payload)
226 }
227
228 pub fn insert_auto(&self, vector: Vec<f32>, payload: Payload) -> Result<VectorId> {
230 let id = {
231 let mut next_id = self.next_id.write();
232 let id = *next_id;
233 *next_id += 1;
234 id
235 };
236
237 self.insert(id, vector, payload)?;
238 Ok(id)
239 }
240
241 pub fn update(&self, id: VectorId, vector: Vec<f32>, payload: Payload) -> Result<()> {
243 {
245 let mut wal = self.wal.write();
246 wal.append(&WalEntry::update(id, vector.clone(), payload.clone()))?;
247 }
248
249 self.apply_update_no_wal(id, vector, payload)
250 }
251
252 pub fn delete(&self, id: VectorId) -> Result<bool> {
254 {
256 let mut wal = self.wal.write();
257 wal.append(&WalEntry::delete(id))?;
258 }
259
260 self.apply_delete_no_wal(id)
261 }
262
263 pub fn search(&self, query: &[f32], k: usize, filter: Option<Filter>) -> Vec<SearchResult> {
265 let index = self.index.read();
266 index.search(query, k, filter)
267 }
268
269 pub fn get(&self, id: VectorId) -> Option<(Vec<f32>, Payload)> {
271 let index = self.index.read();
272 index
273 .get(id)
274 .map(|(v, p)| (v.as_slice().to_vec(), p.clone()))
275 }
276
277 pub fn len(&self) -> usize {
279 self.index.read().len()
280 }
281
282 pub fn is_empty(&self) -> bool {
284 self.index.read().is_empty()
285 }
286
287 pub fn flush(&self) -> Result<()> {
289 {
291 let mut df = self.data_file.write();
292 df.flush()?;
293 }
294
295 {
297 let mut wal = self.wal.write();
298 wal.checkpoint()?;
299 }
300
301 self.save_meta()?;
303
304 Ok(())
305 }
306
307 pub fn path(&self) -> &Path {
309 &self.path
310 }
311
312 fn apply_insert_no_wal(&self, id: VectorId, vector: Vec<f32>, payload: Payload) -> Result<()> {
314 let offset = {
316 let mut df = self.data_file.write();
317 df.append(id, &vector, &payload)?
318 };
319
320 {
322 let mut index = self.index.write();
323 index.delete(id);
325 index.insert(id, vector, payload)?;
326 }
327
328 {
330 let mut offsets = self.offsets.write();
331 offsets.insert(id, offset);
332 }
333
334 {
336 let mut next_id = self.next_id.write();
337 *next_id = (*next_id).max(id + 1);
338 }
339
340 Ok(())
341 }
342
343 fn apply_update_no_wal(&self, id: VectorId, vector: Vec<f32>, payload: Payload) -> Result<()> {
345 {
347 let offsets = self.offsets.read();
348 if let Some(&offset) = offsets.get(&id) {
349 let df = self.data_file.read();
350 df.mark_deleted(offset)?;
351 }
352 }
353
354 let offset = {
356 let mut df = self.data_file.write();
357 df.append(id, &vector, &payload)?
358 };
359
360 {
362 let mut index = self.index.write();
363 index.delete(id);
364 index.insert(id, vector, payload)?;
365 }
366
367 {
369 let mut offsets = self.offsets.write();
370 offsets.insert(id, offset);
371 }
372
373 Ok(())
374 }
375
376 fn apply_delete_no_wal(&self, id: VectorId) -> Result<bool> {
378 {
380 let offsets = self.offsets.read();
381 if let Some(&offset) = offsets.get(&id) {
382 let df = self.data_file.read();
383 df.mark_deleted(offset)?;
384 }
385 }
386
387 let deleted = {
389 let mut index = self.index.write();
390 index.delete(id)
391 };
392
393 {
395 let mut offsets = self.offsets.write();
396 offsets.remove(&id);
397 }
398
399 Ok(deleted)
400 }
401
402 fn save_meta(&self) -> Result<()> {
404 let meta = CollectionMeta {
405 dimension: self.config.dimension,
406 metric: format!("{:?}", self.config.metric),
407 vector_count: self.len() as u64,
408 next_id: *self.next_id.read(),
409 };
410
411 let content = serde_json::to_string_pretty(&meta)
412 .map_err(|e| Error::CollectionError(format!("serialize meta failed: {}", e)))?;
413
414 let meta_path = self.path.join("meta.json");
415 fs::write(&meta_path, content)
416 .map_err(|e| Error::CollectionError(format!("write meta failed: {}", e)))?;
417
418 Ok(())
419 }
420}
421
422#[cfg(feature = "async")]
424mod async_api {
425 use super::*;
426 use std::sync::Arc;
427
428 #[derive(Clone)]
448 pub struct AsyncCollection {
449 inner: Arc<Collection>,
450 }
451
452 impl AsyncCollection {
453 pub async fn open_or_create<P: AsRef<std::path::Path> + Send + 'static>(
455 path: P,
456 config: CollectionConfig,
457 ) -> Result<Self> {
458 let path = path.as_ref().to_path_buf();
459 let collection =
460 tokio::task::spawn_blocking(move || Collection::open_or_create(path, config))
461 .await
462 .map_err(|e| {
463 Error::CollectionError(format!("spawn_blocking failed: {}", e))
464 })??;
465
466 Ok(Self {
467 inner: Arc::new(collection),
468 })
469 }
470
471 pub fn from_sync(collection: Collection) -> Self {
473 Self {
474 inner: Arc::new(collection),
475 }
476 }
477
478 pub async fn insert(&self, id: VectorId, vector: Vec<f32>, payload: Payload) -> Result<()> {
480 let inner = Arc::clone(&self.inner);
481 tokio::task::spawn_blocking(move || inner.insert(id, vector, payload))
482 .await
483 .map_err(|e| Error::CollectionError(format!("spawn_blocking failed: {}", e)))?
484 }
485
486 pub async fn insert_auto(&self, vector: Vec<f32>, payload: Payload) -> Result<VectorId> {
488 let inner = Arc::clone(&self.inner);
489 tokio::task::spawn_blocking(move || inner.insert_auto(vector, payload))
490 .await
491 .map_err(|e| Error::CollectionError(format!("spawn_blocking failed: {}", e)))?
492 }
493
494 pub async fn update(&self, id: VectorId, vector: Vec<f32>, payload: Payload) -> Result<()> {
496 let inner = Arc::clone(&self.inner);
497 tokio::task::spawn_blocking(move || inner.update(id, vector, payload))
498 .await
499 .map_err(|e| Error::CollectionError(format!("spawn_blocking failed: {}", e)))?
500 }
501
502 pub async fn delete(&self, id: VectorId) -> Result<bool> {
504 let inner = Arc::clone(&self.inner);
505 tokio::task::spawn_blocking(move || inner.delete(id))
506 .await
507 .map_err(|e| Error::CollectionError(format!("spawn_blocking failed: {}", e)))?
508 }
509
510 pub async fn search(
512 &self,
513 query: &[f32],
514 k: usize,
515 filter: Option<Filter>,
516 ) -> Vec<SearchResult> {
517 let inner = Arc::clone(&self.inner);
518 let query = query.to_vec();
519 tokio::task::spawn_blocking(move || inner.search(&query, k, filter))
520 .await
521 .unwrap_or_default()
522 }
523
524 pub async fn get(&self, id: VectorId) -> Option<(Vec<f32>, Payload)> {
526 let inner = Arc::clone(&self.inner);
527 tokio::task::spawn_blocking(move || inner.get(id))
528 .await
529 .ok()
530 .flatten()
531 }
532
533 pub fn len(&self) -> usize {
535 self.inner.len()
536 }
537
538 pub fn is_empty(&self) -> bool {
540 self.inner.is_empty()
541 }
542
543 pub async fn flush(&self) -> Result<()> {
545 let inner = Arc::clone(&self.inner);
546 tokio::task::spawn_blocking(move || inner.flush())
547 .await
548 .map_err(|e| Error::CollectionError(format!("spawn_blocking failed: {}", e)))?
549 }
550
551 pub fn inner(&self) -> &Collection {
553 &self.inner
554 }
555 }
556}
557
558#[cfg(feature = "async")]
559pub use async_api::AsyncCollection;
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use std::sync::atomic::{AtomicU64, Ordering};
565
566 static TEST_COUNTER: AtomicU64 = AtomicU64::new(0);
567
568 fn temp_collection_path() -> PathBuf {
569 let id = TEST_COUNTER.fetch_add(1, Ordering::SeqCst);
570 let dir = std::env::temp_dir()
571 .join("polarisdb_test_col")
572 .join(format!("col_{}_{}", std::process::id(), id));
573 let _ = fs::remove_dir_all(&dir);
574 dir
575 }
576
577 #[test]
578 fn test_collection_create_and_insert() {
579 let path = temp_collection_path();
580 let config = CollectionConfig::new(3, DistanceMetric::Euclidean);
581
582 let col = Collection::open_or_create(&path, config).unwrap();
583 col.insert(
584 1,
585 vec![1.0, 2.0, 3.0],
586 Payload::new().with_field("key", "val"),
587 )
588 .unwrap();
589
590 assert_eq!(col.len(), 1);
591
592 let (vec, payload) = col.get(1).unwrap();
593 assert_eq!(vec, vec![1.0, 2.0, 3.0]);
594 assert_eq!(payload.get_str("key"), Some("val"));
595
596 let _ = fs::remove_dir_all(&path);
597 }
598
599 #[test]
600 fn test_collection_persistence() {
601 let path = temp_collection_path();
602 let config = CollectionConfig::new(3, DistanceMetric::Euclidean);
603
604 {
606 let col = Collection::open_or_create(&path, config.clone()).unwrap();
607 col.insert(1, vec![1.0, 2.0, 3.0], Payload::new()).unwrap();
608 col.insert(2, vec![4.0, 5.0, 6.0], Payload::new()).unwrap();
609 col.flush().unwrap();
610 }
611
612 {
614 let col = Collection::open_or_create(&path, config).unwrap();
615 assert_eq!(col.len(), 2);
616 assert!(col.get(1).is_some());
617 assert!(col.get(2).is_some());
618 }
619
620 let _ = fs::remove_dir_all(&path);
621 }
622
623 #[test]
624 fn test_collection_delete() {
625 let path = temp_collection_path();
626 let config = CollectionConfig::new(3, DistanceMetric::Euclidean);
627
628 let col = Collection::open_or_create(&path, config).unwrap();
629 col.insert(1, vec![1.0, 2.0, 3.0], Payload::new()).unwrap();
630 assert_eq!(col.len(), 1);
631
632 col.delete(1).unwrap();
633 assert_eq!(col.len(), 0);
634 assert!(col.get(1).is_none());
635
636 let _ = fs::remove_dir_all(&path);
637 }
638
639 #[test]
640 fn test_collection_search() {
641 let path = temp_collection_path();
642 let config = CollectionConfig::new(3, DistanceMetric::Euclidean);
643
644 let col = Collection::open_or_create(&path, config).unwrap();
645 col.insert(1, vec![1.0, 0.0, 0.0], Payload::new()).unwrap();
646 col.insert(2, vec![0.0, 1.0, 0.0], Payload::new()).unwrap();
647 col.insert(3, vec![0.0, 0.0, 1.0], Payload::new()).unwrap();
648
649 let results = col.search(&[1.0, 0.0, 0.0], 1, None);
650 assert_eq!(results.len(), 1);
651 assert_eq!(results[0].id, 1);
652
653 let _ = fs::remove_dir_all(&path);
654 }
655
656 #[test]
657 fn test_collection_update() {
658 let path = temp_collection_path();
659 let config = CollectionConfig::new(3, DistanceMetric::Euclidean);
660
661 let col = Collection::open_or_create(&path, config).unwrap();
662 col.insert(1, vec![1.0, 2.0, 3.0], Payload::new().with_field("v", 1))
663 .unwrap();
664 col.update(1, vec![4.0, 5.0, 6.0], Payload::new().with_field("v", 2))
665 .unwrap();
666
667 let (vec, payload) = col.get(1).unwrap();
668 assert_eq!(vec, vec![4.0, 5.0, 6.0]);
669 assert_eq!(payload.get_i64("v"), Some(2));
670
671 let _ = fs::remove_dir_all(&path);
672 }
673
674 #[test]
675 fn test_collection_recovery_after_crash() {
676 let path = temp_collection_path();
677 let config = CollectionConfig::new(3, DistanceMetric::Euclidean);
678
679 {
681 let col = Collection::open_or_create(&path, config.clone()).unwrap();
682 col.insert(1, vec![1.0, 2.0, 3.0], Payload::new()).unwrap();
683 col.insert(2, vec![4.0, 5.0, 6.0], Payload::new()).unwrap();
684 }
686
687 {
689 let col = Collection::open_or_create(&path, config).unwrap();
690 assert_eq!(col.len(), 2);
691 }
692
693 let _ = fs::remove_dir_all(&path);
694 }
695}