use crate::error::{Error, Result};
use std::collections::HashMap;
use std::path::PathBuf;
#[cfg(feature = "write-support")]
#[derive(Debug, Clone)]
struct SSTableMetadata {
data_path: PathBuf,
data_size: u64,
}
#[cfg(feature = "write-support")]
impl SSTableMetadata {
fn new(data_path: PathBuf, data_size: u64) -> Self {
Self {
data_path,
data_size,
}
}
}
#[cfg(feature = "write-support")]
#[derive(Debug, Clone)]
pub struct STCSPolicy {
pub min_threshold: usize,
pub max_threshold: usize,
pub bucket_low: f64,
pub bucket_high: f64,
pub min_sstable_size: u64,
}
#[cfg(feature = "write-support")]
impl STCSPolicy {
pub const DEFAULT_MIN_SSTABLE_SIZE: u64 = 50 * 1024 * 1024;
pub fn new(
min_threshold: usize,
max_threshold: usize,
bucket_low: f64,
bucket_high: f64,
min_sstable_size: u64,
) -> Result<Self> {
if min_threshold == 0 {
return Err(Error::InvalidInput(
"min_threshold must be greater than 0".to_string(),
));
}
if max_threshold < min_threshold {
return Err(Error::InvalidInput(format!(
"max_threshold ({}) must be >= min_threshold ({})",
max_threshold, min_threshold
)));
}
if bucket_high <= bucket_low {
return Err(Error::InvalidInput(format!(
"bucket_high ({}) must be > bucket_low ({})",
bucket_high, bucket_low
)));
}
if bucket_low <= 0.0 {
return Err(Error::InvalidInput(format!(
"bucket_low ({}) must be > 0.0",
bucket_low
)));
}
Ok(Self {
min_threshold,
max_threshold,
bucket_low,
bucket_high,
min_sstable_size,
})
}
fn group_into_buckets(&self, sstables: &[SSTableMetadata]) -> Vec<Vec<SSTableMetadata>> {
if sstables.is_empty() {
return Vec::new();
}
let mut sorted = sstables.to_vec();
sorted.sort_by_key(|s| s.data_size);
let mut buckets: HashMap<u64, Vec<SSTableMetadata>> = HashMap::new();
for sstable in sorted {
let size = sstable.data_size;
let mut found_bucket = false;
let mut old_average = 0u64;
for (&avg_size, _bucket) in buckets.iter() {
let within_ratio = (size as f64) >= (avg_size as f64 * self.bucket_low)
&& (size as f64) <= (avg_size as f64 * self.bucket_high);
let both_small = size < self.min_sstable_size && avg_size < self.min_sstable_size;
if within_ratio || both_small {
old_average = avg_size;
found_bucket = true;
break;
}
}
if found_bucket {
if let Some(mut bucket) = buckets.remove(&old_average) {
let total_size = (bucket.len() as u64).saturating_mul(old_average);
let new_average = total_size.saturating_add(size) / (bucket.len() as u64 + 1);
bucket.push(sstable);
buckets.insert(new_average, bucket);
}
} else {
buckets.insert(size, vec![sstable]);
}
}
buckets.into_values().collect()
}
}
#[cfg(feature = "write-support")]
impl Default for STCSPolicy {
fn default() -> Self {
Self {
min_threshold: 4,
max_threshold: 32,
bucket_low: 0.5,
bucket_high: 1.5,
min_sstable_size: Self::DEFAULT_MIN_SSTABLE_SIZE,
}
}
}
#[cfg(feature = "write-support")]
impl STCSPolicy {
fn select_merge_internal(&self, candidates: &[PathBuf]) -> Result<Vec<PathBuf>> {
if candidates.len() < self.min_threshold {
return Ok(Vec::new());
}
let mut sstables = Vec::new();
for path in candidates {
let metadata = std::fs::metadata(path).map_err(|e| {
Error::Storage(format!(
"Failed to read SSTable metadata for {:?}: {}",
path, e
))
})?;
sstables.push(SSTableMetadata::new(path.clone(), metadata.len()));
}
let buckets = self.group_into_buckets(&sstables);
for bucket in buckets {
if bucket.len() >= self.min_threshold {
let selected: Vec<PathBuf> = bucket
.into_iter()
.take(self.max_threshold)
.map(|s| s.data_path)
.collect();
return Ok(selected);
}
}
Ok(Vec::new())
}
}
#[cfg(feature = "write-support")]
impl super::MergePolicy for STCSPolicy {
fn select_merge(&self, candidates: &[PathBuf]) -> Result<Vec<PathBuf>> {
self.select_merge_internal(candidates)
}
}
#[cfg(all(test, feature = "write-support"))]
mod tests {
use super::*;
use crate::storage::write_engine::MergePolicy;
use std::path::PathBuf;
use tempfile::TempDir;
fn create_sstable(generation: u64, size_mb: u64) -> SSTableMetadata {
SSTableMetadata::new(
PathBuf::from(format!("nb-{}-big-Data.db", generation)),
size_mb * 1024 * 1024,
)
}
fn create_temp_sstables(sizes_mb: &[u64]) -> (TempDir, Vec<PathBuf>) {
let temp_dir = TempDir::new().unwrap();
let mut paths = Vec::new();
for (i, &size_mb) in sizes_mb.iter().enumerate() {
let path = temp_dir.path().join(format!("nb-{}-big-Data.db", i + 1));
let size_bytes = size_mb * 1024 * 1024;
let file = std::fs::File::create(&path).unwrap();
file.set_len(size_bytes).unwrap();
paths.push(path);
}
(temp_dir, paths)
}
#[test]
fn test_stcs_policy_default() {
let policy = STCSPolicy::default();
assert_eq!(policy.min_threshold, 4);
assert_eq!(policy.max_threshold, 32);
assert_eq!(policy.bucket_low, 0.5);
assert_eq!(policy.bucket_high, 1.5);
assert_eq!(policy.min_sstable_size, 50 * 1024 * 1024);
}
#[test]
fn test_stcs_policy_new_validates_min_threshold() {
let result = STCSPolicy::new(0, 32, 0.5, 1.5, 50 * 1024 * 1024);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("min_threshold"));
}
#[test]
fn test_stcs_policy_new_validates_max_threshold() {
let result = STCSPolicy::new(10, 5, 0.5, 1.5, 50 * 1024 * 1024);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("max_threshold"));
}
#[test]
fn test_stcs_policy_new_validates_bucket_ratio() {
let result = STCSPolicy::new(4, 32, 1.5, 0.5, 50 * 1024 * 1024);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("bucket_high"));
}
#[test]
fn test_stcs_policy_new_validates_bucket_low_positive() {
let result = STCSPolicy::new(4, 32, 0.0, 1.5, 50 * 1024 * 1024);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("bucket_low"));
}
#[test]
fn test_stcs_no_compaction_below_threshold() {
let policy = STCSPolicy::default();
let (_temp, paths) = create_temp_sstables(&[100, 100, 100]);
let result = policy.select_merge(&paths).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_stcs_compaction_at_threshold() {
let policy = STCSPolicy::default();
let (_temp, paths) = create_temp_sstables(&[100, 100, 100, 100]);
let result = policy.select_merge(&paths).unwrap();
assert_eq!(result.len(), 4);
}
#[test]
fn test_stcs_bucket_grouping_same_size() {
let policy = STCSPolicy::default();
let sstables = vec![
create_sstable(1, 100),
create_sstable(2, 100),
create_sstable(3, 100),
create_sstable(4, 100),
create_sstable(5, 100),
];
let buckets = policy.group_into_buckets(&sstables);
assert_eq!(buckets.len(), 1);
assert_eq!(buckets[0].len(), 5);
}
#[test]
fn test_stcs_bucket_grouping_within_ratio() {
let policy = STCSPolicy::default();
let sstables = vec![
create_sstable(1, 100),
create_sstable(2, 120), create_sstable(3, 80), create_sstable(4, 110), ];
let buckets = policy.group_into_buckets(&sstables);
assert_eq!(buckets.len(), 1);
assert_eq!(buckets[0].len(), 4);
}
#[test]
fn test_stcs_bucket_grouping_outside_ratio() {
let policy = STCSPolicy::default();
let sstables = vec![
create_sstable(1, 100),
create_sstable(2, 100),
create_sstable(3, 100),
create_sstable(4, 100),
create_sstable(5, 200), create_sstable(6, 200),
create_sstable(7, 200),
create_sstable(8, 200),
];
let buckets = policy.group_into_buckets(&sstables);
assert_eq!(buckets.len(), 2);
let mut bucket_sizes: Vec<_> = buckets.iter().map(|b| b.len()).collect();
bucket_sizes.sort();
assert_eq!(bucket_sizes, vec![4, 4]);
}
#[test]
fn test_stcs_small_sstables_grouped_together() {
let policy = STCSPolicy::default();
let sstables = vec![
create_sstable(1, 10), create_sstable(2, 20), create_sstable(3, 30), create_sstable(4, 40), create_sstable(5, 100), ];
let buckets = policy.group_into_buckets(&sstables);
assert_eq!(buckets.len(), 2);
let small_bucket = buckets.iter().find(|b| b.len() == 4);
assert!(small_bucket.is_some());
}
#[test]
fn test_stcs_respects_max_threshold() {
let policy = STCSPolicy::default();
let sizes: Vec<u64> = (1..=50).map(|_| 100).collect();
let (_temp, paths) = create_temp_sstables(&sizes);
let result = policy.select_merge(&paths).unwrap();
assert_eq!(result.len(), 32);
}
#[test]
fn test_stcs_empty_input() {
let policy = STCSPolicy::default();
let paths = vec![];
let result = policy.select_merge(&paths).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_stcs_multiple_buckets_selects_first_eligible() {
let policy = STCSPolicy::default();
let (_temp, paths) = create_temp_sstables(&[
100, 100, 100, 100, 500, 500, 500, 500, 500, ]);
let result = policy.select_merge(&paths).unwrap();
assert!(result.len() >= 4);
}
#[test]
fn test_stcs_varied_sizes() {
let policy = STCSPolicy::default();
let sstables = vec![
create_sstable(1, 1), create_sstable(2, 2), create_sstable(3, 3), create_sstable(4, 5), create_sstable(5, 100), create_sstable(6, 110), create_sstable(7, 120), create_sstable(8, 130), create_sstable(9, 1000), ];
let buckets = policy.group_into_buckets(&sstables);
assert!(buckets.len() >= 2);
}
#[test]
fn test_sstable_metadata_new() {
let metadata = SSTableMetadata::new(PathBuf::from("test.db"), 12345);
assert_eq!(metadata.data_path, PathBuf::from("test.db"));
assert_eq!(metadata.data_size, 12345);
}
#[test]
fn test_stcs_edge_case_exact_boundary() {
let policy = STCSPolicy::default();
let sstables = vec![
create_sstable(1, 100),
create_sstable(2, 50), create_sstable(3, 150), ];
let buckets = policy.group_into_buckets(&sstables);
assert!(!buckets.is_empty());
}
#[test]
fn test_stcs_policy_clone() {
let policy = STCSPolicy::default();
let cloned = policy.clone();
assert_eq!(policy.min_threshold, cloned.min_threshold);
assert_eq!(policy.max_threshold, cloned.max_threshold);
}
#[test]
fn test_sstable_metadata_clone() {
let metadata = SSTableMetadata::new(PathBuf::from("test.db"), 12345);
let cloned = metadata.clone();
assert_eq!(metadata.data_path, cloned.data_path);
assert_eq!(metadata.data_size, cloned.data_size);
}
}