1use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20 pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22 let dense_fields = self.get_dense_vector_fields();
23 if dense_fields.is_empty() {
24 return Ok(());
25 }
26
27 let all_built = {
30 let metadata_arc = self.segment_manager.metadata();
31 let meta = metadata_arc.read().await;
32 dense_fields
33 .iter()
34 .all(|(field, _)| meta.is_field_built(field.0))
35 };
36 if all_built {
37 return Ok(());
38 }
39
40 let segment_ids = self.segment_manager.get_segment_ids().await;
42 let total_vectors = self.count_flat_vectors(&segment_ids).await;
43
44 let should_build = {
46 let metadata_arc = self.segment_manager.metadata();
47 let mut meta = metadata_arc.write().await;
48 meta.total_vectors = total_vectors;
49 dense_fields.iter().any(|(field, config)| {
50 let threshold = config.build_threshold.unwrap_or(1000);
51 meta.should_build_field(field.0, threshold)
52 })
53 };
54
55 if should_build {
56 log::info!(
57 "Threshold crossed ({} vectors), auto-triggering vector index build",
58 total_vectors
59 );
60 self.build_vector_index().await?;
61 }
62
63 Ok(())
64 }
65
66 pub async fn build_vector_index(&self) -> Result<()> {
80 let dense_fields = self.get_dense_vector_fields();
81 if dense_fields.is_empty() {
82 log::info!("No dense vector fields configured for ANN indexing");
83 return Ok(());
84 }
85
86 let fields_to_build = self.get_fields_to_build(&dense_fields).await;
88 if fields_to_build.is_empty() {
89 log::info!("All vector fields already built, skipping training");
90 return Ok(());
91 }
92
93 let segment_ids = self.segment_manager.get_segment_ids().await;
94 if segment_ids.is_empty() {
95 return Ok(());
96 }
97
98 let all_vectors = self
100 .collect_vectors_for_training(&segment_ids, &fields_to_build)
101 .await?;
102
103 for (field, config) in &fields_to_build {
105 self.train_field_index(*field, config, &all_vectors).await?;
106 }
107
108 log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
109
110 self.rebuild_segments_with_ann().await?;
112
113 Ok(())
114 }
115
116 pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
118 let segment_ids = self.segment_manager.get_segment_ids().await;
119 if segment_ids.is_empty() {
120 return Ok(());
121 }
122
123 let (trained_centroids, trained_codebooks) = {
125 let metadata_arc = self.segment_manager.metadata();
126 let meta = metadata_arc.read().await;
127 meta.load_trained_structures(self.directory.as_ref()).await
128 };
129
130 if trained_centroids.is_empty() {
131 log::info!("No trained structures to rebuild with");
132 return Ok(());
133 }
134
135 let trained = TrainedVectorStructures {
136 centroids: trained_centroids,
137 codebooks: trained_codebooks,
138 };
139
140 let readers = self.load_segment_readers(&segment_ids).await?;
142
143 let total_docs: u32 = readers.iter().map(|r| r.meta().num_docs).sum();
145
146 let merger = SegmentMerger::new(Arc::clone(&self.schema));
148 let new_segment_id = SegmentId::new();
149 merger
150 .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
151 .await?;
152
153 self.segment_manager
155 .replace_segments(vec![(new_segment_id.to_hex(), total_docs)], segment_ids)
156 .await?;
157
158 log::info!("Segments rebuilt with ANN indexes");
159 Ok(())
160 }
161
162 pub async fn total_vector_count(&self) -> usize {
164 let metadata_arc = self.segment_manager.metadata();
165 metadata_arc.read().await.total_vectors
166 }
167
168 pub async fn is_vector_index_built(&self, field: Field) -> bool {
170 let metadata_arc = self.segment_manager.metadata();
171 metadata_arc.read().await.is_field_built(field.0)
172 }
173
174 pub async fn rebuild_vector_index(&self) -> Result<()> {
183 let dense_fields: Vec<Field> = self
184 .schema
185 .fields()
186 .filter_map(|(field, entry)| {
187 if entry.field_type == FieldType::DenseVector && entry.indexed {
188 Some(field)
189 } else {
190 None
191 }
192 })
193 .collect();
194
195 if dense_fields.is_empty() {
196 return Ok(());
197 }
198
199 let files_to_delete = {
201 let metadata_arc = self.segment_manager.metadata();
202 let mut meta = metadata_arc.write().await;
203 let mut files = Vec::new();
204 for field in &dense_fields {
205 if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
206 field_meta.state = super::VectorIndexState::Flat;
207 if let Some(ref f) = field_meta.centroids_file {
208 files.push(f.clone());
209 }
210 if let Some(ref f) = field_meta.codebook_file {
211 files.push(f.clone());
212 }
213 field_meta.centroids_file = None;
214 field_meta.codebook_file = None;
215 }
216 }
217 meta.save(self.directory.as_ref()).await?;
218 files
219 };
220
221 for file in files_to_delete {
223 let _ = self.directory.delete(std::path::Path::new(&file)).await;
224 }
225
226 log::info!("Reset vector index state to Flat, triggering rebuild...");
227
228 self.build_vector_index().await
230 }
231
232 fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
238 self.schema
239 .fields()
240 .filter_map(|(field, entry)| {
241 if entry.field_type == FieldType::DenseVector && entry.indexed {
242 entry
243 .dense_vector_config
244 .as_ref()
245 .filter(|c| !c.is_flat())
246 .map(|c| (field, c.clone()))
247 } else {
248 None
249 }
250 })
251 .collect()
252 }
253
254 async fn get_fields_to_build(
256 &self,
257 dense_fields: &[(Field, DenseVectorConfig)],
258 ) -> Vec<(Field, DenseVectorConfig)> {
259 let metadata_arc = self.segment_manager.metadata();
260 let meta = metadata_arc.read().await;
261 dense_fields
262 .iter()
263 .filter(|(field, _)| !meta.is_field_built(field.0))
264 .cloned()
265 .collect()
266 }
267
268 async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
271 let mut total_vectors = 0usize;
272 let mut doc_offset = 0u32;
273
274 for id_str in segment_ids {
275 let Some(segment_id) = SegmentId::from_hex(id_str) else {
276 continue;
277 };
278
279 let files = crate::segment::SegmentFiles::new(segment_id.0);
281 if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
282 continue;
284 }
285
286 if let Ok(reader) = SegmentReader::open(
288 self.directory.as_ref(),
289 segment_id,
290 Arc::clone(&self.schema),
291 doc_offset,
292 self.config.term_cache_blocks,
293 )
294 .await
295 {
296 for index in reader.vector_indexes().values() {
297 if let crate::segment::VectorIndex::Flat(flat_data) = index {
298 total_vectors += flat_data.vectors.len();
299 }
300 }
301 doc_offset += reader.meta().num_docs;
302 }
303 }
304
305 total_vectors
306 }
307
308 async fn collect_vectors_for_training(
313 &self,
314 segment_ids: &[String],
315 fields_to_build: &[(Field, DenseVectorConfig)],
316 ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
317 const MAX_TRAINING_VECTORS: usize = 100_000;
319
320 let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
321 let mut doc_offset = 0u32;
322 let mut total_skipped = 0usize;
323
324 for id_str in segment_ids {
325 let segment_id = SegmentId::from_hex(id_str)
326 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
327 let reader = SegmentReader::open(
328 self.directory.as_ref(),
329 segment_id,
330 Arc::clone(&self.schema),
331 doc_offset,
332 self.config.term_cache_blocks,
333 )
334 .await?;
335
336 for (field_id, index) in reader.vector_indexes() {
337 if fields_to_build.iter().any(|(f, _)| f.0 == *field_id)
338 && let crate::segment::VectorIndex::Flat(flat_data) = index
339 {
340 let entry = all_vectors.entry(*field_id).or_default();
341 let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
342
343 if remaining == 0 {
344 total_skipped += flat_data.vectors.len();
345 continue;
346 }
347
348 if flat_data.vectors.len() <= remaining {
349 entry.extend(flat_data.vectors.iter().cloned());
351 } else {
352 let step = (flat_data.vectors.len() / remaining).max(1);
354 for (i, vec) in flat_data.vectors.iter().enumerate() {
355 if i % step == 0 && entry.len() < MAX_TRAINING_VECTORS {
356 entry.push(vec.clone());
357 }
358 }
359 total_skipped += flat_data.vectors.len() - remaining;
360 }
361 }
362 }
363
364 doc_offset += reader.meta().num_docs;
365 }
366
367 if total_skipped > 0 {
368 let collected: usize = all_vectors.values().map(|v| v.len()).sum();
369 log::info!(
370 "Sampled {} vectors for training (skipped {}, max {} per field)",
371 collected,
372 total_skipped,
373 MAX_TRAINING_VECTORS,
374 );
375 }
376
377 Ok(all_vectors)
378 }
379
380 pub(super) async fn load_segment_readers(
382 &self,
383 segment_ids: &[String],
384 ) -> Result<Vec<SegmentReader>> {
385 let mut readers = Vec::new();
386 let mut doc_offset = 0u32;
387
388 for id_str in segment_ids {
389 let segment_id = SegmentId::from_hex(id_str)
390 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
391 let reader = SegmentReader::open(
392 self.directory.as_ref(),
393 segment_id,
394 Arc::clone(&self.schema),
395 doc_offset,
396 self.config.term_cache_blocks,
397 )
398 .await?;
399 doc_offset += reader.meta().num_docs;
400 readers.push(reader);
401 }
402
403 Ok(readers)
404 }
405
406 async fn train_field_index(
408 &self,
409 field: Field,
410 config: &DenseVectorConfig,
411 all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
412 ) -> Result<()> {
413 let field_id = field.0;
414 let vectors = match all_vectors.get(&field_id) {
415 Some(v) if !v.is_empty() => v,
416 _ => return Ok(()),
417 };
418
419 let index_dim = config.index_dim();
420 let num_vectors = vectors.len();
421 let num_clusters = config.optimal_num_clusters(num_vectors);
422
423 log::info!(
424 "Training vector index for field {} with {} vectors, {} clusters",
425 field_id,
426 num_vectors,
427 num_clusters
428 );
429
430 let centroids_filename = format!("field_{}_centroids.bin", field_id);
431 let mut codebook_filename: Option<String> = None;
432
433 match config.index_type {
434 VectorIndexType::IvfRaBitQ => {
435 self.train_ivf_rabitq(
436 field_id,
437 index_dim,
438 num_clusters,
439 vectors,
440 ¢roids_filename,
441 )
442 .await?;
443 }
444 VectorIndexType::ScaNN => {
445 codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
446 self.train_scann(
447 field_id,
448 index_dim,
449 num_clusters,
450 vectors,
451 ¢roids_filename,
452 codebook_filename.as_ref().unwrap(),
453 )
454 .await?;
455 }
456 _ => {
457 return Ok(());
459 }
460 }
461
462 self.segment_manager
464 .update_metadata(|meta| {
465 meta.init_field(field_id, config.index_type);
466 meta.total_vectors = num_vectors;
467 meta.mark_field_built(
468 field_id,
469 num_vectors,
470 num_clusters,
471 centroids_filename.clone(),
472 codebook_filename.clone(),
473 );
474 })
475 .await?;
476
477 Ok(())
478 }
479
480 async fn train_ivf_rabitq(
482 &self,
483 field_id: u32,
484 index_dim: usize,
485 num_clusters: usize,
486 vectors: &[Vec<f32>],
487 centroids_filename: &str,
488 ) -> Result<()> {
489 let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
490 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
491
492 let centroids_path = std::path::Path::new(centroids_filename);
494 let centroids_bytes =
495 serde_json::to_vec(¢roids).map_err(|e| Error::Serialization(e.to_string()))?;
496 self.directory
497 .write(centroids_path, ¢roids_bytes)
498 .await?;
499
500 log::info!(
501 "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
502 field_id,
503 centroids.num_clusters
504 );
505
506 Ok(())
507 }
508
509 async fn train_scann(
511 &self,
512 field_id: u32,
513 index_dim: usize,
514 num_clusters: usize,
515 vectors: &[Vec<f32>],
516 centroids_filename: &str,
517 codebook_filename: &str,
518 ) -> Result<()> {
519 let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
521 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
522
523 let pq_config = crate::structures::PQConfig::new(index_dim);
525 let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
526
527 let centroids_path = std::path::Path::new(centroids_filename);
529 let centroids_bytes =
530 serde_json::to_vec(¢roids).map_err(|e| Error::Serialization(e.to_string()))?;
531 self.directory
532 .write(centroids_path, ¢roids_bytes)
533 .await?;
534
535 let codebook_path = std::path::Path::new(codebook_filename);
537 let codebook_bytes =
538 serde_json::to_vec(&codebook).map_err(|e| Error::Serialization(e.to_string()))?;
539 self.directory.write(codebook_path, &codebook_bytes).await?;
540
541 log::info!(
542 "Saved ScaNN centroids and codebook for field {} ({} clusters)",
543 field_id,
544 centroids.num_clusters
545 );
546
547 Ok(())
548 }
549}