distx_core/
collection.rs

1use crate::{Error, Point, Result, Vector, HnswIndex, BM25Index, Filter};
2use parking_lot::RwLock;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6
7/// Configuration for a collection
8#[derive(Debug, Clone)]
9pub struct CollectionConfig {
10    pub name: String,
11    pub vector_dim: usize,
12    pub distance: Distance,
13    pub use_hnsw: bool,
14    pub enable_bm25: bool,
15}
16
17impl Default for CollectionConfig {
18    fn default() -> Self {
19        Self {
20            name: String::new(),
21            vector_dim: 128,
22            distance: Distance::Cosine,
23            use_hnsw: true,
24            enable_bm25: false,
25        }
26    }
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum Distance {
31    Cosine,
32    Euclidean,
33    Dot,
34}
35
36/// A collection of vectors with metadata
37pub struct Collection {
38    config: CollectionConfig,
39    points: Arc<RwLock<HashMap<String, Point>>>,
40    hnsw: Option<Arc<RwLock<HnswIndex>>>,
41    bm25: Option<Arc<RwLock<BM25Index>>>,
42    hnsw_built: Arc<RwLock<bool>>,
43    hnsw_rebuilding: Arc<AtomicBool>,
44    batch_mode: Arc<RwLock<bool>>,
45    pending_points: Arc<RwLock<Vec<Point>>>,
46}
47
48impl Collection {
49    pub fn new(config: CollectionConfig) -> Self {
50        let hnsw = if config.use_hnsw {
51            Some(Arc::new(RwLock::new(HnswIndex::new(16, 3))))
52        } else {
53            None
54        };
55
56        let bm25 = if config.enable_bm25 {
57            Some(Arc::new(RwLock::new(BM25Index::new())))
58        } else {
59            None
60        };
61
62        Self {
63            config,
64            points: Arc::new(RwLock::new(HashMap::new())),
65            hnsw,
66            bm25,
67            hnsw_built: Arc::new(RwLock::new(false)),
68            hnsw_rebuilding: Arc::new(AtomicBool::new(false)),
69            batch_mode: Arc::new(RwLock::new(false)),
70            pending_points: Arc::new(RwLock::new(Vec::new())),
71        }
72    }
73
74    #[inline]
75    #[must_use]
76    pub fn name(&self) -> &str {
77        &self.config.name
78    }
79
80    #[inline]
81    #[must_use]
82    pub fn vector_dim(&self) -> usize {
83        self.config.vector_dim
84    }
85
86    #[inline]
87    #[must_use]
88    pub fn distance(&self) -> Distance {
89        self.config.distance
90    }
91
92    #[inline]
93    #[must_use]
94    pub fn count(&self) -> usize {
95        self.points.read().len()
96    }
97
98    #[inline]
99    #[must_use]
100    pub fn is_empty(&self) -> bool {
101        self.points.read().is_empty()
102    }
103
104    /// Get all points in the collection
105    pub fn get_all_points(&self) -> Vec<Point> {
106        self.points.read().values().cloned().collect()
107    }
108
109    /// Insert or update a point
110    pub fn upsert(&self, point: Point) -> Result<()> {
111        if point.vector.dim() != self.config.vector_dim {
112            return Err(Error::InvalidDimension {
113                expected: self.config.vector_dim,
114                actual: point.vector.dim(),
115            });
116        }
117
118        let id_str = point.id.to_string();
119        
120        let in_batch = *self.batch_mode.read();
121        if in_batch {
122            self.points.write().insert(id_str.clone(), point.clone());
123            self.pending_points.write().push(point);
124            return Ok(());
125        }
126        
127        if let Some(hnsw) = &self.hnsw {
128            let built = *self.hnsw_built.read();
129            if built {
130                let mut normalized_point = point.clone();
131                normalized_point.vector.normalize();
132                
133                let mut index = hnsw.write();
134                index.insert(normalized_point);
135            }
136        }
137
138        if let Some(bm25) = &self.bm25 {
139            if let Some(payload) = &point.payload {
140                if let Some(text) = payload.get("text").and_then(|v| v.as_str()) {
141                    let mut index = bm25.write();
142                    index.insert_doc(&id_str, text);
143                }
144            }
145        }
146
147        self.points.write().insert(id_str, point);
148        Ok(())
149    }
150
151    /// Start batch insert mode
152    pub fn start_batch(&self) {
153        *self.batch_mode.write() = true;
154        self.pending_points.write().clear();
155    }
156
157    /// End batch insert mode
158    pub fn end_batch(&self) -> Result<()> {
159        *self.batch_mode.write() = false;
160        
161        if let Some(hnsw) = &self.hnsw {
162            let points = self.points.read();
163            let point_count = points.len();
164            
165            const HNSW_REBUILD_THRESHOLD: usize = 10_000;
166            
167            if point_count > HNSW_REBUILD_THRESHOLD && !self.hnsw_rebuilding.load(Ordering::Acquire) {
168                self.hnsw_rebuilding.store(true, Ordering::Release);
169                let points_clone: Vec<Point> = points.values().cloned().collect();
170                let hnsw_clone = hnsw.clone();
171                let built_flag = self.hnsw_built.clone();
172                let rebuilding_flag = self.hnsw_rebuilding.clone();
173                
174                let job = crate::background::HnswRebuildJob::new(
175                    points_clone,
176                    hnsw_clone,
177                    built_flag,
178                    rebuilding_flag,
179                );
180                crate::background::get_background_system().submit(Box::new(job));
181            }
182        }
183        
184        self.pending_points.write().clear();
185        Ok(())
186    }
187
188    /// Batch insert multiple points
189    pub fn batch_upsert(&self, points: Vec<Point>) -> Result<()> {
190        self.start_batch();
191        for point in points {
192            self.upsert(point)?;
193        }
194        self.end_batch()?;
195        Ok(())
196    }
197
198    /// Batch insert with optional pre-warming
199    pub fn batch_upsert_with_prewarm(&self, points: Vec<Point>, prewarm: bool) -> Result<()> {
200        self.batch_upsert(points)?;
201        if prewarm {
202            self.prewarm_index()?;
203        }
204        Ok(())
205    }
206
207    /// Get a point by ID
208    #[inline]
209    pub fn get(&self, id: &str) -> Option<Point> {
210        self.points.read().get(id).cloned()
211    }
212
213    /// Delete a point by ID
214    pub fn delete(&self, id: &str) -> Result<bool> {
215        if let Some(hnsw) = &self.hnsw {
216            let mut index = hnsw.write();
217            index.remove(id);
218        }
219
220        if let Some(bm25) = &self.bm25 {
221            let mut index = bm25.write();
222            index.delete_doc(id);
223        }
224
225        let mut points = self.points.write();
226        Ok(points.remove(id).is_some())
227    }
228
229    /// Pre-warm HNSW index
230    pub fn prewarm_index(&self) -> Result<()> {
231        if let Some(hnsw) = &self.hnsw {
232            let mut built = self.hnsw_built.write();
233            if !*built {
234                let points = self.points.read();
235                if !points.is_empty() {
236                    let mut index = hnsw.write();
237                    *index = HnswIndex::new(16, 3);
238                    for point in points.values() {
239                        index.insert(point.clone());
240                    }
241                    *built = true;
242                }
243            }
244        }
245        Ok(())
246    }
247
248    /// Fast brute-force search using SIMD - optimal for small datasets
249    fn brute_force_search(&self, query: &Vector, limit: usize, filter: Option<&dyn Filter>) -> Vec<(Point, f32)> {
250        let points = self.points.read();
251        let query_slice = query.as_slice();
252        
253        // Pre-allocate results with capacity
254        let mut results: Vec<(Point, f32)> = Vec::with_capacity(points.len().min(limit * 2));
255        
256        for point in points.values() {
257            if let Some(f) = filter {
258                if !f.matches(point) {
259                    continue;
260                }
261            }
262            
263            // Use SIMD-optimized dot product for cosine similarity (vectors are normalized)
264            let score = match self.config.distance {
265                Distance::Cosine => {
266                    crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
267                }
268                Distance::Euclidean => {
269                    -crate::simd::l2_distance_simd(query_slice, point.vector.as_slice())
270                }
271                Distance::Dot => {
272                    crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
273                }
274            };
275            
276            results.push((point.clone(), score));
277        }
278        
279        // Use partial sort for efficiency when limit << len
280        if results.len() > limit {
281            results.select_nth_unstable_by(limit, |a, b| {
282                b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
283            });
284            results.truncate(limit);
285        }
286        
287        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
288        results
289    }
290
291    /// Search for similar vectors
292    /// Uses brute-force for small datasets (<1000), HNSW for larger ones
293    pub fn search(
294        &self,
295        query: &Vector,
296        limit: usize,
297        filter: Option<&dyn Filter>,
298    ) -> Vec<(Point, f32)> {
299        let normalized_query = query.normalized();
300        let point_count = self.points.read().len();
301        
302        // Use brute-force for small datasets - faster than HNSW overhead
303        const BRUTE_FORCE_THRESHOLD: usize = 1000;
304        if point_count < BRUTE_FORCE_THRESHOLD {
305            return self.brute_force_search(&normalized_query, limit, filter);
306        }
307        
308        if let Some(hnsw) = &self.hnsw {
309            // Check if we need to build the index first
310            {
311                let mut built = self.hnsw_built.write();
312                if !*built {
313                    let points = self.points.read();
314                    if !points.is_empty() {
315                        let mut index = hnsw.write();
316                        *index = HnswIndex::new(16, 3);
317                        for point in points.values() {
318                            index.insert(point.clone());
319                        }
320                        *built = true;
321                    }
322                }
323            }
324            
325            // Use write lock for search (HNSW search is now mutable for performance)
326            let mut index = hnsw.write();
327            let mut results = index.search(&normalized_query, limit, None);
328            
329            if let Some(f) = filter {
330                results.retain(|(point, _)| f.matches(point));
331            }
332            
333            results
334        } else {
335            let points = self.points.read();
336            let results: Vec<(Point, f32)> = points
337                .values()
338                .filter(|point| {
339                    filter.map(|f| f.matches(point)).unwrap_or(true)
340                })
341                .map(|point| {
342                    let score = match self.config.distance {
343                        Distance::Cosine => point.vector.cosine_similarity(query),
344                        Distance::Euclidean => -point.vector.l2_distance(query),
345                        Distance::Dot => {
346                            point.vector.as_slice()
347                                .iter()
348                                .zip(query.as_slice().iter())
349                                .map(|(a, b)| a * b)
350                                .sum()
351                        }
352                    };
353                    (point.clone(), score)
354                })
355                .collect();
356
357            let mut sorted = results;
358            sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
359            sorted.truncate(limit);
360            sorted
361        }
362    }
363
364    /// BM25 text search
365    pub fn search_text(&self, query: &str, limit: usize) -> Vec<(String, f32)> {
366        if let Some(bm25) = &self.bm25 {
367            let index = bm25.read();
368            index.search(query, limit)
369        } else {
370            Vec::new()
371        }
372    }
373
374    /// Get all points
375    pub fn iter(&self) -> Vec<Point> {
376        self.points.read().values().cloned().collect()
377    }
378}
379