1use std::sync::Arc;
8
9use rustc_hash::FxHashMap;
10
11use crate::directories::DirectoryWriter;
12use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
13use crate::error::{Error, Result};
14use crate::segment::{SegmentId, SegmentReader};
15
16use super::IndexWriter;
17
18impl<D: DirectoryWriter + 'static> IndexWriter<D> {
19 pub async fn build_vector_index(&self) -> Result<()> {
29 let dense_fields = self.get_dense_vector_fields();
30 if dense_fields.is_empty() {
31 log::info!("No dense vector fields configured for ANN indexing");
32 return Ok(());
33 }
34
35 let fields_to_build = self.get_fields_to_build(&dense_fields).await;
37 if fields_to_build.is_empty() {
38 log::info!("All vector fields already built, skipping training");
39 return Ok(());
40 }
41
42 let snapshot = self.segment_manager.acquire_snapshot().await;
44 let segment_ids = snapshot.segment_ids();
45 if segment_ids.is_empty() {
46 return Ok(());
47 }
48
49 let all_vectors = self
51 .collect_vectors_for_training(segment_ids, &fields_to_build)
52 .await?;
53
54 for (field, config) in &fields_to_build {
56 self.train_field_index(*field, config, &all_vectors).await?;
57 }
58
59 self.segment_manager.load_and_publish_trained().await;
61
62 log::info!("Vector index training complete, ANN will be built during merges");
63
64 Ok(())
65 }
66
67 pub async fn rebuild_vector_index(&self) -> Result<()> {
71 let dense_fields = self.get_dense_vector_fields();
72 if dense_fields.is_empty() {
73 return Ok(());
74 }
75 let dense_fields: Vec<Field> = dense_fields.into_iter().map(|(f, _)| f).collect();
76
77 let dense_field_ids: Vec<u32> = dense_fields.iter().map(|f| f.0).collect();
79 let mut files_to_delete = Vec::new();
80 self.segment_manager
81 .update_metadata(|meta| {
82 for field_id in &dense_field_ids {
83 if let Some(field_meta) = meta.vector_fields.get_mut(field_id) {
84 field_meta.state = super::VectorIndexState::Flat;
85 if let Some(ref f) = field_meta.centroids_file {
86 files_to_delete.push(f.clone());
87 }
88 if let Some(ref f) = field_meta.codebook_file {
89 files_to_delete.push(f.clone());
90 }
91 field_meta.centroids_file = None;
92 field_meta.codebook_file = None;
93 }
94 }
95 })
96 .await?;
97
98 for file in files_to_delete {
100 let _ = self.directory.delete(std::path::Path::new(&file)).await;
101 }
102
103 self.segment_manager.clear_trained();
105
106 log::info!("Reset vector index state to Flat, triggering rebuild...");
107
108 self.build_vector_index().await
109 }
110
111 fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
117 self.schema
118 .fields()
119 .filter_map(|(field, entry)| {
120 if entry.field_type == FieldType::DenseVector && entry.indexed {
121 entry
122 .dense_vector_config
123 .as_ref()
124 .filter(|c| !c.is_flat())
125 .map(|c| (field, c.clone()))
126 } else {
127 None
128 }
129 })
130 .collect()
131 }
132
133 async fn get_fields_to_build(
135 &self,
136 dense_fields: &[(Field, DenseVectorConfig)],
137 ) -> Vec<(Field, DenseVectorConfig)> {
138 let field_ids: Vec<u32> = dense_fields.iter().map(|(f, _)| f.0).collect();
139 let built: Vec<u32> = self
140 .segment_manager
141 .read_metadata(|meta| {
142 field_ids
143 .iter()
144 .filter(|fid| meta.is_field_built(**fid))
145 .copied()
146 .collect()
147 })
148 .await;
149 dense_fields
150 .iter()
151 .filter(|(field, _)| !built.contains(&field.0))
152 .cloned()
153 .collect()
154 }
155
156 async fn collect_vectors_for_training(
161 &self,
162 segment_ids: &[String],
163 fields_to_build: &[(Field, DenseVectorConfig)],
164 ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
165 const MAX_TRAINING_VECTORS: usize = 100_000;
167
168 let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
169 let mut doc_offset = 0u32;
170 let mut total_skipped = 0usize;
171
172 for id_str in segment_ids {
173 let segment_id = SegmentId::from_hex(id_str)
174 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
175 let reader = SegmentReader::open(
176 self.directory.as_ref(),
177 segment_id,
178 Arc::clone(&self.schema),
179 doc_offset,
180 self.config.term_cache_blocks,
181 )
182 .await?;
183
184 for (field_id, lazy_flat) in reader.flat_vectors() {
185 if !fields_to_build.iter().any(|(f, _)| f.0 == *field_id) {
186 continue;
187 }
188 let entry = all_vectors.entry(*field_id).or_default();
189 let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
190
191 if remaining == 0 {
192 total_skipped += lazy_flat.num_vectors;
193 continue;
194 }
195
196 let n = lazy_flat.num_vectors;
197 let dim = lazy_flat.dim;
198 let quant = lazy_flat.quantization;
199
200 let indices: Vec<usize> = if n <= remaining {
202 (0..n).collect()
203 } else {
204 let step = (n / remaining).max(1);
205 (0..n).step_by(step).take(remaining).collect()
206 };
207
208 if indices.len() < n {
209 total_skipped += n - indices.len();
210 }
211
212 const BATCH: usize = 1024;
214 let mut f32_buf = vec![0f32; BATCH * dim];
215 for chunk in indices.chunks(BATCH) {
216 let start = chunk[0];
218 let end = *chunk.last().unwrap();
219 if end - start + 1 == chunk.len() {
220 if let Ok(batch_bytes) =
222 lazy_flat.read_vectors_batch(start, chunk.len()).await
223 {
224 let floats = chunk.len() * dim;
225 f32_buf.resize(floats, 0.0);
226 crate::segment::dequantize_raw(
227 batch_bytes.as_slice(),
228 quant,
229 floats,
230 &mut f32_buf,
231 );
232 for i in 0..chunk.len() {
233 entry.push(f32_buf[i * dim..(i + 1) * dim].to_vec());
234 }
235 }
236 } else {
237 f32_buf.resize(dim, 0.0);
239 for &idx in chunk {
240 if let Ok(()) = lazy_flat.read_vector_into(idx, &mut f32_buf).await {
241 entry.push(f32_buf[..dim].to_vec());
242 }
243 }
244 }
245 }
246 }
247
248 doc_offset += reader.meta().num_docs;
249 }
250
251 if total_skipped > 0 {
252 let collected: usize = all_vectors.values().map(|v| v.len()).sum();
253 log::info!(
254 "Sampled {} vectors for training (skipped {}, max {} per field)",
255 collected,
256 total_skipped,
257 MAX_TRAINING_VECTORS,
258 );
259 }
260
261 Ok(all_vectors)
262 }
263
264 async fn train_field_index(
266 &self,
267 field: Field,
268 config: &DenseVectorConfig,
269 all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
270 ) -> Result<()> {
271 let field_id = field.0;
272 let vectors = match all_vectors.get(&field_id) {
273 Some(v) if !v.is_empty() => v,
274 _ => return Ok(()),
275 };
276
277 let dim = config.dim;
278 let num_vectors = vectors.len();
279 let num_clusters = config.optimal_num_clusters(num_vectors);
280
281 log::info!(
282 "Training vector index for field {} with {} vectors, {} clusters (dim={})",
283 field_id,
284 num_vectors,
285 num_clusters,
286 dim,
287 );
288
289 let centroids_filename = format!("field_{}_centroids.bin", field_id);
290 let mut codebook_filename: Option<String> = None;
291
292 match config.index_type {
293 VectorIndexType::IvfRaBitQ => {
294 self.train_ivf_rabitq(field_id, dim, num_clusters, vectors, ¢roids_filename)
295 .await?;
296 }
297 VectorIndexType::ScaNN => {
298 codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
299 self.train_scann(
300 field_id,
301 dim,
302 num_clusters,
303 vectors,
304 ¢roids_filename,
305 codebook_filename.as_ref().unwrap(),
306 )
307 .await?;
308 }
309 _ => {
310 return Ok(());
312 }
313 }
314
315 self.segment_manager
317 .update_metadata(|meta| {
318 meta.init_field(field_id, config.index_type);
319 meta.total_vectors = num_vectors;
320 meta.mark_field_built(
321 field_id,
322 num_vectors,
323 num_clusters,
324 centroids_filename.clone(),
325 codebook_filename.clone(),
326 );
327 })
328 .await?;
329
330 Ok(())
331 }
332
333 async fn save_trained_artifact(
335 &self,
336 artifact: &impl serde::Serialize,
337 filename: &str,
338 ) -> Result<()> {
339 let bytes = bincode::serde::encode_to_vec(artifact, bincode::config::standard())
340 .map_err(|e| Error::Serialization(e.to_string()))?;
341 self.directory
342 .write(std::path::Path::new(filename), &bytes)
343 .await?;
344 Ok(())
345 }
346
347 async fn train_ivf_rabitq(
349 &self,
350 field_id: u32,
351 dim: usize,
352 num_clusters: usize,
353 vectors: &[Vec<f32>],
354 centroids_filename: &str,
355 ) -> Result<()> {
356 let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
357 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
358 self.save_trained_artifact(¢roids, centroids_filename)
359 .await?;
360
361 log::info!(
362 "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
363 field_id,
364 centroids.num_clusters
365 );
366 Ok(())
367 }
368
369 async fn train_scann(
371 &self,
372 field_id: u32,
373 dim: usize,
374 num_clusters: usize,
375 vectors: &[Vec<f32>],
376 centroids_filename: &str,
377 codebook_filename: &str,
378 ) -> Result<()> {
379 let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
380 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
381 self.save_trained_artifact(¢roids, centroids_filename)
382 .await?;
383
384 let pq_config = crate::structures::PQConfig::new(dim);
385 let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
386 self.save_trained_artifact(&codebook, codebook_filename)
387 .await?;
388
389 log::info!(
390 "Saved ScaNN centroids and codebook for field {} ({} clusters)",
391 field_id,
392 centroids.num_clusters
393 );
394 Ok(())
395 }
396}