use crate::error::{RaftError, RaftResult};
use crate::types::NodeId;
use amaters_core::Key;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
pub type ShardId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShardState {
Active,
Splitting,
Merging,
Transferring,
Offline,
}
impl ShardState {
pub fn can_read(&self) -> bool {
matches!(self, ShardState::Active | ShardState::Splitting | ShardState::Transferring)
}
pub fn can_write(&self) -> bool {
matches!(self, ShardState::Active)
}
pub fn as_str(&self) -> &'static str {
match self {
ShardState::Active => "Active",
ShardState::Splitting => "Splitting",
ShardState::Merging => "Merging",
ShardState::Transferring => "Transferring",
ShardState::Offline => "Offline",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct KeyRange {
pub start: Key,
pub end: Key,
}
impl KeyRange {
pub fn new(start: Key, end: Key) -> RaftResult<Self> {
if start >= end {
return Err(RaftError::ConfigError {
message: format!("Invalid key range: start {:?} >= end {:?}", start, end),
});
}
Ok(Self { start, end })
}
pub fn contains(&self, key: &Key) -> bool {
key >= &self.start && key < &self.end
}
pub fn overlaps(&self, other: &KeyRange) -> bool {
self.start < other.end && other.start < self.end
}
pub fn midpoint(&self) -> Key {
let start_bytes = self.start.as_bytes();
let end_bytes = self.end.as_bytes();
let min_len = start_bytes.len().min(end_bytes.len());
let mut mid_bytes = Vec::with_capacity(min_len);
let mut carry = false;
for i in 0..min_len {
let avg = (start_bytes[i] as u16 + end_bytes[i] as u16 + carry as u16) / 2;
mid_bytes.push(avg as u8);
carry = (start_bytes[i] as u16 + end_bytes[i] as u16 + carry as u16) % 2 == 1;
}
if mid_bytes == start_bytes[..min_len] {
if let Some(last) = mid_bytes.last_mut() {
*last = last.saturating_add(1);
}
}
Key::from_slice(&mid_bytes)
}
pub fn full() -> Self {
Self {
start: Key::from_slice(&[0u8]),
end: Key::from_slice(&[0xFFu8; 32]),
}
}
}
#[derive(Debug, Clone)]
pub struct ShardMetadata {
pub id: ShardId,
pub range: KeyRange,
pub state: ShardState,
pub node_id: NodeId,
pub replicas: Vec<NodeId>,
pub estimated_keys: u64,
pub estimated_size_bytes: u64,
pub last_updated: SystemTime,
pub created_at: SystemTime,
pub version: u64,
}
impl ShardMetadata {
pub fn new(id: ShardId, range: KeyRange, node_id: NodeId) -> Self {
let now = SystemTime::now();
Self {
id,
range,
state: ShardState::Active,
node_id,
replicas: Vec::new(),
estimated_keys: 0,
estimated_size_bytes: 0,
last_updated: now,
created_at: now,
version: 1,
}
}
pub fn set_state(&mut self, state: ShardState) {
self.state = state;
self.last_updated = SystemTime::now();
self.version += 1;
}
pub fn update_stats(&mut self, estimated_keys: u64, estimated_size_bytes: u64) {
self.estimated_keys = estimated_keys;
self.estimated_size_bytes = estimated_size_bytes;
self.last_updated = SystemTime::now();
self.version += 1;
}
pub fn add_replica(&mut self, node_id: NodeId) -> RaftResult<()> {
if self.replicas.contains(&node_id) {
return Err(RaftError::ConfigError {
message: format!("Replica {} already exists for shard {}", node_id, self.id),
});
}
self.replicas.push(node_id);
self.last_updated = SystemTime::now();
self.version += 1;
Ok(())
}
pub fn remove_replica(&mut self, node_id: NodeId) -> RaftResult<()> {
let initial_len = self.replicas.len();
self.replicas.retain(|&id| id != node_id);
if self.replicas.len() == initial_len {
return Err(RaftError::ConfigError {
message: format!("Replica {} not found for shard {}", node_id, self.id),
});
}
self.last_updated = SystemTime::now();
self.version += 1;
Ok(())
}
pub fn is_hot(&self, key_threshold: u64, size_threshold: u64) -> bool {
self.estimated_keys > key_threshold || self.estimated_size_bytes > size_threshold
}
pub fn is_cold(&self, key_threshold: u64, size_threshold: u64) -> bool {
self.estimated_keys < key_threshold && self.estimated_size_bytes < size_threshold
}
pub fn is_stale(&self, max_age: Duration) -> bool {
self.last_updated
.elapsed()
.map(|elapsed| elapsed > max_age)
.unwrap_or(false)
}
}
#[derive(Debug, Clone)]
pub struct ShardSplit {
pub source_shard_id: ShardId,
pub left_shard_id: ShardId,
pub right_shard_id: ShardId,
pub split_key: Key,
pub initiated_at: SystemTime,
}
impl ShardSplit {
pub fn new(
source_shard_id: ShardId,
left_shard_id: ShardId,
right_shard_id: ShardId,
split_key: Key,
) -> Self {
Self {
source_shard_id,
left_shard_id,
right_shard_id,
split_key,
initiated_at: SystemTime::now(),
}
}
pub fn create_shards(
&self,
source: &ShardMetadata,
) -> RaftResult<(ShardMetadata, ShardMetadata)> {
let left_range = KeyRange::new(source.range.start.clone(), self.split_key.clone())?;
let mut left_shard = ShardMetadata::new(
self.left_shard_id,
left_range,
source.node_id,
);
left_shard.replicas = source.replicas.clone();
let right_range = KeyRange::new(self.split_key.clone(), source.range.end.clone())?;
let mut right_shard = ShardMetadata::new(
self.right_shard_id,
right_range,
source.node_id,
);
right_shard.replicas = source.replicas.clone();
left_shard.estimated_keys = source.estimated_keys / 2;
left_shard.estimated_size_bytes = source.estimated_size_bytes / 2;
right_shard.estimated_keys = source.estimated_keys / 2;
right_shard.estimated_size_bytes = source.estimated_size_bytes / 2;
Ok((left_shard, right_shard))
}
}
#[derive(Debug, Clone)]
pub struct ShardMerge {
pub left_shard_id: ShardId,
pub right_shard_id: ShardId,
pub target_shard_id: ShardId,
pub initiated_at: SystemTime,
}
impl ShardMerge {
pub fn new(
left_shard_id: ShardId,
right_shard_id: ShardId,
target_shard_id: ShardId,
) -> Self {
Self {
left_shard_id,
right_shard_id,
target_shard_id,
initiated_at: SystemTime::now(),
}
}
pub fn validate(&self, left: &ShardMetadata, right: &ShardMetadata) -> RaftResult<()> {
if left.range.end != right.range.start {
return Err(RaftError::ConfigError {
message: format!(
"Shards {} and {} are not adjacent (left.end={:?}, right.start={:?})",
left.id, right.id, left.range.end, right.range.start
),
});
}
if left.node_id != right.node_id {
return Err(RaftError::ConfigError {
message: format!(
"Shards {} and {} are on different nodes ({} vs {})",
left.id, right.id, left.node_id, right.node_id
),
});
}
Ok(())
}
pub fn create_merged_shard(
&self,
left: &ShardMetadata,
right: &ShardMetadata,
) -> RaftResult<ShardMetadata> {
self.validate(left, right)?;
let merged_range = KeyRange::new(
left.range.start.clone(),
right.range.end.clone(),
)?;
let mut merged = ShardMetadata::new(
self.target_shard_id,
merged_range,
left.node_id,
);
merged.estimated_keys = left.estimated_keys + right.estimated_keys;
merged.estimated_size_bytes = left.estimated_size_bytes + right.estimated_size_bytes;
merged.replicas = left.replicas.clone();
Ok(merged)
}
}
#[derive(Debug, Clone)]
pub struct ShardTransfer {
pub shard_id: ShardId,
pub from_node: NodeId,
pub to_node: NodeId,
pub progress: f64,
pub initiated_at: SystemTime,
pub estimated_completion: Option<SystemTime>,
}
impl ShardTransfer {
pub fn new(shard_id: ShardId, from_node: NodeId, to_node: NodeId) -> Self {
Self {
shard_id,
from_node,
to_node,
progress: 0.0,
initiated_at: SystemTime::now(),
estimated_completion: None,
}
}
pub fn update_progress(&mut self, progress: f64) {
self.progress = progress.clamp(0.0, 1.0);
if progress > 0.0 && progress < 1.0 {
if let Ok(elapsed) = self.initiated_at.elapsed() {
let total_time = elapsed.as_secs_f64() / progress;
let remaining_time = total_time * (1.0 - progress);
self.estimated_completion = Some(
SystemTime::now() + Duration::from_secs_f64(remaining_time)
);
}
}
}
pub fn is_complete(&self) -> bool {
self.progress >= 1.0
}
}
#[derive(Debug, Clone)]
pub struct ShardRegistry {
shards: Arc<parking_lot::RwLock<BTreeMap<ShardId, ShardMetadata>>>,
next_shard_id: Arc<parking_lot::Mutex<ShardId>>,
}
impl ShardRegistry {
pub fn new() -> Self {
Self {
shards: Arc::new(parking_lot::RwLock::new(BTreeMap::new())),
next_shard_id: Arc::new(parking_lot::Mutex::new(1)),
}
}
pub fn allocate_shard_id(&self) -> ShardId {
let mut next_id = self.next_shard_id.lock();
let id = *next_id;
*next_id += 1;
id
}
pub fn register(&self, shard: ShardMetadata) -> RaftResult<()> {
let mut shards = self.shards.write();
for existing in shards.values() {
if existing.range.overlaps(&shard.range) {
return Err(RaftError::ConfigError {
message: format!(
"Shard {} range overlaps with existing shard {} range",
shard.id, existing.id
),
});
}
}
shards.insert(shard.id, shard);
Ok(())
}
pub fn get(&self, shard_id: ShardId) -> Option<ShardMetadata> {
self.shards.read().get(&shard_id).cloned()
}
pub fn update(&self, shard: ShardMetadata) -> RaftResult<()> {
let mut shards = self.shards.write();
shards.insert(shard.id, shard);
Ok(())
}
pub fn remove(&self, shard_id: ShardId) -> RaftResult<()> {
let mut shards = self.shards.write();
shards.remove(&shard_id).ok_or_else(|| RaftError::ConfigError {
message: format!("Shard {} not found", shard_id),
})?;
Ok(())
}
pub fn get_all(&self) -> Vec<ShardMetadata> {
self.shards.read().values().cloned().collect()
}
pub fn get_by_node(&self, node_id: NodeId) -> Vec<ShardMetadata> {
self.shards
.read()
.values()
.filter(|shard| shard.node_id == node_id)
.cloned()
.collect()
}
pub fn find_shard_for_key(&self, key: &Key) -> Option<ShardMetadata> {
self.shards
.read()
.values()
.find(|shard| shard.range.contains(key))
.cloned()
}
pub fn count(&self) -> usize {
self.shards.read().len()
}
}
impl Default for ShardRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shard_state() {
assert!(ShardState::Active.can_read());
assert!(ShardState::Active.can_write());
assert!(ShardState::Splitting.can_read());
assert!(!ShardState::Splitting.can_write());
assert!(!ShardState::Offline.can_read());
assert!(!ShardState::Offline.can_write());
}
#[test]
fn test_key_range_contains() -> RaftResult<()> {
let range = KeyRange::new(
Key::from_str("a"),
Key::from_str("z"),
)?;
assert!(range.contains(&Key::from_str("m")));
assert!(range.contains(&Key::from_str("a")));
assert!(!range.contains(&Key::from_str("z")));
assert!(!range.contains(&Key::from_str("aa")));
Ok(())
}
#[test]
fn test_key_range_overlaps() -> RaftResult<()> {
let range1 = KeyRange::new(Key::from_str("a"), Key::from_str("m"))?;
let range2 = KeyRange::new(Key::from_str("g"), Key::from_str("z"))?;
let range3 = KeyRange::new(Key::from_str("m"), Key::from_str("z"))?;
assert!(range1.overlaps(&range2));
assert!(range2.overlaps(&range1));
assert!(!range1.overlaps(&range3));
Ok(())
}
#[test]
fn test_key_range_midpoint() -> RaftResult<()> {
let range = KeyRange::new(
Key::from_str("a"),
Key::from_str("z"),
)?;
let mid = range.midpoint();
assert!(mid > range.start);
assert!(mid < range.end);
Ok(())
}
#[test]
fn test_shard_metadata_creation() {
let range = KeyRange::new(Key::from_str("a"), Key::from_str("z"))
.expect("valid range");
let shard = ShardMetadata::new(1, range, 100);
assert_eq!(shard.id, 1);
assert_eq!(shard.node_id, 100);
assert_eq!(shard.state, ShardState::Active);
assert_eq!(shard.version, 1);
}
#[test]
fn test_shard_metadata_update_stats() {
let range = KeyRange::new(Key::from_str("a"), Key::from_str("z"))
.expect("valid range");
let mut shard = ShardMetadata::new(1, range, 100);
let initial_version = shard.version;
shard.update_stats(1000, 50000);
assert_eq!(shard.estimated_keys, 1000);
assert_eq!(shard.estimated_size_bytes, 50000);
assert_eq!(shard.version, initial_version + 1);
}
#[test]
fn test_shard_metadata_replicas() -> RaftResult<()> {
let range = KeyRange::new(Key::from_str("a"), Key::from_str("z"))?;
let mut shard = ShardMetadata::new(1, range, 100);
shard.add_replica(101)?;
shard.add_replica(102)?;
assert_eq!(shard.replicas.len(), 2);
assert!(shard.add_replica(101).is_err());
shard.remove_replica(101)?;
assert_eq!(shard.replicas.len(), 1);
assert!(shard.replicas.contains(&102));
Ok(())
}
#[test]
fn test_shard_split() -> RaftResult<()> {
let range = KeyRange::new(Key::from_str("a"), Key::from_str("z"))?;
let mut source = ShardMetadata::new(1, range, 100);
source.update_stats(1000, 100000);
let split = ShardSplit::new(1, 2, 3, Key::from_str("m"));
let (left, right) = split.create_shards(&source)?;
assert_eq!(left.id, 2);
assert_eq!(right.id, 3);
assert_eq!(left.range.end, Key::from_str("m"));
assert_eq!(right.range.start, Key::from_str("m"));
assert_eq!(left.estimated_keys, 500);
assert_eq!(right.estimated_keys, 500);
Ok(())
}
#[test]
fn test_shard_merge() -> RaftResult<()> {
let left_range = KeyRange::new(Key::from_str("a"), Key::from_str("m"))?;
let right_range = KeyRange::new(Key::from_str("m"), Key::from_str("z"))?;
let mut left = ShardMetadata::new(1, left_range, 100);
let mut right = ShardMetadata::new(2, right_range, 100);
left.update_stats(500, 50000);
right.update_stats(500, 50000);
let merge = ShardMerge::new(1, 2, 3);
let merged = merge.create_merged_shard(&left, &right)?;
assert_eq!(merged.id, 3);
assert_eq!(merged.range.start, Key::from_str("a"));
assert_eq!(merged.range.end, Key::from_str("z"));
assert_eq!(merged.estimated_keys, 1000);
assert_eq!(merged.estimated_size_bytes, 100000);
Ok(())
}
#[test]
fn test_shard_transfer() {
let mut transfer = ShardTransfer::new(1, 100, 101);
assert_eq!(transfer.progress, 0.0);
assert!(!transfer.is_complete());
transfer.update_progress(0.5);
assert_eq!(transfer.progress, 0.5);
assert!(!transfer.is_complete());
transfer.update_progress(1.0);
assert!(transfer.is_complete());
}
#[test]
fn test_shard_registry() -> RaftResult<()> {
let registry = ShardRegistry::new();
let id1 = registry.allocate_shard_id();
let id2 = registry.allocate_shard_id();
assert_ne!(id1, id2);
let range1 = KeyRange::new(Key::from_str("a"), Key::from_str("m"))?;
let shard1 = ShardMetadata::new(id1, range1, 100);
registry.register(shard1.clone())?;
let retrieved = registry.get(id1);
assert!(retrieved.is_some());
assert_eq!(retrieved.expect("Shard should be retrieved from registry").id, id1);
let found = registry.find_shard_for_key(&Key::from_str("g"));
assert!(found.is_some());
assert_eq!(found.expect("Shard should be found for key").id, id1);
assert_eq!(registry.count(), 1);
Ok(())
}
#[test]
fn test_shard_registry_overlapping_ranges() -> RaftResult<()> {
let registry = ShardRegistry::new();
let range1 = KeyRange::new(Key::from_str("a"), Key::from_str("m"))?;
let shard1 = ShardMetadata::new(1, range1, 100);
registry.register(shard1)?;
let range2 = KeyRange::new(Key::from_str("g"), Key::from_str("z"))?;
let shard2 = ShardMetadata::new(2, range2, 100);
let result = registry.register(shard2);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_hot_cold_shards() {
let range = KeyRange::new(Key::from_str("a"), Key::from_str("z"))
.expect("valid range");
let mut shard = ShardMetadata::new(1, range, 100);
shard.update_stats(1000, 50000);
assert!(shard.is_hot(500, 25000));
assert!(!shard.is_cold(500, 25000));
shard.update_stats(100, 5000);
assert!(!shard.is_hot(500, 25000));
assert!(shard.is_cold(500, 25000));
}
}