Skip to main content

hermes_core/index/
vector_builder.rs

1//! Vector index building for IndexWriter
2//!
3//! This module handles:
4//! - Training centroids/codebooks from accumulated Flat vectors
5//! - Rebuilding segments with ANN indexes
6//! - Threshold-based auto-triggering of vector index builds
7
8use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20    /// Check if any dense vector field should be built and trigger training
21    pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22        let dense_fields = self.get_dense_vector_fields();
23        if dense_fields.is_empty() {
24            return Ok(());
25        }
26
27        // Quick check: if all fields are already built, skip entirely
28        // This avoids loading segments just to count vectors when index is already built
29        let all_built = {
30            let metadata_arc = self.segment_manager.metadata();
31            let meta = metadata_arc.read().await;
32            dense_fields
33                .iter()
34                .all(|(field, _)| meta.is_field_built(field.0))
35        };
36        if all_built {
37            return Ok(());
38        }
39
40        // Count total vectors across all segments
41        let segment_ids = self.segment_manager.get_segment_ids().await;
42        let total_vectors = self.count_flat_vectors(&segment_ids).await;
43
44        // Update total in metadata and check if any field should be built
45        let should_build = {
46            let metadata_arc = self.segment_manager.metadata();
47            let mut meta = metadata_arc.write().await;
48            meta.total_vectors = total_vectors;
49            dense_fields.iter().any(|(field, config)| {
50                let threshold = config.build_threshold.unwrap_or(1000);
51                meta.should_build_field(field.0, threshold)
52            })
53        };
54
55        if should_build {
56            log::info!(
57                "Threshold crossed ({} vectors), auto-triggering vector index build",
58                total_vectors
59            );
60            self.build_vector_index().await?;
61        }
62
63        Ok(())
64    }
65
66    /// Build vector index from accumulated Flat vectors (trains ONCE)
67    ///
68    /// This trains centroids/codebooks from ALL vectors across all segments.
69    /// Training happens only ONCE - subsequent calls are no-ops if already built.
70    ///
71    /// **Note:** This is auto-triggered by `commit()` when threshold is crossed.
72    /// You typically don't need to call this manually.
73    ///
74    /// The process:
75    /// 1. Check if already built (skip if so)
76    /// 2. Collect all vectors from all segments
77    /// 3. Train centroids/codebooks based on schema's index_type
78    /// 4. Update metadata to mark as built (prevents re-training)
79    pub async fn build_vector_index(&self) -> Result<()> {
80        let dense_fields = self.get_dense_vector_fields();
81        if dense_fields.is_empty() {
82            log::info!("No dense vector fields configured for ANN indexing");
83            return Ok(());
84        }
85
86        // Check which fields need building (skip already built)
87        let fields_to_build = self.get_fields_to_build(&dense_fields).await;
88        if fields_to_build.is_empty() {
89            log::info!("All vector fields already built, skipping training");
90            return Ok(());
91        }
92
93        // Wait for any background merges to complete before training.
94        // rebuild_segments_with_ann() calls replace_segments() which clears ALL
95        // segments atomically — concurrent merges would lose data or operate on
96        // stale/deleted segments.
97        self.segment_manager.wait_for_merges().await;
98
99        let segment_ids = self.segment_manager.get_segment_ids().await;
100        if segment_ids.is_empty() {
101            return Ok(());
102        }
103
104        // Collect all vectors from all segments for fields that need building
105        let all_vectors = self
106            .collect_vectors_for_training(&segment_ids, &fields_to_build)
107            .await?;
108
109        // Train centroids/codebooks for each field
110        for (field, config) in &fields_to_build {
111            self.train_field_index(*field, config, &all_vectors).await?;
112        }
113
114        log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
115
116        // Rebuild segments with ANN indexes using trained structures
117        self.rebuild_segments_with_ann().await?;
118
119        Ok(())
120    }
121
122    /// Rebuild all segments with ANN indexes using trained centroids/codebooks
123    pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
124        // Pause background merges and wait for any in-flight ones to finish.
125        // rebuild replaces ALL segments atomically — concurrent merges would
126        // operate on stale/deleted segments or lose their output.
127        self.segment_manager.pause_merges();
128        self.segment_manager.wait_for_merges().await;
129
130        let result = self.rebuild_segments_with_ann_inner().await;
131
132        // Always resume merges, even on error
133        self.segment_manager.resume_merges();
134
135        result
136    }
137
138    async fn rebuild_segments_with_ann_inner(&self) -> Result<()> {
139        let segment_ids = self.segment_manager.get_segment_ids().await;
140        if segment_ids.is_empty() {
141            return Ok(());
142        }
143
144        // Load trained structures from metadata
145        let (trained_centroids, trained_codebooks) = {
146            let metadata_arc = self.segment_manager.metadata();
147            let meta = metadata_arc.read().await;
148            meta.load_trained_structures(self.directory.as_ref()).await
149        };
150
151        if trained_centroids.is_empty() {
152            log::info!("No trained structures to rebuild with");
153            return Ok(());
154        }
155
156        let trained = TrainedVectorStructures {
157            centroids: trained_centroids,
158            codebooks: trained_codebooks,
159        };
160
161        // Load all segment readers
162        let readers = self.load_segment_readers(&segment_ids).await?;
163
164        // Calculate total doc count for the merged segment
165        let total_docs: u32 = readers.iter().map(|r| r.meta().num_docs).sum();
166
167        // Merge all segments into one with ANN indexes
168        let merger = SegmentMerger::new(Arc::clone(&self.schema));
169        let new_segment_id = SegmentId::new();
170        merger
171            .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
172            .await?;
173
174        // Atomically update segments and delete old ones via SegmentManager
175        self.segment_manager
176            .replace_segments(vec![(new_segment_id.to_hex(), total_docs)], segment_ids)
177            .await?;
178
179        log::info!("Segments rebuilt with ANN indexes");
180        Ok(())
181    }
182
183    /// Get total vector count across all segments (for threshold checking)
184    pub async fn total_vector_count(&self) -> usize {
185        let metadata_arc = self.segment_manager.metadata();
186        metadata_arc.read().await.total_vectors
187    }
188
189    /// Check if vector index has been built for a field
190    pub async fn is_vector_index_built(&self, field: Field) -> bool {
191        let metadata_arc = self.segment_manager.metadata();
192        metadata_arc.read().await.is_field_built(field.0)
193    }
194
195    /// Rebuild vector index by retraining centroids/codebooks
196    ///
197    /// Use this when:
198    /// - Significant new data has been added and you want better centroids
199    /// - You want to change the number of clusters
200    /// - The vector distribution has changed significantly
201    ///
202    /// This resets the Built state to Flat, then triggers a fresh training.
203    pub async fn rebuild_vector_index(&self) -> Result<()> {
204        let dense_fields = self.get_dense_vector_fields();
205        if dense_fields.is_empty() {
206            return Ok(());
207        }
208        let dense_fields: Vec<Field> = dense_fields.into_iter().map(|(f, _)| f).collect();
209
210        // Collect files to delete and reset fields to Flat state
211        let files_to_delete = {
212            let metadata_arc = self.segment_manager.metadata();
213            let mut meta = metadata_arc.write().await;
214            let mut files = Vec::new();
215            for field in &dense_fields {
216                if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
217                    field_meta.state = super::VectorIndexState::Flat;
218                    if let Some(ref f) = field_meta.centroids_file {
219                        files.push(f.clone());
220                    }
221                    if let Some(ref f) = field_meta.codebook_file {
222                        files.push(f.clone());
223                    }
224                    field_meta.centroids_file = None;
225                    field_meta.codebook_file = None;
226                }
227            }
228            meta.save(self.directory.as_ref()).await?;
229            files
230        };
231
232        // Delete old centroids/codebook files
233        for file in files_to_delete {
234            let _ = self.directory.delete(std::path::Path::new(&file)).await;
235        }
236
237        log::info!("Reset vector index state to Flat, triggering rebuild...");
238
239        // Now build fresh
240        self.build_vector_index().await
241    }
242
243    // ========================================================================
244    // Helper methods
245    // ========================================================================
246
247    /// Get all dense vector fields that need ANN indexes
248    fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
249        self.schema
250            .fields()
251            .filter_map(|(field, entry)| {
252                if entry.field_type == FieldType::DenseVector && entry.indexed {
253                    entry
254                        .dense_vector_config
255                        .as_ref()
256                        .filter(|c| !c.is_flat())
257                        .map(|c| (field, c.clone()))
258                } else {
259                    None
260                }
261            })
262            .collect()
263    }
264
265    /// Get fields that need building (not already built)
266    async fn get_fields_to_build(
267        &self,
268        dense_fields: &[(Field, DenseVectorConfig)],
269    ) -> Vec<(Field, DenseVectorConfig)> {
270        let metadata_arc = self.segment_manager.metadata();
271        let meta = metadata_arc.read().await;
272        dense_fields
273            .iter()
274            .filter(|(field, _)| !meta.is_field_built(field.0))
275            .cloned()
276            .collect()
277    }
278
279    /// Count flat vectors across all segments
280    /// Only loads segments that have a vectors file to avoid unnecessary I/O
281    async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
282        let mut total_vectors = 0usize;
283        let mut doc_offset = 0u32;
284
285        for id_str in segment_ids {
286            let Some(segment_id) = SegmentId::from_hex(id_str) else {
287                continue;
288            };
289
290            // Quick check: skip segments without vectors file
291            let files = crate::segment::SegmentFiles::new(segment_id.0);
292            if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
293                // No vectors file - segment has no vectors, skip loading
294                continue;
295            }
296
297            // Only load segments that have vectors
298            if let Ok(reader) = SegmentReader::open(
299                self.directory.as_ref(),
300                segment_id,
301                Arc::clone(&self.schema),
302                doc_offset,
303                self.config.term_cache_blocks,
304            )
305            .await
306            {
307                for flat_data in reader.flat_vectors().values() {
308                    total_vectors += flat_data.num_vectors;
309                }
310                doc_offset += reader.meta().num_docs;
311            }
312        }
313
314        total_vectors
315    }
316
317    /// Collect vectors from segments for training, with sampling for large datasets.
318    ///
319    /// K-means clustering converges well with ~100K samples, so we cap collection
320    /// per field to avoid loading millions of vectors into memory.
321    async fn collect_vectors_for_training(
322        &self,
323        segment_ids: &[String],
324        fields_to_build: &[(Field, DenseVectorConfig)],
325    ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
326        /// Maximum vectors per field for training. K-means converges well with ~100K samples.
327        const MAX_TRAINING_VECTORS: usize = 100_000;
328
329        let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
330        let mut doc_offset = 0u32;
331        let mut total_skipped = 0usize;
332
333        for id_str in segment_ids {
334            let segment_id = SegmentId::from_hex(id_str)
335                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
336            let reader = SegmentReader::open(
337                self.directory.as_ref(),
338                segment_id,
339                Arc::clone(&self.schema),
340                doc_offset,
341                self.config.term_cache_blocks,
342            )
343            .await?;
344
345            for (field_id, lazy_flat) in reader.flat_vectors() {
346                if !fields_to_build.iter().any(|(f, _)| f.0 == *field_id) {
347                    continue;
348                }
349                let entry = all_vectors.entry(*field_id).or_default();
350                let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
351
352                if remaining == 0 {
353                    total_skipped += lazy_flat.num_vectors;
354                    continue;
355                }
356
357                let n = lazy_flat.num_vectors;
358                let dim = lazy_flat.dim;
359                let quant = lazy_flat.quantization;
360
361                // Determine which vector indices to collect
362                let indices: Vec<usize> = if n <= remaining {
363                    (0..n).collect()
364                } else {
365                    let step = (n / remaining).max(1);
366                    (0..n).step_by(step).take(remaining).collect()
367                };
368
369                if indices.len() < n {
370                    total_skipped += n - indices.len();
371                }
372
373                // Batch-read and dequantize instead of one-by-one get_vector()
374                const BATCH: usize = 1024;
375                let mut f32_buf = vec![0f32; BATCH * dim];
376                for chunk in indices.chunks(BATCH) {
377                    // For contiguous ranges, use batch read
378                    let start = chunk[0];
379                    let end = *chunk.last().unwrap();
380                    if end - start + 1 == chunk.len() {
381                        // Contiguous — single batch read
382                        if let Ok(batch_bytes) =
383                            lazy_flat.read_vectors_batch(start, chunk.len()).await
384                        {
385                            let floats = chunk.len() * dim;
386                            f32_buf.resize(floats, 0.0);
387                            crate::segment::dequantize_raw(
388                                batch_bytes.as_slice(),
389                                quant,
390                                floats,
391                                &mut f32_buf,
392                            );
393                            for i in 0..chunk.len() {
394                                entry.push(f32_buf[i * dim..(i + 1) * dim].to_vec());
395                            }
396                        }
397                    } else {
398                        // Non-contiguous (sampled) — read individually but reuse buffer
399                        f32_buf.resize(dim, 0.0);
400                        for &idx in chunk {
401                            if let Ok(()) = lazy_flat.read_vector_into(idx, &mut f32_buf).await {
402                                entry.push(f32_buf[..dim].to_vec());
403                            }
404                        }
405                    }
406                }
407            }
408
409            doc_offset += reader.meta().num_docs;
410        }
411
412        if total_skipped > 0 {
413            let collected: usize = all_vectors.values().map(|v| v.len()).sum();
414            log::info!(
415                "Sampled {} vectors for training (skipped {}, max {} per field)",
416                collected,
417                total_skipped,
418                MAX_TRAINING_VECTORS,
419            );
420        }
421
422        Ok(all_vectors)
423    }
424
425    /// Load segment readers for given IDs
426    pub(super) async fn load_segment_readers(
427        &self,
428        segment_ids: &[String],
429    ) -> Result<Vec<SegmentReader>> {
430        let mut readers = Vec::new();
431        let mut doc_offset = 0u32;
432
433        for id_str in segment_ids {
434            let segment_id = SegmentId::from_hex(id_str)
435                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
436            let reader = SegmentReader::open(
437                self.directory.as_ref(),
438                segment_id,
439                Arc::clone(&self.schema),
440                doc_offset,
441                self.config.term_cache_blocks,
442            )
443            .await?;
444            doc_offset += reader.meta().num_docs;
445            readers.push(reader);
446        }
447
448        Ok(readers)
449    }
450
451    /// Train index for a single field
452    async fn train_field_index(
453        &self,
454        field: Field,
455        config: &DenseVectorConfig,
456        all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
457    ) -> Result<()> {
458        let field_id = field.0;
459        let vectors = match all_vectors.get(&field_id) {
460            Some(v) if !v.is_empty() => v,
461            _ => return Ok(()),
462        };
463
464        let dim = config.dim;
465        let num_vectors = vectors.len();
466        let num_clusters = config.optimal_num_clusters(num_vectors);
467
468        log::info!(
469            "Training vector index for field {} with {} vectors, {} clusters (dim={})",
470            field_id,
471            num_vectors,
472            num_clusters,
473            dim,
474        );
475
476        let centroids_filename = format!("field_{}_centroids.bin", field_id);
477        let mut codebook_filename: Option<String> = None;
478
479        match config.index_type {
480            VectorIndexType::IvfRaBitQ => {
481                self.train_ivf_rabitq(field_id, dim, num_clusters, vectors, &centroids_filename)
482                    .await?;
483            }
484            VectorIndexType::ScaNN => {
485                codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
486                self.train_scann(
487                    field_id,
488                    dim,
489                    num_clusters,
490                    vectors,
491                    &centroids_filename,
492                    codebook_filename.as_ref().unwrap(),
493                )
494                .await?;
495            }
496            _ => {
497                // RaBitQ or Flat - no pre-training needed
498                return Ok(());
499            }
500        }
501
502        // Update metadata to mark this field as built
503        self.segment_manager
504            .update_metadata(|meta| {
505                meta.init_field(field_id, config.index_type);
506                meta.total_vectors = num_vectors;
507                meta.mark_field_built(
508                    field_id,
509                    num_vectors,
510                    num_clusters,
511                    centroids_filename.clone(),
512                    codebook_filename.clone(),
513                );
514            })
515            .await?;
516
517        Ok(())
518    }
519
520    /// Serialize a trained structure to JSON and save to an index-level file.
521    async fn save_trained_artifact(
522        &self,
523        artifact: &impl serde::Serialize,
524        filename: &str,
525    ) -> Result<()> {
526        let bytes =
527            serde_json::to_vec(artifact).map_err(|e| Error::Serialization(e.to_string()))?;
528        self.directory
529            .write(std::path::Path::new(filename), &bytes)
530            .await?;
531        Ok(())
532    }
533
534    /// Train IVF-RaBitQ centroids
535    async fn train_ivf_rabitq(
536        &self,
537        field_id: u32,
538        dim: usize,
539        num_clusters: usize,
540        vectors: &[Vec<f32>],
541        centroids_filename: &str,
542    ) -> Result<()> {
543        let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
544        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
545        self.save_trained_artifact(&centroids, centroids_filename)
546            .await?;
547
548        log::info!(
549            "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
550            field_id,
551            centroids.num_clusters
552        );
553        Ok(())
554    }
555
556    /// Train ScaNN (IVF-PQ) centroids and codebook
557    async fn train_scann(
558        &self,
559        field_id: u32,
560        dim: usize,
561        num_clusters: usize,
562        vectors: &[Vec<f32>],
563        centroids_filename: &str,
564        codebook_filename: &str,
565    ) -> Result<()> {
566        let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
567        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
568        self.save_trained_artifact(&centroids, centroids_filename)
569            .await?;
570
571        let pq_config = crate::structures::PQConfig::new(dim);
572        let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
573        self.save_trained_artifact(&codebook, codebook_filename)
574            .await?;
575
576        log::info!(
577            "Saved ScaNN centroids and codebook for field {} ({} clusters)",
578            field_id,
579            centroids.num_clusters
580        );
581        Ok(())
582    }
583}