1use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentMerger, SegmentReader, TrainedVectorStructures};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20 pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22 let dense_fields = self.get_dense_vector_fields();
23 if dense_fields.is_empty() {
24 return Ok(());
25 }
26
27 let segment_ids = self.segment_manager.get_segment_ids().await;
29 let total_vectors = self.count_flat_vectors(&segment_ids).await;
30
31 let should_build = {
33 let metadata_arc = self.segment_manager.metadata();
34 let mut meta = metadata_arc.lock().await;
35 meta.total_vectors = total_vectors;
36 dense_fields.iter().any(|(field, config)| {
37 let threshold = config.build_threshold.unwrap_or(1000);
38 meta.should_build_field(field.0, threshold)
39 })
40 };
41
42 if should_build {
43 log::info!(
44 "Threshold crossed ({} vectors), auto-triggering vector index build",
45 total_vectors
46 );
47 self.build_vector_index().await?;
48 }
49
50 Ok(())
51 }
52
53 pub async fn build_vector_index(&self) -> Result<()> {
67 let dense_fields = self.get_dense_vector_fields();
68 if dense_fields.is_empty() {
69 log::info!("No dense vector fields configured for ANN indexing");
70 return Ok(());
71 }
72
73 let fields_to_build = self.get_fields_to_build(&dense_fields).await;
75 if fields_to_build.is_empty() {
76 log::info!("All vector fields already built, skipping training");
77 return Ok(());
78 }
79
80 let segment_ids = self.segment_manager.get_segment_ids().await;
81 if segment_ids.is_empty() {
82 return Ok(());
83 }
84
85 let all_vectors = self
87 .collect_vectors_for_training(&segment_ids, &fields_to_build)
88 .await?;
89
90 for (field, config) in &fields_to_build {
92 self.train_field_index(*field, config, &all_vectors).await?;
93 }
94
95 log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
96
97 self.rebuild_segments_with_ann().await?;
99
100 Ok(())
101 }
102
103 pub(super) async fn rebuild_segments_with_ann(&self) -> Result<()> {
105 let segment_ids = self.segment_manager.get_segment_ids().await;
106 if segment_ids.is_empty() {
107 return Ok(());
108 }
109
110 let (trained_centroids, trained_codebooks) = {
112 let metadata_arc = self.segment_manager.metadata();
113 let meta = metadata_arc.lock().await;
114 meta.load_trained_structures(self.directory.as_ref()).await
115 };
116
117 if trained_centroids.is_empty() {
118 log::info!("No trained structures to rebuild with");
119 return Ok(());
120 }
121
122 let trained = TrainedVectorStructures {
123 centroids: trained_centroids,
124 codebooks: trained_codebooks,
125 };
126
127 let readers = self.load_segment_readers(&segment_ids).await?;
129
130 let merger = SegmentMerger::new(Arc::clone(&self.schema));
132 let new_segment_id = SegmentId::new();
133 merger
134 .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
135 .await?;
136
137 self.segment_manager
139 .replace_segments(vec![new_segment_id.to_hex()], segment_ids)
140 .await?;
141
142 log::info!("Segments rebuilt with ANN indexes");
143 Ok(())
144 }
145
146 pub async fn total_vector_count(&self) -> usize {
148 let metadata_arc = self.segment_manager.metadata();
149 metadata_arc.lock().await.total_vectors
150 }
151
152 pub async fn is_vector_index_built(&self, field: Field) -> bool {
154 let metadata_arc = self.segment_manager.metadata();
155 metadata_arc.lock().await.is_field_built(field.0)
156 }
157
158 pub async fn rebuild_vector_index(&self) -> Result<()> {
167 let dense_fields: Vec<Field> = self
168 .schema
169 .fields()
170 .filter_map(|(field, entry)| {
171 if entry.field_type == FieldType::DenseVector && entry.indexed {
172 Some(field)
173 } else {
174 None
175 }
176 })
177 .collect();
178
179 if dense_fields.is_empty() {
180 return Ok(());
181 }
182
183 let files_to_delete: Vec<String> = {
185 let metadata_arc = self.segment_manager.metadata();
186 let mut meta = metadata_arc.lock().await;
187 let mut files = Vec::new();
188 for field in &dense_fields {
189 if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
190 field_meta.state = super::VectorIndexState::Flat;
191 if let Some(ref f) = field_meta.centroids_file {
192 files.push(f.clone());
193 }
194 if let Some(ref f) = field_meta.codebook_file {
195 files.push(f.clone());
196 }
197 field_meta.centroids_file = None;
198 field_meta.codebook_file = None;
199 }
200 }
201 meta.save(self.directory.as_ref()).await?;
202 files
203 };
204
205 for file in files_to_delete {
207 let _ = self.directory.delete(std::path::Path::new(&file)).await;
208 }
209
210 log::info!("Reset vector index state to Flat, triggering rebuild...");
211
212 self.build_vector_index().await
214 }
215
216 fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
222 self.schema
223 .fields()
224 .filter_map(|(field, entry)| {
225 if entry.field_type == FieldType::DenseVector && entry.indexed {
226 entry
227 .dense_vector_config
228 .as_ref()
229 .filter(|c| !c.is_flat())
230 .map(|c| (field, c.clone()))
231 } else {
232 None
233 }
234 })
235 .collect()
236 }
237
238 async fn get_fields_to_build(
240 &self,
241 dense_fields: &[(Field, DenseVectorConfig)],
242 ) -> Vec<(Field, DenseVectorConfig)> {
243 let metadata_arc = self.segment_manager.metadata();
244 let meta = metadata_arc.lock().await;
245 dense_fields
246 .iter()
247 .filter(|(field, _)| !meta.is_field_built(field.0))
248 .cloned()
249 .collect()
250 }
251
252 async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
254 let mut total_vectors = 0usize;
255 let mut doc_offset = 0u32;
256
257 for id_str in segment_ids {
258 if let Some(segment_id) = SegmentId::from_hex(id_str)
259 && let Ok(reader) = SegmentReader::open(
260 self.directory.as_ref(),
261 segment_id,
262 Arc::clone(&self.schema),
263 doc_offset,
264 self.config.term_cache_blocks,
265 )
266 .await
267 {
268 for index in reader.vector_indexes().values() {
269 if let crate::segment::VectorIndex::Flat(flat_data) = index {
270 total_vectors += flat_data.vectors.len();
271 }
272 }
273 doc_offset += reader.meta().num_docs;
274 }
275 }
276
277 total_vectors
278 }
279
280 async fn collect_vectors_for_training(
282 &self,
283 segment_ids: &[String],
284 fields_to_build: &[(Field, DenseVectorConfig)],
285 ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
286 let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
287 let mut doc_offset = 0u32;
288
289 for id_str in segment_ids {
290 let segment_id = SegmentId::from_hex(id_str)
291 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
292 let reader = SegmentReader::open(
293 self.directory.as_ref(),
294 segment_id,
295 Arc::clone(&self.schema),
296 doc_offset,
297 self.config.term_cache_blocks,
298 )
299 .await?;
300
301 for (field_id, index) in reader.vector_indexes() {
303 if fields_to_build.iter().any(|(f, _)| f.0 == *field_id)
304 && let crate::segment::VectorIndex::Flat(flat_data) = index
305 {
306 all_vectors
307 .entry(*field_id)
308 .or_default()
309 .extend(flat_data.vectors.iter().cloned());
310 }
311 }
312
313 doc_offset += reader.meta().num_docs;
314 }
315
316 Ok(all_vectors)
317 }
318
319 pub(super) async fn load_segment_readers(
321 &self,
322 segment_ids: &[String],
323 ) -> Result<Vec<SegmentReader>> {
324 let mut readers = Vec::new();
325 let mut doc_offset = 0u32;
326
327 for id_str in segment_ids {
328 let segment_id = SegmentId::from_hex(id_str)
329 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
330 let reader = SegmentReader::open(
331 self.directory.as_ref(),
332 segment_id,
333 Arc::clone(&self.schema),
334 doc_offset,
335 self.config.term_cache_blocks,
336 )
337 .await?;
338 doc_offset += reader.meta().num_docs;
339 readers.push(reader);
340 }
341
342 Ok(readers)
343 }
344
345 async fn train_field_index(
347 &self,
348 field: Field,
349 config: &DenseVectorConfig,
350 all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
351 ) -> Result<()> {
352 let field_id = field.0;
353 let vectors = match all_vectors.get(&field_id) {
354 Some(v) if !v.is_empty() => v,
355 _ => return Ok(()),
356 };
357
358 let index_dim = config.index_dim();
359 let num_vectors = vectors.len();
360 let num_clusters = config.optimal_num_clusters(num_vectors);
361
362 log::info!(
363 "Training vector index for field {} with {} vectors, {} clusters",
364 field_id,
365 num_vectors,
366 num_clusters
367 );
368
369 let centroids_filename = format!("field_{}_centroids.bin", field_id);
370 let mut codebook_filename: Option<String> = None;
371
372 match config.index_type {
373 VectorIndexType::IvfRaBitQ => {
374 self.train_ivf_rabitq(
375 field_id,
376 index_dim,
377 num_clusters,
378 vectors,
379 ¢roids_filename,
380 )
381 .await?;
382 }
383 VectorIndexType::ScaNN => {
384 codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
385 self.train_scann(
386 field_id,
387 index_dim,
388 num_clusters,
389 vectors,
390 ¢roids_filename,
391 codebook_filename.as_ref().unwrap(),
392 )
393 .await?;
394 }
395 _ => {
396 return Ok(());
398 }
399 }
400
401 self.segment_manager
403 .update_metadata(|meta| {
404 meta.init_field(field_id, config.index_type);
405 meta.total_vectors = num_vectors;
406 meta.mark_field_built(
407 field_id,
408 num_vectors,
409 num_clusters,
410 centroids_filename.clone(),
411 codebook_filename.clone(),
412 );
413 })
414 .await?;
415
416 Ok(())
417 }
418
419 async fn train_ivf_rabitq(
421 &self,
422 field_id: u32,
423 index_dim: usize,
424 num_clusters: usize,
425 vectors: &[Vec<f32>],
426 centroids_filename: &str,
427 ) -> Result<()> {
428 let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
429 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
430
431 let centroids_path = std::path::Path::new(centroids_filename);
433 let centroids_bytes =
434 serde_json::to_vec(¢roids).map_err(|e| Error::Serialization(e.to_string()))?;
435 self.directory
436 .write(centroids_path, ¢roids_bytes)
437 .await?;
438
439 log::info!(
440 "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
441 field_id,
442 centroids.num_clusters
443 );
444
445 Ok(())
446 }
447
448 async fn train_scann(
450 &self,
451 field_id: u32,
452 index_dim: usize,
453 num_clusters: usize,
454 vectors: &[Vec<f32>],
455 centroids_filename: &str,
456 codebook_filename: &str,
457 ) -> Result<()> {
458 let coarse_config = crate::structures::CoarseConfig::new(index_dim, num_clusters);
460 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
461
462 let pq_config = crate::structures::PQConfig::new(index_dim);
464 let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
465
466 let centroids_path = std::path::Path::new(centroids_filename);
468 let centroids_bytes =
469 serde_json::to_vec(¢roids).map_err(|e| Error::Serialization(e.to_string()))?;
470 self.directory
471 .write(centroids_path, ¢roids_bytes)
472 .await?;
473
474 let codebook_path = std::path::Path::new(codebook_filename);
476 let codebook_bytes =
477 serde_json::to_vec(&codebook).map_err(|e| Error::Serialization(e.to_string()))?;
478 self.directory.write(codebook_path, &codebook_bytes).await?;
479
480 log::info!(
481 "Saved ScaNN centroids and codebook for field {} ({} clusters)",
482 field_id,
483 centroids.num_clusters
484 );
485
486 Ok(())
487 }
488}