Skip to main content

hermes_core/index/
vector_builder.rs

1//! Vector index building for IndexWriter
2//!
3//! This module handles:
4//! - Training centroids/codebooks from accumulated Flat vectors
5//! - Rebuilding segments with ANN indexes
6//! - Threshold-based auto-triggering of vector index builds
7
8use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20    /// Check if any dense vector field should be built and trigger training
21    pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22        let dense_fields = self.get_dense_vector_fields();
23        if dense_fields.is_empty() {
24            return Ok(());
25        }
26
27        // Quick check: if all fields are already built, skip entirely
28        // This avoids loading segments just to count vectors when index is already built
29        let all_built = {
30            let metadata_arc = self.segment_manager.metadata();
31            let meta = metadata_arc.read().await;
32            dense_fields
33                .iter()
34                .all(|(field, _)| meta.is_field_built(field.0))
35        };
36        if all_built {
37            return Ok(());
38        }
39
40        // Count total vectors across all segments
41        let segment_ids = self.segment_manager.get_segment_ids().await;
42        let total_vectors = self.count_flat_vectors(&segment_ids).await;
43
44        // Update total in metadata and check if any field should be built
45        let should_build = {
46            let metadata_arc = self.segment_manager.metadata();
47            let mut meta = metadata_arc.write().await;
48            meta.total_vectors = total_vectors;
49            dense_fields.iter().any(|(field, config)| {
50                let threshold = config.build_threshold.unwrap_or(1000);
51                meta.should_build_field(field.0, threshold)
52            })
53        };
54
55        if should_build {
56            log::info!(
57                "Threshold crossed ({} vectors), auto-triggering vector index build",
58                total_vectors
59            );
60            self.build_vector_index().await?;
61        }
62
63        Ok(())
64    }
65
66    /// Build vector index from accumulated Flat vectors (trains ONCE)
67    ///
68    /// This trains centroids/codebooks from ALL vectors across all segments.
69    /// Training happens only ONCE - subsequent calls are no-ops if already built.
70    ///
71    /// **Note:** This is auto-triggered by `commit()` when threshold is crossed.
72    /// You typically don't need to call this manually.
73    ///
74    /// The process:
75    /// 1. Check if already built (skip if so)
76    /// 2. Collect all vectors from all segments
77    /// 3. Train centroids/codebooks based on schema's index_type
78    /// 4. Update metadata to mark as built (prevents re-training)
79    pub async fn build_vector_index(&self) -> Result<()> {
80        let dense_fields = self.get_dense_vector_fields();
81        if dense_fields.is_empty() {
82            log::info!("No dense vector fields configured for ANN indexing");
83            return Ok(());
84        }
85
86        // Check which fields need building (skip already built)
87        let fields_to_build = self.get_fields_to_build(&dense_fields).await;
88        if fields_to_build.is_empty() {
89            log::info!("All vector fields already built, skipping training");
90            return Ok(());
91        }
92
93        // Wait for any background merges to complete before training.
94        // rebuild_segments_with_ann() calls replace_segments() which replaces
95        // merged segments — concurrent merges would operate on stale/deleted segments.
96        self.segment_manager.wait_for_merges().await;
97
98        let segment_ids = self.segment_manager.get_segment_ids().await;
99        if segment_ids.is_empty() {
100            return Ok(());
101        }
102
103        // Collect all vectors from all segments for fields that need building
104        let all_vectors = self
105            .collect_vectors_for_training(&segment_ids, &fields_to_build)
106            .await?;
107
108        // Train centroids/codebooks for each field
109        for (field, config) in &fields_to_build {
110            self.train_field_index(*field, config, &all_vectors).await?;
111        }
112
113        log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
114
115        // Rebuild segments with ANN indexes using trained structures
116        self.rebuild_segments_with_ann().await?;
117
118        Ok(())
119    }
120
121    /// Rebuild all segments with ANN indexes using trained centroids/codebooks
122    pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
123        // Pause background merges and wait for any in-flight ones to finish.
124        // rebuild replaces merged segments — concurrent merges would
125        // operate on stale/deleted segments.
126        self.segment_manager.pause_merges();
127        self.segment_manager.wait_for_merges().await;
128
129        let result = self.rebuild_segments_with_ann_inner().await;
130
131        // Always resume merges, even on error
132        self.segment_manager.resume_merges();
133
134        result
135    }
136
137    async fn rebuild_segments_with_ann_inner(&self) -> Result<()> {
138        let segment_ids = self.segment_manager.get_segment_ids().await;
139        if segment_ids.is_empty() {
140            return Ok(());
141        }
142
143        // Load trained structures from metadata
144        let (trained_centroids, trained_codebooks) = {
145            let metadata_arc = self.segment_manager.metadata();
146            let meta = metadata_arc.read().await;
147            meta.load_trained_structures(self.directory.as_ref()).await
148        };
149
150        if trained_centroids.is_empty() {
151            log::info!("No trained structures to rebuild with");
152            return Ok(());
153        }
154
155        let trained = TrainedVectorStructures {
156            centroids: trained_centroids,
157            codebooks: trained_codebooks,
158        };
159
160        // Load all segment readers
161        let readers = self.load_segment_readers(&segment_ids).await?;
162
163        // Calculate total doc count for the merged segment
164        let total_docs: u32 = readers.iter().map(|r| r.meta().num_docs).sum();
165
166        // Merge all segments into one with ANN indexes
167        let merger = SegmentMerger::new(Arc::clone(&self.schema));
168        let new_segment_id = SegmentId::new();
169        merger
170            .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
171            .await?;
172
173        // Atomically update segments and delete old ones via SegmentManager
174        self.segment_manager
175            .replace_segments(vec![(new_segment_id.to_hex(), total_docs)], segment_ids)
176            .await?;
177
178        log::info!("Segments rebuilt with ANN indexes");
179        Ok(())
180    }
181
182    /// Get total vector count across all segments (for threshold checking)
183    pub async fn total_vector_count(&self) -> usize {
184        let metadata_arc = self.segment_manager.metadata();
185        metadata_arc.read().await.total_vectors
186    }
187
188    /// Check if vector index has been built for a field
189    pub async fn is_vector_index_built(&self, field: Field) -> bool {
190        let metadata_arc = self.segment_manager.metadata();
191        metadata_arc.read().await.is_field_built(field.0)
192    }
193
194    /// Rebuild vector index by retraining centroids/codebooks
195    ///
196    /// Use this when:
197    /// - Significant new data has been added and you want better centroids
198    /// - You want to change the number of clusters
199    /// - The vector distribution has changed significantly
200    ///
201    /// This resets the Built state to Flat, then triggers a fresh training.
202    pub async fn rebuild_vector_index(&self) -> Result<()> {
203        let dense_fields = self.get_dense_vector_fields();
204        if dense_fields.is_empty() {
205            return Ok(());
206        }
207        let dense_fields: Vec<Field> = dense_fields.into_iter().map(|(f, _)| f).collect();
208
209        // Collect files to delete and reset fields to Flat state
210        let files_to_delete = {
211            let metadata_arc = self.segment_manager.metadata();
212            let mut meta = metadata_arc.write().await;
213            let mut files = Vec::new();
214            for field in &dense_fields {
215                if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
216                    field_meta.state = super::VectorIndexState::Flat;
217                    if let Some(ref f) = field_meta.centroids_file {
218                        files.push(f.clone());
219                    }
220                    if let Some(ref f) = field_meta.codebook_file {
221                        files.push(f.clone());
222                    }
223                    field_meta.centroids_file = None;
224                    field_meta.codebook_file = None;
225                }
226            }
227            meta.save(self.directory.as_ref()).await?;
228            files
229        };
230
231        // Delete old centroids/codebook files
232        for file in files_to_delete {
233            let _ = self.directory.delete(std::path::Path::new(&file)).await;
234        }
235
236        log::info!("Reset vector index state to Flat, triggering rebuild...");
237
238        // Now build fresh
239        self.build_vector_index().await
240    }
241
242    // ========================================================================
243    // Helper methods
244    // ========================================================================
245
246    /// Get all dense vector fields that need ANN indexes
247    fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
248        self.schema
249            .fields()
250            .filter_map(|(field, entry)| {
251                if entry.field_type == FieldType::DenseVector && entry.indexed {
252                    entry
253                        .dense_vector_config
254                        .as_ref()
255                        .filter(|c| !c.is_flat())
256                        .map(|c| (field, c.clone()))
257                } else {
258                    None
259                }
260            })
261            .collect()
262    }
263
264    /// Get fields that need building (not already built)
265    async fn get_fields_to_build(
266        &self,
267        dense_fields: &[(Field, DenseVectorConfig)],
268    ) -> Vec<(Field, DenseVectorConfig)> {
269        let metadata_arc = self.segment_manager.metadata();
270        let meta = metadata_arc.read().await;
271        dense_fields
272            .iter()
273            .filter(|(field, _)| !meta.is_field_built(field.0))
274            .cloned()
275            .collect()
276    }
277
278    /// Count flat vectors across all segments
279    /// Only loads segments that have a vectors file to avoid unnecessary I/O
280    async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
281        let mut total_vectors = 0usize;
282        let mut doc_offset = 0u32;
283
284        for id_str in segment_ids {
285            let Some(segment_id) = SegmentId::from_hex(id_str) else {
286                continue;
287            };
288
289            // Quick check: skip segments without vectors file
290            let files = crate::segment::SegmentFiles::new(segment_id.0);
291            if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
292                // No vectors file - segment has no vectors, skip loading
293                continue;
294            }
295
296            // Only load segments that have vectors
297            if let Ok(reader) = SegmentReader::open(
298                self.directory.as_ref(),
299                segment_id,
300                Arc::clone(&self.schema),
301                doc_offset,
302                self.config.term_cache_blocks,
303            )
304            .await
305            {
306                for flat_data in reader.flat_vectors().values() {
307                    total_vectors += flat_data.num_vectors;
308                }
309                doc_offset += reader.meta().num_docs;
310            }
311        }
312
313        total_vectors
314    }
315
316    /// Collect vectors from segments for training, with sampling for large datasets.
317    ///
318    /// K-means clustering converges well with ~100K samples, so we cap collection
319    /// per field to avoid loading millions of vectors into memory.
320    async fn collect_vectors_for_training(
321        &self,
322        segment_ids: &[String],
323        fields_to_build: &[(Field, DenseVectorConfig)],
324    ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
325        /// Maximum vectors per field for training. K-means converges well with ~100K samples.
326        const MAX_TRAINING_VECTORS: usize = 100_000;
327
328        let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
329        let mut doc_offset = 0u32;
330        let mut total_skipped = 0usize;
331
332        for id_str in segment_ids {
333            let segment_id = SegmentId::from_hex(id_str)
334                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
335            let reader = SegmentReader::open(
336                self.directory.as_ref(),
337                segment_id,
338                Arc::clone(&self.schema),
339                doc_offset,
340                self.config.term_cache_blocks,
341            )
342            .await?;
343
344            for (field_id, lazy_flat) in reader.flat_vectors() {
345                if !fields_to_build.iter().any(|(f, _)| f.0 == *field_id) {
346                    continue;
347                }
348                let entry = all_vectors.entry(*field_id).or_default();
349                let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
350
351                if remaining == 0 {
352                    total_skipped += lazy_flat.num_vectors;
353                    continue;
354                }
355
356                let n = lazy_flat.num_vectors;
357                let dim = lazy_flat.dim;
358                let quant = lazy_flat.quantization;
359
360                // Determine which vector indices to collect
361                let indices: Vec<usize> = if n <= remaining {
362                    (0..n).collect()
363                } else {
364                    let step = (n / remaining).max(1);
365                    (0..n).step_by(step).take(remaining).collect()
366                };
367
368                if indices.len() < n {
369                    total_skipped += n - indices.len();
370                }
371
372                // Batch-read and dequantize instead of one-by-one get_vector()
373                const BATCH: usize = 1024;
374                let mut f32_buf = vec![0f32; BATCH * dim];
375                for chunk in indices.chunks(BATCH) {
376                    // For contiguous ranges, use batch read
377                    let start = chunk[0];
378                    let end = *chunk.last().unwrap();
379                    if end - start + 1 == chunk.len() {
380                        // Contiguous — single batch read
381                        if let Ok(batch_bytes) =
382                            lazy_flat.read_vectors_batch(start, chunk.len()).await
383                        {
384                            let floats = chunk.len() * dim;
385                            f32_buf.resize(floats, 0.0);
386                            crate::segment::dequantize_raw(
387                                batch_bytes.as_slice(),
388                                quant,
389                                floats,
390                                &mut f32_buf,
391                            );
392                            for i in 0..chunk.len() {
393                                entry.push(f32_buf[i * dim..(i + 1) * dim].to_vec());
394                            }
395                        }
396                    } else {
397                        // Non-contiguous (sampled) — read individually but reuse buffer
398                        f32_buf.resize(dim, 0.0);
399                        for &idx in chunk {
400                            if let Ok(()) = lazy_flat.read_vector_into(idx, &mut f32_buf).await {
401                                entry.push(f32_buf[..dim].to_vec());
402                            }
403                        }
404                    }
405                }
406            }
407
408            doc_offset += reader.meta().num_docs;
409        }
410
411        if total_skipped > 0 {
412            let collected: usize = all_vectors.values().map(|v| v.len()).sum();
413            log::info!(
414                "Sampled {} vectors for training (skipped {}, max {} per field)",
415                collected,
416                total_skipped,
417                MAX_TRAINING_VECTORS,
418            );
419        }
420
421        Ok(all_vectors)
422    }
423
424    /// Load segment readers for given IDs
425    pub(super) async fn load_segment_readers(
426        &self,
427        segment_ids: &[String],
428    ) -> Result<Vec<SegmentReader>> {
429        let mut readers = Vec::new();
430        let mut doc_offset = 0u32;
431
432        for id_str in segment_ids {
433            let segment_id = SegmentId::from_hex(id_str)
434                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
435            let reader = SegmentReader::open(
436                self.directory.as_ref(),
437                segment_id,
438                Arc::clone(&self.schema),
439                doc_offset,
440                self.config.term_cache_blocks,
441            )
442            .await?;
443            doc_offset += reader.meta().num_docs;
444            readers.push(reader);
445        }
446
447        Ok(readers)
448    }
449
450    /// Train index for a single field
451    async fn train_field_index(
452        &self,
453        field: Field,
454        config: &DenseVectorConfig,
455        all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
456    ) -> Result<()> {
457        let field_id = field.0;
458        let vectors = match all_vectors.get(&field_id) {
459            Some(v) if !v.is_empty() => v,
460            _ => return Ok(()),
461        };
462
463        let dim = config.dim;
464        let num_vectors = vectors.len();
465        let num_clusters = config.optimal_num_clusters(num_vectors);
466
467        log::info!(
468            "Training vector index for field {} with {} vectors, {} clusters (dim={})",
469            field_id,
470            num_vectors,
471            num_clusters,
472            dim,
473        );
474
475        let centroids_filename = format!("field_{}_centroids.bin", field_id);
476        let mut codebook_filename: Option<String> = None;
477
478        match config.index_type {
479            VectorIndexType::IvfRaBitQ => {
480                self.train_ivf_rabitq(field_id, dim, num_clusters, vectors, &centroids_filename)
481                    .await?;
482            }
483            VectorIndexType::ScaNN => {
484                codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
485                self.train_scann(
486                    field_id,
487                    dim,
488                    num_clusters,
489                    vectors,
490                    &centroids_filename,
491                    codebook_filename.as_ref().unwrap(),
492                )
493                .await?;
494            }
495            _ => {
496                // RaBitQ or Flat - no pre-training needed
497                return Ok(());
498            }
499        }
500
501        // Update metadata to mark this field as built
502        self.segment_manager
503            .update_metadata(|meta| {
504                meta.init_field(field_id, config.index_type);
505                meta.total_vectors = num_vectors;
506                meta.mark_field_built(
507                    field_id,
508                    num_vectors,
509                    num_clusters,
510                    centroids_filename.clone(),
511                    codebook_filename.clone(),
512                );
513            })
514            .await?;
515
516        Ok(())
517    }
518
519    /// Serialize a trained structure to JSON and save to an index-level file.
520    async fn save_trained_artifact(
521        &self,
522        artifact: &impl serde::Serialize,
523        filename: &str,
524    ) -> Result<()> {
525        let bytes =
526            serde_json::to_vec(artifact).map_err(|e| Error::Serialization(e.to_string()))?;
527        self.directory
528            .write(std::path::Path::new(filename), &bytes)
529            .await?;
530        Ok(())
531    }
532
533    /// Train IVF-RaBitQ centroids
534    async fn train_ivf_rabitq(
535        &self,
536        field_id: u32,
537        dim: usize,
538        num_clusters: usize,
539        vectors: &[Vec<f32>],
540        centroids_filename: &str,
541    ) -> Result<()> {
542        let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
543        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
544        self.save_trained_artifact(&centroids, centroids_filename)
545            .await?;
546
547        log::info!(
548            "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
549            field_id,
550            centroids.num_clusters
551        );
552        Ok(())
553    }
554
555    /// Train ScaNN (IVF-PQ) centroids and codebook
556    async fn train_scann(
557        &self,
558        field_id: u32,
559        dim: usize,
560        num_clusters: usize,
561        vectors: &[Vec<f32>],
562        centroids_filename: &str,
563        codebook_filename: &str,
564    ) -> Result<()> {
565        let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
566        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
567        self.save_trained_artifact(&centroids, centroids_filename)
568            .await?;
569
570        let pq_config = crate::structures::PQConfig::new(dim);
571        let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
572        self.save_trained_artifact(&codebook, codebook_filename)
573            .await?;
574
575        log::info!(
576            "Saved ScaNN centroids and codebook for field {} ({} clusters)",
577            field_id,
578            centroids.num_clusters
579        );
580        Ok(())
581    }
582}