1use std::sync::Arc;
8
9use rustc_hash::FxHashMap;
10
11use crate::directories::DirectoryWriter;
12use crate::dsl::{DenseVectorConfig, Field, FieldType, VectorIndexType};
13use crate::error::{Error, Result};
14use crate::segment::{SegmentId, SegmentReader};
15
16use super::IndexWriter;
17
18impl<D: DirectoryWriter + 'static> IndexWriter<D> {
19 pub async fn build_vector_index(&self) -> Result<()> {
29 let dense_fields = self.get_dense_vector_fields();
30 if dense_fields.is_empty() {
31 log::info!("No dense vector fields configured for ANN indexing");
32 return Ok(());
33 }
34
35 let fields_to_build = self.get_fields_to_build(&dense_fields).await;
37 if fields_to_build.is_empty() {
38 log::info!("All vector fields already built, skipping training");
39 return Ok(());
40 }
41
42 let snapshot = self.segment_manager.acquire_snapshot().await;
44 let segment_ids = snapshot.segment_ids();
45 if segment_ids.is_empty() {
46 return Ok(());
47 }
48
49 let all_vectors = self
51 .collect_vectors_for_training(segment_ids, &fields_to_build)
52 .await?;
53
54 for (field, config) in &fields_to_build {
56 self.train_field_index(*field, config, &all_vectors).await?;
57 }
58
59 self.segment_manager.load_and_publish_trained().await;
61
62 log::info!("Vector index training complete, ANN will be built during merges");
63
64 Ok(())
65 }
66
67 pub async fn rebuild_vector_index(&self) -> Result<()> {
71 let dense_fields = self.get_dense_vector_fields();
72 if dense_fields.is_empty() {
73 return Ok(());
74 }
75 let dense_fields: Vec<Field> = dense_fields.into_iter().map(|(f, _)| f).collect();
76
77 let dense_field_ids: Vec<u32> = dense_fields.iter().map(|f| f.0).collect();
79 let mut files_to_delete = Vec::new();
80 self.segment_manager
81 .update_metadata(|meta| {
82 for field_id in &dense_field_ids {
83 if let Some(field_meta) = meta.vector_fields.get_mut(field_id) {
84 field_meta.state = super::VectorIndexState::Flat;
85 if let Some(ref f) = field_meta.centroids_file {
86 files_to_delete.push(f.clone());
87 }
88 if let Some(ref f) = field_meta.codebook_file {
89 files_to_delete.push(f.clone());
90 }
91 field_meta.centroids_file = None;
92 field_meta.codebook_file = None;
93 }
94 }
95 })
96 .await?;
97
98 for file in files_to_delete {
100 let _ = self.directory.delete(std::path::Path::new(&file)).await;
101 }
102
103 self.segment_manager.clear_trained();
105
106 log::info!("Reset vector index state to Flat, triggering rebuild...");
107
108 self.build_vector_index().await
109 }
110
111 fn get_dense_vector_fields(&self) -> Vec<(Field, DenseVectorConfig)> {
117 self.schema
118 .fields()
119 .filter_map(|(field, entry)| {
120 if entry.field_type == FieldType::DenseVector && entry.indexed {
121 entry
122 .dense_vector_config
123 .as_ref()
124 .filter(|c| !c.is_flat())
125 .map(|c| (field, c.clone()))
126 } else {
127 None
128 }
129 })
130 .collect()
131 }
132
133 async fn get_fields_to_build(
135 &self,
136 dense_fields: &[(Field, DenseVectorConfig)],
137 ) -> Vec<(Field, DenseVectorConfig)> {
138 let field_ids: Vec<u32> = dense_fields.iter().map(|(f, _)| f.0).collect();
139 let built: Vec<u32> = self
140 .segment_manager
141 .read_metadata(|meta| {
142 field_ids
143 .iter()
144 .filter(|fid| meta.is_field_built(**fid))
145 .copied()
146 .collect()
147 })
148 .await;
149 dense_fields
150 .iter()
151 .filter(|(field, _)| !built.contains(&field.0))
152 .cloned()
153 .collect()
154 }
155
156 async fn collect_vectors_for_training(
161 &self,
162 segment_ids: &[String],
163 fields_to_build: &[(Field, DenseVectorConfig)],
164 ) -> Result<FxHashMap<u32, Vec<Vec<f32>>>> {
165 const MAX_TRAINING_VECTORS: usize = 100_000;
167
168 let mut all_vectors: FxHashMap<u32, Vec<Vec<f32>>> = FxHashMap::default();
169 let mut total_skipped = 0usize;
170
171 for id_str in segment_ids {
172 let segment_id = SegmentId::from_hex(id_str)
173 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
174 let reader = SegmentReader::open(
175 self.directory.as_ref(),
176 segment_id,
177 Arc::clone(&self.schema),
178 self.config.term_cache_blocks,
179 )
180 .await?;
181
182 for (field_id, lazy_flat) in reader.flat_vectors() {
183 if !fields_to_build.iter().any(|(f, _)| f.0 == *field_id) {
184 continue;
185 }
186 let entry = all_vectors.entry(*field_id).or_default();
187 let remaining = MAX_TRAINING_VECTORS.saturating_sub(entry.len());
188
189 if remaining == 0 {
190 total_skipped += lazy_flat.num_vectors;
191 continue;
192 }
193
194 let n = lazy_flat.num_vectors;
195 let dim = lazy_flat.dim;
196 let quant = lazy_flat.quantization;
197
198 let indices: Vec<usize> = if n <= remaining {
200 (0..n).collect()
201 } else {
202 let step = (n / remaining).max(1);
203 (0..n).step_by(step).take(remaining).collect()
204 };
205
206 if indices.len() < n {
207 total_skipped += n - indices.len();
208 }
209
210 const BATCH: usize = 1024;
212 let mut f32_buf = vec![0f32; BATCH * dim];
213 for chunk in indices.chunks(BATCH) {
214 let start = chunk[0];
216 let end = *chunk.last().unwrap();
217 if end - start + 1 == chunk.len() {
218 if let Ok(batch_bytes) =
220 lazy_flat.read_vectors_batch(start, chunk.len()).await
221 {
222 let floats = chunk.len() * dim;
223 f32_buf.resize(floats, 0.0);
224 crate::segment::dequantize_raw(
225 batch_bytes.as_slice(),
226 quant,
227 floats,
228 &mut f32_buf,
229 );
230 for i in 0..chunk.len() {
231 entry.push(f32_buf[i * dim..(i + 1) * dim].to_vec());
232 }
233 }
234 } else {
235 f32_buf.resize(dim, 0.0);
237 for &idx in chunk {
238 if let Ok(()) = lazy_flat.read_vector_into(idx, &mut f32_buf).await {
239 entry.push(f32_buf[..dim].to_vec());
240 }
241 }
242 }
243 }
244 }
245 }
246
247 if total_skipped > 0 {
248 let collected: usize = all_vectors.values().map(|v| v.len()).sum();
249 log::info!(
250 "Sampled {} vectors for training (skipped {}, max {} per field)",
251 collected,
252 total_skipped,
253 MAX_TRAINING_VECTORS,
254 );
255 }
256
257 Ok(all_vectors)
258 }
259
260 async fn train_field_index(
262 &self,
263 field: Field,
264 config: &DenseVectorConfig,
265 all_vectors: &FxHashMap<u32, Vec<Vec<f32>>>,
266 ) -> Result<()> {
267 let field_id = field.0;
268 let vectors = match all_vectors.get(&field_id) {
269 Some(v) if !v.is_empty() => v,
270 _ => return Ok(()),
271 };
272
273 let dim = config.dim;
274 let num_vectors = vectors.len();
275 let num_clusters = config.optimal_num_clusters(num_vectors);
276
277 log::info!(
278 "Training vector index for field {} with {} vectors, {} clusters (dim={})",
279 field_id,
280 num_vectors,
281 num_clusters,
282 dim,
283 );
284
285 let centroids_filename = format!("field_{}_centroids.bin", field_id);
286 let mut codebook_filename: Option<String> = None;
287
288 match config.index_type {
289 VectorIndexType::IvfRaBitQ => {
290 self.train_ivf_rabitq(field_id, dim, num_clusters, vectors, ¢roids_filename)
291 .await?;
292 }
293 VectorIndexType::ScaNN => {
294 codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
295 self.train_scann(
296 field_id,
297 dim,
298 num_clusters,
299 vectors,
300 ¢roids_filename,
301 codebook_filename.as_ref().unwrap(),
302 )
303 .await?;
304 }
305 _ => {
306 return Ok(());
308 }
309 }
310
311 self.segment_manager
313 .update_metadata(|meta| {
314 meta.init_field(field_id, config.index_type);
315 meta.total_vectors = num_vectors;
316 meta.mark_field_built(
317 field_id,
318 num_vectors,
319 num_clusters,
320 centroids_filename.clone(),
321 codebook_filename.clone(),
322 );
323 })
324 .await?;
325
326 Ok(())
327 }
328
329 async fn save_trained_artifact(
331 &self,
332 artifact: &impl serde::Serialize,
333 filename: &str,
334 ) -> Result<()> {
335 let bytes = bincode::serde::encode_to_vec(artifact, bincode::config::standard())
336 .map_err(|e| Error::Serialization(e.to_string()))?;
337 self.directory
338 .write(std::path::Path::new(filename), &bytes)
339 .await?;
340 Ok(())
341 }
342
343 async fn train_ivf_rabitq(
345 &self,
346 field_id: u32,
347 dim: usize,
348 num_clusters: usize,
349 vectors: &[Vec<f32>],
350 centroids_filename: &str,
351 ) -> Result<()> {
352 let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
353 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
354 self.save_trained_artifact(¢roids, centroids_filename)
355 .await?;
356
357 log::info!(
358 "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
359 field_id,
360 centroids.num_clusters
361 );
362 Ok(())
363 }
364
365 async fn train_scann(
367 &self,
368 field_id: u32,
369 dim: usize,
370 num_clusters: usize,
371 vectors: &[Vec<f32>],
372 centroids_filename: &str,
373 codebook_filename: &str,
374 ) -> Result<()> {
375 let coarse_config = crate::structures::CoarseConfig::new(dim, num_clusters);
376 let centroids = crate::structures::CoarseCentroids::train(&coarse_config, vectors);
377 self.save_trained_artifact(¢roids, centroids_filename)
378 .await?;
379
380 let pq_config = crate::structures::PQConfig::new(dim);
381 let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
382 self.save_trained_artifact(&codebook, codebook_filename)
383 .await?;
384
385 log::info!(
386 "Saved ScaNN centroids and codebook for field {} ({} clusters)",
387 field_id,
388 centroids.num_clusters
389 );
390 Ok(())
391 }
392}