1use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20 pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22 let dense_fields = self.get_dense_vector_fields();
23 if dense_fields.is_empty() {
24 return Ok(());
25 }
26
27 let all_built = {
30 let metadata_arc = self.segment_manager.metadata();
31 let meta = metadata_arc.read().await;
32 dense_fields
33 .iter()
34 .all(|(field, _)| meta.is_field_built(field.0))
35 };
36 if all_built {
37 return Ok(());
38 }
39
40 let segment_ids = self.segment_manager.get_segment_ids().await;
42 let total_vectors = self.count_flat_vectors(&segment_ids).await;
43
44 let should_build = {
46 let metadata_arc = self.segment_manager.metadata();
47 let mut meta = metadata_arc.write().await;
48 meta.total_vectors = total_vectors;
49 dense_fields.iter().any(|(field, config)| {
50 let threshold = config.build_threshold.unwrap_or(1000);
51 meta.should_build_field(field.0, threshold)
52 })
53 };
54
55 if should_build {
56 log::info!(
57 "Threshold crossed ({} vectors), auto-triggering vector index build",
58 total_vectors
59 );
60 self.build_vector_index().await?;
61 }
62
63 Ok(())
64 }
65
66 pub async fn build_vector_index(&self) -> Result<()> {
80 let dense_fields = self.get_dense_vector_fields();
81 if dense_fields.is_empty() {
82 log::info!("No dense vector fields configured for ANN indexing");
83 return Ok(());
84 }
85
86 let fields_to_build = self.get_fields_to_build(&dense_fields).await;
88 if fields_to_build.is_empty() {
89 log::info!("All vector fields already built, skipping training");
90 return Ok(());
91 }
92
93 self.segment_manager.wait_for_merges().await;
98
99 let segment_ids = self.segment_manager.get_segment_ids().await;
100 if segment_ids.is_empty() {
101 return Ok(());
102 }
103
104 let all_vectors = self
106 .collect_vectors_for_training(&segment_ids, &fields_to_build)
107 .await?;
108
109 for (field, config) in &fields_to_build {
111 self.train_field_index(*field, config, &all_vectors).await?;
112 }
113
114 log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
115
116 self.rebuild_segments_with_ann().await?;
118
119 Ok(())
120 }
121
122 pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
124 self.segment_manager.pause_merges();
128 self.segment_manager.wait_for_merges().await;
129
130 let result = self.rebuild_segments_with_ann_inner().await;
131
132 self.segment_manager.resume_merges();
134
135 result
136 }
137
138 async fn rebuild_segments_with_ann_inner(&self) -> Result<()> {
139 let segment_ids = self.segment_manager.get_segment_ids().await;
140 if segment_ids.is_empty() {
141 return Ok(());
142 }
143
144 let (trained_centroids, trained_codebooks) = {
146 let metadata_arc = self.segment_manager.metadata();
147 let meta = metadata_arc.read().await;
148 meta.load_trained_structures(self.directory.as_ref()).await
149 };
150
151 if trained_centroids.is_empty() {
152 log::info!("No trained structures to rebuild with");
153 return Ok(());
154 }
155
156 let trained = TrainedVectorStructures {
157 centroids: trained_centroids,
158 codebooks: trained_codebooks,
159 };
160
161 let readers = self.load_segment_readers(&segment_ids).await?;
163
164 let total_docs: u32 = readers.iter().map(|r| r.meta().num_docs).sum();
166
167 let merger = SegmentMerger::new(Arc::clone(&self.schema));
169 let new_segment_id = SegmentId::new();
170 merger
171 .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
172 .await?;
173
174 self.segment_manager
176 .replace_segments(vec![(new_segment_id.to_hex(), total_docs)], segment_ids)
177 .await?;
178
179 log::info!("Segments rebuilt with ANN indexes");
180 Ok(())
181 }
182
183 pub async fn total_vector_count(&self) -> usize {
185 let metadata_arc = self.segment_manager.metadata();
186 metadata_arc.read().await.total_vectors
187 }
188
189 pub async fn is_vector_index_built(&self, field: Field) -> bool {
191 let metadata_arc = self.segment_manager.metadata();
192 metadata_arc.read().await.is_field_built(field.0)
193 }
194
195 pub async fn rebuild_vector_index(&self) -> Result<()> {
204 let dense_fields: Vec<Field> = self
205 .schema
206 .fields()
207 .filter_map(|(field, entry)| {
208 if entry.field_type == FieldType::DenseVector && entry.indexed {
209 Some(field)
210 } else {
211 None
212 }
213 })
214 .collect();
215
216 if dense_fields.is_empty() {
217 return Ok(());
218 }
219
220 let files_to_delete = {
222 let metadata_arc = self.segment_manager.metadata();
223 let mut meta = metadata_arc.write().await;
224 let mut files = Vec::new();
225 for field in &dense_fields {
226 if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
227 field_meta.state = super::VectorIndexState::Flat;
228 if let Some(ref f) = field_meta.centroids_file {
229 files.push(f.clone());
230 }
231 if let Some(ref f) = field_meta.codebook_file {
232 files.push(f.clone());
233 }
234 field_meta.centroids_file = None;
235 field_meta.codebook_file = None;
236 }
237 }
238 meta.save(self.directory.as_ref()).await?;
239 files
240 };
241
242 for file in files_to_delete {
244 let _ = self.directory.delete(std::path::Path::new(&file)).await;
245 }
246
247 log::info!("Reset vector index state to Flat, triggering rebuild...");
248
249 self.build_vector_index().await
251 }
252
253 fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
259 self.schema
260 .fields()
261 .filter_map(|(field, entry)| {
262 if entry.field_type == FieldType::DenseVector && entry.indexed {
263 entry
264 .dense_vector_config
265 .as_ref()
266 .filter(|c| !c.is_flat())
267 .map(|c| (field, c.clone()))
268 } else {
269 None
270 }
271 })
272 .collect()
273 }
274
275 async fn get_fields_to_build(
277 &self,
278 dense_fields: &[(Field, DenseVectorConfig)],
279 ) -> Vec<(Field, DenseVectorConfig)> {
280 let metadata_arc = self.segment_manager.metadata();
281 let meta = metadata_arc.read().await;
282 dense_fields
283 .iter()
284 .filter(|(field, _)| !meta.is_field_built(field.0))
285 .cloned()
286 .collect()
287 }
288
289 async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
292 let mut total_vectors = 0usize;
293 let mut doc_offset = 0u32;
294
295 for id_str in segment_ids {
296 let Some(segment_id) = SegmentId::from_hex(id_str) else {
297 continue;
298 };
299
300 let files = crate::segment::SegmentFiles::new(segment_id.0);
302 if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
303 continue;
305 }
306
307 if let Ok(reader) = SegmentReader::open(
309 self.directory.as_ref(),
310 segment_id,
311 Arc::clone(&self.schema),
312 doc_offset,
313 self.config.term_cache_blocks,
314 )
315 .await
316 {
317 for index in reader.vector_indexes().values() {
318 if let crate::segment::VectorIndex::Flat(flat_data) = index {
319 total_vectors += flat_data.num_vectors();
320 }
321 }
322 doc_offset += reader.meta().num_docs;
323 }
324 }
325
326 total_vectors
327 }
328
329 async fn collect_vectors_for_training(
334 &self,
335 segment_ids: &[String],
336 fields_to_build: &[(Field, DenseVectorConfig)],
337 ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
338 const MAX_TRAINING_VECTORS: usize = 100_000;
340
341 let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
342 let mut doc_offset = 0u32;
343 let mut total_skipped = 0usize;
344
345 for id_str in segment_ids {
346 let segment_id = SegmentId::from_hex(id_str)
347 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
348 let reader = SegmentReader::open(
349 self.directory.as_ref(),
350 segment_id,
351 Arc::clone(&self.schema),
352 doc_offset,
353 self.config.term_cache_blocks,
354 )
355 .await?;
356
357 for (field_id, index) in reader.vector_indexes() {
358 if fields_to_build.iter().any(|(f, _)| f.0 == *field_id)
359 && let crate::segment::VectorIndex::Flat(flat_data) = index
360 {
361 let entry = all_vectors.entry(*field_id).or_default();
362 let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
363
364 if remaining == 0 {
365 total_skipped += flat_data.num_vectors();
366 continue;
367 }
368
369 let n = flat_data.num_vectors();
370 if n <= remaining {
371 entry.extend((0..n).map(|i| flat_data.get_vector(i).to_vec()));
373 } else {
374 let step = (n / remaining).max(1);
376 for i in 0..n {
377 if i % step == 0 && entry.len() < MAX_TRAINING_VECTORS {
378 entry.push(flat_data.get_vector(i).to_vec());
379 }
380 }
381 total_skipped += n - remaining;
382 }
383 }
384 }
385
386 doc_offset += reader.meta().num_docs;
387 }
388
389 if total_skipped > 0 {
390 let collected: usize = all_vectors.values().map(|v| v.len()).sum();
391 log::info!(
392 "Sampled {} vectors for training (skipped {}, max {} per field)",
393 collected,
394 total_skipped,
395 MAX_TRAINING_VECTORS,
396 );
397 }
398
399 Ok(all_vectors)
400 }
401
402 pub(super) async fn load_segment_readers(
404 &self,
405 segment_ids: &[String],
406 ) -> Result<Vec<SegmentReader>> {
407 let mut readers = Vec::new();
408 let mut doc_offset = 0u32;
409
410 for id_str in segment_ids {
411 let segment_id = SegmentId::from_hex(id_str)
412 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
413 let reader = SegmentReader::open(
414 self.directory.as_ref(),
415 segment_id,
416 Arc::clone(&self.schema),
417 doc_offset,
418 self.config.term_cache_blocks,
419 )
420 .await?;
421 doc_offset += reader.meta().num_docs;
422 readers.push(reader);
423 }
424
425 Ok(readers)
426 }
427
428 async fn train_field_index(
430 &self,
431 field: Field,
432 config: &DenseVectorConfig,
433 all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
434 ) -> Result<()> {
435 let field_id = field.0;
436 let vectors = match all_vectors.get(&field_id) {
437 Some(v) if !v.is_empty() => v,
438 _ => return Ok(()),
439 };
440
441 let index_dim = config.index_dim();
442 let num_vectors = vectors.len();
443 let num_clusters = config.optimal_num_clusters(num_vectors);
444
445 log::info!(
446 "Training vector index for field {} with {} vectors, {} clusters",
447 field_id,
448 num_vectors,
449 num_clusters
450 );
451
452 let centroids_filename = format!("field_{}_centroids.bin", field_id);
453 let mut codebook_filename: Option<String> = None;
454
455 match config.index_type {
456 VectorIndexType::IvfRaBitQ => {
457 self.train_ivf_rabitq(
458 field_id,
459 index_dim,
460 num_clusters,
461 vectors,
462 ¢roids_filename,
463 )
464 .await?;
465 }
466 VectorIndexType::ScaNN => {
467 codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
468 self.train_scann(
469 field_id,
470 index_dim,
471 num_clusters,
472 vectors,
473 ¢roids_filename,
474 codebook_filename.as_ref().unwrap(),
475 )
476 .await?;
477 }
478 _ => {
479 return Ok(());
481 }
482 }
483
484 self.segment_manager
486 .update_metadata(|meta| {
487 meta.init_field(field_id, config.index_type);
488 meta.total_vectors = num_vectors;
489 meta.mark_field_built(
490 field_id,
491 num_vectors,
492 num_clusters,
493 centroids_filename.clone(),
494 codebook_filename.clone(),
495 );
496 })
497 .await?;
498
499 Ok(())
500 }
501
502 async fn train_ivf_rabitq(
504 &self,
505 field_id: u32,
506 index_dim: usize,
507 num_clusters: usize,
508 vectors: &[Vec<f32>],
509 centroids_filename: &str,
510 ) -> Result<()> {
511 let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
512 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
513
514 let centroids_path = std::path::Path::new(centroids_filename);
516 let centroids_bytes =
517 serde_json::to_vec(¢roids).map_err(|e| Error::Serialization(e.to_string()))?;
518 self.directory
519 .write(centroids_path, ¢roids_bytes)
520 .await?;
521
522 log::info!(
523 "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
524 field_id,
525 centroids.num_clusters
526 );
527
528 Ok(())
529 }
530
531 async fn train_scann(
533 &self,
534 field_id: u32,
535 index_dim: usize,
536 num_clusters: usize,
537 vectors: &[Vec<f32>],
538 centroids_filename: &str,
539 codebook_filename: &str,
540 ) -> Result<()> {
541 let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
543 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
544
545 let pq_config = crate::structures::PQConfig::new(index_dim);
547 let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
548
549 let centroids_path = std::path::Path::new(centroids_filename);
551 let centroids_bytes =
552 serde_json::to_vec(¢roids).map_err(|e| Error::Serialization(e.to_string()))?;
553 self.directory
554 .write(centroids_path, ¢roids_bytes)
555 .await?;
556
557 let codebook_path = std::path::Path::new(codebook_filename);
559 let codebook_bytes =
560 serde_json::to_vec(&codebook).map_err(|e| Error::Serialization(e.to_string()))?;
561 self.directory.write(codebook_path, &codebook_bytes).await?;
562
563 log::info!(
564 "Saved ScaNN centroids and codebook for field {} ({} clusters)",
565 field_id,
566 centroids.num_clusters
567 );
568
569 Ok(())
570 }
571}