load_balancer/
proxy.rs

1use crate::{
2    BoxLoadBalancer, LoadBalancer,
3    simple::{Entry, SimpleLoadBalancer, SimpleLoadBalancerRef},
4};
5use async_trait::async_trait;
6use reqwest::Proxy;
7use std::{
8    ops::Range,
9    sync::{Arc, atomic::Ordering},
10    time::Duration,
11};
12use tokio::{
13    spawn,
14    sync::Semaphore,
15    task::JoinHandle,
16    time::{Instant, sleep},
17};
18
19/// An advanced proxy pool that measures latency, removes dead proxies,
20/// and sorts proxies by response time in ascending order.
21#[derive(Clone)]
22pub struct ProxyPool {
23    code_range: Range<u16>,
24    test_url: String,
25    timeout: Duration,
26    proxy: Option<Proxy>,
27    max_check_concurrency: usize,
28    lb: SimpleLoadBalancer<Arc<str>>,
29}
30
31impl ProxyPool {
32    /// Create a new `LatencyProxyPool` from a list of proxy URLs.
33    pub fn new<T: IntoIterator<Item = impl AsRef<str>>>(url: T) -> Self {
34        Self {
35            code_range: (200..300),
36            test_url: "https://apple.com".to_string(),
37            timeout: Duration::from_secs(5),
38            proxy: None,
39            max_check_concurrency: 1000,
40            lb: SimpleLoadBalancer::new(url.into_iter().map(|v| v.as_ref().into()).collect()),
41        }
42    }
43
44    /// Set the range of HTTP status codes that are considered successful.
45    pub fn code_range(mut self, code_range: Range<u16>) -> Self {
46        self.code_range = code_range;
47        self
48    }
49
50    /// Set the URL used for testing proxy connectivity.
51    pub fn test_url(mut self, test_url: String) -> Self {
52        self.test_url = test_url;
53        self
54    }
55
56    /// Set the request timeout for proxy testing.
57    pub fn timeout(mut self, timeout: Duration) -> Self {
58        self.timeout = timeout;
59        self
60    }
61
62    /// Set an optional upstream proxy for proxy validation.
63    pub fn proxy(mut self, proxy: Proxy) -> Self {
64        self.proxy = Some(proxy);
65        self
66    }
67
68    /// Set the maximum number of concurrent proxy checks during health validation.
69    pub fn max_check_concurrency(mut self, max_check_concurrency: usize) -> Self {
70        self.max_check_concurrency = max_check_concurrency;
71        self
72    }
73
74    /// Get the number of currently available (healthy) proxies.
75    pub async fn available_count(&self) -> usize {
76        self.lb
77            .update(async |v| Ok(v.entries.read().await.len()))
78            .await
79            .unwrap()
80    }
81
82    /// Get available proxies.
83    pub async fn available(&self) -> Vec<String> {
84        self.lb
85            .update(async |v| {
86                Ok(v.entries
87                    .read()
88                    .await
89                    .iter()
90                    .map(|v| v.value.to_string())
91                    .collect::<Vec<_>>())
92            })
93            .await
94            .unwrap()
95    }
96
97    /// Add new proxies to the pool without performing immediate validation.
98    ///
99    /// New entries are appended, the cursor is reset, and the available count is updated.
100    /// Validation occurs on the next `check()` call.
101    pub async fn extend<T: IntoIterator<Item = impl AsRef<str>>>(&self, urls: T) {
102        let new_entries = urls
103            .into_iter()
104            .map(|v| Entry {
105                value: Arc::from(v.as_ref()),
106            })
107            .collect::<Vec<_>>();
108
109        self.lb
110            .update(async |v| {
111                let mut lock = v.entries.write().await;
112
113                lock.extend(new_entries.clone());
114                v.cursor.store(0, Ordering::Relaxed);
115
116                Ok(())
117            })
118            .await
119            .unwrap();
120    }
121
122    /// Add new proxies and immediately perform connectivity and latency checks.
123    ///
124    /// Proxies are validated, failed ones are removed, and remaining entries
125    /// are sorted by latency (ascending).
126    pub async fn extend_check<T: IntoIterator<Item = impl AsRef<str>>>(
127        &self,
128        url: T,
129        retry_count: usize,
130    ) -> anyhow::Result<()> {
131        let new_entries = url
132            .into_iter()
133            .map(|v| Entry {
134                value: Arc::from(v.as_ref()),
135            })
136            .collect::<Vec<Entry<Arc<str>>>>();
137
138        self.lb
139            .update(async |v| {
140                let old_entries = {
141                    let lock = v.entries.read().await;
142                    let mut result = Vec::with_capacity(lock.len() + new_entries.len());
143
144                    result.extend_from_slice(&new_entries);
145                    result.extend(lock.iter().cloned());
146
147                    result
148                };
149
150                let result = self.internal_check(&old_entries, retry_count).await?;
151
152                let mut new_entries = Vec::with_capacity(result.len());
153
154                for (index, _) in result {
155                    new_entries.push(old_entries[index].clone());
156                }
157
158                let mut lock = v.entries.write().await;
159
160                *lock = new_entries;
161                v.cursor.store(0, Ordering::Relaxed);
162
163                Ok(())
164            })
165            .await
166    }
167
168    /// Validate all proxies, remove dead ones, and sort by latency.
169    pub async fn check(&self, retry_count: usize) -> anyhow::Result<()> {
170        self.lb
171            .update(async |v| {
172                let old_entries = v.entries.read().await;
173
174                let result = self.internal_check(&old_entries, retry_count).await?;
175
176                let mut new_entries = Vec::with_capacity(result.len());
177
178                for (index, _) in result {
179                    new_entries.push(old_entries[index].clone());
180                }
181
182                drop(old_entries);
183
184                let mut lock = v.entries.write().await;
185
186                *lock = new_entries;
187                v.cursor.store(0, Ordering::Relaxed);
188
189                Ok(())
190            })
191            .await
192    }
193
194    /// Spawn a background task to periodically validate proxies and update order by latency.
195    ///
196    /// Returns a `JoinHandle` to allow cancellation or awaiting of the task.
197    pub async fn spawn_check(
198        &self,
199        check_interval: Duration,
200        retry_count: usize,
201    ) -> anyhow::Result<JoinHandle<()>> {
202        self.check(retry_count).await?;
203
204        let this = self.clone();
205
206        Ok(spawn(async move {
207            loop {
208                sleep(check_interval).await;
209                _ = this.check(retry_count).await;
210            }
211        }))
212    }
213
214    /// Spawn a background task with a callback after each check.
215    pub async fn spawn_check_callback<F, R>(
216        &self,
217        check_interval: Duration,
218        retry_count: usize,
219        callback: F,
220    ) -> anyhow::Result<JoinHandle<anyhow::Result<()>>>
221    where
222        F: Fn() -> R + Send + 'static,
223        R: Future<Output = anyhow::Result<()>> + Send,
224    {
225        self.check(retry_count).await?;
226        callback().await?;
227
228        let this = self.clone();
229
230        Ok(spawn(async move {
231            loop {
232                sleep(check_interval).await;
233                _ = this.check(retry_count).await;
234                callback().await?;
235            }
236        }))
237    }
238
239    /// Update the load balancer using a custom async handler.
240    pub async fn update<F, R>(&self, handle: F) -> anyhow::Result<()>
241    where
242        F: Fn(Arc<SimpleLoadBalancerRef<Arc<str>>>) -> R,
243        R: Future<Output = anyhow::Result<()>>,
244    {
245        self.lb.update(handle).await
246    }
247
248    async fn internal_check(
249        &self,
250        entries: &Vec<Entry<Arc<str>>>,
251        retry_count: usize,
252    ) -> anyhow::Result<Vec<(usize, u128)>> {
253        let semaphore = Arc::new(Semaphore::new(self.max_check_concurrency));
254        let mut task = Vec::with_capacity(entries.len());
255
256        for (index, entry) in entries.iter().enumerate() {
257            let permit = semaphore.clone().acquire_owned().await.unwrap();
258            let entry = entry.clone();
259            let code_range = self.code_range.clone();
260            let test_url = self.test_url.clone();
261            let timeout = self.timeout;
262            let upstream_proxy = self.proxy.clone();
263            let entry_value = entry.value.clone();
264
265            task.push(tokio::spawn(async move {
266                let _permit = permit;
267                let mut latency = None;
268
269                for _ in 0..=retry_count {
270                    let client = if let Some(proxy) = upstream_proxy.clone() {
271                        reqwest::ClientBuilder::new()
272                            .proxy(proxy)
273                            .proxy(Proxy::all(&*entry_value)?)
274                            .timeout(timeout)
275                            .build()?
276                    } else {
277                        reqwest::ClientBuilder::new()
278                            .proxy(Proxy::all(&*entry_value)?)
279                            .timeout(timeout)
280                            .build()?
281                    };
282
283                    let start = Instant::now();
284
285                    if let Ok(v) = client.get(&test_url).send().await {
286                        if code_range.contains(&v.status().as_u16()) {
287                            latency = Some(start.elapsed().as_millis());
288                            break;
289                        }
290                    }
291                }
292
293                anyhow::Ok(latency.map(|v| (index, v)))
294            }));
295        }
296
297        let mut result = Vec::new();
298
299        for i in task {
300            if let Ok(Ok(Some(r))) = i.await {
301                result.push(r);
302            }
303        }
304
305        result.sort_by_key(|(_, latency)| *latency);
306
307        Ok(result)
308    }
309}
310
311impl LoadBalancer<String> for ProxyPool {
312    async fn alloc(&self) -> String {
313        LoadBalancer::alloc(&self.lb).await.to_string()
314    }
315
316    fn try_alloc(&self) -> Option<String> {
317        LoadBalancer::try_alloc(&self.lb).map(|v| v.to_string())
318    }
319}
320
321#[async_trait]
322impl BoxLoadBalancer<String> for ProxyPool {
323    async fn alloc(&self) -> String {
324        LoadBalancer::alloc(self).await
325    }
326
327    fn try_alloc(&self) -> Option<String> {
328        LoadBalancer::try_alloc(self)
329    }
330}