use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::sync::Arc;
use parking_lot::RwLock;
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct DeletionManager {
deleted_ids: Arc<RwLock<HashSet<u64>>>,
deletion_count: Arc<RwLock<u64>>,
}
impl DeletionManager {
pub fn new() -> Self {
Self {
deleted_ids: Arc::new(RwLock::new(HashSet::new())),
deletion_count: Arc::new(RwLock::new(0)),
}
}
pub fn delete(&self, vector_id: u64) -> Result<bool> {
let mut deleted_ids = self.deleted_ids.write();
let inserted = deleted_ids.insert(vector_id);
if inserted {
let mut count = self.deletion_count.write();
*count += 1;
}
Ok(inserted)
}
pub fn delete_batch(&self, vector_ids: &[u64]) -> Result<u64> {
let mut deleted_ids = self.deleted_ids.write();
let mut count = self.deletion_count.write();
let mut deleted_count = 0;
for &vector_id in vector_ids {
if deleted_ids.insert(vector_id) {
*count += 1;
deleted_count += 1;
}
}
Ok(deleted_count)
}
pub fn is_deleted(&self, vector_id: u64) -> bool {
let deleted_ids = self.deleted_ids.read();
deleted_ids.contains(&vector_id)
}
pub fn deletion_count(&self) -> u64 {
*self.deletion_count.read()
}
pub fn deleted_ids(&self) -> Vec<u64> {
let deleted_ids = self.deleted_ids.read();
deleted_ids.iter().copied().collect()
}
pub fn clear(&self) -> Result<()> {
let mut deleted_ids = self.deleted_ids.write();
let mut count = self.deletion_count.write();
deleted_ids.clear();
*count = 0;
Ok(())
}
pub fn undelete(&self, vector_id: u64) -> Result<bool> {
let mut deleted_ids = self.deleted_ids.write();
let removed = deleted_ids.remove(&vector_id);
if removed {
let mut count = self.deletion_count.write();
*count = count.saturating_sub(1);
}
Ok(removed)
}
}
impl Default for DeletionManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SegmentDeletionInfo {
pub segment_id: String,
pub deleted_count: u64,
pub total_count: u64,
pub deletion_ratio: f64,
}
impl SegmentDeletionInfo {
pub fn new(segment_id: String, deleted_count: u64, total_count: u64) -> Self {
let deletion_ratio = if total_count > 0 {
deleted_count as f64 / total_count as f64
} else {
0.0
};
Self {
segment_id,
deleted_count,
total_count,
deletion_ratio,
}
}
pub fn needs_compaction(&self, threshold: f64) -> bool {
self.deletion_ratio >= threshold
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deletion_manager_basic() {
let manager = DeletionManager::new();
assert!(!manager.is_deleted(1));
assert_eq!(manager.deletion_count(), 0);
manager.delete(1).unwrap();
assert!(manager.is_deleted(1));
assert_eq!(manager.deletion_count(), 1);
manager.delete(1).unwrap();
assert_eq!(manager.deletion_count(), 1);
}
#[test]
fn test_deletion_batch() {
let manager = DeletionManager::new();
let ids = vec![1, 2, 3, 4, 5];
let deleted = manager.delete_batch(&ids).unwrap();
assert_eq!(deleted, 5);
assert_eq!(manager.deletion_count(), 5);
let deleted = manager.delete_batch(&ids).unwrap();
assert_eq!(deleted, 0);
assert_eq!(manager.deletion_count(), 5);
}
#[test]
fn test_segment_deletion_info() {
let info = SegmentDeletionInfo::new("seg1".to_string(), 50, 100);
assert_eq!(info.deletion_ratio, 0.5);
assert!(info.needs_compaction(0.3));
assert!(!info.needs_compaction(0.6));
}
}