1use std::path::Path;
6use std::sync::Arc;
7use std::sync::atomic::{AtomicUsize, Ordering};
8
9use rustc_hash::FxHashMap;
10use tokio::sync::Mutex as AsyncMutex;
11use tokio::sync::mpsc;
12
13use crate::DocId;
14use crate::directories::DirectoryWriter;
15use crate::dsl::{Document, Field, Schema};
16use crate::error::{Error, Result};
17use crate::segment::{
18 SegmentBuilder, SegmentBuilderConfig, SegmentId, SegmentMerger, SegmentReader,
19};
20use crate::tokenizer::BoxedTokenizer;
21
22use super::IndexConfig;
23
24pub struct IndexWriter<D: DirectoryWriter + 'static> {
32 directory: Arc<D>,
33 schema: Arc<Schema>,
34 config: IndexConfig,
35 builder_config: SegmentBuilderConfig,
36 tokenizers: FxHashMap<Field, BoxedTokenizer>,
37 builders: Vec<AsyncMutex<Option<SegmentBuilder>>>,
39 segment_manager: Arc<crate::merge::SegmentManager<D>>,
41 segment_id_sender: mpsc::UnboundedSender<String>,
43 segment_id_receiver: AsyncMutex<mpsc::UnboundedReceiver<String>>,
45 pending_builds: Arc<AtomicUsize>,
47 metadata: AsyncMutex<super::IndexMetadata>,
49}
50
51impl<D: DirectoryWriter + 'static> IndexWriter<D> {
52 pub async fn create(directory: D, schema: Schema, config: IndexConfig) -> Result<Self> {
54 Self::create_with_config(directory, schema, config, SegmentBuilderConfig::default()).await
55 }
56
57 pub async fn create_with_config(
59 directory: D,
60 schema: Schema,
61 config: IndexConfig,
62 builder_config: SegmentBuilderConfig,
63 ) -> Result<Self> {
64 let directory = Arc::new(directory);
65 let schema = Arc::new(schema);
66
67 let schema_bytes =
69 serde_json::to_vec(&*schema).map_err(|e| Error::Serialization(e.to_string()))?;
70 directory
71 .write(Path::new("schema.json"), &schema_bytes)
72 .await?;
73
74 let segments_bytes = serde_json::to_vec(&Vec::<String>::new())
76 .map_err(|e| Error::Serialization(e.to_string()))?;
77 directory
78 .write(Path::new("segments.json"), &segments_bytes)
79 .await?;
80
81 let num_builders = config.num_indexing_threads.max(1);
83 let mut builders = Vec::with_capacity(num_builders);
84 for _ in 0..num_builders {
85 builders.push(AsyncMutex::new(None));
86 }
87
88 let (segment_id_sender, segment_id_receiver) = mpsc::unbounded_channel();
90
91 let segment_manager = Arc::new(crate::merge::SegmentManager::new(
93 Arc::clone(&directory),
94 Arc::clone(&schema),
95 Vec::new(),
96 config.merge_policy.clone_box(),
97 config.term_cache_blocks,
98 ));
99
100 let metadata = super::IndexMetadata::new();
102 metadata.save(directory.as_ref()).await?;
103
104 Ok(Self {
105 directory,
106 schema,
107 config,
108 builder_config,
109 tokenizers: FxHashMap::default(),
110 builders,
111 segment_manager,
112 segment_id_sender,
113 segment_id_receiver: AsyncMutex::new(segment_id_receiver),
114 pending_builds: Arc::new(AtomicUsize::new(0)),
115 metadata: AsyncMutex::new(metadata),
116 })
117 }
118
119 pub async fn open(directory: D, config: IndexConfig) -> Result<Self> {
121 Self::open_with_config(directory, config, SegmentBuilderConfig::default()).await
122 }
123
124 pub async fn open_with_config(
126 directory: D,
127 config: IndexConfig,
128 builder_config: SegmentBuilderConfig,
129 ) -> Result<Self> {
130 let directory = Arc::new(directory);
131
132 let schema_slice = directory.open_read(Path::new("schema.json")).await?;
134 let schema_bytes = schema_slice.read_bytes().await?;
135 let schema: Schema = serde_json::from_slice(schema_bytes.as_slice())
136 .map_err(|e| Error::Serialization(e.to_string()))?;
137 let schema = Arc::new(schema);
138
139 let metadata = super::IndexMetadata::load(directory.as_ref()).await?;
141 let segment_ids = metadata.segments.clone();
142
143 let num_builders = config.num_indexing_threads.max(1);
145 let mut builders = Vec::with_capacity(num_builders);
146 for _ in 0..num_builders {
147 builders.push(AsyncMutex::new(None));
148 }
149
150 let (segment_id_sender, segment_id_receiver) = mpsc::unbounded_channel();
152
153 let segment_manager = Arc::new(crate::merge::SegmentManager::new(
155 Arc::clone(&directory),
156 Arc::clone(&schema),
157 segment_ids,
158 config.merge_policy.clone_box(),
159 config.term_cache_blocks,
160 ));
161
162 Ok(Self {
163 directory,
164 schema,
165 config,
166 builder_config,
167 tokenizers: FxHashMap::default(),
168 builders,
169 segment_manager,
170 segment_id_sender,
171 segment_id_receiver: AsyncMutex::new(segment_id_receiver),
172 pending_builds: Arc::new(AtomicUsize::new(0)),
173 metadata: AsyncMutex::new(metadata),
174 })
175 }
176
177 pub fn schema(&self) -> &Schema {
179 &self.schema
180 }
181
182 pub fn set_tokenizer<T: crate::tokenizer::Tokenizer>(&mut self, field: Field, tokenizer: T) {
184 self.tokenizers.insert(field, Box::new(tokenizer));
185 }
186
187 pub async fn add_document(&self, doc: Document) -> Result<DocId> {
193 use rand::Rng;
194
195 let builder_idx = rand::rng().random_range(0..self.builders.len());
197
198 let mut builder_guard = self.builders[builder_idx].lock().await;
199
200 if builder_guard.is_none() {
202 let mut builder =
203 SegmentBuilder::new((*self.schema).clone(), self.builder_config.clone())?;
204 for (field, tokenizer) in &self.tokenizers {
205 builder.set_tokenizer(*field, tokenizer.clone_box());
206 }
207 *builder_guard = Some(builder);
208 }
209
210 let builder = builder_guard.as_mut().unwrap();
211 let doc_id = builder.add_document(doc)?;
212
213 if builder.num_docs() >= self.config.max_docs_per_segment {
215 let full_builder = builder_guard.take().unwrap();
216 drop(builder_guard); self.spawn_background_build(full_builder);
218 }
219
220 Ok(doc_id)
221 }
222
223 fn spawn_background_build(&self, builder: SegmentBuilder) {
228 let directory = Arc::clone(&self.directory);
229 let segment_id = SegmentId::new();
230 let segment_hex = segment_id.to_hex();
231 let sender = self.segment_id_sender.clone();
232 let segment_manager = Arc::clone(&self.segment_manager);
233
234 self.pending_builds.fetch_add(1, Ordering::SeqCst);
235
236 tokio::spawn(async move {
238 match builder.build(directory.as_ref(), segment_id).await {
239 Ok(_) => {
240 segment_manager.register_segment(segment_hex.clone()).await;
242 let _ = sender.send(segment_hex);
244 }
245 Err(e) => {
246 eprintln!("Background segment build failed: {:?}", e);
248 }
249 }
250 });
251 }
252
253 async fn collect_completed_segments(&self) {
257 let mut receiver = self.segment_id_receiver.lock().await;
258 while let Ok(_segment_hex) = receiver.try_recv() {
259 self.pending_builds.fetch_sub(1, Ordering::SeqCst);
261 }
262 }
263
264 pub fn pending_build_count(&self) -> usize {
266 self.pending_builds.load(Ordering::SeqCst)
267 }
268
269 pub fn pending_merge_count(&self) -> usize {
271 self.segment_manager.pending_merge_count()
272 }
273
274 pub async fn maybe_merge(&self) {
279 self.segment_manager.maybe_merge().await;
280 }
281
282 pub async fn wait_for_merges(&self) {
284 self.segment_manager.wait_for_merges().await;
285 }
286
287 pub async fn cleanup_orphan_segments(&self) -> Result<usize> {
295 self.segment_manager.cleanup_orphan_segments().await
296 }
297
298 pub async fn get_builder_stats(&self) -> Option<crate::segment::SegmentBuilderStats> {
300 let mut total_stats: Option<crate::segment::SegmentBuilderStats> = None;
301
302 for builder_mutex in &self.builders {
303 let guard = builder_mutex.lock().await;
304 if let Some(builder) = guard.as_ref() {
305 let stats = builder.stats();
306 if let Some(ref mut total) = total_stats {
307 total.num_docs += stats.num_docs;
308 total.unique_terms += stats.unique_terms;
309 total.postings_in_memory += stats.postings_in_memory;
310 total.interned_strings += stats.interned_strings;
311 total.doc_field_lengths_size += stats.doc_field_lengths_size;
312 } else {
313 total_stats = Some(stats);
314 }
315 }
316 }
317
318 total_stats
319 }
320
321 pub async fn flush(&self) -> Result<()> {
327 self.collect_completed_segments().await;
329
330 for builder_mutex in &self.builders {
332 let mut guard = builder_mutex.lock().await;
333 if let Some(builder) = guard.take()
334 && builder.num_docs() > 0
335 {
336 self.spawn_background_build(builder);
337 }
338 }
339
340 Ok(())
341 }
342
343 pub async fn commit(&self) -> Result<()> {
350 self.flush().await?;
352
353 let mut receiver = self.segment_id_receiver.lock().await;
355 while self.pending_builds.load(Ordering::SeqCst) > 0 {
356 match receiver.recv().await {
357 Some(_segment_hex) => {
358 self.pending_builds.fetch_sub(1, Ordering::SeqCst);
359 }
360 None => break, }
362 }
363 drop(receiver);
364
365 let segment_ids = self.segment_manager.get_segment_ids().await;
367 {
368 let mut meta = self.metadata.lock().await;
369 meta.segments = segment_ids;
370 meta.save(self.directory.as_ref()).await?;
371 }
372
373 self.maybe_build_vector_index().await?;
375
376 Ok(())
377 }
378
379 async fn maybe_build_vector_index(&self) -> Result<()> {
381 use crate::dsl::FieldType;
382
383 let dense_fields: Vec<(Field, crate::dsl::DenseVectorConfig)> = self
385 .schema
386 .fields()
387 .filter_map(|(field, entry)| {
388 if entry.field_type == FieldType::DenseVector && entry.indexed {
389 entry
390 .dense_vector_config
391 .as_ref()
392 .filter(|c| !c.is_flat())
393 .map(|c| (field, c.clone()))
394 } else {
395 None
396 }
397 })
398 .collect();
399
400 if dense_fields.is_empty() {
401 return Ok(());
402 }
403
404 let segment_ids = self.segment_manager.get_segment_ids().await;
406 let mut total_vectors = 0usize;
407 let mut doc_offset = 0u32;
408
409 for id_str in &segment_ids {
410 if let Some(segment_id) = SegmentId::from_hex(id_str)
411 && let Ok(reader) = SegmentReader::open(
412 self.directory.as_ref(),
413 segment_id,
414 Arc::clone(&self.schema),
415 doc_offset,
416 self.config.term_cache_blocks,
417 )
418 .await
419 {
420 for index in reader.vector_indexes().values() {
422 if let crate::segment::VectorIndex::Flat(flat_data) = index {
423 total_vectors += flat_data.vectors.len();
424 }
425 }
426 doc_offset += reader.meta().num_docs;
427 }
428 }
429
430 {
432 let mut meta = self.metadata.lock().await;
433 meta.total_vectors = total_vectors;
434 }
435
436 let should_build = {
438 let meta = self.metadata.lock().await;
439 dense_fields.iter().any(|(field, config)| {
440 let threshold = config.build_threshold.unwrap_or(1000);
441 meta.should_build_field(field.0, threshold)
442 })
443 };
444
445 if should_build {
446 log::info!(
447 "Threshold crossed ({} vectors), auto-triggering vector index build",
448 total_vectors
449 );
450 self.build_vector_index().await?;
451 }
452
453 Ok(())
454 }
455
456 async fn do_merge(&self) -> Result<()> {
458 let segment_ids = self.segment_manager.get_segment_ids().await;
459
460 if segment_ids.len() < 2 {
461 return Ok(());
462 }
463
464 let ids_to_merge: Vec<String> = segment_ids.clone();
465 drop(segment_ids);
466
467 let mut readers = Vec::new();
469 let mut doc_offset = 0u32;
470
471 for id_str in &ids_to_merge {
472 let segment_id = SegmentId::from_hex(id_str)
473 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
474 let reader = SegmentReader::open(
475 self.directory.as_ref(),
476 segment_id,
477 Arc::clone(&self.schema),
478 doc_offset,
479 self.config.term_cache_blocks,
480 )
481 .await?;
482 doc_offset += reader.meta().num_docs;
483 readers.push(reader);
484 }
485
486 let merger = SegmentMerger::new(Arc::clone(&self.schema));
488 let new_segment_id = SegmentId::new();
489 merger
490 .merge(self.directory.as_ref(), &readers, new_segment_id)
491 .await?;
492
493 {
495 let segment_ids_arc = self.segment_manager.segment_ids();
496 let mut segment_ids = segment_ids_arc.lock().await;
497 segment_ids.clear();
498 segment_ids.push(new_segment_id.to_hex());
499 }
500
501 let segment_ids = self.segment_manager.get_segment_ids().await;
503 {
504 let mut meta = self.metadata.lock().await;
505 meta.segments = segment_ids;
506 meta.save(self.directory.as_ref()).await?;
507 }
508
509 for id_str in ids_to_merge {
511 if let Some(segment_id) = SegmentId::from_hex(&id_str) {
512 let _ = crate::segment::delete_segment(self.directory.as_ref(), segment_id).await;
513 }
514 }
515
516 Ok(())
517 }
518
519 pub async fn force_merge(&self) -> Result<()> {
521 self.commit().await?;
523 self.do_merge().await
525 }
526
527 pub async fn build_vector_index(&self) -> Result<()> {
541 use crate::dsl::{FieldType, VectorIndexType};
542
543 let dense_fields: Vec<(Field, crate::dsl::DenseVectorConfig)> = self
545 .schema
546 .fields()
547 .filter_map(|(field, entry)| {
548 if entry.field_type == FieldType::DenseVector && entry.indexed {
549 entry
550 .dense_vector_config
551 .as_ref()
552 .filter(|c| !c.is_flat())
553 .map(|c| (field, c.clone()))
554 } else {
555 None
556 }
557 })
558 .collect();
559
560 if dense_fields.is_empty() {
561 log::info!("No dense vector fields configured for ANN indexing");
562 return Ok(());
563 }
564
565 let fields_to_build: Vec<_> = {
567 let meta = self.metadata.lock().await;
568 dense_fields
569 .iter()
570 .filter(|(field, _)| !meta.is_field_built(field.0))
571 .cloned()
572 .collect()
573 };
574
575 if fields_to_build.is_empty() {
576 log::info!("All vector fields already built, skipping training");
577 return Ok(());
578 }
579
580 let segment_ids = self.segment_manager.get_segment_ids().await;
581 if segment_ids.is_empty() {
582 return Ok(());
583 }
584
585 let mut all_vectors: rustc_hash::FxHashMap<u32, Vec<Vec<f32>>> =
587 rustc_hash::FxHashMap::default();
588 let mut doc_offset = 0u32;
589
590 for id_str in &segment_ids {
591 let segment_id = SegmentId::from_hex(id_str)
592 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
593 let reader = SegmentReader::open(
594 self.directory.as_ref(),
595 segment_id,
596 Arc::clone(&self.schema),
597 doc_offset,
598 self.config.term_cache_blocks,
599 )
600 .await?;
601
602 for (field_id, index) in reader.vector_indexes() {
604 if fields_to_build.iter().any(|(f, _)| f.0 == *field_id)
606 && let crate::segment::VectorIndex::Flat(flat_data) = index
607 {
608 all_vectors
609 .entry(*field_id)
610 .or_default()
611 .extend(flat_data.vectors.iter().cloned());
612 }
613 }
614
615 doc_offset += reader.meta().num_docs;
616 }
617
618 for (field, config) in &fields_to_build {
620 let field_id = field.0;
621 if let Some(vectors) = all_vectors.get(&field_id) {
622 if vectors.is_empty() {
623 continue;
624 }
625
626 let index_dim = config.index_dim();
627 let num_vectors = vectors.len();
628 let num_clusters = config.optimal_num_clusters(num_vectors);
629
630 log::info!(
631 "Training vector index for field {} with {} vectors, {} clusters",
632 field_id,
633 num_vectors,
634 num_clusters
635 );
636
637 let centroids_filename = format!("field_{}_centroids.bin", field_id);
638 let mut codebook_filename: Option<String> = None;
639
640 match config.index_type {
641 VectorIndexType::IvfRaBitQ => {
642 let coarse_config =
644 crate::structures::CoarseConfig::new(index_dim, num_clusters);
645 let centroids =
646 crate::structures::CoarseCentroids::train(&coarse_config, vectors);
647
648 let centroids_path = std::path::Path::new(¢roids_filename);
650 let centroids_bytes = serde_json::to_vec(¢roids)
651 .map_err(|e| Error::Serialization(e.to_string()))?;
652 self.directory
653 .write(centroids_path, ¢roids_bytes)
654 .await?;
655
656 log::info!(
657 "Saved IVF-RaBitQ centroids for field {} ({} clusters)",
658 field_id,
659 centroids.num_clusters
660 );
661 }
662 VectorIndexType::ScaNN => {
663 let coarse_config =
665 crate::structures::CoarseConfig::new(index_dim, num_clusters);
666 let centroids =
667 crate::structures::CoarseCentroids::train(&coarse_config, vectors);
668
669 let pq_config = crate::structures::PQConfig::new(index_dim);
671 let codebook = crate::structures::PQCodebook::train(pq_config, vectors, 10);
672
673 let centroids_path = std::path::Path::new(¢roids_filename);
675 let centroids_bytes = serde_json::to_vec(¢roids)
676 .map_err(|e| Error::Serialization(e.to_string()))?;
677 self.directory
678 .write(centroids_path, ¢roids_bytes)
679 .await?;
680
681 codebook_filename = Some(format!("field_{}_codebook.bin", field_id));
682 let codebook_path =
683 std::path::Path::new(codebook_filename.as_ref().unwrap());
684 let codebook_bytes = serde_json::to_vec(&codebook)
685 .map_err(|e| Error::Serialization(e.to_string()))?;
686 self.directory.write(codebook_path, &codebook_bytes).await?;
687
688 log::info!(
689 "Saved ScaNN centroids and codebook for field {} ({} clusters)",
690 field_id,
691 centroids.num_clusters
692 );
693 }
694 _ => {
695 continue;
697 }
698 }
699
700 {
702 let mut meta = self.metadata.lock().await;
703 meta.init_field(field_id, config.index_type);
704 meta.total_vectors = num_vectors;
705 meta.mark_field_built(
706 field_id,
707 num_vectors,
708 num_clusters,
709 centroids_filename,
710 codebook_filename,
711 );
712 meta.save(self.directory.as_ref()).await?;
713 }
714 }
715 }
716
717 log::info!("Vector index training complete. Rebuilding segments with ANN indexes...");
718
719 self.rebuild_segments_with_ann().await?;
721
722 Ok(())
723 }
724
725 async fn rebuild_segments_with_ann(&self) -> Result<()> {
727 use crate::segment::{SegmentMerger, TrainedVectorStructures};
728
729 let segment_ids = self.segment_manager.get_segment_ids().await;
730 if segment_ids.is_empty() {
731 return Ok(());
732 }
733
734 let (trained_centroids, trained_codebooks) = {
736 let meta = self.metadata.lock().await;
737 meta.load_trained_structures(self.directory.as_ref()).await
738 };
739
740 if trained_centroids.is_empty() {
741 log::info!("No trained structures to rebuild with");
742 return Ok(());
743 }
744
745 let trained = TrainedVectorStructures {
746 centroids: trained_centroids,
747 codebooks: trained_codebooks,
748 };
749
750 let mut readers = Vec::new();
752 let mut doc_offset = 0u32;
753
754 for id_str in &segment_ids {
755 let segment_id = SegmentId::from_hex(id_str)
756 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))?;
757 let reader = SegmentReader::open(
758 self.directory.as_ref(),
759 segment_id,
760 Arc::clone(&self.schema),
761 doc_offset,
762 self.config.term_cache_blocks,
763 )
764 .await?;
765 doc_offset += reader.meta().num_docs;
766 readers.push(reader);
767 }
768
769 let merger = SegmentMerger::new(Arc::clone(&self.schema));
771 let new_segment_id = SegmentId::new();
772 merger
773 .merge_with_ann(self.directory.as_ref(), &readers, new_segment_id, &trained)
774 .await?;
775
776 {
778 let segment_ids_arc = self.segment_manager.segment_ids();
779 let mut segment_ids = segment_ids_arc.lock().await;
780 let old_ids: Vec<String> = segment_ids.clone();
781 segment_ids.clear();
782 segment_ids.push(new_segment_id.to_hex());
783
784 let mut meta = self.metadata.lock().await;
786 meta.segments = segment_ids.clone();
787 meta.save(self.directory.as_ref()).await?;
788
789 for id_str in old_ids {
791 if let Some(segment_id) = SegmentId::from_hex(&id_str) {
792 let _ =
793 crate::segment::delete_segment(self.directory.as_ref(), segment_id).await;
794 }
795 }
796 }
797
798 log::info!("Segments rebuilt with ANN indexes");
799 Ok(())
800 }
801
802 pub async fn total_vector_count(&self) -> usize {
804 self.metadata.lock().await.total_vectors
805 }
806
807 pub async fn is_vector_index_built(&self, field: Field) -> bool {
809 self.metadata.lock().await.is_field_built(field.0)
810 }
811
812 pub async fn rebuild_vector_index(&self) -> Result<()> {
821 use crate::dsl::FieldType;
822
823 let dense_fields: Vec<Field> = self
825 .schema
826 .fields()
827 .filter_map(|(field, entry)| {
828 if entry.field_type == FieldType::DenseVector && entry.indexed {
829 Some(field)
830 } else {
831 None
832 }
833 })
834 .collect();
835
836 if dense_fields.is_empty() {
837 return Ok(());
838 }
839
840 {
842 let mut meta = self.metadata.lock().await;
843 for field in &dense_fields {
844 if let Some(field_meta) = meta.vector_fields.get_mut(&field.0) {
845 field_meta.state = super::VectorIndexState::Flat;
846 if let Some(ref centroids_file) = field_meta.centroids_file {
848 let _ = self
849 .directory
850 .delete(std::path::Path::new(centroids_file))
851 .await;
852 }
853 if let Some(ref codebook_file) = field_meta.codebook_file {
854 let _ = self
855 .directory
856 .delete(std::path::Path::new(codebook_file))
857 .await;
858 }
859 field_meta.centroids_file = None;
860 field_meta.codebook_file = None;
861 }
862 }
863 meta.save(self.directory.as_ref()).await?;
864 }
865
866 log::info!("Reset vector index state to Flat, triggering rebuild...");
867
868 self.build_vector_index().await
870 }
871}