use crate::raft::OxirsNodeId;
use crate::shard::ShardId;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::Display;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{info, warn};
pub type RangeKey = String;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Range {
pub start: Option<RangeKey>,
pub end: Option<RangeKey>,
}
impl Range {
pub fn new(start: Option<RangeKey>, end: Option<RangeKey>) -> Self {
Self { start, end }
}
pub fn unbounded() -> Self {
Self {
start: None,
end: None,
}
}
pub fn from(start: RangeKey) -> Self {
Self {
start: Some(start),
end: None,
}
}
pub fn to(end: RangeKey) -> Self {
Self {
start: None,
end: Some(end),
}
}
pub fn between(start: RangeKey, end: RangeKey) -> Self {
Self {
start: Some(start),
end: Some(end),
}
}
pub fn contains(&self, key: &str) -> bool {
if let Some(ref start) = self.start {
if key < start.as_str() {
return false;
}
}
if let Some(ref end) = self.end {
if key >= end.as_str() {
return false;
}
}
true
}
pub fn overlaps(&self, other: &Range) -> bool {
if let (Some(s1), Some(e2)) = (&self.start, &other.end) {
if s1 >= e2 {
return false;
}
}
if let (Some(s2), Some(e1)) = (&other.start, &self.end) {
if s2 >= e1 {
return false;
}
}
true
}
pub fn split_at(&self, split_key: &str) -> (Range, Range) {
let left = Range {
start: self.start.clone(),
end: Some(split_key.to_string()),
};
let right = Range {
start: Some(split_key.to_string()),
end: self.end.clone(),
};
(left, right)
}
pub fn can_merge_with(&self, other: &Range) -> bool {
matches!((&self.end, &other.start), (Some(e1), Some(s2)) if e1 == s2)
|| matches!((&other.end, &self.start), (Some(e2), Some(s1)) if e2 == s1)
}
pub fn merge_with(&self, other: &Range) -> Option<Range> {
if !self.can_merge_with(other) {
return None;
}
let start = match (&self.start, &other.start) {
(None, _) | (_, None) => None,
(Some(s1), Some(s2)) => Some(s1.min(s2).clone()),
};
let end = match (&self.end, &other.end) {
(None, _) | (_, None) => None,
(Some(e1), Some(e2)) => Some(e1.max(e2).clone()),
};
Some(Range { start, end })
}
}
impl Display for Range {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match (&self.start, &self.end) {
(None, None) => write!(f, "(-∞, +∞)"),
(Some(start), None) => write!(f, "[{start}, +∞)"),
(None, Some(end)) => write!(f, "(-∞, {end})"),
(Some(start), Some(end)) => write!(f, "[{start}, {end})"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RangePartition {
pub partition_id: String,
pub shard_id: ShardId,
pub range: Range,
pub nodes: BTreeSet<OxirsNodeId>,
pub load_stats: LoadStats,
pub created_at: u64,
pub modified_at: u64,
}
impl RangePartition {
pub fn new(partition_id: String, shard_id: ShardId, range: Range) -> Self {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs();
Self {
partition_id,
shard_id,
range,
nodes: BTreeSet::new(),
load_stats: LoadStats::default(),
created_at: timestamp,
modified_at: timestamp,
}
}
pub fn add_node(&mut self, node_id: OxirsNodeId) {
self.nodes.insert(node_id);
self.touch();
}
pub fn remove_node(&mut self, node_id: OxirsNodeId) {
self.nodes.remove(&node_id);
self.touch();
}
pub fn update_load_stats(&mut self, stats: LoadStats) {
self.load_stats = stats;
self.touch();
}
pub fn needs_split(&self, max_load: u64) -> bool {
self.load_stats.key_count > max_load || self.load_stats.data_size > max_load * 1024
}
pub fn can_merge(&self, min_load: u64) -> bool {
self.load_stats.key_count < min_load && self.load_stats.data_size < min_load * 1024
}
fn touch(&mut self) {
self.modified_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs();
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoadStats {
pub key_count: u64,
pub data_size: u64,
pub read_ops_per_sec: f64,
pub write_ops_per_sec: f64,
pub avg_key_size: u64,
pub last_updated: u64,
}
impl Default for LoadStats {
fn default() -> Self {
Self {
key_count: 0,
data_size: 0,
read_ops_per_sec: 0.0,
write_ops_per_sec: 0.0,
avg_key_size: 0,
last_updated: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SplitOperation {
pub operation_id: String,
pub source_partition: String,
pub split_key: RangeKey,
pub left_partition: String,
pub right_partition: String,
pub status: OperationStatus,
pub progress: u8,
pub created_at: u64,
pub completed_at: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MergeOperation {
pub operation_id: String,
pub source_partitions: Vec<String>,
pub target_partition: String,
pub status: OperationStatus,
pub progress: u8,
pub created_at: u64,
pub completed_at: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum OperationStatus {
Planned,
InProgress,
Completed,
Failed,
Cancelled,
}
pub struct RangePartitionManager {
partitions: Arc<RwLock<BTreeMap<String, RangePartition>>>,
split_operations: Arc<RwLock<BTreeMap<String, SplitOperation>>>,
merge_operations: Arc<RwLock<BTreeMap<String, MergeOperation>>>,
config: RangePartitionConfig,
}
#[derive(Debug, Clone)]
pub struct RangePartitionConfig {
pub max_keys_per_partition: u64,
pub min_keys_per_partition: u64,
pub max_partition_size: u64,
pub min_partition_size: u64,
pub auto_rebalance: bool,
pub rebalance_interval: u64,
}
impl Default for RangePartitionConfig {
fn default() -> Self {
Self {
max_keys_per_partition: 1_000_000,
min_keys_per_partition: 100_000,
max_partition_size: 1024 * 1024 * 1024, min_partition_size: 10 * 1024 * 1024, auto_rebalance: true,
rebalance_interval: 300, }
}
}
impl RangePartitionManager {
pub fn new(config: RangePartitionConfig) -> Self {
Self {
partitions: Arc::new(RwLock::new(BTreeMap::new())),
split_operations: Arc::new(RwLock::new(BTreeMap::new())),
merge_operations: Arc::new(RwLock::new(BTreeMap::new())),
config,
}
}
pub async fn create_initial_partition(&self, shard_id: ShardId) -> Result<String> {
let partition_id = uuid::Uuid::new_v4().to_string();
let partition = RangePartition::new(partition_id.clone(), shard_id, Range::unbounded());
let mut partitions = self.partitions.write().await;
partitions.insert(partition_id.clone(), partition);
info!(
"Created initial partition {} for shard {}",
partition_id, shard_id
);
Ok(partition_id)
}
pub async fn find_partition_for_key(&self, key: &str) -> Option<String> {
let partitions = self.partitions.read().await;
for (partition_id, partition) in partitions.iter() {
if partition.range.contains(key) {
return Some(partition_id.clone());
}
}
None
}
pub async fn get_all_partitions(&self) -> Vec<RangePartition> {
let partitions = self.partitions.read().await;
partitions.values().cloned().collect()
}
pub async fn get_partition(&self, partition_id: &str) -> Option<RangePartition> {
let partitions = self.partitions.read().await;
partitions.get(partition_id).cloned()
}
pub async fn update_partition_load(
&self,
partition_id: &str,
load_stats: LoadStats,
) -> Result<()> {
let mut partitions = self.partitions.write().await;
if let Some(partition) = partitions.get_mut(partition_id) {
partition.update_load_stats(load_stats);
info!("Updated load stats for partition {}", partition_id);
Ok(())
} else {
Err(anyhow::anyhow!("Partition {} not found", partition_id))
}
}
pub async fn split_partition(&self, partition_id: &str, split_key: &str) -> Result<String> {
let operation_id = uuid::Uuid::new_v4().to_string();
let partition = {
let partitions = self.partitions.read().await;
partitions
.get(partition_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Partition {} not found", partition_id))?
};
if !partition.range.contains(split_key) {
return Err(anyhow::anyhow!(
"Split key '{}' is not within partition range {}",
split_key,
partition.range
));
}
let left_partition_id = format!("{operation_id}-left");
let right_partition_id = format!("{operation_id}-right");
let split_op = SplitOperation {
operation_id: operation_id.clone(),
source_partition: partition_id.to_string(),
split_key: split_key.to_string(),
left_partition: left_partition_id.clone(),
right_partition: right_partition_id.clone(),
status: OperationStatus::Planned,
progress: 0,
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
completed_at: None,
};
{
let mut operations = self.split_operations.write().await;
operations.insert(operation_id.clone(), split_op);
}
let manager = self.clone();
let operation_id_clone = operation_id.clone();
tokio::spawn(async move {
if let Err(e) = manager.execute_split_operation(&operation_id_clone).await {
warn!("Split operation {} failed: {}", operation_id_clone, e);
}
});
info!(
"Started split operation {} for partition {}",
operation_id, partition_id
);
Ok(operation_id)
}
pub async fn merge_partitions(&self, partition_ids: Vec<String>) -> Result<String> {
let operation_id = uuid::Uuid::new_v4().to_string();
if partition_ids.len() < 2 {
return Err(anyhow::anyhow!("Need at least 2 partitions to merge"));
}
{
let partitions = self.partitions.read().await;
let mut partition_ranges = Vec::new();
for partition_id in &partition_ids {
let partition = partitions
.get(partition_id)
.ok_or_else(|| anyhow::anyhow!("Partition {} not found", partition_id))?;
partition_ranges.push((partition_id.clone(), partition.range.clone()));
}
partition_ranges.sort_by(|a, b| match (&a.1.start, &b.1.start) {
(None, Some(_)) => std::cmp::Ordering::Less,
(Some(_), None) => std::cmp::Ordering::Greater,
(None, None) => std::cmp::Ordering::Equal,
(Some(a), Some(b)) => a.cmp(b),
});
}
let target_partition_id = format!("{operation_id}-merged");
let merge_op = MergeOperation {
operation_id: operation_id.clone(),
source_partitions: partition_ids,
target_partition: target_partition_id,
status: OperationStatus::Planned,
progress: 0,
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
completed_at: None,
};
{
let mut operations = self.merge_operations.write().await;
operations.insert(operation_id.clone(), merge_op);
}
let manager = self.clone();
let operation_id_clone = operation_id.clone();
tokio::spawn(async move {
if let Err(e) = manager.execute_merge_operation(&operation_id_clone).await {
warn!("Merge operation {} failed: {}", operation_id_clone, e);
}
});
info!("Started merge operation {}", operation_id);
Ok(operation_id)
}
async fn execute_split_operation(&self, operation_id: &str) -> Result<()> {
{
let mut operations = self.split_operations.write().await;
if let Some(op) = operations.get_mut(operation_id) {
op.status = OperationStatus::InProgress;
op.progress = 10;
}
}
let operation = {
let operations = self.split_operations.read().await;
operations
.get(operation_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Split operation {} not found", operation_id))?
};
let source_partition = {
let partitions = self.partitions.read().await;
partitions
.get(&operation.source_partition)
.cloned()
.ok_or_else(|| {
anyhow::anyhow!("Source partition {} not found", operation.source_partition)
})?
};
let (left_range, right_range) = source_partition.range.split_at(&operation.split_key);
let mut left_partition = RangePartition::new(
operation.left_partition.clone(),
source_partition.shard_id,
left_range,
);
left_partition.nodes = source_partition.nodes.clone();
let mut right_partition = RangePartition::new(
operation.right_partition.clone(),
source_partition.shard_id,
right_range,
);
right_partition.nodes = source_partition.nodes.clone();
{
let mut operations = self.split_operations.write().await;
if let Some(op) = operations.get_mut(operation_id) {
op.progress = 50;
}
}
{
let mut partitions = self.partitions.write().await;
partitions.remove(&operation.source_partition);
partitions.insert(operation.left_partition.clone(), left_partition);
partitions.insert(operation.right_partition.clone(), right_partition);
}
{
let mut operations = self.split_operations.write().await;
if let Some(op) = operations.get_mut(operation_id) {
op.status = OperationStatus::Completed;
op.progress = 100;
op.completed_at = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
);
}
}
info!("Completed split operation {}", operation_id);
Ok(())
}
async fn execute_merge_operation(&self, operation_id: &str) -> Result<()> {
let operation = {
let operations = self.merge_operations.read().await;
operations
.get(operation_id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Merge operation {} not found", operation_id))?
};
{
let mut operations = self.merge_operations.write().await;
if let Some(op) = operations.get_mut(operation_id) {
op.status = OperationStatus::InProgress;
op.progress = 10;
}
}
let source_partitions: Vec<RangePartition> = {
let partitions = self.partitions.read().await;
operation
.source_partitions
.iter()
.map(|id| partitions.get(id).cloned())
.collect::<Option<Vec<_>>>()
.ok_or_else(|| anyhow::anyhow!("Some source partitions not found"))?
};
let mut merged_range = source_partitions[0].range.clone();
for partition in &source_partitions[1..] {
if let Some(new_range) = merged_range.merge_with(&partition.range) {
merged_range = new_range;
}
}
let merged_partition = RangePartition::new(
operation.target_partition.clone(),
source_partitions[0].shard_id,
merged_range,
);
{
let mut partitions = self.partitions.write().await;
for partition_id in &operation.source_partitions {
partitions.remove(partition_id);
}
partitions.insert(operation.target_partition.clone(), merged_partition);
}
{
let mut operations = self.merge_operations.write().await;
if let Some(op) = operations.get_mut(operation_id) {
op.status = OperationStatus::Completed;
op.progress = 100;
op.completed_at = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs(),
);
}
}
info!("Completed merge operation {}", operation_id);
Ok(())
}
pub async fn check_rebalancing_needed(&self) -> Vec<String> {
let mut partitions_needing_split = Vec::new();
let mut partitions_needing_merge = Vec::new();
{
let partitions = self.partitions.read().await;
for (partition_id, partition) in partitions.iter() {
if partition.needs_split(self.config.max_keys_per_partition) {
partitions_needing_split.push(partition_id.clone());
} else if partition.can_merge(self.config.min_keys_per_partition) {
partitions_needing_merge.push(partition_id.clone());
}
}
}
info!(
"Rebalancing check: {} partitions need split, {} need merge",
partitions_needing_split.len(),
partitions_needing_merge.len()
);
[partitions_needing_split, partitions_needing_merge].concat()
}
pub async fn get_split_operation_status(&self, operation_id: &str) -> Option<SplitOperation> {
let operations = self.split_operations.read().await;
operations.get(operation_id).cloned()
}
pub async fn get_merge_operation_status(&self, operation_id: &str) -> Option<MergeOperation> {
let operations = self.merge_operations.read().await;
operations.get(operation_id).cloned()
}
}
impl Clone for RangePartitionManager {
fn clone(&self) -> Self {
Self {
partitions: Arc::clone(&self.partitions),
split_operations: Arc::clone(&self.split_operations),
merge_operations: Arc::clone(&self.merge_operations),
config: self.config.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_range_contains() {
let range = Range::between("b".to_string(), "f".to_string());
assert!(!range.contains("a"));
assert!(range.contains("b"));
assert!(range.contains("d"));
assert!(!range.contains("f"));
assert!(!range.contains("g"));
}
#[test]
fn test_range_split() {
let range = Range::between("a".to_string(), "z".to_string());
let (left, right) = range.split_at("m");
assert_eq!(left.start, Some("a".to_string()));
assert_eq!(left.end, Some("m".to_string()));
assert_eq!(right.start, Some("m".to_string()));
assert_eq!(right.end, Some("z".to_string()));
}
#[test]
fn test_range_merge() {
let left = Range::between("a".to_string(), "m".to_string());
let right = Range::between("m".to_string(), "z".to_string());
assert!(left.can_merge_with(&right));
let merged = left.merge_with(&right).unwrap();
assert_eq!(merged.start, Some("a".to_string()));
assert_eq!(merged.end, Some("z".to_string()));
}
#[tokio::test]
async fn test_partition_manager_basic() {
let config = RangePartitionConfig::default();
let manager = RangePartitionManager::new(config);
let partition_id = manager.create_initial_partition(1).await.unwrap();
let partition = manager.get_partition(&partition_id).await.unwrap();
assert_eq!(partition.shard_id, 1);
assert_eq!(partition.range, Range::unbounded());
}
#[tokio::test]
async fn test_find_partition_for_key() {
let config = RangePartitionConfig::default();
let manager = RangePartitionManager::new(config);
let partition_id = manager.create_initial_partition(1).await.unwrap();
let found = manager.find_partition_for_key("any_key").await;
assert_eq!(found, Some(partition_id));
}
#[test]
fn test_load_stats_default() {
let stats = LoadStats::default();
assert_eq!(stats.key_count, 0);
assert_eq!(stats.data_size, 0);
assert_eq!(stats.read_ops_per_sec, 0.0);
assert_eq!(stats.write_ops_per_sec, 0.0);
}
}