nodedb_cluster/rebalancer/
driver.rs1use std::sync::{Arc, RwLock};
27use std::time::{Duration, Instant};
28
29use async_trait::async_trait;
30use tokio::sync::{Notify, watch};
31use tokio::time::{MissedTickBehavior, interval};
32use tracing::{debug, info, warn};
33
34use crate::error::Result;
35use crate::loop_metrics::LoopMetrics;
36use crate::rebalance::PlannedMove;
37use crate::routing::RoutingTable;
38use crate::topology::ClusterTopology;
39
40use super::metrics::LoadMetricsProvider;
41use super::plan::{RebalancerPlanConfig, compute_load_based_plan};
42
43#[async_trait]
47pub trait ElectionGate: Send + Sync {
48 async fn any_group_electing(&self) -> bool;
52}
53
54pub struct AlwaysReadyGate;
57
58#[async_trait]
59impl ElectionGate for AlwaysReadyGate {
60 async fn any_group_electing(&self) -> bool {
61 false
62 }
63}
64
65#[async_trait]
69pub trait MigrationDispatcher: Send + Sync {
70 async fn dispatch(&self, mv: PlannedMove) -> Result<()>;
71}
72
73#[derive(Debug, Clone)]
75pub struct RebalancerLoopConfig {
76 pub interval: Duration,
78 pub plan: RebalancerPlanConfig,
81 pub backpressure_cpu_threshold: f64,
86}
87
88impl Default for RebalancerLoopConfig {
89 fn default() -> Self {
90 Self {
91 interval: Duration::from_secs(30),
92 plan: RebalancerPlanConfig::default(),
93 backpressure_cpu_threshold: 0.80,
94 }
95 }
96}
97
98pub struct RebalancerLoop {
100 cfg: RebalancerLoopConfig,
101 metrics: Arc<dyn LoadMetricsProvider>,
102 dispatcher: Arc<dyn MigrationDispatcher>,
103 gate: Arc<dyn ElectionGate>,
104 routing: Arc<RwLock<RoutingTable>>,
105 topology: Arc<RwLock<ClusterTopology>>,
106 loop_metrics: Arc<LoopMetrics>,
111 kick: Arc<Notify>,
117}
118
119impl RebalancerLoop {
120 pub fn new(
121 cfg: RebalancerLoopConfig,
122 metrics: Arc<dyn LoadMetricsProvider>,
123 dispatcher: Arc<dyn MigrationDispatcher>,
124 gate: Arc<dyn ElectionGate>,
125 routing: Arc<RwLock<RoutingTable>>,
126 topology: Arc<RwLock<ClusterTopology>>,
127 ) -> Self {
128 Self {
129 cfg,
130 metrics,
131 dispatcher,
132 gate,
133 routing,
134 topology,
135 loop_metrics: LoopMetrics::new("rebalancer_loop"),
136 kick: Arc::new(Notify::new()),
137 }
138 }
139
140 pub fn kick_handle(&self) -> Arc<Notify> {
144 Arc::clone(&self.kick)
145 }
146
147 pub fn loop_metrics(&self) -> Arc<LoopMetrics> {
150 Arc::clone(&self.loop_metrics)
151 }
152
153 pub async fn run(self: Arc<Self>, mut shutdown: watch::Receiver<bool>) {
155 let mut tick = interval(self.cfg.interval);
156 tick.set_missed_tick_behavior(MissedTickBehavior::Delay);
157 tick.tick().await;
161 self.loop_metrics.set_up(true);
162 loop {
163 tokio::select! {
164 biased;
165 changed = shutdown.changed() => {
166 if changed.is_ok() && *shutdown.borrow() {
167 break;
168 }
169 }
170 _ = tick.tick() => {
171 self.sweep_once().await;
172 }
173 _ = self.kick.notified() => {
174 debug!("rebalancer: membership-change kick received");
175 self.sweep_once().await;
176 }
177 }
178 }
179 self.loop_metrics.set_up(false);
180 debug!("rebalancer loop shutting down");
181 }
182
183 pub async fn sweep_once(&self) {
186 let started = Instant::now();
187 if self.gate.any_group_electing().await {
188 debug!("rebalancer: raft election in progress, skipping tick");
189 self.loop_metrics.observe(started.elapsed());
190 return;
191 }
192 let metrics = match self.metrics.snapshot().await {
193 Ok(m) => m,
194 Err(e) => {
195 warn!(error = %e, "rebalancer: failed to collect metrics");
196 self.loop_metrics.record_error("metrics_snapshot");
197 self.loop_metrics.observe(started.elapsed());
198 return;
199 }
200 };
201 if let Some(hot) = metrics
202 .iter()
203 .find(|m| m.cpu_utilization > self.cfg.backpressure_cpu_threshold)
204 {
205 info!(
206 node_id = hot.node_id,
207 cpu = format!("{:.0}%", hot.cpu_utilization * 100.0),
208 threshold = format!("{:.0}%", self.cfg.backpressure_cpu_threshold * 100.0),
209 "rebalancer: back-pressure — cluster under load, skipping sweep"
210 );
211 self.loop_metrics.record_error("backpressure");
212 self.loop_metrics.observe(started.elapsed());
213 return;
214 }
215 let plan = {
216 let routing = self.routing.read().unwrap_or_else(|p| p.into_inner());
217 let topo = self.topology.read().unwrap_or_else(|p| p.into_inner());
218 compute_load_based_plan(&metrics, &routing, &topo, &self.cfg.plan)
219 };
220 if plan.is_empty() {
221 debug!("rebalancer: no moves needed this tick");
222 self.loop_metrics.observe(started.elapsed());
223 return;
224 }
225 info!(
226 move_count = plan.len(),
227 "rebalancer: dispatching planned moves"
228 );
229 let dispatcher = Arc::clone(&self.dispatcher);
230 let err_metrics = Arc::clone(&self.loop_metrics);
231 for mv in plan {
232 let dispatcher = Arc::clone(&dispatcher);
233 let err_metrics = Arc::clone(&err_metrics);
234 tokio::spawn(async move {
235 if let Err(e) = dispatcher.dispatch(mv).await {
236 warn!(error = %e, "rebalancer: dispatch failed");
237 err_metrics.record_error("dispatch");
238 }
239 });
240 }
241 self.loop_metrics.observe(started.elapsed());
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use crate::rebalancer::metrics::LoadMetrics;
249 use crate::topology::{NodeInfo, NodeState};
250 use std::net::SocketAddr;
251 use std::sync::Mutex;
252
253 struct StaticMetrics(Vec<LoadMetrics>);
254
255 #[async_trait]
256 impl LoadMetricsProvider for StaticMetrics {
257 async fn snapshot(&self) -> Result<Vec<LoadMetrics>> {
258 Ok(self.0.clone())
259 }
260 }
261
262 struct RecordingDispatcher {
263 calls: Mutex<Vec<PlannedMove>>,
264 }
265
266 impl RecordingDispatcher {
267 fn new() -> Arc<Self> {
268 Arc::new(Self {
269 calls: Mutex::new(Vec::new()),
270 })
271 }
272 fn take(&self) -> Vec<PlannedMove> {
273 let mut g = self.calls.lock().unwrap();
274 let out = g.clone();
275 g.clear();
276 out
277 }
278 }
279
280 #[async_trait]
281 impl MigrationDispatcher for RecordingDispatcher {
282 async fn dispatch(&self, mv: PlannedMove) -> Result<()> {
283 self.calls.lock().unwrap().push(mv);
284 Ok(())
285 }
286 }
287
288 struct BlockingGate(bool);
289
290 #[async_trait]
291 impl ElectionGate for BlockingGate {
292 async fn any_group_electing(&self) -> bool {
293 self.0
294 }
295 }
296
297 fn topo(nodes: &[u64]) -> Arc<RwLock<ClusterTopology>> {
298 let mut t = ClusterTopology::new();
299 for (i, id) in nodes.iter().enumerate() {
300 let a: SocketAddr = format!("127.0.0.1:{}", 9000 + i).parse().unwrap();
301 t.add_node(NodeInfo::new(*id, a, NodeState::Active));
302 }
303 Arc::new(RwLock::new(t))
304 }
305
306 fn routing_hot_on(node: u64) -> Arc<RwLock<RoutingTable>> {
307 let mut r = RoutingTable::uniform(6, &[1, 2, 3], 1);
308 for gid in 0..6 {
309 r.set_leader(gid, node);
310 }
311 Arc::new(RwLock::new(r))
312 }
313
314 fn lm(id: u64, v: u32, bytes_mib: u64, w: f64, r: f64) -> LoadMetrics {
315 LoadMetrics {
316 node_id: id,
317 vshards_led: v,
318 bytes_stored: bytes_mib * 1_048_576,
319 writes_per_sec: w,
320 reads_per_sec: r,
321 qps_recent: 0.0,
322 p95_latency_us: 0,
323 cpu_utilization: 0.0,
324 }
325 }
326
327 fn hot_cluster_loop(
328 gate: Arc<dyn ElectionGate>,
329 ) -> (Arc<RebalancerLoop>, Arc<RecordingDispatcher>) {
330 let metrics: Arc<dyn LoadMetricsProvider> = Arc::new(StaticMetrics(vec![
331 lm(1, 500, 5000, 200.0, 200.0),
332 lm(2, 5, 5, 5.0, 5.0),
333 lm(3, 5, 5, 5.0, 5.0),
334 ]));
335 let dispatcher = RecordingDispatcher::new();
336 let disp_dyn: Arc<dyn MigrationDispatcher> = dispatcher.clone();
337 let rloop = Arc::new(RebalancerLoop::new(
338 RebalancerLoopConfig {
339 interval: Duration::from_millis(50),
340 ..Default::default()
341 },
342 metrics,
343 disp_dyn,
344 gate,
345 routing_hot_on(1),
346 topo(&[1, 2, 3]),
347 ));
348 (rloop, dispatcher)
349 }
350
351 #[tokio::test]
352 async fn sweep_dispatches_moves_when_imbalanced() {
353 let (rloop, dispatcher) = hot_cluster_loop(Arc::new(AlwaysReadyGate));
354 rloop.sweep_once().await;
355 for _ in 0..16 {
356 tokio::task::yield_now().await;
357 }
358 let calls = dispatcher.take();
359 assert!(!calls.is_empty());
360 for c in &calls {
361 assert_eq!(c.source_node, 1);
362 }
363 }
364
365 #[tokio::test]
366 async fn sweep_skipped_during_election() {
367 let (rloop, dispatcher) = hot_cluster_loop(Arc::new(BlockingGate(true)));
368 rloop.sweep_once().await;
369 for _ in 0..8 {
370 tokio::task::yield_now().await;
371 }
372 assert!(dispatcher.take().is_empty());
373 }
374
375 #[tokio::test]
376 async fn sweep_noop_on_balanced_cluster() {
377 let metrics: Arc<dyn LoadMetricsProvider> = Arc::new(StaticMetrics(vec![
378 lm(1, 50, 500, 100.0, 100.0),
379 lm(2, 50, 500, 100.0, 100.0),
380 lm(3, 50, 500, 100.0, 100.0),
381 ]));
382 let dispatcher = RecordingDispatcher::new();
383 let rloop = Arc::new(RebalancerLoop::new(
384 RebalancerLoopConfig::default(),
385 metrics,
386 dispatcher.clone() as Arc<dyn MigrationDispatcher>,
387 Arc::new(AlwaysReadyGate),
388 routing_hot_on(1),
389 topo(&[1, 2, 3]),
390 ));
391 rloop.sweep_once().await;
392 for _ in 0..8 {
393 tokio::task::yield_now().await;
394 }
395 assert!(dispatcher.take().is_empty());
396 }
397
398 #[tokio::test(start_paused = true)]
399 async fn run_loop_fires_sweeps_and_shuts_down() {
400 let (rloop, dispatcher) = hot_cluster_loop(Arc::new(AlwaysReadyGate));
401 let (tx, rx) = watch::channel(false);
402 let handle = tokio::spawn({
403 let d = Arc::clone(&rloop);
404 async move { d.run(rx).await }
405 });
406 for _ in 0..4 {
410 tokio::time::advance(Duration::from_millis(80)).await;
411 for _ in 0..16 {
412 tokio::task::yield_now().await;
413 }
414 }
415 assert!(!dispatcher.take().is_empty());
416
417 let _ = tx.send(true);
418 let _ = tokio::time::timeout(Duration::from_millis(500), handle).await;
419 }
420
421 #[tokio::test]
422 async fn sweep_skipped_under_cpu_backpressure() {
423 let metrics: Arc<dyn LoadMetricsProvider> = Arc::new(StaticMetrics(vec![
424 LoadMetrics {
425 node_id: 1,
426 vshards_led: 500,
427 bytes_stored: 5000 * 1_048_576,
428 writes_per_sec: 200.0,
429 reads_per_sec: 200.0,
430 qps_recent: 0.0,
431 p95_latency_us: 0,
432 cpu_utilization: 0.95, },
434 lm(2, 5, 5, 5.0, 5.0),
435 lm(3, 5, 5, 5.0, 5.0),
436 ]));
437 let dispatcher = RecordingDispatcher::new();
438 let rloop = Arc::new(RebalancerLoop::new(
439 RebalancerLoopConfig {
440 interval: Duration::from_millis(50),
441 ..Default::default()
442 },
443 metrics,
444 dispatcher.clone() as Arc<dyn MigrationDispatcher>,
445 Arc::new(AlwaysReadyGate),
446 routing_hot_on(1),
447 topo(&[1, 2, 3]),
448 ));
449 rloop.sweep_once().await;
450 for _ in 0..8 {
451 tokio::task::yield_now().await;
452 }
453 assert!(
454 dispatcher.take().is_empty(),
455 "dispatcher should not fire when cluster is under CPU backpressure"
456 );
457 }
458}