faucet_server/client/load_balancing/
mod.rs

1pub mod cookie_hash;
2mod ip_extractor;
3pub mod ip_hash;
4pub mod round_robin;
5pub mod rps_autoscale;
6
7use super::worker::WorkerConfig;
8use crate::client::Client;
9use crate::error::FaucetResult;
10use crate::leak;
11use cookie_hash::CookieHash;
12use hyper::Request;
13pub use ip_extractor::IpExtractor;
14use std::net::IpAddr;
15use std::str::FromStr;
16use uuid::Uuid;
17
18use self::ip_hash::IpHash;
19use self::round_robin::RoundRobin;
20use self::rps_autoscale::RpsAutoscale;
21
22const DEFAULT_MAX_RPS: f64 = 10.0;
23
24trait LoadBalancingStrategy {
25    type Input;
26    async fn entry(&self, ip: Self::Input) -> Client;
27}
28
29#[derive(Debug, Clone, Copy, clap::ValueEnum, Eq, PartialEq, serde::Deserialize)]
30#[serde(rename = "snake_case")]
31pub enum Strategy {
32    #[serde(alias = "round_robin", alias = "RoundRobin", alias = "round-robin")]
33    RoundRobin,
34    #[serde(alias = "ip_hash", alias = "IpHash", alias = "ip-hash")]
35    IpHash,
36    #[serde(alias = "cookie_hash", alias = "CookieHash", alias = "cookie-hash")]
37    CookieHash,
38    #[serde(alias = "rps", alias = "Rps", alias = "rps")]
39    Rps,
40}
41
42impl FromStr for Strategy {
43    type Err = &'static str;
44    fn from_str(s: &str) -> Result<Self, Self::Err> {
45        match s {
46            "round_robin" => Ok(Self::RoundRobin),
47            "ip_hash" => Ok(Self::IpHash),
48            "cookie_hash" => Ok(Self::CookieHash),
49            "rps" => Ok(Self::Rps),
50            _ => Err("invalid strategy"),
51        }
52    }
53}
54
55#[derive(Debug, Clone, Copy)]
56enum LBIdent {
57    Ip(IpAddr),
58    Uuid(Uuid),
59}
60
61#[derive(Copy, Clone)]
62enum DynLoadBalancer {
63    IpHash(&'static ip_hash::IpHash),
64    RoundRobin(&'static round_robin::RoundRobin),
65    CookieHash(&'static cookie_hash::CookieHash),
66    Rps(&'static rps_autoscale::RpsAutoscale),
67}
68
69impl LoadBalancingStrategy for DynLoadBalancer {
70    type Input = LBIdent;
71    async fn entry(&self, ip: LBIdent) -> Client {
72        match ip {
73            LBIdent::Ip(ip) => match self {
74                DynLoadBalancer::RoundRobin(rr) => rr.entry(ip).await,
75                DynLoadBalancer::IpHash(ih) => ih.entry(ip).await,
76                DynLoadBalancer::Rps(rr) => rr.entry(ip).await,
77                _ => unreachable!(
78                    "This should never happen, ip should never be passed to cookie hash"
79                ),
80            },
81            LBIdent::Uuid(uuid) => match self {
82                DynLoadBalancer::CookieHash(ch) => ch.entry(uuid).await,
83                _ => unreachable!(
84                    "This should never happen, uuid should never be passed to round robin or ip hash"
85                ),
86            },
87        }
88    }
89}
90
91pub(crate) struct LoadBalancer {
92    strategy: DynLoadBalancer,
93    extractor: IpExtractor,
94}
95
96impl LoadBalancer {
97    pub async fn new(
98        strategy: Strategy,
99        extractor: IpExtractor,
100        workers: &[&'static WorkerConfig],
101        max_rps_config: Option<f64>, // New parameter
102    ) -> FaucetResult<Self> {
103        let strategy: DynLoadBalancer = match strategy {
104            Strategy::RoundRobin => {
105                DynLoadBalancer::RoundRobin(leak!(RoundRobin::new(workers).await))
106            }
107            Strategy::IpHash => DynLoadBalancer::IpHash(leak!(IpHash::new(workers).await)),
108            Strategy::CookieHash => {
109                DynLoadBalancer::CookieHash(leak!(CookieHash::new(workers).await))
110            }
111            Strategy::Rps => {
112                let rps_value = max_rps_config.unwrap_or(DEFAULT_MAX_RPS);
113                DynLoadBalancer::Rps(leak!(RpsAutoscale::new(workers, rps_value).await))
114            }
115        };
116        Ok(Self {
117            strategy,
118            extractor,
119        })
120    }
121    pub fn get_strategy(&self) -> Strategy {
122        match self.strategy {
123            DynLoadBalancer::RoundRobin(_) => Strategy::RoundRobin,
124            DynLoadBalancer::IpHash(_) => Strategy::IpHash,
125            DynLoadBalancer::CookieHash(_) => Strategy::CookieHash,
126            DynLoadBalancer::Rps(_) => Strategy::Rps,
127        }
128    }
129    async fn get_client_ip(&self, ip: IpAddr) -> FaucetResult<Client> {
130        Ok(self.strategy.entry(LBIdent::Ip(ip)).await)
131    }
132    async fn get_client_uuid(&self, uuid: Uuid) -> FaucetResult<Client> {
133        Ok(self.strategy.entry(LBIdent::Uuid(uuid)).await)
134    }
135    pub async fn get_client(&self, ip: IpAddr, uuid: Option<Uuid>) -> FaucetResult<Client> {
136        if let Some(uuid) = uuid {
137            self.get_client_uuid(uuid).await
138        } else {
139            self.get_client_ip(ip).await
140        }
141    }
142    pub fn extract_ip<B>(
143        &self,
144        request: &Request<B>,
145        socket: Option<IpAddr>,
146    ) -> FaucetResult<IpAddr> {
147        self.extractor.extract(request, socket)
148    }
149}
150
151impl Clone for LoadBalancer {
152    fn clone(&self) -> Self {
153        Self {
154            strategy: self.strategy,
155            extractor: self.extractor,
156        }
157    }
158}
159
160#[cfg(test)]
161mod tests {
162
163    use super::*;
164
165    #[test]
166    fn test_strategy_from_str() {
167        assert_eq!(
168            Strategy::from_str("round_robin").unwrap(),
169            Strategy::RoundRobin
170        );
171        assert_eq!(Strategy::from_str("ip_hash").unwrap(), Strategy::IpHash);
172        assert!(Strategy::from_str("invalid").is_err());
173    }
174
175    #[tokio::test]
176    async fn test_load_balancer_new_round_robin() {
177        let configs = Vec::new();
178        let _ = LoadBalancer::new(
179            Strategy::RoundRobin,
180            IpExtractor::XForwardedFor,
181            &configs,
182            None,
183        )
184        .await
185        .expect("failed to create load balancer");
186    }
187
188    #[tokio::test]
189    async fn test_load_balancer_new_ip_hash() {
190        let configs = Vec::new();
191        let _ = LoadBalancer::new(Strategy::IpHash, IpExtractor::XForwardedFor, &configs, None)
192            .await
193            .expect("failed to create load balancer");
194    }
195
196    #[tokio::test]
197    async fn test_load_balancer_extract_ip() {
198        let configs = Vec::new();
199        let load_balancer = LoadBalancer::new(
200            Strategy::RoundRobin,
201            IpExtractor::XForwardedFor,
202            &configs,
203            None,
204        )
205        .await
206        .expect("failed to create load balancer");
207        let request = Request::builder()
208            .header("x-forwarded-for", "192.168.0.1")
209            .body(())
210            .unwrap();
211        let ip = load_balancer
212            .extract_ip(&request, Some("127.0.0.1".parse().unwrap()))
213            .expect("failed to extract ip");
214
215        assert_eq!(ip, "192.168.0.1".parse::<IpAddr>().unwrap());
216    }
217
218    #[tokio::test]
219    async fn test_load_balancer_get_client() {
220        use crate::client::ExtractSocketAddr;
221        let configs: [&'static WorkerConfig; 2] = [
222            &*Box::leak(Box::new(WorkerConfig::dummy(
223                "test",
224                "127.0.0.1:9999",
225                true,
226            ))),
227            &*Box::leak(Box::new(WorkerConfig::dummy(
228                "test",
229                "127.0.0.1:9998",
230                true,
231            ))),
232        ];
233        let load_balancer = LoadBalancer::new(
234            Strategy::RoundRobin,
235            IpExtractor::XForwardedFor,
236            &configs,
237            None,
238        )
239        .await
240        .expect("failed to create load balancer");
241        let ip = "192.168.0.1".parse().unwrap();
242        let client = load_balancer
243            .get_client_ip(ip)
244            .await
245            .expect("failed to get client");
246        assert_eq!(client.socket_addr(), "127.0.0.1:9999".parse().unwrap());
247
248        let client = load_balancer
249            .get_client_ip(ip)
250            .await
251            .expect("failed to get client");
252
253        assert_eq!(client.socket_addr(), "127.0.0.1:9998".parse().unwrap());
254
255        for config in configs.iter() {
256            config.wait_until_done().await;
257        }
258    }
259
260    #[tokio::test]
261    async fn test_clone_load_balancer() {
262        let configs = Vec::new();
263        let load_balancer = LoadBalancer::new(
264            Strategy::RoundRobin,
265            IpExtractor::XForwardedFor,
266            &configs,
267            None,
268        )
269        .await
270        .expect("failed to create load balancer");
271        let _ = load_balancer.clone();
272    }
273}