1use crate::{Error, Point, Result, Vector, HnswIndex, BM25Index, Filter, MultiVector};
2use parking_lot::RwLock;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6
7#[derive(Debug, Clone)]
9pub struct CollectionConfig {
10 pub name: String,
11 pub vector_dim: usize,
12 pub distance: Distance,
13 pub use_hnsw: bool,
14 pub enable_bm25: bool,
15}
16
17impl Default for CollectionConfig {
18 fn default() -> Self {
19 Self {
20 name: String::new(),
21 vector_dim: 128,
22 distance: Distance::Cosine,
23 use_hnsw: true,
24 enable_bm25: false,
25 }
26 }
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum Distance {
31 Cosine,
32 Euclidean,
33 Dot,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum PayloadIndexType {
39 Keyword,
40 Integer,
41 Float,
42 Bool,
43 Geo,
44 Text,
45}
46
47pub struct Collection {
49 config: CollectionConfig,
50 points: Arc<RwLock<HashMap<String, Point>>>,
51 hnsw: Option<Arc<RwLock<HnswIndex>>>,
52 bm25: Option<Arc<RwLock<BM25Index>>>,
53 hnsw_built: Arc<RwLock<bool>>,
54 hnsw_rebuilding: Arc<AtomicBool>,
55 batch_mode: Arc<RwLock<bool>>,
56 pending_points: Arc<RwLock<Vec<Point>>>,
57 payload_indexes: Arc<RwLock<HashMap<String, PayloadIndexType>>>,
59}
60
61impl Collection {
62 pub fn new(config: CollectionConfig) -> Self {
63 let hnsw = if config.use_hnsw {
64 Some(Arc::new(RwLock::new(HnswIndex::new(16, 3))))
65 } else {
66 None
67 };
68
69 let bm25 = if config.enable_bm25 {
70 Some(Arc::new(RwLock::new(BM25Index::new())))
71 } else {
72 None
73 };
74
75 Self {
76 config,
77 points: Arc::new(RwLock::new(HashMap::new())),
78 hnsw,
79 bm25,
80 hnsw_built: Arc::new(RwLock::new(false)),
81 hnsw_rebuilding: Arc::new(AtomicBool::new(false)),
82 batch_mode: Arc::new(RwLock::new(false)),
83 pending_points: Arc::new(RwLock::new(Vec::new())),
84 payload_indexes: Arc::new(RwLock::new(HashMap::new())),
85 }
86 }
87
88 #[inline]
89 #[must_use]
90 pub fn name(&self) -> &str {
91 &self.config.name
92 }
93
94 #[inline]
95 #[must_use]
96 pub fn vector_dim(&self) -> usize {
97 self.config.vector_dim
98 }
99
100 #[inline]
101 #[must_use]
102 pub fn distance(&self) -> Distance {
103 self.config.distance
104 }
105
106 #[inline]
107 #[must_use]
108 pub fn use_hnsw(&self) -> bool {
109 self.config.use_hnsw
110 }
111
112 #[inline]
113 #[must_use]
114 pub fn enable_bm25(&self) -> bool {
115 self.config.enable_bm25
116 }
117
118 #[inline]
119 #[must_use]
120 pub fn count(&self) -> usize {
121 self.points.read().len()
122 }
123
124 #[inline]
125 #[must_use]
126 pub fn is_empty(&self) -> bool {
127 self.points.read().is_empty()
128 }
129
130 pub fn get_all_points(&self) -> Vec<Point> {
132 self.points.read().values().cloned().collect()
133 }
134
135 pub fn upsert(&self, point: Point) -> Result<()> {
137 if self.config.vector_dim > 0 && point.vector.dim() != self.config.vector_dim {
139 return Err(Error::InvalidDimension {
140 expected: self.config.vector_dim,
141 actual: point.vector.dim(),
142 });
143 }
144
145 let id_str = point.id.to_string();
146
147 let in_batch = *self.batch_mode.read();
148 if in_batch {
149 self.points.write().insert(id_str.clone(), point.clone());
150 self.pending_points.write().push(point);
151 return Ok(());
152 }
153
154 if let Some(hnsw) = &self.hnsw {
155 let built = *self.hnsw_built.read();
156 if built {
157 let mut normalized_point = point.clone();
158 normalized_point.vector.normalize();
159
160 let mut index = hnsw.write();
161 index.insert(normalized_point);
162 }
163 }
164
165 if let Some(bm25) = &self.bm25 {
166 if let Some(payload) = &point.payload {
167 if let Some(text) = payload.get("text").and_then(|v| v.as_str()) {
168 let mut index = bm25.write();
169 index.insert_doc(&id_str, text);
170 }
171 }
172 }
173
174 self.points.write().insert(id_str, point);
175 Ok(())
176 }
177
178 pub fn start_batch(&self) {
180 *self.batch_mode.write() = true;
181 self.pending_points.write().clear();
182 }
183
184 pub fn end_batch(&self) -> Result<()> {
186 *self.batch_mode.write() = false;
187
188 if let Some(hnsw) = &self.hnsw {
189 let points = self.points.read();
190 let point_count = points.len();
191
192 const HNSW_REBUILD_THRESHOLD: usize = 10_000;
193
194 if point_count > HNSW_REBUILD_THRESHOLD && !self.hnsw_rebuilding.load(Ordering::Acquire) {
195 self.hnsw_rebuilding.store(true, Ordering::Release);
196 let points_clone: Vec<Point> = points.values().cloned().collect();
197 let hnsw_clone = hnsw.clone();
198 let built_flag = self.hnsw_built.clone();
199 let rebuilding_flag = self.hnsw_rebuilding.clone();
200
201 let job = crate::background::HnswRebuildJob::new(
202 points_clone,
203 hnsw_clone,
204 built_flag,
205 rebuilding_flag,
206 );
207 crate::background::get_background_system().submit(Box::new(job));
208 }
209 }
210
211 self.pending_points.write().clear();
212 Ok(())
213 }
214
215 pub fn batch_upsert(&self, points: Vec<Point>) -> Result<()> {
217 self.start_batch();
218 for point in points {
219 self.upsert(point)?;
220 }
221 self.end_batch()?;
222 Ok(())
223 }
224
225 pub fn batch_upsert_with_prewarm(&self, points: Vec<Point>, prewarm: bool) -> Result<()> {
227 self.batch_upsert(points)?;
228 if prewarm {
229 self.prewarm_index()?;
230 }
231 Ok(())
232 }
233
234 #[inline]
236 pub fn get(&self, id: &str) -> Option<Point> {
237 self.points.read().get(id).cloned()
238 }
239
240 pub fn delete(&self, id: &str) -> Result<bool> {
242 if let Some(hnsw) = &self.hnsw {
243 let mut index = hnsw.write();
244 index.remove(id);
245 }
246
247 if let Some(bm25) = &self.bm25 {
248 let mut index = bm25.write();
249 index.delete_doc(id);
250 }
251
252 let mut points = self.points.write();
253 Ok(points.remove(id).is_some())
254 }
255
256 pub fn set_payload(&self, id: &str, payload: serde_json::Value) -> Result<bool> {
258 let mut points = self.points.write();
259 if let Some(point) = points.get_mut(id) {
260 if let Some(existing) = &mut point.payload {
261 if let (Some(existing_obj), Some(new_obj)) = (existing.as_object_mut(), payload.as_object()) {
262 for (key, value) in new_obj {
263 existing_obj.insert(key.clone(), value.clone());
264 }
265 }
266 } else {
267 point.payload = Some(payload);
268 }
269 Ok(true)
270 } else {
271 Ok(false)
272 }
273 }
274
275 pub fn overwrite_payload(&self, id: &str, payload: serde_json::Value) -> Result<bool> {
277 let mut points = self.points.write();
278 if let Some(point) = points.get_mut(id) {
279 point.payload = Some(payload);
280 Ok(true)
281 } else {
282 Ok(false)
283 }
284 }
285
286 pub fn delete_payload_keys(&self, id: &str, keys: &[String]) -> Result<bool> {
288 let mut points = self.points.write();
289 if let Some(point) = points.get_mut(id) {
290 if let Some(payload) = &mut point.payload {
291 if let Some(obj) = payload.as_object_mut() {
292 for key in keys {
293 obj.remove(key);
294 }
295 }
296 }
297 Ok(true)
298 } else {
299 Ok(false)
300 }
301 }
302
303 pub fn clear_payload(&self, id: &str) -> Result<bool> {
305 let mut points = self.points.write();
306 if let Some(point) = points.get_mut(id) {
307 point.payload = None;
308 Ok(true)
309 } else {
310 Ok(false)
311 }
312 }
313
314 pub fn update_vector(&self, id: &str, vector: Vector) -> Result<bool> {
316 let mut points = self.points.write();
317 if let Some(point) = points.get_mut(id) {
318 point.vector = vector.clone();
319
320 if let Some(hnsw) = &self.hnsw {
322 let mut index = hnsw.write();
323 index.remove(id);
324 index.insert(point.clone());
326 }
327 Ok(true)
328 } else {
329 Ok(false)
330 }
331 }
332
333 pub fn update_multivector(&self, id: &str, multivector: Option<MultiVector>) -> Result<bool> {
335 let mut points = self.points.write();
336 if let Some(point) = points.get_mut(id) {
337 point.multivector = multivector;
338 Ok(true)
339 } else {
340 Ok(false)
341 }
342 }
343
344 pub fn delete_vector(&self, id: &str) -> Result<bool> {
346 self.delete(id)
349 }
350
351 pub fn create_payload_index(&self, field_name: &str, index_type: PayloadIndexType) -> Result<bool> {
353 let mut indexes = self.payload_indexes.write();
354 indexes.insert(field_name.to_string(), index_type);
355 Ok(true)
356 }
357
358 pub fn delete_payload_index(&self, field_name: &str) -> Result<bool> {
360 let mut indexes = self.payload_indexes.write();
361 Ok(indexes.remove(field_name).is_some())
362 }
363
364 pub fn get_payload_indexes(&self) -> HashMap<String, PayloadIndexType> {
366 self.payload_indexes.read().clone()
367 }
368
369 pub fn is_field_indexed(&self, field_name: &str) -> bool {
371 self.payload_indexes.read().contains_key(field_name)
372 }
373
374 pub fn prewarm_index(&self) -> Result<()> {
376 if let Some(hnsw) = &self.hnsw {
377 let mut built = self.hnsw_built.write();
378 if !*built {
379 let points = self.points.read();
380 if !points.is_empty() {
381 let mut index = hnsw.write();
382 *index = HnswIndex::new(16, 3);
383 for point in points.values() {
384 index.insert(point.clone());
385 }
386 *built = true;
387 }
388 }
389 }
390 Ok(())
391 }
392
393 fn brute_force_search(&self, query: &Vector, limit: usize, filter: Option<&dyn Filter>) -> Vec<(Point, f32)> {
395 let points = self.points.read();
396 let query_slice = query.as_slice();
397
398 let mut results: Vec<(Point, f32)> = Vec::with_capacity(points.len().min(limit * 2));
400
401 for point in points.values() {
402 if let Some(f) = filter {
403 if !f.matches(point) {
404 continue;
405 }
406 }
407
408 let score = match self.config.distance {
410 Distance::Cosine => {
411 crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
412 }
413 Distance::Euclidean => {
414 -crate::simd::l2_distance_simd(query_slice, point.vector.as_slice())
415 }
416 Distance::Dot => {
417 crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
418 }
419 };
420
421 results.push((point.clone(), score));
422 }
423
424 if results.len() > limit {
426 results.select_nth_unstable_by(limit, |a, b| {
427 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
428 });
429 results.truncate(limit);
430 }
431
432 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
433 results
434 }
435
436 pub fn search(
439 &self,
440 query: &Vector,
441 limit: usize,
442 filter: Option<&dyn Filter>,
443 ) -> Vec<(Point, f32)> {
444 let normalized_query = query.normalized();
445 let point_count = self.points.read().len();
446
447 const BRUTE_FORCE_THRESHOLD: usize = 1000;
449 if point_count < BRUTE_FORCE_THRESHOLD {
450 return self.brute_force_search(&normalized_query, limit, filter);
451 }
452
453 if let Some(hnsw) = &self.hnsw {
454 {
456 let mut built = self.hnsw_built.write();
457 if !*built {
458 let points = self.points.read();
459 if !points.is_empty() {
460 let mut index = hnsw.write();
461 *index = HnswIndex::new(16, 3);
462 for point in points.values() {
463 index.insert(point.clone());
464 }
465 *built = true;
466 }
467 }
468 }
469
470 let mut index = hnsw.write();
472 let mut results = index.search(&normalized_query, limit, None);
473
474 if let Some(f) = filter {
475 results.retain(|(point, _)| f.matches(point));
476 }
477
478 results
479 } else {
480 let points = self.points.read();
481 let results: Vec<(Point, f32)> = points
482 .values()
483 .filter(|point| {
484 filter.map(|f| f.matches(point)).unwrap_or(true)
485 })
486 .map(|point| {
487 let score = match self.config.distance {
488 Distance::Cosine => point.vector.cosine_similarity(query),
489 Distance::Euclidean => -point.vector.l2_distance(query),
490 Distance::Dot => {
491 point.vector.as_slice()
492 .iter()
493 .zip(query.as_slice().iter())
494 .map(|(a, b)| a * b)
495 .sum()
496 }
497 };
498 (point.clone(), score)
499 })
500 .collect();
501
502 let mut sorted = results;
503 sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
504 sorted.truncate(limit);
505 sorted
506 }
507 }
508
509 pub fn search_text(&self, query: &str, limit: usize) -> Vec<(String, f32)> {
511 if let Some(bm25) = &self.bm25 {
512 let index = bm25.read();
513 index.search(query, limit)
514 } else {
515 Vec::new()
516 }
517 }
518
519 pub fn search_multivector(
524 &self,
525 query: &MultiVector,
526 limit: usize,
527 filter: Option<&dyn Filter>,
528 ) -> Vec<(Point, f32)> {
529 let points = self.points.read();
530
531 let mut results: Vec<(Point, f32)> = Vec::with_capacity(points.len().min(limit * 2));
532
533 for point in points.values() {
534 if let Some(f) = filter {
535 if !f.matches(point) {
536 continue;
537 }
538 }
539
540 let score = if let Some(doc_mv) = &point.multivector {
542 match self.config.distance {
544 Distance::Cosine => query.max_sim_cosine(doc_mv),
545 Distance::Euclidean => query.max_sim_l2(doc_mv),
546 Distance::Dot => query.max_sim(doc_mv),
547 }
548 } else {
549 let doc_mv = MultiVector::from_single(point.vector.as_slice().to_vec())
551 .unwrap_or_else(|_| MultiVector::new(vec![vec![0.0; query.dim()]]).unwrap());
552 match self.config.distance {
553 Distance::Cosine => query.max_sim_cosine(&doc_mv),
554 Distance::Euclidean => query.max_sim_l2(&doc_mv),
555 Distance::Dot => query.max_sim(&doc_mv),
556 }
557 };
558
559 results.push((point.clone(), score));
560 }
561
562 if results.len() > limit {
564 results.select_nth_unstable_by(limit, |a, b| {
565 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
566 });
567 results.truncate(limit);
568 }
569
570 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
571 results
572 }
573
574 pub fn iter(&self) -> Vec<Point> {
576 self.points.read().values().cloned().collect()
577 }
578}
579