Skip to main content

nodedb_cluster/rebalancer/
driver.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Rebalancer driver loop.
4//!
5//! [`RebalancerLoop`] is the active half of the load-based rebalancer.
6//! Every `interval` it walks this sequence:
7//!
8//! 1. Ask the injected `ElectionGate` whether any raft group is
9//!    currently mid-election. If so, skip this tick entirely —
10//!    moves during an election race with the new leader's log and
11//!    are almost guaranteed to be wasted work.
12//! 2. Ask the injected [`LoadMetricsProvider`] for a snapshot of
13//!    every node's current load metrics.
14//! 3. Call [`compute_load_based_plan`] against the live routing +
15//!    topology with the configured plan config. If the plan is
16//!    empty (cluster within threshold, or no cold candidates), do
17//!    nothing.
18//! 4. Dispatch each planned move through the injected
19//!    [`MigrationDispatcher`], fire-and-forget. The dispatcher is
20//!    where the bridge to the production `MigrationExecutor` lives
21//!    — tests use a mock that records the calls.
22//!
23//! The loop holds no state of its own; the dispatcher tracks
24//! in-flight work and the breaker/scheduler state is on the
25//! underlying subsystems. This keeps the driver trivially
26//! restartable: crash mid-tick, respawn, resume.
27
28use std::sync::{Arc, RwLock};
29use std::time::{Duration, Instant};
30
31use async_trait::async_trait;
32use tokio::sync::{Notify, watch};
33use tokio::time::{MissedTickBehavior, interval};
34use tracing::{debug, info, warn};
35
36use crate::error::Result;
37use crate::loop_metrics::LoopMetrics;
38use crate::rebalance::PlannedMove;
39use crate::routing::RoutingTable;
40use crate::topology::ClusterTopology;
41
42use super::metrics::LoadMetricsProvider;
43use super::plan::{RebalancerPlanConfig, compute_load_based_plan};
44
45/// Injection seam: tells the driver whether it's safe to dispatch
46/// moves. Production wraps a `MultiRaft` status probe; tests return
47/// a constant boolean.
48#[async_trait]
49pub trait ElectionGate: Send + Sync {
50    /// Return `true` if **any** raft group is currently holding an
51    /// election (no stable leader). The driver skips its tick when
52    /// this is `true`.
53    async fn any_group_electing(&self) -> bool;
54}
55
56/// Permissive gate that never blocks the driver. Useful in tests
57/// and in single-node clusters where elections are instantaneous.
58pub struct AlwaysReadyGate;
59
60#[async_trait]
61impl ElectionGate for AlwaysReadyGate {
62    async fn any_group_electing(&self) -> bool {
63        false
64    }
65}
66
67/// Injection seam: executes a single planned move. Production
68/// wraps `MigrationExecutor::execute` and reports success/failure
69/// via logging + the tracker; tests record the move.
70#[async_trait]
71pub trait MigrationDispatcher: Send + Sync {
72    async fn dispatch(&self, mv: PlannedMove) -> Result<()>;
73}
74
75/// Configuration for [`RebalancerLoop`].
76#[derive(Debug, Clone)]
77pub struct RebalancerLoopConfig {
78    /// Period between rebalance sweeps. Defaults to 30 s.
79    pub interval: Duration,
80    /// Plan computation config propagated to
81    /// [`compute_load_based_plan`] on every tick.
82    pub plan: RebalancerPlanConfig,
83    /// CPU utilization threshold (0.0–1.0) above which the
84    /// rebalancer pauses to avoid amplifying load. If ANY node in
85    /// the metrics snapshot exceeds this value, the sweep is skipped
86    /// and a STATUS event is logged. Default 0.80 (80%).
87    pub backpressure_cpu_threshold: f64,
88}
89
90impl Default for RebalancerLoopConfig {
91    fn default() -> Self {
92        Self {
93            interval: Duration::from_secs(30),
94            plan: RebalancerPlanConfig::default(),
95            backpressure_cpu_threshold: 0.80,
96        }
97    }
98}
99
100/// The driver itself.
101pub struct RebalancerLoop {
102    cfg: RebalancerLoopConfig,
103    metrics: Arc<dyn LoadMetricsProvider>,
104    dispatcher: Arc<dyn MigrationDispatcher>,
105    gate: Arc<dyn ElectionGate>,
106    routing: Arc<RwLock<RoutingTable>>,
107    topology: Arc<RwLock<ClusterTopology>>,
108    /// Standardized loop observations (iterations, last-iteration
109    /// duration, errors by kind, up flag). Register this handle with
110    /// the cluster's [`LoopMetricsRegistry`](crate::LoopMetricsRegistry)
111    /// so scrapes include its samples.
112    loop_metrics: Arc<LoopMetrics>,
113    /// Membership-change notification. When any caller (a SWIM
114    /// subscriber, a manual admin trigger, etc.) calls
115    /// [`notify`](Notify::notify_one) on this handle, the run loop
116    /// wakes up immediately and runs an extra sweep instead of
117    /// waiting for the next 30 s tick.
118    kick: Arc<Notify>,
119}
120
121impl RebalancerLoop {
122    pub fn new(
123        cfg: RebalancerLoopConfig,
124        metrics: Arc<dyn LoadMetricsProvider>,
125        dispatcher: Arc<dyn MigrationDispatcher>,
126        gate: Arc<dyn ElectionGate>,
127        routing: Arc<RwLock<RoutingTable>>,
128        topology: Arc<RwLock<ClusterTopology>>,
129    ) -> Self {
130        Self {
131            cfg,
132            metrics,
133            dispatcher,
134            gate,
135            routing,
136            topology,
137            loop_metrics: LoopMetrics::new("rebalancer_loop"),
138            kick: Arc::new(Notify::new()),
139        }
140    }
141
142    /// Return a handle that callers can use to trigger an immediate
143    /// sweep. Cloning the `Arc<Notify>` is cheap; every clone
144    /// shares the same waker.
145    pub fn kick_handle(&self) -> Arc<Notify> {
146        Arc::clone(&self.kick)
147    }
148
149    /// Shared handle to this loop's standardized metrics. Lifecycle
150    /// owners register this with the cluster registry on spawn.
151    pub fn loop_metrics(&self) -> Arc<LoopMetrics> {
152        Arc::clone(&self.loop_metrics)
153    }
154
155    /// Run the driver until `shutdown` flips to `true`.
156    pub async fn run(self: Arc<Self>, mut shutdown: watch::Receiver<bool>) {
157        let mut tick = interval(self.cfg.interval);
158        tick.set_missed_tick_behavior(MissedTickBehavior::Delay);
159        // Consume the immediate first tick so the first sweep fires
160        // a full interval after start. Prevents start-up stampedes
161        // when many nodes restart together.
162        tick.tick().await;
163        self.loop_metrics.set_up(true);
164        loop {
165            tokio::select! {
166                biased;
167                changed = shutdown.changed() => {
168                    if changed.is_ok() && *shutdown.borrow() {
169                        break;
170                    }
171                }
172                _ = tick.tick() => {
173                    self.sweep_once().await;
174                }
175                _ = self.kick.notified() => {
176                    debug!("rebalancer: membership-change kick received");
177                    self.sweep_once().await;
178                }
179            }
180        }
181        self.loop_metrics.set_up(false);
182        debug!("rebalancer loop shutting down");
183    }
184
185    /// Run a single sweep. Exposed for tests that drive the loop
186    /// manually rather than through `run`.
187    pub async fn sweep_once(&self) {
188        let started = Instant::now();
189        if self.gate.any_group_electing().await {
190            debug!("rebalancer: raft election in progress, skipping tick");
191            self.loop_metrics.observe(started.elapsed());
192            return;
193        }
194        let metrics = match self.metrics.snapshot().await {
195            Ok(m) => m,
196            Err(e) => {
197                warn!(error = %e, "rebalancer: failed to collect metrics");
198                self.loop_metrics.record_error("metrics_snapshot");
199                self.loop_metrics.observe(started.elapsed());
200                return;
201            }
202        };
203        if let Some(hot) = metrics
204            .iter()
205            .find(|m| m.cpu_utilization > self.cfg.backpressure_cpu_threshold)
206        {
207            info!(
208                node_id = hot.node_id,
209                cpu = format!("{:.0}%", hot.cpu_utilization * 100.0),
210                threshold = format!("{:.0}%", self.cfg.backpressure_cpu_threshold * 100.0),
211                "rebalancer: back-pressure — cluster under load, skipping sweep"
212            );
213            self.loop_metrics.record_error("backpressure");
214            self.loop_metrics.observe(started.elapsed());
215            return;
216        }
217        let plan = {
218            let routing = self.routing.read().unwrap_or_else(|p| p.into_inner());
219            let topo = self.topology.read().unwrap_or_else(|p| p.into_inner());
220            compute_load_based_plan(&metrics, &routing, &topo, &self.cfg.plan)
221        };
222        if plan.is_empty() {
223            debug!("rebalancer: no moves needed this tick");
224            self.loop_metrics.observe(started.elapsed());
225            return;
226        }
227        info!(
228            move_count = plan.len(),
229            "rebalancer: dispatching planned moves"
230        );
231        let dispatcher = Arc::clone(&self.dispatcher);
232        let err_metrics = Arc::clone(&self.loop_metrics);
233        // PHASE-A-WIRE: recover_in_flight_migrations should be called in
234        // RebalancerSubsystem::start (after metadata-Raft-ready, before this
235        // loop begins) so stale migrations are resumed or aborted before new
236        // ones are dispatched.  The per-vshard dedup guard lives in
237        // MigrationExecutor::execute — duplicate dispatches for the same vshard
238        // are rejected with ClusterError::MigrationInProgress.
239        for mv in plan {
240            let dispatcher = Arc::clone(&dispatcher);
241            let err_metrics = Arc::clone(&err_metrics);
242            tokio::spawn(async move {
243                if let Err(e) = dispatcher.dispatch(mv).await {
244                    warn!(error = %e, "rebalancer: dispatch failed");
245                    err_metrics.record_error("dispatch");
246                }
247            });
248        }
249        self.loop_metrics.observe(started.elapsed());
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use crate::rebalancer::metrics::LoadMetrics;
257    use crate::topology::{NodeInfo, NodeState};
258    use std::net::SocketAddr;
259    use std::sync::Mutex;
260
261    struct StaticMetrics(Vec<LoadMetrics>);
262
263    #[async_trait]
264    impl LoadMetricsProvider for StaticMetrics {
265        async fn snapshot(&self) -> Result<Vec<LoadMetrics>> {
266            Ok(self.0.clone())
267        }
268    }
269
270    struct RecordingDispatcher {
271        calls: Mutex<Vec<PlannedMove>>,
272    }
273
274    impl RecordingDispatcher {
275        fn new() -> Arc<Self> {
276            Arc::new(Self {
277                calls: Mutex::new(Vec::new()),
278            })
279        }
280        fn take(&self) -> Vec<PlannedMove> {
281            let mut g = self.calls.lock().unwrap();
282            let out = g.clone();
283            g.clear();
284            out
285        }
286    }
287
288    #[async_trait]
289    impl MigrationDispatcher for RecordingDispatcher {
290        async fn dispatch(&self, mv: PlannedMove) -> Result<()> {
291            self.calls.lock().unwrap().push(mv);
292            Ok(())
293        }
294    }
295
296    struct BlockingGate(bool);
297
298    #[async_trait]
299    impl ElectionGate for BlockingGate {
300        async fn any_group_electing(&self) -> bool {
301            self.0
302        }
303    }
304
305    fn topo(nodes: &[u64]) -> Arc<RwLock<ClusterTopology>> {
306        let mut t = ClusterTopology::new();
307        for (i, id) in nodes.iter().enumerate() {
308            let a: SocketAddr = format!("127.0.0.1:{}", 9000 + i).parse().unwrap();
309            t.add_node(NodeInfo::new(*id, a, NodeState::Active));
310        }
311        Arc::new(RwLock::new(t))
312    }
313
314    fn routing_hot_on(node: u64) -> Arc<RwLock<RoutingTable>> {
315        let mut r = RoutingTable::uniform(6, &[1, 2, 3], 1);
316        for gid in 0..6 {
317            r.set_leader(gid, node);
318        }
319        Arc::new(RwLock::new(r))
320    }
321
322    fn lm(id: u64, v: u32, bytes_mib: u64, w: f64, r: f64) -> LoadMetrics {
323        LoadMetrics {
324            node_id: id,
325            vshards_led: v,
326            bytes_stored: bytes_mib * 1_048_576,
327            writes_per_sec: w,
328            reads_per_sec: r,
329            qps_recent: 0.0,
330            p95_latency_us: 0,
331            cpu_utilization: 0.0,
332        }
333    }
334
335    fn hot_cluster_loop(
336        gate: Arc<dyn ElectionGate>,
337    ) -> (Arc<RebalancerLoop>, Arc<RecordingDispatcher>) {
338        let metrics: Arc<dyn LoadMetricsProvider> = Arc::new(StaticMetrics(vec![
339            lm(1, 500, 5000, 200.0, 200.0),
340            lm(2, 5, 5, 5.0, 5.0),
341            lm(3, 5, 5, 5.0, 5.0),
342        ]));
343        let dispatcher = RecordingDispatcher::new();
344        let disp_dyn: Arc<dyn MigrationDispatcher> = dispatcher.clone();
345        let rloop = Arc::new(RebalancerLoop::new(
346            RebalancerLoopConfig {
347                interval: Duration::from_millis(50),
348                ..Default::default()
349            },
350            metrics,
351            disp_dyn,
352            gate,
353            routing_hot_on(1),
354            topo(&[1, 2, 3]),
355        ));
356        (rloop, dispatcher)
357    }
358
359    #[tokio::test]
360    async fn sweep_dispatches_moves_when_imbalanced() {
361        let (rloop, dispatcher) = hot_cluster_loop(Arc::new(AlwaysReadyGate));
362        rloop.sweep_once().await;
363        for _ in 0..16 {
364            tokio::task::yield_now().await;
365        }
366        let calls = dispatcher.take();
367        assert!(!calls.is_empty());
368        for c in &calls {
369            assert_eq!(c.source_node, 1);
370        }
371    }
372
373    #[tokio::test]
374    async fn sweep_skipped_during_election() {
375        let (rloop, dispatcher) = hot_cluster_loop(Arc::new(BlockingGate(true)));
376        rloop.sweep_once().await;
377        for _ in 0..8 {
378            tokio::task::yield_now().await;
379        }
380        assert!(dispatcher.take().is_empty());
381    }
382
383    #[tokio::test]
384    async fn sweep_noop_on_balanced_cluster() {
385        let metrics: Arc<dyn LoadMetricsProvider> = Arc::new(StaticMetrics(vec![
386            lm(1, 50, 500, 100.0, 100.0),
387            lm(2, 50, 500, 100.0, 100.0),
388            lm(3, 50, 500, 100.0, 100.0),
389        ]));
390        let dispatcher = RecordingDispatcher::new();
391        let rloop = Arc::new(RebalancerLoop::new(
392            RebalancerLoopConfig::default(),
393            metrics,
394            dispatcher.clone() as Arc<dyn MigrationDispatcher>,
395            Arc::new(AlwaysReadyGate),
396            routing_hot_on(1),
397            topo(&[1, 2, 3]),
398        ));
399        rloop.sweep_once().await;
400        for _ in 0..8 {
401            tokio::task::yield_now().await;
402        }
403        assert!(dispatcher.take().is_empty());
404    }
405
406    #[tokio::test(start_paused = true)]
407    async fn run_loop_fires_sweeps_and_shuts_down() {
408        let (rloop, dispatcher) = hot_cluster_loop(Arc::new(AlwaysReadyGate));
409        let (tx, rx) = watch::channel(false);
410        let handle = tokio::spawn({
411            let d = Arc::clone(&rloop);
412            async move { d.run(rx).await }
413        });
414        // First tick consumed immediately by run(); advance past a
415        // couple of real intervals with interleaved yields so the
416        // run-loop's select + spawned dispatch tasks all get to poll.
417        for _ in 0..4 {
418            tokio::time::advance(Duration::from_millis(80)).await;
419            for _ in 0..16 {
420                tokio::task::yield_now().await;
421            }
422        }
423        assert!(!dispatcher.take().is_empty());
424
425        let _ = tx.send(true);
426        let _ = tokio::time::timeout(Duration::from_millis(500), handle).await;
427    }
428
429    #[tokio::test]
430    async fn sweep_skipped_under_cpu_backpressure() {
431        let metrics: Arc<dyn LoadMetricsProvider> = Arc::new(StaticMetrics(vec![
432            LoadMetrics {
433                node_id: 1,
434                vshards_led: 500,
435                bytes_stored: 5000 * 1_048_576,
436                writes_per_sec: 200.0,
437                reads_per_sec: 200.0,
438                qps_recent: 0.0,
439                p95_latency_us: 0,
440                cpu_utilization: 0.95, // above 80% threshold
441            },
442            lm(2, 5, 5, 5.0, 5.0),
443            lm(3, 5, 5, 5.0, 5.0),
444        ]));
445        let dispatcher = RecordingDispatcher::new();
446        let rloop = Arc::new(RebalancerLoop::new(
447            RebalancerLoopConfig {
448                interval: Duration::from_millis(50),
449                ..Default::default()
450            },
451            metrics,
452            dispatcher.clone() as Arc<dyn MigrationDispatcher>,
453            Arc::new(AlwaysReadyGate),
454            routing_hot_on(1),
455            topo(&[1, 2, 3]),
456        ));
457        rloop.sweep_once().await;
458        for _ in 0..8 {
459            tokio::task::yield_now().await;
460        }
461        assert!(
462            dispatcher.take().is_empty(),
463            "dispatcher should not fire when cluster is under CPU backpressure"
464        );
465    }
466}