1use crate::{Error, Point, Result, Vector, HnswIndex, BM25Index, Filter};
2use parking_lot::RwLock;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6
7#[derive(Debug, Clone)]
9pub struct CollectionConfig {
10 pub name: String,
11 pub vector_dim: usize,
12 pub distance: Distance,
13 pub use_hnsw: bool,
14 pub enable_bm25: bool,
15}
16
17impl Default for CollectionConfig {
18 fn default() -> Self {
19 Self {
20 name: String::new(),
21 vector_dim: 128,
22 distance: Distance::Cosine,
23 use_hnsw: true,
24 enable_bm25: false,
25 }
26 }
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum Distance {
31 Cosine,
32 Euclidean,
33 Dot,
34}
35
36pub struct Collection {
38 config: CollectionConfig,
39 points: Arc<RwLock<HashMap<String, Point>>>,
40 hnsw: Option<Arc<RwLock<HnswIndex>>>,
41 bm25: Option<Arc<RwLock<BM25Index>>>,
42 hnsw_built: Arc<RwLock<bool>>,
43 hnsw_rebuilding: Arc<AtomicBool>,
44 batch_mode: Arc<RwLock<bool>>,
45 pending_points: Arc<RwLock<Vec<Point>>>,
46}
47
48impl Collection {
49 pub fn new(config: CollectionConfig) -> Self {
50 let hnsw = if config.use_hnsw {
51 Some(Arc::new(RwLock::new(HnswIndex::new(16, 3))))
52 } else {
53 None
54 };
55
56 let bm25 = if config.enable_bm25 {
57 Some(Arc::new(RwLock::new(BM25Index::new())))
58 } else {
59 None
60 };
61
62 Self {
63 config,
64 points: Arc::new(RwLock::new(HashMap::new())),
65 hnsw,
66 bm25,
67 hnsw_built: Arc::new(RwLock::new(false)),
68 hnsw_rebuilding: Arc::new(AtomicBool::new(false)),
69 batch_mode: Arc::new(RwLock::new(false)),
70 pending_points: Arc::new(RwLock::new(Vec::new())),
71 }
72 }
73
74 #[inline]
75 #[must_use]
76 pub fn name(&self) -> &str {
77 &self.config.name
78 }
79
80 #[inline]
81 #[must_use]
82 pub fn vector_dim(&self) -> usize {
83 self.config.vector_dim
84 }
85
86 #[inline]
87 #[must_use]
88 pub fn distance(&self) -> Distance {
89 self.config.distance
90 }
91
92 #[inline]
93 #[must_use]
94 pub fn count(&self) -> usize {
95 self.points.read().len()
96 }
97
98 #[inline]
99 #[must_use]
100 pub fn is_empty(&self) -> bool {
101 self.points.read().is_empty()
102 }
103
104 pub fn get_all_points(&self) -> Vec<Point> {
106 self.points.read().values().cloned().collect()
107 }
108
109 pub fn upsert(&self, point: Point) -> Result<()> {
111 if point.vector.dim() != self.config.vector_dim {
112 return Err(Error::InvalidDimension {
113 expected: self.config.vector_dim,
114 actual: point.vector.dim(),
115 });
116 }
117
118 let id_str = point.id.to_string();
119
120 let in_batch = *self.batch_mode.read();
121 if in_batch {
122 self.points.write().insert(id_str.clone(), point.clone());
123 self.pending_points.write().push(point);
124 return Ok(());
125 }
126
127 if let Some(hnsw) = &self.hnsw {
128 let built = *self.hnsw_built.read();
129 if built {
130 let mut normalized_point = point.clone();
131 normalized_point.vector.normalize();
132
133 let mut index = hnsw.write();
134 index.insert(normalized_point);
135 }
136 }
137
138 if let Some(bm25) = &self.bm25 {
139 if let Some(payload) = &point.payload {
140 if let Some(text) = payload.get("text").and_then(|v| v.as_str()) {
141 let mut index = bm25.write();
142 index.insert_doc(&id_str, text);
143 }
144 }
145 }
146
147 self.points.write().insert(id_str, point);
148 Ok(())
149 }
150
151 pub fn start_batch(&self) {
153 *self.batch_mode.write() = true;
154 self.pending_points.write().clear();
155 }
156
157 pub fn end_batch(&self) -> Result<()> {
159 *self.batch_mode.write() = false;
160
161 if let Some(hnsw) = &self.hnsw {
162 let points = self.points.read();
163 let point_count = points.len();
164
165 const HNSW_REBUILD_THRESHOLD: usize = 10_000;
166
167 if point_count > HNSW_REBUILD_THRESHOLD && !self.hnsw_rebuilding.load(Ordering::Acquire) {
168 self.hnsw_rebuilding.store(true, Ordering::Release);
169 let points_clone: Vec<Point> = points.values().cloned().collect();
170 let hnsw_clone = hnsw.clone();
171 let built_flag = self.hnsw_built.clone();
172 let rebuilding_flag = self.hnsw_rebuilding.clone();
173
174 let job = crate::background::HnswRebuildJob::new(
175 points_clone,
176 hnsw_clone,
177 built_flag,
178 rebuilding_flag,
179 );
180 crate::background::get_background_system().submit(Box::new(job));
181 }
182 }
183
184 self.pending_points.write().clear();
185 Ok(())
186 }
187
188 pub fn batch_upsert(&self, points: Vec<Point>) -> Result<()> {
190 self.start_batch();
191 for point in points {
192 self.upsert(point)?;
193 }
194 self.end_batch()?;
195 Ok(())
196 }
197
198 pub fn batch_upsert_with_prewarm(&self, points: Vec<Point>, prewarm: bool) -> Result<()> {
200 self.batch_upsert(points)?;
201 if prewarm {
202 self.prewarm_index()?;
203 }
204 Ok(())
205 }
206
207 #[inline]
209 pub fn get(&self, id: &str) -> Option<Point> {
210 self.points.read().get(id).cloned()
211 }
212
213 pub fn delete(&self, id: &str) -> Result<bool> {
215 if let Some(hnsw) = &self.hnsw {
216 let mut index = hnsw.write();
217 index.remove(id);
218 }
219
220 if let Some(bm25) = &self.bm25 {
221 let mut index = bm25.write();
222 index.delete_doc(id);
223 }
224
225 let mut points = self.points.write();
226 Ok(points.remove(id).is_some())
227 }
228
229 pub fn prewarm_index(&self) -> Result<()> {
231 if let Some(hnsw) = &self.hnsw {
232 let mut built = self.hnsw_built.write();
233 if !*built {
234 let points = self.points.read();
235 if !points.is_empty() {
236 let mut index = hnsw.write();
237 *index = HnswIndex::new(16, 3);
238 for point in points.values() {
239 index.insert(point.clone());
240 }
241 *built = true;
242 }
243 }
244 }
245 Ok(())
246 }
247
248 fn brute_force_search(&self, query: &Vector, limit: usize, filter: Option<&dyn Filter>) -> Vec<(Point, f32)> {
250 let points = self.points.read();
251 let query_slice = query.as_slice();
252
253 let mut results: Vec<(Point, f32)> = Vec::with_capacity(points.len().min(limit * 2));
255
256 for point in points.values() {
257 if let Some(f) = filter {
258 if !f.matches(point) {
259 continue;
260 }
261 }
262
263 let score = match self.config.distance {
265 Distance::Cosine => {
266 crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
267 }
268 Distance::Euclidean => {
269 -crate::simd::l2_distance_simd(query_slice, point.vector.as_slice())
270 }
271 Distance::Dot => {
272 crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
273 }
274 };
275
276 results.push((point.clone(), score));
277 }
278
279 if results.len() > limit {
281 results.select_nth_unstable_by(limit, |a, b| {
282 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
283 });
284 results.truncate(limit);
285 }
286
287 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
288 results
289 }
290
291 pub fn search(
294 &self,
295 query: &Vector,
296 limit: usize,
297 filter: Option<&dyn Filter>,
298 ) -> Vec<(Point, f32)> {
299 let normalized_query = query.normalized();
300 let point_count = self.points.read().len();
301
302 const BRUTE_FORCE_THRESHOLD: usize = 1000;
304 if point_count < BRUTE_FORCE_THRESHOLD {
305 return self.brute_force_search(&normalized_query, limit, filter);
306 }
307
308 if let Some(hnsw) = &self.hnsw {
309 {
311 let mut built = self.hnsw_built.write();
312 if !*built {
313 let points = self.points.read();
314 if !points.is_empty() {
315 let mut index = hnsw.write();
316 *index = HnswIndex::new(16, 3);
317 for point in points.values() {
318 index.insert(point.clone());
319 }
320 *built = true;
321 }
322 }
323 }
324
325 let mut index = hnsw.write();
327 let mut results = index.search(&normalized_query, limit, None);
328
329 if let Some(f) = filter {
330 results.retain(|(point, _)| f.matches(point));
331 }
332
333 results
334 } else {
335 let points = self.points.read();
336 let results: Vec<(Point, f32)> = points
337 .values()
338 .filter(|point| {
339 filter.map(|f| f.matches(point)).unwrap_or(true)
340 })
341 .map(|point| {
342 let score = match self.config.distance {
343 Distance::Cosine => point.vector.cosine_similarity(query),
344 Distance::Euclidean => -point.vector.l2_distance(query),
345 Distance::Dot => {
346 point.vector.as_slice()
347 .iter()
348 .zip(query.as_slice().iter())
349 .map(|(a, b)| a * b)
350 .sum()
351 }
352 };
353 (point.clone(), score)
354 })
355 .collect();
356
357 let mut sorted = results;
358 sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
359 sorted.truncate(limit);
360 sorted
361 }
362 }
363
364 pub fn search_text(&self, query: &str, limit: usize) -> Vec<(String, f32)> {
366 if let Some(bm25) = &self.bm25 {
367 let index = bm25.read();
368 index.search(query, limit)
369 } else {
370 Vec::new()
371 }
372 }
373
374 pub fn iter(&self) -> Vec<Point> {
376 self.points.read().values().cloned().collect()
377 }
378}
379