nodedb_cluster/rebalancer/
driver.rs1use std::sync::{Arc, RwLock};
29use std::time::{Duration, Instant};
30
31use async_trait::async_trait;
32use tokio::sync::{Notify, watch};
33use tokio::time::{MissedTickBehavior, interval};
34use tracing::{debug, info, warn};
35
36use crate::error::Result;
37use crate::loop_metrics::LoopMetrics;
38use crate::rebalance::PlannedMove;
39use crate::routing::RoutingTable;
40use crate::topology::ClusterTopology;
41
42use super::metrics::LoadMetricsProvider;
43use super::plan::{RebalancerPlanConfig, compute_load_based_plan};
44
45#[async_trait]
49pub trait ElectionGate: Send + Sync {
50 async fn any_group_electing(&self) -> bool;
54}
55
56pub struct AlwaysReadyGate;
59
60#[async_trait]
61impl ElectionGate for AlwaysReadyGate {
62 async fn any_group_electing(&self) -> bool {
63 false
64 }
65}
66
67#[async_trait]
71pub trait MigrationDispatcher: Send + Sync {
72 async fn dispatch(&self, mv: PlannedMove) -> Result<()>;
73}
74
75#[derive(Debug, Clone)]
77pub struct RebalancerLoopConfig {
78 pub interval: Duration,
80 pub plan: RebalancerPlanConfig,
83 pub backpressure_cpu_threshold: f64,
88}
89
90impl Default for RebalancerLoopConfig {
91 fn default() -> Self {
92 Self {
93 interval: Duration::from_secs(30),
94 plan: RebalancerPlanConfig::default(),
95 backpressure_cpu_threshold: 0.80,
96 }
97 }
98}
99
100pub struct RebalancerLoop {
102 cfg: RebalancerLoopConfig,
103 metrics: Arc<dyn LoadMetricsProvider>,
104 dispatcher: Arc<dyn MigrationDispatcher>,
105 gate: Arc<dyn ElectionGate>,
106 routing: Arc<RwLock<RoutingTable>>,
107 topology: Arc<RwLock<ClusterTopology>>,
108 loop_metrics: Arc<LoopMetrics>,
113 kick: Arc<Notify>,
119}
120
121impl RebalancerLoop {
122 pub fn new(
123 cfg: RebalancerLoopConfig,
124 metrics: Arc<dyn LoadMetricsProvider>,
125 dispatcher: Arc<dyn MigrationDispatcher>,
126 gate: Arc<dyn ElectionGate>,
127 routing: Arc<RwLock<RoutingTable>>,
128 topology: Arc<RwLock<ClusterTopology>>,
129 ) -> Self {
130 Self {
131 cfg,
132 metrics,
133 dispatcher,
134 gate,
135 routing,
136 topology,
137 loop_metrics: LoopMetrics::new("rebalancer_loop"),
138 kick: Arc::new(Notify::new()),
139 }
140 }
141
142 pub fn kick_handle(&self) -> Arc<Notify> {
146 Arc::clone(&self.kick)
147 }
148
149 pub fn loop_metrics(&self) -> Arc<LoopMetrics> {
152 Arc::clone(&self.loop_metrics)
153 }
154
155 pub async fn run(self: Arc<Self>, mut shutdown: watch::Receiver<bool>) {
157 let mut tick = interval(self.cfg.interval);
158 tick.set_missed_tick_behavior(MissedTickBehavior::Delay);
159 tick.tick().await;
163 self.loop_metrics.set_up(true);
164 loop {
165 tokio::select! {
166 biased;
167 changed = shutdown.changed() => {
168 if changed.is_ok() && *shutdown.borrow() {
169 break;
170 }
171 }
172 _ = tick.tick() => {
173 self.sweep_once().await;
174 }
175 _ = self.kick.notified() => {
176 debug!("rebalancer: membership-change kick received");
177 self.sweep_once().await;
178 }
179 }
180 }
181 self.loop_metrics.set_up(false);
182 debug!("rebalancer loop shutting down");
183 }
184
185 pub async fn sweep_once(&self) {
188 let started = Instant::now();
189 if self.gate.any_group_electing().await {
190 debug!("rebalancer: raft election in progress, skipping tick");
191 self.loop_metrics.observe(started.elapsed());
192 return;
193 }
194 let metrics = match self.metrics.snapshot().await {
195 Ok(m) => m,
196 Err(e) => {
197 warn!(error = %e, "rebalancer: failed to collect metrics");
198 self.loop_metrics.record_error("metrics_snapshot");
199 self.loop_metrics.observe(started.elapsed());
200 return;
201 }
202 };
203 if let Some(hot) = metrics
204 .iter()
205 .find(|m| m.cpu_utilization > self.cfg.backpressure_cpu_threshold)
206 {
207 info!(
208 node_id = hot.node_id,
209 cpu = format!("{:.0}%", hot.cpu_utilization * 100.0),
210 threshold = format!("{:.0}%", self.cfg.backpressure_cpu_threshold * 100.0),
211 "rebalancer: back-pressure — cluster under load, skipping sweep"
212 );
213 self.loop_metrics.record_error("backpressure");
214 self.loop_metrics.observe(started.elapsed());
215 return;
216 }
217 let plan = {
218 let routing = self.routing.read().unwrap_or_else(|p| p.into_inner());
219 let topo = self.topology.read().unwrap_or_else(|p| p.into_inner());
220 compute_load_based_plan(&metrics, &routing, &topo, &self.cfg.plan)
221 };
222 if plan.is_empty() {
223 debug!("rebalancer: no moves needed this tick");
224 self.loop_metrics.observe(started.elapsed());
225 return;
226 }
227 info!(
228 move_count = plan.len(),
229 "rebalancer: dispatching planned moves"
230 );
231 let dispatcher = Arc::clone(&self.dispatcher);
232 let err_metrics = Arc::clone(&self.loop_metrics);
233 for mv in plan {
240 let dispatcher = Arc::clone(&dispatcher);
241 let err_metrics = Arc::clone(&err_metrics);
242 tokio::spawn(async move {
243 if let Err(e) = dispatcher.dispatch(mv).await {
244 warn!(error = %e, "rebalancer: dispatch failed");
245 err_metrics.record_error("dispatch");
246 }
247 });
248 }
249 self.loop_metrics.observe(started.elapsed());
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use crate::rebalancer::metrics::LoadMetrics;
257 use crate::topology::{NodeInfo, NodeState};
258 use std::net::SocketAddr;
259 use std::sync::Mutex;
260
261 struct StaticMetrics(Vec<LoadMetrics>);
262
263 #[async_trait]
264 impl LoadMetricsProvider for StaticMetrics {
265 async fn snapshot(&self) -> Result<Vec<LoadMetrics>> {
266 Ok(self.0.clone())
267 }
268 }
269
270 struct RecordingDispatcher {
271 calls: Mutex<Vec<PlannedMove>>,
272 }
273
274 impl RecordingDispatcher {
275 fn new() -> Arc<Self> {
276 Arc::new(Self {
277 calls: Mutex::new(Vec::new()),
278 })
279 }
280 fn take(&self) -> Vec<PlannedMove> {
281 let mut g = self.calls.lock().unwrap();
282 let out = g.clone();
283 g.clear();
284 out
285 }
286 }
287
288 #[async_trait]
289 impl MigrationDispatcher for RecordingDispatcher {
290 async fn dispatch(&self, mv: PlannedMove) -> Result<()> {
291 self.calls.lock().unwrap().push(mv);
292 Ok(())
293 }
294 }
295
296 struct BlockingGate(bool);
297
298 #[async_trait]
299 impl ElectionGate for BlockingGate {
300 async fn any_group_electing(&self) -> bool {
301 self.0
302 }
303 }
304
305 fn topo(nodes: &[u64]) -> Arc<RwLock<ClusterTopology>> {
306 let mut t = ClusterTopology::new();
307 for (i, id) in nodes.iter().enumerate() {
308 let a: SocketAddr = format!("127.0.0.1:{}", 9000 + i).parse().unwrap();
309 t.add_node(NodeInfo::new(*id, a, NodeState::Active));
310 }
311 Arc::new(RwLock::new(t))
312 }
313
314 fn routing_hot_on(node: u64) -> Arc<RwLock<RoutingTable>> {
315 let mut r = RoutingTable::uniform(6, &[1, 2, 3], 1);
316 for gid in 0..6 {
317 r.set_leader(gid, node);
318 }
319 Arc::new(RwLock::new(r))
320 }
321
322 fn lm(id: u64, v: u32, bytes_mib: u64, w: f64, r: f64) -> LoadMetrics {
323 LoadMetrics {
324 node_id: id,
325 vshards_led: v,
326 bytes_stored: bytes_mib * 1_048_576,
327 writes_per_sec: w,
328 reads_per_sec: r,
329 qps_recent: 0.0,
330 p95_latency_us: 0,
331 cpu_utilization: 0.0,
332 }
333 }
334
335 fn hot_cluster_loop(
336 gate: Arc<dyn ElectionGate>,
337 ) -> (Arc<RebalancerLoop>, Arc<RecordingDispatcher>) {
338 let metrics: Arc<dyn LoadMetricsProvider> = Arc::new(StaticMetrics(vec![
339 lm(1, 500, 5000, 200.0, 200.0),
340 lm(2, 5, 5, 5.0, 5.0),
341 lm(3, 5, 5, 5.0, 5.0),
342 ]));
343 let dispatcher = RecordingDispatcher::new();
344 let disp_dyn: Arc<dyn MigrationDispatcher> = dispatcher.clone();
345 let rloop = Arc::new(RebalancerLoop::new(
346 RebalancerLoopConfig {
347 interval: Duration::from_millis(50),
348 ..Default::default()
349 },
350 metrics,
351 disp_dyn,
352 gate,
353 routing_hot_on(1),
354 topo(&[1, 2, 3]),
355 ));
356 (rloop, dispatcher)
357 }
358
359 #[tokio::test]
360 async fn sweep_dispatches_moves_when_imbalanced() {
361 let (rloop, dispatcher) = hot_cluster_loop(Arc::new(AlwaysReadyGate));
362 rloop.sweep_once().await;
363 for _ in 0..16 {
364 tokio::task::yield_now().await;
365 }
366 let calls = dispatcher.take();
367 assert!(!calls.is_empty());
368 for c in &calls {
369 assert_eq!(c.source_node, 1);
370 }
371 }
372
373 #[tokio::test]
374 async fn sweep_skipped_during_election() {
375 let (rloop, dispatcher) = hot_cluster_loop(Arc::new(BlockingGate(true)));
376 rloop.sweep_once().await;
377 for _ in 0..8 {
378 tokio::task::yield_now().await;
379 }
380 assert!(dispatcher.take().is_empty());
381 }
382
383 #[tokio::test]
384 async fn sweep_noop_on_balanced_cluster() {
385 let metrics: Arc<dyn LoadMetricsProvider> = Arc::new(StaticMetrics(vec![
386 lm(1, 50, 500, 100.0, 100.0),
387 lm(2, 50, 500, 100.0, 100.0),
388 lm(3, 50, 500, 100.0, 100.0),
389 ]));
390 let dispatcher = RecordingDispatcher::new();
391 let rloop = Arc::new(RebalancerLoop::new(
392 RebalancerLoopConfig::default(),
393 metrics,
394 dispatcher.clone() as Arc<dyn MigrationDispatcher>,
395 Arc::new(AlwaysReadyGate),
396 routing_hot_on(1),
397 topo(&[1, 2, 3]),
398 ));
399 rloop.sweep_once().await;
400 for _ in 0..8 {
401 tokio::task::yield_now().await;
402 }
403 assert!(dispatcher.take().is_empty());
404 }
405
406 #[tokio::test(start_paused = true)]
407 async fn run_loop_fires_sweeps_and_shuts_down() {
408 let (rloop, dispatcher) = hot_cluster_loop(Arc::new(AlwaysReadyGate));
409 let (tx, rx) = watch::channel(false);
410 let handle = tokio::spawn({
411 let d = Arc::clone(&rloop);
412 async move { d.run(rx).await }
413 });
414 for _ in 0..4 {
418 tokio::time::advance(Duration::from_millis(80)).await;
419 for _ in 0..16 {
420 tokio::task::yield_now().await;
421 }
422 }
423 assert!(!dispatcher.take().is_empty());
424
425 let _ = tx.send(true);
426 let _ = tokio::time::timeout(Duration::from_millis(500), handle).await;
427 }
428
429 #[tokio::test]
430 async fn sweep_skipped_under_cpu_backpressure() {
431 let metrics: Arc<dyn LoadMetricsProvider> = Arc::new(StaticMetrics(vec![
432 LoadMetrics {
433 node_id: 1,
434 vshards_led: 500,
435 bytes_stored: 5000 * 1_048_576,
436 writes_per_sec: 200.0,
437 reads_per_sec: 200.0,
438 qps_recent: 0.0,
439 p95_latency_us: 0,
440 cpu_utilization: 0.95, },
442 lm(2, 5, 5, 5.0, 5.0),
443 lm(3, 5, 5, 5.0, 5.0),
444 ]));
445 let dispatcher = RecordingDispatcher::new();
446 let rloop = Arc::new(RebalancerLoop::new(
447 RebalancerLoopConfig {
448 interval: Duration::from_millis(50),
449 ..Default::default()
450 },
451 metrics,
452 dispatcher.clone() as Arc<dyn MigrationDispatcher>,
453 Arc::new(AlwaysReadyGate),
454 routing_hot_on(1),
455 topo(&[1, 2, 3]),
456 ));
457 rloop.sweep_once().await;
458 for _ in 0..8 {
459 tokio::task::yield_now().await;
460 }
461 assert!(
462 dispatcher.take().is_empty(),
463 "dispatcher should not fire when cluster is under CPU backpressure"
464 );
465 }
466}