1use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20 pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22 let dense_fields = self.get_dense_vector_fields();
23 if dense_fields.is_empty() {
24 return Ok(());
25 }
26
27 let all_built = {
30 let metadata_arc = self.segment_manager.metadata();
31 let meta = metadata_arc.read().await;
32 dense_fields
33 .iter()
34 .all(|(field, _)| meta.is_field_built(field.0))
35 };
36 if all_built {
37 return Ok(());
38 }
39
40 let segment_ids = self.segment_manager.get_segment_ids().await;
42 let total_vectors = self.count_flat_vectors(&segment_ids).await;
43
44 let should_build = {
46 let metadata_arc = self.segment_manager.metadata();
47 let mut meta = metadata_arc.write().await;
48 meta.total_vectors = total_vectors;
49 dense_fields.iter().any(|(field, config)| {
50 let threshold = config.build_threshold.unwrap_or(1000);
51 meta.should_build_field(field.0, threshold)
52 })
53 };
54
55 if should_build {
56 log::info!(
57 "Threshold crossed ({} vectors), auto-triggering vector index build",
58 total_vectors
59 );
60 self.build_vector_index().await?;
61 }
62
63 Ok(())
64 }
65
66 pub async fn build_vector_index(&self) -> Result<()> {
80 let dense_fields = self.get_dense_vector_fields();
81 if dense_fields.is_empty() {
82 log::info!("No dense vector fields configured for ANN indexing");
83 return Ok(());
84 }
85
86 let fields_to_build = self.get_fields_to_build(&dense_fields).await;
88 if fields_to_build.is_empty() {
89 log::info!("All vector fields already built, skipping training");
90 return Ok(());
91 }
92
93 self.segment_manager.wait_for_merges().await;
98
99 let segment_ids = self.segment_manager.get_segment_ids().await;
100 if segment_ids.is_empty() {
101 return Ok(());
102 }
103
104 let all_vectors = self
106 .collect_vectors_for_training(&segment_ids, &fields_to_build)
107 .await?;
108
109 for (field, config) in &fields_to_build {
111 self.train_field_index(*field, config, &all_vectors).await?;
112 }
113
114 log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
115
116 self.rebuild_segments_with_ann().await?;
118
119 Ok(())
120 }
121
122 pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
124 self.segment_manager.pause_merges();
128 self.segment_manager.wait_for_merges().await;
129
130 let result = self.rebuild_segments_with_ann_inner().await;
131
132 self.segment_manager.resume_merges();
134
135 result
136 }
137
138 async fn rebuild_segments_with_ann_inner(&self) -> Result<()> {
139 let segment_ids = self.segment_manager.get_segment_ids().await;
140 if segment_ids.is_empty() {
141 return Ok(());
142 }
143
144 let (trained_centroids, trained_codebooks) = {
146 let metadata_arc = self.segment_manager.metadata();
147 let meta = metadata_arc.read().await;
148 meta.load_trained_structures(self.directory.as_ref()).await
149 };
150
151 if trained_centroids.is_empty() {
152 log::info!("No trained structures to rebuild with");
153 return Ok(());
154 }
155
156 let trained = TrainedVectorStructures {
157 centroids: trained_centroids,
158 codebooks: trained_codebooks,
159 };
160
161 let readers = self.load_segment_readers(&segment_ids).await?;
163
164 let total_docs: u32 = readers.iter().map(|r| r.meta().num_docs).sum();
166
167 let merger = SegmentMerger::new(Arc::clone(&self.schema));
169 let new_segment_id = SegmentId::new();
170 merger
171 .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
172 .await?;
173
174 self.segment_manager
176 .replace_segments(vec![(new_segment_id.to_hex(), total_docs)], segment_ids)
177 .await?;
178
179 log::info!("Segments rebuilt with ANN indexes");
180 Ok(())
181 }
182
183 pub async fn total_vector_count(&self) -> usize {
185 let metadata_arc = self.segment_manager.metadata();
186 metadata_arc.read().await.total_vectors
187 }
188
189 pub async fn is_vector_index_built(&self, field: Field) -> bool {
191 let metadata_arc = self.segment_manager.metadata();
192 metadata_arc.read().await.is_field_built(field.0)
193 }
194
195 pub async fn rebuild_vector_index(&self) -> Result<()> {
204 let dense_fields: Vec<Field> = self
205 .schema
206 .fields()
207 .filter_map(|(field, entry)| {
208 if entry.field_type == FieldType::DenseVector && entry.indexed {
209 Some(field)
210 } else {
211 None
212 }
213 })
214 .collect();
215
216 if dense_fields.is_empty() {
217 return Ok(());
218 }
219
220 let files_to_delete = {
222 let metadata_arc = self.segment_manager.metadata();
223 let mut meta = metadata_arc.write().await;
224 let mut files = Vec::new();
225 for field in &dense_fields {
226 if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
227 field_meta.state = super::VectorIndexState::Flat;
228 if let Some(ref f) = field_meta.centroids_file {
229 files.push(f.clone());
230 }
231 if let Some(ref f) = field_meta.codebook_file {
232 files.push(f.clone());
233 }
234 field_meta.centroids_file = None;
235 field_meta.codebook_file = None;
236 }
237 }
238 meta.save(self.directory.as_ref()).await?;
239 files
240 };
241
242 for file in files_to_delete {
244 let _ = self.directory.delete(std::path::Path::new(&file)).await;
245 }
246
247 log::info!("Reset vector index state to Flat, triggering rebuild...");
248
249 self.build_vector_index().await
251 }
252
253 fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
259 self.schema
260 .fields()
261 .filter_map(|(field, entry)| {
262 if entry.field_type == FieldType::DenseVector && entry.indexed {
263 entry
264 .dense_vector_config
265 .as_ref()
266 .filter(|c| !c.is_flat())
267 .map(|c| (field, c.clone()))
268 } else {
269 None
270 }
271 })
272 .collect()
273 }
274
275 async fn get_fields_to_build(
277 &self,
278 dense_fields: &[(Field, DenseVectorConfig)],
279 ) -> Vec<(Field, DenseVectorConfig)> {
280 let metadata_arc = self.segment_manager.metadata();
281 let meta = metadata_arc.read().await;
282 dense_fields
283 .iter()
284 .filter(|(field, _)| !meta.is_field_built(field.0))
285 .cloned()
286 .collect()
287 }
288
289 async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
292 let mut total_vectors = 0usize;
293 let mut doc_offset = 0u32;
294
295 for id_str in segment_ids {
296 let Some(segment_id) = SegmentId::from_hex(id_str) else {
297 continue;
298 };
299
300 let files = crate::segment::SegmentFiles::new(segment_id.0);
302 if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
303 continue;
305 }
306
307 if let Ok(reader) = SegmentReader::open(
309 self.directory.as_ref(),
310 segment_id,
311 Arc::clone(&self.schema),
312 doc_offset,
313 self.config.term_cache_blocks,
314 )
315 .await
316 {
317 for index in reader.vector_indexes().values() {
318 if let crate::segment::VectorIndex::Flat(flat_data) = index {
319 total_vectors += flat_data.vectors.len();
320 }
321 }
322 doc_offset += reader.meta().num_docs;
323 }
324 }
325
326 total_vectors
327 }
328
329 async fn collect_vectors_for_training(
334 &self,
335 segment_ids: &[String],
336 fields_to_build: &[(Field, DenseVectorConfig)],
337 ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
338 const MAX_TRAINING_VECTORS: usize = 100_000;
340
341 let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
342 let mut doc_offset = 0u32;
343 let mut total_skipped = 0usize;
344
345 for id_str in segment_ids {
346 let segment_id = SegmentId::from_hex(id_str)
347 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
348 let reader = SegmentReader::open(
349 self.directory.as_ref(),
350 segment_id,
351 Arc::clone(&self.schema),
352 doc_offset,
353 self.config.term_cache_blocks,
354 )
355 .await?;
356
357 for (field_id, index) in reader.vector_indexes() {
358 if fields_to_build.iter().any(|(f, _)| f.0 == *field_id)
359 && let crate::segment::VectorIndex::Flat(flat_data) = index
360 {
361 let entry = all_vectors.entry(*field_id).or_default();
362 let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
363
364 if remaining == 0 {
365 total_skipped += flat_data.vectors.len();
366 continue;
367 }
368
369 if flat_data.vectors.len() <= remaining {
370 entry.extend(flat_data.vectors.iter().cloned());
372 } else {
373 let step = (flat_data.vectors.len() / remaining).max(1);
375 for (i, vec) in flat_data.vectors.iter().enumerate() {
376 if i % step == 0 && entry.len() < MAX_TRAINING_VECTORS {
377 entry.push(vec.clone());
378 }
379 }
380 total_skipped += flat_data.vectors.len() - remaining;
381 }
382 }
383 }
384
385 doc_offset += reader.meta().num_docs;
386 }
387
388 if total_skipped > 0 {
389 let collected: usize = all_vectors.values().map(|v| v.len()).sum();
390 log::info!(
391 "Sampled {} vectors for training (skipped {}, max {} per field)",
392 collected,
393 total_skipped,
394 MAX_TRAINING_VECTORS,
395 );
396 }
397
398 Ok(all_vectors)
399 }
400
401 pub(super) async fn load_segment_readers(
403 &self,
404 segment_ids: &[String],
405 ) -> Result<Vec<SegmentReader>> {
406 let mut readers = Vec::new();
407 let mut doc_offset = 0u32;
408
409 for id_str in segment_ids {
410 let segment_id = SegmentId::from_hex(id_str)
411 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
412 let reader = SegmentReader::open(
413 self.directory.as_ref(),
414 segment_id,
415 Arc::clone(&self.schema),
416 doc_offset,
417 self.config.term_cache_blocks,
418 )
419 .await?;
420 doc_offset += reader.meta().num_docs;
421 readers.push(reader);
422 }
423
424 Ok(readers)
425 }
426
427 async fn train_field_index(
429 &self,
430 field: Field,
431 config: &DenseVectorConfig,
432 all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
433 ) -> Result<()> {
434 let field_id = field.0;
435 let vectors = match all_vectors.get(&field_id) {
436 Some(v) if !v.is_empty() => v,
437 _ => return Ok(()),
438 };
439
440 let index_dim = config.index_dim();
441 let num_vectors = vectors.len();
442 let num_clusters = config.optimal_num_clusters(num_vectors);
443
444 log::info!(
445 "Training vector index for field {} with {} vectors, {} clusters",
446 field_id,
447 num_vectors,
448 num_clusters
449 );
450
451 let centroids_filename = format!("field_{}_centroids.bin", field_id);
452 let mut codebook_filename: Option<String> = None;
453
454 match config.index_type {
455 VectorIndexType::IvfRaBitQ => {
456 self.train_ivf_rabitq(
457 field_id,
458 index_dim,
459 num_clusters,
460 vectors,
461 ¢roids_filename,
462 )
463 .await?;
464 }
465 VectorIndexType::ScaNN => {
466 codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
467 self.train_scann(
468 field_id,
469 index_dim,
470 num_clusters,
471 vectors,
472 ¢roids_filename,
473 codebook_filename.as_ref().unwrap(),
474 )
475 .await?;
476 }
477 _ => {
478 return Ok(());
480 }
481 }
482
483 self.segment_manager
485 .update_metadata(|meta| {
486 meta.init_field(field_id, config.index_type);
487 meta.total_vectors = num_vectors;
488 meta.mark_field_built(
489 field_id,
490 num_vectors,
491 num_clusters,
492 centroids_filename.clone(),
493 codebook_filename.clone(),
494 );
495 })
496 .await?;
497
498 Ok(())
499 }
500
501 async fn train_ivf_rabitq(
503 &self,
504 field_id: u32,
505 index_dim: usize,
506 num_clusters: usize,
507 vectors: &[Vec<f32>],
508 centroids_filename: &str,
509 ) -> Result<()> {
510 let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
511 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
512
513 let centroids_path = std::path::Path::new(centroids_filename);
515 let centroids_bytes =
516 serde_json::to_vec(¢roids).map_err(|e| Error::Serialization(e.to_string()))?;
517 self.directory
518 .write(centroids_path, ¢roids_bytes)
519 .await?;
520
521 log::info!(
522 "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
523 field_id,
524 centroids.num_clusters
525 );
526
527 Ok(())
528 }
529
530 async fn train_scann(
532 &self,
533 field_id: u32,
534 index_dim: usize,
535 num_clusters: usize,
536 vectors: &[Vec<f32>],
537 centroids_filename: &str,
538 codebook_filename: &str,
539 ) -> Result<()> {
540 let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
542 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
543
544 let pq_config = crate::structures::PQConfig::new(index_dim);
546 let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
547
548 let centroids_path = std::path::Path::new(centroids_filename);
550 let centroids_bytes =
551 serde_json::to_vec(¢roids).map_err(|e| Error::Serialization(e.to_string()))?;
552 self.directory
553 .write(centroids_path, ¢roids_bytes)
554 .await?;
555
556 let codebook_path = std::path::Path::new(codebook_filename);
558 let codebook_bytes =
559 serde_json::to_vec(&codebook).map_err(|e| Error::Serialization(e.to_string()))?;
560 self.directory.write(codebook_path, &codebook_bytes).await?;
561
562 log::info!(
563 "Saved ScaNN centroids and codebook for field {} ({} clusters)",
564 field_id,
565 centroids.num_clusters
566 );
567
568 Ok(())
569 }
570}