Skip to main content

nodedb_cluster/rebalancer/
driver.rs

1//! Rebalancer driver loop.
2//!
3//! [`RebalancerLoop`] is the active half of the load-based rebalancer.
4//! Every `interval` it walks this sequence:
5//!
6//! 1. Ask the injected `ElectionGate` whether any raft group is
7//!    currently mid-election. If so, skip this tick entirely —
8//!    moves during an election race with the new leader's log and
9//!    are almost guaranteed to be wasted work.
10//! 2. Ask the injected [`LoadMetricsProvider`] for a snapshot of
11//!    every node's current load metrics.
12//! 3. Call [`compute_load_based_plan`] against the live routing +
13//!    topology with the configured plan config. If the plan is
14//!    empty (cluster within threshold, or no cold candidates), do
15//!    nothing.
16//! 4. Dispatch each planned move through the injected
17//!    [`MigrationDispatcher`], fire-and-forget. The dispatcher is
18//!    where the bridge to the production `MigrationExecutor` lives
19//!    — tests use a mock that records the calls.
20//!
21//! The loop holds no state of its own; the dispatcher tracks
22//! in-flight work and the breaker/scheduler state is on the
23//! underlying subsystems. This keeps the driver trivially
24//! restartable: crash mid-tick, respawn, resume.
25
26use std::sync::{Arc, RwLock};
27use std::time::{Duration, Instant};
28
29use async_trait::async_trait;
30use tokio::sync::{Notify, watch};
31use tokio::time::{MissedTickBehavior, interval};
32use tracing::{debug, info, warn};
33
34use crate::error::Result;
35use crate::loop_metrics::LoopMetrics;
36use crate::rebalance::PlannedMove;
37use crate::routing::RoutingTable;
38use crate::topology::ClusterTopology;
39
40use super::metrics::LoadMetricsProvider;
41use super::plan::{RebalancerPlanConfig, compute_load_based_plan};
42
43/// Injection seam: tells the driver whether it's safe to dispatch
44/// moves. Production wraps a `MultiRaft` status probe; tests return
45/// a constant boolean.
46#[async_trait]
47pub trait ElectionGate: Send + Sync {
48    /// Return `true` if **any** raft group is currently holding an
49    /// election (no stable leader). The driver skips its tick when
50    /// this is `true`.
51    async fn any_group_electing(&self) -> bool;
52}
53
54/// Permissive gate that never blocks the driver. Useful in tests
55/// and in single-node clusters where elections are instantaneous.
56pub struct AlwaysReadyGate;
57
58#[async_trait]
59impl ElectionGate for AlwaysReadyGate {
60    async fn any_group_electing(&self) -> bool {
61        false
62    }
63}
64
65/// Injection seam: executes a single planned move. Production
66/// wraps `MigrationExecutor::execute` and reports success/failure
67/// via logging + the tracker; tests record the move.
68#[async_trait]
69pub trait MigrationDispatcher: Send + Sync {
70    async fn dispatch(&self, mv: PlannedMove) -> Result<()>;
71}
72
73/// Configuration for [`RebalancerLoop`].
74#[derive(Debug, Clone)]
75pub struct RebalancerLoopConfig {
76    /// Period between rebalance sweeps. Defaults to 30 s.
77    pub interval: Duration,
78    /// Plan computation config propagated to
79    /// [`compute_load_based_plan`] on every tick.
80    pub plan: RebalancerPlanConfig,
81    /// CPU utilization threshold (0.0–1.0) above which the
82    /// rebalancer pauses to avoid amplifying load. If ANY node in
83    /// the metrics snapshot exceeds this value, the sweep is skipped
84    /// and a STATUS event is logged. Default 0.80 (80%).
85    pub backpressure_cpu_threshold: f64,
86}
87
88impl Default for RebalancerLoopConfig {
89    fn default() -> Self {
90        Self {
91            interval: Duration::from_secs(30),
92            plan: RebalancerPlanConfig::default(),
93            backpressure_cpu_threshold: 0.80,
94        }
95    }
96}
97
98/// The driver itself.
99pub struct RebalancerLoop {
100    cfg: RebalancerLoopConfig,
101    metrics: Arc<dyn LoadMetricsProvider>,
102    dispatcher: Arc<dyn MigrationDispatcher>,
103    gate: Arc<dyn ElectionGate>,
104    routing: Arc<RwLock<RoutingTable>>,
105    topology: Arc<RwLock<ClusterTopology>>,
106    /// Standardized loop observations (iterations, last-iteration
107    /// duration, errors by kind, up flag). Register this handle with
108    /// the cluster's [`LoopMetricsRegistry`](crate::LoopMetricsRegistry)
109    /// so scrapes include its samples.
110    loop_metrics: Arc<LoopMetrics>,
111    /// Membership-change notification. When any caller (a SWIM
112    /// subscriber, a manual admin trigger, etc.) calls
113    /// [`notify`](Notify::notify_one) on this handle, the run loop
114    /// wakes up immediately and runs an extra sweep instead of
115    /// waiting for the next 30 s tick.
116    kick: Arc<Notify>,
117}
118
119impl RebalancerLoop {
120    pub fn new(
121        cfg: RebalancerLoopConfig,
122        metrics: Arc<dyn LoadMetricsProvider>,
123        dispatcher: Arc<dyn MigrationDispatcher>,
124        gate: Arc<dyn ElectionGate>,
125        routing: Arc<RwLock<RoutingTable>>,
126        topology: Arc<RwLock<ClusterTopology>>,
127    ) -> Self {
128        Self {
129            cfg,
130            metrics,
131            dispatcher,
132            gate,
133            routing,
134            topology,
135            loop_metrics: LoopMetrics::new("rebalancer_loop"),
136            kick: Arc::new(Notify::new()),
137        }
138    }
139
140    /// Return a handle that callers can use to trigger an immediate
141    /// sweep. Cloning the `Arc<Notify>` is cheap; every clone
142    /// shares the same waker.
143    pub fn kick_handle(&self) -> Arc<Notify> {
144        Arc::clone(&self.kick)
145    }
146
147    /// Shared handle to this loop's standardized metrics. Lifecycle
148    /// owners register this with the cluster registry on spawn.
149    pub fn loop_metrics(&self) -> Arc<LoopMetrics> {
150        Arc::clone(&self.loop_metrics)
151    }
152
153    /// Run the driver until `shutdown` flips to `true`.
154    pub async fn run(self: Arc<Self>, mut shutdown: watch::Receiver<bool>) {
155        let mut tick = interval(self.cfg.interval);
156        tick.set_missed_tick_behavior(MissedTickBehavior::Delay);
157        // Consume the immediate first tick so the first sweep fires
158        // a full interval after start. Prevents start-up stampedes
159        // when many nodes restart together.
160        tick.tick().await;
161        self.loop_metrics.set_up(true);
162        loop {
163            tokio::select! {
164                biased;
165                changed = shutdown.changed() => {
166                    if changed.is_ok() && *shutdown.borrow() {
167                        break;
168                    }
169                }
170                _ = tick.tick() => {
171                    self.sweep_once().await;
172                }
173                _ = self.kick.notified() => {
174                    debug!("rebalancer: membership-change kick received");
175                    self.sweep_once().await;
176                }
177            }
178        }
179        self.loop_metrics.set_up(false);
180        debug!("rebalancer loop shutting down");
181    }
182
183    /// Run a single sweep. Exposed for tests that drive the loop
184    /// manually rather than through `run`.
185    pub async fn sweep_once(&self) {
186        let started = Instant::now();
187        if self.gate.any_group_electing().await {
188            debug!("rebalancer: raft election in progress, skipping tick");
189            self.loop_metrics.observe(started.elapsed());
190            return;
191        }
192        let metrics = match self.metrics.snapshot().await {
193            Ok(m) => m,
194            Err(e) => {
195                warn!(error = %e, "rebalancer: failed to collect metrics");
196                self.loop_metrics.record_error("metrics_snapshot");
197                self.loop_metrics.observe(started.elapsed());
198                return;
199            }
200        };
201        if let Some(hot) = metrics
202            .iter()
203            .find(|m| m.cpu_utilization > self.cfg.backpressure_cpu_threshold)
204        {
205            info!(
206                node_id = hot.node_id,
207                cpu = format!("{:.0}%", hot.cpu_utilization * 100.0),
208                threshold = format!("{:.0}%", self.cfg.backpressure_cpu_threshold * 100.0),
209                "rebalancer: back-pressure — cluster under load, skipping sweep"
210            );
211            self.loop_metrics.record_error("backpressure");
212            self.loop_metrics.observe(started.elapsed());
213            return;
214        }
215        let plan = {
216            let routing = self.routing.read().unwrap_or_else(|p| p.into_inner());
217            let topo = self.topology.read().unwrap_or_else(|p| p.into_inner());
218            compute_load_based_plan(&metrics, &routing, &topo, &self.cfg.plan)
219        };
220        if plan.is_empty() {
221            debug!("rebalancer: no moves needed this tick");
222            self.loop_metrics.observe(started.elapsed());
223            return;
224        }
225        info!(
226            move_count = plan.len(),
227            "rebalancer: dispatching planned moves"
228        );
229        let dispatcher = Arc::clone(&self.dispatcher);
230        let err_metrics = Arc::clone(&self.loop_metrics);
231        for mv in plan {
232            let dispatcher = Arc::clone(&dispatcher);
233            let err_metrics = Arc::clone(&err_metrics);
234            tokio::spawn(async move {
235                if let Err(e) = dispatcher.dispatch(mv).await {
236                    warn!(error = %e, "rebalancer: dispatch failed");
237                    err_metrics.record_error("dispatch");
238                }
239            });
240        }
241        self.loop_metrics.observe(started.elapsed());
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use crate::rebalancer::metrics::LoadMetrics;
249    use crate::topology::{NodeInfo, NodeState};
250    use std::net::SocketAddr;
251    use std::sync::Mutex;
252
253    struct StaticMetrics(Vec<LoadMetrics>);
254
255    #[async_trait]
256    impl LoadMetricsProvider for StaticMetrics {
257        async fn snapshot(&self) -> Result<Vec<LoadMetrics>> {
258            Ok(self.0.clone())
259        }
260    }
261
262    struct RecordingDispatcher {
263        calls: Mutex<Vec<PlannedMove>>,
264    }
265
266    impl RecordingDispatcher {
267        fn new() -> Arc<Self> {
268            Arc::new(Self {
269                calls: Mutex::new(Vec::new()),
270            })
271        }
272        fn take(&self) -> Vec<PlannedMove> {
273            let mut g = self.calls.lock().unwrap();
274            let out = g.clone();
275            g.clear();
276            out
277        }
278    }
279
280    #[async_trait]
281    impl MigrationDispatcher for RecordingDispatcher {
282        async fn dispatch(&self, mv: PlannedMove) -> Result<()> {
283            self.calls.lock().unwrap().push(mv);
284            Ok(())
285        }
286    }
287
288    struct BlockingGate(bool);
289
290    #[async_trait]
291    impl ElectionGate for BlockingGate {
292        async fn any_group_electing(&self) -> bool {
293            self.0
294        }
295    }
296
297    fn topo(nodes: &[u64]) -> Arc<RwLock<ClusterTopology>> {
298        let mut t = ClusterTopology::new();
299        for (i, id) in nodes.iter().enumerate() {
300            let a: SocketAddr = format!("127.0.0.1:{}", 9000 + i).parse().unwrap();
301            t.add_node(NodeInfo::new(*id, a, NodeState::Active));
302        }
303        Arc::new(RwLock::new(t))
304    }
305
306    fn routing_hot_on(node: u64) -> Arc<RwLock<RoutingTable>> {
307        let mut r = RoutingTable::uniform(6, &[1, 2, 3], 1);
308        for gid in 0..6 {
309            r.set_leader(gid, node);
310        }
311        Arc::new(RwLock::new(r))
312    }
313
314    fn lm(id: u64, v: u32, bytes_mib: u64, w: f64, r: f64) -> LoadMetrics {
315        LoadMetrics {
316            node_id: id,
317            vshards_led: v,
318            bytes_stored: bytes_mib * 1_048_576,
319            writes_per_sec: w,
320            reads_per_sec: r,
321            qps_recent: 0.0,
322            p95_latency_us: 0,
323            cpu_utilization: 0.0,
324        }
325    }
326
327    fn hot_cluster_loop(
328        gate: Arc<dyn ElectionGate>,
329    ) -> (Arc<RebalancerLoop>, Arc<RecordingDispatcher>) {
330        let metrics: Arc<dyn LoadMetricsProvider> = Arc::new(StaticMetrics(vec![
331            lm(1, 500, 5000, 200.0, 200.0),
332            lm(2, 5, 5, 5.0, 5.0),
333            lm(3, 5, 5, 5.0, 5.0),
334        ]));
335        let dispatcher = RecordingDispatcher::new();
336        let disp_dyn: Arc<dyn MigrationDispatcher> = dispatcher.clone();
337        let rloop = Arc::new(RebalancerLoop::new(
338            RebalancerLoopConfig {
339                interval: Duration::from_millis(50),
340                ..Default::default()
341            },
342            metrics,
343            disp_dyn,
344            gate,
345            routing_hot_on(1),
346            topo(&[1, 2, 3]),
347        ));
348        (rloop, dispatcher)
349    }
350
351    #[tokio::test]
352    async fn sweep_dispatches_moves_when_imbalanced() {
353        let (rloop, dispatcher) = hot_cluster_loop(Arc::new(AlwaysReadyGate));
354        rloop.sweep_once().await;
355        for _ in 0..16 {
356            tokio::task::yield_now().await;
357        }
358        let calls = dispatcher.take();
359        assert!(!calls.is_empty());
360        for c in &calls {
361            assert_eq!(c.source_node, 1);
362        }
363    }
364
365    #[tokio::test]
366    async fn sweep_skipped_during_election() {
367        let (rloop, dispatcher) = hot_cluster_loop(Arc::new(BlockingGate(true)));
368        rloop.sweep_once().await;
369        for _ in 0..8 {
370            tokio::task::yield_now().await;
371        }
372        assert!(dispatcher.take().is_empty());
373    }
374
375    #[tokio::test]
376    async fn sweep_noop_on_balanced_cluster() {
377        let metrics: Arc<dyn LoadMetricsProvider> = Arc::new(StaticMetrics(vec![
378            lm(1, 50, 500, 100.0, 100.0),
379            lm(2, 50, 500, 100.0, 100.0),
380            lm(3, 50, 500, 100.0, 100.0),
381        ]));
382        let dispatcher = RecordingDispatcher::new();
383        let rloop = Arc::new(RebalancerLoop::new(
384            RebalancerLoopConfig::default(),
385            metrics,
386            dispatcher.clone() as Arc<dyn MigrationDispatcher>,
387            Arc::new(AlwaysReadyGate),
388            routing_hot_on(1),
389            topo(&[1, 2, 3]),
390        ));
391        rloop.sweep_once().await;
392        for _ in 0..8 {
393            tokio::task::yield_now().await;
394        }
395        assert!(dispatcher.take().is_empty());
396    }
397
398    #[tokio::test(start_paused = true)]
399    async fn run_loop_fires_sweeps_and_shuts_down() {
400        let (rloop, dispatcher) = hot_cluster_loop(Arc::new(AlwaysReadyGate));
401        let (tx, rx) = watch::channel(false);
402        let handle = tokio::spawn({
403            let d = Arc::clone(&rloop);
404            async move { d.run(rx).await }
405        });
406        // First tick consumed immediately by run(); advance past a
407        // couple of real intervals with interleaved yields so the
408        // run-loop's select + spawned dispatch tasks all get to poll.
409        for _ in 0..4 {
410            tokio::time::advance(Duration::from_millis(80)).await;
411            for _ in 0..16 {
412                tokio::task::yield_now().await;
413            }
414        }
415        assert!(!dispatcher.take().is_empty());
416
417        let _ = tx.send(true);
418        let _ = tokio::time::timeout(Duration::from_millis(500), handle).await;
419    }
420
421    #[tokio::test]
422    async fn sweep_skipped_under_cpu_backpressure() {
423        let metrics: Arc<dyn LoadMetricsProvider> = Arc::new(StaticMetrics(vec![
424            LoadMetrics {
425                node_id: 1,
426                vshards_led: 500,
427                bytes_stored: 5000 * 1_048_576,
428                writes_per_sec: 200.0,
429                reads_per_sec: 200.0,
430                qps_recent: 0.0,
431                p95_latency_us: 0,
432                cpu_utilization: 0.95, // above 80% threshold
433            },
434            lm(2, 5, 5, 5.0, 5.0),
435            lm(3, 5, 5, 5.0, 5.0),
436        ]));
437        let dispatcher = RecordingDispatcher::new();
438        let rloop = Arc::new(RebalancerLoop::new(
439            RebalancerLoopConfig {
440                interval: Duration::from_millis(50),
441                ..Default::default()
442            },
443            metrics,
444            dispatcher.clone() as Arc<dyn MigrationDispatcher>,
445            Arc::new(AlwaysReadyGate),
446            routing_hot_on(1),
447            topo(&[1, 2, 3]),
448        ));
449        rloop.sweep_once().await;
450        for _ in 0..8 {
451            tokio::task::yield_now().await;
452        }
453        assert!(
454            dispatcher.take().is_empty(),
455            "dispatcher should not fire when cluster is under CPU backpressure"
456        );
457    }
458}