Skip to main content

reqwest_proxy_pool/
middleware.rs

1//! Middleware implementation for reqwest.
2
3use crate::classifier::ProxyResponseVerdict;
4use crate::config::{HostConfig, ProxyPoolConfig, RetryStrategy};
5use crate::error::NoProxyAvailable;
6use crate::pool::ProxyPool;
7
8use anyhow::anyhow;
9use async_trait::async_trait;
10use log::{info, warn};
11use reqwest_middleware::{Error, Middleware, Next, Result};
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14
15/// Middleware that routes requests to host-bound proxy pools.
16#[derive(Clone)]
17pub struct ProxyPoolMiddleware {
18    /// Host -> pool mapping.
19    pools: HashMap<String, Arc<ProxyPool>>,
20    /// Primary pool used for unknown hosts.
21    primary_host: String,
22}
23
24impl ProxyPoolMiddleware {
25    /// Create middleware with host-based routing.
26    pub async fn new(config: ProxyPoolConfig) -> Result<Self> {
27        if config.hosts().is_empty() {
28            return Err(Error::Middleware(anyhow!(
29                "ProxyPoolConfig must contain at least one HostConfig"
30            )));
31        }
32        if config.sources().is_empty() {
33            return Err(Error::Middleware(anyhow!(
34                "ProxyPoolConfig.sources cannot be empty"
35            )));
36        }
37
38        let primary_host = validate_hosts(config.hosts())?;
39
40        let mut pools = HashMap::new();
41        for host_config in config.hosts().iter().cloned() {
42            let host = host_config.host().to_ascii_lowercase();
43            let pool = ProxyPool::new(config.sources().to_vec(), host_config)
44                .await
45                .map_err(Error::Reqwest)?;
46            let (total, healthy) = pool.get_stats();
47            info!(
48                "Host pool [{}] initialized with {}/{} healthy proxies",
49                host, healthy, total
50            );
51            if healthy == 0 {
52                warn!("No healthy proxies available in host pool [{}]", host);
53            }
54
55            pools.insert(host, pool);
56        }
57
58        Ok(Self {
59            pools,
60            primary_host,
61        })
62    }
63
64    fn resolve_pool(&self, req: &reqwest::Request) -> Option<Arc<ProxyPool>> {
65        let host = req.url().host_str().map(|h| h.to_ascii_lowercase());
66        if let Some(host) = host {
67            if let Some(pool) = self.pools.get(&host) {
68                return Some(Arc::clone(pool));
69            }
70        }
71        self.pools.get(&self.primary_host).map(Arc::clone)
72    }
73}
74
75fn validate_hosts(hosts: &[HostConfig]) -> Result<String> {
76    let mut seen = HashSet::new();
77    let mut primary_hosts = Vec::new();
78
79    for host_config in hosts {
80        let host = host_config.host().trim().to_ascii_lowercase();
81        if host.is_empty() {
82            return Err(Error::Middleware(anyhow!(
83                "HostConfig.host cannot be empty"
84            )));
85        }
86        if !seen.insert(host.clone()) {
87            return Err(Error::Middleware(anyhow!(
88                "duplicate HostConfig for host: {}",
89                host
90            )));
91        }
92        if host_config.primary() {
93            primary_hosts.push(host);
94        }
95    }
96
97    if primary_hosts.is_empty() {
98        return Err(Error::Middleware(anyhow!(
99            "exactly one HostConfig must set primary=true, found none"
100        )));
101    }
102    if primary_hosts.len() > 1 {
103        return Err(Error::Middleware(anyhow!(
104            "exactly one HostConfig must set primary=true, found {} ({:?})",
105            primary_hosts.len(),
106            primary_hosts
107        )));
108    }
109
110    Ok(primary_hosts.remove(0))
111}
112
113#[cfg(test)]
114mod tests {
115    use super::validate_hosts;
116    use crate::config::HostConfig;
117
118    #[test]
119    fn validate_hosts_requires_one_primary() {
120        let hosts = vec![
121            HostConfig::builder("a.example.com").build(),
122            HostConfig::builder("b.example.com").build(),
123        ];
124        assert!(validate_hosts(&hosts).is_err());
125    }
126
127    #[test]
128    fn validate_hosts_rejects_multiple_primary() {
129        let hosts = vec![
130            HostConfig::builder("a.example.com").primary(true).build(),
131            HostConfig::builder("b.example.com").primary(true).build(),
132        ];
133        assert!(validate_hosts(&hosts).is_err());
134    }
135
136    #[test]
137    fn validate_hosts_returns_primary_host() {
138        let hosts = vec![
139            HostConfig::builder("a.example.com").build(),
140            HostConfig::builder("b.example.com").primary(true).build(),
141        ];
142        let primary = validate_hosts(&hosts).expect("primary host should be valid");
143        assert_eq!(primary, "b.example.com");
144    }
145}
146
147#[async_trait]
148impl Middleware for ProxyPoolMiddleware {
149    async fn handle(
150        &self,
151        req: reqwest::Request,
152        _extensions: &mut http::Extensions,
153        _next: Next<'_>,
154    ) -> Result<reqwest::Response> {
155        let pool = self.resolve_pool(&req).ok_or_else(|| {
156            Error::Middleware(anyhow!(
157                "No pool available for request host and no primary host pool configured"
158            ))
159        })?;
160
161        let max_retries = pool.config.retry_count;
162        let mut retry_count = 0;
163        let mut used_proxy_urls = HashSet::new();
164
165        loop {
166            let proxy_result = match pool.config.retry_strategy {
167                RetryStrategy::DefaultSelection => pool.get_proxy(),
168                RetryStrategy::NewProxyOnRetry => {
169                    if retry_count == 0 {
170                        pool.get_proxy()
171                    } else {
172                        pool.get_proxy_excluding(&used_proxy_urls)
173                    }
174                }
175            };
176
177            match proxy_result {
178                Ok(proxy) => {
179                    let proxied_request = req.try_clone().ok_or_else(|| {
180                        Error::Middleware(anyhow!(
181                            "Request object is not cloneable. Are you passing a streaming body?"
182                                .to_string()
183                        ))
184                    })?;
185
186                    let proxy_url = proxy.url.clone();
187                    used_proxy_urls.insert(proxy_url.clone());
188                    info!("Using proxy: {} (attempt {})", proxy_url, retry_count + 1);
189
190                    proxy.limiter.until_ready().await;
191
192                    let reqwest_proxy = match proxy.to_reqwest_proxy() {
193                        Ok(p) => p,
194                        Err(e) => {
195                            warn!("Failed to create proxy from {}: {}", proxy_url, e);
196                            pool.report_proxy_failure(&proxy_url);
197                            retry_count += 1;
198                            if retry_count > max_retries {
199                                return Err(Error::Reqwest(e));
200                            }
201                            continue;
202                        }
203                    };
204
205                    let client = match reqwest::Client::builder()
206                        .proxy(reqwest_proxy)
207                        .timeout(pool.config.health_check_timeout)
208                        .danger_accept_invalid_certs(pool.config.danger_accept_invalid_certs)
209                        .build()
210                    {
211                        Ok(c) => c,
212                        Err(e) => {
213                            warn!("Failed to build client with proxy {}: {}", proxy_url, e);
214                            pool.report_proxy_failure(&proxy_url);
215                            retry_count += 1;
216                            if retry_count > max_retries {
217                                return Err(Error::Reqwest(e));
218                            }
219                            continue;
220                        }
221                    };
222
223                    match client.execute(proxied_request).await {
224                        Ok(response) => match pool.config.response_classifier.classify(&response) {
225                            ProxyResponseVerdict::Success => {
226                                pool.report_proxy_success(&proxy_url);
227                                return Ok(response);
228                            }
229                            ProxyResponseVerdict::ProxyBlocked => {
230                                warn!(
231                                    "Proxy {} blocked by target site (attempt {})",
232                                    proxy_url,
233                                    retry_count + 1
234                                );
235                                pool.report_proxy_failure(&proxy_url);
236                                retry_count += 1;
237                                if retry_count > max_retries {
238                                    return Ok(response);
239                                }
240                            }
241                            ProxyResponseVerdict::Passthrough => {
242                                return Ok(response);
243                            }
244                        },
245                        Err(err) => {
246                            warn!(
247                                "Request failed with proxy {} (attempt {}): {}",
248                                proxy_url,
249                                retry_count + 1,
250                                err
251                            );
252                            pool.report_proxy_failure(&proxy_url);
253                            retry_count += 1;
254                            if retry_count > max_retries {
255                                return Err(Error::Reqwest(err));
256                            }
257                        }
258                    }
259                }
260                Err(_) => {
261                    let (total, healthy) = pool.get_stats();
262                    warn!(
263                        "No proxy available in selected host pool. Total: {}, Healthy: {}",
264                        total, healthy
265                    );
266                    return Err(Error::Middleware(anyhow!(NoProxyAvailable)));
267                }
268            }
269        }
270    }
271}