Skip to main content

hermes_core/index/
vector_builder.rs

1//! Vector index building for IndexWriter
2//!
3//! This module handles:
4//! - Training centroids/codebooks from accumulated Flat vectors
5//! - Rebuilding segments with ANN indexes
6//! - Threshold-based auto-triggering of vector index builds
7
8use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20    /// Check if any dense vector field should be built and trigger training
21    pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22        let dense_fields = self.get_dense_vector_fields();
23        if dense_fields.is_empty() {
24            return Ok(());
25        }
26
27        // Quick check: if all fields are already built, skip entirely
28        // This avoids loading segments just to count vectors when index is already built
29        let all_built = {
30            let metadata_arc = self.segment_manager.metadata();
31            let meta = metadata_arc.read().await;
32            dense_fields
33                .iter()
34                .all(|(field, _)| meta.is_field_built(field.0))
35        };
36        if all_built {
37            return Ok(());
38        }
39
40        // Count total vectors across all segments
41        let segment_ids = self.segment_manager.get_segment_ids().await;
42        let total_vectors = self.count_flat_vectors(&segment_ids).await;
43
44        // Update total in metadata and check if any field should be built
45        let should_build = {
46            let metadata_arc = self.segment_manager.metadata();
47            let mut meta = metadata_arc.write().await;
48            meta.total_vectors = total_vectors;
49            dense_fields.iter().any(|(field, config)| {
50                let threshold = config.build_threshold.unwrap_or(1000);
51                meta.should_build_field(field.0, threshold)
52            })
53        };
54
55        if should_build {
56            log::info!(
57                "Threshold crossed ({} vectors), auto-triggering vector index build",
58                total_vectors
59            );
60            self.build_vector_index().await?;
61        }
62
63        Ok(())
64    }
65
66    /// Build vector index from accumulated Flat vectors (trains ONCE)
67    ///
68    /// This trains centroids/codebooks from ALL vectors across all segments.
69    /// Training happens only ONCE - subsequent calls are no-ops if already built.
70    ///
71    /// **Note:** This is auto-triggered by `commit()` when threshold is crossed.
72    /// You typically don't need to call this manually.
73    ///
74    /// The process:
75    /// 1. Check if already built (skip if so)
76    /// 2. Collect all vectors from all segments
77    /// 3. Train centroids/codebooks based on schema's index_type
78    /// 4. Update metadata to mark as built (prevents re-training)
79    pub async fn build_vector_index(&self) -> Result<()> {
80        let dense_fields = self.get_dense_vector_fields();
81        if dense_fields.is_empty() {
82            log::info!("No dense vector fields configured for ANN indexing");
83            return Ok(());
84        }
85
86        // Check which fields need building (skip already built)
87        let fields_to_build = self.get_fields_to_build(&dense_fields).await;
88        if fields_to_build.is_empty() {
89            log::info!("All vector fields already built, skipping training");
90            return Ok(());
91        }
92
93        let segment_ids = self.segment_manager.get_segment_ids().await;
94        if segment_ids.is_empty() {
95            return Ok(());
96        }
97
98        // Collect all vectors from all segments for fields that need building
99        let all_vectors = self
100            .collect_vectors_for_training(&segment_ids, &fields_to_build)
101            .await?;
102
103        // Train centroids/codebooks for each field
104        for (field, config) in &fields_to_build {
105            self.train_field_index(*field, config, &all_vectors).await?;
106        }
107
108        log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
109
110        // Rebuild segments with ANN indexes using trained structures
111        self.rebuild_segments_with_ann().await?;
112
113        Ok(())
114    }
115
116    /// Rebuild all segments with ANN indexes using trained centroids/codebooks
117    pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
118        let segment_ids = self.segment_manager.get_segment_ids().await;
119        if segment_ids.is_empty() {
120            return Ok(());
121        }
122
123        // Load trained structures from metadata
124        let (trained_centroids, trained_codebooks) = {
125            let metadata_arc = self.segment_manager.metadata();
126            let meta = metadata_arc.read().await;
127            meta.load_trained_structures(self.directory.as_ref()).await
128        };
129
130        if trained_centroids.is_empty() {
131            log::info!("No trained structures to rebuild with");
132            return Ok(());
133        }
134
135        let trained = TrainedVectorStructures {
136            centroids: trained_centroids,
137            codebooks: trained_codebooks,
138        };
139
140        // Load all segment readers
141        let readers = self.load_segment_readers(&segment_ids).await?;
142
143        // Calculate total doc count for the merged segment
144        let total_docs: u32 = readers.iter().map(|r| r.meta().num_docs).sum();
145
146        // Merge all segments into one with ANN indexes
147        let merger = SegmentMerger::new(Arc::clone(&self.schema));
148        let new_segment_id = SegmentId::new();
149        merger
150            .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
151            .await?;
152
153        // Atomically update segments and delete old ones via SegmentManager
154        self.segment_manager
155            .replace_segments(vec![(new_segment_id.to_hex(), total_docs)], segment_ids)
156            .await?;
157
158        log::info!("Segments rebuilt with ANN indexes");
159        Ok(())
160    }
161
162    /// Get total vector count across all segments (for threshold checking)
163    pub async fn total_vector_count(&self) -> usize {
164        let metadata_arc = self.segment_manager.metadata();
165        metadata_arc.read().await.total_vectors
166    }
167
168    /// Check if vector index has been built for a field
169    pub async fn is_vector_index_built(&self, field: Field) -> bool {
170        let metadata_arc = self.segment_manager.metadata();
171        metadata_arc.read().await.is_field_built(field.0)
172    }
173
174    /// Rebuild vector index by retraining centroids/codebooks
175    ///
176    /// Use this when:
177    /// - Significant new data has been added and you want better centroids
178    /// - You want to change the number of clusters
179    /// - The vector distribution has changed significantly
180    ///
181    /// This resets the Built state to Flat, then triggers a fresh training.
182    pub async fn rebuild_vector_index(&self) -> Result<()> {
183        let dense_fields: Vec<Field> = self
184            .schema
185            .fields()
186            .filter_map(|(field, entry)| {
187                if entry.field_type == FieldType::DenseVector && entry.indexed {
188                    Some(field)
189                } else {
190                    None
191                }
192            })
193            .collect();
194
195        if dense_fields.is_empty() {
196            return Ok(());
197        }
198
199        // Collect files to delete and reset fields to Flat state
200        let files_to_delete = {
201            let metadata_arc = self.segment_manager.metadata();
202            let mut meta = metadata_arc.write().await;
203            let mut files = Vec::new();
204            for field in &dense_fields {
205                if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
206                    field_meta.state = super::VectorIndexState::Flat;
207                    if let Some(ref f) = field_meta.centroids_file {
208                        files.push(f.clone());
209                    }
210                    if let Some(ref f) = field_meta.codebook_file {
211                        files.push(f.clone());
212                    }
213                    field_meta.centroids_file = None;
214                    field_meta.codebook_file = None;
215                }
216            }
217            meta.save(self.directory.as_ref()).await?;
218            files
219        };
220
221        // Delete old centroids/codebook files
222        for file in files_to_delete {
223            let _ = self.directory.delete(std::path::Path::new(&file)).await;
224        }
225
226        log::info!("Reset vector index state to Flat, triggering rebuild...");
227
228        // Now build fresh
229        self.build_vector_index().await
230    }
231
232    // ========================================================================
233    // Helper methods
234    // ========================================================================
235
236    /// Get all dense vector fields that need ANN indexes
237    fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
238        self.schema
239            .fields()
240            .filter_map(|(field, entry)| {
241                if entry.field_type == FieldType::DenseVector && entry.indexed {
242                    entry
243                        .dense_vector_config
244                        .as_ref()
245                        .filter(|c| !c.is_flat())
246                        .map(|c| (field, c.clone()))
247                } else {
248                    None
249                }
250            })
251            .collect()
252    }
253
254    /// Get fields that need building (not already built)
255    async fn get_fields_to_build(
256        &self,
257        dense_fields: &[(Field, DenseVectorConfig)],
258    ) -> Vec<(Field, DenseVectorConfig)> {
259        let metadata_arc = self.segment_manager.metadata();
260        let meta = metadata_arc.read().await;
261        dense_fields
262            .iter()
263            .filter(|(field, _)| !meta.is_field_built(field.0))
264            .cloned()
265            .collect()
266    }
267
268    /// Count flat vectors across all segments
269    /// Only loads segments that have a vectors file to avoid unnecessary I/O
270    async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
271        let mut total_vectors = 0usize;
272        let mut doc_offset = 0u32;
273
274        for id_str in segment_ids {
275            let Some(segment_id) = SegmentId::from_hex(id_str) else {
276                continue;
277            };
278
279            // Quick check: skip segments without vectors file
280            let files = crate::segment::SegmentFiles::new(segment_id.0);
281            if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
282                // No vectors file - segment has no vectors, skip loading
283                continue;
284            }
285
286            // Only load segments that have vectors
287            if let Ok(reader) = SegmentReader::open(
288                self.directory.as_ref(),
289                segment_id,
290                Arc::clone(&self.schema),
291                doc_offset,
292                self.config.term_cache_blocks,
293            )
294            .await
295            {
296                for index in reader.vector_indexes().values() {
297                    if let crate::segment::VectorIndex::Flat(flat_data) = index {
298                        total_vectors += flat_data.vectors.len();
299                    }
300                }
301                doc_offset += reader.meta().num_docs;
302            }
303        }
304
305        total_vectors
306    }
307
308    /// Collect all vectors from segments for fields that need building
309    async fn collect_vectors_for_training(
310        &self,
311        segment_ids: &[String],
312        fields_to_build: &[(Field, DenseVectorConfig)],
313    ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
314        let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
315        let mut doc_offset = 0u32;
316
317        for id_str in segment_ids {
318            let segment_id = SegmentId::from_hex(id_str)
319                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
320            let reader = SegmentReader::open(
321                self.directory.as_ref(),
322                segment_id,
323                Arc::clone(&self.schema),
324                doc_offset,
325                self.config.term_cache_blocks,
326            )
327            .await?;
328
329            // Extract vectors from each Flat index
330            for (field_id, index) in reader.vector_indexes() {
331                if fields_to_build.iter().any(|(f, _)| f.0 == *field_id)
332                    && let crate::segment::VectorIndex::Flat(flat_data) = index
333                {
334                    all_vectors
335                        .entry(*field_id)
336                        .or_default()
337                        .extend(flat_data.vectors.iter().cloned());
338                }
339            }
340
341            doc_offset += reader.meta().num_docs;
342        }
343
344        Ok(all_vectors)
345    }
346
347    /// Load segment readers for given IDs
348    pub(super) async fn load_segment_readers(
349        &self,
350        segment_ids: &[String],
351    ) -> Result<Vec<SegmentReader>> {
352        let mut readers = Vec::new();
353        let mut doc_offset = 0u32;
354
355        for id_str in segment_ids {
356            let segment_id = SegmentId::from_hex(id_str)
357                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
358            let reader = SegmentReader::open(
359                self.directory.as_ref(),
360                segment_id,
361                Arc::clone(&self.schema),
362                doc_offset,
363                self.config.term_cache_blocks,
364            )
365            .await?;
366            doc_offset += reader.meta().num_docs;
367            readers.push(reader);
368        }
369
370        Ok(readers)
371    }
372
373    /// Train index for a single field
374    async fn train_field_index(
375        &self,
376        field: Field,
377        config: &DenseVectorConfig,
378        all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
379    ) -> Result<()> {
380        let field_id = field.0;
381        let vectors = match all_vectors.get(&field_id) {
382            Some(v) if !v.is_empty() => v,
383            _ => return Ok(()),
384        };
385
386        let index_dim = config.index_dim();
387        let num_vectors = vectors.len();
388        let num_clusters = config.optimal_num_clusters(num_vectors);
389
390        log::info!(
391            "Training vector index for field {} with {} vectors, {} clusters",
392            field_id,
393            num_vectors,
394            num_clusters
395        );
396
397        let centroids_filename = format!("field_{}_centroids.bin", field_id);
398        let mut codebook_filename: Option<String> = None;
399
400        match config.index_type {
401            VectorIndexType::IvfRaBitQ => {
402                self.train_ivf_rabitq(
403                    field_id,
404                    index_dim,
405                    num_clusters,
406                    vectors,
407                    &centroids_filename,
408                )
409                .await?;
410            }
411            VectorIndexType::ScaNN => {
412                codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
413                self.train_scann(
414                    field_id,
415                    index_dim,
416                    num_clusters,
417                    vectors,
418                    &centroids_filename,
419                    codebook_filename.as_ref().unwrap(),
420                )
421                .await?;
422            }
423            _ => {
424                // RaBitQ or Flat - no pre-training needed
425                return Ok(());
426            }
427        }
428
429        // Update metadata to mark this field as built
430        self.segment_manager
431            .update_metadata(|meta| {
432                meta.init_field(field_id, config.index_type);
433                meta.total_vectors = num_vectors;
434                meta.mark_field_built(
435                    field_id,
436                    num_vectors,
437                    num_clusters,
438                    centroids_filename.clone(),
439                    codebook_filename.clone(),
440                );
441            })
442            .await?;
443
444        Ok(())
445    }
446
447    /// Train IVF-RaBitQ centroids
448    async fn train_ivf_rabitq(
449        &self,
450        field_id: u32,
451        index_dim: usize,
452        num_clusters: usize,
453        vectors: &[Vec<f32>],
454        centroids_filename: &str,
455    ) -> Result<()> {
456        let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
457        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
458
459        // Save centroids to index-level file
460        let centroids_path = std::path::Path::new(centroids_filename);
461        let centroids_bytes =
462            serde_json::to_vec(&centroids).map_err(|e| Error::Serialization(e.to_string()))?;
463        self.directory
464            .write(centroids_path, &centroids_bytes)
465            .await?;
466
467        log::info!(
468            "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
469            field_id,
470            centroids.num_clusters
471        );
472
473        Ok(())
474    }
475
476    /// Train ScaNN (IVF-PQ) centroids and codebook
477    async fn train_scann(
478        &self,
479        field_id: u32,
480        index_dim: usize,
481        num_clusters: usize,
482        vectors: &[Vec<f32>],
483        centroids_filename: &str,
484        codebook_filename: &str,
485    ) -> Result<()> {
486        // Train coarse centroids
487        let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
488        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
489
490        // Train PQ codebook
491        let pq_config = crate::structures::PQConfig::new(index_dim);
492        let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
493
494        // Save centroids
495        let centroids_path = std::path::Path::new(centroids_filename);
496        let centroids_bytes =
497            serde_json::to_vec(&centroids).map_err(|e| Error::Serialization(e.to_string()))?;
498        self.directory
499            .write(centroids_path, &centroids_bytes)
500            .await?;
501
502        // Save codebook
503        let codebook_path = std::path::Path::new(codebook_filename);
504        let codebook_bytes =
505            serde_json::to_vec(&codebook).map_err(|e| Error::Serialization(e.to_string()))?;
506        self.directory.write(codebook_path, &codebook_bytes).await?;
507
508        log::info!(
509            "Saved ScaNN centroids and codebook for field {} ({} clusters)",
510            field_id,
511            centroids.num_clusters
512        );
513
514        Ok(())
515    }
516}