Skip to main content

hermes_core/index/
vector_builder.rs

1//! Vector index building for IndexWriter
2//!
3//! This module handles:
4//! - Training centroids/codebooks from accumulated Flat vectors
5//! - Rebuilding segments with ANN indexes
6//! - Threshold-based auto-triggering of vector index builds
7
8use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20    /// Check if any dense vector field should be built and trigger training
21    pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22        let dense_fields = self.get_dense_vector_fields();
23        if dense_fields.is_empty() {
24            return Ok(());
25        }
26
27        // Quick check: if all fields are already built, skip entirely
28        // This avoids loading segments just to count vectors when index is already built
29        let all_built = {
30            let metadata_arc = self.segment_manager.metadata();
31            let meta = metadata_arc.read().await;
32            dense_fields
33                .iter()
34                .all(|(field, _)| meta.is_field_built(field.0))
35        };
36        if all_built {
37            return Ok(());
38        }
39
40        // Count total vectors across all segments
41        let segment_ids = self.segment_manager.get_segment_ids().await;
42        let total_vectors = self.count_flat_vectors(&segment_ids).await;
43
44        // Update total in metadata and check if any field should be built
45        let should_build = {
46            let metadata_arc = self.segment_manager.metadata();
47            let mut meta = metadata_arc.write().await;
48            meta.total_vectors = total_vectors;
49            dense_fields.iter().any(|(field, config)| {
50                let threshold = config.build_threshold.unwrap_or(1000);
51                meta.should_build_field(field.0, threshold)
52            })
53        };
54
55        if should_build {
56            log::info!(
57                "Threshold crossed ({} vectors), auto-triggering vector index build",
58                total_vectors
59            );
60            self.build_vector_index().await?;
61        }
62
63        Ok(())
64    }
65
66    /// Build vector index from accumulated Flat vectors (trains ONCE)
67    ///
68    /// This trains centroids/codebooks from ALL vectors across all segments.
69    /// Training happens only ONCE - subsequent calls are no-ops if already built.
70    ///
71    /// **Note:** This is auto-triggered by `commit()` when threshold is crossed.
72    /// You typically don't need to call this manually.
73    ///
74    /// The process:
75    /// 1. Check if already built (skip if so)
76    /// 2. Collect all vectors from all segments
77    /// 3. Train centroids/codebooks based on schema's index_type
78    /// 4. Update metadata to mark as built (prevents re-training)
79    pub async fn build_vector_index(&self) -> Result<()> {
80        let dense_fields = self.get_dense_vector_fields();
81        if dense_fields.is_empty() {
82            log::info!("No dense vector fields configured for ANN indexing");
83            return Ok(());
84        }
85
86        // Check which fields need building (skip already built)
87        let fields_to_build = self.get_fields_to_build(&dense_fields).await;
88        if fields_to_build.is_empty() {
89            log::info!("All vector fields already built, skipping training");
90            return Ok(());
91        }
92
93        // Wait for any background merges to complete before training.
94        // rebuild_segments_with_ann() calls replace_segments() which clears ALL
95        // segments atomically — concurrent merges would lose data or operate on
96        // stale/deleted segments.
97        self.segment_manager.wait_for_merges().await;
98
99        let segment_ids = self.segment_manager.get_segment_ids().await;
100        if segment_ids.is_empty() {
101            return Ok(());
102        }
103
104        // Collect all vectors from all segments for fields that need building
105        let all_vectors = self
106            .collect_vectors_for_training(&segment_ids, &fields_to_build)
107            .await?;
108
109        // Train centroids/codebooks for each field
110        for (field, config) in &fields_to_build {
111            self.train_field_index(*field, config, &all_vectors).await?;
112        }
113
114        log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
115
116        // Rebuild segments with ANN indexes using trained structures
117        self.rebuild_segments_with_ann().await?;
118
119        Ok(())
120    }
121
122    /// Rebuild all segments with ANN indexes using trained centroids/codebooks
123    pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
124        // Pause background merges and wait for any in-flight ones to finish.
125        // rebuild replaces ALL segments atomically — concurrent merges would
126        // operate on stale/deleted segments or lose their output.
127        self.segment_manager.pause_merges();
128        self.segment_manager.wait_for_merges().await;
129
130        let result = self.rebuild_segments_with_ann_inner().await;
131
132        // Always resume merges, even on error
133        self.segment_manager.resume_merges();
134
135        result
136    }
137
138    async fn rebuild_segments_with_ann_inner(&self) -> Result<()> {
139        let segment_ids = self.segment_manager.get_segment_ids().await;
140        if segment_ids.is_empty() {
141            return Ok(());
142        }
143
144        // Load trained structures from metadata
145        let (trained_centroids, trained_codebooks) = {
146            let metadata_arc = self.segment_manager.metadata();
147            let meta = metadata_arc.read().await;
148            meta.load_trained_structures(self.directory.as_ref()).await
149        };
150
151        if trained_centroids.is_empty() {
152            log::info!("No trained structures to rebuild with");
153            return Ok(());
154        }
155
156        let trained = TrainedVectorStructures {
157            centroids: trained_centroids,
158            codebooks: trained_codebooks,
159        };
160
161        // Load all segment readers
162        let readers = self.load_segment_readers(&segment_ids).await?;
163
164        // Calculate total doc count for the merged segment
165        let total_docs: u32 = readers.iter().map(|r| r.meta().num_docs).sum();
166
167        // Merge all segments into one with ANN indexes
168        let merger = SegmentMerger::new(Arc::clone(&self.schema));
169        let new_segment_id = SegmentId::new();
170        merger
171            .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
172            .await?;
173
174        // Atomically update segments and delete old ones via SegmentManager
175        self.segment_manager
176            .replace_segments(vec![(new_segment_id.to_hex(), total_docs)], segment_ids)
177            .await?;
178
179        log::info!("Segments rebuilt with ANN indexes");
180        Ok(())
181    }
182
183    /// Get total vector count across all segments (for threshold checking)
184    pub async fn total_vector_count(&self) -> usize {
185        let metadata_arc = self.segment_manager.metadata();
186        metadata_arc.read().await.total_vectors
187    }
188
189    /// Check if vector index has been built for a field
190    pub async fn is_vector_index_built(&self, field: Field) -> bool {
191        let metadata_arc = self.segment_manager.metadata();
192        metadata_arc.read().await.is_field_built(field.0)
193    }
194
195    /// Rebuild vector index by retraining centroids/codebooks
196    ///
197    /// Use this when:
198    /// - Significant new data has been added and you want better centroids
199    /// - You want to change the number of clusters
200    /// - The vector distribution has changed significantly
201    ///
202    /// This resets the Built state to Flat, then triggers a fresh training.
203    pub async fn rebuild_vector_index(&self) -> Result<()> {
204        let dense_fields: Vec<Field> = self
205            .schema
206            .fields()
207            .filter_map(|(field, entry)| {
208                if entry.field_type == FieldType::DenseVector && entry.indexed {
209                    Some(field)
210                } else {
211                    None
212                }
213            })
214            .collect();
215
216        if dense_fields.is_empty() {
217            return Ok(());
218        }
219
220        // Collect files to delete and reset fields to Flat state
221        let files_to_delete = {
222            let metadata_arc = self.segment_manager.metadata();
223            let mut meta = metadata_arc.write().await;
224            let mut files = Vec::new();
225            for field in &dense_fields {
226                if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
227                    field_meta.state = super::VectorIndexState::Flat;
228                    if let Some(ref f) = field_meta.centroids_file {
229                        files.push(f.clone());
230                    }
231                    if let Some(ref f) = field_meta.codebook_file {
232                        files.push(f.clone());
233                    }
234                    field_meta.centroids_file = None;
235                    field_meta.codebook_file = None;
236                }
237            }
238            meta.save(self.directory.as_ref()).await?;
239            files
240        };
241
242        // Delete old centroids/codebook files
243        for file in files_to_delete {
244            let _ = self.directory.delete(std::path::Path::new(&file)).await;
245        }
246
247        log::info!("Reset vector index state to Flat, triggering rebuild...");
248
249        // Now build fresh
250        self.build_vector_index().await
251    }
252
253    // ========================================================================
254    // Helper methods
255    // ========================================================================
256
257    /// Get all dense vector fields that need ANN indexes
258    fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
259        self.schema
260            .fields()
261            .filter_map(|(field, entry)| {
262                if entry.field_type == FieldType::DenseVector && entry.indexed {
263                    entry
264                        .dense_vector_config
265                        .as_ref()
266                        .filter(|c| !c.is_flat())
267                        .map(|c| (field, c.clone()))
268                } else {
269                    None
270                }
271            })
272            .collect()
273    }
274
275    /// Get fields that need building (not already built)
276    async fn get_fields_to_build(
277        &self,
278        dense_fields: &[(Field, DenseVectorConfig)],
279    ) -> Vec<(Field, DenseVectorConfig)> {
280        let metadata_arc = self.segment_manager.metadata();
281        let meta = metadata_arc.read().await;
282        dense_fields
283            .iter()
284            .filter(|(field, _)| !meta.is_field_built(field.0))
285            .cloned()
286            .collect()
287    }
288
289    /// Count flat vectors across all segments
290    /// Only loads segments that have a vectors file to avoid unnecessary I/O
291    async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
292        let mut total_vectors = 0usize;
293        let mut doc_offset = 0u32;
294
295        for id_str in segment_ids {
296            let Some(segment_id) = SegmentId::from_hex(id_str) else {
297                continue;
298            };
299
300            // Quick check: skip segments without vectors file
301            let files = crate::segment::SegmentFiles::new(segment_id.0);
302            if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
303                // No vectors file - segment has no vectors, skip loading
304                continue;
305            }
306
307            // Only load segments that have vectors
308            if let Ok(reader) = SegmentReader::open(
309                self.directory.as_ref(),
310                segment_id,
311                Arc::clone(&self.schema),
312                doc_offset,
313                self.config.term_cache_blocks,
314            )
315            .await
316            {
317                for index in reader.vector_indexes().values() {
318                    if let crate::segment::VectorIndex::Flat(flat_data) = index {
319                        total_vectors += flat_data.num_vectors();
320                    }
321                }
322                doc_offset += reader.meta().num_docs;
323            }
324        }
325
326        total_vectors
327    }
328
329    /// Collect vectors from segments for training, with sampling for large datasets.
330    ///
331    /// K-means clustering converges well with ~100K samples, so we cap collection
332    /// per field to avoid loading millions of vectors into memory.
333    async fn collect_vectors_for_training(
334        &self,
335        segment_ids: &[String],
336        fields_to_build: &[(Field, DenseVectorConfig)],
337    ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
338        /// Maximum vectors per field for training. K-means converges well with ~100K samples.
339        const MAX_TRAINING_VECTORS: usize = 100_000;
340
341        let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
342        let mut doc_offset = 0u32;
343        let mut total_skipped = 0usize;
344
345        for id_str in segment_ids {
346            let segment_id = SegmentId::from_hex(id_str)
347                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
348            let reader = SegmentReader::open(
349                self.directory.as_ref(),
350                segment_id,
351                Arc::clone(&self.schema),
352                doc_offset,
353                self.config.term_cache_blocks,
354            )
355            .await?;
356
357            for (field_id, index) in reader.vector_indexes() {
358                if fields_to_build.iter().any(|(f, _)| f.0 == *field_id)
359                    && let crate::segment::VectorIndex::Flat(flat_data) = index
360                {
361                    let entry = all_vectors.entry(*field_id).or_default();
362                    let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
363
364                    if remaining == 0 {
365                        total_skipped += flat_data.num_vectors();
366                        continue;
367                    }
368
369                    let n = flat_data.num_vectors();
370                    if n <= remaining {
371                        // Take all vectors from this segment
372                        entry.extend((0..n).map(|i| flat_data.get_vector(i).to_vec()));
373                    } else {
374                        // Uniform sample: take every Nth vector
375                        let step = (n / remaining).max(1);
376                        for i in 0..n {
377                            if i % step == 0 && entry.len() < MAX_TRAINING_VECTORS {
378                                entry.push(flat_data.get_vector(i).to_vec());
379                            }
380                        }
381                        total_skipped += n - remaining;
382                    }
383                }
384            }
385
386            doc_offset += reader.meta().num_docs;
387        }
388
389        if total_skipped > 0 {
390            let collected: usize = all_vectors.values().map(|v| v.len()).sum();
391            log::info!(
392                "Sampled {} vectors for training (skipped {}, max {} per field)",
393                collected,
394                total_skipped,
395                MAX_TRAINING_VECTORS,
396            );
397        }
398
399        Ok(all_vectors)
400    }
401
402    /// Load segment readers for given IDs
403    pub(super) async fn load_segment_readers(
404        &self,
405        segment_ids: &[String],
406    ) -> Result<Vec<SegmentReader>> {
407        let mut readers = Vec::new();
408        let mut doc_offset = 0u32;
409
410        for id_str in segment_ids {
411            let segment_id = SegmentId::from_hex(id_str)
412                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
413            let reader = SegmentReader::open(
414                self.directory.as_ref(),
415                segment_id,
416                Arc::clone(&self.schema),
417                doc_offset,
418                self.config.term_cache_blocks,
419            )
420            .await?;
421            doc_offset += reader.meta().num_docs;
422            readers.push(reader);
423        }
424
425        Ok(readers)
426    }
427
428    /// Train index for a single field
429    async fn train_field_index(
430        &self,
431        field: Field,
432        config: &DenseVectorConfig,
433        all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
434    ) -> Result<()> {
435        let field_id = field.0;
436        let vectors = match all_vectors.get(&field_id) {
437            Some(v) if !v.is_empty() => v,
438            _ => return Ok(()),
439        };
440
441        let index_dim = config.index_dim();
442        let num_vectors = vectors.len();
443        let num_clusters = config.optimal_num_clusters(num_vectors);
444
445        log::info!(
446            "Training vector index for field {} with {} vectors, {} clusters",
447            field_id,
448            num_vectors,
449            num_clusters
450        );
451
452        let centroids_filename = format!("field_{}_centroids.bin", field_id);
453        let mut codebook_filename: Option<String> = None;
454
455        match config.index_type {
456            VectorIndexType::IvfRaBitQ => {
457                self.train_ivf_rabitq(
458                    field_id,
459                    index_dim,
460                    num_clusters,
461                    vectors,
462                    &centroids_filename,
463                )
464                .await?;
465            }
466            VectorIndexType::ScaNN => {
467                codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
468                self.train_scann(
469                    field_id,
470                    index_dim,
471                    num_clusters,
472                    vectors,
473                    &centroids_filename,
474                    codebook_filename.as_ref().unwrap(),
475                )
476                .await?;
477            }
478            _ => {
479                // RaBitQ or Flat - no pre-training needed
480                return Ok(());
481            }
482        }
483
484        // Update metadata to mark this field as built
485        self.segment_manager
486            .update_metadata(|meta| {
487                meta.init_field(field_id, config.index_type);
488                meta.total_vectors = num_vectors;
489                meta.mark_field_built(
490                    field_id,
491                    num_vectors,
492                    num_clusters,
493                    centroids_filename.clone(),
494                    codebook_filename.clone(),
495                );
496            })
497            .await?;
498
499        Ok(())
500    }
501
502    /// Train IVF-RaBitQ centroids
503    async fn train_ivf_rabitq(
504        &self,
505        field_id: u32,
506        index_dim: usize,
507        num_clusters: usize,
508        vectors: &[Vec<f32>],
509        centroids_filename: &str,
510    ) -> Result<()> {
511        let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
512        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
513
514        // Save centroids to index-level file
515        let centroids_path = std::path::Path::new(centroids_filename);
516        let centroids_bytes =
517            serde_json::to_vec(&centroids).map_err(|e| Error::Serialization(e.to_string()))?;
518        self.directory
519            .write(centroids_path, &centroids_bytes)
520            .await?;
521
522        log::info!(
523            "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
524            field_id,
525            centroids.num_clusters
526        );
527
528        Ok(())
529    }
530
531    /// Train ScaNN (IVF-PQ) centroids and codebook
532    async fn train_scann(
533        &self,
534        field_id: u32,
535        index_dim: usize,
536        num_clusters: usize,
537        vectors: &[Vec<f32>],
538        centroids_filename: &str,
539        codebook_filename: &str,
540    ) -> Result<()> {
541        // Train coarse centroids
542        let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
543        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
544
545        // Train PQ codebook
546        let pq_config = crate::structures::PQConfig::new(index_dim);
547        let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
548
549        // Save centroids
550        let centroids_path = std::path::Path::new(centroids_filename);
551        let centroids_bytes =
552            serde_json::to_vec(&centroids).map_err(|e| Error::Serialization(e.to_string()))?;
553        self.directory
554            .write(centroids_path, &centroids_bytes)
555            .await?;
556
557        // Save codebook
558        let codebook_path = std::path::Path::new(codebook_filename);
559        let codebook_bytes =
560            serde_json::to_vec(&codebook).map_err(|e| Error::Serialization(e.to_string()))?;
561        self.directory.write(codebook_path, &codebook_bytes).await?;
562
563        log::info!(
564            "Saved ScaNN centroids and codebook for field {} ({} clusters)",
565            field_id,
566            centroids.num_clusters
567        );
568
569        Ok(())
570    }
571}