Skip to main content

spider_middleware/
proxy.rs

1//! Auto-Rotate Proxy Middleware for rotating proxies during crawling.
2//!
3//! This middleware manages and rotates proxy URLs for outgoing requests.
4//! It supports loading proxies from a list or a file and offers different
5//! rotation strategies.
6
7use async_trait::async_trait;
8use rand::seq::SliceRandom;
9use serde::{Deserialize, Serialize};
10use std::borrow::Cow;
11use std::fmt::Debug;
12use std::fs::File;
13use std::io::{BufRead, BufReader};
14use std::path::{Path, PathBuf};
15use std::sync::atomic::{AtomicUsize, Ordering};
16use std::sync::Arc;
17use tracing::{info, warn};
18
19use spider_util::error::SpiderError;
20use crate::middleware::{Middleware, MiddlewareAction};
21use spider_util::request::Request;
22use spider_util::response::Response;
23
24/// Defines the strategy for rotating proxies.
25#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
26pub enum ProxyRotationStrategy {
27    /// Sequentially cycles through the available proxies.
28    #[default]
29    Sequential,
30    /// Randomly selects a proxy from the available pool.
31    Random,
32    /// Uses one proxy until a failure is detected (based on status or body), then rotates.
33    StickyFailover,
34}
35
36/// Defines the source from which proxies are loaded.
37#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(untagged)]
39pub enum ProxySource {
40    /// A direct list of proxy URLs.
41    List(Vec<String>),
42    /// Path to a file containing proxy URLs, one per line.
43    File(PathBuf),
44}
45
46impl Default for ProxySource {
47    fn default() -> Self {
48        ProxySource::List(Vec::new())
49    }
50}
51
52/// Builder for creating an `ProxyMiddleware`.
53#[derive(Debug, Clone, Default, Serialize, Deserialize)]
54pub struct ProxyMiddlewareBuilder {
55    source: ProxySource,
56    strategy: ProxyRotationStrategy,
57    block_detection_texts: Vec<String>,
58}
59
60impl ProxyMiddlewareBuilder {
61    /// Sets the primary source for proxies.
62    pub fn source(mut self, source: ProxySource) -> Self {
63        self.source = source;
64        self
65    }
66
67    /// Sets the strategy to use for rotating proxies.
68    pub fn strategy(mut self, strategy: ProxyRotationStrategy) -> Self {
69        self.strategy = strategy;
70        self
71    }
72
73    /// Sets the texts to detect in the response body to trigger a proxy rotation.
74    /// This is only used with the `StickyFailover` strategy.
75    pub fn with_block_detection_texts(mut self, texts: Vec<String>) -> Self {
76        self.block_detection_texts = texts;
77        self
78    }
79
80    /// Builds the `ProxyMiddleware`.
81    /// This can fail if a proxy source file is specified but cannot be read.
82    pub fn build(self) -> Result<ProxyMiddleware, SpiderError> {
83        let proxies = Arc::new(ProxyMiddleware::load_proxies(&self.source)?);
84
85        let block_texts = if self.block_detection_texts.is_empty() {
86            None
87        } else {
88            Some(self.block_detection_texts)
89        };
90
91        let middleware = ProxyMiddleware {
92            strategy: self.strategy,
93            proxies,
94            current_index: AtomicUsize::new(0),
95            block_detection_texts: block_texts,
96        };
97
98        info!(
99            "Initializing ProxyMiddleware with config: {:?}",
100            middleware
101        );
102
103        Ok(middleware)
104    }
105}
106
107pub struct ProxyMiddleware {
108    strategy: ProxyRotationStrategy,
109    proxies: Arc<Vec<String>>,
110    current_index: AtomicUsize,
111    block_detection_texts: Option<Vec<String>>,
112}
113
114impl Debug for ProxyMiddleware {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        f.debug_struct("ProxyMiddleware")
117            .field("strategy", &self.strategy)
118            .field("proxies", &format!("Pool({})", self.proxies.len()))
119            .field("current_index", &self.current_index)
120            .field("block_detection_texts", &self.block_detection_texts)
121            .finish()
122    }
123}
124
125impl ProxyMiddleware {
126    /// Creates a new `ProxyMiddlewareBuilder` to start building the middleware.
127    pub fn builder() -> ProxyMiddlewareBuilder {
128        ProxyMiddlewareBuilder::default()
129    }
130
131    fn load_proxies(source: &ProxySource) -> Result<Vec<String>, SpiderError> {
132        match source {
133            ProxySource::List(list) => Ok(list.clone()),
134            ProxySource::File(path) => Self::load_from_file(path),
135        }
136    }
137
138    fn load_from_file(path: &Path) -> Result<Vec<String>, SpiderError> {
139        if !path.exists() {
140            return Err(SpiderError::IoError(
141                std::io::Error::new(
142                    std::io::ErrorKind::NotFound,
143                    format!("Proxy file not found: {}", path.display()),
144                )
145                .to_string(),
146            ));
147        }
148        let file = File::open(path)?;
149        let reader = BufReader::new(file);
150        let proxies: Vec<String> = reader
151            .lines()
152            .map_while(Result::ok)
153            .filter(|line| !line.trim().is_empty())
154            .collect();
155
156        if proxies.is_empty() {
157            warn!(
158                "Proxy file {:?} is empty or contains no valid proxy URLs.",
159                path
160            );
161        }
162        Ok(proxies)
163    }
164
165    fn get_proxy(&self) -> Option<String> {
166        if self.proxies.is_empty() {
167            return None;
168        }
169
170        match self.strategy {
171            ProxyRotationStrategy::Sequential => {
172                let current = self.current_index.fetch_add(1, Ordering::SeqCst);
173                let index = current % self.proxies.len();
174                self.proxies.get(index).cloned()
175            }
176            ProxyRotationStrategy::Random => {
177                let mut rng = rand::thread_rng();
178                self.proxies.choose(&mut rng).cloned()
179            }
180            ProxyRotationStrategy::StickyFailover => {
181                let current = self.current_index.load(Ordering::SeqCst);
182                let index = current % self.proxies.len();
183                self.proxies.get(index).cloned()
184            }
185        }
186    }
187
188    fn rotate_proxy(&self) {
189        if !self.proxies.is_empty() {
190            self.current_index.fetch_add(1, Ordering::SeqCst);
191            info!("Proxy rotation triggered due to failure.");
192        }
193    }
194}
195
196#[async_trait]
197impl<C: Send + Sync> Middleware<C> for ProxyMiddleware {
198    fn name(&self) -> &str {
199        "ProxyMiddleware"
200    }
201
202    async fn process_request(
203        &mut self,
204        _client: &C,
205        request: Request,
206    ) -> Result<MiddlewareAction<Request>, SpiderError> {
207        if let Some(proxy) = self.get_proxy() {
208            request.meta.insert(Cow::Borrowed("proxy"), proxy.into());
209        }
210        Ok(MiddlewareAction::Continue(request))
211    }
212
213    async fn process_response(
214        &mut self,
215        response: Response,
216    ) -> Result<MiddlewareAction<Response>, SpiderError> {
217        if self.strategy != ProxyRotationStrategy::StickyFailover {
218            return Ok(MiddlewareAction::Continue(response));
219        }
220
221        let mut rotate = false;
222        let status = response.status;
223
224        // Check for bad status codes
225        if status.is_client_error() || status.is_server_error() {
226            // e.g., 403 Forbidden, 429 Too Many Requests, 5xx errors
227            rotate = true;
228        }
229
230        // Check for block texts in body if status is OK
231        if status.is_success() && let Some(texts) = &self.block_detection_texts {
232            let body_str = String::from_utf8_lossy(&response.body);
233            if texts.iter().any(|text| body_str.contains(text)) {
234                rotate = true;
235                info!(
236                    "Block detection text found in response body from {}",
237                    response.url
238                );
239            }
240        }
241
242        if rotate {
243            self.rotate_proxy();
244        }
245
246        Ok(MiddlewareAction::Continue(response))
247    }
248
249    async fn handle_error(
250        &mut self,
251        _request: &Request,
252        error: &SpiderError,
253    ) -> Result<MiddlewareAction<Request>, SpiderError> {
254        if self.strategy == ProxyRotationStrategy::StickyFailover {
255            self.rotate_proxy();
256        }
257
258        // Pass the error along for other middlewares (like Retry) to handle.
259        Err(error.clone())
260    }
261}