1use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20 pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22 let dense_fields = self.get_dense_vector_fields();
23 if dense_fields.is_empty() {
24 return Ok(());
25 }
26
27 let all_built = {
30 let metadata_arc = self.segment_manager.metadata();
31 let meta = metadata_arc.read().await;
32 dense_fields
33 .iter()
34 .all(|(field, _)| meta.is_field_built(field.0))
35 };
36 if all_built {
37 return Ok(());
38 }
39
40 let segment_ids = self.segment_manager.get_segment_ids().await;
42 let total_vectors = self.count_flat_vectors(&segment_ids).await;
43
44 let should_build = {
46 let metadata_arc = self.segment_manager.metadata();
47 let mut meta = metadata_arc.write().await;
48 meta.total_vectors = total_vectors;
49 dense_fields.iter().any(|(field, config)| {
50 let threshold = config.build_threshold.unwrap_or(1000);
51 meta.should_build_field(field.0, threshold)
52 })
53 };
54
55 if should_build {
56 log::info!(
57 "Threshold crossed ({} vectors), auto-triggering vector index build",
58 total_vectors
59 );
60 self.build_vector_index().await?;
61 }
62
63 Ok(())
64 }
65
66 pub async fn build_vector_index(&self) -> Result<()> {
80 let dense_fields = self.get_dense_vector_fields();
81 if dense_fields.is_empty() {
82 log::info!("No dense vector fields configured for ANN indexing");
83 return Ok(());
84 }
85
86 let fields_to_build = self.get_fields_to_build(&dense_fields).await;
88 if fields_to_build.is_empty() {
89 log::info!("All vector fields already built, skipping training");
90 return Ok(());
91 }
92
93 self.segment_manager.wait_for_merges().await;
97
98 let segment_ids = self.segment_manager.get_segment_ids().await;
99 if segment_ids.is_empty() {
100 return Ok(());
101 }
102
103 let all_vectors = self
105 .collect_vectors_for_training(&segment_ids, &fields_to_build)
106 .await?;
107
108 for (field, config) in &fields_to_build {
110 self.train_field_index(*field, config, &all_vectors).await?;
111 }
112
113 log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
114
115 self.rebuild_segments_with_ann().await?;
117
118 Ok(())
119 }
120
121 pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
123 self.segment_manager.pause_merges();
127 self.segment_manager.wait_for_merges().await;
128
129 let result = self.rebuild_segments_with_ann_inner().await;
130
131 self.segment_manager.resume_merges();
133
134 result
135 }
136
137 async fn rebuild_segments_with_ann_inner(&self) -> Result<()> {
138 let segment_ids = self.segment_manager.get_segment_ids().await;
139 if segment_ids.is_empty() {
140 return Ok(());
141 }
142
143 let (trained_centroids, trained_codebooks) = {
145 let metadata_arc = self.segment_manager.metadata();
146 let meta = metadata_arc.read().await;
147 meta.load_trained_structures(self.directory.as_ref()).await
148 };
149
150 if trained_centroids.is_empty() {
151 log::info!("No trained structures to rebuild with");
152 return Ok(());
153 }
154
155 let trained = TrainedVectorStructures {
156 centroids: trained_centroids,
157 codebooks: trained_codebooks,
158 };
159
160 let readers = self.load_segment_readers(&segment_ids).await?;
162
163 let total_docs: u32 = readers.iter().map(|r| r.meta().num_docs).sum();
165
166 let merger = SegmentMerger::new(Arc::clone(&self.schema));
168 let new_segment_id = SegmentId::new();
169 merger
170 .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
171 .await?;
172
173 self.segment_manager
175 .replace_segments(vec![(new_segment_id.to_hex(), total_docs)], segment_ids)
176 .await?;
177
178 log::info!("Segments rebuilt with ANN indexes");
179 Ok(())
180 }
181
182 pub async fn total_vector_count(&self) -> usize {
184 let metadata_arc = self.segment_manager.metadata();
185 metadata_arc.read().await.total_vectors
186 }
187
188 pub async fn is_vector_index_built(&self, field: Field) -> bool {
190 let metadata_arc = self.segment_manager.metadata();
191 metadata_arc.read().await.is_field_built(field.0)
192 }
193
194 pub async fn rebuild_vector_index(&self) -> Result<()> {
203 let dense_fields = self.get_dense_vector_fields();
204 if dense_fields.is_empty() {
205 return Ok(());
206 }
207 let dense_fields: Vec<Field> = dense_fields.into_iter().map(|(f, _)| f).collect();
208
209 let files_to_delete = {
211 let metadata_arc = self.segment_manager.metadata();
212 let mut meta = metadata_arc.write().await;
213 let mut files = Vec::new();
214 for field in &dense_fields {
215 if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
216 field_meta.state = super::VectorIndexState::Flat;
217 if let Some(ref f) = field_meta.centroids_file {
218 files.push(f.clone());
219 }
220 if let Some(ref f) = field_meta.codebook_file {
221 files.push(f.clone());
222 }
223 field_meta.centroids_file = None;
224 field_meta.codebook_file = None;
225 }
226 }
227 meta.save(self.directory.as_ref()).await?;
228 files
229 };
230
231 for file in files_to_delete {
233 let _ = self.directory.delete(std::path::Path::new(&file)).await;
234 }
235
236 log::info!("Reset vector index state to Flat, triggering rebuild...");
237
238 self.build_vector_index().await
240 }
241
242 fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
248 self.schema
249 .fields()
250 .filter_map(|(field, entry)| {
251 if entry.field_type == FieldType::DenseVector && entry.indexed {
252 entry
253 .dense_vector_config
254 .as_ref()
255 .filter(|c| !c.is_flat())
256 .map(|c| (field, c.clone()))
257 } else {
258 None
259 }
260 })
261 .collect()
262 }
263
264 async fn get_fields_to_build(
266 &self,
267 dense_fields: &[(Field, DenseVectorConfig)],
268 ) -> Vec<(Field, DenseVectorConfig)> {
269 let metadata_arc = self.segment_manager.metadata();
270 let meta = metadata_arc.read().await;
271 dense_fields
272 .iter()
273 .filter(|(field, _)| !meta.is_field_built(field.0))
274 .cloned()
275 .collect()
276 }
277
278 async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
281 let mut total_vectors = 0usize;
282 let mut doc_offset = 0u32;
283
284 for id_str in segment_ids {
285 let Some(segment_id) = SegmentId::from_hex(id_str) else {
286 continue;
287 };
288
289 let files = crate::segment::SegmentFiles::new(segment_id.0);
291 if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
292 continue;
294 }
295
296 if let Ok(reader) = SegmentReader::open(
298 self.directory.as_ref(),
299 segment_id,
300 Arc::clone(&self.schema),
301 doc_offset,
302 self.config.term_cache_blocks,
303 )
304 .await
305 {
306 for flat_data in reader.flat_vectors().values() {
307 total_vectors += flat_data.num_vectors;
308 }
309 doc_offset += reader.meta().num_docs;
310 }
311 }
312
313 total_vectors
314 }
315
316 async fn collect_vectors_for_training(
321 &self,
322 segment_ids: &[String],
323 fields_to_build: &[(Field, DenseVectorConfig)],
324 ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
325 const MAX_TRAINING_VECTORS: usize = 100_000;
327
328 let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
329 let mut doc_offset = 0u32;
330 let mut total_skipped = 0usize;
331
332 for id_str in segment_ids {
333 let segment_id = SegmentId::from_hex(id_str)
334 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
335 let reader = SegmentReader::open(
336 self.directory.as_ref(),
337 segment_id,
338 Arc::clone(&self.schema),
339 doc_offset,
340 self.config.term_cache_blocks,
341 )
342 .await?;
343
344 for (field_id, lazy_flat) in reader.flat_vectors() {
345 if !fields_to_build.iter().any(|(f, _)| f.0 == *field_id) {
346 continue;
347 }
348 let entry = all_vectors.entry(*field_id).or_default();
349 let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
350
351 if remaining == 0 {
352 total_skipped += lazy_flat.num_vectors;
353 continue;
354 }
355
356 let n = lazy_flat.num_vectors;
357 let dim = lazy_flat.dim;
358 let quant = lazy_flat.quantization;
359
360 let indices: Vec<usize> = if n <= remaining {
362 (0..n).collect()
363 } else {
364 let step = (n / remaining).max(1);
365 (0..n).step_by(step).take(remaining).collect()
366 };
367
368 if indices.len() < n {
369 total_skipped += n - indices.len();
370 }
371
372 const BATCH: usize = 1024;
374 let mut f32_buf = vec![0f32; BATCH * dim];
375 for chunk in indices.chunks(BATCH) {
376 let start = chunk[0];
378 let end = *chunk.last().unwrap();
379 if end - start + 1 == chunk.len() {
380 if let Ok(batch_bytes) =
382 lazy_flat.read_vectors_batch(start, chunk.len()).await
383 {
384 let floats = chunk.len() * dim;
385 f32_buf.resize(floats, 0.0);
386 crate::segment::dequantize_raw(
387 batch_bytes.as_slice(),
388 quant,
389 floats,
390 &mut f32_buf,
391 );
392 for i in 0..chunk.len() {
393 entry.push(f32_buf[i * dim..(i + 1) * dim].to_vec());
394 }
395 }
396 } else {
397 f32_buf.resize(dim, 0.0);
399 for &idx in chunk {
400 if let Ok(()) = lazy_flat.read_vector_into(idx, &mut f32_buf).await {
401 entry.push(f32_buf[..dim].to_vec());
402 }
403 }
404 }
405 }
406 }
407
408 doc_offset += reader.meta().num_docs;
409 }
410
411 if total_skipped > 0 {
412 let collected: usize = all_vectors.values().map(|v| v.len()).sum();
413 log::info!(
414 "Sampled {} vectors for training (skipped {}, max {} per field)",
415 collected,
416 total_skipped,
417 MAX_TRAINING_VECTORS,
418 );
419 }
420
421 Ok(all_vectors)
422 }
423
424 pub(super) async fn load_segment_readers(
426 &self,
427 segment_ids: &[String],
428 ) -> Result<Vec<SegmentReader>> {
429 let mut readers = Vec::new();
430 let mut doc_offset = 0u32;
431
432 for id_str in segment_ids {
433 let segment_id = SegmentId::from_hex(id_str)
434 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
435 let reader = SegmentReader::open(
436 self.directory.as_ref(),
437 segment_id,
438 Arc::clone(&self.schema),
439 doc_offset,
440 self.config.term_cache_blocks,
441 )
442 .await?;
443 doc_offset += reader.meta().num_docs;
444 readers.push(reader);
445 }
446
447 Ok(readers)
448 }
449
450 async fn train_field_index(
452 &self,
453 field: Field,
454 config: &DenseVectorConfig,
455 all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
456 ) -> Result<()> {
457 let field_id = field.0;
458 let vectors = match all_vectors.get(&field_id) {
459 Some(v) if !v.is_empty() => v,
460 _ => return Ok(()),
461 };
462
463 let dim = config.dim;
464 let num_vectors = vectors.len();
465 let num_clusters = config.optimal_num_clusters(num_vectors);
466
467 log::info!(
468 "Training vector index for field {} with {} vectors, {} clusters (dim={})",
469 field_id,
470 num_vectors,
471 num_clusters,
472 dim,
473 );
474
475 let centroids_filename = format!("field_{}_centroids.bin", field_id);
476 let mut codebook_filename: Option<String> = None;
477
478 match config.index_type {
479 VectorIndexType::IvfRaBitQ => {
480 self.train_ivf_rabitq(field_id, dim, num_clusters, vectors, ¢roids_filename)
481 .await?;
482 }
483 VectorIndexType::ScaNN => {
484 codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
485 self.train_scann(
486 field_id,
487 dim,
488 num_clusters,
489 vectors,
490 ¢roids_filename,
491 codebook_filename.as_ref().unwrap(),
492 )
493 .await?;
494 }
495 _ => {
496 return Ok(());
498 }
499 }
500
501 self.segment_manager
503 .update_metadata(|meta| {
504 meta.init_field(field_id, config.index_type);
505 meta.total_vectors = num_vectors;
506 meta.mark_field_built(
507 field_id,
508 num_vectors,
509 num_clusters,
510 centroids_filename.clone(),
511 codebook_filename.clone(),
512 );
513 })
514 .await?;
515
516 Ok(())
517 }
518
519 async fn save_trained_artifact(
521 &self,
522 artifact: &impl serde::Serialize,
523 filename: &str,
524 ) -> Result<()> {
525 let bytes =
526 serde_json::to_vec(artifact).map_err(|e| Error::Serialization(e.to_string()))?;
527 self.directory
528 .write(std::path::Path::new(filename), &bytes)
529 .await?;
530 Ok(())
531 }
532
533 async fn train_ivf_rabitq(
535 &self,
536 field_id: u32,
537 dim: usize,
538 num_clusters: usize,
539 vectors: &[Vec<f32>],
540 centroids_filename: &str,
541 ) -> Result<()> {
542 let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
543 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
544 self.save_trained_artifact(¢roids, centroids_filename)
545 .await?;
546
547 log::info!(
548 "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
549 field_id,
550 centroids.num_clusters
551 );
552 Ok(())
553 }
554
555 async fn train_scann(
557 &self,
558 field_id: u32,
559 dim: usize,
560 num_clusters: usize,
561 vectors: &[Vec<f32>],
562 centroids_filename: &str,
563 codebook_filename: &str,
564 ) -> Result<()> {
565 let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
566 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
567 self.save_trained_artifact(¢roids, centroids_filename)
568 .await?;
569
570 let pq_config = crate::structures::PQConfig::new(dim);
571 let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
572 self.save_trained_artifact(&codebook, codebook_filename)
573 .await?;
574
575 log::info!(
576 "Saved ScaNN centroids and codebook for field {} ({} clusters)",
577 field_id,
578 centroids.num_clusters
579 );
580 Ok(())
581 }
582}