Skip to main content

hermes_core/index/
vector_builder.rs

1//! Vector index building for IndexWriter
2//!
3//! This module handles:
4//! - Training centroids/codebooks from accumulated Flat vectors
5//! - Rebuilding segments with ANN indexes
6//! - Threshold-based auto-triggering of vector index builds
7
8use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20    /// Check if any dense vector field should be built and trigger training
21    pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22        let dense_fields = self.get_dense_vector_fields();
23        if dense_fields.is_empty() {
24            return Ok(());
25        }
26
27        // Count total vectors across all segments
28        let segment_ids = self.segment_manager.get_segment_ids().await;
29        let total_vectors = self.count_flat_vectors(&segment_ids).await;
30
31        // Update total in metadata and check if any field should be built
32        let should_build = {
33            let metadata_arc = self.segment_manager.metadata();
34            let mut meta = metadata_arc.lock().await;
35            meta.total_vectors = total_vectors;
36            dense_fields.iter().any(|(field, config)| {
37                let threshold = config.build_threshold.unwrap_or(1000);
38                meta.should_build_field(field.0, threshold)
39            })
40        };
41
42        if should_build {
43            log::info!(
44                "Threshold crossed ({} vectors), auto-triggering vector index build",
45                total_vectors
46            );
47            self.build_vector_index().await?;
48        }
49
50        Ok(())
51    }
52
53    /// Build vector index from accumulated Flat vectors (trains ONCE)
54    ///
55    /// This trains centroids/codebooks from ALL vectors across all segments.
56    /// Training happens only ONCE - subsequent calls are no-ops if already built.
57    ///
58    /// **Note:** This is auto-triggered by `commit()` when threshold is crossed.
59    /// You typically don't need to call this manually.
60    ///
61    /// The process:
62    /// 1. Check if already built (skip if so)
63    /// 2. Collect all vectors from all segments
64    /// 3. Train centroids/codebooks based on schema's index_type
65    /// 4. Update metadata to mark as built (prevents re-training)
66    pub async fn build_vector_index(&self) -> Result<()> {
67        let dense_fields = self.get_dense_vector_fields();
68        if dense_fields.is_empty() {
69            log::info!("No dense vector fields configured for ANN indexing");
70            return Ok(());
71        }
72
73        // Check which fields need building (skip already built)
74        let fields_to_build = self.get_fields_to_build(&dense_fields).await;
75        if fields_to_build.is_empty() {
76            log::info!("All vector fields already built, skipping training");
77            return Ok(());
78        }
79
80        let segment_ids = self.segment_manager.get_segment_ids().await;
81        if segment_ids.is_empty() {
82            return Ok(());
83        }
84
85        // Collect all vectors from all segments for fields that need building
86        let all_vectors = self
87            .collect_vectors_for_training(&segment_ids, &fields_to_build)
88            .await?;
89
90        // Train centroids/codebooks for each field
91        for (field, config) in &fields_to_build {
92            self.train_field_index(*field, config, &all_vectors).await?;
93        }
94
95        log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
96
97        // Rebuild segments with ANN indexes using trained structures
98        self.rebuild_segments_with_ann().await?;
99
100        Ok(())
101    }
102
103    /// Rebuild all segments with ANN indexes using trained centroids/codebooks
104    pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
105        let segment_ids = self.segment_manager.get_segment_ids().await;
106        if segment_ids.is_empty() {
107            return Ok(());
108        }
109
110        // Load trained structures from metadata
111        let (trained_centroids, trained_codebooks) = {
112            let metadata_arc = self.segment_manager.metadata();
113            let meta = metadata_arc.lock().await;
114            meta.load_trained_structures(self.directory.as_ref()).await
115        };
116
117        if trained_centroids.is_empty() {
118            log::info!("No trained structures to rebuild with");
119            return Ok(());
120        }
121
122        let trained = TrainedVectorStructures {
123            centroids: trained_centroids,
124            codebooks: trained_codebooks,
125        };
126
127        // Load all segment readers
128        let readers = self.load_segment_readers(&segment_ids).await?;
129
130        // Merge all segments into one with ANN indexes
131        let merger = SegmentMerger::new(Arc::clone(&self.schema));
132        let new_segment_id = SegmentId::new();
133        merger
134            .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
135            .await?;
136
137        // Atomically update segments and delete old ones via SegmentManager
138        self.segment_manager
139            .replace_segments(vec![new_segment_id.to_hex()], segment_ids)
140            .await?;
141
142        log::info!("Segments rebuilt with ANN indexes");
143        Ok(())
144    }
145
146    /// Get total vector count across all segments (for threshold checking)
147    pub async fn total_vector_count(&self) -> usize {
148        let metadata_arc = self.segment_manager.metadata();
149        metadata_arc.lock().await.total_vectors
150    }
151
152    /// Check if vector index has been built for a field
153    pub async fn is_vector_index_built(&self, field: Field) -> bool {
154        let metadata_arc = self.segment_manager.metadata();
155        metadata_arc.lock().await.is_field_built(field.0)
156    }
157
158    /// Rebuild vector index by retraining centroids/codebooks
159    ///
160    /// Use this when:
161    /// - Significant new data has been added and you want better centroids
162    /// - You want to change the number of clusters
163    /// - The vector distribution has changed significantly
164    ///
165    /// This resets the Built state to Flat, then triggers a fresh training.
166    pub async fn rebuild_vector_index(&self) -> Result<()> {
167        let dense_fields: Vec<Field> = self
168            .schema
169            .fields()
170            .filter_map(|(field, entry)| {
171                if entry.field_type == FieldType::DenseVector && entry.indexed {
172                    Some(field)
173                } else {
174                    None
175                }
176            })
177            .collect();
178
179        if dense_fields.is_empty() {
180            return Ok(());
181        }
182
183        // Collect files to delete and reset fields to Flat state
184        let files_to_delete: Vec<String> = {
185            let metadata_arc = self.segment_manager.metadata();
186            let mut meta = metadata_arc.lock().await;
187            let mut files = Vec::new();
188            for field in &dense_fields {
189                if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
190                    field_meta.state = super::VectorIndexState::Flat;
191                    if let Some(ref f) = field_meta.centroids_file {
192                        files.push(f.clone());
193                    }
194                    if let Some(ref f) = field_meta.codebook_file {
195                        files.push(f.clone());
196                    }
197                    field_meta.centroids_file = None;
198                    field_meta.codebook_file = None;
199                }
200            }
201            meta.save(self.directory.as_ref()).await?;
202            files
203        };
204
205        // Delete old centroids/codebook files
206        for file in files_to_delete {
207            let _ = self.directory.delete(std::path::Path::new(&file)).await;
208        }
209
210        log::info!("Reset vector index state to Flat, triggering rebuild...");
211
212        // Now build fresh
213        self.build_vector_index().await
214    }
215
216    // ========================================================================
217    // Helper methods
218    // ========================================================================
219
220    /// Get all dense vector fields that need ANN indexes
221    fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
222        self.schema
223            .fields()
224            .filter_map(|(field, entry)| {
225                if entry.field_type == FieldType::DenseVector && entry.indexed {
226                    entry
227                        .dense_vector_config
228                        .as_ref()
229                        .filter(|c| !c.is_flat())
230                        .map(|c| (field, c.clone()))
231                } else {
232                    None
233                }
234            })
235            .collect()
236    }
237
238    /// Get fields that need building (not already built)
239    async fn get_fields_to_build(
240        &self,
241        dense_fields: &[(Field, DenseVectorConfig)],
242    ) -> Vec<(Field, DenseVectorConfig)> {
243        let metadata_arc = self.segment_manager.metadata();
244        let meta = metadata_arc.lock().await;
245        dense_fields
246            .iter()
247            .filter(|(field, _)| !meta.is_field_built(field.0))
248            .cloned()
249            .collect()
250    }
251
252    /// Count flat vectors across all segments
253    async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
254        let mut total_vectors = 0usize;
255        let mut doc_offset = 0u32;
256
257        for id_str in segment_ids {
258            if let Some(segment_id) = SegmentId::from_hex(id_str)
259                && let Ok(reader) = SegmentReader::open(
260                    self.directory.as_ref(),
261                    segment_id,
262                    Arc::clone(&self.schema),
263                    doc_offset,
264                    self.config.term_cache_blocks,
265                )
266                .await
267            {
268                for index in reader.vector_indexes().values() {
269                    if let crate::segment::VectorIndex::Flat(flat_data) = index {
270                        total_vectors += flat_data.vectors.len();
271                    }
272                }
273                doc_offset += reader.meta().num_docs;
274            }
275        }
276
277        total_vectors
278    }
279
280    /// Collect all vectors from segments for fields that need building
281    async fn collect_vectors_for_training(
282        &self,
283        segment_ids: &[String],
284        fields_to_build: &[(Field, DenseVectorConfig)],
285    ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
286        let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
287        let mut doc_offset = 0u32;
288
289        for id_str in segment_ids {
290            let segment_id = SegmentId::from_hex(id_str)
291                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
292            let reader = SegmentReader::open(
293                self.directory.as_ref(),
294                segment_id,
295                Arc::clone(&self.schema),
296                doc_offset,
297                self.config.term_cache_blocks,
298            )
299            .await?;
300
301            // Extract vectors from each Flat index
302            for (field_id, index) in reader.vector_indexes() {
303                if fields_to_build.iter().any(|(f, _)| f.0 == *field_id)
304                    && let crate::segment::VectorIndex::Flat(flat_data) = index
305                {
306                    all_vectors
307                        .entry(*field_id)
308                        .or_default()
309                        .extend(flat_data.vectors.iter().cloned());
310                }
311            }
312
313            doc_offset += reader.meta().num_docs;
314        }
315
316        Ok(all_vectors)
317    }
318
319    /// Load segment readers for given IDs
320    pub(super) async fn load_segment_readers(
321        &self,
322        segment_ids: &[String],
323    ) -> Result<Vec<SegmentReader>> {
324        let mut readers = Vec::new();
325        let mut doc_offset = 0u32;
326
327        for id_str in segment_ids {
328            let segment_id = SegmentId::from_hex(id_str)
329                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
330            let reader = SegmentReader::open(
331                self.directory.as_ref(),
332                segment_id,
333                Arc::clone(&self.schema),
334                doc_offset,
335                self.config.term_cache_blocks,
336            )
337            .await?;
338            doc_offset += reader.meta().num_docs;
339            readers.push(reader);
340        }
341
342        Ok(readers)
343    }
344
345    /// Train index for a single field
346    async fn train_field_index(
347        &self,
348        field: Field,
349        config: &DenseVectorConfig,
350        all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
351    ) -> Result<()> {
352        let field_id = field.0;
353        let vectors = match all_vectors.get(&field_id) {
354            Some(v) if !v.is_empty() => v,
355            _ => return Ok(()),
356        };
357
358        let index_dim = config.index_dim();
359        let num_vectors = vectors.len();
360        let num_clusters = config.optimal_num_clusters(num_vectors);
361
362        log::info!(
363            "Training vector index for field {} with {} vectors, {} clusters",
364            field_id,
365            num_vectors,
366            num_clusters
367        );
368
369        let centroids_filename = format!("field_{}_centroids.bin", field_id);
370        let mut codebook_filename: Option<String> = None;
371
372        match config.index_type {
373            VectorIndexType::IvfRaBitQ => {
374                self.train_ivf_rabitq(
375                    field_id,
376                    index_dim,
377                    num_clusters,
378                    vectors,
379                    &centroids_filename,
380                )
381                .await?;
382            }
383            VectorIndexType::ScaNN => {
384                codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
385                self.train_scann(
386                    field_id,
387                    index_dim,
388                    num_clusters,
389                    vectors,
390                    &centroids_filename,
391                    codebook_filename.as_ref().unwrap(),
392                )
393                .await?;
394            }
395            _ => {
396                // RaBitQ or Flat - no pre-training needed
397                return Ok(());
398            }
399        }
400
401        // Update metadata to mark this field as built
402        self.segment_manager
403            .update_metadata(|meta| {
404                meta.init_field(field_id, config.index_type);
405                meta.total_vectors = num_vectors;
406                meta.mark_field_built(
407                    field_id,
408                    num_vectors,
409                    num_clusters,
410                    centroids_filename.clone(),
411                    codebook_filename.clone(),
412                );
413            })
414            .await?;
415
416        Ok(())
417    }
418
419    /// Train IVF-RaBitQ centroids
420    async fn train_ivf_rabitq(
421        &self,
422        field_id: u32,
423        index_dim: usize,
424        num_clusters: usize,
425        vectors: &[Vec<f32>],
426        centroids_filename: &str,
427    ) -> Result<()> {
428        let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
429        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
430
431        // Save centroids to index-level file
432        let centroids_path = std::path::Path::new(centroids_filename);
433        let centroids_bytes =
434            serde_json::to_vec(&centroids).map_err(|e| Error::Serialization(e.to_string()))?;
435        self.directory
436            .write(centroids_path, &centroids_bytes)
437            .await?;
438
439        log::info!(
440            "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
441            field_id,
442            centroids.num_clusters
443        );
444
445        Ok(())
446    }
447
448    /// Train ScaNN (IVF-PQ) centroids and codebook
449    async fn train_scann(
450        &self,
451        field_id: u32,
452        index_dim: usize,
453        num_clusters: usize,
454        vectors: &[Vec<f32>],
455        centroids_filename: &str,
456        codebook_filename: &str,
457    ) -> Result<()> {
458        // Train coarse centroids
459        let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
460        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
461
462        // Train PQ codebook
463        let pq_config = crate::structures::PQConfig::new(index_dim);
464        let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
465
466        // Save centroids
467        let centroids_path = std::path::Path::new(centroids_filename);
468        let centroids_bytes =
469            serde_json::to_vec(&centroids).map_err(|e| Error::Serialization(e.to_string()))?;
470        self.directory
471            .write(centroids_path, &centroids_bytes)
472            .await?;
473
474        // Save codebook
475        let codebook_path = std::path::Path::new(codebook_filename);
476        let codebook_bytes =
477            serde_json::to_vec(&codebook).map_err(|e| Error::Serialization(e.to_string()))?;
478        self.directory.write(codebook_path, &codebook_bytes).await?;
479
480        log::info!(
481            "Saved ScaNN centroids and codebook for field {} ({} clusters)",
482            field_id,
483            centroids.num_clusters
484        );
485
486        Ok(())
487    }
488}