1use std::collections::HashSet;
28use std::sync::Arc;
29
30use arc_swap::ArcSwapOption;
31use tokio::sync::Mutex as AsyncMutex;
32use tokio::task::JoinHandle;
33
34use crate::directories::DirectoryWriter;
35use crate::error::{Error, Result};
36use crate::index::IndexMetadata;
37use crate::segment::{SegmentId, SegmentSnapshot, SegmentTracker, TrainedVectorStructures};
38#[cfg(feature = "native")]
39use crate::segment::{SegmentMerger, SegmentReader};
40
41use super::{MergePolicy, SegmentInfo};
42
43struct MergeInventory {
54 inner: parking_lot::Mutex<HashSet<String>>,
55}
56
57impl MergeInventory {
58 fn new() -> Self {
59 Self {
60 inner: parking_lot::Mutex::new(HashSet::new()),
61 }
62 }
63
64 fn try_register(self: &Arc<Self>, segment_ids: Vec<String>) -> Option<MergeGuard> {
67 let mut inner = self.inner.lock();
68 for id in &segment_ids {
70 if inner.contains(id) {
71 log::debug!(
72 "[merge_inventory] rejected: {} overlaps with active merge ({} active IDs)",
73 id,
74 inner.len()
75 );
76 return None;
77 }
78 }
79 log::debug!(
80 "[merge_inventory] registered {} IDs (total active: {})",
81 segment_ids.len(),
82 inner.len() + segment_ids.len()
83 );
84 for id in &segment_ids {
85 inner.insert(id.clone());
86 }
87 Some(MergeGuard {
88 inventory: Arc::clone(self),
89 segment_ids,
90 })
91 }
92
93 fn snapshot(&self) -> HashSet<String> {
95 self.inner.lock().clone()
96 }
97
98 fn contains(&self, segment_id: &str) -> bool {
100 self.inner.lock().contains(segment_id)
101 }
102}
103
104struct MergeGuard {
108 inventory: Arc<MergeInventory>,
109 segment_ids: Vec<String>,
110}
111
112impl Drop for MergeGuard {
113 fn drop(&mut self) {
114 let mut inner = self.inventory.inner.lock();
115 for id in &self.segment_ids {
116 inner.remove(id);
117 }
118 }
119}
120
121struct ManagerState {
123 metadata: IndexMetadata,
124 merge_policy: Box<dyn MergePolicy>,
125}
126
127pub struct SegmentManager<D: DirectoryWriter + 'static> {
131 state: AsyncMutex<ManagerState>,
133
134 merge_inventory: Arc<MergeInventory>,
137
138 merge_handles: AsyncMutex<Vec<JoinHandle<()>>>,
140
141 trained: ArcSwapOption<TrainedVectorStructures>,
143
144 tracker: Arc<SegmentTracker>,
146
147 delete_fn: Arc<dyn Fn(Vec<SegmentId>) + Send + Sync>,
149
150 directory: Arc<D>,
152 schema: Arc<crate::dsl::Schema>,
154 term_cache_blocks: usize,
156 max_concurrent_merges: usize,
158}
159
160impl<D: DirectoryWriter + 'static> SegmentManager<D> {
161 pub fn new(
163 directory: Arc<D>,
164 schema: Arc<crate::dsl::Schema>,
165 metadata: IndexMetadata,
166 merge_policy: Box<dyn MergePolicy>,
167 term_cache_blocks: usize,
168 max_concurrent_merges: usize,
169 ) -> Self {
170 let tracker = Arc::new(SegmentTracker::new());
171 for seg_id in metadata.segment_metas.keys() {
172 tracker.register(seg_id);
173 }
174
175 let delete_fn: Arc<dyn Fn(Vec<SegmentId>) + Send + Sync> = {
176 let dir = Arc::clone(&directory);
177 Arc::new(move |segment_ids| {
178 let Ok(handle) = tokio::runtime::Handle::try_current() else {
181 return;
182 };
183 let dir = Arc::clone(&dir);
184 handle.spawn(async move {
185 for segment_id in segment_ids {
186 log::info!(
187 "[segment_cleanup] deleting deferred segment {}",
188 segment_id.0
189 );
190 let _ = crate::segment::delete_segment(dir.as_ref(), segment_id).await;
191 }
192 });
193 })
194 };
195
196 Self {
197 state: AsyncMutex::new(ManagerState {
198 metadata,
199 merge_policy,
200 }),
201 merge_inventory: Arc::new(MergeInventory::new()),
202 merge_handles: AsyncMutex::new(Vec::new()),
203 trained: ArcSwapOption::new(None),
204 tracker,
205 delete_fn,
206 directory,
207 schema,
208 term_cache_blocks,
209 max_concurrent_merges: max_concurrent_merges.max(1),
210 }
211 }
212
213 pub async fn get_segment_ids(&self) -> Vec<String> {
219 self.state.lock().await.metadata.segment_ids()
220 }
221
222 pub fn trained(&self) -> Option<Arc<TrainedVectorStructures>> {
224 self.trained.load_full()
225 }
226
227 pub async fn load_and_publish_trained(&self) {
230 let vector_fields = {
232 let st = self.state.lock().await;
233 st.metadata.vector_fields.clone()
234 };
235 let trained =
237 IndexMetadata::load_trained_from_fields(&vector_fields, self.directory.as_ref()).await;
238 if let Some(t) = trained {
239 self.trained.store(Some(Arc::new(t)));
240 }
241 }
242
243 pub(crate) fn clear_trained(&self) {
245 self.trained.store(None);
246 }
247
248 pub(crate) async fn read_metadata<F, R>(&self, f: F) -> R
250 where
251 F: FnOnce(&IndexMetadata) -> R,
252 {
253 let st = self.state.lock().await;
254 f(&st.metadata)
255 }
256
257 pub(crate) async fn update_metadata<F>(&self, f: F) -> Result<()>
259 where
260 F: FnOnce(&mut IndexMetadata),
261 {
262 let mut st = self.state.lock().await;
263 f(&mut st.metadata);
264 st.metadata.save(self.directory.as_ref()).await
265 }
266
267 pub async fn acquire_snapshot(&self) -> SegmentSnapshot {
270 let acquired = {
271 let st = self.state.lock().await;
272 let segment_ids = st.metadata.segment_ids();
273 self.tracker.acquire(&segment_ids)
274 };
275
276 SegmentSnapshot::with_delete_fn(
277 Arc::clone(&self.tracker),
278 acquired,
279 Arc::clone(&self.delete_fn),
280 )
281 }
282
283 pub fn tracker(&self) -> Arc<SegmentTracker> {
285 Arc::clone(&self.tracker)
286 }
287
288 pub fn directory(&self) -> Arc<D> {
290 Arc::clone(&self.directory)
291 }
292}
293
294#[cfg(feature = "native")]
299impl<D: DirectoryWriter + 'static> SegmentManager<D> {
300 pub async fn commit(&self, new_segments: Vec<(String, u32)>) -> Result<()> {
302 let mut st = self.state.lock().await;
303 for (segment_id, num_docs) in new_segments {
304 if !st.metadata.has_segment(&segment_id) {
305 st.metadata.add_segment(segment_id.clone(), num_docs);
306 self.tracker.register(&segment_id);
307 }
308 }
309 st.metadata.save(self.directory.as_ref()).await
310 }
311
312 pub async fn maybe_merge(self: &Arc<Self>) {
324 let slots_available = {
326 let mut handles = self.merge_handles.lock().await;
327 handles.retain(|h| !h.is_finished());
328 self.max_concurrent_merges.saturating_sub(handles.len())
329 };
330
331 if slots_available == 0 {
332 log::debug!("[maybe_merge] at max concurrent merges, skipping");
333 return;
334 }
335
336 let new_handles = {
340 let st = self.state.lock().await;
341
342 let segments: Vec<SegmentInfo> = st
344 .metadata
345 .segment_metas
346 .iter()
347 .filter(|(id, _)| {
348 !self.tracker.is_pending_deletion(id) && !self.merge_inventory.contains(id)
349 })
350 .map(|(id, info)| SegmentInfo {
351 id: id.clone(),
352 num_docs: info.num_docs,
353 })
354 .collect();
355
356 log::debug!("[maybe_merge] {} eligible segments", segments.len());
357
358 let candidates = st.merge_policy.find_merges(&segments);
359
360 if candidates.is_empty() {
361 return;
362 }
363
364 log::debug!(
365 "[maybe_merge] {} merge candidates, {} slots available",
366 candidates.len(),
367 slots_available
368 );
369
370 let mut handles = Vec::new();
371 for c in candidates {
372 if handles.len() >= slots_available {
373 break;
374 }
375 if let Some(h) = self.spawn_merge(c.segment_ids) {
376 handles.push(h);
377 }
378 }
379 handles
380 };
382
383 if !new_handles.is_empty() {
384 self.merge_handles.lock().await.extend(new_handles);
385 }
386 }
387
388 fn spawn_merge(self: &Arc<Self>, segment_ids_to_merge: Vec<String>) -> Option<JoinHandle<()>> {
397 let output_id = SegmentId::new();
398 let output_hex = output_id.to_hex();
399
400 let mut all_ids = segment_ids_to_merge.clone();
401 all_ids.push(output_hex);
402
403 let guard = match self.merge_inventory.try_register(all_ids) {
404 Some(g) => g,
405 None => {
406 log::debug!("[spawn_merge] skipped: segments overlap with active merge");
407 return None;
408 }
409 };
410
411 let sm = Arc::clone(self);
412 let ids = segment_ids_to_merge;
413
414 Some(tokio::spawn(async move {
415 let _guard = guard;
416
417 let trained_snap = sm.trained();
418 let result = Self::do_merge(
419 sm.directory.as_ref(),
420 &sm.schema,
421 &ids,
422 output_id,
423 sm.term_cache_blocks,
424 trained_snap.as_deref(),
425 )
426 .await;
427
428 match result {
429 Ok((new_id, doc_count)) => {
430 if let Err(e) = sm.replace_segments(&ids, new_id, doc_count, false).await {
431 log::error!("[merge] Failed to replace segments after merge: {:?}", e);
432 }
433 }
434 Err(e) => {
435 log::error!(
436 "[merge] Background merge failed for segments {:?}: {:?}",
437 ids,
438 e
439 );
440 }
441 }
442 sm.maybe_merge().await;
447 }))
448 }
449
450 async fn replace_segments(
454 &self,
455 old_ids: &[String],
456 new_id: String,
457 doc_count: u32,
458 reordered: bool,
459 ) -> Result<()> {
460 self.tracker.register(&new_id);
461
462 {
463 let mut st = self.state.lock().await;
464 let parent_gen = old_ids
466 .iter()
467 .filter_map(|id| st.metadata.segment_metas.get(id))
468 .map(|info| info.generation)
469 .max()
470 .unwrap_or(0);
471 let ancestors: Vec<String> = old_ids.to_vec();
472
473 for id in old_ids {
474 st.metadata.remove_segment(id);
475 }
476 st.metadata
477 .add_merged_segment(new_id, doc_count, ancestors, parent_gen + 1, reordered);
478 st.metadata.save(self.directory.as_ref()).await?;
480 }
481
482 let ready_to_delete = self.tracker.mark_for_deletion(old_ids);
483 for segment_id in ready_to_delete {
484 let _ = crate::segment::delete_segment(self.directory.as_ref(), segment_id).await;
485 }
486 Ok(())
487 }
488
489 pub(crate) async fn do_merge(
493 directory: &D,
494 schema: &Arc<crate::dsl::Schema>,
495 segment_ids_to_merge: &[String],
496 output_segment_id: SegmentId,
497 term_cache_blocks: usize,
498 trained: Option<&TrainedVectorStructures>,
499 ) -> Result<(String, u32)> {
500 let output_hex = output_segment_id.to_hex();
501 let load_start = std::time::Instant::now();
502
503 let segment_ids: Vec<SegmentId> = segment_ids_to_merge
504 .iter()
505 .map(|id_str| {
506 SegmentId::from_hex(id_str)
507 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))
508 })
509 .collect::<Result<Vec<_>>>()?;
510
511 let schema_arc = Arc::clone(schema);
512 let futures: Vec<_> = segment_ids
513 .iter()
514 .map(|&sid| {
515 let sch = Arc::clone(&schema_arc);
516 async move { SegmentReader::open(directory, sid, sch, term_cache_blocks).await }
517 })
518 .collect();
519
520 let results = futures::future::join_all(futures).await;
521 let mut readers = Vec::with_capacity(results.len());
522 let mut total_docs = 0u64;
523 for (i, result) in results.into_iter().enumerate() {
524 match result {
525 Ok(r) => {
526 total_docs += r.meta().num_docs as u64;
527 readers.push(r);
528 }
529 Err(e) => {
530 log::error!(
531 "[merge] Failed to open segment {}: {:?}",
532 segment_ids_to_merge[i],
533 e
534 );
535 return Err(e);
536 }
537 }
538 }
539
540 for (i, reader) in readers.iter().enumerate() {
544 let meta_docs = reader.meta().num_docs;
545 let store_docs = reader.store().num_docs();
546 if store_docs != meta_docs {
547 return Err(Error::Corruption(format!(
548 "pre-merge validation: segment {} store has {} docs but meta says {}",
549 segment_ids_to_merge[i], store_docs, meta_docs
550 )));
551 }
552 }
553
554 log::info!(
555 "[merge] loaded {} segment readers in {:.1}s",
556 readers.len(),
557 load_start.elapsed().as_secs_f64()
558 );
559
560 let merger = SegmentMerger::new(Arc::clone(schema));
561
562 log::info!(
563 "[merge] {} segments -> {} (trained={})",
564 segment_ids_to_merge.len(),
565 output_hex,
566 trained.map_or(0, |t| t.centroids.len()),
567 );
568
569 merger
570 .merge(directory, &readers, output_segment_id, trained)
571 .await?;
572
573 log::info!(
574 "[merge] total wall-clock: {:.1}s ({} segments, {} docs)",
575 load_start.elapsed().as_secs_f64(),
576 readers.len(),
577 total_docs,
578 );
579
580 if total_docs > u32::MAX as u64 {
581 return Err(Error::Internal(format!(
582 "Merged segment doc count ({}) exceeds u32::MAX",
583 total_docs
584 )));
585 }
586 Ok((output_hex, total_docs as u32))
587 }
588
589 pub async fn abort_merges(&self) {
592 let handles: Vec<JoinHandle<()>> =
593 { std::mem::take(&mut *self.merge_handles.lock().await) };
594 for h in handles {
595 h.abort();
596 }
597 }
598
599 pub async fn wait_for_merging_thread(self: &Arc<Self>) {
601 let handles: Vec<JoinHandle<()>> =
602 { std::mem::take(&mut *self.merge_handles.lock().await) };
603 for h in handles {
604 let _ = h.await;
605 }
606 }
607
608 pub async fn wait_for_all_merges(self: &Arc<Self>) {
614 loop {
615 let handles: Vec<JoinHandle<()>> =
616 { std::mem::take(&mut *self.merge_handles.lock().await) };
617 if handles.is_empty() {
618 break;
619 }
620 for h in handles {
621 let _ = h.await;
622 }
623 }
624 }
625
626 pub async fn force_merge(self: &Arc<Self>) -> Result<()> {
635 const FORCE_MERGE_BATCH: usize = 64;
636
637 let max_segment_docs = {
638 let st = self.state.lock().await;
639 st.merge_policy.max_segment_docs()
640 };
641
642 self.wait_for_all_merges().await;
645
646 loop {
647 let mut segments: Vec<(String, u32)> = {
649 let st = self.state.lock().await;
650 st.metadata
651 .segment_metas
652 .iter()
653 .map(|(id, info)| (id.clone(), info.num_docs))
654 .collect()
655 };
656
657 if segments.len() < 2 {
658 return Ok(());
659 }
660
661 segments.sort_by_key(|(_, docs)| *docs);
662
663 let max_docs = max_segment_docs.map(|m| m as u64).unwrap_or(u64::MAX);
665 let mut batch = Vec::new();
666 let mut batch_docs = 0u64;
667
668 for (id, docs) in &segments {
669 if batch.len() >= FORCE_MERGE_BATCH {
670 break;
671 }
672 let next_total = batch_docs + *docs as u64;
673 if next_total > max_docs && !batch.is_empty() {
674 break;
675 }
676 batch.push(id.clone());
677 batch_docs += *docs as u64;
678 }
679
680 if batch.len() < 2 {
681 return Ok(());
682 }
683
684 log::info!(
685 "[force_merge] merging batch of {} segments ({} docs)",
686 batch.len(),
687 batch_docs
688 );
689
690 let output_id = SegmentId::new();
691 let output_hex = output_id.to_hex();
692
693 let mut all_ids = batch.clone();
695 all_ids.push(output_hex);
696 let _guard = match self.merge_inventory.try_register(all_ids) {
697 Some(g) => g,
698 None => {
699 self.wait_for_merging_thread().await;
701 continue;
702 }
703 };
704
705 let trained_snap = self.trained();
706 let (new_segment_id, total_docs) = Self::do_merge(
707 self.directory.as_ref(),
708 &self.schema,
709 &batch,
710 output_id,
711 self.term_cache_blocks,
712 trained_snap.as_deref(),
713 )
714 .await?;
715
716 self.replace_segments(&batch, new_segment_id, total_docs, false)
717 .await?;
718
719 }
721 }
722
723 pub async fn reorder_segments(self: &Arc<Self>) -> Result<()> {
730 self.wait_for_all_merges().await;
731 let segment_ids = self.get_segment_ids().await;
732
733 if segment_ids.is_empty() {
734 log::info!("[reorder] no segments to reorder");
735 return Ok(());
736 }
737
738 log::info!("[reorder] reordering {} segments", segment_ids.len());
739
740 for seg_id in segment_ids {
741 match self.reorder_single_segment(&seg_id, None).await {
742 Ok(true) => {}
743 Ok(false) => log::warn!("[reorder] segment {} skipped (in merge)", seg_id),
744 Err(e) => return Err(e),
745 }
746 }
747
748 log::info!("[reorder] all segments reordered");
749 Ok(())
750 }
751
752 pub async fn unreordered_segment_ids(&self) -> Vec<String> {
757 let st = self.state.lock().await;
758 let in_merge = self.merge_inventory.snapshot();
759 st.metadata
760 .segment_metas
761 .iter()
762 .filter(|(id, info)| !info.reordered && !in_merge.contains(*id))
763 .map(|(id, _)| id.clone())
764 .collect()
765 }
766
767 pub async fn reorder_single_segment(
772 self: &Arc<Self>,
773 seg_id: &str,
774 rayon_pool: Option<Arc<rayon::ThreadPool>>,
775 ) -> Result<bool> {
776 let source_id = SegmentId::from_hex(seg_id)
777 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", seg_id)))?;
778 let output_id = SegmentId::new();
779 let output_hex = output_id.to_hex();
780
781 let all_ids = vec![seg_id.to_string(), output_hex];
782 let _guard = match self.merge_inventory.try_register(all_ids) {
783 Some(g) => g,
784 None => {
785 log::debug!("[optimizer] segment {} in active merge, skipping", seg_id);
786 return Ok(false);
787 }
788 };
789
790 let (new_id, total_docs) = crate::segment::reorder::reorder_segment(
791 self.directory.as_ref(),
792 &self.schema,
793 source_id,
794 output_id,
795 self.term_cache_blocks,
796 crate::segment::reorder::DEFAULT_MEMORY_BUDGET,
797 rayon_pool,
798 )
799 .await?;
800
801 self.replace_segments(&[seg_id.to_string()], new_id, total_docs, true)
802 .await?;
803
804 Ok(true)
805 }
806
807 pub async fn cleanup_orphan_segments(&self) -> Result<usize> {
812 let (registered_set, in_merge_set) = {
816 let st = self.state.lock().await;
817 let registered = st
818 .metadata
819 .segment_metas
820 .keys()
821 .cloned()
822 .collect::<HashSet<String>>();
823 let in_merge = self.merge_inventory.snapshot();
824 (registered, in_merge)
825 };
826
827 let mut orphan_ids: HashSet<String> = HashSet::new();
828
829 if let Ok(entries) = self.directory.list_files(std::path::Path::new("")).await {
830 for entry in entries {
831 let filename = entry.to_string_lossy();
832 if filename.starts_with("seg_") && filename.len() > 37 {
833 let hex_part = &filename[4..36];
834 if !registered_set.contains(hex_part) && !in_merge_set.contains(hex_part) {
835 orphan_ids.insert(hex_part.to_string());
836 }
837 }
838 }
839 }
840
841 let mut deleted = 0;
842 for hex_id in &orphan_ids {
843 if let Some(segment_id) = SegmentId::from_hex(hex_id)
844 && crate::segment::delete_segment(self.directory.as_ref(), segment_id)
845 .await
846 .is_ok()
847 {
848 deleted += 1;
849 }
850 }
851
852 Ok(deleted)
853 }
854}
855
856#[cfg(test)]
857mod tests {
858 use super::*;
859
860 #[test]
861 fn test_inventory_guard_drop_unregisters() {
862 let inv = Arc::new(MergeInventory::new());
863 {
864 let _guard = inv.try_register(vec!["a".into(), "b".into()]).unwrap();
865 let snap = inv.snapshot();
866 assert!(snap.contains("a"));
867 assert!(snap.contains("b"));
868 }
869 assert!(inv.snapshot().is_empty());
871 }
872
873 #[test]
874 fn test_inventory_concurrent_non_overlapping_merges() {
875 let inv = Arc::new(MergeInventory::new());
876 let _g1 = inv.try_register(vec!["a".into(), "b".into()]).unwrap();
877 let _g2 = inv.try_register(vec!["c".into(), "d".into()]).unwrap();
879 let snap = inv.snapshot();
880 assert_eq!(snap.len(), 4);
881
882 drop(_g1);
884 let snap = inv.snapshot();
885 assert_eq!(snap.len(), 2);
886 assert!(snap.contains("c"));
887 assert!(snap.contains("d"));
888 }
889
890 #[test]
891 fn test_inventory_overlapping_merge_rejected() {
892 let inv = Arc::new(MergeInventory::new());
893 let _g1 = inv.try_register(vec!["a".into(), "b".into()]).unwrap();
894 assert!(inv.try_register(vec!["b".into(), "c".into()]).is_none());
896 drop(_g1);
898 assert!(inv.try_register(vec!["b".into(), "c".into()]).is_some());
899 }
900
901 #[test]
902 fn test_inventory_snapshot() {
903 let inv = Arc::new(MergeInventory::new());
904 let _g = inv.try_register(vec!["x".into(), "y".into()]).unwrap();
905 let snap = inv.snapshot();
906 assert!(snap.contains("x"));
907 assert!(snap.contains("y"));
908 assert!(!snap.contains("z"));
909 }
910}