faucet_server/client/load_balancing/
mod.rs

1pub mod cookie_hash;
2mod ip_extractor;
3pub mod ip_hash;
4pub mod round_robin;
5use super::worker::WorkerConfig;
6use crate::client::Client;
7use crate::error::FaucetResult;
8use crate::leak;
9use cookie_hash::CookieHash;
10use hyper::Request;
11pub use ip_extractor::IpExtractor;
12use std::net::IpAddr;
13use std::str::FromStr;
14use uuid::Uuid;
15
16use self::ip_hash::IpHash;
17use self::round_robin::RoundRobin;
18
19trait LoadBalancingStrategy {
20    type Input;
21    async fn entry(&self, ip: Self::Input) -> Client;
22}
23
24#[derive(Debug, Clone, Copy, clap::ValueEnum, Eq, PartialEq, serde::Deserialize)]
25#[serde(rename = "snake_case")]
26pub enum Strategy {
27    #[serde(alias = "round_robin", alias = "RoundRobin", alias = "round-robin")]
28    RoundRobin,
29    #[serde(alias = "ip_hash", alias = "IpHash", alias = "ip-hash")]
30    IpHash,
31    #[serde(alias = "cookie_hash", alias = "CookieHash", alias = "cookie-hash")]
32    CookieHash,
33}
34
35impl FromStr for Strategy {
36    type Err = &'static str;
37    fn from_str(s: &str) -> Result<Self, Self::Err> {
38        match s {
39            "round_robin" => Ok(Self::RoundRobin),
40            "ip_hash" => Ok(Self::IpHash),
41            "cookie_hash" => Ok(Self::CookieHash),
42            _ => Err("invalid strategy"),
43        }
44    }
45}
46
47#[derive(Debug, Clone, Copy)]
48enum LBIdent {
49    Ip(IpAddr),
50    Uuid(Uuid),
51}
52
53#[derive(Copy, Clone)]
54enum DynLoadBalancer {
55    IpHash(&'static ip_hash::IpHash),
56    RoundRobin(&'static round_robin::RoundRobin),
57    CookieHash(&'static cookie_hash::CookieHash),
58}
59
60impl LoadBalancingStrategy for DynLoadBalancer {
61    type Input = LBIdent;
62    async fn entry(&self, ip: LBIdent) -> Client {
63        match ip {
64            LBIdent::Ip(ip) => match self {
65                DynLoadBalancer::RoundRobin(rr) => rr.entry(ip).await,
66                DynLoadBalancer::IpHash(ih) => ih.entry(ip).await,
67                _ => unreachable!(
68                    "This should never happen, ip should never be passed to cookie hash"
69                ),
70            },
71            LBIdent::Uuid(uuid) => match self {
72                DynLoadBalancer::CookieHash(ch) => ch.entry(uuid).await,
73                _ => unreachable!(
74                    "This should never happen, uuid should never be passed to round robin or ip hash"
75                ),
76            },
77        }
78    }
79}
80
81pub(crate) struct LoadBalancer {
82    strategy: DynLoadBalancer,
83    extractor: IpExtractor,
84}
85
86impl LoadBalancer {
87    pub fn new(
88        strategy: Strategy,
89        extractor: IpExtractor,
90        workers: &[WorkerConfig],
91    ) -> FaucetResult<Self> {
92        let strategy: DynLoadBalancer = match strategy {
93            Strategy::RoundRobin => DynLoadBalancer::RoundRobin(leak!(RoundRobin::new(workers)?)),
94            Strategy::IpHash => DynLoadBalancer::IpHash(leak!(IpHash::new(workers)?)),
95            Strategy::CookieHash => DynLoadBalancer::CookieHash(leak!(CookieHash::new(workers)?)),
96        };
97        Ok(Self {
98            strategy,
99            extractor,
100        })
101    }
102    pub fn get_strategy(&self) -> Strategy {
103        match self.strategy {
104            DynLoadBalancer::RoundRobin(_) => Strategy::RoundRobin,
105            DynLoadBalancer::IpHash(_) => Strategy::IpHash,
106            DynLoadBalancer::CookieHash(_) => Strategy::CookieHash,
107        }
108    }
109    async fn get_client_ip(&self, ip: IpAddr) -> FaucetResult<Client> {
110        Ok(self.strategy.entry(LBIdent::Ip(ip)).await)
111    }
112    async fn get_client_uuid(&self, uuid: Uuid) -> FaucetResult<Client> {
113        Ok(self.strategy.entry(LBIdent::Uuid(uuid)).await)
114    }
115    pub async fn get_client(&self, ip: IpAddr, uuid: Option<Uuid>) -> FaucetResult<Client> {
116        if let Some(uuid) = uuid {
117            self.get_client_uuid(uuid).await
118        } else {
119            self.get_client_ip(ip).await
120        }
121    }
122    pub fn extract_ip<B>(
123        &self,
124        request: &Request<B>,
125        socket: Option<IpAddr>,
126    ) -> FaucetResult<IpAddr> {
127        self.extractor.extract(request, socket)
128    }
129}
130
131impl Clone for LoadBalancer {
132    fn clone(&self) -> Self {
133        Self {
134            strategy: self.strategy,
135            extractor: self.extractor,
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142
143    use super::*;
144
145    #[test]
146    fn test_strategy_from_str() {
147        assert_eq!(
148            Strategy::from_str("round_robin").unwrap(),
149            Strategy::RoundRobin
150        );
151        assert_eq!(Strategy::from_str("ip_hash").unwrap(), Strategy::IpHash);
152        assert!(Strategy::from_str("invalid").is_err());
153    }
154
155    #[test]
156    fn test_load_balancer_new_round_robin() {
157        let configs = Vec::new();
158        let _ = LoadBalancer::new(Strategy::RoundRobin, IpExtractor::XForwardedFor, &configs)
159            .expect("failed to create load balancer");
160    }
161
162    #[test]
163    fn test_load_balancer_new_ip_hash() {
164        let configs = Vec::new();
165        let _ = LoadBalancer::new(Strategy::IpHash, IpExtractor::XForwardedFor, &configs)
166            .expect("failed to create load balancer");
167    }
168
169    #[test]
170    fn test_load_balancer_extract_ip() {
171        let configs = Vec::new();
172        let load_balancer =
173            LoadBalancer::new(Strategy::RoundRobin, IpExtractor::XForwardedFor, &configs)
174                .expect("failed to create load balancer");
175        let request = Request::builder()
176            .header("x-forwarded-for", "192.168.0.1")
177            .body(())
178            .unwrap();
179        let ip = load_balancer
180            .extract_ip(&request, Some("127.0.0.1".parse().unwrap()))
181            .expect("failed to extract ip");
182
183        assert_eq!(ip, "192.168.0.1".parse::<IpAddr>().unwrap());
184    }
185
186    #[tokio::test]
187    async fn test_load_balancer_get_client() {
188        use crate::client::ExtractSocketAddr;
189        let configs = [
190            WorkerConfig::dummy("test", "127.0.0.1:9999", true),
191            WorkerConfig::dummy("test", "127.0.0.1:9998", true),
192        ];
193        let load_balancer =
194            LoadBalancer::new(Strategy::RoundRobin, IpExtractor::XForwardedFor, &configs)
195                .expect("failed to create load balancer");
196        let ip = "192.168.0.1".parse().unwrap();
197        let client = load_balancer
198            .get_client_ip(ip)
199            .await
200            .expect("failed to get client");
201        assert_eq!(client.socket_addr(), "127.0.0.1:9999".parse().unwrap());
202
203        let client = load_balancer
204            .get_client_ip(ip)
205            .await
206            .expect("failed to get client");
207
208        assert_eq!(client.socket_addr(), "127.0.0.1:9998".parse().unwrap());
209    }
210
211    #[test]
212    fn test_clone_load_balancer() {
213        let configs = Vec::new();
214        let load_balancer =
215            LoadBalancer::new(Strategy::RoundRobin, IpExtractor::XForwardedFor, &configs)
216                .expect("failed to create load balancer");
217        let _ = load_balancer.clone();
218    }
219}