Skip to main content

hermes_core/index/
vector_builder.rs

1//! Vector index building for IndexWriter
2//!
3//! Training is **manual-only** — decoupled from commit.
4//! Call `build_vector_index()` explicitly when ready.
5//! ANN indexes are built naturally during subsequent merges.
6
7use std::sync::Arc;
8
9use rustc_hash::FxHashMap;
10
11use crate::directories::DirectoryWriter;
12use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
13use crate::error::{Error, Result};
14use crate::segment::{SegmentId, SegmentReader};
15
16use super::IndexWriter;
17
18impl<D: DirectoryWriter + 'static> IndexWriter<D> {
19    /// Train vector index from accumulated Flat vectors (manual, not auto-triggered).
20    ///
21    /// 1. Acquires a snapshot (segments safe to read)
22    /// 2. Collects vectors for training
23    /// 3. Trains centroids/codebooks
24    /// 4. Updates metadata (marks fields as Built)
25    /// 5. Publishes to ArcSwap — merges will use these automatically
26    ///
27    /// Existing flat segments get ANN during normal merges. No rebuild needed.
28    pub async fn build_vector_index(&self) -> Result<()> {
29        let dense_fields = self.get_dense_vector_fields();
30        if dense_fields.is_empty() {
31            log::info!("No dense vector fields configured for ANN indexing");
32            return Ok(());
33        }
34
35        // Check which fields need building (skip already built)
36        let fields_to_build = self.get_fields_to_build(&dense_fields).await;
37        if fields_to_build.is_empty() {
38            log::info!("All vector fields already built, skipping training");
39            return Ok(());
40        }
41
42        // Acquire snapshot — segments won't be deleted while we read them
43        let snapshot = self.segment_manager.acquire_snapshot().await;
44        let segment_ids = snapshot.segment_ids();
45        if segment_ids.is_empty() {
46            return Ok(());
47        }
48
49        // Collect vectors for training
50        let all_vectors = self
51            .collect_vectors_for_training(segment_ids, &fields_to_build)
52            .await?;
53
54        // Train centroids/codebooks for each field
55        for (field, config) in &fields_to_build {
56            self.train_field_index(*field, config, &all_vectors).await?;
57        }
58
59        // Publish to ArcSwap — merges and new segment builds will use these
60        self.segment_manager.load_and_publish_trained().await;
61
62        log::info!("Vector index training complete, ANN will be built during merges");
63
64        Ok(())
65    }
66
67    /// Rebuild vector index by retraining centroids/codebooks.
68    ///
69    /// Resets Built state to Flat, clears trained structures, then trains fresh.
70    pub async fn rebuild_vector_index(&self) -> Result<()> {
71        let dense_fields = self.get_dense_vector_fields();
72        if dense_fields.is_empty() {
73            return Ok(());
74        }
75        let dense_fields: Vec<Field> = dense_fields.into_iter().map(|(f, _)| f).collect();
76
77        // Reset fields to Flat and collect files to delete
78        let dense_field_ids: Vec<u32> = dense_fields.iter().map(|f| f.0).collect();
79        let mut files_to_delete = Vec::new();
80        self.segment_manager
81            .update_metadata(|meta| {
82                for field_id in &dense_field_ids {
83                    if let Some(field_meta) = meta.vector_fields.get_mut(field_id) {
84                        field_meta.state = super::VectorIndexState::Flat;
85                        if let Some(ref f) = field_meta.centroids_file {
86                            files_to_delete.push(f.clone());
87                        }
88                        if let Some(ref f) = field_meta.codebook_file {
89                            files_to_delete.push(f.clone());
90                        }
91                        field_meta.centroids_file = None;
92                        field_meta.codebook_file = None;
93                    }
94                }
95            })
96            .await?;
97
98        // Delete old files
99        for file in files_to_delete {
100            let _ = self.directory.delete(std::path::Path::new(&file)).await;
101        }
102
103        // Clear ArcSwap so workers produce flat segments during retraining
104        self.segment_manager.clear_trained();
105
106        log::info!("Reset vector index state to Flat, triggering rebuild...");
107
108        self.build_vector_index().await
109    }
110
111    // ========================================================================
112    // Helper methods
113    // ========================================================================
114
115    /// Get all dense vector fields that need ANN indexes
116    fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
117        self.schema
118            .fields()
119            .filter_map(|(field, entry)| {
120                if entry.field_type == FieldType::DenseVector && entry.indexed {
121                    entry
122                        .dense_vector_config
123                        .as_ref()
124                        .filter(|c| !c.is_flat())
125                        .map(|c| (field, c.clone()))
126                } else {
127                    None
128                }
129            })
130            .collect()
131    }
132
133    /// Get fields that need building (not already built)
134    async fn get_fields_to_build(
135        &self,
136        dense_fields: &[(Field, DenseVectorConfig)],
137    ) -> Vec<(Field, DenseVectorConfig)> {
138        let field_ids: Vec<u32> = dense_fields.iter().map(|(f, _)| f.0).collect();
139        let built: Vec<u32> = self
140            .segment_manager
141            .read_metadata(|meta| {
142                field_ids
143                    .iter()
144                    .filter(|fid| meta.is_field_built(**fid))
145                    .copied()
146                    .collect()
147            })
148            .await;
149        dense_fields
150            .iter()
151            .filter(|(field, _)| !built.contains(&field.0))
152            .cloned()
153            .collect()
154    }
155
156    /// Collect vectors from segments for training, with sampling for large datasets.
157    ///
158    /// K-means clustering converges well with ~100K samples, so we cap collection
159    /// per field to avoid loading millions of vectors into memory.
160    async fn collect_vectors_for_training(
161        &self,
162        segment_ids: &[String],
163        fields_to_build: &[(Field, DenseVectorConfig)],
164    ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
165        /// Maximum vectors per field for training. K-means converges well with ~100K samples.
166        const MAX_TRAINING_VECTORS: usize = 100_000;
167
168        let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
169        let mut doc_offset = 0u32;
170        let mut total_skipped = 0usize;
171
172        for id_str in segment_ids {
173            let segment_id = SegmentId::from_hex(id_str)
174                .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
175            let reader = SegmentReader::open(
176                self.directory.as_ref(),
177                segment_id,
178                Arc::clone(&self.schema),
179                doc_offset,
180                self.config.term_cache_blocks,
181            )
182            .await?;
183
184            for (field_id, lazy_flat) in reader.flat_vectors() {
185                if !fields_to_build.iter().any(|(f, _)| f.0 == *field_id) {
186                    continue;
187                }
188                let entry = all_vectors.entry(*field_id).or_default();
189                let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
190
191                if remaining == 0 {
192                    total_skipped += lazy_flat.num_vectors;
193                    continue;
194                }
195
196                let n = lazy_flat.num_vectors;
197                let dim = lazy_flat.dim;
198                let quant = lazy_flat.quantization;
199
200                // Determine which vector indices to collect
201                let indices: Vec<usize> = if n <= remaining {
202                    (0..n).collect()
203                } else {
204                    let step = (n / remaining).max(1);
205                    (0..n).step_by(step).take(remaining).collect()
206                };
207
208                if indices.len() < n {
209                    total_skipped += n - indices.len();
210                }
211
212                // Batch-read and dequantize instead of one-by-one get_vector()
213                const BATCH: usize = 1024;
214                let mut f32_buf = vec![0f32; BATCH * dim];
215                for chunk in indices.chunks(BATCH) {
216                    // For contiguous ranges, use batch read
217                    let start = chunk[0];
218                    let end = *chunk.last().unwrap();
219                    if end - start + 1 == chunk.len() {
220                        // Contiguous — single batch read
221                        if let Ok(batch_bytes) =
222                            lazy_flat.read_vectors_batch(start, chunk.len()).await
223                        {
224                            let floats = chunk.len() * dim;
225                            f32_buf.resize(floats, 0.0);
226                            crate::segment::dequantize_raw(
227                                batch_bytes.as_slice(),
228                                quant,
229                                floats,
230                                &mut f32_buf,
231                            );
232                            for i in 0..chunk.len() {
233                                entry.push(f32_buf[i * dim..(i + 1) * dim].to_vec());
234                            }
235                        }
236                    } else {
237                        // Non-contiguous (sampled) — read individually but reuse buffer
238                        f32_buf.resize(dim, 0.0);
239                        for &idx in chunk {
240                            if let Ok(()) = lazy_flat.read_vector_into(idx, &mut f32_buf).await {
241                                entry.push(f32_buf[..dim].to_vec());
242                            }
243                        }
244                    }
245                }
246            }
247
248            doc_offset += reader.meta().num_docs;
249        }
250
251        if total_skipped > 0 {
252            let collected: usize = all_vectors.values().map(|v| v.len()).sum();
253            log::info!(
254                "Sampled {} vectors for training (skipped {}, max {} per field)",
255                collected,
256                total_skipped,
257                MAX_TRAINING_VECTORS,
258            );
259        }
260
261        Ok(all_vectors)
262    }
263
264    /// Train index for a single field
265    async fn train_field_index(
266        &self,
267        field: Field,
268        config: &DenseVectorConfig,
269        all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
270    ) -> Result<()> {
271        let field_id = field.0;
272        let vectors = match all_vectors.get(&field_id) {
273            Some(v) if !v.is_empty() => v,
274            _ => return Ok(()),
275        };
276
277        let dim = config.dim;
278        let num_vectors = vectors.len();
279        let num_clusters = config.optimal_num_clusters(num_vectors);
280
281        log::info!(
282            "Training vector index for field {} with {} vectors, {} clusters (dim={})",
283            field_id,
284            num_vectors,
285            num_clusters,
286            dim,
287        );
288
289        let centroids_filename = format!("field_{}_centroids.bin", field_id);
290        let mut codebook_filename: Option<String> = None;
291
292        match config.index_type {
293            VectorIndexType::IvfRaBitQ => {
294                self.train_ivf_rabitq(field_id, dim, num_clusters, vectors, &centroids_filename)
295                    .await?;
296            }
297            VectorIndexType::ScaNN => {
298                codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
299                self.train_scann(
300                    field_id,
301                    dim,
302                    num_clusters,
303                    vectors,
304                    &centroids_filename,
305                    codebook_filename.as_ref().unwrap(),
306                )
307                .await?;
308            }
309            _ => {
310                // RaBitQ or Flat - no pre-training needed
311                return Ok(());
312            }
313        }
314
315        // Update metadata to mark this field as built
316        self.segment_manager
317            .update_metadata(|meta| {
318                meta.init_field(field_id, config.index_type);
319                meta.total_vectors = num_vectors;
320                meta.mark_field_built(
321                    field_id,
322                    num_vectors,
323                    num_clusters,
324                    centroids_filename.clone(),
325                    codebook_filename.clone(),
326                );
327            })
328            .await?;
329
330        Ok(())
331    }
332
333    /// Serialize a trained structure to bincode and save to an index-level file.
334    async fn save_trained_artifact(
335        &self,
336        artifact: &impl serde::Serialize,
337        filename: &str,
338    ) -> Result<()> {
339        let bytes = bincode::serde::encode_to_vec(artifact, bincode::config::standard())
340            .map_err(|e| Error::Serialization(e.to_string()))?;
341        self.directory
342            .write(std::path::Path::new(filename), &bytes)
343            .await?;
344        Ok(())
345    }
346
347    /// Train IVF-RaBitQ centroids
348    async fn train_ivf_rabitq(
349        &self,
350        field_id: u32,
351        dim: usize,
352        num_clusters: usize,
353        vectors: &[Vec<f32>],
354        centroids_filename: &str,
355    ) -> Result<()> {
356        let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
357        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
358        self.save_trained_artifact(&centroids, centroids_filename)
359            .await?;
360
361        log::info!(
362            "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
363            field_id,
364            centroids.num_clusters
365        );
366        Ok(())
367    }
368
369    /// Train ScaNN (IVF-PQ) centroids and codebook
370    async fn train_scann(
371        &self,
372        field_id: u32,
373        dim: usize,
374        num_clusters: usize,
375        vectors: &[Vec<f32>],
376        centroids_filename: &str,
377        codebook_filename: &str,
378    ) -> Result<()> {
379        let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
380        let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
381        self.save_trained_artifact(&centroids, centroids_filename)
382            .await?;
383
384        let pq_config = crate::structures::PQConfig::new(dim);
385        let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
386        self.save_trained_artifact(&codebook, codebook_filename)
387            .await?;
388
389        log::info!(
390            "Saved ScaNN centroids and codebook for field {} ({} clusters)",
391            field_id,
392            centroids.num_clusters
393        );
394        Ok(())
395    }
396}