1use std::collections::HashSet;
28use std::sync::Arc;
29
30use arc_swap::ArcSwapOption;
31use tokio::sync::Mutex as AsyncMutex;
32use tokio::task::JoinHandle;
33
34use crate::directories::DirectoryWriter;
35use crate::error::{Error, Result};
36use crate::index::IndexMetadata;
37use crate::segment::{SegmentId, SegmentSnapshot, SegmentTracker, TrainedVectorStructures};
38#[cfg(feature = "native")]
39use crate::segment::{SegmentMerger, SegmentReader};
40
41use super::{MergePolicy, SegmentInfo};
42
43struct MergeInventory {
54 inner: parking_lot::Mutex<HashSet<String>>,
55}
56
57impl MergeInventory {
58 fn new() -> Self {
59 Self {
60 inner: parking_lot::Mutex::new(HashSet::new()),
61 }
62 }
63
64 fn try_register(self: &Arc<Self>, segment_ids: Vec<String>) -> Option<MergeGuard> {
67 let mut inner = self.inner.lock();
68 for id in &segment_ids {
70 if inner.contains(id) {
71 return None;
72 }
73 }
74 for id in &segment_ids {
75 inner.insert(id.clone());
76 }
77 Some(MergeGuard {
78 inventory: Arc::clone(self),
79 segment_ids,
80 })
81 }
82
83 fn snapshot(&self) -> HashSet<String> {
85 self.inner.lock().clone()
86 }
87
88 fn contains(&self, segment_id: &str) -> bool {
90 self.inner.lock().contains(segment_id)
91 }
92}
93
94struct MergeGuard {
98 inventory: Arc<MergeInventory>,
99 segment_ids: Vec<String>,
100}
101
102impl Drop for MergeGuard {
103 fn drop(&mut self) {
104 let mut inner = self.inventory.inner.lock();
105 for id in &self.segment_ids {
106 inner.remove(id);
107 }
108 }
109}
110
111struct ManagerState {
113 metadata: IndexMetadata,
114 merge_policy: Box<dyn MergePolicy>,
115}
116
117pub struct SegmentManager<D: DirectoryWriter + 'static> {
121 state: AsyncMutex<ManagerState>,
123
124 merge_inventory: Arc<MergeInventory>,
127
128 merge_handles: AsyncMutex<Vec<JoinHandle<()>>>,
130
131 trained: ArcSwapOption<TrainedVectorStructures>,
133
134 tracker: Arc<SegmentTracker>,
136
137 delete_fn: Arc<dyn Fn(Vec<SegmentId>) + Send + Sync>,
139
140 directory: Arc<D>,
142 schema: Arc<crate::dsl::Schema>,
144 term_cache_blocks: usize,
146 max_concurrent_merges: usize,
148}
149
150impl<D: DirectoryWriter + 'static> SegmentManager<D> {
151 pub fn new(
153 directory: Arc<D>,
154 schema: Arc<crate::dsl::Schema>,
155 metadata: IndexMetadata,
156 merge_policy: Box<dyn MergePolicy>,
157 term_cache_blocks: usize,
158 max_concurrent_merges: usize,
159 ) -> Self {
160 let tracker = Arc::new(SegmentTracker::new());
161 for seg_id in metadata.segment_metas.keys() {
162 tracker.register(seg_id);
163 }
164
165 let delete_fn: Arc<dyn Fn(Vec<SegmentId>) + Send + Sync> = {
166 let dir = Arc::clone(&directory);
167 Arc::new(move |segment_ids| {
168 let Ok(handle) = tokio::runtime::Handle::try_current() else {
171 return;
172 };
173 let dir = Arc::clone(&dir);
174 handle.spawn(async move {
175 for segment_id in segment_ids {
176 log::info!(
177 "[segment_cleanup] deleting deferred segment {}",
178 segment_id.0
179 );
180 let _ = crate::segment::delete_segment(dir.as_ref(), segment_id).await;
181 }
182 });
183 })
184 };
185
186 Self {
187 state: AsyncMutex::new(ManagerState {
188 metadata,
189 merge_policy,
190 }),
191 merge_inventory: Arc::new(MergeInventory::new()),
192 merge_handles: AsyncMutex::new(Vec::new()),
193 trained: ArcSwapOption::new(None),
194 tracker,
195 delete_fn,
196 directory,
197 schema,
198 term_cache_blocks,
199 max_concurrent_merges: max_concurrent_merges.max(1),
200 }
201 }
202
203 pub async fn get_segment_ids(&self) -> Vec<String> {
209 self.state.lock().await.metadata.segment_ids()
210 }
211
212 pub fn trained(&self) -> Option<Arc<TrainedVectorStructures>> {
214 self.trained.load_full()
215 }
216
217 pub async fn load_and_publish_trained(&self) {
220 let vector_fields = {
222 let st = self.state.lock().await;
223 st.metadata.vector_fields.clone()
224 };
225 let trained =
227 IndexMetadata::load_trained_from_fields(&vector_fields, self.directory.as_ref()).await;
228 if let Some(t) = trained {
229 self.trained.store(Some(Arc::new(t)));
230 }
231 }
232
233 pub(crate) fn clear_trained(&self) {
235 self.trained.store(None);
236 }
237
238 pub(crate) async fn read_metadata<F, R>(&self, f: F) -> R
240 where
241 F: FnOnce(&IndexMetadata) -> R,
242 {
243 let st = self.state.lock().await;
244 f(&st.metadata)
245 }
246
247 pub(crate) async fn update_metadata<F>(&self, f: F) -> Result<()>
249 where
250 F: FnOnce(&mut IndexMetadata),
251 {
252 let mut st = self.state.lock().await;
253 f(&mut st.metadata);
254 st.metadata.save(self.directory.as_ref()).await
255 }
256
257 pub async fn acquire_snapshot(&self) -> SegmentSnapshot {
260 let acquired = {
261 let st = self.state.lock().await;
262 let segment_ids = st.metadata.segment_ids();
263 self.tracker.acquire(&segment_ids)
264 };
265
266 SegmentSnapshot::with_delete_fn(
267 Arc::clone(&self.tracker),
268 acquired,
269 Arc::clone(&self.delete_fn),
270 )
271 }
272
273 pub fn tracker(&self) -> Arc<SegmentTracker> {
275 Arc::clone(&self.tracker)
276 }
277
278 pub fn directory(&self) -> Arc<D> {
280 Arc::clone(&self.directory)
281 }
282}
283
284#[cfg(feature = "native")]
289impl<D: DirectoryWriter + 'static> SegmentManager<D> {
290 pub async fn commit(&self, new_segments: Vec<(String, u32)>) -> Result<()> {
292 let mut st = self.state.lock().await;
293 for (segment_id, num_docs) in new_segments {
294 if !st.metadata.has_segment(&segment_id) {
295 st.metadata.add_segment(segment_id.clone(), num_docs);
296 self.tracker.register(&segment_id);
297 }
298 }
299 st.metadata.save(self.directory.as_ref()).await
300 }
301
302 pub async fn maybe_merge(self: &Arc<Self>) {
313 let slots_available = {
315 let mut handles = self.merge_handles.lock().await;
316 handles.retain(|h| !h.is_finished());
317 self.max_concurrent_merges.saturating_sub(handles.len())
318 };
319
320 if slots_available == 0 {
321 log::debug!("[maybe_merge] at max concurrent merges, skipping");
322 return;
323 }
324
325 let candidates = {
326 let st = self.state.lock().await;
327
328 let segments: Vec<SegmentInfo> = st
333 .metadata
334 .segment_metas
335 .iter()
336 .filter(|(id, _)| {
337 !self.tracker.is_pending_deletion(id) && !self.merge_inventory.contains(id)
338 })
339 .map(|(id, info)| SegmentInfo {
340 id: id.clone(),
341 num_docs: info.num_docs,
342 })
343 .collect();
344
345 log::debug!("[maybe_merge] {} eligible segments", segments.len());
346
347 st.merge_policy.find_merges(&segments)
348 };
349
350 if candidates.is_empty() {
351 return;
352 }
353
354 log::debug!(
355 "[maybe_merge] {} merge candidates, {} slots available",
356 candidates.len(),
357 slots_available
358 );
359
360 let mut new_handles = Vec::new();
361 for c in candidates {
362 if new_handles.len() >= slots_available {
363 break;
364 }
365 if let Some(h) = self.spawn_merge(c.segment_ids) {
366 new_handles.push(h);
367 }
368 }
369 if !new_handles.is_empty() {
370 self.merge_handles.lock().await.extend(new_handles);
371 }
372 }
373
374 fn spawn_merge(self: &Arc<Self>, segment_ids_to_merge: Vec<String>) -> Option<JoinHandle<()>> {
383 let output_id = SegmentId::new();
384 let output_hex = output_id.to_hex();
385
386 let mut all_ids = segment_ids_to_merge.clone();
387 all_ids.push(output_hex);
388
389 let guard = match self.merge_inventory.try_register(all_ids) {
390 Some(g) => g,
391 None => {
392 log::debug!("[spawn_merge] skipped: segments overlap with active merge");
393 return None;
394 }
395 };
396
397 let sm = Arc::clone(self);
398 let ids = segment_ids_to_merge;
399
400 Some(tokio::spawn(async move {
401 let _guard = guard;
402
403 let trained_snap = sm.trained();
404 let result = Self::do_merge(
405 sm.directory.as_ref(),
406 &sm.schema,
407 &ids,
408 output_id,
409 sm.term_cache_blocks,
410 trained_snap.as_deref(),
411 )
412 .await;
413
414 match result {
415 Ok((new_id, doc_count)) => {
416 if let Err(e) = sm.replace_segments(&ids, new_id, doc_count).await {
417 log::error!("[merge] Failed to replace segments after merge: {:?}", e);
418 }
419 }
420 Err(e) => {
421 log::error!(
422 "[merge] Background merge failed for segments {:?}: {:?}",
423 ids,
424 e
425 );
426 }
427 }
428 sm.maybe_merge().await;
433 }))
434 }
435
436 async fn replace_segments(
439 &self,
440 old_ids: &[String],
441 new_id: String,
442 doc_count: u32,
443 ) -> Result<()> {
444 self.tracker.register(&new_id);
445
446 {
447 let mut st = self.state.lock().await;
448 let parent_gen = old_ids
450 .iter()
451 .filter_map(|id| st.metadata.segment_metas.get(id))
452 .map(|info| info.generation)
453 .max()
454 .unwrap_or(0);
455 let ancestors: Vec<String> = old_ids.to_vec();
456
457 for id in old_ids {
458 st.metadata.remove_segment(id);
459 }
460 st.metadata
461 .add_merged_segment(new_id, doc_count, ancestors, parent_gen + 1);
462 st.metadata.save(self.directory.as_ref()).await?;
463 }
464
465 let ready_to_delete = self.tracker.mark_for_deletion(old_ids);
466 for segment_id in ready_to_delete {
467 let _ = crate::segment::delete_segment(self.directory.as_ref(), segment_id).await;
468 }
469 Ok(())
470 }
471
472 pub(crate) async fn do_merge(
476 directory: &D,
477 schema: &Arc<crate::dsl::Schema>,
478 segment_ids_to_merge: &[String],
479 output_segment_id: SegmentId,
480 term_cache_blocks: usize,
481 trained: Option<&TrainedVectorStructures>,
482 ) -> Result<(String, u32)> {
483 let output_hex = output_segment_id.to_hex();
484 let load_start = std::time::Instant::now();
485
486 let segment_ids: Vec<SegmentId> = segment_ids_to_merge
487 .iter()
488 .map(|id_str| {
489 SegmentId::from_hex(id_str)
490 .ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))
491 })
492 .collect::<Result<Vec<_>>>()?;
493
494 let schema_arc = Arc::clone(schema);
495 let futures: Vec<_> = segment_ids
496 .iter()
497 .map(|&sid| {
498 let sch = Arc::clone(&schema_arc);
499 async move { SegmentReader::open(directory, sid, sch, 0, term_cache_blocks).await }
500 })
501 .collect();
502
503 let results = futures::future::join_all(futures).await;
504 let mut readers = Vec::with_capacity(results.len());
505 let mut total_docs = 0u64;
506 for (i, result) in results.into_iter().enumerate() {
507 match result {
508 Ok(r) => {
509 total_docs += r.meta().num_docs as u64;
510 readers.push(r);
511 }
512 Err(e) => {
513 log::error!(
514 "[merge] Failed to open segment {}: {:?}",
515 segment_ids_to_merge[i],
516 e
517 );
518 return Err(e);
519 }
520 }
521 }
522
523 log::info!(
524 "[merge] loaded {} segment readers in {:.1}s",
525 readers.len(),
526 load_start.elapsed().as_secs_f64()
527 );
528
529 let merger = SegmentMerger::new(Arc::clone(schema));
530
531 log::info!(
532 "[merge] {} segments -> {} (trained={})",
533 segment_ids_to_merge.len(),
534 output_hex,
535 trained.map_or(0, |t| t.centroids.len())
536 );
537
538 merger
539 .merge(directory, &readers, output_segment_id, trained)
540 .await?;
541
542 log::info!(
543 "[merge] total wall-clock: {:.1}s ({} segments, {} docs)",
544 load_start.elapsed().as_secs_f64(),
545 readers.len(),
546 total_docs,
547 );
548
549 Ok((output_hex, total_docs.min(u32::MAX as u64) as u32))
550 }
551
552 pub async fn wait_for_merging_thread(self: &Arc<Self>) {
554 let handles: Vec<JoinHandle<()>> =
555 { std::mem::take(&mut *self.merge_handles.lock().await) };
556 for h in handles {
557 let _ = h.await;
558 }
559 }
560
561 pub async fn wait_for_all_merges(self: &Arc<Self>) {
567 loop {
568 let handles: Vec<JoinHandle<()>> =
569 { std::mem::take(&mut *self.merge_handles.lock().await) };
570 if handles.is_empty() {
571 break;
572 }
573 for h in handles {
574 let _ = h.await;
575 }
576 }
577 }
578
579 pub async fn force_merge(self: &Arc<Self>) -> Result<()> {
584 const FORCE_MERGE_BATCH: usize = 64;
585
586 self.wait_for_all_merges().await;
589
590 loop {
591 let ids_to_merge = self.get_segment_ids().await;
592 if ids_to_merge.len() < 2 {
593 return Ok(());
594 }
595
596 let batch: Vec<String> = ids_to_merge.into_iter().take(FORCE_MERGE_BATCH).collect();
597
598 log::info!("[force_merge] merging batch of {} segments", batch.len());
599
600 let output_id = SegmentId::new();
601 let output_hex = output_id.to_hex();
602
603 let mut all_ids = batch.clone();
605 all_ids.push(output_hex);
606 let _guard = match self.merge_inventory.try_register(all_ids) {
607 Some(g) => g,
608 None => {
609 self.wait_for_merging_thread().await;
611 continue;
612 }
613 };
614
615 let trained_snap = self.trained();
616 let (new_segment_id, total_docs) = Self::do_merge(
617 self.directory.as_ref(),
618 &self.schema,
619 &batch,
620 output_id,
621 self.term_cache_blocks,
622 trained_snap.as_deref(),
623 )
624 .await?;
625
626 self.replace_segments(&batch, new_segment_id, total_docs)
627 .await?;
628
629 }
631 }
632
633 pub async fn cleanup_orphan_segments(&self) -> Result<usize> {
638 let (registered_set, in_merge_set) = {
642 let st = self.state.lock().await;
643 let registered = st
644 .metadata
645 .segment_metas
646 .keys()
647 .cloned()
648 .collect::<HashSet<String>>();
649 let in_merge = self.merge_inventory.snapshot();
650 (registered, in_merge)
651 };
652
653 let mut orphan_ids: HashSet<String> = HashSet::new();
654
655 if let Ok(entries) = self.directory.list_files(std::path::Path::new("")).await {
656 for entry in entries {
657 let filename = entry.to_string_lossy();
658 if filename.starts_with("seg_") && filename.len() > 37 {
659 let hex_part = &filename[4..36];
660 if !registered_set.contains(hex_part) && !in_merge_set.contains(hex_part) {
661 orphan_ids.insert(hex_part.to_string());
662 }
663 }
664 }
665 }
666
667 let mut deleted = 0;
668 for hex_id in &orphan_ids {
669 if let Some(segment_id) = SegmentId::from_hex(hex_id)
670 && crate::segment::delete_segment(self.directory.as_ref(), segment_id)
671 .await
672 .is_ok()
673 {
674 deleted += 1;
675 }
676 }
677
678 Ok(deleted)
679 }
680}
681
682#[cfg(test)]
683mod tests {
684 use super::*;
685
686 #[test]
687 fn test_inventory_guard_drop_unregisters() {
688 let inv = Arc::new(MergeInventory::new());
689 {
690 let _guard = inv.try_register(vec!["a".into(), "b".into()]).unwrap();
691 let snap = inv.snapshot();
692 assert!(snap.contains("a"));
693 assert!(snap.contains("b"));
694 }
695 assert!(inv.snapshot().is_empty());
697 }
698
699 #[test]
700 fn test_inventory_concurrent_non_overlapping_merges() {
701 let inv = Arc::new(MergeInventory::new());
702 let _g1 = inv.try_register(vec!["a".into(), "b".into()]).unwrap();
703 let _g2 = inv.try_register(vec!["c".into(), "d".into()]).unwrap();
705 let snap = inv.snapshot();
706 assert_eq!(snap.len(), 4);
707
708 drop(_g1);
710 let snap = inv.snapshot();
711 assert_eq!(snap.len(), 2);
712 assert!(snap.contains("c"));
713 assert!(snap.contains("d"));
714 }
715
716 #[test]
717 fn test_inventory_overlapping_merge_rejected() {
718 let inv = Arc::new(MergeInventory::new());
719 let _g1 = inv.try_register(vec!["a".into(), "b".into()]).unwrap();
720 assert!(inv.try_register(vec!["b".into(), "c".into()]).is_none());
722 drop(_g1);
724 assert!(inv.try_register(vec!["b".into(), "c".into()]).is_some());
725 }
726
727 #[test]
728 fn test_inventory_snapshot() {
729 let inv = Arc::new(MergeInventory::new());
730 let _g = inv.try_register(vec!["x".into(), "y".into()]).unwrap();
731 let snap = inv.snapshot();
732 assert!(snap.contains("x"));
733 assert!(snap.contains("y"));
734 assert!(!snap.contains("z"));
735 }
736}