use std::sync::Arc;
use std::time::Duration;
use parking_lot::RwLock;
use tokio::sync::Notify;
use tracing::{debug, info, warn};
use crate::cluster_command::ClusterCommand;
use crate::log::Command;
use crate::node::RaftNode;
use crate::placement::{PlacementCoordinator, PlacementPolicy};
use crate::shard::ShardRegistry;
#[derive(Debug, Clone)]
pub struct PlacementSchedulerConfig {
pub tick_interval: Duration,
pub max_actions_per_tick: usize,
pub imbalance_threshold: f64,
}
impl Default for PlacementSchedulerConfig {
fn default() -> Self {
Self {
tick_interval: Duration::from_secs(30),
max_actions_per_tick: 5,
imbalance_threshold: 0.2,
}
}
}
fn compute_imbalance(registry: &ShardRegistry) -> f64 {
use std::collections::HashMap;
let shards = registry.get_all();
if shards.is_empty() {
return 0.0;
}
let mut counts: HashMap<crate::types::NodeId, usize> = HashMap::new();
for shard in &shards {
*counts.entry(shard.node_id).or_insert(0) += 1;
}
if counts.len() < 2 {
return 0.0;
}
let max = counts.values().copied().max().unwrap_or(0) as f64;
let min = counts.values().copied().min().unwrap_or(0) as f64;
let mean = shards.len() as f64 / counts.len() as f64;
if mean == 0.0 {
return 0.0;
}
(max - min) / mean
}
#[derive(Debug, Clone)]
pub struct PlacementSchedulerHandle {
stop: Arc<Notify>,
}
impl PlacementSchedulerHandle {
pub fn stop(&self) {
self.stop.notify_one();
}
}
pub struct PlacementScheduler {
node: Arc<RaftNode>,
registry: Arc<RwLock<ShardRegistry>>,
coordinator: PlacementCoordinator,
config: PlacementSchedulerConfig,
stop: Arc<Notify>,
}
impl PlacementScheduler {
pub fn new(
node: Arc<RaftNode>,
registry: Arc<RwLock<ShardRegistry>>,
policy: PlacementPolicy,
config: PlacementSchedulerConfig,
) -> Self {
Self {
node,
registry,
coordinator: PlacementCoordinator::new(policy),
config,
stop: Arc::new(Notify::new()),
}
}
pub fn handle(&self) -> PlacementSchedulerHandle {
PlacementSchedulerHandle {
stop: Arc::clone(&self.stop),
}
}
pub fn stop_signal(&self) -> Arc<Notify> {
Arc::clone(&self.stop)
}
pub async fn run(self) {
info!(
tick_interval_secs = self.config.tick_interval.as_secs_f64(),
max_actions_per_tick = self.config.max_actions_per_tick,
"PlacementScheduler: started",
);
loop {
tokio::select! {
_ = tokio::time::sleep(self.config.tick_interval) => {
self.tick().await;
}
_ = self.stop.notified() => {
info!("PlacementScheduler: stop signal received, exiting");
break;
}
}
}
}
pub(crate) async fn tick(&self) {
if !self.node.is_leader() {
debug!("PlacementScheduler: skipping tick — not leader");
return;
}
let imbalance = {
let registry = self.registry.read();
compute_imbalance(®istry)
};
if imbalance < self.config.imbalance_threshold {
debug!(
imbalance,
threshold = self.config.imbalance_threshold,
"PlacementScheduler: imbalance below threshold, skipping tick",
);
return;
}
let plan = {
let registry = self.registry.read();
match self.coordinator.plan(®istry) {
Ok(p) => p,
Err(e) => {
warn!(error = ?e, "PlacementScheduler: plan() failed; skipping tick");
return;
}
}
};
if plan.is_empty() {
debug!("PlacementScheduler: no placement actions needed");
return;
}
info!(
action_count = plan.len(),
"PlacementScheduler: proposing placement actions",
);
for action in plan.actions.iter().take(self.config.max_actions_per_tick) {
let cmd = ClusterCommand::from_placement_action(action);
let encoded = cmd.encode();
match self.node.propose(Command::new(encoded)) {
Ok(index) => {
debug!(
log_index = index,
variant = ?cmd.tag(),
"PlacementScheduler: proposed action",
);
}
Err(e) => {
warn!(
error = ?e,
"PlacementScheduler: failed to propose action \
(likely stepped down from leadership); aborting this tick",
);
break;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cluster_command::ClusterCommand;
use crate::types::RaftConfig;
use amaters_core::Key;
use std::time::Duration;
fn make_follower_node() -> Arc<RaftNode> {
let config = RaftConfig::new(1, vec![1, 2, 3]);
Arc::new(RaftNode::new(config).expect("RaftNode::new must succeed for valid config"))
}
fn make_registry() -> Arc<RwLock<ShardRegistry>> {
Arc::new(RwLock::new(ShardRegistry::new()))
}
#[test]
fn test_scheduler_config_defaults() {
let cfg = PlacementSchedulerConfig::default();
assert_eq!(
cfg.tick_interval,
Duration::from_secs(30),
"default tick_interval must be 30 seconds",
);
assert_eq!(
cfg.max_actions_per_tick, 5,
"default max_actions_per_tick must be 5",
);
}
#[test]
fn test_placement_scheduler_new() {
let node = make_follower_node();
let registry = make_registry();
let policy = PlacementPolicy::default_policy();
let config = PlacementSchedulerConfig::default();
let scheduler =
PlacementScheduler::new(Arc::clone(&node), Arc::clone(®istry), policy, config);
let sig1 = scheduler.stop_signal();
let sig2 = Arc::clone(&sig1);
assert!(
Arc::ptr_eq(&sig1, &sig2),
"cloned stop signals must share the same Arc",
);
let handle = scheduler.handle();
assert!(
Arc::ptr_eq(&sig1, &handle.stop),
"handle stop must share the same Arc as stop_signal()",
);
}
#[tokio::test]
async fn test_scheduler_skips_when_not_leader() {
let node = make_follower_node();
let registry = make_registry();
let policy = PlacementPolicy::default_policy();
let config = PlacementSchedulerConfig {
tick_interval: Duration::from_secs(3600),
max_actions_per_tick: 5,
imbalance_threshold: 0.2,
};
let scheduler =
PlacementScheduler::new(Arc::clone(&node), Arc::clone(®istry), policy, config);
let log_len_before = node.last_log_index();
scheduler.tick().await;
let log_len_after = node.last_log_index();
assert_eq!(
log_len_before, log_len_after,
"tick() must not append any entries when node is not leader",
);
}
#[tokio::test]
async fn test_scheduler_exits_on_stop_signal() {
let node = make_follower_node();
let registry = make_registry();
let policy = PlacementPolicy::default_policy();
let config = PlacementSchedulerConfig {
tick_interval: Duration::from_secs(3600),
max_actions_per_tick: 5,
imbalance_threshold: 0.2,
};
let scheduler =
PlacementScheduler::new(Arc::clone(&node), Arc::clone(®istry), policy, config);
let stop = scheduler.stop_signal();
let join = tokio::spawn(scheduler.run());
stop.notify_one();
tokio::time::timeout(Duration::from_secs(1), join)
.await
.expect("scheduler must exit within 1 second after stop signal")
.expect("scheduler task must not panic");
}
#[tokio::test]
async fn test_handle_stop_exits_run_loop() {
let node = make_follower_node();
let registry = make_registry();
let policy = PlacementPolicy::default_policy();
let config = PlacementSchedulerConfig {
tick_interval: Duration::from_secs(3600),
max_actions_per_tick: 5,
imbalance_threshold: 0.2,
};
let scheduler =
PlacementScheduler::new(Arc::clone(&node), Arc::clone(®istry), policy, config);
let handle = scheduler.handle();
let join = tokio::spawn(scheduler.run());
handle.stop();
tokio::time::timeout(Duration::from_secs(1), join)
.await
.expect("scheduler must exit within 1 second after handle.stop()")
.expect("scheduler task must not panic");
}
#[test]
fn test_cluster_command_round_trip_in_scheduler() {
let split_cmd = ClusterCommand::PlaceSplit {
shard_id: 10,
split_key: Key::from_slice(&[0x80u8]).as_bytes().to_vec(),
};
let encoded = split_cmd.encode();
let decoded = ClusterCommand::decode(&encoded).expect("PlaceSplit round-trip must succeed");
assert_eq!(split_cmd, decoded, "PlaceSplit round-trip must be lossless");
let merge_cmd = ClusterCommand::PlaceMerge {
left_shard_id: 3,
right_shard_id: 4,
};
let encoded = merge_cmd.encode();
let decoded = ClusterCommand::decode(&encoded).expect("PlaceMerge round-trip must succeed");
assert_eq!(merge_cmd, decoded, "PlaceMerge round-trip must be lossless");
let transfer_cmd = ClusterCommand::PlaceTransfer {
shard_id: 99,
from_node: 1,
to_node: 2,
};
let encoded = transfer_cmd.encode();
let decoded =
ClusterCommand::decode(&encoded).expect("PlaceTransfer round-trip must succeed");
assert_eq!(
transfer_cmd, decoded,
"PlaceTransfer round-trip must be lossless"
);
}
#[tokio::test]
async fn test_handle_stop_is_idempotent() {
let node = make_follower_node();
let registry = make_registry();
let policy = PlacementPolicy::default_policy();
let config = PlacementSchedulerConfig {
tick_interval: Duration::from_secs(3600),
max_actions_per_tick: 5,
imbalance_threshold: 0.2,
};
let scheduler =
PlacementScheduler::new(Arc::clone(&node), Arc::clone(®istry), policy, config);
let handle = scheduler.handle();
let join = tokio::spawn(scheduler.run());
handle.stop();
handle.stop();
handle.stop();
tokio::time::timeout(Duration::from_secs(1), join)
.await
.expect("scheduler must still exit cleanly after multiple stop() calls")
.expect("scheduler task must not panic");
}
#[test]
fn test_placement_scheduler_config_imbalance_threshold_default() {
let cfg = PlacementSchedulerConfig::default();
assert!(
(cfg.imbalance_threshold - 0.2).abs() < 1e-9,
"default imbalance threshold must be 0.2, got {}",
cfg.imbalance_threshold
);
}
#[tokio::test]
async fn test_placement_skipped_below_threshold() {
use crate::shard::ShardRegistry;
let registry = ShardRegistry::new();
let imbalance = super::compute_imbalance(®istry);
assert_eq!(imbalance, 0.0, "empty registry imbalance must be 0.0");
}
#[test]
fn test_imbalance_computed_correctly() {
use crate::shard::{KeyRange, ShardMetadata, ShardRegistry};
use amaters_core::Key;
let registry = ShardRegistry::new();
let ranges: &[(&str, &str, u64, u64)] = &[
("a00", "a10", 1, 1),
("a10", "a20", 2, 1),
("a20", "a30", 3, 1),
("b00", "b10", 4, 2),
];
for &(start_s, end_s, shard_id, node_id) in ranges {
let start = Key::from_str(start_s);
let end = Key::from_str(end_s);
let range = KeyRange::new(start, end).expect("valid range");
let shard = ShardMetadata::new(shard_id, range, node_id);
registry.register(shard).expect("register must succeed");
}
let imbalance = super::compute_imbalance(®istry);
assert!(
(imbalance - 1.0).abs() < 1e-9,
"imbalance should be 1.0, got {}",
imbalance
);
}
}