1use std::sync::{
7 Arc,
8 atomic::{AtomicUsize, Ordering},
9};
10
11use futures_util::{FutureExt, future::BoxFuture};
12use http::{Request, Response, StatusCode};
13use rand::RngExt;
14use tower::{Layer, Service};
15
16use crate::{
17 Body, Error, Proxy, client::layer::config::RequestOptions, config::RequestConfig,
18 error::BoxError, proxy::Matcher,
19};
20
21#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
23pub enum ProxyPoolStrategy {
24 RandomPerRequest,
26 #[default]
28 StickyFailover,
29}
30
31#[derive(Default)]
33pub struct ProxyPoolBuilder {
34 proxies: Vec<Proxy>,
35 strategy: ProxyPoolStrategy,
36}
37
38#[derive(Clone, Debug)]
40pub struct ProxyPool {
41 inner: Arc<Inner>,
42}
43
44#[derive(Debug)]
45struct Inner {
46 strategy: ProxyPoolStrategy,
47 matchers: Vec<Matcher>,
48 sticky_index: AtomicUsize,
49}
50
51#[derive(Clone)]
52pub(crate) struct ProxyPoolLayer {
53 pool: ProxyPool,
54}
55
56#[derive(Clone)]
57pub(crate) struct ProxyPoolService<S> {
58 inner: S,
59 pool: ProxyPool,
60}
61
62#[derive(Clone)]
63struct Selection {
64 index: usize,
65 matcher: Matcher,
66}
67
68impl ProxyPoolBuilder {
71 #[inline]
73 pub fn new() -> Self {
74 Self::default()
75 }
76
77 #[inline]
79 pub fn strategy(mut self, strategy: ProxyPoolStrategy) -> Self {
80 self.strategy = strategy;
81 self
82 }
83
84 #[inline]
86 pub fn proxy(mut self, proxy: Proxy) -> Self {
87 self.proxies.push(proxy);
88 self
89 }
90
91 #[inline]
93 pub fn proxies<I>(mut self, proxies: I) -> Self
94 where
95 I: IntoIterator<Item = Proxy>,
96 {
97 self.proxies.extend(proxies);
98 self
99 }
100
101 #[inline]
103 pub fn build(self) -> crate::Result<ProxyPool> {
104 ProxyPool::with_strategy(self.proxies, self.strategy)
105 }
106}
107
108impl ProxyPool {
111 #[inline]
113 pub fn new(proxies: Vec<Proxy>) -> crate::Result<Self> {
114 Self::with_strategy(proxies, ProxyPoolStrategy::StickyFailover)
115 }
116
117 pub fn with_strategy(proxies: Vec<Proxy>, strategy: ProxyPoolStrategy) -> crate::Result<Self> {
119 let matchers: Vec<Matcher> = proxies.into_iter().map(Proxy::into_matcher).collect();
120
121 if matchers.is_empty() {
122 return Err(Error::builder("proxy pool cannot be empty"));
123 }
124
125 Ok(Self {
126 inner: Arc::new(Inner {
127 strategy,
128 matchers,
129 sticky_index: AtomicUsize::new(0),
130 }),
131 })
132 }
133
134 #[inline]
136 pub fn builder() -> ProxyPoolBuilder {
137 ProxyPoolBuilder::new()
138 }
139
140 #[inline]
142 pub fn strategy(&self) -> ProxyPoolStrategy {
143 self.inner.strategy
144 }
145
146 #[inline]
148 pub fn len(&self) -> usize {
149 self.inner.matchers.len()
150 }
151
152 #[inline]
154 pub fn is_empty(&self) -> bool {
155 self.inner.matchers.is_empty()
156 }
157
158 #[inline]
160 pub fn is_failure_status(status: StatusCode) -> bool {
161 status == StatusCode::PROXY_AUTHENTICATION_REQUIRED
162 || status == StatusCode::TOO_MANY_REQUESTS
163 || status.is_server_error()
164 }
165
166 #[inline]
167 pub(crate) fn layer(&self) -> ProxyPoolLayer {
168 ProxyPoolLayer { pool: self.clone() }
169 }
170
171 fn select(&self) -> Selection {
172 let len = self.inner.matchers.len();
173 let index = match self.inner.strategy {
174 ProxyPoolStrategy::RandomPerRequest => {
175 let mut rng = rand::rng();
176 rng.random_range(0..len)
177 }
178 ProxyPoolStrategy::StickyFailover => {
179 self.inner.sticky_index.load(Ordering::Relaxed) % len
180 }
181 };
182
183 Selection {
184 index,
185 matcher: self.inner.matchers[index].clone(),
186 }
187 }
188
189 fn record_status(&self, selected_index: usize, status: StatusCode) {
190 if Self::is_failure_status(status) {
191 self.record_failure(selected_index);
192 }
193 }
194
195 fn record_error(&self, selected_index: usize, _error: &BoxError) {
196 self.record_failure(selected_index);
197 }
198
199 fn record_failure(&self, selected_index: usize) {
200 if self.inner.strategy != ProxyPoolStrategy::StickyFailover {
201 return;
202 }
203
204 let len = self.inner.matchers.len();
205 if len <= 1 {
206 return;
207 }
208
209 let next = (selected_index + 1) % len;
210 let _ = self.inner.sticky_index.compare_exchange(
211 selected_index,
212 next,
213 Ordering::AcqRel,
214 Ordering::Relaxed,
215 );
216 }
217}
218
219impl ProxyPoolLayer {
222 #[inline]
223 pub(crate) fn new(pool: ProxyPool) -> Self {
224 Self { pool }
225 }
226}
227
228impl<S> Layer<S> for ProxyPoolLayer {
229 type Service = ProxyPoolService<S>;
230
231 #[inline]
232 fn layer(&self, inner: S) -> Self::Service {
233 ProxyPoolService {
234 inner,
235 pool: self.pool.clone(),
236 }
237 }
238}
239
240impl<S, ResBody> Service<Request<Body>> for ProxyPoolService<S>
243where
244 S: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + 'static,
245 S::Error: Into<BoxError> + Send,
246 S::Future: Send + 'static,
247 ResBody: Send + 'static,
248{
249 type Response = Response<ResBody>;
250 type Error = BoxError;
251 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
252
253 #[inline]
254 fn poll_ready(
255 &mut self,
256 cx: &mut std::task::Context<'_>,
257 ) -> std::task::Poll<Result<(), Self::Error>> {
258 self.inner.poll_ready(cx).map_err(Into::into)
259 }
260
261 fn call(&mut self, mut req: Request<Body>) -> Self::Future {
262 let selected = self.pool.select();
263
264 RequestConfig::<RequestOptions>::get_mut(req.extensions_mut())
265 .get_or_insert_default()
266 .proxy_matcher_mut()
267 .replace(selected.matcher.clone());
268
269 let pool = self.pool.clone();
270 let mut inner = self.inner.clone();
271
272 async move {
273 match inner.call(req).await {
274 Ok(response) => {
275 pool.record_status(selected.index, response.status());
276 Ok(response)
277 }
278 Err(error) => {
279 let boxed_error: BoxError = error.into();
280 pool.record_error(selected.index, &boxed_error);
281 Err(boxed_error)
282 }
283 }
284 }
285 .boxed()
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use http::StatusCode;
292
293 use super::*;
294
295 fn make_pool(strategy: ProxyPoolStrategy) -> ProxyPool {
296 ProxyPool::with_strategy(
297 vec![
298 Proxy::all("http://proxy-a:8080").expect("proxy a should parse"),
299 Proxy::all("http://proxy-b:8080").expect("proxy b should parse"),
300 Proxy::all("http://proxy-c:8080").expect("proxy c should parse"),
301 ],
302 strategy,
303 )
304 .expect("pool should be non-empty")
305 }
306
307 #[test]
308 fn sticky_failover_switches_only_after_failure() {
309 let pool = make_pool(ProxyPoolStrategy::StickyFailover);
310
311 assert_eq!(pool.select().index, 0);
312
313 pool.record_status(0, StatusCode::OK);
314 assert_eq!(pool.select().index, 0);
315
316 pool.record_status(0, StatusCode::BAD_GATEWAY);
317 assert_eq!(pool.select().index, 1);
318
319 pool.record_status(1, StatusCode::OK);
320 assert_eq!(pool.select().index, 1);
321
322 pool.record_status(1, StatusCode::TOO_MANY_REQUESTS);
323 assert_eq!(pool.select().index, 2);
324 }
325
326 #[test]
327 fn random_strategy_does_not_advance_sticky_cursor_on_failure() {
328 let pool = make_pool(ProxyPoolStrategy::RandomPerRequest);
329
330 pool.record_status(0, StatusCode::BAD_GATEWAY);
331
332 assert_eq!(pool.inner.sticky_index.load(Ordering::Relaxed), 0);
333 }
334
335 #[test]
336 fn failure_status_classifier_matches_policy() {
337 assert!(ProxyPool::is_failure_status(
338 StatusCode::PROXY_AUTHENTICATION_REQUIRED
339 ));
340 assert!(ProxyPool::is_failure_status(StatusCode::TOO_MANY_REQUESTS));
341 assert!(ProxyPool::is_failure_status(
342 StatusCode::SERVICE_UNAVAILABLE
343 ));
344 assert!(!ProxyPool::is_failure_status(StatusCode::BAD_REQUEST));
345 assert!(!ProxyPool::is_failure_status(StatusCode::OK));
346 }
347
348 #[test]
349 fn empty_pool_is_rejected() {
350 let result = ProxyPool::with_strategy(Vec::new(), ProxyPoolStrategy::StickyFailover);
351 assert!(result.is_err());
352 }
353}