reqwest_proxy_pool/
middleware.rs1use crate::classifier::ProxyBodyVerdict;
4use crate::config::{ClientBuilderFactory, HostConfig, ProxyPoolConfig, RetryStrategy};
5use crate::error::NoProxyAvailable;
6use crate::pool::ProxyPool;
7
8use anyhow::anyhow;
9use async_trait::async_trait;
10use log::{info, warn};
11use parking_lot::RwLock;
12use reqwest::ResponseBuilderExt;
13use reqwest_middleware::{Error, Middleware, Next, Result};
14use std::collections::{HashMap, HashSet};
15use std::sync::Arc;
16
17#[derive(Clone)]
19pub struct ProxyPoolMiddleware {
20 pools: HashMap<String, Arc<ProxyPool>>,
22 primary_host: String,
24 client_builder_factory: ClientBuilderFactory,
26 client_cache: Arc<RwLock<HashMap<String, reqwest::Client>>>,
28}
29
30impl ProxyPoolMiddleware {
31 pub async fn new(config: ProxyPoolConfig) -> Result<Self> {
33 if config.hosts().is_empty() {
34 return Err(Error::Middleware(anyhow!(
35 "ProxyPoolConfig must contain at least one HostConfig"
36 )));
37 }
38 if config.sources().is_empty() {
39 return Err(Error::Middleware(anyhow!(
40 "ProxyPoolConfig.sources cannot be empty"
41 )));
42 }
43
44 let primary_host = validate_hosts(config.hosts())?;
45
46 let client_builder_factory = Arc::clone(config.client_builder_factory());
47 let mut pools = HashMap::new();
48 for host_config in config.hosts().iter().cloned() {
49 let host = host_config.host().to_ascii_lowercase();
50 let pool = ProxyPool::new(config.sources().to_vec(), host_config)
51 .await
52 .map_err(Error::Reqwest)?;
53 let (total, healthy) = pool.get_stats();
54 info!(
55 "Host pool [{}] initialized with {}/{} healthy proxies",
56 host, healthy, total
57 );
58 if healthy == 0 {
59 warn!("No healthy proxies available in host pool [{}]", host);
60 }
61
62 pools.insert(host, pool);
63 }
64
65 Ok(Self {
66 pools,
67 primary_host,
68 client_builder_factory,
69 client_cache: Arc::new(RwLock::new(HashMap::new())),
70 })
71 }
72
73 fn resolve_pool(&self, req: &reqwest::Request) -> Option<Arc<ProxyPool>> {
74 let host = req.url().host_str().map(|h| h.to_ascii_lowercase());
75 if let Some(host) = host {
76 if let Some(pool) = self.pools.get(&host) {
77 return Some(Arc::clone(pool));
78 }
79 }
80 self.pools.get(&self.primary_host).map(Arc::clone)
81 }
82
83 fn get_or_build_client(
84 &self,
85 proxy_url: &str,
86 reqwest_proxy: reqwest::Proxy,
87 ) -> std::result::Result<reqwest::Client, reqwest::Error> {
88 if let Some(existing) = self.client_cache.read().get(proxy_url).cloned() {
89 return Ok(existing);
90 }
91
92 let built = (self.client_builder_factory)()
93 .proxy(reqwest_proxy)
94 .build()?;
95
96 let mut cache = self.client_cache.write();
97 let cached = cache
98 .entry(proxy_url.to_string())
99 .or_insert_with(|| built.clone());
100 Ok(cached.clone())
101 }
102}
103
104fn validate_hosts(hosts: &[HostConfig]) -> Result<String> {
105 let mut seen = HashSet::new();
106 let mut primary_hosts = Vec::new();
107
108 for host_config in hosts {
109 let host = host_config.host().trim().to_ascii_lowercase();
110 if host.is_empty() {
111 return Err(Error::Middleware(anyhow!(
112 "HostConfig.host cannot be empty"
113 )));
114 }
115 if !seen.insert(host.clone()) {
116 return Err(Error::Middleware(anyhow!(
117 "duplicate HostConfig for host: {}",
118 host
119 )));
120 }
121 if host_config.primary() {
122 primary_hosts.push(host);
123 }
124 }
125
126 if primary_hosts.is_empty() {
127 return Err(Error::Middleware(anyhow!(
128 "exactly one HostConfig must set primary=true, found none"
129 )));
130 }
131 if primary_hosts.len() > 1 {
132 return Err(Error::Middleware(anyhow!(
133 "exactly one HostConfig must set primary=true, found {} ({:?})",
134 primary_hosts.len(),
135 primary_hosts
136 )));
137 }
138
139 Ok(primary_hosts.remove(0))
140}
141
142#[cfg(test)]
143mod tests {
144 use super::validate_hosts;
145 use crate::config::HostConfig;
146
147 #[test]
148 fn validate_hosts_requires_one_primary() {
149 let hosts = vec![
150 HostConfig::builder("a.example.com").build(),
151 HostConfig::builder("b.example.com").build(),
152 ];
153 assert!(validate_hosts(&hosts).is_err());
154 }
155
156 #[test]
157 fn validate_hosts_rejects_multiple_primary() {
158 let hosts = vec![
159 HostConfig::builder("a.example.com").primary(true).build(),
160 HostConfig::builder("b.example.com").primary(true).build(),
161 ];
162 assert!(validate_hosts(&hosts).is_err());
163 }
164
165 #[test]
166 fn validate_hosts_returns_primary_host() {
167 let hosts = vec![
168 HostConfig::builder("a.example.com").build(),
169 HostConfig::builder("b.example.com").primary(true).build(),
170 ];
171 let primary = validate_hosts(&hosts).expect("primary host should be valid");
172 assert_eq!(primary, "b.example.com");
173 }
174}
175
176#[async_trait]
177impl Middleware for ProxyPoolMiddleware {
178 async fn handle(
179 &self,
180 req: reqwest::Request,
181 _extensions: &mut http::Extensions,
182 _next: Next<'_>,
183 ) -> Result<reqwest::Response> {
184 let pool = self.resolve_pool(&req).ok_or_else(|| {
185 Error::Middleware(anyhow!(
186 "No pool available for request host and no primary host pool configured"
187 ))
188 })?;
189
190 let max_retries = pool.config.retry_count;
191 let mut retry_count = 0;
192 let mut used_proxy_urls = HashSet::new();
193
194 loop {
195 let proxy_result = match pool.config.retry_strategy {
196 RetryStrategy::DefaultSelection => pool.get_proxy(),
197 RetryStrategy::NewProxyOnRetry => {
198 if retry_count == 0 {
199 pool.get_proxy()
200 } else {
201 pool.get_proxy_excluding(&used_proxy_urls)
202 }
203 }
204 };
205
206 match proxy_result {
207 Ok(proxy) => {
208 let proxied_request = req.try_clone().ok_or_else(|| {
209 Error::Middleware(anyhow!(
210 "Request object is not cloneable. Are you passing a streaming body?"
211 .to_string()
212 ))
213 })?;
214
215 let proxy_url = proxy.url.clone();
216 used_proxy_urls.insert(proxy_url.clone());
217 info!("Using proxy: {} (attempt {})", proxy_url, retry_count + 1);
218
219 proxy.limiter.until_ready().await;
220
221 let reqwest_proxy = match proxy.to_reqwest_proxy() {
222 Ok(p) => p,
223 Err(e) => {
224 warn!("Failed to create proxy from {}: {}", proxy_url, e);
225 pool.report_proxy_failure(&proxy_url);
226 retry_count += 1;
227 if retry_count > max_retries {
228 return Err(Error::Reqwest(e));
229 }
230 continue;
231 }
232 };
233
234 let client = match self.get_or_build_client(&proxy_url, reqwest_proxy) {
235 Ok(c) => c,
236 Err(e) => {
237 warn!("Failed to build client with proxy {}: {}", proxy_url, e);
238 pool.report_proxy_failure(&proxy_url);
239 retry_count += 1;
240 if retry_count > max_retries {
241 return Err(Error::Reqwest(e));
242 }
243 continue;
244 }
245 };
246
247 match client.execute(proxied_request).await {
248 Ok(response) => {
249 let status = response.status();
250 let version = response.version();
251 let headers = response.headers().clone();
252 let url = response.url().clone();
253
254 let body = match response.bytes().await {
255 Ok(body) => body,
256 Err(err) => {
257 warn!(
258 "Read body failed with proxy {} (attempt {}): {}",
259 proxy_url,
260 retry_count + 1,
261 err
262 );
263 pool.report_proxy_failure(&proxy_url);
264 retry_count += 1;
265 if retry_count > max_retries {
266 return Err(Error::Reqwest(err));
267 }
268 continue;
269 }
270 };
271
272 let verdict = pool.config.body_classifier.classify(
273 status,
274 &headers,
275 body.as_ref(),
276 );
277 let rebuilt =
278 rebuild_response(status, version, headers, url, body.to_vec())
279 .map_err(|e| {
280 Error::Middleware(anyhow!(
281 "Failed to rebuild response after body classification: {}",
282 e
283 ))
284 })?;
285
286 match verdict {
287 ProxyBodyVerdict::Success => {
288 pool.report_proxy_success(&proxy_url);
289 return Ok(rebuilt);
290 }
291 ProxyBodyVerdict::ProxyBlocked => {
292 warn!(
293 "Proxy {} blocked by target site (attempt {})",
294 proxy_url,
295 retry_count + 1
296 );
297 pool.report_proxy_failure(&proxy_url);
298 retry_count += 1;
299 if retry_count > max_retries {
300 return Ok(rebuilt);
301 }
302 }
303 ProxyBodyVerdict::Passthrough => {
304 return Ok(rebuilt);
305 }
306 }
307 }
308 Err(err) => {
309 warn!(
310 "Request failed with proxy {} (attempt {}): {}",
311 proxy_url,
312 retry_count + 1,
313 err
314 );
315 pool.report_proxy_failure(&proxy_url);
316 retry_count += 1;
317 if retry_count > max_retries {
318 return Err(Error::Reqwest(err));
319 }
320 }
321 }
322 }
323 Err(_) => {
324 let (total, healthy) = pool.get_stats();
325 warn!(
326 "No proxy available in selected host pool. Total: {}, Healthy: {}",
327 total, healthy
328 );
329 return Err(Error::Middleware(anyhow!(NoProxyAvailable)));
330 }
331 }
332 }
333 }
334}
335
336fn rebuild_response(
337 status: reqwest::StatusCode,
338 version: reqwest::Version,
339 headers: reqwest::header::HeaderMap,
340 url: reqwest::Url,
341 body: Vec<u8>,
342) -> std::result::Result<reqwest::Response, http::Error> {
343 let mut builder = http::Response::builder()
344 .status(status)
345 .version(version)
346 .url(url);
347 if let Some(headers_mut) = builder.headers_mut() {
348 *headers_mut = headers;
349 }
350 let http_response = builder.body(body)?;
351 Ok(reqwest::Response::from(http_response))
352}