faucet_server/client/load_balancing/
round_robin.rs

1use super::LoadBalancingStrategy;
2use crate::client::{worker::WorkerConfig, Client};
3use std::{net::IpAddr, sync::atomic::AtomicUsize};
4
5struct Targets {
6    targets: &'static [Client],
7    index: AtomicUsize,
8}
9
10// 500us is the time it takes for the round robin to move to the next target
11// in the unlikely event that the target is offline
12const WAIT_TIME_UNTIL_RETRY: std::time::Duration = std::time::Duration::from_micros(500);
13
14impl Targets {
15    fn new(configs: &[&'static WorkerConfig]) -> Self {
16        let mut targets = Vec::new();
17        for state in configs {
18            let client = Client::new(state);
19            targets.push(client);
20        }
21        let targets = Box::leak(targets.into_boxed_slice());
22        Targets {
23            targets,
24            index: AtomicUsize::new(0),
25        }
26    }
27    fn next(&self) -> Client {
28        let index = self.index.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
29        self.targets[index % self.targets.len()].clone()
30    }
31}
32
33pub struct RoundRobin {
34    targets: Targets,
35}
36
37impl RoundRobin {
38    pub(crate) async fn new(configs: &[&'static WorkerConfig]) -> Self {
39        // Start the process of each config
40        for config in configs {
41            config.spawn_worker_task().await;
42        }
43        Self {
44            targets: Targets::new(configs),
45        }
46    }
47}
48
49impl LoadBalancingStrategy for RoundRobin {
50    type Input = IpAddr;
51    async fn entry(&self, _ip: IpAddr) -> Client {
52        let mut client = self.targets.next();
53        loop {
54            if client.is_online() {
55                break client;
56            }
57            tokio::time::sleep(WAIT_TIME_UNTIL_RETRY).await;
58            client = self.targets.next();
59        }
60    }
61}
62
63#[cfg(test)]
64mod tests {
65
66    use super::*;
67
68    #[test]
69    fn test_new_targets() {
70        let configs_static_refs: Vec<&'static WorkerConfig> = (0..3)
71            .map(|i| {
72                &*Box::leak(Box::new(WorkerConfig::dummy(
73                    "test",
74                    &format!("127.0.0.1:900{i}"),
75                    true,
76                )))
77            })
78            .collect();
79
80        let _ = Targets::new(&configs_static_refs);
81    }
82
83    #[tokio::test]
84    async fn test_new_round_robin() {
85        let configs_static_refs: Vec<&'static WorkerConfig> = (0..3)
86            .map(|i| {
87                &*Box::leak(Box::new(WorkerConfig::dummy(
88                    "test",
89                    &format!("127.0.0.1:900{i}"),
90                    true,
91                )))
92            })
93            .collect();
94
95        let _ = RoundRobin::new(&configs_static_refs).await;
96
97        for config in configs_static_refs.iter() {
98            config.wait_until_done().await;
99        }
100    }
101
102    #[tokio::test]
103    async fn test_round_robin_entry() {
104        use crate::client::ExtractSocketAddr;
105
106        let original_addrs: Vec<std::net::SocketAddr> = (0..3)
107            .map(|i| {
108                format!("127.0.0.1:900{i}")
109                    .parse()
110                    .expect("Failed to parse addr")
111            })
112            .collect();
113
114        let configs_static_refs: Vec<&'static WorkerConfig> = (0..3)
115            .map(|i| {
116                &*Box::leak(Box::new(WorkerConfig::dummy(
117                    "test",
118                    &format!("127.0.0.1:900{i}"),
119                    true,
120                )))
121            })
122            .collect();
123
124        let rr = RoundRobin::new(&configs_static_refs).await;
125
126        let ip = "0.0.0.0".parse().expect("failed to parse ip");
127
128        assert_eq!(rr.entry(ip).await.socket_addr(), original_addrs[0]);
129        assert_eq!(rr.entry(ip).await.socket_addr(), original_addrs[1]);
130        assert_eq!(rr.entry(ip).await.socket_addr(), original_addrs[2]);
131        assert_eq!(rr.entry(ip).await.socket_addr(), original_addrs[0]);
132        assert_eq!(rr.entry(ip).await.socket_addr(), original_addrs[1]);
133        assert_eq!(rr.entry(ip).await.socket_addr(), original_addrs[2]);
134
135        for config in configs_static_refs.iter() {
136            config.wait_until_done().await;
137        }
138    }
139
140    #[tokio::test]
141    async fn test_round_robin_entry_with_offline_target() {
142        use crate::client::ExtractSocketAddr;
143
144        // Storing the target address for assertion, as the original WorkerConfig array is no longer directly used.
145        let target_online_addr: std::net::SocketAddr = "127.0.0.1:9002".parse().unwrap();
146
147        let configs_static_refs: [&'static WorkerConfig; 3] = [
148            &*Box::leak(Box::new(WorkerConfig::dummy(
149                "test",
150                "127.0.0.1:9000",
151                false,
152            ))),
153            &*Box::leak(Box::new(WorkerConfig::dummy(
154                "test",
155                "127.0.0.1:9001",
156                false,
157            ))),
158            &*Box::leak(Box::new(WorkerConfig::dummy(
159                "test",
160                "127.0.0.1:9002",
161                true,
162            ))),
163        ];
164
165        let rr = RoundRobin::new(&configs_static_refs).await;
166
167        let ip = "0.0.0.0".parse().expect("failed to parse ip");
168
169        assert_eq!(rr.entry(ip).await.socket_addr(), target_online_addr);
170
171        for config in configs_static_refs.iter() {
172            config.wait_until_done().await;
173        }
174    }
175}