1use std::sync::Arc;
9
10use rustc_hash::FxHashMap;
11
12use crate::directories::DirectoryWriter;
13use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
14use crate::error::{Error, Result};
15use crate::segment::{SegmentId, SegmentReader};
16
17use super::IndexWriter;
18
19impl<D: DirectoryWriter + 'static> IndexWriter<D> {
20 pub(super) async fn maybe_build_vector_index(&self) -> Result<()> {
22 let dense_fields = self.get_dense_vector_fields();
23 if dense_fields.is_empty() {
24 return Ok(());
25 }
26
27 let all_built = {
30 let metadata_arc = self.segment_manager.metadata();
31 let meta = metadata_arc.read().await;
32 dense_fields
33 .iter()
34 .all(|(field, _)| meta.is_field_built(field.0))
35 };
36 if all_built {
37 if self
39 .trained_structures
40 .read()
41 .ok()
42 .is_none_or(|g| g.is_none())
43 {
44 self.publish_trained_structures().await;
45 }
46 return Ok(());
47 }
48
49 let segment_ids = self.segment_manager.get_segment_ids().await;
51 let total_vectors = self.count_flat_vectors(&segment_ids).await;
52
53 let should_build = {
55 let metadata_arc = self.segment_manager.metadata();
56 let mut meta = metadata_arc.write().await;
57 meta.total_vectors = total_vectors;
58 dense_fields.iter().any(|(field, config)| {
59 let threshold = config.build_threshold.unwrap_or(1000);
60 meta.should_build_field(field.0, threshold)
61 })
62 };
63
64 if should_build {
65 log::info!(
66 "Threshold crossed ({} vectors), auto-triggering vector index build",
67 total_vectors
68 );
69 self.build_vector_index().await?;
70 }
71
72 Ok(())
73 }
74
75 pub async fn build_vector_index(&self) -> Result<()> {
89 let dense_fields = self.get_dense_vector_fields();
90 if dense_fields.is_empty() {
91 log::info!("No dense vector fields configured for ANN indexing");
92 return Ok(());
93 }
94
95 let fields_to_build = self.get_fields_to_build(&dense_fields).await;
97 if fields_to_build.is_empty() {
98 log::info!("All vector fields already built, skipping training");
99 return Ok(());
100 }
101
102 let segment_ids = self.segment_manager.get_segment_ids().await;
103 if segment_ids.is_empty() {
104 return Ok(());
105 }
106
107 let all_vectors = self
109 .collect_vectors_for_training(&segment_ids, &fields_to_build)
110 .await?;
111
112 for (field, config) in &fields_to_build {
114 self.train_field_index(*field, config, &all_vectors).await?;
115 }
116
117 self.publish_trained_structures().await;
120
121 log::info!("Vector index training complete, new segments will have ANN inline");
122
123 Ok(())
124 }
125
126 pub(super) async fn publish_trained_structures(&self) {
129 let trained = {
130 let metadata_arc = self.segment_manager.metadata();
131 let meta = metadata_arc.read().await;
132 meta.load_trained_structures(self.directory.as_ref()).await
133 };
134 if let Some(trained) = trained
135 && let Ok(mut guard) = self.trained_structures.write()
136 {
137 log::info!(
138 "[writer] published trained structures to workers ({} fields)",
139 trained.centroids.len()
140 );
141 *guard = Some(trained);
142 }
143 }
144
145 pub async fn total_vector_count(&self) -> usize {
147 let metadata_arc = self.segment_manager.metadata();
148 metadata_arc.read().await.total_vectors
149 }
150
151 pub async fn is_vector_index_built(&self, field: Field) -> bool {
153 let metadata_arc = self.segment_manager.metadata();
154 metadata_arc.read().await.is_field_built(field.0)
155 }
156
157 pub async fn rebuild_vector_index(&self) -> Result<()> {
166 let dense_fields = self.get_dense_vector_fields();
167 if dense_fields.is_empty() {
168 return Ok(());
169 }
170 let dense_fields: Vec<Field> = dense_fields.into_iter().map(|(f, _)| f).collect();
171
172 let files_to_delete = {
174 let metadata_arc = self.segment_manager.metadata();
175 let mut meta = metadata_arc.write().await;
176 let mut files = Vec::new();
177 for field in &dense_fields {
178 if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
179 field_meta.state = super::VectorIndexState::Flat;
180 if let Some(ref f) = field_meta.centroids_file {
181 files.push(f.clone());
182 }
183 if let Some(ref f) = field_meta.codebook_file {
184 files.push(f.clone());
185 }
186 field_meta.centroids_file = None;
187 field_meta.codebook_file = None;
188 }
189 }
190 meta.save(self.directory.as_ref()).await?;
191 files
192 };
193
194 for file in files_to_delete {
196 let _ = self.directory.delete(std::path::Path::new(&file)).await;
197 }
198
199 if let Ok(mut guard) = self.trained_structures.write() {
202 *guard = None;
203 }
204
205 log::info!("Reset vector index state to Flat, triggering rebuild...");
206
207 self.build_vector_index().await
209 }
210
211 fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
217 self.schema
218 .fields()
219 .filter_map(|(field, entry)| {
220 if entry.field_type == FieldType::DenseVector && entry.indexed {
221 entry
222 .dense_vector_config
223 .as_ref()
224 .filter(|c| !c.is_flat())
225 .map(|c| (field, c.clone()))
226 } else {
227 None
228 }
229 })
230 .collect()
231 }
232
233 async fn get_fields_to_build(
235 &self,
236 dense_fields: &[(Field, DenseVectorConfig)],
237 ) -> Vec<(Field, DenseVectorConfig)> {
238 let metadata_arc = self.segment_manager.metadata();
239 let meta = metadata_arc.read().await;
240 dense_fields
241 .iter()
242 .filter(|(field, _)| !meta.is_field_built(field.0))
243 .cloned()
244 .collect()
245 }
246
247 async fn count_flat_vectors(&self, segment_ids: &[String]) -> usize {
250 let mut total_vectors = 0usize;
251 let mut doc_offset = 0u32;
252
253 for id_str in segment_ids {
254 let Some(segment_id) = SegmentId::from_hex(id_str) else {
255 continue;
256 };
257
258 let files = crate::segment::SegmentFiles::new(segment_id.0);
260 if !self.directory.exists(&files.vectors).await.unwrap_or(false) {
261 continue;
263 }
264
265 if let Ok(reader) = SegmentReader::open(
267 self.directory.as_ref(),
268 segment_id,
269 Arc::clone(&self.schema),
270 doc_offset,
271 self.config.term_cache_blocks,
272 )
273 .await
274 {
275 for flat_data in reader.flat_vectors().values() {
276 total_vectors += flat_data.num_vectors;
277 }
278 doc_offset += reader.meta().num_docs;
279 }
280 }
281
282 total_vectors
283 }
284
285 async fn collect_vectors_for_training(
290 &self,
291 segment_ids: &[String],
292 fields_to_build: &[(Field, DenseVectorConfig)],
293 ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
294 const MAX_TRAINING_VECTORS: usize = 100_000;
296
297 let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
298 let mut doc_offset = 0u32;
299 let mut total_skipped = 0usize;
300
301 for id_str in segment_ids {
302 let segment_id = SegmentId::from_hex(id_str)
303 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
304 let reader = SegmentReader::open(
305 self.directory.as_ref(),
306 segment_id,
307 Arc::clone(&self.schema),
308 doc_offset,
309 self.config.term_cache_blocks,
310 )
311 .await?;
312
313 for (field_id, lazy_flat) in reader.flat_vectors() {
314 if !fields_to_build.iter().any(|(f, _)| f.0 == *field_id) {
315 continue;
316 }
317 let entry = all_vectors.entry(*field_id).or_default();
318 let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
319
320 if remaining == 0 {
321 total_skipped += lazy_flat.num_vectors;
322 continue;
323 }
324
325 let n = lazy_flat.num_vectors;
326 let dim = lazy_flat.dim;
327 let quant = lazy_flat.quantization;
328
329 let indices: Vec<usize> = if n <= remaining {
331 (0..n).collect()
332 } else {
333 let step = (n / remaining).max(1);
334 (0..n).step_by(step).take(remaining).collect()
335 };
336
337 if indices.len() < n {
338 total_skipped += n - indices.len();
339 }
340
341 const BATCH: usize = 1024;
343 let mut f32_buf = vec![0f32; BATCH * dim];
344 for chunk in indices.chunks(BATCH) {
345 let start = chunk[0];
347 let end = *chunk.last().unwrap();
348 if end - start + 1 == chunk.len() {
349 if let Ok(batch_bytes) =
351 lazy_flat.read_vectors_batch(start, chunk.len()).await
352 {
353 let floats = chunk.len() * dim;
354 f32_buf.resize(floats, 0.0);
355 crate::segment::dequantize_raw(
356 batch_bytes.as_slice(),
357 quant,
358 floats,
359 &mut f32_buf,
360 );
361 for i in 0..chunk.len() {
362 entry.push(f32_buf[i * dim..(i + 1) * dim].to_vec());
363 }
364 }
365 } else {
366 f32_buf.resize(dim, 0.0);
368 for &idx in chunk {
369 if let Ok(()) = lazy_flat.read_vector_into(idx, &mut f32_buf).await {
370 entry.push(f32_buf[..dim].to_vec());
371 }
372 }
373 }
374 }
375 }
376
377 doc_offset += reader.meta().num_docs;
378 }
379
380 if total_skipped > 0 {
381 let collected: usize = all_vectors.values().map(|v| v.len()).sum();
382 log::info!(
383 "Sampled {} vectors for training (skipped {}, max {} per field)",
384 collected,
385 total_skipped,
386 MAX_TRAINING_VECTORS,
387 );
388 }
389
390 Ok(all_vectors)
391 }
392
393 async fn train_field_index(
395 &self,
396 field: Field,
397 config: &DenseVectorConfig,
398 all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
399 ) -> Result<()> {
400 let field_id = field.0;
401 let vectors = match all_vectors.get(&field_id) {
402 Some(v) if !v.is_empty() => v,
403 _ => return Ok(()),
404 };
405
406 let dim = config.dim;
407 let num_vectors = vectors.len();
408 let num_clusters = config.optimal_num_clusters(num_vectors);
409
410 log::info!(
411 "Training vector index for field {} with {} vectors, {} clusters (dim={})",
412 field_id,
413 num_vectors,
414 num_clusters,
415 dim,
416 );
417
418 let centroids_filename = format!("field_{}_centroids.bin", field_id);
419 let mut codebook_filename: Option<String> = None;
420
421 match config.index_type {
422 VectorIndexType::IvfRaBitQ => {
423 self.train_ivf_rabitq(field_id, dim, num_clusters, vectors, ¢roids_filename)
424 .await?;
425 }
426 VectorIndexType::ScaNN => {
427 codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
428 self.train_scann(
429 field_id,
430 dim,
431 num_clusters,
432 vectors,
433 ¢roids_filename,
434 codebook_filename.as_ref().unwrap(),
435 )
436 .await?;
437 }
438 _ => {
439 return Ok(());
441 }
442 }
443
444 self.segment_manager
446 .update_metadata(|meta| {
447 meta.init_field(field_id, config.index_type);
448 meta.total_vectors = num_vectors;
449 meta.mark_field_built(
450 field_id,
451 num_vectors,
452 num_clusters,
453 centroids_filename.clone(),
454 codebook_filename.clone(),
455 );
456 })
457 .await?;
458
459 Ok(())
460 }
461
462 async fn save_trained_artifact(
464 &self,
465 artifact: &impl serde::Serialize,
466 filename: &str,
467 ) -> Result<()> {
468 let bytes =
469 serde_json::to_vec(artifact).map_err(|e| Error::Serialization(e.to_string()))?;
470 self.directory
471 .write(std::path::Path::new(filename), &bytes)
472 .await?;
473 Ok(())
474 }
475
476 async fn train_ivf_rabitq(
478 &self,
479 field_id: u32,
480 dim: usize,
481 num_clusters: usize,
482 vectors: &[Vec<f32>],
483 centroids_filename: &str,
484 ) -> Result<()> {
485 let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
486 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
487 self.save_trained_artifact(¢roids, centroids_filename)
488 .await?;
489
490 log::info!(
491 "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
492 field_id,
493 centroids.num_clusters
494 );
495 Ok(())
496 }
497
498 async fn train_scann(
500 &self,
501 field_id: u32,
502 dim: usize,
503 num_clusters: usize,
504 vectors: &[Vec<f32>],
505 centroids_filename: &str,
506 codebook_filename: &str,
507 ) -> Result<()> {
508 let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
509 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
510 self.save_trained_artifact(¢roids, centroids_filename)
511 .await?;
512
513 let pq_config = crate::structures::PQConfig::new(dim);
514 let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
515 self.save_trained_artifact(&codebook, codebook_filename)
516 .await?;
517
518 log::info!(
519 "Saved ScaNN centroids and codebook for field {} ({} clusters)",
520 field_id,
521 centroids.num_clusters
522 );
523 Ok(())
524 }
525}