faucet_server/client/load_balancing/
mod.rs1pub mod cookie_hash;
2mod ip_extractor;
3pub mod ip_hash;
4pub mod round_robin;
5pub mod rps_autoscale;
6
7use super::worker::WorkerConfig;
8use crate::client::Client;
9use crate::error::FaucetResult;
10use crate::leak;
11use cookie_hash::CookieHash;
12use hyper::Request;
13pub use ip_extractor::IpExtractor;
14use std::net::IpAddr;
15use std::str::FromStr;
16use uuid::Uuid;
17
18use self::ip_hash::IpHash;
19use self::round_robin::RoundRobin;
20use self::rps_autoscale::RpsAutoscale;
21
22const DEFAULT_MAX_RPS: f64 = 10.0;
23
24trait LoadBalancingStrategy {
25 type Input;
26 async fn entry(&self, ip: Self::Input) -> Client;
27}
28
29#[derive(Debug, Clone, Copy, clap::ValueEnum, Eq, PartialEq, serde::Deserialize)]
30#[serde(rename = "snake_case")]
31pub enum Strategy {
32 #[serde(alias = "round_robin", alias = "RoundRobin", alias = "round-robin")]
33 RoundRobin,
34 #[serde(alias = "ip_hash", alias = "IpHash", alias = "ip-hash")]
35 IpHash,
36 #[serde(alias = "cookie_hash", alias = "CookieHash", alias = "cookie-hash")]
37 CookieHash,
38 #[serde(alias = "rps", alias = "Rps", alias = "rps")]
39 Rps,
40}
41
42impl FromStr for Strategy {
43 type Err = &'static str;
44 fn from_str(s: &str) -> Result<Self, Self::Err> {
45 match s {
46 "round_robin" => Ok(Self::RoundRobin),
47 "ip_hash" => Ok(Self::IpHash),
48 "cookie_hash" => Ok(Self::CookieHash),
49 "rps" => Ok(Self::Rps),
50 _ => Err("invalid strategy"),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy)]
56enum LBIdent {
57 Ip(IpAddr),
58 Uuid(Uuid),
59}
60
61#[derive(Copy, Clone)]
62enum DynLoadBalancer {
63 IpHash(&'static ip_hash::IpHash),
64 RoundRobin(&'static round_robin::RoundRobin),
65 CookieHash(&'static cookie_hash::CookieHash),
66 Rps(&'static rps_autoscale::RpsAutoscale),
67}
68
69impl LoadBalancingStrategy for DynLoadBalancer {
70 type Input = LBIdent;
71 async fn entry(&self, ip: LBIdent) -> Client {
72 match ip {
73 LBIdent::Ip(ip) => match self {
74 DynLoadBalancer::RoundRobin(rr) => rr.entry(ip).await,
75 DynLoadBalancer::IpHash(ih) => ih.entry(ip).await,
76 DynLoadBalancer::Rps(rr) => rr.entry(ip).await,
77 _ => unreachable!(
78 "This should never happen, ip should never be passed to cookie hash"
79 ),
80 },
81 LBIdent::Uuid(uuid) => match self {
82 DynLoadBalancer::CookieHash(ch) => ch.entry(uuid).await,
83 _ => unreachable!(
84 "This should never happen, uuid should never be passed to round robin or ip hash"
85 ),
86 },
87 }
88 }
89}
90
91pub(crate) struct LoadBalancer {
92 strategy: DynLoadBalancer,
93 extractor: IpExtractor,
94}
95
96impl LoadBalancer {
97 pub async fn new(
98 strategy: Strategy,
99 extractor: IpExtractor,
100 workers: &[&'static WorkerConfig],
101 max_rps_config: Option<f64>, ) -> FaucetResult<Self> {
103 let strategy: DynLoadBalancer = match strategy {
104 Strategy::RoundRobin => {
105 DynLoadBalancer::RoundRobin(leak!(RoundRobin::new(workers).await))
106 }
107 Strategy::IpHash => DynLoadBalancer::IpHash(leak!(IpHash::new(workers).await)),
108 Strategy::CookieHash => {
109 DynLoadBalancer::CookieHash(leak!(CookieHash::new(workers).await))
110 }
111 Strategy::Rps => {
112 let rps_value = max_rps_config.unwrap_or(DEFAULT_MAX_RPS);
113 DynLoadBalancer::Rps(leak!(RpsAutoscale::new(workers, rps_value).await))
114 }
115 };
116 Ok(Self {
117 strategy,
118 extractor,
119 })
120 }
121 pub fn get_strategy(&self) -> Strategy {
122 match self.strategy {
123 DynLoadBalancer::RoundRobin(_) => Strategy::RoundRobin,
124 DynLoadBalancer::IpHash(_) => Strategy::IpHash,
125 DynLoadBalancer::CookieHash(_) => Strategy::CookieHash,
126 DynLoadBalancer::Rps(_) => Strategy::Rps,
127 }
128 }
129 async fn get_client_ip(&self, ip: IpAddr) -> FaucetResult<Client> {
130 Ok(self.strategy.entry(LBIdent::Ip(ip)).await)
131 }
132 async fn get_client_uuid(&self, uuid: Uuid) -> FaucetResult<Client> {
133 Ok(self.strategy.entry(LBIdent::Uuid(uuid)).await)
134 }
135 pub async fn get_client(&self, ip: IpAddr, uuid: Option<Uuid>) -> FaucetResult<Client> {
136 if let Some(uuid) = uuid {
137 self.get_client_uuid(uuid).await
138 } else {
139 self.get_client_ip(ip).await
140 }
141 }
142 pub fn extract_ip<B>(
143 &self,
144 request: &Request<B>,
145 socket: Option<IpAddr>,
146 ) -> FaucetResult<IpAddr> {
147 self.extractor.extract(request, socket)
148 }
149}
150
151impl Clone for LoadBalancer {
152 fn clone(&self) -> Self {
153 Self {
154 strategy: self.strategy,
155 extractor: self.extractor,
156 }
157 }
158}
159
160#[cfg(test)]
161mod tests {
162
163 use super::*;
164
165 #[test]
166 fn test_strategy_from_str() {
167 assert_eq!(
168 Strategy::from_str("round_robin").unwrap(),
169 Strategy::RoundRobin
170 );
171 assert_eq!(Strategy::from_str("ip_hash").unwrap(), Strategy::IpHash);
172 assert!(Strategy::from_str("invalid").is_err());
173 }
174
175 #[tokio::test]
176 async fn test_load_balancer_new_round_robin() {
177 let configs = Vec::new();
178 let _ = LoadBalancer::new(
179 Strategy::RoundRobin,
180 IpExtractor::XForwardedFor,
181 &configs,
182 None,
183 )
184 .await
185 .expect("failed to create load balancer");
186 }
187
188 #[tokio::test]
189 async fn test_load_balancer_new_ip_hash() {
190 let configs = Vec::new();
191 let _ = LoadBalancer::new(Strategy::IpHash, IpExtractor::XForwardedFor, &configs, None)
192 .await
193 .expect("failed to create load balancer");
194 }
195
196 #[tokio::test]
197 async fn test_load_balancer_extract_ip() {
198 let configs = Vec::new();
199 let load_balancer = LoadBalancer::new(
200 Strategy::RoundRobin,
201 IpExtractor::XForwardedFor,
202 &configs,
203 None,
204 )
205 .await
206 .expect("failed to create load balancer");
207 let request = Request::builder()
208 .header("x-forwarded-for", "192.168.0.1")
209 .body(())
210 .unwrap();
211 let ip = load_balancer
212 .extract_ip(&request, Some("127.0.0.1".parse().unwrap()))
213 .expect("failed to extract ip");
214
215 assert_eq!(ip, "192.168.0.1".parse::<IpAddr>().unwrap());
216 }
217
218 #[tokio::test]
219 async fn test_load_balancer_get_client() {
220 use crate::client::ExtractSocketAddr;
221 let configs: [&'static WorkerConfig; 2] = [
222 &*Box::leak(Box::new(WorkerConfig::dummy(
223 "test",
224 "127.0.0.1:9999",
225 true,
226 ))),
227 &*Box::leak(Box::new(WorkerConfig::dummy(
228 "test",
229 "127.0.0.1:9998",
230 true,
231 ))),
232 ];
233 let load_balancer = LoadBalancer::new(
234 Strategy::RoundRobin,
235 IpExtractor::XForwardedFor,
236 &configs,
237 None,
238 )
239 .await
240 .expect("failed to create load balancer");
241 let ip = "192.168.0.1".parse().unwrap();
242 let client = load_balancer
243 .get_client_ip(ip)
244 .await
245 .expect("failed to get client");
246 assert_eq!(client.socket_addr(), "127.0.0.1:9999".parse().unwrap());
247
248 let client = load_balancer
249 .get_client_ip(ip)
250 .await
251 .expect("failed to get client");
252
253 assert_eq!(client.socket_addr(), "127.0.0.1:9998".parse().unwrap());
254
255 for config in configs.iter() {
256 config.wait_until_done().await;
257 }
258 }
259
260 #[tokio::test]
261 async fn test_clone_load_balancer() {
262 let configs = Vec::new();
263 let load_balancer = LoadBalancer::new(
264 Strategy::RoundRobin,
265 IpExtractor::XForwardedFor,
266 &configs,
267 None,
268 )
269 .await
270 .expect("failed to create load balancer");
271 let _ = load_balancer.clone();
272 }
273}