1use std::collections::HashSet;
28use std::sync::Arc;
29
30use arc_swap::ArcSwapOption;
31use tokio::sync::Mutex as AsyncMutex;
32use tokio::task::JoinHandle;
33
34use crate::directories::DirectoryWriter;
35use crate::error::{Error, Result};
36use crate::index::IndexMetadata;
37use crate::segment::{SegmentId, SegmentSnapshot, SegmentTracker, TrainedVectorStructures};
38#[cfg(feature = "native")]
39use crate::segment::{SegmentMerger, SegmentReader};
40
41use super::{MergePolicy, SegmentInfo};
42
43struct MergeInventory {
54 inner: parking_lot::Mutex<HashSet<String>>,
55}
56
57impl MergeInventory {
58 fn new() -> Self {
59 Self {
60 inner: parking_lot::Mutex::new(HashSet::new()),
61 }
62 }
63
64 fn try_register(self: &Arc<Self>, segment_ids: Vec<String>) -> Option<MergeGuard> {
67 let mut inner = self.inner.lock();
68 for id in &segment_ids {
70 if inner.contains(id) {
71 log::debug!(
72 "[merge_inventory] rejected: {} overlaps with active merge ({} active IDs)",
73 id,
74 inner.len()
75 );
76 return None;
77 }
78 }
79 log::debug!(
80 "[merge_inventory] registered {} IDs (total active: {})",
81 segment_ids.len(),
82 inner.len() + segment_ids.len()
83 );
84 for id in &segment_ids {
85 inner.insert(id.clone());
86 }
87 Some(MergeGuard {
88 inventory: Arc::clone(self),
89 segment_ids,
90 })
91 }
92
93 fn snapshot(&self) -> HashSet<String> {
95 self.inner.lock().clone()
96 }
97
98 fn contains(&self, segment_id: &str) -> bool {
100 self.inner.lock().contains(segment_id)
101 }
102}
103
104struct MergeGuard {
108 inventory: Arc<MergeInventory>,
109 segment_ids: Vec<String>,
110}
111
112impl Drop for MergeGuard {
113 fn drop(&mut self) {
114 let mut inner = self.inventory.inner.lock();
115 for id in &self.segment_ids {
116 inner.remove(id);
117 }
118 }
119}
120
121struct ManagerState {
123 metadata: IndexMetadata,
124 merge_policy: Box<dyn MergePolicy>,
125}
126
127pub struct SegmentManager<D: DirectoryWriter + 'static> {
131 state: AsyncMutex<ManagerState>,
133
134 merge_inventory: Arc<MergeInventory>,
137
138 merge_handles: AsyncMutex<Vec<JoinHandle<()>>>,
140
141 trained: ArcSwapOption<TrainedVectorStructures>,
143
144 tracker: Arc<SegmentTracker>,
146
147 delete_fn: Arc<dyn Fn(Vec<SegmentId>) + Send + Sync>,
149
150 directory: Arc<D>,
152 schema: Arc<crate::dsl::Schema>,
154 term_cache_blocks: usize,
156 max_concurrent_merges: usize,
158}
159
160impl<D: DirectoryWriter + 'static> SegmentManager<D> {
161 pub fn new(
163 directory: Arc<D>,
164 schema: Arc<crate::dsl::Schema>,
165 metadata: IndexMetadata,
166 merge_policy: Box<dyn MergePolicy>,
167 term_cache_blocks: usize,
168 max_concurrent_merges: usize,
169 ) -> Self {
170 let tracker = Arc::new(SegmentTracker::new());
171 for seg_id in metadata.segment_metas.keys() {
172 tracker.register(seg_id);
173 }
174
175 let delete_fn: Arc<dyn Fn(Vec<SegmentId>) + Send + Sync> = {
176 let dir = Arc::clone(&directory);
177 Arc::new(move |segment_ids| {
178 let Ok(handle) = tokio::runtime::Handle::try_current() else {
181 return;
182 };
183 let dir = Arc::clone(&dir);
184 handle.spawn(async move {
185 for segment_id in segment_ids {
186 log::info!(
187 "[segment_cleanup] deleting deferred segment {}",
188 segment_id.0
189 );
190 let _ = crate::segment::delete_segment(dir.as_ref(), segment_id).await;
191 }
192 });
193 })
194 };
195
196 Self {
197 state: AsyncMutex::new(ManagerState {
198 metadata,
199 merge_policy,
200 }),
201 merge_inventory: Arc::new(MergeInventory::new()),
202 merge_handles: AsyncMutex::new(Vec::new()),
203 trained: ArcSwapOption::new(None),
204 tracker,
205 delete_fn,
206 directory,
207 schema,
208 term_cache_blocks,
209 max_concurrent_merges: max_concurrent_merges.max(1),
210 }
211 }
212
213 pub async fn get_segment_ids(&self) -> Vec<String> {
219 self.state.lock().await.metadata.segment_ids()
220 }
221
222 pub fn trained(&self) -> Option<Arc<TrainedVectorStructures>> {
224 self.trained.load_full()
225 }
226
227 pub async fn load_and_publish_trained(&self) {
230 let vector_fields = {
232 let st = self.state.lock().await;
233 st.metadata.vector_fields.clone()
234 };
235 let trained =
237 IndexMetadata::load_trained_from_fields(&vector_fields, self.directory.as_ref()).await;
238 if let Some(t) = trained {
239 self.trained.store(Some(Arc::new(t)));
240 }
241 }
242
243 pub(crate) fn clear_trained(&self) {
245 self.trained.store(None);
246 }
247
248 pub(crate) async fn read_metadata<F, R>(&self, f: F) -> R
250 where
251 F: FnOnce(&IndexMetadata) -> R,
252 {
253 let st = self.state.lock().await;
254 f(&st.metadata)
255 }
256
257 pub(crate) async fn update_metadata<F>(&self, f: F) -> Result<()>
259 where
260 F: FnOnce(&mut IndexMetadata),
261 {
262 let mut st = self.state.lock().await;
263 f(&mut st.metadata);
264 st.metadata.save(self.directory.as_ref()).await
265 }
266
267 pub async fn acquire_snapshot(&self) -> SegmentSnapshot {
270 let acquired = {
271 let st = self.state.lock().await;
272 let segment_ids = st.metadata.segment_ids();
273 self.tracker.acquire(&segment_ids)
274 };
275
276 SegmentSnapshot::with_delete_fn(
277 Arc::clone(&self.tracker),
278 acquired,
279 Arc::clone(&self.delete_fn),
280 )
281 }
282
283 pub fn tracker(&self) -> Arc<SegmentTracker> {
285 Arc::clone(&self.tracker)
286 }
287
288 pub fn directory(&self) -> Arc<D> {
290 Arc::clone(&self.directory)
291 }
292}
293
294#[cfg(feature = "native")]
299impl<D: DirectoryWriter + 'static> SegmentManager<D> {
300 pub async fn commit(&self, new_segments: Vec<(String, u32)>) -> Result<()> {
302 let mut st = self.state.lock().await;
303 for (segment_id, num_docs) in new_segments {
304 if !st.metadata.has_segment(&segment_id) {
305 st.metadata.add_segment(segment_id.clone(), num_docs);
306 self.tracker.register(&segment_id);
307 }
308 }
309 st.metadata.save(self.directory.as_ref()).await
310 }
311
312 pub async fn maybe_merge(self: &Arc<Self>) {
324 let slots_available = {
326 let mut handles = self.merge_handles.lock().await;
327 handles.retain(|h| !h.is_finished());
328 self.max_concurrent_merges.saturating_sub(handles.len())
329 };
330
331 if slots_available == 0 {
332 log::debug!("[maybe_merge] at max concurrent merges, skipping");
333 return;
334 }
335
336 let new_handles = {
340 let st = self.state.lock().await;
341
342 let segments: Vec<SegmentInfo> = st
344 .metadata
345 .segment_metas
346 .iter()
347 .filter(|(id, _)| {
348 !self.tracker.is_pending_deletion(id) && !self.merge_inventory.contains(id)
349 })
350 .map(|(id, info)| SegmentInfo {
351 id: id.clone(),
352 num_docs: info.num_docs,
353 })
354 .collect();
355
356 log::debug!("[maybe_merge] {} eligible segments", segments.len());
357
358 let candidates = st.merge_policy.find_merges(&segments);
359
360 if candidates.is_empty() {
361 return;
362 }
363
364 log::debug!(
365 "[maybe_merge] {} merge candidates, {} slots available",
366 candidates.len(),
367 slots_available
368 );
369
370 let mut handles = Vec::new();
371 for c in candidates {
372 if handles.len() >= slots_available {
373 break;
374 }
375 if let Some(h) = self.spawn_merge(c.segment_ids) {
376 handles.push(h);
377 }
378 }
379 handles
380 };
382
383 if !new_handles.is_empty() {
384 self.merge_handles.lock().await.extend(new_handles);
385 }
386 }
387
388 fn spawn_merge(self: &Arc<Self>, segment_ids_to_merge: Vec<String>) -> Option<JoinHandle<()>> {
397 let output_id = SegmentId::new();
398 let output_hex = output_id.to_hex();
399
400 let mut all_ids = segment_ids_to_merge.clone();
401 all_ids.push(output_hex);
402
403 let guard = match self.merge_inventory.try_register(all_ids) {
404 Some(g) => g,
405 None => {
406 log::debug!("[spawn_merge] skipped: segments overlap with active merge");
407 return None;
408 }
409 };
410
411 let sm = Arc::clone(self);
412 let ids = segment_ids_to_merge;
413
414 Some(tokio::spawn(async move {
415 let _guard = guard;
416
417 let trained_snap = sm.trained();
418 let result = Self::do_merge(
419 sm.directory.as_ref(),
420 &sm.schema,
421 &ids,
422 output_id,
423 sm.term_cache_blocks,
424 trained_snap.as_deref(),
425 false,
426 )
427 .await;
428
429 match result {
430 Ok((new_id, doc_count)) => {
431 if let Err(e) = sm.replace_segments(&ids, new_id, doc_count).await {
432 log::error!("[merge] Failed to replace segments after merge: {:?}", e);
433 }
434 }
435 Err(e) => {
436 log::error!(
437 "[merge] Background merge failed for segments {:?}: {:?}",
438 ids,
439 e
440 );
441 }
442 }
443 sm.maybe_merge().await;
448 }))
449 }
450
451 async fn replace_segments(
454 &self,
455 old_ids: &[String],
456 new_id: String,
457 doc_count: u32,
458 ) -> Result<()> {
459 self.tracker.register(&new_id);
460
461 {
462 let mut st = self.state.lock().await;
463 let parent_gen = old_ids
465 .iter()
466 .filter_map(|id| st.metadata.segment_metas.get(id))
467 .map(|info| info.generation)
468 .max()
469 .unwrap_or(0);
470 let ancestors: Vec<String> = old_ids.to_vec();
471
472 for id in old_ids {
473 st.metadata.remove_segment(id);
474 }
475 st.metadata
476 .add_merged_segment(new_id, doc_count, ancestors, parent_gen + 1);
477 st.metadata.save(self.directory.as_ref()).await?;
479 }
480
481 let ready_to_delete = self.tracker.mark_for_deletion(old_ids);
482 for segment_id in ready_to_delete {
483 let _ = crate::segment::delete_segment(self.directory.as_ref(), segment_id).await;
484 }
485 Ok(())
486 }
487
488 pub(crate) async fn do_merge(
493 directory: &D,
494 schema: &Arc<crate::dsl::Schema>,
495 segment_ids_to_merge: &[String],
496 output_segment_id: SegmentId,
497 term_cache_blocks: usize,
498 trained: Option<&TrainedVectorStructures>,
499 force_reorder: bool,
500 ) -> Result<(String, u32)> {
501 let output_hex = output_segment_id.to_hex();
502 let load_start = std::time::Instant::now();
503
504 let segment_ids: Vec<SegmentId> = segment_ids_to_merge
505 .iter()
506 .map(|id_str| {
507 SegmentId::from_hex(id_str)
508 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))
509 })
510 .collect::<Result<Vec<_>>>()?;
511
512 let schema_arc = Arc::clone(schema);
513 let futures: Vec<_> = segment_ids
514 .iter()
515 .map(|&sid| {
516 let sch = Arc::clone(&schema_arc);
517 async move { SegmentReader::open(directory, sid, sch, term_cache_blocks).await }
518 })
519 .collect();
520
521 let results = futures::future::join_all(futures).await;
522 let mut readers = Vec::with_capacity(results.len());
523 let mut total_docs = 0u64;
524 for (i, result) in results.into_iter().enumerate() {
525 match result {
526 Ok(r) => {
527 total_docs += r.meta().num_docs as u64;
528 readers.push(r);
529 }
530 Err(e) => {
531 log::error!(
532 "[merge] Failed to open segment {}: {:?}",
533 segment_ids_to_merge[i],
534 e
535 );
536 return Err(e);
537 }
538 }
539 }
540
541 for (i, reader) in readers.iter().enumerate() {
545 let meta_docs = reader.meta().num_docs;
546 let store_docs = reader.store().num_docs();
547 if store_docs != meta_docs {
548 return Err(Error::Corruption(format!(
549 "pre-merge validation: segment {} store has {} docs but meta says {}",
550 segment_ids_to_merge[i], store_docs, meta_docs
551 )));
552 }
553 }
554
555 log::info!(
556 "[merge] loaded {} segment readers in {:.1}s",
557 readers.len(),
558 load_start.elapsed().as_secs_f64()
559 );
560
561 let merger = SegmentMerger::new(Arc::clone(schema)).with_force_reorder(force_reorder);
562
563 log::info!(
564 "[merge] {} segments -> {} (trained={}, force_reorder={})",
565 segment_ids_to_merge.len(),
566 output_hex,
567 trained.map_or(0, |t| t.centroids.len()),
568 force_reorder,
569 );
570
571 merger
572 .merge(directory, &readers, output_segment_id, trained)
573 .await?;
574
575 log::info!(
576 "[merge] total wall-clock: {:.1}s ({} segments, {} docs)",
577 load_start.elapsed().as_secs_f64(),
578 readers.len(),
579 total_docs,
580 );
581
582 if total_docs > u32::MAX as u64 {
583 return Err(Error::Internal(format!(
584 "Merged segment doc count ({}) exceeds u32::MAX",
585 total_docs
586 )));
587 }
588 Ok((output_hex, total_docs as u32))
589 }
590
591 pub async fn abort_merges(&self) {
594 let handles: Vec<JoinHandle<()>> =
595 { std::mem::take(&mut *self.merge_handles.lock().await) };
596 for h in handles {
597 h.abort();
598 }
599 }
600
601 pub async fn wait_for_merging_thread(self: &Arc<Self>) {
603 let handles: Vec<JoinHandle<()>> =
604 { std::mem::take(&mut *self.merge_handles.lock().await) };
605 for h in handles {
606 let _ = h.await;
607 }
608 }
609
610 pub async fn wait_for_all_merges(self: &Arc<Self>) {
616 loop {
617 let handles: Vec<JoinHandle<()>> =
618 { std::mem::take(&mut *self.merge_handles.lock().await) };
619 if handles.is_empty() {
620 break;
621 }
622 for h in handles {
623 let _ = h.await;
624 }
625 }
626 }
627
628 pub async fn force_merge(self: &Arc<Self>) -> Result<()> {
633 const FORCE_MERGE_BATCH: usize = 64;
634
635 self.wait_for_all_merges().await;
638
639 loop {
640 let ids_to_merge = self.get_segment_ids().await;
641 if ids_to_merge.len() < 2 {
642 return Ok(());
643 }
644
645 let batch: Vec<String> = ids_to_merge.into_iter().take(FORCE_MERGE_BATCH).collect();
646
647 log::info!("[force_merge] merging batch of {} segments", batch.len());
648
649 let output_id = SegmentId::new();
650 let output_hex = output_id.to_hex();
651
652 let mut all_ids = batch.clone();
654 all_ids.push(output_hex);
655 let _guard = match self.merge_inventory.try_register(all_ids) {
656 Some(g) => g,
657 None => {
658 self.wait_for_merging_thread().await;
660 continue;
661 }
662 };
663
664 let trained_snap = self.trained();
665 let (new_segment_id, total_docs) = Self::do_merge(
666 self.directory.as_ref(),
667 &self.schema,
668 &batch,
669 output_id,
670 self.term_cache_blocks,
671 trained_snap.as_deref(),
672 false,
673 )
674 .await?;
675
676 self.replace_segments(&batch, new_segment_id, total_docs)
677 .await?;
678
679 }
681 }
682
683 pub async fn reorder_segments(self: &Arc<Self>) -> Result<()> {
693 self.wait_for_all_merges().await;
694 let segment_ids = self.get_segment_ids().await;
695
696 if segment_ids.is_empty() {
697 log::info!("[reorder] no segments to reorder");
698 return Ok(());
699 }
700
701 log::info!("[reorder] reordering {} segments", segment_ids.len());
702
703 for seg_id in segment_ids {
704 let output_id = SegmentId::new();
705 let output_hex = output_id.to_hex();
706
707 let all_ids = vec![seg_id.clone(), output_hex];
708 let _guard = match self.merge_inventory.try_register(all_ids) {
709 Some(g) => g,
710 None => {
711 log::warn!("[reorder] segment {} in active merge, skipping", seg_id);
712 continue;
713 }
714 };
715
716 let trained_snap = self.trained();
717 let (new_id, total_docs) = Self::do_merge(
718 self.directory.as_ref(),
719 &self.schema,
720 std::slice::from_ref(&seg_id),
721 output_id,
722 self.term_cache_blocks,
723 trained_snap.as_deref(),
724 true, )
726 .await?;
727
728 self.replace_segments(&[seg_id], new_id, total_docs).await?;
729 }
730
731 log::info!("[reorder] all segments reordered");
732 Ok(())
733 }
734
735 pub async fn cleanup_orphan_segments(&self) -> Result<usize> {
740 let (registered_set, in_merge_set) = {
744 let st = self.state.lock().await;
745 let registered = st
746 .metadata
747 .segment_metas
748 .keys()
749 .cloned()
750 .collect::<HashSet<String>>();
751 let in_merge = self.merge_inventory.snapshot();
752 (registered, in_merge)
753 };
754
755 let mut orphan_ids: HashSet<String> = HashSet::new();
756
757 if let Ok(entries) = self.directory.list_files(std::path::Path::new("")).await {
758 for entry in entries {
759 let filename = entry.to_string_lossy();
760 if filename.starts_with("seg_") && filename.len() > 37 {
761 let hex_part = &filename[4..36];
762 if !registered_set.contains(hex_part) && !in_merge_set.contains(hex_part) {
763 orphan_ids.insert(hex_part.to_string());
764 }
765 }
766 }
767 }
768
769 let mut deleted = 0;
770 for hex_id in &orphan_ids {
771 if let Some(segment_id) = SegmentId::from_hex(hex_id)
772 && crate::segment::delete_segment(self.directory.as_ref(), segment_id)
773 .await
774 .is_ok()
775 {
776 deleted += 1;
777 }
778 }
779
780 Ok(deleted)
781 }
782}
783
784#[cfg(test)]
785mod tests {
786 use super::*;
787
788 #[test]
789 fn test_inventory_guard_drop_unregisters() {
790 let inv = Arc::new(MergeInventory::new());
791 {
792 let _guard = inv.try_register(vec!["a".into(), "b".into()]).unwrap();
793 let snap = inv.snapshot();
794 assert!(snap.contains("a"));
795 assert!(snap.contains("b"));
796 }
797 assert!(inv.snapshot().is_empty());
799 }
800
801 #[test]
802 fn test_inventory_concurrent_non_overlapping_merges() {
803 let inv = Arc::new(MergeInventory::new());
804 let _g1 = inv.try_register(vec!["a".into(), "b".into()]).unwrap();
805 let _g2 = inv.try_register(vec!["c".into(), "d".into()]).unwrap();
807 let snap = inv.snapshot();
808 assert_eq!(snap.len(), 4);
809
810 drop(_g1);
812 let snap = inv.snapshot();
813 assert_eq!(snap.len(), 2);
814 assert!(snap.contains("c"));
815 assert!(snap.contains("d"));
816 }
817
818 #[test]
819 fn test_inventory_overlapping_merge_rejected() {
820 let inv = Arc::new(MergeInventory::new());
821 let _g1 = inv.try_register(vec!["a".into(), "b".into()]).unwrap();
822 assert!(inv.try_register(vec!["b".into(), "c".into()]).is_none());
824 drop(_g1);
826 assert!(inv.try_register(vec!["b".into(), "c".into()]).is_some());
827 }
828
829 #[test]
830 fn test_inventory_snapshot() {
831 let inv = Arc::new(MergeInventory::new());
832 let _g = inv.try_register(vec!["x".into(), "y".into()]).unwrap();
833 let snap = inv.snapshot();
834 assert!(snap.contains("x"));
835 assert!(snap.contains("y"));
836 assert!(!snap.contains("z"));
837 }
838}