load_balancer/
ip.rs

1use crate::{
2    BoxLoadBalancer, LoadBalancer,
3    limit::{LimitLoadBalancer, LimitLoadBalancerRef},
4};
5use async_trait::async_trait;
6use get_if_addrs::get_if_addrs;
7use reqwest::{Client, ClientBuilder};
8use std::{net::IpAddr, sync::Arc, time::Duration};
9
10/// Load balancer for `reqwest::Client` instances bound to specific IP addresses.
11/// Supports limits per client and optional interval-based resets.
12#[derive(Clone)]
13pub struct IPClientLoadBalancer {
14    inner: LimitLoadBalancer<Client>,
15}
16
17impl IPClientLoadBalancer {
18    /// Create a new load balancer with fixed clients and limits.
19    pub fn new(entries: Vec<(u64, Client)>) -> Self {
20        Self {
21            inner: LimitLoadBalancer::new(entries),
22        }
23    }
24
25    /// Create a new load balancer with a custom interval for resetting limits.
26    pub fn new_interval(entries: Vec<(u64, Client)>, interval: Duration) -> Self {
27        Self {
28            inner: LimitLoadBalancer::new_interval(entries, interval),
29        }
30    }
31
32    /// Build a load balancer using all local IPv4 addresses.
33    pub fn with_ipv4(limit: u64) -> Self {
34        Self {
35            inner: LimitLoadBalancer::new(
36                get_ipv4_list()
37                    .clone()
38                    .into_iter()
39                    .map(|v| ClientBuilder::new().local_address(v).build().unwrap())
40                    .map(move |v| (limit, v))
41                    .collect(),
42            ),
43        }
44    }
45
46    /// Build a load balancer using IPv4 addresses with a custom interval.
47    pub fn with_ipv4_interval(limit: u64, interval: Duration) -> Self {
48        Self {
49            inner: LimitLoadBalancer::new_interval(
50                get_ipv4_list()
51                    .clone()
52                    .into_iter()
53                    .map(|v| ClientBuilder::new().local_address(v).build().unwrap())
54                    .map(move |v| (limit, v))
55                    .collect(),
56                interval,
57            ),
58        }
59    }
60
61    /// Build a load balancer using all local IPv6 addresses.
62    pub fn with_ipv6(limit: u64) -> Self {
63        Self {
64            inner: LimitLoadBalancer::new(
65                get_ipv6_list()
66                    .clone()
67                    .into_iter()
68                    .map(|v| ClientBuilder::new().local_address(v).build().unwrap())
69                    .map(move |v| (limit, v))
70                    .collect(),
71            ),
72        }
73    }
74
75    /// Build a load balancer using IPv6 addresses with a custom interval.
76    pub fn with_ipv6_interval(limit: u64, interval: Duration) -> Self {
77        Self {
78            inner: LimitLoadBalancer::new_interval(
79                get_ipv6_list()
80                    .clone()
81                    .into_iter()
82                    .map(|v| ClientBuilder::new().local_address(v).build().unwrap())
83                    .map(move |v| (limit, v))
84                    .collect(),
85                interval,
86            ),
87        }
88    }
89
90    /// Build a load balancer using IPv4 addresses with a per-client timeout.
91    pub fn with_ipv4_timeout(limit: u64, timeout: Duration) -> Self {
92        Self {
93            inner: LimitLoadBalancer::new(
94                get_ipv4_list()
95                    .clone()
96                    .into_iter()
97                    .map(|v| {
98                        ClientBuilder::new()
99                            .local_address(v)
100                            .timeout(timeout)
101                            .build()
102                            .unwrap()
103                    })
104                    .map(move |v| (limit, v))
105                    .collect(),
106            ),
107        }
108    }
109
110    /// Build a load balancer using IPv4 addresses with interval and timeout.
111    pub fn with_ipv4_interval_timeout(limit: u64, interval: Duration, timeout: Duration) -> Self {
112        Self {
113            inner: LimitLoadBalancer::new_interval(
114                get_ipv4_list()
115                    .clone()
116                    .into_iter()
117                    .map(|v| {
118                        ClientBuilder::new()
119                            .local_address(v)
120                            .timeout(timeout)
121                            .build()
122                            .unwrap()
123                    })
124                    .map(move |v| (limit, v))
125                    .collect(),
126                interval,
127            ),
128        }
129    }
130
131    /// Build a load balancer using IPv6 addresses with a per-client timeout.
132    pub fn with_ipv6_timeout(limit: u64, timeout: Duration) -> Self {
133        Self {
134            inner: LimitLoadBalancer::new(
135                get_ipv6_list()
136                    .clone()
137                    .into_iter()
138                    .map(|v| {
139                        ClientBuilder::new()
140                            .local_address(v)
141                            .timeout(timeout)
142                            .build()
143                            .unwrap()
144                    })
145                    .map(move |v| (limit, v))
146                    .collect(),
147            ),
148        }
149    }
150
151    /// Build a load balancer using IPv6 addresses with interval and timeout.
152    pub fn with_ipv6_interval_timeout(limit: u64, interval: Duration, timeout: Duration) -> Self {
153        Self {
154            inner: LimitLoadBalancer::new_interval(
155                get_ipv6_list()
156                    .clone()
157                    .into_iter()
158                    .map(|v| {
159                        ClientBuilder::new()
160                            .local_address(v)
161                            .timeout(timeout)
162                            .build()
163                            .unwrap()
164                    })
165                    .map(move |v| (limit, v))
166                    .collect(),
167                interval,
168            ),
169        }
170    }
171
172    /// Update the internal load balancer using a custom async closure.
173    pub async fn update<F, R>(&self, handle: F) -> anyhow::Result<()>
174    where
175        F: Fn(Arc<LimitLoadBalancerRef<Client>>) -> R,
176        R: Future<Output = anyhow::Result<()>>,
177    {
178        self.inner.update(handle).await
179    }
180}
181
182impl LoadBalancer<Client> for IPClientLoadBalancer {
183    /// Allocate a client asynchronously.
184    fn alloc(&self) -> impl Future<Output = Option<Client>> + Send {
185        LoadBalancer::alloc(&self.inner)
186    }
187
188    /// Attempt to allocate a client immediately.
189    fn try_alloc(&self) -> Option<Client> {
190        LoadBalancer::try_alloc(&self.inner)
191    }
192}
193
194#[async_trait]
195impl BoxLoadBalancer<Client> for IPClientLoadBalancer {
196    /// Allocate a client asynchronously.
197    async fn alloc(&self) -> Option<Client> {
198        LoadBalancer::alloc(self).await
199    }
200
201    /// Attempt to allocate a client immediately.
202    fn try_alloc(&self) -> Option<Client> {
203        LoadBalancer::try_alloc(self)
204    }
205}
206
207/// Get all non-loopback IP addresses of the machine.
208pub fn get_ip_list() -> Vec<IpAddr> {
209    get_if_addrs()
210        .unwrap()
211        .into_iter()
212        .filter(|v| !v.is_loopback())
213        .map(|v| v.ip())
214        .collect::<Vec<_>>()
215}
216
217/// Get all non-loopback IPv4 addresses of the machine.
218pub fn get_ipv4_list() -> Vec<IpAddr> {
219    get_if_addrs()
220        .unwrap()
221        .into_iter()
222        .filter(|v| !v.is_loopback() && v.ip().is_ipv4())
223        .map(|v| v.ip())
224        .collect::<Vec<_>>()
225}
226
227/// Get all non-loopback IPv6 addresses of the machine.
228pub fn get_ipv6_list() -> Vec<IpAddr> {
229    get_if_addrs()
230        .unwrap()
231        .into_iter()
232        .filter(|v| !v.is_loopback() && v.ip().is_ipv6())
233        .map(|v| v.ip())
234        .collect::<Vec<_>>()
235}