1use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20 pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22 let dense_fields = self.get_dense_vector_fields();
23 if dense_fields.is_empty() {
24 return Ok(());
25 }
26
27 let all_built = {
30 let metadata_arc = self.segment_manager.metadata();
31 let meta = metadata_arc.read().await;
32 dense_fields
33 .iter()
34 .all(|(field, _)| meta.is_field_built(field.0))
35 };
36 if all_built {
37 return Ok(());
38 }
39
40 let segment_ids = self.segment_manager.get_segment_ids().await;
42 let total_vectors = self.count_flat_vectors(&segment_ids).await;
43
44 let should_build = {
46 let metadata_arc = self.segment_manager.metadata();
47 let mut meta = metadata_arc.write().await;
48 meta.total_vectors = total_vectors;
49 dense_fields.iter().any(|(field, config)| {
50 let threshold = config.build_threshold.unwrap_or(1000);
51 meta.should_build_field(field.0, threshold)
52 })
53 };
54
55 if should_build {
56 log::info!(
57 "Threshold crossed ({} vectors), auto-triggering vector index build",
58 total_vectors
59 );
60 self.build_vector_index().await?;
61 }
62
63 Ok(())
64 }
65
66 pub async fn build_vector_index(&self) -> Result<()> {
80 let dense_fields = self.get_dense_vector_fields();
81 if dense_fields.is_empty() {
82 log::info!("No dense vector fields configured for ANN indexing");
83 return Ok(());
84 }
85
86 let fields_to_build = self.get_fields_to_build(&dense_fields).await;
88 if fields_to_build.is_empty() {
89 log::info!("All vector fields already built, skipping training");
90 return Ok(());
91 }
92
93 self.segment_manager.wait_for_merges().await;
98
99 let segment_ids = self.segment_manager.get_segment_ids().await;
100 if segment_ids.is_empty() {
101 return Ok(());
102 }
103
104 let all_vectors = self
106 .collect_vectors_for_training(&segment_ids, &fields_to_build)
107 .await?;
108
109 for (field, config) in &fields_to_build {
111 self.train_field_index(*field, config, &all_vectors).await?;
112 }
113
114 log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
115
116 self.rebuild_segments_with_ann().await?;
118
119 Ok(())
120 }
121
122 pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
124 self.segment_manager.pause_merges();
128 self.segment_manager.wait_for_merges().await;
129
130 let result = self.rebuild_segments_with_ann_inner().await;
131
132 self.segment_manager.resume_merges();
134
135 result
136 }
137
138 async fn rebuild_segments_with_ann_inner(&self) -> Result<()> {
139 let segment_ids = self.segment_manager.get_segment_ids().await;
140 if segment_ids.is_empty() {
141 return Ok(());
142 }
143
144 let (trained_centroids, trained_codebooks) = {
146 let metadata_arc = self.segment_manager.metadata();
147 let meta = metadata_arc.read().await;
148 meta.load_trained_structures(self.directory.as_ref()).await
149 };
150
151 if trained_centroids.is_empty() {
152 log::info!("No trained structures to rebuild with");
153 return Ok(());
154 }
155
156 let trained = TrainedVectorStructures {
157 centroids: trained_centroids,
158 codebooks: trained_codebooks,
159 };
160
161 let readers = self.load_segment_readers(&segment_ids).await?;
163
164 let total_docs: u32 = readers.iter().map(|r| r.meta().num_docs).sum();
166
167 let merger = SegmentMerger::new(Arc::clone(&self.schema));
169 let new_segment_id = SegmentId::new();
170 merger
171 .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
172 .await?;
173
174 self.segment_manager
176 .replace_segments(vec![(new_segment_id.to_hex(), total_docs)], segment_ids)
177 .await?;
178
179 log::info!("Segments rebuilt with ANN indexes");
180 Ok(())
181 }
182
183 pub async fn total_vector_count(&self) -> usize {
185 let metadata_arc = self.segment_manager.metadata();
186 metadata_arc.read().await.total_vectors
187 }
188
189 pub async fn is_vector_index_built(&self, field: Field) -> bool {
191 let metadata_arc = self.segment_manager.metadata();
192 metadata_arc.read().await.is_field_built(field.0)
193 }
194
195 pub async fn rebuild_vector_index(&self) -> Result<()> {
204 let dense_fields = self.get_dense_vector_fields();
205 if dense_fields.is_empty() {
206 return Ok(());
207 }
208 let dense_fields: Vec<Field> = dense_fields.into_iter().map(|(f, _)| f).collect();
209
210 let files_to_delete = {
212 let metadata_arc = self.segment_manager.metadata();
213 let mut meta = metadata_arc.write().await;
214 let mut files = Vec::new();
215 for field in &dense_fields {
216 if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
217 field_meta.state = super::VectorIndexState::Flat;
218 if let Some(ref f) = field_meta.centroids_file {
219 files.push(f.clone());
220 }
221 if let Some(ref f) = field_meta.codebook_file {
222 files.push(f.clone());
223 }
224 field_meta.centroids_file = None;
225 field_meta.codebook_file = None;
226 }
227 }
228 meta.save(self.directory.as_ref()).await?;
229 files
230 };
231
232 for file in files_to_delete {
234 let _ = self.directory.delete(std::path::Path::new(&file)).await;
235 }
236
237 log::info!("Reset vector index state to Flat, triggering rebuild...");
238
239 self.build_vector_index().await
241 }
242
243 fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
249 self.schema
250 .fields()
251 .filter_map(|(field, entry)| {
252 if entry.field_type == FieldType::DenseVector && entry.indexed {
253 entry
254 .dense_vector_config
255 .as_ref()
256 .filter(|c| !c.is_flat())
257 .map(|c| (field, c.clone()))
258 } else {
259 None
260 }
261 })
262 .collect()
263 }
264
265 async fn get_fields_to_build(
267 &self,
268 dense_fields: &[(Field, DenseVectorConfig)],
269 ) -> Vec<(Field, DenseVectorConfig)> {
270 let metadata_arc = self.segment_manager.metadata();
271 let meta = metadata_arc.read().await;
272 dense_fields
273 .iter()
274 .filter(|(field, _)| !meta.is_field_built(field.0))
275 .cloned()
276 .collect()
277 }
278
279 async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
282 let mut total_vectors = 0usize;
283 let mut doc_offset = 0u32;
284
285 for id_str in segment_ids {
286 let Some(segment_id) = SegmentId::from_hex(id_str) else {
287 continue;
288 };
289
290 let files = crate::segment::SegmentFiles::new(segment_id.0);
292 if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
293 continue;
295 }
296
297 if let Ok(reader) = SegmentReader::open(
299 self.directory.as_ref(),
300 segment_id,
301 Arc::clone(&self.schema),
302 doc_offset,
303 self.config.term_cache_blocks,
304 )
305 .await
306 {
307 for flat_data in reader.flat_vectors().values() {
308 total_vectors += flat_data.num_vectors;
309 }
310 doc_offset += reader.meta().num_docs;
311 }
312 }
313
314 total_vectors
315 }
316
317 async fn collect_vectors_for_training(
322 &self,
323 segment_ids: &[String],
324 fields_to_build: &[(Field, DenseVectorConfig)],
325 ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
326 const MAX_TRAINING_VECTORS: usize = 100_000;
328
329 let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
330 let mut doc_offset = 0u32;
331 let mut total_skipped = 0usize;
332
333 for id_str in segment_ids {
334 let segment_id = SegmentId::from_hex(id_str)
335 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
336 let reader = SegmentReader::open(
337 self.directory.as_ref(),
338 segment_id,
339 Arc::clone(&self.schema),
340 doc_offset,
341 self.config.term_cache_blocks,
342 )
343 .await?;
344
345 for (field_id, lazy_flat) in reader.flat_vectors() {
346 if !fields_to_build.iter().any(|(f, _)| f.0 == *field_id) {
347 continue;
348 }
349 let entry = all_vectors.entry(*field_id).or_default();
350 let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
351
352 if remaining == 0 {
353 total_skipped += lazy_flat.num_vectors;
354 continue;
355 }
356
357 let n = lazy_flat.num_vectors;
358 let dim = lazy_flat.dim;
359 let quant = lazy_flat.quantization;
360
361 let indices: Vec<usize> = if n <= remaining {
363 (0..n).collect()
364 } else {
365 let step = (n / remaining).max(1);
366 (0..n).step_by(step).take(remaining).collect()
367 };
368
369 if indices.len() < n {
370 total_skipped += n - indices.len();
371 }
372
373 const BATCH: usize = 1024;
375 let mut f32_buf = vec![0f32; BATCH * dim];
376 for chunk in indices.chunks(BATCH) {
377 let start = chunk[0];
379 let end = *chunk.last().unwrap();
380 if end - start + 1 == chunk.len() {
381 if let Ok(batch_bytes) =
383 lazy_flat.read_vectors_batch(start, chunk.len()).await
384 {
385 let floats = chunk.len() * dim;
386 f32_buf.resize(floats, 0.0);
387 crate::segment::dequantize_raw(
388 batch_bytes.as_slice(),
389 quant,
390 floats,
391 &mut f32_buf,
392 );
393 for i in 0..chunk.len() {
394 entry.push(f32_buf[i * dim..(i + 1) * dim].to_vec());
395 }
396 }
397 } else {
398 f32_buf.resize(dim, 0.0);
400 for &idx in chunk {
401 if let Ok(()) = lazy_flat.read_vector_into(idx, &mut f32_buf).await {
402 entry.push(f32_buf[..dim].to_vec());
403 }
404 }
405 }
406 }
407 }
408
409 doc_offset += reader.meta().num_docs;
410 }
411
412 if total_skipped > 0 {
413 let collected: usize = all_vectors.values().map(|v| v.len()).sum();
414 log::info!(
415 "Sampled {} vectors for training (skipped {}, max {} per field)",
416 collected,
417 total_skipped,
418 MAX_TRAINING_VECTORS,
419 );
420 }
421
422 Ok(all_vectors)
423 }
424
425 pub(super) async fn load_segment_readers(
427 &self,
428 segment_ids: &[String],
429 ) -> Result<Vec<SegmentReader>> {
430 let mut readers = Vec::new();
431 let mut doc_offset = 0u32;
432
433 for id_str in segment_ids {
434 let segment_id = SegmentId::from_hex(id_str)
435 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
436 let reader = SegmentReader::open(
437 self.directory.as_ref(),
438 segment_id,
439 Arc::clone(&self.schema),
440 doc_offset,
441 self.config.term_cache_blocks,
442 )
443 .await?;
444 doc_offset += reader.meta().num_docs;
445 readers.push(reader);
446 }
447
448 Ok(readers)
449 }
450
451 async fn train_field_index(
453 &self,
454 field: Field,
455 config: &DenseVectorConfig,
456 all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
457 ) -> Result<()> {
458 let field_id = field.0;
459 let vectors = match all_vectors.get(&field_id) {
460 Some(v) if !v.is_empty() => v,
461 _ => return Ok(()),
462 };
463
464 let dim = config.dim;
465 let num_vectors = vectors.len();
466 let num_clusters = config.optimal_num_clusters(num_vectors);
467
468 log::info!(
469 "Training vector index for field {} with {} vectors, {} clusters (dim={})",
470 field_id,
471 num_vectors,
472 num_clusters,
473 dim,
474 );
475
476 let centroids_filename = format!("field_{}_centroids.bin", field_id);
477 let mut codebook_filename: Option<String> = None;
478
479 match config.index_type {
480 VectorIndexType::IvfRaBitQ => {
481 self.train_ivf_rabitq(field_id, dim, num_clusters, vectors, ¢roids_filename)
482 .await?;
483 }
484 VectorIndexType::ScaNN => {
485 codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
486 self.train_scann(
487 field_id,
488 dim,
489 num_clusters,
490 vectors,
491 ¢roids_filename,
492 codebook_filename.as_ref().unwrap(),
493 )
494 .await?;
495 }
496 _ => {
497 return Ok(());
499 }
500 }
501
502 self.segment_manager
504 .update_metadata(|meta| {
505 meta.init_field(field_id, config.index_type);
506 meta.total_vectors = num_vectors;
507 meta.mark_field_built(
508 field_id,
509 num_vectors,
510 num_clusters,
511 centroids_filename.clone(),
512 codebook_filename.clone(),
513 );
514 })
515 .await?;
516
517 Ok(())
518 }
519
520 async fn save_trained_artifact(
522 &self,
523 artifact: &impl serde::Serialize,
524 filename: &str,
525 ) -> Result<()> {
526 let bytes =
527 serde_json::to_vec(artifact).map_err(|e| Error::Serialization(e.to_string()))?;
528 self.directory
529 .write(std::path::Path::new(filename), &bytes)
530 .await?;
531 Ok(())
532 }
533
534 async fn train_ivf_rabitq(
536 &self,
537 field_id: u32,
538 dim: usize,
539 num_clusters: usize,
540 vectors: &[Vec<f32>],
541 centroids_filename: &str,
542 ) -> Result<()> {
543 let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
544 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
545 self.save_trained_artifact(¢roids, centroids_filename)
546 .await?;
547
548 log::info!(
549 "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
550 field_id,
551 centroids.num_clusters
552 );
553 Ok(())
554 }
555
556 async fn train_scann(
558 &self,
559 field_id: u32,
560 dim: usize,
561 num_clusters: usize,
562 vectors: &[Vec<f32>],
563 centroids_filename: &str,
564 codebook_filename: &str,
565 ) -> Result<()> {
566 let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
567 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
568 self.save_trained_artifact(¢roids, centroids_filename)
569 .await?;
570
571 let pq_config = crate::structures::PQConfig::new(dim);
572 let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
573 self.save_trained_artifact(&codebook, codebook_filename)
574 .await?;
575
576 log::info!(
577 "Saved ScaNN centroids and codebook for field {} ({} clusters)",
578 field_id,
579 centroids.num_clusters
580 );
581 Ok(())
582 }
583}