Skip to main content

reqwest_proxy_pool/
middleware.rs

1//! Middleware implementation for reqwest.
2
3use crate::classifier::ProxyBodyVerdict;
4use crate::config::{ClientBuilderFactory, HostConfig, ProxyPoolConfig, RetryStrategy};
5use crate::error::NoProxyAvailable;
6use crate::pool::ProxyPool;
7
8use anyhow::anyhow;
9use async_trait::async_trait;
10use log::{info, warn};
11use parking_lot::RwLock;
12use reqwest::ResponseBuilderExt;
13use reqwest_middleware::{Error, Middleware, Next, Result};
14use std::collections::{HashMap, HashSet};
15use std::sync::Arc;
16
17/// Middleware that routes requests to host-bound proxy pools.
18#[derive(Clone)]
19pub struct ProxyPoolMiddleware {
20    /// Host -> pool mapping.
21    pools: HashMap<String, Arc<ProxyPool>>,
22    /// Primary pool used for unknown hosts.
23    primary_host: String,
24    /// Factory used to create request clients before attaching proxy.
25    client_builder_factory: ClientBuilderFactory,
26    /// Built client cache keyed by proxy URL.
27    client_cache: Arc<RwLock<HashMap<String, reqwest::Client>>>,
28}
29
30impl ProxyPoolMiddleware {
31    /// Create middleware with host-based routing.
32    pub async fn new(config: ProxyPoolConfig) -> Result<Self> {
33        if config.hosts().is_empty() {
34            return Err(Error::Middleware(anyhow!(
35                "ProxyPoolConfig must contain at least one HostConfig"
36            )));
37        }
38        if config.sources().is_empty() {
39            return Err(Error::Middleware(anyhow!(
40                "ProxyPoolConfig.sources cannot be empty"
41            )));
42        }
43
44        let primary_host = validate_hosts(config.hosts())?;
45
46        let client_builder_factory = Arc::clone(config.client_builder_factory());
47        let mut pools = HashMap::new();
48        for host_config in config.hosts().iter().cloned() {
49            let host = host_config.host().to_ascii_lowercase();
50            let pool = ProxyPool::new(config.sources().to_vec(), host_config)
51                .await
52                .map_err(Error::Reqwest)?;
53            let (total, healthy) = pool.get_stats();
54            info!(
55                "Host pool [{}] initialized with {}/{} healthy proxies",
56                host, healthy, total
57            );
58            if healthy == 0 {
59                warn!("No healthy proxies available in host pool [{}]", host);
60            }
61
62            pools.insert(host, pool);
63        }
64
65        Ok(Self {
66            pools,
67            primary_host,
68            client_builder_factory,
69            client_cache: Arc::new(RwLock::new(HashMap::new())),
70        })
71    }
72
73    fn resolve_pool(&self, req: &reqwest::Request) -> Option<Arc<ProxyPool>> {
74        let host = req.url().host_str().map(|h| h.to_ascii_lowercase());
75        if let Some(host) = host {
76            if let Some(pool) = self.pools.get(&host) {
77                return Some(Arc::clone(pool));
78            }
79        }
80        self.pools.get(&self.primary_host).map(Arc::clone)
81    }
82
83    fn get_or_build_client(
84        &self,
85        proxy_url: &str,
86        reqwest_proxy: reqwest::Proxy,
87    ) -> std::result::Result<reqwest::Client, reqwest::Error> {
88        if let Some(existing) = self.client_cache.read().get(proxy_url).cloned() {
89            return Ok(existing);
90        }
91
92        let built = (self.client_builder_factory)()
93            .proxy(reqwest_proxy)
94            .build()?;
95
96        let mut cache = self.client_cache.write();
97        let cached = cache
98            .entry(proxy_url.to_string())
99            .or_insert_with(|| built.clone());
100        Ok(cached.clone())
101    }
102}
103
104fn validate_hosts(hosts: &[HostConfig]) -> Result<String> {
105    let mut seen = HashSet::new();
106    let mut primary_hosts = Vec::new();
107
108    for host_config in hosts {
109        let host = host_config.host().trim().to_ascii_lowercase();
110        if host.is_empty() {
111            return Err(Error::Middleware(anyhow!(
112                "HostConfig.host cannot be empty"
113            )));
114        }
115        if !seen.insert(host.clone()) {
116            return Err(Error::Middleware(anyhow!(
117                "duplicate HostConfig for host: {}",
118                host
119            )));
120        }
121        if host_config.primary() {
122            primary_hosts.push(host);
123        }
124    }
125
126    if primary_hosts.is_empty() {
127        return Err(Error::Middleware(anyhow!(
128            "exactly one HostConfig must set primary=true, found none"
129        )));
130    }
131    if primary_hosts.len() > 1 {
132        return Err(Error::Middleware(anyhow!(
133            "exactly one HostConfig must set primary=true, found {} ({:?})",
134            primary_hosts.len(),
135            primary_hosts
136        )));
137    }
138
139    Ok(primary_hosts.remove(0))
140}
141
142#[cfg(test)]
143mod tests {
144    use super::validate_hosts;
145    use crate::config::HostConfig;
146
147    #[test]
148    fn validate_hosts_requires_one_primary() {
149        let hosts = vec![
150            HostConfig::builder("a.example.com").build(),
151            HostConfig::builder("b.example.com").build(),
152        ];
153        assert!(validate_hosts(&hosts).is_err());
154    }
155
156    #[test]
157    fn validate_hosts_rejects_multiple_primary() {
158        let hosts = vec![
159            HostConfig::builder("a.example.com").primary(true).build(),
160            HostConfig::builder("b.example.com").primary(true).build(),
161        ];
162        assert!(validate_hosts(&hosts).is_err());
163    }
164
165    #[test]
166    fn validate_hosts_returns_primary_host() {
167        let hosts = vec![
168            HostConfig::builder("a.example.com").build(),
169            HostConfig::builder("b.example.com").primary(true).build(),
170        ];
171        let primary = validate_hosts(&hosts).expect("primary host should be valid");
172        assert_eq!(primary, "b.example.com");
173    }
174}
175
176#[async_trait]
177impl Middleware for ProxyPoolMiddleware {
178    async fn handle(
179        &self,
180        req: reqwest::Request,
181        _extensions: &mut http::Extensions,
182        _next: Next<'_>,
183    ) -> Result<reqwest::Response> {
184        let pool = self.resolve_pool(&req).ok_or_else(|| {
185            Error::Middleware(anyhow!(
186                "No pool available for request host and no primary host pool configured"
187            ))
188        })?;
189
190        let max_retries = pool.config.retry_count;
191        let mut retry_count = 0;
192        let mut used_proxy_urls = HashSet::new();
193
194        loop {
195            let proxy_result = match pool.config.retry_strategy {
196                RetryStrategy::DefaultSelection => pool.get_proxy(),
197                RetryStrategy::NewProxyOnRetry => {
198                    if retry_count == 0 {
199                        pool.get_proxy()
200                    } else {
201                        pool.get_proxy_excluding(&used_proxy_urls)
202                    }
203                }
204            };
205
206            match proxy_result {
207                Ok(proxy) => {
208                    let proxied_request = req.try_clone().ok_or_else(|| {
209                        Error::Middleware(anyhow!(
210                            "Request object is not cloneable. Are you passing a streaming body?"
211                                .to_string()
212                        ))
213                    })?;
214
215                    let proxy_url = proxy.url.clone();
216                    used_proxy_urls.insert(proxy_url.clone());
217                    info!("Using proxy: {} (attempt {})", proxy_url, retry_count + 1);
218
219                    proxy.limiter.until_ready().await;
220
221                    let reqwest_proxy = match proxy.to_reqwest_proxy() {
222                        Ok(p) => p,
223                        Err(e) => {
224                            warn!("Failed to create proxy from {}: {}", proxy_url, e);
225                            pool.report_proxy_failure(&proxy_url);
226                            retry_count += 1;
227                            if retry_count > max_retries {
228                                return Err(Error::Reqwest(e));
229                            }
230                            continue;
231                        }
232                    };
233
234                    let client = match self.get_or_build_client(&proxy_url, reqwest_proxy) {
235                        Ok(c) => c,
236                        Err(e) => {
237                            warn!("Failed to build client with proxy {}: {}", proxy_url, e);
238                            pool.report_proxy_failure(&proxy_url);
239                            retry_count += 1;
240                            if retry_count > max_retries {
241                                return Err(Error::Reqwest(e));
242                            }
243                            continue;
244                        }
245                    };
246
247                    match client.execute(proxied_request).await {
248                        Ok(response) => {
249                            let status = response.status();
250                            let version = response.version();
251                            let headers = response.headers().clone();
252                            let url = response.url().clone();
253
254                            let body = match response.bytes().await {
255                                Ok(body) => body,
256                                Err(err) => {
257                                    warn!(
258                                        "Read body failed with proxy {} (attempt {}): {}",
259                                        proxy_url,
260                                        retry_count + 1,
261                                        err
262                                    );
263                                    pool.report_proxy_failure(&proxy_url);
264                                    retry_count += 1;
265                                    if retry_count > max_retries {
266                                        return Err(Error::Reqwest(err));
267                                    }
268                                    continue;
269                                }
270                            };
271
272                            let verdict = pool.config.body_classifier.classify(
273                                status,
274                                &headers,
275                                body.as_ref(),
276                            );
277                            let rebuilt =
278                                rebuild_response(status, version, headers, url, body.to_vec())
279                                    .map_err(|e| {
280                                        Error::Middleware(anyhow!(
281                                        "Failed to rebuild response after body classification: {}",
282                                        e
283                                    ))
284                                    })?;
285
286                            match verdict {
287                                ProxyBodyVerdict::Success => {
288                                    pool.report_proxy_success(&proxy_url);
289                                    return Ok(rebuilt);
290                                }
291                                ProxyBodyVerdict::ProxyBlocked => {
292                                    warn!(
293                                        "Proxy {} blocked by target site (attempt {})",
294                                        proxy_url,
295                                        retry_count + 1
296                                    );
297                                    pool.report_proxy_failure(&proxy_url);
298                                    retry_count += 1;
299                                    if retry_count > max_retries {
300                                        return Ok(rebuilt);
301                                    }
302                                }
303                                ProxyBodyVerdict::Passthrough => {
304                                    return Ok(rebuilt);
305                                }
306                            }
307                        }
308                        Err(err) => {
309                            warn!(
310                                "Request failed with proxy {} (attempt {}): {}",
311                                proxy_url,
312                                retry_count + 1,
313                                err
314                            );
315                            pool.report_proxy_failure(&proxy_url);
316                            retry_count += 1;
317                            if retry_count > max_retries {
318                                return Err(Error::Reqwest(err));
319                            }
320                        }
321                    }
322                }
323                Err(_) => {
324                    let (total, healthy) = pool.get_stats();
325                    warn!(
326                        "No proxy available in selected host pool. Total: {}, Healthy: {}",
327                        total, healthy
328                    );
329                    return Err(Error::Middleware(anyhow!(NoProxyAvailable)));
330                }
331            }
332        }
333    }
334}
335
336fn rebuild_response(
337    status: reqwest::StatusCode,
338    version: reqwest::Version,
339    headers: reqwest::header::HeaderMap,
340    url: reqwest::Url,
341    body: Vec<u8>,
342) -> std::result::Result<reqwest::Response, http::Error> {
343    let mut builder = http::Response::builder()
344        .status(status)
345        .version(version)
346        .url(url);
347    if let Some(headers_mut) = builder.headers_mut() {
348        *headers_mut = headers;
349    }
350    let http_response = builder.body(body)?;
351    Ok(reqwest::Response::from(http_response))
352}