1use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20 pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22 let dense_fields = self.get_dense_vector_fields();
23 if dense_fields.is_empty() {
24 return Ok(());
25 }
26
27 let all_built = {
30 let metadata_arc = self.segment_manager.metadata();
31 let meta = metadata_arc.read().await;
32 dense_fields
33 .iter()
34 .all(|(field, _)| meta.is_field_built(field.0))
35 };
36 if all_built {
37 return Ok(());
38 }
39
40 let segment_ids = self.segment_manager.get_segment_ids().await;
42 let total_vectors = self.count_flat_vectors(&segment_ids).await;
43
44 let should_build = {
46 let metadata_arc = self.segment_manager.metadata();
47 let mut meta = metadata_arc.write().await;
48 meta.total_vectors = total_vectors;
49 dense_fields.iter().any(|(field, config)| {
50 let threshold = config.build_threshold.unwrap_or(1000);
51 meta.should_build_field(field.0, threshold)
52 })
53 };
54
55 if should_build {
56 log::info!(
57 "Threshold crossed ({} vectors), auto-triggering vector index build",
58 total_vectors
59 );
60 self.build_vector_index().await?;
61 }
62
63 Ok(())
64 }
65
66 pub async fn build_vector_index(&self) -> Result<()> {
80 let dense_fields = self.get_dense_vector_fields();
81 if dense_fields.is_empty() {
82 log::info!("No dense vector fields configured for ANN indexing");
83 return Ok(());
84 }
85
86 let fields_to_build = self.get_fields_to_build(&dense_fields).await;
88 if fields_to_build.is_empty() {
89 log::info!("All vector fields already built, skipping training");
90 return Ok(());
91 }
92
93 let segment_ids = self.segment_manager.get_segment_ids().await;
94 if segment_ids.is_empty() {
95 return Ok(());
96 }
97
98 let all_vectors = self
100 .collect_vectors_for_training(&segment_ids, &fields_to_build)
101 .await?;
102
103 for (field, config) in &fields_to_build {
105 self.train_field_index(*field, config, &all_vectors).await?;
106 }
107
108 log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
109
110 self.rebuild_segments_with_ann().await?;
112
113 Ok(())
114 }
115
116 pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
118 let segment_ids = self.segment_manager.get_segment_ids().await;
119 if segment_ids.is_empty() {
120 return Ok(());
121 }
122
123 let (trained_centroids, trained_codebooks) = {
125 let metadata_arc = self.segment_manager.metadata();
126 let meta = metadata_arc.read().await;
127 meta.load_trained_structures(self.directory.as_ref()).await
128 };
129
130 if trained_centroids.is_empty() {
131 log::info!("No trained structures to rebuild with");
132 return Ok(());
133 }
134
135 let trained = TrainedVectorStructures {
136 centroids: trained_centroids,
137 codebooks: trained_codebooks,
138 };
139
140 let readers = self.load_segment_readers(&segment_ids).await?;
142
143 let total_docs: u32 = readers.iter().map(|r| r.meta().num_docs).sum();
145
146 let merger = SegmentMerger::new(Arc::clone(&self.schema));
148 let new_segment_id = SegmentId::new();
149 merger
150 .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
151 .await?;
152
153 self.segment_manager
155 .replace_segments(vec![(new_segment_id.to_hex(), total_docs)], segment_ids)
156 .await?;
157
158 log::info!("Segments rebuilt with ANN indexes");
159 Ok(())
160 }
161
162 pub async fn total_vector_count(&self) -> usize {
164 let metadata_arc = self.segment_manager.metadata();
165 metadata_arc.read().await.total_vectors
166 }
167
168 pub async fn is_vector_index_built(&self, field: Field) -> bool {
170 let metadata_arc = self.segment_manager.metadata();
171 metadata_arc.read().await.is_field_built(field.0)
172 }
173
174 pub async fn rebuild_vector_index(&self) -> Result<()> {
183 let dense_fields: Vec<Field> = self
184 .schema
185 .fields()
186 .filter_map(|(field, entry)| {
187 if entry.field_type == FieldType::DenseVector && entry.indexed {
188 Some(field)
189 } else {
190 None
191 }
192 })
193 .collect();
194
195 if dense_fields.is_empty() {
196 return Ok(());
197 }
198
199 let files_to_delete = {
201 let metadata_arc = self.segment_manager.metadata();
202 let mut meta = metadata_arc.write().await;
203 let mut files = Vec::new();
204 for field in &dense_fields {
205 if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
206 field_meta.state = super::VectorIndexState::Flat;
207 if let Some(ref f) = field_meta.centroids_file {
208 files.push(f.clone());
209 }
210 if let Some(ref f) = field_meta.codebook_file {
211 files.push(f.clone());
212 }
213 field_meta.centroids_file = None;
214 field_meta.codebook_file = None;
215 }
216 }
217 meta.save(self.directory.as_ref()).await?;
218 files
219 };
220
221 for file in files_to_delete {
223 let _ = self.directory.delete(std::path::Path::new(&file)).await;
224 }
225
226 log::info!("Reset vector index state to Flat, triggering rebuild...");
227
228 self.build_vector_index().await
230 }
231
232 fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
238 self.schema
239 .fields()
240 .filter_map(|(field, entry)| {
241 if entry.field_type == FieldType::DenseVector && entry.indexed {
242 entry
243 .dense_vector_config
244 .as_ref()
245 .filter(|c| !c.is_flat())
246 .map(|c| (field, c.clone()))
247 } else {
248 None
249 }
250 })
251 .collect()
252 }
253
254 async fn get_fields_to_build(
256 &self,
257 dense_fields: &[(Field, DenseVectorConfig)],
258 ) -> Vec<(Field, DenseVectorConfig)> {
259 let metadata_arc = self.segment_manager.metadata();
260 let meta = metadata_arc.read().await;
261 dense_fields
262 .iter()
263 .filter(|(field, _)| !meta.is_field_built(field.0))
264 .cloned()
265 .collect()
266 }
267
268 async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
271 let mut total_vectors = 0usize;
272 let mut doc_offset = 0u32;
273
274 for id_str in segment_ids {
275 let Some(segment_id) = SegmentId::from_hex(id_str) else {
276 continue;
277 };
278
279 let files = crate::segment::SegmentFiles::new(segment_id.0);
281 if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
282 continue;
284 }
285
286 if let Ok(reader) = SegmentReader::open(
288 self.directory.as_ref(),
289 segment_id,
290 Arc::clone(&self.schema),
291 doc_offset,
292 self.config.term_cache_blocks,
293 )
294 .await
295 {
296 for index in reader.vector_indexes().values() {
297 if let crate::segment::VectorIndex::Flat(flat_data) = index {
298 total_vectors += flat_data.vectors.len();
299 }
300 }
301 doc_offset += reader.meta().num_docs;
302 }
303 }
304
305 total_vectors
306 }
307
308 async fn collect_vectors_for_training(
310 &self,
311 segment_ids: &[String],
312 fields_to_build: &[(Field, DenseVectorConfig)],
313 ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
314 let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
315 let mut doc_offset = 0u32;
316
317 for id_str in segment_ids {
318 let segment_id = SegmentId::from_hex(id_str)
319 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
320 let reader = SegmentReader::open(
321 self.directory.as_ref(),
322 segment_id,
323 Arc::clone(&self.schema),
324 doc_offset,
325 self.config.term_cache_blocks,
326 )
327 .await?;
328
329 for (field_id, index) in reader.vector_indexes() {
331 if fields_to_build.iter().any(|(f, _)| f.0 == *field_id)
332 && let crate::segment::VectorIndex::Flat(flat_data) = index
333 {
334 all_vectors
335 .entry(*field_id)
336 .or_default()
337 .extend(flat_data.vectors.iter().cloned());
338 }
339 }
340
341 doc_offset += reader.meta().num_docs;
342 }
343
344 Ok(all_vectors)
345 }
346
347 pub(super) async fn load_segment_readers(
349 &self,
350 segment_ids: &[String],
351 ) -> Result<Vec<SegmentReader>> {
352 let mut readers = Vec::new();
353 let mut doc_offset = 0u32;
354
355 for id_str in segment_ids {
356 let segment_id = SegmentId::from_hex(id_str)
357 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
358 let reader = SegmentReader::open(
359 self.directory.as_ref(),
360 segment_id,
361 Arc::clone(&self.schema),
362 doc_offset,
363 self.config.term_cache_blocks,
364 )
365 .await?;
366 doc_offset += reader.meta().num_docs;
367 readers.push(reader);
368 }
369
370 Ok(readers)
371 }
372
373 async fn train_field_index(
375 &self,
376 field: Field,
377 config: &DenseVectorConfig,
378 all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
379 ) -> Result<()> {
380 let field_id = field.0;
381 let vectors = match all_vectors.get(&field_id) {
382 Some(v) if !v.is_empty() => v,
383 _ => return Ok(()),
384 };
385
386 let index_dim = config.index_dim();
387 let num_vectors = vectors.len();
388 let num_clusters = config.optimal_num_clusters(num_vectors);
389
390 log::info!(
391 "Training vector index for field {} with {} vectors, {} clusters",
392 field_id,
393 num_vectors,
394 num_clusters
395 );
396
397 let centroids_filename = format!("field_{}_centroids.bin", field_id);
398 let mut codebook_filename: Option<String> = None;
399
400 match config.index_type {
401 VectorIndexType::IvfRaBitQ => {
402 self.train_ivf_rabitq(
403 field_id,
404 index_dim,
405 num_clusters,
406 vectors,
407 ¢roids_filename,
408 )
409 .await?;
410 }
411 VectorIndexType::ScaNN => {
412 codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
413 self.train_scann(
414 field_id,
415 index_dim,
416 num_clusters,
417 vectors,
418 ¢roids_filename,
419 codebook_filename.as_ref().unwrap(),
420 )
421 .await?;
422 }
423 _ => {
424 return Ok(());
426 }
427 }
428
429 self.segment_manager
431 .update_metadata(|meta| {
432 meta.init_field(field_id, config.index_type);
433 meta.total_vectors = num_vectors;
434 meta.mark_field_built(
435 field_id,
436 num_vectors,
437 num_clusters,
438 centroids_filename.clone(),
439 codebook_filename.clone(),
440 );
441 })
442 .await?;
443
444 Ok(())
445 }
446
447 async fn train_ivf_rabitq(
449 &self,
450 field_id: u32,
451 index_dim: usize,
452 num_clusters: usize,
453 vectors: &[Vec<f32>],
454 centroids_filename: &str,
455 ) -> Result<()> {
456 let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
457 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
458
459 let centroids_path = std::path::Path::new(centroids_filename);
461 let centroids_bytes =
462 serde_json::to_vec(¢roids).map_err(|e| Error::Serialization(e.to_string()))?;
463 self.directory
464 .write(centroids_path, ¢roids_bytes)
465 .await?;
466
467 log::info!(
468 "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
469 field_id,
470 centroids.num_clusters
471 );
472
473 Ok(())
474 }
475
476 async fn train_scann(
478 &self,
479 field_id: u32,
480 index_dim: usize,
481 num_clusters: usize,
482 vectors: &[Vec<f32>],
483 centroids_filename: &str,
484 codebook_filename: &str,
485 ) -> Result<()> {
486 let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
488 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
489
490 let pq_config = crate::structures::PQConfig::new(index_dim);
492 let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
493
494 let centroids_path = std::path::Path::new(centroids_filename);
496 let centroids_bytes =
497 serde_json::to_vec(¢roids).map_err(|e| Error::Serialization(e.to_string()))?;
498 self.directory
499 .write(centroids_path, ¢roids_bytes)
500 .await?;
501
502 let codebook_path = std::path::Path::new(codebook_filename);
504 let codebook_bytes =
505 serde_json::to_vec(&codebook).map_err(|e| Error::Serialization(e.to_string()))?;
506 self.directory.write(codebook_path, &codebook_bytes).await?;
507
508 log::info!(
509 "Saved ScaNN centroids and codebook for field {} ({} clusters)",
510 field_id,
511 centroids.num_clusters
512 );
513
514 Ok(())
515 }
516}