distx_core/
collection.rs

1use crate::{Error, Point, Result, Vector, HnswIndex, BM25Index, Filter, MultiVector};
2use parking_lot::RwLock;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6
7/// Configuration for a collection
8#[derive(Debug, Clone)]
9pub struct CollectionConfig {
10    pub name: String,
11    pub vector_dim: usize,
12    pub distance: Distance,
13    pub use_hnsw: bool,
14    pub enable_bm25: bool,
15}
16
17impl Default for CollectionConfig {
18    fn default() -> Self {
19        Self {
20            name: String::new(),
21            vector_dim: 128,
22            distance: Distance::Cosine,
23            use_hnsw: true,
24            enable_bm25: false,
25        }
26    }
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum Distance {
31    Cosine,
32    Euclidean,
33    Dot,
34}
35
36/// Payload field index type
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum PayloadIndexType {
39    Keyword,
40    Integer,
41    Float,
42    Bool,
43    Geo,
44    Text,
45}
46
47/// A collection of vectors with metadata
48pub struct Collection {
49    config: CollectionConfig,
50    points: Arc<RwLock<HashMap<String, Point>>>,
51    hnsw: Option<Arc<RwLock<HnswIndex>>>,
52    bm25: Option<Arc<RwLock<BM25Index>>>,
53    hnsw_built: Arc<RwLock<bool>>,
54    hnsw_rebuilding: Arc<AtomicBool>,
55    batch_mode: Arc<RwLock<bool>>,
56    pending_points: Arc<RwLock<Vec<Point>>>,
57    /// Payload field indexes
58    payload_indexes: Arc<RwLock<HashMap<String, PayloadIndexType>>>,
59}
60
61impl Collection {
62    pub fn new(config: CollectionConfig) -> Self {
63        let hnsw = if config.use_hnsw {
64            Some(Arc::new(RwLock::new(HnswIndex::new(16, 3))))
65        } else {
66            None
67        };
68
69        let bm25 = if config.enable_bm25 {
70            Some(Arc::new(RwLock::new(BM25Index::new())))
71        } else {
72            None
73        };
74
75        Self {
76            config,
77            points: Arc::new(RwLock::new(HashMap::new())),
78            hnsw,
79            bm25,
80            hnsw_built: Arc::new(RwLock::new(false)),
81            hnsw_rebuilding: Arc::new(AtomicBool::new(false)),
82            batch_mode: Arc::new(RwLock::new(false)),
83            pending_points: Arc::new(RwLock::new(Vec::new())),
84            payload_indexes: Arc::new(RwLock::new(HashMap::new())),
85        }
86    }
87
88    #[inline]
89    #[must_use]
90    pub fn name(&self) -> &str {
91        &self.config.name
92    }
93
94    #[inline]
95    #[must_use]
96    pub fn vector_dim(&self) -> usize {
97        self.config.vector_dim
98    }
99
100    #[inline]
101    #[must_use]
102    pub fn distance(&self) -> Distance {
103        self.config.distance
104    }
105
106    #[inline]
107    #[must_use]
108    pub fn use_hnsw(&self) -> bool {
109        self.config.use_hnsw
110    }
111
112    #[inline]
113    #[must_use]
114    pub fn enable_bm25(&self) -> bool {
115        self.config.enable_bm25
116    }
117
118    #[inline]
119    #[must_use]
120    pub fn count(&self) -> usize {
121        self.points.read().len()
122    }
123
124    #[inline]
125    #[must_use]
126    pub fn is_empty(&self) -> bool {
127        self.points.read().is_empty()
128    }
129
130    /// Get all points in the collection
131    pub fn get_all_points(&self) -> Vec<Point> {
132        self.points.read().values().cloned().collect()
133    }
134
135    /// Insert or update a point
136    pub fn upsert(&self, point: Point) -> Result<()> {
137        // Skip dimension check for sparse-only collections (vector_dim == 0)
138        if self.config.vector_dim > 0 && point.vector.dim() != self.config.vector_dim {
139            return Err(Error::InvalidDimension {
140                expected: self.config.vector_dim,
141                actual: point.vector.dim(),
142            });
143        }
144
145        let id_str = point.id.to_string();
146        
147        let in_batch = *self.batch_mode.read();
148        if in_batch {
149            self.points.write().insert(id_str.clone(), point.clone());
150            self.pending_points.write().push(point);
151            return Ok(());
152        }
153        
154        if let Some(hnsw) = &self.hnsw {
155            let built = *self.hnsw_built.read();
156            if built {
157                let mut normalized_point = point.clone();
158                normalized_point.vector.normalize();
159                
160                let mut index = hnsw.write();
161                index.insert(normalized_point);
162            }
163        }
164
165        if let Some(bm25) = &self.bm25 {
166            if let Some(payload) = &point.payload {
167                if let Some(text) = payload.get("text").and_then(|v| v.as_str()) {
168                    let mut index = bm25.write();
169                    index.insert_doc(&id_str, text);
170                }
171            }
172        }
173
174        self.points.write().insert(id_str, point);
175        Ok(())
176    }
177
178    /// Start batch insert mode
179    pub fn start_batch(&self) {
180        *self.batch_mode.write() = true;
181        self.pending_points.write().clear();
182    }
183
184    /// End batch insert mode
185    pub fn end_batch(&self) -> Result<()> {
186        *self.batch_mode.write() = false;
187        
188        if let Some(hnsw) = &self.hnsw {
189            let points = self.points.read();
190            let point_count = points.len();
191            
192            const HNSW_REBUILD_THRESHOLD: usize = 10_000;
193            
194            if point_count > HNSW_REBUILD_THRESHOLD && !self.hnsw_rebuilding.load(Ordering::Acquire) {
195                self.hnsw_rebuilding.store(true, Ordering::Release);
196                let points_clone: Vec<Point> = points.values().cloned().collect();
197                let hnsw_clone = hnsw.clone();
198                let built_flag = self.hnsw_built.clone();
199                let rebuilding_flag = self.hnsw_rebuilding.clone();
200                
201                let job = crate::background::HnswRebuildJob::new(
202                    points_clone,
203                    hnsw_clone,
204                    built_flag,
205                    rebuilding_flag,
206                );
207                crate::background::get_background_system().submit(Box::new(job));
208            }
209        }
210        
211        self.pending_points.write().clear();
212        Ok(())
213    }
214
215    /// Batch insert multiple points
216    pub fn batch_upsert(&self, points: Vec<Point>) -> Result<()> {
217        self.start_batch();
218        for point in points {
219            self.upsert(point)?;
220        }
221        self.end_batch()?;
222        Ok(())
223    }
224
225    /// Batch insert with optional pre-warming
226    pub fn batch_upsert_with_prewarm(&self, points: Vec<Point>, prewarm: bool) -> Result<()> {
227        self.batch_upsert(points)?;
228        if prewarm {
229            self.prewarm_index()?;
230        }
231        Ok(())
232    }
233
234    /// Get a point by ID
235    #[inline]
236    pub fn get(&self, id: &str) -> Option<Point> {
237        self.points.read().get(id).cloned()
238    }
239
240    /// Delete a point by ID
241    pub fn delete(&self, id: &str) -> Result<bool> {
242        if let Some(hnsw) = &self.hnsw {
243            let mut index = hnsw.write();
244            index.remove(id);
245        }
246
247        if let Some(bm25) = &self.bm25 {
248            let mut index = bm25.write();
249            index.delete_doc(id);
250        }
251
252        let mut points = self.points.write();
253        Ok(points.remove(id).is_some())
254    }
255
256    /// Set payload values for a point (merge with existing)
257    pub fn set_payload(&self, id: &str, payload: serde_json::Value) -> Result<bool> {
258        let mut points = self.points.write();
259        if let Some(point) = points.get_mut(id) {
260            if let Some(existing) = &mut point.payload {
261                if let (Some(existing_obj), Some(new_obj)) = (existing.as_object_mut(), payload.as_object()) {
262                    for (key, value) in new_obj {
263                        existing_obj.insert(key.clone(), value.clone());
264                    }
265                }
266            } else {
267                point.payload = Some(payload);
268            }
269            Ok(true)
270        } else {
271            Ok(false)
272        }
273    }
274
275    /// Overwrite entire payload for a point
276    pub fn overwrite_payload(&self, id: &str, payload: serde_json::Value) -> Result<bool> {
277        let mut points = self.points.write();
278        if let Some(point) = points.get_mut(id) {
279            point.payload = Some(payload);
280            Ok(true)
281        } else {
282            Ok(false)
283        }
284    }
285
286    /// Delete specific payload keys from a point
287    pub fn delete_payload_keys(&self, id: &str, keys: &[String]) -> Result<bool> {
288        let mut points = self.points.write();
289        if let Some(point) = points.get_mut(id) {
290            if let Some(payload) = &mut point.payload {
291                if let Some(obj) = payload.as_object_mut() {
292                    for key in keys {
293                        obj.remove(key);
294                    }
295                }
296            }
297            Ok(true)
298        } else {
299            Ok(false)
300        }
301    }
302
303    /// Clear all payload from a point
304    pub fn clear_payload(&self, id: &str) -> Result<bool> {
305        let mut points = self.points.write();
306        if let Some(point) = points.get_mut(id) {
307            point.payload = None;
308            Ok(true)
309        } else {
310            Ok(false)
311        }
312    }
313
314    /// Update vector for a point
315    pub fn update_vector(&self, id: &str, vector: Vector) -> Result<bool> {
316        let mut points = self.points.write();
317        if let Some(point) = points.get_mut(id) {
318            point.vector = vector.clone();
319            
320            // Update HNSW index if present
321            if let Some(hnsw) = &self.hnsw {
322                let mut index = hnsw.write();
323                index.remove(id);
324                // Insert the updated point
325                index.insert(point.clone());
326            }
327            Ok(true)
328        } else {
329            Ok(false)
330        }
331    }
332
333    /// Update multivector for a point
334    pub fn update_multivector(&self, id: &str, multivector: Option<MultiVector>) -> Result<bool> {
335        let mut points = self.points.write();
336        if let Some(point) = points.get_mut(id) {
337            point.multivector = multivector;
338            Ok(true)
339        } else {
340            Ok(false)
341        }
342    }
343
344    /// Delete vector (set to empty) - for named vectors this would delete specific vector
345    pub fn delete_vector(&self, id: &str) -> Result<bool> {
346        // For now, deleting a vector means deleting the point
347        // In full implementation, named vectors would be individually deletable
348        self.delete(id)
349    }
350
351    /// Create a payload field index
352    pub fn create_payload_index(&self, field_name: &str, index_type: PayloadIndexType) -> Result<bool> {
353        let mut indexes = self.payload_indexes.write();
354        indexes.insert(field_name.to_string(), index_type);
355        Ok(true)
356    }
357
358    /// Delete a payload field index
359    pub fn delete_payload_index(&self, field_name: &str) -> Result<bool> {
360        let mut indexes = self.payload_indexes.write();
361        Ok(indexes.remove(field_name).is_some())
362    }
363
364    /// Get all payload indexes
365    pub fn get_payload_indexes(&self) -> HashMap<String, PayloadIndexType> {
366        self.payload_indexes.read().clone()
367    }
368
369    /// Check if a field is indexed
370    pub fn is_field_indexed(&self, field_name: &str) -> bool {
371        self.payload_indexes.read().contains_key(field_name)
372    }
373
374    /// Pre-warm HNSW index
375    pub fn prewarm_index(&self) -> Result<()> {
376        if let Some(hnsw) = &self.hnsw {
377            let mut built = self.hnsw_built.write();
378            if !*built {
379                let points = self.points.read();
380                if !points.is_empty() {
381                    let mut index = hnsw.write();
382                    *index = HnswIndex::new(16, 3);
383                    for point in points.values() {
384                        index.insert(point.clone());
385                    }
386                    *built = true;
387                }
388            }
389        }
390        Ok(())
391    }
392
393    /// Fast brute-force search using SIMD - optimal for small datasets
394    fn brute_force_search(&self, query: &Vector, limit: usize, filter: Option<&dyn Filter>) -> Vec<(Point, f32)> {
395        let points = self.points.read();
396        let query_slice = query.as_slice();
397        
398        // Pre-allocate results with capacity
399        let mut results: Vec<(Point, f32)> = Vec::with_capacity(points.len().min(limit * 2));
400        
401        for point in points.values() {
402            if let Some(f) = filter {
403                if !f.matches(point) {
404                    continue;
405                }
406            }
407            
408            // Use SIMD-optimized dot product for cosine similarity (vectors are normalized)
409            let score = match self.config.distance {
410                Distance::Cosine => {
411                    crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
412                }
413                Distance::Euclidean => {
414                    -crate::simd::l2_distance_simd(query_slice, point.vector.as_slice())
415                }
416                Distance::Dot => {
417                    crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
418                }
419            };
420            
421            results.push((point.clone(), score));
422        }
423        
424        // Use partial sort for efficiency when limit << len
425        if results.len() > limit {
426            results.select_nth_unstable_by(limit, |a, b| {
427                b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
428            });
429            results.truncate(limit);
430        }
431        
432        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
433        results
434    }
435
436    /// Search for similar vectors
437    /// Uses brute-force for small datasets (<1000), HNSW for larger ones
438    pub fn search(
439        &self,
440        query: &Vector,
441        limit: usize,
442        filter: Option<&dyn Filter>,
443    ) -> Vec<(Point, f32)> {
444        let normalized_query = query.normalized();
445        let point_count = self.points.read().len();
446        
447        // Use brute-force for small datasets - faster than HNSW overhead
448        const BRUTE_FORCE_THRESHOLD: usize = 1000;
449        if point_count < BRUTE_FORCE_THRESHOLD {
450            return self.brute_force_search(&normalized_query, limit, filter);
451        }
452        
453        if let Some(hnsw) = &self.hnsw {
454            // Check if we need to build the index first
455            {
456                let mut built = self.hnsw_built.write();
457                if !*built {
458                    let points = self.points.read();
459                    if !points.is_empty() {
460                        let mut index = hnsw.write();
461                        *index = HnswIndex::new(16, 3);
462                        for point in points.values() {
463                            index.insert(point.clone());
464                        }
465                        *built = true;
466                    }
467                }
468            }
469            
470            // Use write lock for search (HNSW search is now mutable for performance)
471            let mut index = hnsw.write();
472            let mut results = index.search(&normalized_query, limit, None);
473            
474            if let Some(f) = filter {
475                results.retain(|(point, _)| f.matches(point));
476            }
477            
478            results
479        } else {
480            let points = self.points.read();
481            let results: Vec<(Point, f32)> = points
482                .values()
483                .filter(|point| {
484                    filter.map(|f| f.matches(point)).unwrap_or(true)
485                })
486                .map(|point| {
487                    let score = match self.config.distance {
488                        Distance::Cosine => point.vector.cosine_similarity(query),
489                        Distance::Euclidean => -point.vector.l2_distance(query),
490                        Distance::Dot => {
491                            point.vector.as_slice()
492                                .iter()
493                                .zip(query.as_slice().iter())
494                                .map(|(a, b)| a * b)
495                                .sum()
496                        }
497                    };
498                    (point.clone(), score)
499                })
500                .collect();
501
502            let mut sorted = results;
503            sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
504            sorted.truncate(limit);
505            sorted
506        }
507    }
508
509    /// BM25 text search
510    pub fn search_text(&self, query: &str, limit: usize) -> Vec<(String, f32)> {
511        if let Some(bm25) = &self.bm25 {
512            let index = bm25.read();
513            index.search(query, limit)
514        } else {
515            Vec::new()
516        }
517    }
518    
519    /// Search using multivector MaxSim scoring (ColBERT-style)
520    /// 
521    /// For each sub-vector in the query, finds the maximum similarity 
522    /// with any sub-vector in each document, then sums all maximums.
523    pub fn search_multivector(
524        &self,
525        query: &MultiVector,
526        limit: usize,
527        filter: Option<&dyn Filter>,
528    ) -> Vec<(Point, f32)> {
529        let points = self.points.read();
530        
531        let mut results: Vec<(Point, f32)> = Vec::with_capacity(points.len().min(limit * 2));
532        
533        for point in points.values() {
534            if let Some(f) = filter {
535                if !f.matches(point) {
536                    continue;
537                }
538            }
539            
540            // Calculate MaxSim score
541            let score = if let Some(doc_mv) = &point.multivector {
542                // Both query and document have multivectors - use MaxSim
543                match self.config.distance {
544                    Distance::Cosine => query.max_sim_cosine(doc_mv),
545                    Distance::Euclidean => query.max_sim_l2(doc_mv),
546                    Distance::Dot => query.max_sim(doc_mv),
547                }
548            } else {
549                // Document has single vector - wrap it as multivector
550                let doc_mv = MultiVector::from_single(point.vector.as_slice().to_vec())
551                    .unwrap_or_else(|_| MultiVector::new(vec![vec![0.0; query.dim()]]).unwrap());
552                match self.config.distance {
553                    Distance::Cosine => query.max_sim_cosine(&doc_mv),
554                    Distance::Euclidean => query.max_sim_l2(&doc_mv),
555                    Distance::Dot => query.max_sim(&doc_mv),
556                }
557            };
558            
559            results.push((point.clone(), score));
560        }
561        
562        // Sort by score descending
563        if results.len() > limit {
564            results.select_nth_unstable_by(limit, |a, b| {
565                b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
566            });
567            results.truncate(limit);
568        }
569        
570        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
571        results
572    }
573
574    /// Get all points
575    pub fn iter(&self) -> Vec<Point> {
576        self.points.read().values().cloned().collect()
577    }
578}
579