load_balancer/
ip.rs

1use crate::{BoxLoadBalancer, LoadBalancer, interval::IntervalLoadBalancer};
2use async_trait::async_trait;
3use futures::future::join_all;
4use get_if_addrs::get_if_addrs;
5use reqwest::{Client, ClientBuilder, Proxy};
6use std::{net::IpAddr, sync::Arc, time::Duration};
7
8/// Load balancer for `reqwest::Client` instances bound to specific IP addresses.
9/// Uses interval-based allocation.
10#[derive(Clone)]
11pub struct IPClient {
12    inner: IntervalLoadBalancer<Client>,
13}
14
15impl IPClient {
16    /// Create a new interval-based load balancer with given clients.
17    pub fn new(entries: Vec<(Duration, Client)>) -> Self {
18        Self {
19            inner: IntervalLoadBalancer::new(entries),
20        }
21    }
22
23    /// Build a load balancer using all local IP addresses.
24    pub fn with_ip(ip: Vec<IpAddr>, interval: Duration) -> Self {
25        Self {
26            inner: IntervalLoadBalancer::new(
27                ip.into_iter()
28                    .map(|v| {
29                        (
30                            interval,
31                            ClientBuilder::new().local_address(v).build().unwrap(),
32                        )
33                    })
34                    .collect(),
35            ),
36        }
37    }
38
39    /// Build a load balancer using IP addresses with a per-client timeout.
40    pub fn with_timeout(ip: Vec<IpAddr>, interval: Duration, timeout: Duration) -> Self {
41        Self {
42            inner: IntervalLoadBalancer::new(
43                ip.into_iter()
44                    .map(|v| {
45                        (
46                            interval,
47                            ClientBuilder::new()
48                                .local_address(v)
49                                .timeout(timeout)
50                                .build()
51                                .unwrap(),
52                        )
53                    })
54                    .collect(),
55            ),
56        }
57    }
58
59    /// Build a load balancer using IP addresses with a per-client timeout, use proxy.
60    pub fn with_timeout_proxy(
61        ip: Vec<IpAddr>,
62        interval: Duration,
63        timeout: Duration,
64        proxy: Proxy,
65    ) -> Self {
66        Self {
67            inner: IntervalLoadBalancer::new(
68                ip.into_iter()
69                    .map(|v| {
70                        (
71                            interval,
72                            ClientBuilder::new()
73                                .local_address(v)
74                                .timeout(timeout)
75                                .proxy(proxy.clone())
76                                .build()
77                                .unwrap(),
78                        )
79                    })
80                    .collect(),
81            ),
82        }
83    }
84
85    /// Update the internal load balancer.
86    pub async fn update<F, R>(&self, handle: F) -> anyhow::Result<()>
87    where
88        F: Fn(Arc<std::sync::RwLock<Vec<crate::interval::Entry<Client>>>>) -> R,
89        R: std::future::Future<Output = anyhow::Result<()>>,
90    {
91        self.inner.update(handle).await
92    }
93}
94
95impl LoadBalancer<Client> for IPClient {
96    fn alloc(&self) -> impl std::future::Future<Output = Client> + Send {
97        LoadBalancer::alloc(&self.inner)
98    }
99
100    fn try_alloc(&self) -> Option<Client> {
101        LoadBalancer::try_alloc(&self.inner)
102    }
103}
104
105#[async_trait]
106impl BoxLoadBalancer<Client> for IPClient {
107    async fn alloc(&self) -> Client {
108        LoadBalancer::alloc(self).await
109    }
110
111    fn try_alloc(&self) -> Option<Client> {
112        LoadBalancer::try_alloc(self)
113    }
114}
115
116/// Get all non-loopback IP addresses of the machine.
117pub fn get_ip_list() -> anyhow::Result<Vec<IpAddr>> {
118    Ok(get_if_addrs()?
119        .into_iter()
120        .filter(|v| !v.is_loopback())
121        .map(|v| v.ip())
122        .collect::<Vec<_>>())
123}
124
125/// Get all non-loopback IPv4 addresses of the machine.
126pub fn get_ipv4_list() -> anyhow::Result<Vec<IpAddr>> {
127    Ok(get_if_addrs()?
128        .into_iter()
129        .filter(|v| !v.is_loopback() && v.ip().is_ipv4())
130        .map(|v| v.ip())
131        .collect::<Vec<_>>())
132}
133
134/// Get all non-loopback IPv6 addresses of the machine.
135pub fn get_ipv6_list() -> anyhow::Result<Vec<IpAddr>> {
136    Ok(get_if_addrs()?
137        .into_iter()
138        .filter(|v| !v.is_loopback() && v.ip().is_ipv6())
139        .map(|v| v.ip())
140        .collect::<Vec<_>>())
141}
142
143pub async fn test_ip(ip: IpAddr) -> anyhow::Result<IpAddr> {
144    reqwest::ClientBuilder::new()
145        .timeout(Duration::from_secs(3))
146        .local_address(ip)
147        .build()?
148        .get("https://bilibili.com")
149        .send()
150        .await?;
151
152    Ok(ip)
153}
154
155pub async fn test_all_ip() -> Vec<anyhow::Result<IpAddr>> {
156    match get_ip_list() {
157        Ok(v) => join_all(v.into_iter().map(|v| test_ip(v))).await,
158        Err(_) => Vec::new(),
159    }
160}
161
162pub async fn test_all_ipv4() -> Vec<anyhow::Result<IpAddr>> {
163    match get_ipv4_list() {
164        Ok(v) => join_all(v.into_iter().map(|v| test_ip(v))).await,
165        Err(_) => Vec::new(),
166    }
167}
168
169pub async fn test_all_ipv6() -> Vec<anyhow::Result<IpAddr>> {
170    match get_ipv6_list() {
171        Ok(v) => join_all(v.into_iter().map(|v| test_ip(v))).await,
172        Err(_) => Vec::new(),
173    }
174}