use std::collections::HashSet;
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use tokio::sync::Mutex as AsyncMutex;
use tokio::task::JoinHandle;
use crate::directories::DirectoryWriter;
use crate::error::{Error, Result};
use crate::index::IndexMetadata;
use crate::segment::{SegmentId, SegmentSnapshot, SegmentTracker, TrainedVectorStructures};
#[cfg(feature = "native")]
use crate::segment::{SegmentMerger, SegmentReader};
use super::{MergePolicy, SegmentInfo};
struct MergeInventory {
inner: parking_lot::Mutex<HashSet<String>>,
}
impl MergeInventory {
fn new() -> Self {
Self {
inner: parking_lot::Mutex::new(HashSet::new()),
}
}
fn try_register(self: &Arc<Self>, segment_ids: Vec<String>) -> Option<MergeGuard> {
let mut inner = self.inner.lock();
for id in &segment_ids {
if inner.contains(id) {
log::debug!(
"[merge_inventory] rejected: {} overlaps with active merge ({} active IDs)",
id,
inner.len()
);
return None;
}
}
log::debug!(
"[merge_inventory] registered {} IDs (total active: {})",
segment_ids.len(),
inner.len() + segment_ids.len()
);
for id in &segment_ids {
inner.insert(id.clone());
}
Some(MergeGuard {
inventory: Arc::clone(self),
segment_ids,
})
}
fn snapshot(&self) -> HashSet<String> {
self.inner.lock().clone()
}
fn contains(&self, segment_id: &str) -> bool {
self.inner.lock().contains(segment_id)
}
}
struct MergeGuard {
inventory: Arc<MergeInventory>,
segment_ids: Vec<String>,
}
impl Drop for MergeGuard {
fn drop(&mut self) {
let mut inner = self.inventory.inner.lock();
for id in &self.segment_ids {
inner.remove(id);
}
}
}
struct ManagerState {
metadata: IndexMetadata,
merge_policy: Box<dyn MergePolicy>,
}
pub struct SegmentManager<D: DirectoryWriter + 'static> {
state: AsyncMutex<ManagerState>,
merge_inventory: Arc<MergeInventory>,
merge_handles: AsyncMutex<Vec<JoinHandle<()>>>,
trained: ArcSwapOption<TrainedVectorStructures>,
tracker: Arc<SegmentTracker>,
delete_fn: Arc<dyn Fn(Vec<SegmentId>) + Send + Sync>,
directory: Arc<D>,
schema: Arc<crate::dsl::Schema>,
term_cache_blocks: usize,
max_concurrent_merges: usize,
}
impl<D: DirectoryWriter + 'static> SegmentManager<D> {
pub fn new(
directory: Arc<D>,
schema: Arc<crate::dsl::Schema>,
metadata: IndexMetadata,
merge_policy: Box<dyn MergePolicy>,
term_cache_blocks: usize,
max_concurrent_merges: usize,
) -> Self {
let tracker = Arc::new(SegmentTracker::new());
for seg_id in metadata.segment_metas.keys() {
tracker.register(seg_id);
}
let delete_fn: Arc<dyn Fn(Vec<SegmentId>) + Send + Sync> = {
let dir = Arc::clone(&directory);
Arc::new(move |segment_ids| {
let Ok(handle) = tokio::runtime::Handle::try_current() else {
return;
};
let dir = Arc::clone(&dir);
handle.spawn(async move {
for segment_id in segment_ids {
log::info!(
"[segment_cleanup] deleting deferred segment {}",
segment_id.0
);
let _ = crate::segment::delete_segment(dir.as_ref(), segment_id).await;
}
});
})
};
Self {
state: AsyncMutex::new(ManagerState {
metadata,
merge_policy,
}),
merge_inventory: Arc::new(MergeInventory::new()),
merge_handles: AsyncMutex::new(Vec::new()),
trained: ArcSwapOption::new(None),
tracker,
delete_fn,
directory,
schema,
term_cache_blocks,
max_concurrent_merges: max_concurrent_merges.max(1),
}
}
pub async fn get_segment_ids(&self) -> Vec<String> {
self.state.lock().await.metadata.segment_ids()
}
pub fn trained(&self) -> Option<Arc<TrainedVectorStructures>> {
self.trained.load_full()
}
pub async fn load_and_publish_trained(&self) {
let vector_fields = {
let st = self.state.lock().await;
st.metadata.vector_fields.clone()
};
let trained =
IndexMetadata::load_trained_from_fields(&vector_fields, self.directory.as_ref()).await;
if let Some(t) = trained {
self.trained.store(Some(Arc::new(t)));
}
}
pub(crate) fn clear_trained(&self) {
self.trained.store(None);
}
pub(crate) async fn read_metadata<F, R>(&self, f: F) -> R
where
F: FnOnce(&IndexMetadata) -> R,
{
let st = self.state.lock().await;
f(&st.metadata)
}
pub(crate) async fn update_metadata<F>(&self, f: F) -> Result<()>
where
F: FnOnce(&mut IndexMetadata),
{
let mut st = self.state.lock().await;
f(&mut st.metadata);
st.metadata.save(self.directory.as_ref()).await
}
pub async fn acquire_snapshot(&self) -> SegmentSnapshot {
let acquired = {
let st = self.state.lock().await;
let segment_ids = st.metadata.segment_ids();
self.tracker.acquire(&segment_ids)
};
SegmentSnapshot::with_delete_fn(
Arc::clone(&self.tracker),
acquired,
Arc::clone(&self.delete_fn),
)
}
pub fn tracker(&self) -> Arc<SegmentTracker> {
Arc::clone(&self.tracker)
}
pub fn directory(&self) -> Arc<D> {
Arc::clone(&self.directory)
}
}
#[cfg(feature = "native")]
impl<D: DirectoryWriter + 'static> SegmentManager<D> {
pub async fn commit(&self, new_segments: Vec<(String, u32)>) -> Result<()> {
let mut st = self.state.lock().await;
for (segment_id, num_docs) in new_segments {
if !st.metadata.has_segment(&segment_id) {
st.metadata.add_segment(segment_id.clone(), num_docs);
self.tracker.register(&segment_id);
}
}
st.metadata.save(self.directory.as_ref()).await
}
pub async fn maybe_merge(self: &Arc<Self>) {
let slots_available = {
let mut handles = self.merge_handles.lock().await;
handles.retain(|h| !h.is_finished());
self.max_concurrent_merges.saturating_sub(handles.len())
};
if slots_available == 0 {
log::debug!("[maybe_merge] at max concurrent merges, skipping");
return;
}
let new_handles = {
let st = self.state.lock().await;
let segments: Vec<SegmentInfo> = st
.metadata
.segment_metas
.iter()
.filter(|(id, _)| {
!self.tracker.is_pending_deletion(id) && !self.merge_inventory.contains(id)
})
.map(|(id, info)| SegmentInfo {
id: id.clone(),
num_docs: info.num_docs,
})
.collect();
log::debug!("[maybe_merge] {} eligible segments", segments.len());
let candidates = st.merge_policy.find_merges(&segments);
if candidates.is_empty() {
return;
}
log::debug!(
"[maybe_merge] {} merge candidates, {} slots available",
candidates.len(),
slots_available
);
let mut handles = Vec::new();
for c in candidates {
if handles.len() >= slots_available {
break;
}
if let Some(h) = self.spawn_merge(c.segment_ids) {
handles.push(h);
}
}
handles
};
if !new_handles.is_empty() {
self.merge_handles.lock().await.extend(new_handles);
}
}
fn spawn_merge(self: &Arc<Self>, segment_ids_to_merge: Vec<String>) -> Option<JoinHandle<()>> {
let output_id = SegmentId::new();
let output_hex = output_id.to_hex();
let mut all_ids = segment_ids_to_merge.clone();
all_ids.push(output_hex);
let guard = match self.merge_inventory.try_register(all_ids) {
Some(g) => g,
None => {
log::debug!("[spawn_merge] skipped: segments overlap with active merge");
return None;
}
};
let sm = Arc::clone(self);
let ids = segment_ids_to_merge;
Some(tokio::spawn(async move {
let _guard = guard;
let trained_snap = sm.trained();
let result = Self::do_merge(
sm.directory.as_ref(),
&sm.schema,
&ids,
output_id,
sm.term_cache_blocks,
trained_snap.as_deref(),
)
.await;
match result {
Ok((new_id, doc_count)) => {
if let Err(e) = sm.replace_segments(&ids, new_id, doc_count, false).await {
log::error!("[merge] Failed to replace segments after merge: {:?}", e);
}
}
Err(e) => {
log::error!(
"[merge] Background merge failed for segments {:?}: {:?}",
ids,
e
);
}
}
sm.maybe_merge().await;
}))
}
async fn replace_segments(
&self,
old_ids: &[String],
new_id: String,
doc_count: u32,
reordered: bool,
) -> Result<()> {
self.tracker.register(&new_id);
{
let mut st = self.state.lock().await;
let parent_gen = old_ids
.iter()
.filter_map(|id| st.metadata.segment_metas.get(id))
.map(|info| info.generation)
.max()
.unwrap_or(0);
let ancestors: Vec<String> = old_ids.to_vec();
for id in old_ids {
st.metadata.remove_segment(id);
}
st.metadata
.add_merged_segment(new_id, doc_count, ancestors, parent_gen + 1, reordered);
st.metadata.save(self.directory.as_ref()).await?;
}
let ready_to_delete = self.tracker.mark_for_deletion(old_ids);
for segment_id in ready_to_delete {
let _ = crate::segment::delete_segment(self.directory.as_ref(), segment_id).await;
}
Ok(())
}
pub(crate) async fn do_merge(
directory: &D,
schema: &Arc<crate::dsl::Schema>,
segment_ids_to_merge: &[String],
output_segment_id: SegmentId,
term_cache_blocks: usize,
trained: Option<&TrainedVectorStructures>,
) -> Result<(String, u32)> {
let output_hex = output_segment_id.to_hex();
let load_start = std::time::Instant::now();
let segment_ids: Vec<SegmentId> = segment_ids_to_merge
.iter()
.map(|id_str| {
SegmentId::from_hex(id_str)
.ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", id_str)))
})
.collect::<Result<Vec<_>>>()?;
let schema_arc = Arc::clone(schema);
let futures: Vec<_> = segment_ids
.iter()
.map(|&sid| {
let sch = Arc::clone(&schema_arc);
async move { SegmentReader::open(directory, sid, sch, term_cache_blocks).await }
})
.collect();
let results = futures::future::join_all(futures).await;
let mut readers = Vec::with_capacity(results.len());
let mut total_docs = 0u64;
for (i, result) in results.into_iter().enumerate() {
match result {
Ok(r) => {
total_docs += r.meta().num_docs as u64;
readers.push(r);
}
Err(e) => {
log::error!(
"[merge] Failed to open segment {}: {:?}",
segment_ids_to_merge[i],
e
);
return Err(e);
}
}
}
for (i, reader) in readers.iter().enumerate() {
let meta_docs = reader.meta().num_docs;
let store_docs = reader.store().num_docs();
if store_docs != meta_docs {
return Err(Error::Corruption(format!(
"pre-merge validation: segment {} store has {} docs but meta says {}",
segment_ids_to_merge[i], store_docs, meta_docs
)));
}
}
log::info!(
"[merge] loaded {} segment readers in {:.1}s",
readers.len(),
load_start.elapsed().as_secs_f64()
);
let merger = SegmentMerger::new(Arc::clone(schema));
log::info!(
"[merge] {} segments -> {} (trained={})",
segment_ids_to_merge.len(),
output_hex,
trained.map_or(0, |t| t.centroids.len()),
);
merger
.merge(directory, &readers, output_segment_id, trained)
.await?;
log::info!(
"[merge] total wall-clock: {:.1}s ({} segments, {} docs)",
load_start.elapsed().as_secs_f64(),
readers.len(),
total_docs,
);
if total_docs > u32::MAX as u64 {
return Err(Error::Internal(format!(
"Merged segment doc count ({}) exceeds u32::MAX",
total_docs
)));
}
Ok((output_hex, total_docs as u32))
}
pub async fn abort_merges(&self) {
let handles: Vec<JoinHandle<()>> =
{ std::mem::take(&mut *self.merge_handles.lock().await) };
for h in handles {
h.abort();
}
}
pub async fn wait_for_merging_thread(self: &Arc<Self>) {
let handles: Vec<JoinHandle<()>> =
{ std::mem::take(&mut *self.merge_handles.lock().await) };
for h in handles {
let _ = h.await;
}
}
pub async fn wait_for_all_merges(self: &Arc<Self>) {
loop {
let handles: Vec<JoinHandle<()>> =
{ std::mem::take(&mut *self.merge_handles.lock().await) };
if handles.is_empty() {
break;
}
for h in handles {
let _ = h.await;
}
}
}
pub async fn force_merge(self: &Arc<Self>) -> Result<()> {
const FORCE_MERGE_BATCH: usize = 64;
let max_segment_docs = {
let st = self.state.lock().await;
st.merge_policy.max_segment_docs()
};
self.wait_for_all_merges().await;
loop {
let mut segments: Vec<(String, u32)> = {
let st = self.state.lock().await;
st.metadata
.segment_metas
.iter()
.map(|(id, info)| (id.clone(), info.num_docs))
.collect()
};
if segments.len() < 2 {
return Ok(());
}
segments.sort_by_key(|(_, docs)| *docs);
let max_docs = max_segment_docs.map(|m| m as u64).unwrap_or(u64::MAX);
let mut batch = Vec::new();
let mut batch_docs = 0u64;
for (id, docs) in &segments {
if batch.len() >= FORCE_MERGE_BATCH {
break;
}
let next_total = batch_docs + *docs as u64;
if next_total > max_docs && !batch.is_empty() {
break;
}
batch.push(id.clone());
batch_docs += *docs as u64;
}
if batch.len() < 2 {
return Ok(());
}
log::info!(
"[force_merge] merging batch of {} segments ({} docs)",
batch.len(),
batch_docs
);
let output_id = SegmentId::new();
let output_hex = output_id.to_hex();
let mut all_ids = batch.clone();
all_ids.push(output_hex);
let _guard = match self.merge_inventory.try_register(all_ids) {
Some(g) => g,
None => {
self.wait_for_merging_thread().await;
continue;
}
};
let trained_snap = self.trained();
let (new_segment_id, total_docs) = Self::do_merge(
self.directory.as_ref(),
&self.schema,
&batch,
output_id,
self.term_cache_blocks,
trained_snap.as_deref(),
)
.await?;
self.replace_segments(&batch, new_segment_id, total_docs, false)
.await?;
}
}
pub async fn reorder_segments(self: &Arc<Self>) -> Result<()> {
self.wait_for_all_merges().await;
let segment_ids = self.get_segment_ids().await;
if segment_ids.is_empty() {
log::info!("[reorder] no segments to reorder");
return Ok(());
}
log::info!("[reorder] reordering {} segments", segment_ids.len());
for seg_id in segment_ids {
match self.reorder_single_segment(&seg_id, None).await {
Ok(true) => {}
Ok(false) => log::warn!("[reorder] segment {} skipped (in merge)", seg_id),
Err(e) => return Err(e),
}
}
log::info!("[reorder] all segments reordered");
Ok(())
}
pub async fn unreordered_segment_ids(&self) -> Vec<String> {
let st = self.state.lock().await;
let in_merge = self.merge_inventory.snapshot();
st.metadata
.segment_metas
.iter()
.filter(|(id, info)| !info.reordered && !in_merge.contains(*id))
.map(|(id, _)| id.clone())
.collect()
}
pub async fn reorder_single_segment(
self: &Arc<Self>,
seg_id: &str,
rayon_pool: Option<Arc<rayon::ThreadPool>>,
) -> Result<bool> {
let source_id = SegmentId::from_hex(seg_id)
.ok_or_else(|| Error::Corruption(format!("Invalid segment ID: {}", seg_id)))?;
let output_id = SegmentId::new();
let output_hex = output_id.to_hex();
let all_ids = vec![seg_id.to_string(), output_hex];
let _guard = match self.merge_inventory.try_register(all_ids) {
Some(g) => g,
None => {
log::debug!("[optimizer] segment {} in active merge, skipping", seg_id);
return Ok(false);
}
};
let (new_id, total_docs) = crate::segment::reorder::reorder_segment(
self.directory.as_ref(),
&self.schema,
source_id,
output_id,
self.term_cache_blocks,
crate::segment::reorder::DEFAULT_MEMORY_BUDGET,
rayon_pool,
)
.await?;
self.replace_segments(&[seg_id.to_string()], new_id, total_docs, true)
.await?;
Ok(true)
}
pub async fn cleanup_orphan_segments(&self) -> Result<usize> {
let (registered_set, in_merge_set) = {
let st = self.state.lock().await;
let registered = st
.metadata
.segment_metas
.keys()
.cloned()
.collect::<HashSet<String>>();
let in_merge = self.merge_inventory.snapshot();
(registered, in_merge)
};
let mut orphan_ids: HashSet<String> = HashSet::new();
if let Ok(entries) = self.directory.list_files(std::path::Path::new("")).await {
for entry in entries {
let filename = entry.to_string_lossy();
if filename.starts_with("seg_") && filename.len() > 37 {
let hex_part = &filename[4..36];
if !registered_set.contains(hex_part) && !in_merge_set.contains(hex_part) {
orphan_ids.insert(hex_part.to_string());
}
}
}
}
let mut deleted = 0;
for hex_id in &orphan_ids {
if let Some(segment_id) = SegmentId::from_hex(hex_id)
&& crate::segment::delete_segment(self.directory.as_ref(), segment_id)
.await
.is_ok()
{
deleted += 1;
}
}
Ok(deleted)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inventory_guard_drop_unregisters() {
let inv = Arc::new(MergeInventory::new());
{
let _guard = inv.try_register(vec!["a".into(), "b".into()]).unwrap();
let snap = inv.snapshot();
assert!(snap.contains("a"));
assert!(snap.contains("b"));
}
assert!(inv.snapshot().is_empty());
}
#[test]
fn test_inventory_concurrent_non_overlapping_merges() {
let inv = Arc::new(MergeInventory::new());
let _g1 = inv.try_register(vec!["a".into(), "b".into()]).unwrap();
let _g2 = inv.try_register(vec!["c".into(), "d".into()]).unwrap();
let snap = inv.snapshot();
assert_eq!(snap.len(), 4);
drop(_g1);
let snap = inv.snapshot();
assert_eq!(snap.len(), 2);
assert!(snap.contains("c"));
assert!(snap.contains("d"));
}
#[test]
fn test_inventory_overlapping_merge_rejected() {
let inv = Arc::new(MergeInventory::new());
let _g1 = inv.try_register(vec!["a".into(), "b".into()]).unwrap();
assert!(inv.try_register(vec!["b".into(), "c".into()]).is_none());
drop(_g1);
assert!(inv.try_register(vec!["b".into(), "c".into()]).is_some());
}
#[test]
fn test_inventory_snapshot() {
let inv = Arc::new(MergeInventory::new());
let _g = inv.try_register(vec!["x".into(), "y".into()]).unwrap();
let snap = inv.snapshot();
assert!(snap.contains("x"));
assert!(snap.contains("y"));
assert!(!snap.contains("z"));
}
}