1use std::collections::HashMap;
17use std::sync::Arc;
18
19use parking_lot::RwLock;
20use serde::{Deserialize, Serialize};
21
22#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
23pub struct NodeMetrics {
24 pub address: String,
25 pub timestamp: u64,
26 pub cpu_load: f64,
27 pub memory_used: u64,
28 pub memory_max: u64,
29}
30
31impl NodeMetrics {
32 pub fn memory_usage(&self) -> f64 {
35 if self.memory_max == 0 {
36 0.0
37 } else {
38 self.memory_used as f64 / self.memory_max as f64
39 }
40 }
41}
42
43#[derive(Default)]
44pub struct ClusterMetrics {
45 entries: RwLock<HashMap<String, NodeMetrics>>,
46}
47
48impl ClusterMetrics {
49 pub fn new() -> Self {
50 Self::default()
51 }
52
53 pub fn publish(&self, m: NodeMetrics) {
54 self.entries.write().insert(m.address.clone(), m);
55 }
56
57 pub fn snapshot(&self) -> Vec<NodeMetrics> {
58 self.entries.read().values().cloned().collect()
59 }
60
61 pub fn get(&self, address: &str) -> Option<NodeMetrics> {
62 self.entries.read().get(address).cloned()
63 }
64
65 pub fn node_count(&self) -> usize {
66 self.entries.read().len()
67 }
68}
69
70pub trait MetricsProbe: Send + Sync + 'static {
76 fn sample(&self, address: &str, timestamp: u64) -> NodeMetrics;
77}
78
79pub struct StaticProbe {
82 pub cpu_load: f64,
83 pub memory_used: u64,
84 pub memory_max: u64,
85}
86
87impl MetricsProbe for StaticProbe {
88 fn sample(&self, address: &str, timestamp: u64) -> NodeMetrics {
89 NodeMetrics {
90 address: address.into(),
91 timestamp,
92 cpu_load: self.cpu_load,
93 memory_used: self.memory_used,
94 memory_max: self.memory_max,
95 }
96 }
97}
98
99pub struct AdaptiveLoadBalancer {
105 metrics: Arc<ClusterMetrics>,
106}
107
108impl AdaptiveLoadBalancer {
109 pub fn new(metrics: Arc<ClusterMetrics>) -> Self {
110 Self { metrics }
111 }
112
113 pub fn pick<'a>(&self, candidates: &'a [&'a str]) -> Option<&'a str> {
116 if candidates.is_empty() {
117 return None;
118 }
119 let snapshot = self.metrics.snapshot();
120 let lookup: HashMap<&str, &NodeMetrics> = snapshot.iter().map(|m| (m.address.as_str(), m)).collect();
121 let mut sorted: Vec<&&str> = candidates.iter().collect();
122 sorted.sort_by(|a, b| {
123 let load_a = lookup.get(*a).map(|m| m.cpu_load).unwrap_or(f64::INFINITY);
124 let load_b = lookup.get(*b).map(|m| m.cpu_load).unwrap_or(f64::INFINITY);
125 load_a.partial_cmp(&load_b).unwrap_or(std::cmp::Ordering::Equal).then_with(|| a.cmp(b))
126 });
127 sorted.first().copied().copied()
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn publish_and_fetch() {
137 let m = ClusterMetrics::new();
138 m.publish(NodeMetrics {
139 address: "a".into(),
140 timestamp: 1,
141 cpu_load: 0.5,
142 memory_used: 100,
143 memory_max: 1000,
144 });
145 assert_eq!(m.snapshot().len(), 1);
146 assert_eq!(m.get("a").unwrap().cpu_load, 0.5);
147 }
148
149 #[test]
150 fn memory_usage_ratio() {
151 let m = NodeMetrics {
152 address: "a".into(),
153 timestamp: 0,
154 cpu_load: 0.0,
155 memory_used: 250,
156 memory_max: 1000,
157 };
158 assert_eq!(m.memory_usage(), 0.25);
159 }
160
161 #[test]
162 fn memory_usage_handles_zero_max() {
163 let m =
164 NodeMetrics { address: "a".into(), timestamp: 0, cpu_load: 0.0, memory_used: 0, memory_max: 0 };
165 assert_eq!(m.memory_usage(), 0.0);
166 }
167
168 #[test]
169 fn static_probe_returns_configured_values() {
170 let probe = StaticProbe { cpu_load: 0.7, memory_used: 5, memory_max: 10 };
171 let m = probe.sample("nodeA", 42);
172 assert_eq!(m.address, "nodeA");
173 assert_eq!(m.timestamp, 42);
174 assert_eq!(m.cpu_load, 0.7);
175 assert_eq!(m.memory_used, 5);
176 }
177
178 #[test]
179 fn adaptive_picks_lowest_load() {
180 let m = Arc::new(ClusterMetrics::new());
181 m.publish(NodeMetrics {
182 address: "a".into(),
183 timestamp: 0,
184 cpu_load: 0.9,
185 memory_used: 0,
186 memory_max: 1,
187 });
188 m.publish(NodeMetrics {
189 address: "b".into(),
190 timestamp: 0,
191 cpu_load: 0.1,
192 memory_used: 0,
193 memory_max: 1,
194 });
195 m.publish(NodeMetrics {
196 address: "c".into(),
197 timestamp: 0,
198 cpu_load: 0.5,
199 memory_used: 0,
200 memory_max: 1,
201 });
202 let lb = AdaptiveLoadBalancer::new(m);
203 assert_eq!(lb.pick(&["a", "b", "c"]), Some("b"));
204 }
205
206 #[test]
207 fn adaptive_falls_back_to_address_order_when_no_metrics() {
208 let m = Arc::new(ClusterMetrics::new());
209 let lb = AdaptiveLoadBalancer::new(m);
210 assert_eq!(lb.pick(&["c", "a", "b"]), Some("a"));
211 }
212
213 #[test]
214 fn adaptive_returns_none_for_empty_candidates() {
215 let m = Arc::new(ClusterMetrics::new());
216 let lb = AdaptiveLoadBalancer::new(m);
217 assert_eq!(lb.pick(&[]), None);
218 }
219}
220
221#[cfg(feature = "sysinfo-probe")]
224pub mod sys {
225 use super::{MetricsProbe, NodeMetrics};
228 use std::sync::Mutex;
229 use sysinfo::System;
230
231 pub struct SysinfoProbe {
232 sys: Mutex<System>,
233 }
234
235 impl Default for SysinfoProbe {
236 fn default() -> Self {
237 Self::new()
238 }
239 }
240
241 impl SysinfoProbe {
242 pub fn new() -> Self {
243 Self { sys: Mutex::new(System::new_all()) }
244 }
245 }
246
247 impl MetricsProbe for SysinfoProbe {
248 fn sample(&self, address: &str, timestamp: u64) -> NodeMetrics {
249 let mut sys = self.sys.lock().unwrap();
250 sys.refresh_cpu_usage();
251 sys.refresh_memory();
252 let cpu_load = (sys.global_cpu_usage() as f64 / 100.0).clamp(0.0, 1.0);
254 let memory_max = sys.total_memory();
255 let memory_used = sys.used_memory();
256 NodeMetrics { address: address.into(), timestamp, cpu_load, memory_used, memory_max }
257 }
258 }
259
260 #[cfg(test)]
261 mod tests {
262 use super::*;
263
264 #[test]
265 fn sysinfo_probe_returns_finite_load() {
266 let p = SysinfoProbe::new();
267 let m = p.sample("a", 1);
268 assert!(m.cpu_load.is_finite());
269 assert!(m.memory_max >= m.memory_used);
270 }
271 }
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize)]
278#[non_exhaustive]
279pub enum MetricsPdu {
280 Push(NodeMetrics),
282 PushBatch(Vec<NodeMetrics>),
284}
285
286pub trait MetricsTransport: Send + Sync + 'static {
289 fn send(&self, target_node: &str, pdu: MetricsPdu);
290}
291
292pub fn apply_metrics_pdu(metrics: &ClusterMetrics, pdu: MetricsPdu) {
294 match pdu {
295 MetricsPdu::Push(m) => metrics.publish(m),
296 MetricsPdu::PushBatch(v) => {
297 for m in v {
298 metrics.publish(m);
299 }
300 }
301 }
302}
303
304pub fn gossip_local_metrics<P: MetricsProbe + ?Sized>(
306 probe: &P,
307 self_address: &str,
308 target_node: &str,
309 transport: &dyn MetricsTransport,
310 now: u64,
311) {
312 let m = probe.sample(self_address, now);
313 transport.send(target_node, MetricsPdu::Push(m));
314}
315
316#[cfg(test)]
317mod gossip_tests {
318 use super::*;
319 use std::sync::Mutex;
320
321 #[derive(Default)]
322 struct CaptureTransport {
323 sent: Mutex<Vec<(String, MetricsPdu)>>,
324 }
325 impl MetricsTransport for CaptureTransport {
326 fn send(&self, target: &str, pdu: MetricsPdu) {
327 self.sent.lock().unwrap().push((target.to_string(), pdu));
328 }
329 }
330
331 #[test]
332 fn gossip_pushes_local_sample_to_target() {
333 let probe = StaticProbe { cpu_load: 0.3, memory_used: 1, memory_max: 4 };
334 let net = CaptureTransport::default();
335 gossip_local_metrics(&probe, "self", "peer", &net, 1);
336 let sent = net.sent.lock().unwrap();
337 assert_eq!(sent.len(), 1);
338 match &sent[0].1 {
339 MetricsPdu::Push(m) => assert_eq!(m.address, "self"),
340 _ => panic!("expected Push"),
341 }
342 }
343
344 #[test]
345 fn apply_pdu_merges_into_metrics() {
346 let m = ClusterMetrics::new();
347 let pdu = MetricsPdu::Push(NodeMetrics {
348 address: "x".into(),
349 timestamp: 7,
350 cpu_load: 0.1,
351 memory_used: 1,
352 memory_max: 2,
353 });
354 apply_metrics_pdu(&m, pdu);
355 assert_eq!(m.node_count(), 1);
356 assert_eq!(m.get("x").unwrap().timestamp, 7);
357 }
358
359 #[test]
360 fn adaptive_balancer_can_be_used_as_picker_closure() {
361 let m = Arc::new(ClusterMetrics::new());
362 m.publish(NodeMetrics {
363 address: "akka.tcp://Sys@a:1".into(),
364 timestamp: 0,
365 cpu_load: 0.9,
366 memory_used: 0,
367 memory_max: 1,
368 });
369 m.publish(NodeMetrics {
370 address: "akka.tcp://Sys@b:1".into(),
371 timestamp: 0,
372 cpu_load: 0.1,
373 memory_used: 0,
374 memory_max: 1,
375 });
376 let lb = Arc::new(AdaptiveLoadBalancer::new(m));
377 type Picker = Arc<dyn Fn(&[String]) -> Option<String> + Send + Sync>;
378 let picker: Picker = {
379 let lb = lb.clone();
380 Arc::new(move |cands| {
381 let refs: Vec<&str> = cands.iter().map(String::as_str).collect();
382 lb.pick(&refs).map(|s| s.to_string())
383 })
384 };
385 let chosen = picker(&["akka.tcp://Sys@a:1".to_string(), "akka.tcp://Sys@b:1".to_string()]).unwrap();
386 assert_eq!(chosen, "akka.tcp://Sys@b:1");
387 }
388
389 #[test]
390 fn batch_pdu_merges_each() {
391 let m = ClusterMetrics::new();
392 let pdu = MetricsPdu::PushBatch(vec![
393 NodeMetrics { address: "a".into(), timestamp: 1, cpu_load: 0.0, memory_used: 0, memory_max: 0 },
394 NodeMetrics { address: "b".into(), timestamp: 2, cpu_load: 0.0, memory_used: 0, memory_max: 0 },
395 ]);
396 apply_metrics_pdu(&m, pdu);
397 assert_eq!(m.node_count(), 2);
398 }
399}