Skip to main content

hermes_core/index/
vector_builder.rs

1//! Vector index building for IndexWriter
2//!
3//! This module handles:
4//! - Training centroids/codebooks from accumulated Flat vectors
5//! - Rebuilding segments with ANN indexes
6//! - Threshold-based auto-triggering of vector index builds
7
8use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentReader};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20    /// Check if any dense vector field should be built and trigger training
21    pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22        let dense_fields = self.get_dense_vector_fields();
23        if dense_fields.is_empty() {
24            return Ok(());
25        }
26
27        // Quick check: if all fields are already built, skip entirely
28        // This avoids loading segments just to count vectors when index is already built
29        let all_built = {
30            let metadata_arc = self.segment_manager.metadata();
31            let meta = metadata_arc.read().await;
32            dense_fields
33                .iter()
34                .all(|(field, _)| meta.is_field_built(field.0))
35        };
36        if all_built {
37            // Ensure workers have trained structures (handles from_index cold start)
38            if self
39                .trained_structures
40                .read()
41                .ok()
42                .is_none_or(|g| g.is_none())
43            {
44                self.publish_trained_structures().await;
45            }
46            return Ok(());
47        }
48
49        // Count total vectors across all segments
50        let segment_ids = self.segment_manager.get_segment_ids().await;
51        let total_vectors = self.count_flat_vectors(&segment_ids).await;
52
53        // Update total in metadata and check if any field should be built
54        let should_build = {
55            let metadata_arc = self.segment_manager.metadata();
56            let mut meta = metadata_arc.write().await;
57            meta.total_vectors = total_vectors;
58            dense_fields.iter().any(|(field, config)| {
59                let threshold = config.build_threshold.unwrap_or(1000);
60                meta.should_build_field(field.0, threshold)
61            })
62        };
63
64        if should_build {
65            log::info!(
66                "Threshold crossed ({} vectors), auto-triggering vector index build",
67                total_vectors
68            );
69            self.build_vector_index().await?;
70        }
71
72        Ok(())
73    }
74
75    /// Build vector index from accumulated Flat vectors (trains ONCE)
76    ///
77    /// This trains centroids/codebooks from ALL vectors across all segments.
78    /// Training happens only ONCE - subsequent calls are no-ops if already built.
79    ///
80    /// **Note:** This is auto-triggered by `commit()` when threshold is crossed.
81    /// You typically don't need to call this manually.
82    ///
83    /// The process:
84    /// 1. Check if already built (skip if so)
85    /// 2. Collect all vectors from all segments
86    /// 3. Train centroids/codebooks based on schema's index_type
87    /// 4. Update metadata to mark as built (prevents re-training)
88    pub async fn build_vector_index(&self) -> Result<()> {
89        let dense_fields = self.get_dense_vector_fields();
90        if dense_fields.is_empty() {
91            log::info!("No dense vector fields configured for ANN indexing");
92            return Ok(());
93        }
94
95        // Check which fields need building (skip already built)
96        let fields_to_build = self.get_fields_to_build(&dense_fields).await;
97        if fields_to_build.is_empty() {
98            log::info!("All vector fields already built, skipping training");
99            return Ok(());
100        }
101
102        let segment_ids = self.segment_manager.get_segment_ids().await;
103        if segment_ids.is_empty() {
104            return Ok(());
105        }
106
107        // Collect all vectors from all segments for fields that need building
108        let all_vectors = self
109            .collect_vectors_for_training(&segment_ids, &fields_to_build)
110            .await?;
111
112        // Train centroids/codebooks for each field
113        for (field, config) in &fields_to_build {
114            self.train_field_index(*field, config, &all_vectors).await?;
115        }
116
117        // Publish trained structures to workers so new segments get ANN inline.
118        // Existing flat segments acquire ANN during regular background merges.
119        self.publish_trained_structures().await;
120
121        log::info!("Vector index training complete, new segments will have ANN inline");
122
123        Ok(())
124    }
125
126    /// Publish trained structures to shared worker state so new segment builds
127    /// include ANN indexes inline. Called after training completes.
128    pub(super) async fn publish_trained_structures(&self) {
129        let trained = {
130            let metadata_arc = self.segment_manager.metadata();
131            let meta = metadata_arc.read().await;
132            meta.load_trained_structures(self.directory.as_ref()).await
133        };
134        if let Some(trained) = trained
135            && let Ok(mut guard) = self.trained_structures.write()
136        {
137            log::info!(
138                "[writer] published trained structures to workers ({} fields)",
139                trained.centroids.len()
140            );
141            *guard = Some(trained);
142        }
143    }
144
145    /// Get total vector count across all segments (for threshold checking)
146    pub async fn total_vector_count(&self) -> usize {
147        let metadata_arc = self.segment_manager.metadata();
148        metadata_arc.read().await.total_vectors
149    }
150
151    /// Check if vector index has been built for a field
152    pub async fn is_vector_index_built(&self, field: Field) -> bool {
153        let metadata_arc = self.segment_manager.metadata();
154        metadata_arc.read().await.is_field_built(field.0)
155    }
156
157    /// Rebuild vector index by retraining centroids/codebooks
158    ///
159    /// Use this when:
160    /// - Significant new data has been added and you want better centroids
161    /// - You want to change the number of clusters
162    /// - The vector distribution has changed significantly
163    ///
164    /// This resets the Built state to Flat, then triggers a fresh training.
165    pub async fn rebuild_vector_index(&self) -> Result<()> {
166        let dense_fields = self.get_dense_vector_fields();
167        if dense_fields.is_empty() {
168            return Ok(());
169        }
170        let dense_fields: Vec<Field> = dense_fields.into_iter().map(|(f, _)| f).collect();
171
172        // Collect files to delete and reset fields to Flat state
173        let files_to_delete = {
174            let metadata_arc = self.segment_manager.metadata();
175            let mut meta = metadata_arc.write().await;
176            let mut files = Vec::new();
177            for field in &dense_fields {
178                if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
179                    field_meta.state = super::VectorIndexState::Flat;
180                    if let Some(ref f) = field_meta.centroids_file {
181                        files.push(f.clone());
182                    }
183                    if let Some(ref f) = field_meta.codebook_file {
184                        files.push(f.clone());
185                    }
186                    field_meta.centroids_file = None;
187                    field_meta.codebook_file = None;
188                }
189            }
190            meta.save(self.directory.as_ref()).await?;
191            files
192        };
193
194        // Delete old centroids/codebook files
195        for file in files_to_delete {
196            let _ = self.directory.delete(std::path::Path::new(&file)).await;
197        }
198
199        // Clear shared trained structures so workers produce flat segments
200        // during retraining (avoids stale centroid mismatch)
201        if let Ok(mut guard) = self.trained_structures.write() {
202            *guard = None;
203        }
204
205        log::info!("Reset vector index state to Flat, triggering rebuild...");
206
207        // Now build fresh
208        self.build_vector_index().await
209    }
210
211    // ========================================================================
212    // Helper methods
213    // ========================================================================
214
215    /// Get all dense vector fields that need ANN indexes
216    fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
217        self.schema
218            .fields()
219            .filter_map(|(field, entry)| {
220                if entry.field_type == FieldType::DenseVector && entry.indexed {
221                    entry
222                        .dense_vector_config
223                        .as_ref()
224                        .filter(|c| !c.is_flat())
225                        .map(|c| (field, c.clone()))
226                } else {
227                    None
228                }
229            })
230            .collect()
231    }
232
233    /// Get fields that need building (not already built)
234    async fn get_fields_to_build(
235        &self,
236        dense_fields: &[(Field, DenseVectorConfig)],
237    ) -> Vec<(Field, DenseVectorConfig)> {
238        let metadata_arc = self.segment_manager.metadata();
239        let meta = metadata_arc.read().await;
240        dense_fields
241            .iter()
242            .filter(|(field, _)| !meta.is_field_built(field.0))
243            .cloned()
244            .collect()
245    }
246
247    /// Count flat vectors across all segments
248    /// Only loads segments that have a vectors file to avoid unnecessary I/O
249    async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
250        let mut total_vectors = 0usize;
251        let mut doc_offset = 0u32;
252
253        for id_str in segment_ids {
254            let Some(segment_id) = SegmentId::from_hex(id_str) else {
255                continue;
256            };
257
258            // Quick check: skip segments without vectors file
259            let files = crate::segment::SegmentFiles::new(segment_id.0);
260            if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
261                // No vectors file - segment has no vectors, skip loading
262                continue;
263            }
264
265            // Only load segments that have vectors
266            if let Ok(reader) = SegmentReader::open(
267                self.directory.as_ref(),
268                segment_id,
269                Arc::clone(&self.schema),
270                doc_offset,
271                self.config.term_cache_blocks,
272            )
273            .await
274            {
275                for flat_data in reader.flat_vectors().values() {
276                    total_vectors += flat_data.num_vectors;
277                }
278                doc_offset += reader.meta().num_docs;
279            }
280        }
281
282        total_vectors
283    }
284
285    /// Collect vectors from segments for training, with sampling for large datasets.
286    ///
287    /// K-means clustering converges well with ~100K samples, so we cap collection
288    /// per field to avoid loading millions of vectors into memory.
289    async fn collect_vectors_for_training(
290        &self,
291        segment_ids: &[String],
292        fields_to_build: &[(Field, DenseVectorConfig)],
293    ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
294        /// Maximum vectors per field for training. K-means converges well with ~100K samples.
295        const MAX_TRAINING_VECTORS: usize = 100_000;
296
297        let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
298        let mut doc_offset = 0u32;
299        let mut total_skipped = 0usize;
300
301        for id_str in segment_ids {
302            let segment_id = SegmentId::from_hex(id_str)
303                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
304            let reader = SegmentReader::open(
305                self.directory.as_ref(),
306                segment_id,
307                Arc::clone(&self.schema),
308                doc_offset,
309                self.config.term_cache_blocks,
310            )
311            .await?;
312
313            for (field_id, lazy_flat) in reader.flat_vectors() {
314                if !fields_to_build.iter().any(|(f, _)| f.0 == *field_id) {
315                    continue;
316                }
317                let entry = all_vectors.entry(*field_id).or_default();
318                let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
319
320                if remaining == 0 {
321                    total_skipped += lazy_flat.num_vectors;
322                    continue;
323                }
324
325                let n = lazy_flat.num_vectors;
326                let dim = lazy_flat.dim;
327                let quant = lazy_flat.quantization;
328
329                // Determine which vector indices to collect
330                let indices: Vec<usize> = if n <= remaining {
331                    (0..n).collect()
332                } else {
333                    let step = (n / remaining).max(1);
334                    (0..n).step_by(step).take(remaining).collect()
335                };
336
337                if indices.len() < n {
338                    total_skipped += n - indices.len();
339                }
340
341                // Batch-read and dequantize instead of one-by-one get_vector()
342                const BATCH: usize = 1024;
343                let mut f32_buf = vec![0f32; BATCH * dim];
344                for chunk in indices.chunks(BATCH) {
345                    // For contiguous ranges, use batch read
346                    let start = chunk[0];
347                    let end = *chunk.last().unwrap();
348                    if end - start + 1 == chunk.len() {
349                        // Contiguous — single batch read
350                        if let Ok(batch_bytes) =
351                            lazy_flat.read_vectors_batch(start, chunk.len()).await
352                        {
353                            let floats = chunk.len() * dim;
354                            f32_buf.resize(floats, 0.0);
355                            crate::segment::dequantize_raw(
356                                batch_bytes.as_slice(),
357                                quant,
358                                floats,
359                                &mut f32_buf,
360                            );
361                            for i in 0..chunk.len() {
362                                entry.push(f32_buf[i * dim..(i + 1) * dim].to_vec());
363                            }
364                        }
365                    } else {
366                        // Non-contiguous (sampled) — read individually but reuse buffer
367                        f32_buf.resize(dim, 0.0);
368                        for &idx in chunk {
369                            if let Ok(()) = lazy_flat.read_vector_into(idx, &mut f32_buf).await {
370                                entry.push(f32_buf[..dim].to_vec());
371                            }
372                        }
373                    }
374                }
375            }
376
377            doc_offset += reader.meta().num_docs;
378        }
379
380        if total_skipped > 0 {
381            let collected: usize = all_vectors.values().map(|v| v.len()).sum();
382            log::info!(
383                "Sampled {} vectors for training (skipped {}, max {} per field)",
384                collected,
385                total_skipped,
386                MAX_TRAINING_VECTORS,
387            );
388        }
389
390        Ok(all_vectors)
391    }
392
393    /// Train index for a single field
394    async fn train_field_index(
395        &self,
396        field: Field,
397        config: &DenseVectorConfig,
398        all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
399    ) -> Result<()> {
400        let field_id = field.0;
401        let vectors = match all_vectors.get(&field_id) {
402            Some(v) if !v.is_empty() => v,
403            _ => return Ok(()),
404        };
405
406        let dim = config.dim;
407        let num_vectors = vectors.len();
408        let num_clusters = config.optimal_num_clusters(num_vectors);
409
410        log::info!(
411            "Training vector index for field {} with {} vectors, {} clusters (dim={})",
412            field_id,
413            num_vectors,
414            num_clusters,
415            dim,
416        );
417
418        let centroids_filename = format!("field_{}_centroids.bin", field_id);
419        let mut codebook_filename: Option<String> = None;
420
421        match config.index_type {
422            VectorIndexType::IvfRaBitQ => {
423                self.train_ivf_rabitq(field_id, dim, num_clusters, vectors, &centroids_filename)
424                    .await?;
425            }
426            VectorIndexType::ScaNN => {
427                codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
428                self.train_scann(
429                    field_id,
430                    dim,
431                    num_clusters,
432                    vectors,
433                    &centroids_filename,
434                    codebook_filename.as_ref().unwrap(),
435                )
436                .await?;
437            }
438            _ => {
439                // RaBitQ or Flat - no pre-training needed
440                return Ok(());
441            }
442        }
443
444        // Update metadata to mark this field as built
445        self.segment_manager
446            .update_metadata(|meta| {
447                meta.init_field(field_id, config.index_type);
448                meta.total_vectors = num_vectors;
449                meta.mark_field_built(
450                    field_id,
451                    num_vectors,
452                    num_clusters,
453                    centroids_filename.clone(),
454                    codebook_filename.clone(),
455                );
456            })
457            .await?;
458
459        Ok(())
460    }
461
462    /// Serialize a trained structure to JSON and save to an index-level file.
463    async fn save_trained_artifact(
464        &self,
465        artifact: &impl serde::Serialize,
466        filename: &str,
467    ) -> Result<()> {
468        let bytes =
469            serde_json::to_vec(artifact).map_err(|e| Error::Serialization(e.to_string()))?;
470        self.directory
471            .write(std::path::Path::new(filename), &bytes)
472            .await?;
473        Ok(())
474    }
475
476    /// Train IVF-RaBitQ centroids
477    async fn train_ivf_rabitq(
478        &self,
479        field_id: u32,
480        dim: usize,
481        num_clusters: usize,
482        vectors: &[Vec<f32>],
483        centroids_filename: &str,
484    ) -> Result<()> {
485        let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
486        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
487        self.save_trained_artifact(&centroids, centroids_filename)
488            .await?;
489
490        log::info!(
491            "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
492            field_id,
493            centroids.num_clusters
494        );
495        Ok(())
496    }
497
498    /// Train ScaNN (IVF-PQ) centroids and codebook
499    async fn train_scann(
500        &self,
501        field_id: u32,
502        dim: usize,
503        num_clusters: usize,
504        vectors: &[Vec<f32>],
505        centroids_filename: &str,
506        codebook_filename: &str,
507    ) -> Result<()> {
508        let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
509        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
510        self.save_trained_artifact(&centroids, centroids_filename)
511            .await?;
512
513        let pq_config = crate::structures::PQConfig::new(dim);
514        let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
515        self.save_trained_artifact(&codebook, codebook_filename)
516            .await?;
517
518        log::info!(
519            "Saved ScaNN centroids and codebook for field {} ({} clusters)",
520            field_id,
521            centroids.num_clusters
522        );
523        Ok(())
524    }
525}