faucet_server/client/load_balancing/
round_robin.rs1use super::LoadBalancingStrategy;
2use crate::client::{worker::WorkerConfig, Client};
3use std::{net::IpAddr, sync::atomic::AtomicUsize};
4
5struct Targets {
6 targets: &'static [Client],
7 index: AtomicUsize,
8}
9
10const WAIT_TIME_UNTIL_RETRY: std::time::Duration = std::time::Duration::from_micros(500);
13
14impl Targets {
15 fn new(configs: &[&'static WorkerConfig]) -> Self {
16 let mut targets = Vec::new();
17 for state in configs {
18 let client = Client::new(state);
19 targets.push(client);
20 }
21 let targets = Box::leak(targets.into_boxed_slice());
22 Targets {
23 targets,
24 index: AtomicUsize::new(0),
25 }
26 }
27 fn next(&self) -> Client {
28 let index = self.index.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
29 self.targets[index % self.targets.len()].clone()
30 }
31}
32
33pub struct RoundRobin {
34 targets: Targets,
35}
36
37impl RoundRobin {
38 pub(crate) async fn new(configs: &[&'static WorkerConfig]) -> Self {
39 for config in configs {
41 config.spawn_worker_task().await;
42 }
43 Self {
44 targets: Targets::new(configs),
45 }
46 }
47}
48
49impl LoadBalancingStrategy for RoundRobin {
50 type Input = IpAddr;
51 async fn entry(&self, _ip: IpAddr) -> Client {
52 let mut client = self.targets.next();
53 loop {
54 if client.is_online() {
55 break client;
56 }
57 tokio::time::sleep(WAIT_TIME_UNTIL_RETRY).await;
58 client = self.targets.next();
59 }
60 }
61}
62
63#[cfg(test)]
64mod tests {
65
66 use super::*;
67
68 #[test]
69 fn test_new_targets() {
70 let configs_static_refs: Vec<&'static WorkerConfig> = (0..3)
71 .map(|i| {
72 &*Box::leak(Box::new(WorkerConfig::dummy(
73 "test",
74 &format!("127.0.0.1:900{i}"),
75 true,
76 )))
77 })
78 .collect();
79
80 let _ = Targets::new(&configs_static_refs);
81 }
82
83 #[tokio::test]
84 async fn test_new_round_robin() {
85 let configs_static_refs: Vec<&'static WorkerConfig> = (0..3)
86 .map(|i| {
87 &*Box::leak(Box::new(WorkerConfig::dummy(
88 "test",
89 &format!("127.0.0.1:900{i}"),
90 true,
91 )))
92 })
93 .collect();
94
95 let _ = RoundRobin::new(&configs_static_refs).await;
96
97 for config in configs_static_refs.iter() {
98 config.wait_until_done().await;
99 }
100 }
101
102 #[tokio::test]
103 async fn test_round_robin_entry() {
104 use crate::client::ExtractSocketAddr;
105
106 let original_addrs: Vec<std::net::SocketAddr> = (0..3)
107 .map(|i| {
108 format!("127.0.0.1:900{i}")
109 .parse()
110 .expect("Failed to parse addr")
111 })
112 .collect();
113
114 let configs_static_refs: Vec<&'static WorkerConfig> = (0..3)
115 .map(|i| {
116 &*Box::leak(Box::new(WorkerConfig::dummy(
117 "test",
118 &format!("127.0.0.1:900{i}"),
119 true,
120 )))
121 })
122 .collect();
123
124 let rr = RoundRobin::new(&configs_static_refs).await;
125
126 let ip = "0.0.0.0".parse().expect("failed to parse ip");
127
128 assert_eq!(rr.entry(ip).await.socket_addr(), original_addrs[0]);
129 assert_eq!(rr.entry(ip).await.socket_addr(), original_addrs[1]);
130 assert_eq!(rr.entry(ip).await.socket_addr(), original_addrs[2]);
131 assert_eq!(rr.entry(ip).await.socket_addr(), original_addrs[0]);
132 assert_eq!(rr.entry(ip).await.socket_addr(), original_addrs[1]);
133 assert_eq!(rr.entry(ip).await.socket_addr(), original_addrs[2]);
134
135 for config in configs_static_refs.iter() {
136 config.wait_until_done().await;
137 }
138 }
139
140 #[tokio::test]
141 async fn test_round_robin_entry_with_offline_target() {
142 use crate::client::ExtractSocketAddr;
143
144 let target_online_addr: std::net::SocketAddr = "127.0.0.1:9002".parse().unwrap();
146
147 let configs_static_refs: [&'static WorkerConfig; 3] = [
148 &*Box::leak(Box::new(WorkerConfig::dummy(
149 "test",
150 "127.0.0.1:9000",
151 false,
152 ))),
153 &*Box::leak(Box::new(WorkerConfig::dummy(
154 "test",
155 "127.0.0.1:9001",
156 false,
157 ))),
158 &*Box::leak(Box::new(WorkerConfig::dummy(
159 "test",
160 "127.0.0.1:9002",
161 true,
162 ))),
163 ];
164
165 let rr = RoundRobin::new(&configs_static_refs).await;
166
167 let ip = "0.0.0.0".parse().expect("failed to parse ip");
168
169 assert_eq!(rr.entry(ip).await.socket_addr(), target_online_addr);
170
171 for config in configs_static_refs.iter() {
172 config.wait_until_done().await;
173 }
174 }
175}