1use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20 pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22 let dense_fields = self.get_dense_vector_fields();
23 if dense_fields.is_empty() {
24 return Ok(());
25 }
26
27 let all_built = {
30 let metadata_arc = self.segment_manager.metadata();
31 let meta = metadata_arc.read().await;
32 dense_fields
33 .iter()
34 .all(|(field, _)| meta.is_field_built(field.0))
35 };
36 if all_built {
37 return Ok(());
38 }
39
40 let segment_ids = self.segment_manager.get_segment_ids().await;
42 let total_vectors = self.count_flat_vectors(&segment_ids).await;
43
44 let should_build = {
46 let metadata_arc = self.segment_manager.metadata();
47 let mut meta = metadata_arc.write().await;
48 meta.total_vectors = total_vectors;
49 dense_fields.iter().any(|(field, config)| {
50 let threshold = config.build_threshold.unwrap_or(1000);
51 meta.should_build_field(field.0, threshold)
52 })
53 };
54
55 if should_build {
56 log::info!(
57 "Threshold crossed ({} vectors), auto-triggering vector index build",
58 total_vectors
59 );
60 self.build_vector_index().await?;
61 }
62
63 Ok(())
64 }
65
66 pub async fn build_vector_index(&self) -> Result<()> {
80 let dense_fields = self.get_dense_vector_fields();
81 if dense_fields.is_empty() {
82 log::info!("No dense vector fields configured for ANN indexing");
83 return Ok(());
84 }
85
86 let fields_to_build = self.get_fields_to_build(&dense_fields).await;
88 if fields_to_build.is_empty() {
89 log::info!("All vector fields already built, skipping training");
90 return Ok(());
91 }
92
93 self.segment_manager.wait_for_merges().await;
98
99 let segment_ids = self.segment_manager.get_segment_ids().await;
100 if segment_ids.is_empty() {
101 return Ok(());
102 }
103
104 let all_vectors = self
106 .collect_vectors_for_training(&segment_ids, &fields_to_build)
107 .await?;
108
109 for (field, config) in &fields_to_build {
111 self.train_field_index(*field, config, &all_vectors).await?;
112 }
113
114 log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
115
116 self.rebuild_segments_with_ann().await?;
118
119 Ok(())
120 }
121
122 pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
124 self.segment_manager.pause_merges();
128 self.segment_manager.wait_for_merges().await;
129
130 let result = self.rebuild_segments_with_ann_inner().await;
131
132 self.segment_manager.resume_merges();
134
135 result
136 }
137
138 async fn rebuild_segments_with_ann_inner(&self) -> Result<()> {
139 let segment_ids = self.segment_manager.get_segment_ids().await;
140 if segment_ids.is_empty() {
141 return Ok(());
142 }
143
144 let (trained_centroids, trained_codebooks) = {
146 let metadata_arc = self.segment_manager.metadata();
147 let meta = metadata_arc.read().await;
148 meta.load_trained_structures(self.directory.as_ref()).await
149 };
150
151 if trained_centroids.is_empty() {
152 log::info!("No trained structures to rebuild with");
153 return Ok(());
154 }
155
156 let trained = TrainedVectorStructures {
157 centroids: trained_centroids,
158 codebooks: trained_codebooks,
159 };
160
161 let readers = self.load_segment_readers(&segment_ids).await?;
163
164 let total_docs: u32 = readers.iter().map(|r| r.meta().num_docs).sum();
166
167 let merger = SegmentMerger::new(Arc::clone(&self.schema));
169 let new_segment_id = SegmentId::new();
170 merger
171 .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
172 .await?;
173
174 self.segment_manager
176 .replace_segments(vec![(new_segment_id.to_hex(), total_docs)], segment_ids)
177 .await?;
178
179 log::info!("Segments rebuilt with ANN indexes");
180 Ok(())
181 }
182
183 pub async fn total_vector_count(&self) -> usize {
185 let metadata_arc = self.segment_manager.metadata();
186 metadata_arc.read().await.total_vectors
187 }
188
189 pub async fn is_vector_index_built(&self, field: Field) -> bool {
191 let metadata_arc = self.segment_manager.metadata();
192 metadata_arc.read().await.is_field_built(field.0)
193 }
194
195 pub async fn rebuild_vector_index(&self) -> Result<()> {
204 let dense_fields: Vec<Field> = self
205 .schema
206 .fields()
207 .filter_map(|(field, entry)| {
208 if entry.field_type == FieldType::DenseVector && entry.indexed {
209 Some(field)
210 } else {
211 None
212 }
213 })
214 .collect();
215
216 if dense_fields.is_empty() {
217 return Ok(());
218 }
219
220 let files_to_delete = {
222 let metadata_arc = self.segment_manager.metadata();
223 let mut meta = metadata_arc.write().await;
224 let mut files = Vec::new();
225 for field in &dense_fields {
226 if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
227 field_meta.state = super::VectorIndexState::Flat;
228 if let Some(ref f) = field_meta.centroids_file {
229 files.push(f.clone());
230 }
231 if let Some(ref f) = field_meta.codebook_file {
232 files.push(f.clone());
233 }
234 field_meta.centroids_file = None;
235 field_meta.codebook_file = None;
236 }
237 }
238 meta.save(self.directory.as_ref()).await?;
239 files
240 };
241
242 for file in files_to_delete {
244 let _ = self.directory.delete(std::path::Path::new(&file)).await;
245 }
246
247 log::info!("Reset vector index state to Flat, triggering rebuild...");
248
249 self.build_vector_index().await
251 }
252
253 fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
259 self.schema
260 .fields()
261 .filter_map(|(field, entry)| {
262 if entry.field_type == FieldType::DenseVector && entry.indexed {
263 entry
264 .dense_vector_config
265 .as_ref()
266 .filter(|c| !c.is_flat())
267 .map(|c| (field, c.clone()))
268 } else {
269 None
270 }
271 })
272 .collect()
273 }
274
275 async fn get_fields_to_build(
277 &self,
278 dense_fields: &[(Field, DenseVectorConfig)],
279 ) -> Vec<(Field, DenseVectorConfig)> {
280 let metadata_arc = self.segment_manager.metadata();
281 let meta = metadata_arc.read().await;
282 dense_fields
283 .iter()
284 .filter(|(field, _)| !meta.is_field_built(field.0))
285 .cloned()
286 .collect()
287 }
288
289 async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
292 let mut total_vectors = 0usize;
293 let mut doc_offset = 0u32;
294
295 for id_str in segment_ids {
296 let Some(segment_id) = SegmentId::from_hex(id_str) else {
297 continue;
298 };
299
300 let files = crate::segment::SegmentFiles::new(segment_id.0);
302 if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
303 continue;
305 }
306
307 if let Ok(reader) = SegmentReader::open(
309 self.directory.as_ref(),
310 segment_id,
311 Arc::clone(&self.schema),
312 doc_offset,
313 self.config.term_cache_blocks,
314 )
315 .await
316 {
317 for flat_data in reader.flat_vectors().values() {
318 total_vectors += flat_data.num_vectors;
319 }
320 doc_offset += reader.meta().num_docs;
321 }
322 }
323
324 total_vectors
325 }
326
327 async fn collect_vectors_for_training(
332 &self,
333 segment_ids: &[String],
334 fields_to_build: &[(Field, DenseVectorConfig)],
335 ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
336 const MAX_TRAINING_VECTORS: usize = 100_000;
338
339 let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
340 let mut doc_offset = 0u32;
341 let mut total_skipped = 0usize;
342
343 for id_str in segment_ids {
344 let segment_id = SegmentId::from_hex(id_str)
345 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
346 let reader = SegmentReader::open(
347 self.directory.as_ref(),
348 segment_id,
349 Arc::clone(&self.schema),
350 doc_offset,
351 self.config.term_cache_blocks,
352 )
353 .await?;
354
355 for (field_id, lazy_flat) in reader.flat_vectors() {
356 if !fields_to_build.iter().any(|(f, _)| f.0 == *field_id) {
357 continue;
358 }
359 let entry = all_vectors.entry(*field_id).or_default();
360 let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
361
362 if remaining == 0 {
363 total_skipped += lazy_flat.num_vectors;
364 continue;
365 }
366
367 let n = lazy_flat.num_vectors;
368 if n <= remaining {
369 for i in 0..n {
371 if let Ok(vec) = lazy_flat.get_vector(i).await {
372 entry.push(vec);
373 }
374 }
375 } else {
376 let step = (n / remaining).max(1);
378 for i in 0..n {
379 if i % step == 0
380 && entry.len() < MAX_TRAINING_VECTORS
381 && let Ok(vec) = lazy_flat.get_vector(i).await
382 {
383 entry.push(vec);
384 }
385 }
386 total_skipped += n - remaining;
387 }
388 }
389
390 doc_offset += reader.meta().num_docs;
391 }
392
393 if total_skipped > 0 {
394 let collected: usize = all_vectors.values().map(|v| v.len()).sum();
395 log::info!(
396 "Sampled {} vectors for training (skipped {}, max {} per field)",
397 collected,
398 total_skipped,
399 MAX_TRAINING_VECTORS,
400 );
401 }
402
403 Ok(all_vectors)
404 }
405
406 pub(super) async fn load_segment_readers(
408 &self,
409 segment_ids: &[String],
410 ) -> Result<Vec<SegmentReader>> {
411 let mut readers = Vec::new();
412 let mut doc_offset = 0u32;
413
414 for id_str in segment_ids {
415 let segment_id = SegmentId::from_hex(id_str)
416 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
417 let reader = SegmentReader::open(
418 self.directory.as_ref(),
419 segment_id,
420 Arc::clone(&self.schema),
421 doc_offset,
422 self.config.term_cache_blocks,
423 )
424 .await?;
425 doc_offset += reader.meta().num_docs;
426 readers.push(reader);
427 }
428
429 Ok(readers)
430 }
431
432 async fn train_field_index(
434 &self,
435 field: Field,
436 config: &DenseVectorConfig,
437 all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
438 ) -> Result<()> {
439 let field_id = field.0;
440 let vectors = match all_vectors.get(&field_id) {
441 Some(v) if !v.is_empty() => v,
442 _ => return Ok(()),
443 };
444
445 let dim = config.dim;
446 let num_vectors = vectors.len();
447 let num_clusters = config.optimal_num_clusters(num_vectors);
448
449 log::info!(
450 "Training vector index for field {} with {} vectors, {} clusters (dim={})",
451 field_id,
452 num_vectors,
453 num_clusters,
454 dim,
455 );
456
457 let centroids_filename = format!("field_{}_centroids.bin", field_id);
458 let mut codebook_filename: Option<String> = None;
459
460 match config.index_type {
461 VectorIndexType::IvfRaBitQ => {
462 self.train_ivf_rabitq(field_id, dim, num_clusters, vectors, ¢roids_filename)
463 .await?;
464 }
465 VectorIndexType::ScaNN => {
466 codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
467 self.train_scann(
468 field_id,
469 dim,
470 num_clusters,
471 vectors,
472 ¢roids_filename,
473 codebook_filename.as_ref().unwrap(),
474 )
475 .await?;
476 }
477 _ => {
478 return Ok(());
480 }
481 }
482
483 self.segment_manager
485 .update_metadata(|meta| {
486 meta.init_field(field_id, config.index_type);
487 meta.total_vectors = num_vectors;
488 meta.mark_field_built(
489 field_id,
490 num_vectors,
491 num_clusters,
492 centroids_filename.clone(),
493 codebook_filename.clone(),
494 );
495 })
496 .await?;
497
498 Ok(())
499 }
500
501 async fn train_ivf_rabitq(
503 &self,
504 field_id: u32,
505 dim: usize,
506 num_clusters: usize,
507 vectors: &[Vec<f32>],
508 centroids_filename: &str,
509 ) -> Result<()> {
510 let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
511 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
512
513 let centroids_path = std::path::Path::new(centroids_filename);
515 let centroids_bytes =
516 serde_json::to_vec(¢roids).map_err(|e| Error::Serialization(e.to_string()))?;
517 self.directory
518 .write(centroids_path, ¢roids_bytes)
519 .await?;
520
521 log::info!(
522 "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
523 field_id,
524 centroids.num_clusters
525 );
526
527 Ok(())
528 }
529
530 async fn train_scann(
532 &self,
533 field_id: u32,
534 dim: usize,
535 num_clusters: usize,
536 vectors: &[Vec<f32>],
537 centroids_filename: &str,
538 codebook_filename: &str,
539 ) -> Result<()> {
540 let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
542 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
543
544 let pq_config = crate::structures::PQConfig::new(dim);
546 let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
547
548 let centroids_path = std::path::Path::new(centroids_filename);
550 let centroids_bytes =
551 serde_json::to_vec(¢roids).map_err(|e| Error::Serialization(e.to_string()))?;
552 self.directory
553 .write(centroids_path, ¢roids_bytes)
554 .await?;
555
556 let codebook_path = std::path::Path::new(codebook_filename);
558 let codebook_bytes =
559 serde_json::to_vec(&codebook).map_err(|e| Error::Serialization(e.to_string()))?;
560 self.directory.write(codebook_path, &codebook_bytes).await?;
561
562 log::info!(
563 "Saved ScaNN centroids and codebook for field {} ({} clusters)",
564 field_id,
565 centroids.num_clusters
566 );
567
568 Ok(())
569 }
570}