Skip to main content

hermes_core/index/
vector_builder.rs

1//! Vector index building for IndexWriter
2//!
3//! This module handles:
4//! - Training centroids/codebooks from accumulated Flat vectors
5//! - Rebuilding segments with ANN indexes
6//! - Threshold-based auto-triggering of vector index builds
7
8use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20    /// Check if any dense vector field should be built and trigger training
21    pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22        let dense_fields = self.get_dense_vector_fields();
23        if dense_fields.is_empty() {
24            return Ok(());
25        }
26
27        // Quick check: if all fields are already built, skip entirely
28        // This avoids loading segments just to count vectors when index is already built
29        let all_built = {
30            let metadata_arc = self.segment_manager.metadata();
31            let meta = metadata_arc.read().await;
32            dense_fields
33                .iter()
34                .all(|(field, _)| meta.is_field_built(field.0))
35        };
36        if all_built {
37            return Ok(());
38        }
39
40        // Count total vectors across all segments
41        let segment_ids = self.segment_manager.get_segment_ids().await;
42        let total_vectors = self.count_flat_vectors(&segment_ids).await;
43
44        // Update total in metadata and check if any field should be built
45        let should_build = {
46            let metadata_arc = self.segment_manager.metadata();
47            let mut meta = metadata_arc.write().await;
48            meta.total_vectors = total_vectors;
49            dense_fields.iter().any(|(field, config)| {
50                let threshold = config.build_threshold.unwrap_or(1000);
51                meta.should_build_field(field.0, threshold)
52            })
53        };
54
55        if should_build {
56            log::info!(
57                "Threshold crossed ({} vectors), auto-triggering vector index build",
58                total_vectors
59            );
60            self.build_vector_index().await?;
61        }
62
63        Ok(())
64    }
65
66    /// Build vector index from accumulated Flat vectors (trains ONCE)
67    ///
68    /// This trains centroids/codebooks from ALL vectors across all segments.
69    /// Training happens only ONCE - subsequent calls are no-ops if already built.
70    ///
71    /// **Note:** This is auto-triggered by `commit()` when threshold is crossed.
72    /// You typically don't need to call this manually.
73    ///
74    /// The process:
75    /// 1. Check if already built (skip if so)
76    /// 2. Collect all vectors from all segments
77    /// 3. Train centroids/codebooks based on schema's index_type
78    /// 4. Update metadata to mark as built (prevents re-training)
79    pub async fn build_vector_index(&self) -> Result<()> {
80        let dense_fields = self.get_dense_vector_fields();
81        if dense_fields.is_empty() {
82            log::info!("No dense vector fields configured for ANN indexing");
83            return Ok(());
84        }
85
86        // Check which fields need building (skip already built)
87        let fields_to_build = self.get_fields_to_build(&dense_fields).await;
88        if fields_to_build.is_empty() {
89            log::info!("All vector fields already built, skipping training");
90            return Ok(());
91        }
92
93        let segment_ids = self.segment_manager.get_segment_ids().await;
94        if segment_ids.is_empty() {
95            return Ok(());
96        }
97
98        // Collect all vectors from all segments for fields that need building
99        let all_vectors = self
100            .collect_vectors_for_training(&segment_ids, &fields_to_build)
101            .await?;
102
103        // Train centroids/codebooks for each field
104        for (field, config) in &fields_to_build {
105            self.train_field_index(*field, config, &all_vectors).await?;
106        }
107
108        log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
109
110        // Rebuild segments with ANN indexes using trained structures
111        self.rebuild_segments_with_ann().await?;
112
113        Ok(())
114    }
115
116    /// Rebuild all segments with ANN indexes using trained centroids/codebooks
117    pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
118        let segment_ids = self.segment_manager.get_segment_ids().await;
119        if segment_ids.is_empty() {
120            return Ok(());
121        }
122
123        // Load trained structures from metadata
124        let (trained_centroids, trained_codebooks) = {
125            let metadata_arc = self.segment_manager.metadata();
126            let meta = metadata_arc.read().await;
127            meta.load_trained_structures(self.directory.as_ref()).await
128        };
129
130        if trained_centroids.is_empty() {
131            log::info!("No trained structures to rebuild with");
132            return Ok(());
133        }
134
135        let trained = TrainedVectorStructures {
136            centroids: trained_centroids,
137            codebooks: trained_codebooks,
138        };
139
140        // Load all segment readers
141        let readers = self.load_segment_readers(&segment_ids).await?;
142
143        // Calculate total doc count for the merged segment
144        let total_docs: u32 = readers.iter().map(|r| r.meta().num_docs).sum();
145
146        // Merge all segments into one with ANN indexes
147        let merger = SegmentMerger::new(Arc::clone(&self.schema));
148        let new_segment_id = SegmentId::new();
149        merger
150            .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
151            .await?;
152
153        // Atomically update segments and delete old ones via SegmentManager
154        self.segment_manager
155            .replace_segments(vec![(new_segment_id.to_hex(), total_docs)], segment_ids)
156            .await?;
157
158        log::info!("Segments rebuilt with ANN indexes");
159        Ok(())
160    }
161
162    /// Get total vector count across all segments (for threshold checking)
163    pub async fn total_vector_count(&self) -> usize {
164        let metadata_arc = self.segment_manager.metadata();
165        metadata_arc.read().await.total_vectors
166    }
167
168    /// Check if vector index has been built for a field
169    pub async fn is_vector_index_built(&self, field: Field) -> bool {
170        let metadata_arc = self.segment_manager.metadata();
171        metadata_arc.read().await.is_field_built(field.0)
172    }
173
174    /// Rebuild vector index by retraining centroids/codebooks
175    ///
176    /// Use this when:
177    /// - Significant new data has been added and you want better centroids
178    /// - You want to change the number of clusters
179    /// - The vector distribution has changed significantly
180    ///
181    /// This resets the Built state to Flat, then triggers a fresh training.
182    pub async fn rebuild_vector_index(&self) -> Result<()> {
183        let dense_fields: Vec<Field> = self
184            .schema
185            .fields()
186            .filter_map(|(field, entry)| {
187                if entry.field_type == FieldType::DenseVector && entry.indexed {
188                    Some(field)
189                } else {
190                    None
191                }
192            })
193            .collect();
194
195        if dense_fields.is_empty() {
196            return Ok(());
197        }
198
199        // Collect files to delete and reset fields to Flat state
200        let files_to_delete = {
201            let metadata_arc = self.segment_manager.metadata();
202            let mut meta = metadata_arc.write().await;
203            let mut files = Vec::new();
204            for field in &dense_fields {
205                if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
206                    field_meta.state = super::VectorIndexState::Flat;
207                    if let Some(ref f) = field_meta.centroids_file {
208                        files.push(f.clone());
209                    }
210                    if let Some(ref f) = field_meta.codebook_file {
211                        files.push(f.clone());
212                    }
213                    field_meta.centroids_file = None;
214                    field_meta.codebook_file = None;
215                }
216            }
217            meta.save(self.directory.as_ref()).await?;
218            files
219        };
220
221        // Delete old centroids/codebook files
222        for file in files_to_delete {
223            let _ = self.directory.delete(std::path::Path::new(&file)).await;
224        }
225
226        log::info!("Reset vector index state to Flat, triggering rebuild...");
227
228        // Now build fresh
229        self.build_vector_index().await
230    }
231
232    // ========================================================================
233    // Helper methods
234    // ========================================================================
235
236    /// Get all dense vector fields that need ANN indexes
237    fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
238        self.schema
239            .fields()
240            .filter_map(|(field, entry)| {
241                if entry.field_type == FieldType::DenseVector && entry.indexed {
242                    entry
243                        .dense_vector_config
244                        .as_ref()
245                        .filter(|c| !c.is_flat())
246                        .map(|c| (field, c.clone()))
247                } else {
248                    None
249                }
250            })
251            .collect()
252    }
253
254    /// Get fields that need building (not already built)
255    async fn get_fields_to_build(
256        &self,
257        dense_fields: &[(Field, DenseVectorConfig)],
258    ) -> Vec<(Field, DenseVectorConfig)> {
259        let metadata_arc = self.segment_manager.metadata();
260        let meta = metadata_arc.read().await;
261        dense_fields
262            .iter()
263            .filter(|(field, _)| !meta.is_field_built(field.0))
264            .cloned()
265            .collect()
266    }
267
268    /// Count flat vectors across all segments
269    /// Only loads segments that have a vectors file to avoid unnecessary I/O
270    async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
271        let mut total_vectors = 0usize;
272        let mut doc_offset = 0u32;
273
274        for id_str in segment_ids {
275            let Some(segment_id) = SegmentId::from_hex(id_str) else {
276                continue;
277            };
278
279            // Quick check: skip segments without vectors file
280            let files = crate::segment::SegmentFiles::new(segment_id.0);
281            if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
282                // No vectors file - segment has no vectors, skip loading
283                continue;
284            }
285
286            // Only load segments that have vectors
287            if let Ok(reader) = SegmentReader::open(
288                self.directory.as_ref(),
289                segment_id,
290                Arc::clone(&self.schema),
291                doc_offset,
292                self.config.term_cache_blocks,
293            )
294            .await
295            {
296                for index in reader.vector_indexes().values() {
297                    if let crate::segment::VectorIndex::Flat(flat_data) = index {
298                        total_vectors += flat_data.vectors.len();
299                    }
300                }
301                doc_offset += reader.meta().num_docs;
302            }
303        }
304
305        total_vectors
306    }
307
308    /// Collect vectors from segments for training, with sampling for large datasets.
309    ///
310    /// K-means clustering converges well with ~100K samples, so we cap collection
311    /// per field to avoid loading millions of vectors into memory.
312    async fn collect_vectors_for_training(
313        &self,
314        segment_ids: &[String],
315        fields_to_build: &[(Field, DenseVectorConfig)],
316    ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
317        /// Maximum vectors per field for training. K-means converges well with ~100K samples.
318        const MAX_TRAINING_VECTORS: usize = 100_000;
319
320        let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
321        let mut doc_offset = 0u32;
322        let mut total_skipped = 0usize;
323
324        for id_str in segment_ids {
325            let segment_id = SegmentId::from_hex(id_str)
326                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
327            let reader = SegmentReader::open(
328                self.directory.as_ref(),
329                segment_id,
330                Arc::clone(&self.schema),
331                doc_offset,
332                self.config.term_cache_blocks,
333            )
334            .await?;
335
336            for (field_id, index) in reader.vector_indexes() {
337                if fields_to_build.iter().any(|(f, _)| f.0 == *field_id)
338                    && let crate::segment::VectorIndex::Flat(flat_data) = index
339                {
340                    let entry = all_vectors.entry(*field_id).or_default();
341                    let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
342
343                    if remaining == 0 {
344                        total_skipped += flat_data.vectors.len();
345                        continue;
346                    }
347
348                    if flat_data.vectors.len() <= remaining {
349                        // Take all vectors from this segment
350                        entry.extend(flat_data.vectors.iter().cloned());
351                    } else {
352                        // Uniform sample: take every Nth vector
353                        let step = (flat_data.vectors.len() / remaining).max(1);
354                        for (i, vec) in flat_data.vectors.iter().enumerate() {
355                            if i % step == 0 && entry.len() < MAX_TRAINING_VECTORS {
356                                entry.push(vec.clone());
357                            }
358                        }
359                        total_skipped += flat_data.vectors.len() - remaining;
360                    }
361                }
362            }
363
364            doc_offset += reader.meta().num_docs;
365        }
366
367        if total_skipped > 0 {
368            let collected: usize = all_vectors.values().map(|v| v.len()).sum();
369            log::info!(
370                "Sampled {} vectors for training (skipped {}, max {} per field)",
371                collected,
372                total_skipped,
373                MAX_TRAINING_VECTORS,
374            );
375        }
376
377        Ok(all_vectors)
378    }
379
380    /// Load segment readers for given IDs
381    pub(super) async fn load_segment_readers(
382        &self,
383        segment_ids: &[String],
384    ) -> Result<Vec<SegmentReader>> {
385        let mut readers = Vec::new();
386        let mut doc_offset = 0u32;
387
388        for id_str in segment_ids {
389            let segment_id = SegmentId::from_hex(id_str)
390                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
391            let reader = SegmentReader::open(
392                self.directory.as_ref(),
393                segment_id,
394                Arc::clone(&self.schema),
395                doc_offset,
396                self.config.term_cache_blocks,
397            )
398            .await?;
399            doc_offset += reader.meta().num_docs;
400            readers.push(reader);
401        }
402
403        Ok(readers)
404    }
405
406    /// Train index for a single field
407    async fn train_field_index(
408        &self,
409        field: Field,
410        config: &DenseVectorConfig,
411        all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
412    ) -> Result<()> {
413        let field_id = field.0;
414        let vectors = match all_vectors.get(&field_id) {
415            Some(v) if !v.is_empty() => v,
416            _ => return Ok(()),
417        };
418
419        let index_dim = config.index_dim();
420        let num_vectors = vectors.len();
421        let num_clusters = config.optimal_num_clusters(num_vectors);
422
423        log::info!(
424            "Training vector index for field {} with {} vectors, {} clusters",
425            field_id,
426            num_vectors,
427            num_clusters
428        );
429
430        let centroids_filename = format!("field_{}_centroids.bin", field_id);
431        let mut codebook_filename: Option<String> = None;
432
433        match config.index_type {
434            VectorIndexType::IvfRaBitQ => {
435                self.train_ivf_rabitq(
436                    field_id,
437                    index_dim,
438                    num_clusters,
439                    vectors,
440                    &centroids_filename,
441                )
442                .await?;
443            }
444            VectorIndexType::ScaNN => {
445                codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
446                self.train_scann(
447                    field_id,
448                    index_dim,
449                    num_clusters,
450                    vectors,
451                    &centroids_filename,
452                    codebook_filename.as_ref().unwrap(),
453                )
454                .await?;
455            }
456            _ => {
457                // RaBitQ or Flat - no pre-training needed
458                return Ok(());
459            }
460        }
461
462        // Update metadata to mark this field as built
463        self.segment_manager
464            .update_metadata(|meta| {
465                meta.init_field(field_id, config.index_type);
466                meta.total_vectors = num_vectors;
467                meta.mark_field_built(
468                    field_id,
469                    num_vectors,
470                    num_clusters,
471                    centroids_filename.clone(),
472                    codebook_filename.clone(),
473                );
474            })
475            .await?;
476
477        Ok(())
478    }
479
480    /// Train IVF-RaBitQ centroids
481    async fn train_ivf_rabitq(
482        &self,
483        field_id: u32,
484        index_dim: usize,
485        num_clusters: usize,
486        vectors: &[Vec<f32>],
487        centroids_filename: &str,
488    ) -> Result<()> {
489        let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
490        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
491
492        // Save centroids to index-level file
493        let centroids_path = std::path::Path::new(centroids_filename);
494        let centroids_bytes =
495            serde_json::to_vec(&centroids).map_err(|e| Error::Serialization(e.to_string()))?;
496        self.directory
497            .write(centroids_path, &centroids_bytes)
498            .await?;
499
500        log::info!(
501            "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
502            field_id,
503            centroids.num_clusters
504        );
505
506        Ok(())
507    }
508
509    /// Train ScaNN (IVF-PQ) centroids and codebook
510    async fn train_scann(
511        &self,
512        field_id: u32,
513        index_dim: usize,
514        num_clusters: usize,
515        vectors: &[Vec<f32>],
516        centroids_filename: &str,
517        codebook_filename: &str,
518    ) -> Result<()> {
519        // Train coarse centroids
520        let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
521        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
522
523        // Train PQ codebook
524        let pq_config = crate::structures::PQConfig::new(index_dim);
525        let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
526
527        // Save centroids
528        let centroids_path = std::path::Path::new(centroids_filename);
529        let centroids_bytes =
530            serde_json::to_vec(&centroids).map_err(|e| Error::Serialization(e.to_string()))?;
531        self.directory
532            .write(centroids_path, &centroids_bytes)
533            .await?;
534
535        // Save codebook
536        let codebook_path = std::path::Path::new(codebook_filename);
537        let codebook_bytes =
538            serde_json::to_vec(&codebook).map_err(|e| Error::Serialization(e.to_string()))?;
539        self.directory.write(codebook_path, &codebook_bytes).await?;
540
541        log::info!(
542            "Saved ScaNN centroids and codebook for field {} ({} clusters)",
543            field_id,
544            centroids.num_clusters
545        );
546
547        Ok(())
548    }
549}