1use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20 pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22 let dense_fields = self.get_dense_vector_fields();
23 if dense_fields.is_empty() {
24 return Ok(());
25 }
26
27 let all_built = {
30 let metadata_arc = self.segment_manager.metadata();
31 let meta = metadata_arc.read().await;
32 dense_fields
33 .iter()
34 .all(|(field, _)| meta.is_field_built(field.0))
35 };
36 if all_built {
37 return Ok(());
38 }
39
40 let segment_ids = self.segment_manager.get_segment_ids().await;
42 let total_vectors = self.count_flat_vectors(&segment_ids).await;
43
44 let should_build = {
46 let metadata_arc = self.segment_manager.metadata();
47 let mut meta = metadata_arc.write().await;
48 meta.total_vectors = total_vectors;
49 dense_fields.iter().any(|(field, config)| {
50 let threshold = config.build_threshold.unwrap_or(1000);
51 meta.should_build_field(field.0, threshold)
52 })
53 };
54
55 if should_build {
56 log::info!(
57 "Threshold crossed ({} vectors), auto-triggering vector index build",
58 total_vectors
59 );
60 self.build_vector_index().await?;
61 }
62
63 Ok(())
64 }
65
66 pub async fn build_vector_index(&self) -> Result<()> {
80 let dense_fields = self.get_dense_vector_fields();
81 if dense_fields.is_empty() {
82 log::info!("No dense vector fields configured for ANN indexing");
83 return Ok(());
84 }
85
86 let fields_to_build = self.get_fields_to_build(&dense_fields).await;
88 if fields_to_build.is_empty() {
89 log::info!("All vector fields already built, skipping training");
90 return Ok(());
91 }
92
93 self.segment_manager.wait_for_merges().await;
98
99 let segment_ids = self.segment_manager.get_segment_ids().await;
100 if segment_ids.is_empty() {
101 return Ok(());
102 }
103
104 let all_vectors = self
106 .collect_vectors_for_training(&segment_ids, &fields_to_build)
107 .await?;
108
109 for (field, config) in &fields_to_build {
111 self.train_field_index(*field, config, &all_vectors).await?;
112 }
113
114 log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
115
116 self.rebuild_segments_with_ann().await?;
118
119 Ok(())
120 }
121
122 pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
124 self.segment_manager.pause_merges();
128 self.segment_manager.wait_for_merges().await;
129
130 let result = self.rebuild_segments_with_ann_inner().await;
131
132 self.segment_manager.resume_merges();
134
135 result
136 }
137
138 async fn rebuild_segments_with_ann_inner(&self) -> Result<()> {
139 let segment_ids = self.segment_manager.get_segment_ids().await;
140 if segment_ids.is_empty() {
141 return Ok(());
142 }
143
144 let (trained_centroids, trained_codebooks) = {
146 let metadata_arc = self.segment_manager.metadata();
147 let meta = metadata_arc.read().await;
148 meta.load_trained_structures(self.directory.as_ref()).await
149 };
150
151 if trained_centroids.is_empty() {
152 log::info!("No trained structures to rebuild with");
153 return Ok(());
154 }
155
156 let trained = TrainedVectorStructures {
157 centroids: trained_centroids,
158 codebooks: trained_codebooks,
159 };
160
161 let readers = self.load_segment_readers(&segment_ids).await?;
163
164 let total_docs: u32 = readers.iter().map(|r| r.meta().num_docs).sum();
166
167 let merger = SegmentMerger::new(Arc::clone(&self.schema));
169 let new_segment_id = SegmentId::new();
170 merger
171 .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
172 .await?;
173
174 self.segment_manager
176 .replace_segments(vec![(new_segment_id.to_hex(), total_docs)], segment_ids)
177 .await?;
178
179 log::info!("Segments rebuilt with ANN indexes");
180 Ok(())
181 }
182
183 pub async fn total_vector_count(&self) -> usize {
185 let metadata_arc = self.segment_manager.metadata();
186 metadata_arc.read().await.total_vectors
187 }
188
189 pub async fn is_vector_index_built(&self, field: Field) -> bool {
191 let metadata_arc = self.segment_manager.metadata();
192 metadata_arc.read().await.is_field_built(field.0)
193 }
194
195 pub async fn rebuild_vector_index(&self) -> Result<()> {
204 let dense_fields: Vec<Field> = self
205 .schema
206 .fields()
207 .filter_map(|(field, entry)| {
208 if entry.field_type == FieldType::DenseVector && entry.indexed {
209 Some(field)
210 } else {
211 None
212 }
213 })
214 .collect();
215
216 if dense_fields.is_empty() {
217 return Ok(());
218 }
219
220 let files_to_delete = {
222 let metadata_arc = self.segment_manager.metadata();
223 let mut meta = metadata_arc.write().await;
224 let mut files = Vec::new();
225 for field in &dense_fields {
226 if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
227 field_meta.state = super::VectorIndexState::Flat;
228 if let Some(ref f) = field_meta.centroids_file {
229 files.push(f.clone());
230 }
231 if let Some(ref f) = field_meta.codebook_file {
232 files.push(f.clone());
233 }
234 field_meta.centroids_file = None;
235 field_meta.codebook_file = None;
236 }
237 }
238 meta.save(self.directory.as_ref()).await?;
239 files
240 };
241
242 for file in files_to_delete {
244 let _ = self.directory.delete(std::path::Path::new(&file)).await;
245 }
246
247 log::info!("Reset vector index state to Flat, triggering rebuild...");
248
249 self.build_vector_index().await
251 }
252
253 fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
259 self.schema
260 .fields()
261 .filter_map(|(field, entry)| {
262 if entry.field_type == FieldType::DenseVector && entry.indexed {
263 entry
264 .dense_vector_config
265 .as_ref()
266 .filter(|c| !c.is_flat())
267 .map(|c| (field, c.clone()))
268 } else {
269 None
270 }
271 })
272 .collect()
273 }
274
275 async fn get_fields_to_build(
277 &self,
278 dense_fields: &[(Field, DenseVectorConfig)],
279 ) -> Vec<(Field, DenseVectorConfig)> {
280 let metadata_arc = self.segment_manager.metadata();
281 let meta = metadata_arc.read().await;
282 dense_fields
283 .iter()
284 .filter(|(field, _)| !meta.is_field_built(field.0))
285 .cloned()
286 .collect()
287 }
288
289 async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
292 let mut total_vectors = 0usize;
293 let mut doc_offset = 0u32;
294
295 for id_str in segment_ids {
296 let Some(segment_id) = SegmentId::from_hex(id_str) else {
297 continue;
298 };
299
300 let files = crate::segment::SegmentFiles::new(segment_id.0);
302 if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
303 continue;
305 }
306
307 if let Ok(reader) = SegmentReader::open(
309 self.directory.as_ref(),
310 segment_id,
311 Arc::clone(&self.schema),
312 doc_offset,
313 self.config.term_cache_blocks,
314 )
315 .await
316 {
317 for flat_data in reader.flat_vectors().values() {
318 total_vectors += flat_data.num_vectors;
319 }
320 doc_offset += reader.meta().num_docs;
321 }
322 }
323
324 total_vectors
325 }
326
327 async fn collect_vectors_for_training(
332 &self,
333 segment_ids: &[String],
334 fields_to_build: &[(Field, DenseVectorConfig)],
335 ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
336 const MAX_TRAINING_VECTORS: usize = 100_000;
338
339 let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
340 let mut doc_offset = 0u32;
341 let mut total_skipped = 0usize;
342
343 for id_str in segment_ids {
344 let segment_id = SegmentId::from_hex(id_str)
345 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
346 let reader = SegmentReader::open(
347 self.directory.as_ref(),
348 segment_id,
349 Arc::clone(&self.schema),
350 doc_offset,
351 self.config.term_cache_blocks,
352 )
353 .await?;
354
355 for (field_id, lazy_flat) in reader.flat_vectors() {
356 if !fields_to_build.iter().any(|(f, _)| f.0 == *field_id) {
357 continue;
358 }
359 let entry = all_vectors.entry(*field_id).or_default();
360 let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
361
362 if remaining == 0 {
363 total_skipped += lazy_flat.num_vectors;
364 continue;
365 }
366
367 let n = lazy_flat.num_vectors;
368 if n <= remaining {
369 for i in 0..n {
371 if let Ok(vec) = lazy_flat.get_vector(i).await {
372 entry.push(vec);
373 }
374 }
375 } else {
376 let step = (n / remaining).max(1);
378 for i in 0..n {
379 if i % step == 0
380 && entry.len() < MAX_TRAINING_VECTORS
381 && let Ok(vec) = lazy_flat.get_vector(i).await
382 {
383 entry.push(vec);
384 }
385 }
386 total_skipped += n - remaining;
387 }
388 }
389
390 doc_offset += reader.meta().num_docs;
391 }
392
393 if total_skipped > 0 {
394 let collected: usize = all_vectors.values().map(|v| v.len()).sum();
395 log::info!(
396 "Sampled {} vectors for training (skipped {}, max {} per field)",
397 collected,
398 total_skipped,
399 MAX_TRAINING_VECTORS,
400 );
401 }
402
403 Ok(all_vectors)
404 }
405
406 pub(super) async fn load_segment_readers(
408 &self,
409 segment_ids: &[String],
410 ) -> Result<Vec<SegmentReader>> {
411 let mut readers = Vec::new();
412 let mut doc_offset = 0u32;
413
414 for id_str in segment_ids {
415 let segment_id = SegmentId::from_hex(id_str)
416 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
417 let reader = SegmentReader::open(
418 self.directory.as_ref(),
419 segment_id,
420 Arc::clone(&self.schema),
421 doc_offset,
422 self.config.term_cache_blocks,
423 )
424 .await?;
425 doc_offset += reader.meta().num_docs;
426 readers.push(reader);
427 }
428
429 Ok(readers)
430 }
431
432 async fn train_field_index(
434 &self,
435 field: Field,
436 config: &DenseVectorConfig,
437 all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
438 ) -> Result<()> {
439 let field_id = field.0;
440 let vectors = match all_vectors.get(&field_id) {
441 Some(v) if !v.is_empty() => v,
442 _ => return Ok(()),
443 };
444
445 let index_dim = config.index_dim();
446 let num_vectors = vectors.len();
447 let num_clusters = config.optimal_num_clusters(num_vectors);
448
449 log::info!(
450 "Training vector index for field {} with {} vectors, {} clusters (index_dim={})",
451 field_id,
452 num_vectors,
453 num_clusters,
454 index_dim,
455 );
456
457 let trimmed: Vec<Vec<f32>>;
460 let training_vectors = if vectors.first().is_some_and(|v| v.len() > index_dim) {
461 trimmed = vectors.iter().map(|v| v[..index_dim].to_vec()).collect();
462 &trimmed
463 } else {
464 vectors
465 };
466
467 let centroids_filename = format!("field_{}_centroids.bin", field_id);
468 let mut codebook_filename: Option<String> = None;
469
470 match config.index_type {
471 VectorIndexType::IvfRaBitQ => {
472 self.train_ivf_rabitq(
473 field_id,
474 index_dim,
475 num_clusters,
476 training_vectors,
477 ¢roids_filename,
478 )
479 .await?;
480 }
481 VectorIndexType::ScaNN => {
482 codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
483 self.train_scann(
484 field_id,
485 index_dim,
486 num_clusters,
487 training_vectors,
488 ¢roids_filename,
489 codebook_filename.as_ref().unwrap(),
490 )
491 .await?;
492 }
493 _ => {
494 return Ok(());
496 }
497 }
498
499 self.segment_manager
501 .update_metadata(|meta| {
502 meta.init_field(field_id, config.index_type);
503 meta.total_vectors = num_vectors;
504 meta.mark_field_built(
505 field_id,
506 num_vectors,
507 num_clusters,
508 centroids_filename.clone(),
509 codebook_filename.clone(),
510 );
511 })
512 .await?;
513
514 Ok(())
515 }
516
517 async fn train_ivf_rabitq(
519 &self,
520 field_id: u32,
521 index_dim: usize,
522 num_clusters: usize,
523 vectors: &[Vec<f32>],
524 centroids_filename: &str,
525 ) -> Result<()> {
526 let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
527 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
528
529 let centroids_path = std::path::Path::new(centroids_filename);
531 let centroids_bytes =
532 serde_json::to_vec(¢roids).map_err(|e| Error::Serialization(e.to_string()))?;
533 self.directory
534 .write(centroids_path, ¢roids_bytes)
535 .await?;
536
537 log::info!(
538 "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
539 field_id,
540 centroids.num_clusters
541 );
542
543 Ok(())
544 }
545
546 async fn train_scann(
548 &self,
549 field_id: u32,
550 index_dim: usize,
551 num_clusters: usize,
552 vectors: &[Vec<f32>],
553 centroids_filename: &str,
554 codebook_filename: &str,
555 ) -> Result<()> {
556 let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
558 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
559
560 let pq_config = crate::structures::PQConfig::new(index_dim);
562 let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
563
564 let centroids_path = std::path::Path::new(centroids_filename);
566 let centroids_bytes =
567 serde_json::to_vec(¢roids).map_err(|e| Error::Serialization(e.to_string()))?;
568 self.directory
569 .write(centroids_path, ¢roids_bytes)
570 .await?;
571
572 let codebook_path = std::path::Path::new(codebook_filename);
574 let codebook_bytes =
575 serde_json::to_vec(&codebook).map_err(|e| Error::Serialization(e.to_string()))?;
576 self.directory.write(codebook_path, &codebook_bytes).await?;
577
578 log::info!(
579 "Saved ScaNN centroids and codebook for field {} ({} clusters)",
580 field_id,
581 centroids.num_clusters
582 );
583
584 Ok(())
585 }
586}