Skip to main content

hermes_core/index/
vector_builder.rs

1//! Vector index building for IndexWriter
2//!
3//! This module handles:
4//! - Training centroids/codebooks from accumulated Flat vectors
5//! - Rebuilding segments with ANN indexes
6//! - Threshold-based auto-triggering of vector index builds
7
8use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20    /// Check if any dense vector field should be built and trigger training
21    pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22        let dense_fields = self.get_dense_vector_fields();
23        if dense_fields.is_empty() {
24            return Ok(());
25        }
26
27        // Quick check: if all fields are already built, skip entirely
28        // This avoids loading segments just to count vectors when index is already built
29        let all_built = {
30            let metadata_arc = self.segment_manager.metadata();
31            let meta = metadata_arc.read().await;
32            dense_fields
33                .iter()
34                .all(|(field, _)| meta.is_field_built(field.0))
35        };
36        if all_built {
37            return Ok(());
38        }
39
40        // Count total vectors across all segments
41        let segment_ids = self.segment_manager.get_segment_ids().await;
42        let total_vectors = self.count_flat_vectors(&segment_ids).await;
43
44        // Update total in metadata and check if any field should be built
45        let should_build = {
46            let metadata_arc = self.segment_manager.metadata();
47            let mut meta = metadata_arc.write().await;
48            meta.total_vectors = total_vectors;
49            dense_fields.iter().any(|(field, config)| {
50                let threshold = config.build_threshold.unwrap_or(1000);
51                meta.should_build_field(field.0, threshold)
52            })
53        };
54
55        if should_build {
56            log::info!(
57                "Threshold crossed ({} vectors), auto-triggering vector index build",
58                total_vectors
59            );
60            self.build_vector_index().await?;
61        }
62
63        Ok(())
64    }
65
66    /// Build vector index from accumulated Flat vectors (trains ONCE)
67    ///
68    /// This trains centroids/codebooks from ALL vectors across all segments.
69    /// Training happens only ONCE - subsequent calls are no-ops if already built.
70    ///
71    /// **Note:** This is auto-triggered by `commit()` when threshold is crossed.
72    /// You typically don't need to call this manually.
73    ///
74    /// The process:
75    /// 1. Check if already built (skip if so)
76    /// 2. Collect all vectors from all segments
77    /// 3. Train centroids/codebooks based on schema's index_type
78    /// 4. Update metadata to mark as built (prevents re-training)
79    pub async fn build_vector_index(&self) -> Result<()> {
80        let dense_fields = self.get_dense_vector_fields();
81        if dense_fields.is_empty() {
82            log::info!("No dense vector fields configured for ANN indexing");
83            return Ok(());
84        }
85
86        // Check which fields need building (skip already built)
87        let fields_to_build = self.get_fields_to_build(&dense_fields).await;
88        if fields_to_build.is_empty() {
89            log::info!("All vector fields already built, skipping training");
90            return Ok(());
91        }
92
93        // Wait for any background merges to complete before training.
94        // rebuild_segments_with_ann() calls replace_segments() which clears ALL
95        // segments atomically — concurrent merges would lose data or operate on
96        // stale/deleted segments.
97        self.segment_manager.wait_for_merges().await;
98
99        let segment_ids = self.segment_manager.get_segment_ids().await;
100        if segment_ids.is_empty() {
101            return Ok(());
102        }
103
104        // Collect all vectors from all segments for fields that need building
105        let all_vectors = self
106            .collect_vectors_for_training(&segment_ids, &fields_to_build)
107            .await?;
108
109        // Train centroids/codebooks for each field
110        for (field, config) in &fields_to_build {
111            self.train_field_index(*field, config, &all_vectors).await?;
112        }
113
114        log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
115
116        // Rebuild segments with ANN indexes using trained structures
117        self.rebuild_segments_with_ann().await?;
118
119        Ok(())
120    }
121
122    /// Rebuild all segments with ANN indexes using trained centroids/codebooks
123    pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
124        // Pause background merges and wait for any in-flight ones to finish.
125        // rebuild replaces ALL segments atomically — concurrent merges would
126        // operate on stale/deleted segments or lose their output.
127        self.segment_manager.pause_merges();
128        self.segment_manager.wait_for_merges().await;
129
130        let result = self.rebuild_segments_with_ann_inner().await;
131
132        // Always resume merges, even on error
133        self.segment_manager.resume_merges();
134
135        result
136    }
137
138    async fn rebuild_segments_with_ann_inner(&self) -> Result<()> {
139        let segment_ids = self.segment_manager.get_segment_ids().await;
140        if segment_ids.is_empty() {
141            return Ok(());
142        }
143
144        // Load trained structures from metadata
145        let (trained_centroids, trained_codebooks) = {
146            let metadata_arc = self.segment_manager.metadata();
147            let meta = metadata_arc.read().await;
148            meta.load_trained_structures(self.directory.as_ref()).await
149        };
150
151        if trained_centroids.is_empty() {
152            log::info!("No trained structures to rebuild with");
153            return Ok(());
154        }
155
156        let trained = TrainedVectorStructures {
157            centroids: trained_centroids,
158            codebooks: trained_codebooks,
159        };
160
161        // Load all segment readers
162        let readers = self.load_segment_readers(&segment_ids).await?;
163
164        // Calculate total doc count for the merged segment
165        let total_docs: u32 = readers.iter().map(|r| r.meta().num_docs).sum();
166
167        // Merge all segments into one with ANN indexes
168        let merger = SegmentMerger::new(Arc::clone(&self.schema));
169        let new_segment_id = SegmentId::new();
170        merger
171            .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
172            .await?;
173
174        // Atomically update segments and delete old ones via SegmentManager
175        self.segment_manager
176            .replace_segments(vec![(new_segment_id.to_hex(), total_docs)], segment_ids)
177            .await?;
178
179        log::info!("Segments rebuilt with ANN indexes");
180        Ok(())
181    }
182
183    /// Get total vector count across all segments (for threshold checking)
184    pub async fn total_vector_count(&self) -> usize {
185        let metadata_arc = self.segment_manager.metadata();
186        metadata_arc.read().await.total_vectors
187    }
188
189    /// Check if vector index has been built for a field
190    pub async fn is_vector_index_built(&self, field: Field) -> bool {
191        let metadata_arc = self.segment_manager.metadata();
192        metadata_arc.read().await.is_field_built(field.0)
193    }
194
195    /// Rebuild vector index by retraining centroids/codebooks
196    ///
197    /// Use this when:
198    /// - Significant new data has been added and you want better centroids
199    /// - You want to change the number of clusters
200    /// - The vector distribution has changed significantly
201    ///
202    /// This resets the Built state to Flat, then triggers a fresh training.
203    pub async fn rebuild_vector_index(&self) -> Result<()> {
204        let dense_fields: Vec<Field> = self
205            .schema
206            .fields()
207            .filter_map(|(field, entry)| {
208                if entry.field_type == FieldType::DenseVector && entry.indexed {
209                    Some(field)
210                } else {
211                    None
212                }
213            })
214            .collect();
215
216        if dense_fields.is_empty() {
217            return Ok(());
218        }
219
220        // Collect files to delete and reset fields to Flat state
221        let files_to_delete = {
222            let metadata_arc = self.segment_manager.metadata();
223            let mut meta = metadata_arc.write().await;
224            let mut files = Vec::new();
225            for field in &dense_fields {
226                if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
227                    field_meta.state = super::VectorIndexState::Flat;
228                    if let Some(ref f) = field_meta.centroids_file {
229                        files.push(f.clone());
230                    }
231                    if let Some(ref f) = field_meta.codebook_file {
232                        files.push(f.clone());
233                    }
234                    field_meta.centroids_file = None;
235                    field_meta.codebook_file = None;
236                }
237            }
238            meta.save(self.directory.as_ref()).await?;
239            files
240        };
241
242        // Delete old centroids/codebook files
243        for file in files_to_delete {
244            let _ = self.directory.delete(std::path::Path::new(&file)).await;
245        }
246
247        log::info!("Reset vector index state to Flat, triggering rebuild...");
248
249        // Now build fresh
250        self.build_vector_index().await
251    }
252
253    // ========================================================================
254    // Helper methods
255    // ========================================================================
256
257    /// Get all dense vector fields that need ANN indexes
258    fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
259        self.schema
260            .fields()
261            .filter_map(|(field, entry)| {
262                if entry.field_type == FieldType::DenseVector && entry.indexed {
263                    entry
264                        .dense_vector_config
265                        .as_ref()
266                        .filter(|c| !c.is_flat())
267                        .map(|c| (field, c.clone()))
268                } else {
269                    None
270                }
271            })
272            .collect()
273    }
274
275    /// Get fields that need building (not already built)
276    async fn get_fields_to_build(
277        &self,
278        dense_fields: &[(Field, DenseVectorConfig)],
279    ) -> Vec<(Field, DenseVectorConfig)> {
280        let metadata_arc = self.segment_manager.metadata();
281        let meta = metadata_arc.read().await;
282        dense_fields
283            .iter()
284            .filter(|(field, _)| !meta.is_field_built(field.0))
285            .cloned()
286            .collect()
287    }
288
289    /// Count flat vectors across all segments
290    /// Only loads segments that have a vectors file to avoid unnecessary I/O
291    async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
292        let mut total_vectors = 0usize;
293        let mut doc_offset = 0u32;
294
295        for id_str in segment_ids {
296            let Some(segment_id) = SegmentId::from_hex(id_str) else {
297                continue;
298            };
299
300            // Quick check: skip segments without vectors file
301            let files = crate::segment::SegmentFiles::new(segment_id.0);
302            if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
303                // No vectors file - segment has no vectors, skip loading
304                continue;
305            }
306
307            // Only load segments that have vectors
308            if let Ok(reader) = SegmentReader::open(
309                self.directory.as_ref(),
310                segment_id,
311                Arc::clone(&self.schema),
312                doc_offset,
313                self.config.term_cache_blocks,
314            )
315            .await
316            {
317                for flat_data in reader.flat_vectors().values() {
318                    total_vectors += flat_data.num_vectors;
319                }
320                doc_offset += reader.meta().num_docs;
321            }
322        }
323
324        total_vectors
325    }
326
327    /// Collect vectors from segments for training, with sampling for large datasets.
328    ///
329    /// K-means clustering converges well with ~100K samples, so we cap collection
330    /// per field to avoid loading millions of vectors into memory.
331    async fn collect_vectors_for_training(
332        &self,
333        segment_ids: &[String],
334        fields_to_build: &[(Field, DenseVectorConfig)],
335    ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
336        /// Maximum vectors per field for training. K-means converges well with ~100K samples.
337        const MAX_TRAINING_VECTORS: usize = 100_000;
338
339        let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
340        let mut doc_offset = 0u32;
341        let mut total_skipped = 0usize;
342
343        for id_str in segment_ids {
344            let segment_id = SegmentId::from_hex(id_str)
345                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
346            let reader = SegmentReader::open(
347                self.directory.as_ref(),
348                segment_id,
349                Arc::clone(&self.schema),
350                doc_offset,
351                self.config.term_cache_blocks,
352            )
353            .await?;
354
355            for (field_id, lazy_flat) in reader.flat_vectors() {
356                if !fields_to_build.iter().any(|(f, _)| f.0 == *field_id) {
357                    continue;
358                }
359                let entry = all_vectors.entry(*field_id).or_default();
360                let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
361
362                if remaining == 0 {
363                    total_skipped += lazy_flat.num_vectors;
364                    continue;
365                }
366
367                let n = lazy_flat.num_vectors;
368                if n <= remaining {
369                    // Take all vectors from this segment (async reads)
370                    for i in 0..n {
371                        if let Ok(vec) = lazy_flat.get_vector(i).await {
372                            entry.push(vec);
373                        }
374                    }
375                } else {
376                    // Uniform sample: take every Nth vector
377                    let step = (n / remaining).max(1);
378                    for i in 0..n {
379                        if i % step == 0
380                            && entry.len() < MAX_TRAINING_VECTORS
381                            && let Ok(vec) = lazy_flat.get_vector(i).await
382                        {
383                            entry.push(vec);
384                        }
385                    }
386                    total_skipped += n - remaining;
387                }
388            }
389
390            doc_offset += reader.meta().num_docs;
391        }
392
393        if total_skipped > 0 {
394            let collected: usize = all_vectors.values().map(|v| v.len()).sum();
395            log::info!(
396                "Sampled {} vectors for training (skipped {}, max {} per field)",
397                collected,
398                total_skipped,
399                MAX_TRAINING_VECTORS,
400            );
401        }
402
403        Ok(all_vectors)
404    }
405
406    /// Load segment readers for given IDs
407    pub(super) async fn load_segment_readers(
408        &self,
409        segment_ids: &[String],
410    ) -> Result<Vec<SegmentReader>> {
411        let mut readers = Vec::new();
412        let mut doc_offset = 0u32;
413
414        for id_str in segment_ids {
415            let segment_id = SegmentId::from_hex(id_str)
416                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
417            let reader = SegmentReader::open(
418                self.directory.as_ref(),
419                segment_id,
420                Arc::clone(&self.schema),
421                doc_offset,
422                self.config.term_cache_blocks,
423            )
424            .await?;
425            doc_offset += reader.meta().num_docs;
426            readers.push(reader);
427        }
428
429        Ok(readers)
430    }
431
432    /// Train index for a single field
433    async fn train_field_index(
434        &self,
435        field: Field,
436        config: &DenseVectorConfig,
437        all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
438    ) -> Result<()> {
439        let field_id = field.0;
440        let vectors = match all_vectors.get(&field_id) {
441            Some(v) if !v.is_empty() => v,
442            _ => return Ok(()),
443        };
444
445        let dim = config.dim;
446        let num_vectors = vectors.len();
447        let num_clusters = config.optimal_num_clusters(num_vectors);
448
449        log::info!(
450            "Training vector index for field {} with {} vectors, {} clusters (dim={})",
451            field_id,
452            num_vectors,
453            num_clusters,
454            dim,
455        );
456
457        let centroids_filename = format!("field_{}_centroids.bin", field_id);
458        let mut codebook_filename: Option<String> = None;
459
460        match config.index_type {
461            VectorIndexType::IvfRaBitQ => {
462                self.train_ivf_rabitq(field_id, dim, num_clusters, vectors, &centroids_filename)
463                    .await?;
464            }
465            VectorIndexType::ScaNN => {
466                codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
467                self.train_scann(
468                    field_id,
469                    dim,
470                    num_clusters,
471                    vectors,
472                    &centroids_filename,
473                    codebook_filename.as_ref().unwrap(),
474                )
475                .await?;
476            }
477            _ => {
478                // RaBitQ or Flat - no pre-training needed
479                return Ok(());
480            }
481        }
482
483        // Update metadata to mark this field as built
484        self.segment_manager
485            .update_metadata(|meta| {
486                meta.init_field(field_id, config.index_type);
487                meta.total_vectors = num_vectors;
488                meta.mark_field_built(
489                    field_id,
490                    num_vectors,
491                    num_clusters,
492                    centroids_filename.clone(),
493                    codebook_filename.clone(),
494                );
495            })
496            .await?;
497
498        Ok(())
499    }
500
501    /// Train IVF-RaBitQ centroids
502    async fn train_ivf_rabitq(
503        &self,
504        field_id: u32,
505        dim: usize,
506        num_clusters: usize,
507        vectors: &[Vec<f32>],
508        centroids_filename: &str,
509    ) -> Result<()> {
510        let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
511        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
512
513        // Save centroids to index-level file
514        let centroids_path = std::path::Path::new(centroids_filename);
515        let centroids_bytes =
516            serde_json::to_vec(&centroids).map_err(|e| Error::Serialization(e.to_string()))?;
517        self.directory
518            .write(centroids_path, &centroids_bytes)
519            .await?;
520
521        log::info!(
522            "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
523            field_id,
524            centroids.num_clusters
525        );
526
527        Ok(())
528    }
529
530    /// Train ScaNN (IVF-PQ) centroids and codebook
531    async fn train_scann(
532        &self,
533        field_id: u32,
534        dim: usize,
535        num_clusters: usize,
536        vectors: &[Vec<f32>],
537        centroids_filename: &str,
538        codebook_filename: &str,
539    ) -> Result<()> {
540        // Train coarse centroids
541        let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
542        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
543
544        // Train PQ codebook
545        let pq_config = crate::structures::PQConfig::new(dim);
546        let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
547
548        // Save centroids
549        let centroids_path = std::path::Path::new(centroids_filename);
550        let centroids_bytes =
551            serde_json::to_vec(&centroids).map_err(|e| Error::Serialization(e.to_string()))?;
552        self.directory
553            .write(centroids_path, &centroids_bytes)
554            .await?;
555
556        // Save codebook
557        let codebook_path = std::path::Path::new(codebook_filename);
558        let codebook_bytes =
559            serde_json::to_vec(&codebook).map_err(|e| Error::Serialization(e.to_string()))?;
560        self.directory.write(codebook_path, &codebook_bytes).await?;
561
562        log::info!(
563            "Saved ScaNN centroids and codebook for field {} ({} clusters)",
564            field_id,
565            centroids.num_clusters
566        );
567
568        Ok(())
569    }
570}