use serde::{Deserialize, Serialize};
use std::io::Write;
use std::sync::Arc;
use parking_lot::RwLock;
use crate::error::Result;
use crate::storage::Storage;
use super::merge_policy::MergePolicy;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SegmentManagerConfig {
pub max_vectors_per_segment: u64,
pub min_vectors_per_segment: u64,
pub max_segments: u32,
pub merge_factor: u32,
}
impl Default for SegmentManagerConfig {
fn default() -> Self {
Self {
max_vectors_per_segment: 1000000,
min_vectors_per_segment: 10000,
max_segments: 100,
merge_factor: 10,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ManagedSegmentInfo {
pub segment_id: String,
pub vector_count: u64,
pub vector_offset: u64,
pub generation: u64,
pub has_deletions: bool,
pub size_bytes: u64,
}
impl ManagedSegmentInfo {
pub fn new(segment_id: String, vector_count: u64, vector_offset: u64, generation: u64) -> Self {
Self {
segment_id,
vector_count,
vector_offset,
generation,
has_deletions: false,
size_bytes: 0,
}
}
pub fn should_merge(&self, config: &SegmentManagerConfig) -> bool {
self.vector_count < config.min_vectors_per_segment
}
}
#[derive(Debug, Clone)]
pub struct MergeCandidate {
pub segments: Vec<ManagedSegmentInfo>,
pub total_vectors: u64,
pub total_size: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MergeStrategy {
Smallest,
MostDeletions,
Adjacent,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum MergeUrgency {
Low,
Medium,
High,
}
#[derive(Debug, Clone)]
pub struct MergePlan {
pub candidates: Vec<MergeCandidate>,
pub strategy: MergeStrategy,
pub urgency: MergeUrgency,
}
#[derive(Debug, Clone)]
pub struct SegmentManagerStats {
pub segment_count: u32,
pub total_vectors: u64,
pub total_size: u64,
pub segments_with_deletions: u32,
pub avg_vectors_per_segment: f64,
}
#[derive(Debug)]
pub struct SegmentManager {
config: SegmentManagerConfig,
storage: Arc<dyn Storage>,
segments: Arc<RwLock<Vec<ManagedSegmentInfo>>>,
next_segment_id: Arc<RwLock<u64>>,
}
impl SegmentManager {
pub fn new(config: SegmentManagerConfig, storage: Arc<dyn Storage>) -> Result<Self> {
let manager = Self {
config,
storage,
segments: Arc::new(RwLock::new(Vec::new())),
next_segment_id: Arc::new(RwLock::new(0)),
};
let _ = manager.load_state();
Ok(manager)
}
fn load_state(&self) -> Result<()> {
let mut reader = match self.storage.open_input("segments.json") {
Ok(r) => r,
Err(_) => return Ok(()),
};
let mut content = Vec::new();
reader.read_to_end(&mut content)?;
if content.is_empty() {
return Ok(());
}
let segments_info: Vec<ManagedSegmentInfo> = serde_json::from_slice(&content)?;
let mut segments = self.segments.write();
*segments = segments_info;
let max_id = segments
.iter()
.filter_map(|s| s.segment_id.strip_prefix("segment_"))
.filter_map(|s| s.parse::<u64>().ok())
.max()
.unwrap_or(0);
*self.next_segment_id.write() = max_id + 1;
Ok(())
}
pub fn save_state(&self) -> Result<()> {
let segments = self.segments.read();
let content = serde_json::to_vec_pretty(&*segments)?;
let mut writer = self.storage.create_output("segments.json")?;
writer.write_all(&content)?;
writer.flush()?;
Ok(())
}
pub fn add_segment(&self, info: ManagedSegmentInfo) -> Result<()> {
let mut segments = self.segments.write();
segments.push(info);
drop(segments);
self.save_state()
}
pub fn remove_segment(&self, segment_id: &str) -> Result<()> {
let mut segments = self.segments.write();
if let Some(pos) = segments.iter().position(|s| s.segment_id == segment_id) {
segments.remove(pos);
drop(segments);
self.save_state()
} else {
Ok(())
}
}
pub fn delete_segment_files(&self, segment_id: &str) -> Result<()> {
let file_name = format!("{}.hnsw", segment_id);
let _ = self.storage.delete_file(&file_name);
Ok(())
}
pub fn update_segment(&self, info: ManagedSegmentInfo) -> Result<()> {
let mut segments = self.segments.write();
if let Some(idx) = segments
.iter()
.position(|s| s.segment_id == info.segment_id)
{
segments[idx] = info;
}
drop(segments);
self.save_state()
}
pub fn get_segment(&self, segment_id: &str) -> Option<ManagedSegmentInfo> {
let segments = self.segments.read();
segments
.iter()
.find(|s| s.segment_id == segment_id)
.cloned()
}
pub fn list_segments(&self) -> Vec<ManagedSegmentInfo> {
let segments = self.segments.read();
segments.clone()
}
pub fn check_merge(&self, policy: &dyn MergePolicy) -> Option<MergeCandidate> {
let segments_lock = self.segments.read();
if let Some(candidate_ids) = policy.candidates(&segments_lock, &self.config) {
let mut total_vectors = 0;
let mut total_size = 0;
let mut candidates = Vec::new();
for id in &candidate_ids {
if let Some(segment) = segments_lock.iter().find(|s| s.segment_id == *id) {
total_vectors += segment.vector_count;
total_size += segment.size_bytes;
candidates.push(segment.clone());
}
}
return Some(MergeCandidate {
segments: candidates,
total_vectors,
total_size,
});
}
None
}
pub fn apply_merge(
&self,
candidate: MergeCandidate,
merged_segment: ManagedSegmentInfo,
) -> Result<()> {
let mut segments_lock = self.segments.write();
let ids_to_remove: std::collections::HashSet<_> =
candidate.segments.iter().map(|s| &s.segment_id).collect();
segments_lock.retain(|s| !ids_to_remove.contains(&s.segment_id));
segments_lock.push(merged_segment);
drop(segments_lock);
self.save_state()?;
for segment in candidate.segments {
self.delete_segment_files(&segment.segment_id)?;
}
Ok(())
}
pub fn total_vectors(&self) -> u64 {
self.segments.read().iter().map(|s| s.vector_count).sum()
}
pub fn generate_segment_id(&self) -> String {
let mut next_id = self.next_segment_id.write();
let id = *next_id;
*next_id += 1;
format!("segment_{:06}", id)
}
pub fn needs_merge(&self) -> bool {
let segments = self.segments.read();
segments.len() as u32 > self.config.max_segments
}
pub fn create_merge_plan(&self, strategy: MergeStrategy) -> Option<MergePlan> {
let segments = self.segments.read();
if segments.len() <= 1 {
return None;
}
let mut segment_list: Vec<_> = segments.iter().cloned().collect();
match strategy {
MergeStrategy::Smallest => {
segment_list.sort_by_key(|s| s.vector_count);
}
MergeStrategy::MostDeletions => {
segment_list.sort_by(|a, b| b.has_deletions.cmp(&a.has_deletions));
}
MergeStrategy::Adjacent => {
segment_list.sort_by_key(|s| s.vector_offset);
}
}
let merge_count = self.config.merge_factor.min(segment_list.len() as u32) as usize;
let to_merge = &segment_list[..merge_count];
let candidate = MergeCandidate {
segments: to_merge.to_vec(),
total_vectors: to_merge.iter().map(|s| s.vector_count).sum(),
total_size: to_merge.iter().map(|s| s.size_bytes).sum(),
};
let urgency = if segments.len() as u32 > self.config.max_segments * 2 {
MergeUrgency::High
} else if segments.len() as u32 > self.config.max_segments {
MergeUrgency::Medium
} else {
MergeUrgency::Low
};
Some(MergePlan {
candidates: vec![candidate],
strategy,
urgency,
})
}
pub fn stats(&self) -> SegmentManagerStats {
let segments = self.segments.read();
let segment_count = segments.len() as u32;
let total_vectors: u64 = segments.iter().map(|s| s.vector_count).sum();
let total_size: u64 = segments.iter().map(|s| s.size_bytes).sum();
let segments_with_deletions = segments.iter().filter(|s| s.has_deletions).count() as u32;
let avg_vectors_per_segment = if segment_count > 0 {
total_vectors as f64 / segment_count as f64
} else {
0.0
};
SegmentManagerStats {
segment_count,
total_vectors,
total_size,
segments_with_deletions,
avg_vectors_per_segment,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::memory::{MemoryStorage, MemoryStorageConfig};
use crate::vector::index::hnsw::segment::merge_policy::SimpleMergePolicy;
fn create_info(id: &str, count: u64) -> ManagedSegmentInfo {
ManagedSegmentInfo {
segment_id: id.to_string(),
vector_count: count,
vector_offset: 0,
generation: 1,
has_deletions: false,
size_bytes: count * 100,
}
}
#[test]
fn test_segment_manager_basic() {
let config = SegmentManagerConfig::default();
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let manager = SegmentManager::new(config, storage).unwrap();
let segment_id = manager.generate_segment_id();
assert_eq!(segment_id, "segment_000000");
let info = ManagedSegmentInfo::new(segment_id.clone(), 1000, 0, 0);
manager.add_segment(info.clone()).unwrap();
let retrieved = manager.get_segment(&segment_id).unwrap();
assert_eq!(retrieved.vector_count, 1000);
}
#[test]
fn test_persistence() {
let config = SegmentManagerConfig::default();
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
{
let manager = SegmentManager::new(config.clone(), storage.clone()).unwrap();
let info = ManagedSegmentInfo::new("segment_000000".to_string(), 1000, 0, 0);
manager.add_segment(info).unwrap();
}
{
let manager = SegmentManager::new(config, storage.clone()).unwrap();
let segments = manager.list_segments();
assert_eq!(segments.len(), 1);
assert_eq!(segments[0].segment_id, "segment_000000");
}
}
#[test]
fn test_check_merge() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let config = SegmentManagerConfig {
max_segments: 5,
merge_factor: 3,
..Default::default()
};
let manager = SegmentManager::new(config, storage).unwrap();
manager.add_segment(create_info("1", 100)).unwrap();
manager.add_segment(create_info("2", 100)).unwrap();
assert!(manager.check_merge(&SimpleMergePolicy::new()).is_none());
manager.add_segment(create_info("3", 100)).unwrap();
manager.add_segment(create_info("4", 100)).unwrap();
manager.add_segment(create_info("5", 100)).unwrap();
manager.add_segment(create_info("6", 100)).unwrap();
let candidate = manager.check_merge(&SimpleMergePolicy::new());
assert!(candidate.is_some());
let candidate = candidate.unwrap();
assert_eq!(candidate.segments.len(), 3);
let ids: Vec<String> = candidate
.segments
.iter()
.map(|s| s.segment_id.clone())
.collect();
assert!(ids.contains(&"1".to_string()));
assert!(ids.contains(&"2".to_string()));
assert!(ids.contains(&"3".to_string()));
}
}