Skip to main content

spider_lib/
builder.rs

1//! Builder for constructing and configuring the `Crawler` instance.
2//!
3//! This module provides the `CrawlerBuilder`, a fluent API for
4//! setting up and customizing a web crawler. It simplifies the process of
5//! assembling various `spider-lib` components, including:
6//! - Defining concurrency settings for downloads, parsing, and pipelines.
7//! - Attaching custom `Downloader` implementations.
8//! - Registering `Middleware`s to process requests and responses.
9//! - Adding `Pipeline`s to process scraped items.
10//! - Configuring checkpointing for persistence and fault tolerance.
11//! - Initializing and integrating a `StatCollector` for gathering crawl statistics.
12//!
13//! The builder handles default configurations (e.g., adding a default User-Agent
14//! middleware if none is specified) and loading existing checkpoints.
15
16use crate::ConsoleWriterPipeline;
17#[cfg(feature = "checkpoint")]
18use crate::checkpoint::Checkpoint;
19use crate::{CookieStore, SchedulerCheckpoint};
20
21use crate::downloader::Downloader;
22use crate::downloaders::reqwest_client::ReqwestClientDownloader;
23use crate::error::SpiderError;
24use crate::middleware::Middleware;
25use crate::middlewares::user_agent::{UserAgentMiddleware, UserAgentSource};
26use crate::pipeline::Pipeline;
27use crate::scheduler::Scheduler;
28use crate::spider::Spider;
29use num_cpus;
30#[cfg(feature = "middleware-cookies")]
31use std::any::Any;
32#[cfg(feature = "checkpoint")]
33use std::fs;
34use std::marker::PhantomData;
35#[cfg(feature = "checkpoint")]
36use std::path::{Path, PathBuf};
37use std::sync::Arc;
38#[cfg(feature = "checkpoint")]
39use std::time::Duration;
40#[cfg(feature = "middleware-cookies")]
41use tracing::info;
42#[cfg(feature = "checkpoint")]
43use tracing::{debug, warn};
44
45#[cfg(feature = "middleware-cookies")]
46use crate::middlewares::cookies::CookieMiddleware;
47
48#[cfg(feature = "middleware-cookies")]
49use tokio::sync::Mutex;
50
51use super::Crawler;
52use crate::stats::StatCollector;
53
54/// Configuration for the crawler's concurrency settings.
55pub struct CrawlerConfig {
56    /// The maximum number of concurrent downloads.
57    pub max_concurrent_downloads: usize,
58    /// The number of workers dedicated to parsing responses.
59    pub parser_workers: usize,
60    /// The maximum number of concurrent item processing pipelines.
61    pub max_concurrent_pipelines: usize,
62}
63
64impl Default for CrawlerConfig {
65    fn default() -> Self {
66        CrawlerConfig {
67            max_concurrent_downloads: 5,
68            parser_workers: num_cpus::get(),
69            max_concurrent_pipelines: 5,
70        }
71    }
72}
73
74pub struct CrawlerBuilder<S: Spider, D = ReqwestClientDownloader>
75where
76    D: Downloader,
77{
78    crawler_config: CrawlerConfig,
79    downloader: D,
80    spider: Option<S>,
81    middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
82    item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
83    #[cfg(feature = "checkpoint")]
84    checkpoint_path: Option<PathBuf>,
85    #[cfg(feature = "checkpoint")]
86    checkpoint_interval: Option<Duration>,
87    _phantom: PhantomData<S>,
88}
89
90impl<S: Spider, D: Default + Downloader> Default for CrawlerBuilder<S, D> {
91    fn default() -> Self {
92        Self {
93            crawler_config: CrawlerConfig::default(),
94            downloader: D::default(),
95            spider: None,
96            middlewares: Vec::new(),
97            item_pipelines: Vec::new(),
98            #[cfg(feature = "checkpoint")]
99            checkpoint_path: None,
100            #[cfg(feature = "checkpoint")]
101            checkpoint_interval: None,
102            _phantom: PhantomData,
103        }
104    }
105}
106
107impl<S: Spider> CrawlerBuilder<S> {
108    /// Creates a new `CrawlerBuilder` for a given spider.
109    pub fn new(spider: S) -> Self {
110        Self {
111            spider: Some(spider),
112            ..Default::default()
113        }
114    }
115}
116
117impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
118    /// Sets the maximum number of concurrent downloads.
119    pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
120        self.crawler_config.max_concurrent_downloads = limit;
121        self
122    }
123
124    /// Sets the maximum number of concurrent parser workers.
125    pub fn max_parser_workers(mut self, limit: usize) -> Self {
126        self.crawler_config.parser_workers = limit;
127        self
128    }
129
130    /// Sets the maximum number of concurrent pipelines.
131    pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
132        self.crawler_config.max_concurrent_pipelines = limit;
133        self
134    }
135
136    /// Sets a custom downloader for the crawler.
137    pub fn downloader(mut self, downloader: D) -> Self {
138        self.downloader = downloader;
139        self
140    }
141
142    /// Adds a middleware to the crawler.
143    pub fn add_middleware<M>(mut self, middleware: M) -> Self
144    where
145        D: Downloader,
146        M: Middleware<D::Client> + Send + Sync + 'static,
147    {
148        self.middlewares.push(Box::new(middleware));
149        self
150    }
151
152    /// Adds an item pipeline to the crawler.
153    pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
154    where
155        P: Pipeline<S::Item> + 'static,
156    {
157        self.item_pipelines.push(Box::new(pipeline));
158        self
159    }
160
161    /// Enables checkpointing and sets the path for the checkpoint file.
162    #[cfg(feature = "checkpoint")]
163    pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
164        self.checkpoint_path = Some(path.as_ref().to_path_buf());
165        self
166    }
167
168    /// Sets the interval for periodic checkpointing.
169    #[cfg(feature = "checkpoint")]
170    pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
171        self.checkpoint_interval = Some(interval);
172        self
173    }
174
175    /// Builds the `Crawler` instance, initializing and passing the `StatCollector` along with other configured components.
176    #[allow(unused_variables)]
177    pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
178    where
179        D: Downloader + Send + Sync + 'static,
180        D::Client: Send + Sync + Clone,
181    {
182        let spider = self.validate_and_get_spider()?;
183        self.ensure_default_components().await?;
184
185        #[cfg(feature = "checkpoint")]
186        let (initial_scheduler_state, loaded_cookie_store) =
187            self.load_and_restore_checkpoint_state().await?;
188
189        #[cfg(not(feature = "checkpoint"))]
190        let (initial_scheduler_state, loaded_cookie_store): (
191            Option<SchedulerCheckpoint>,
192            Option<CookieStore>,
193        ) = (None, None);
194
195        let (scheduler_arc, req_rx) = Scheduler::new(initial_scheduler_state);
196
197        #[cfg(feature = "middleware-cookies")]
198        let final_cookie_store = self.resolve_final_cookie_store(loaded_cookie_store);
199
200        let downloader_arc = Arc::new(self.downloader);
201        let stats = Arc::new(StatCollector::new());
202
203        let crawler = Crawler::new(
204            scheduler_arc,
205            req_rx,
206            downloader_arc,
207            self.middlewares,
208            spider,
209            self.item_pipelines,
210            self.crawler_config.max_concurrent_downloads,
211            self.crawler_config.parser_workers,
212            self.crawler_config.max_concurrent_pipelines,
213            #[cfg(feature = "checkpoint")]
214            self.checkpoint_path.take(),
215            #[cfg(feature = "checkpoint")]
216            self.checkpoint_interval,
217            stats,
218            #[cfg(feature = "middleware-cookies")]
219            final_cookie_store,
220        );
221
222        Ok(crawler)
223    }
224
225    #[cfg(feature = "middleware-cookies")]
226    fn resolve_final_cookie_store(
227        &self,
228        loaded_cookie_store: Option<CookieStore>,
229    ) -> Arc<Mutex<CookieStore>> {
230        let mut final_cookie_store = Arc::new(Mutex::new(loaded_cookie_store.unwrap_or_default()));
231
232        // Detect CookieMiddleware and use its store, overriding checkpoint or default
233        for mw_box in &self.middlewares {
234            if let Some(cookie_mw) =
235                (mw_box.as_ref() as &dyn Any).downcast_ref::<CookieMiddleware>()
236            {
237                info!(
238                    "Found CookieMiddleware, using its cookie store for Crawler. This overrides any checkpointed store."
239                );
240                final_cookie_store = cookie_mw.store.clone();
241                break;
242            }
243        }
244        final_cookie_store
245    }
246
247    #[cfg(feature = "checkpoint")]
248    async fn load_and_restore_checkpoint_state(
249        &mut self,
250    ) -> Result<(Option<SchedulerCheckpoint>, Option<CookieStore>), SpiderError> {
251        let mut initial_scheduler_state = None;
252        let mut loaded_pipelines_state = None;
253        #[cfg(feature = "middleware-cookies")]
254        let mut loaded_cookie_store = None;
255
256        if let Some(path) = &self.checkpoint_path {
257            debug!("Attempting to load checkpoint from {:?}", path);
258            match fs::read(path) {
259                Ok(bytes) => match rmp_serde::from_slice::<Checkpoint>(&bytes) {
260                    Ok(checkpoint) => {
261                        initial_scheduler_state = Some(checkpoint.scheduler);
262                        loaded_pipelines_state = Some(checkpoint.pipelines);
263
264                        #[cfg(feature = "middleware-cookies")]
265                        {
266                            info!("Checkpoint loaded, restoring cookie store data.");
267                            loaded_cookie_store = Some(checkpoint.cookie_store);
268                        }
269                    }
270                    Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
271                },
272                Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
273            }
274        }
275
276        if let Some(pipeline_states) = loaded_pipelines_state {
277            for (name, state) in pipeline_states {
278                if let Some(pipeline) = self.item_pipelines.iter().find(|p| p.name() == name) {
279                    pipeline.restore_state(state).await?;
280                } else {
281                    warn!("Checkpoint contains state for unknown pipeline: {}", name);
282                }
283            }
284        }
285
286        #[cfg(not(feature = "middleware-cookies"))]
287        return Ok((initial_scheduler_state, None));
288        #[cfg(feature = "middleware-cookies")]
289        return Ok((initial_scheduler_state, loaded_cookie_store));
290    }
291
292    fn validate_and_get_spider(&mut self) -> Result<S, SpiderError> {
293        if self.crawler_config.max_concurrent_downloads == 0 {
294            return Err(SpiderError::ConfigurationError(
295                "max_concurrent_downloads must be greater than 0.".to_string(),
296            ));
297        }
298        if self.crawler_config.parser_workers == 0 {
299            return Err(SpiderError::ConfigurationError(
300                "parser_workers must be greater than 0.".to_string(),
301            ));
302        }
303        self.spider.take().ok_or_else(|| {
304            SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
305        })
306    }
307
308    async fn ensure_default_components(&mut self) -> Result<(), SpiderError> {
309        if self.item_pipelines.is_empty() {
310            self.item_pipelines
311                .push(Box::new(ConsoleWriterPipeline::new()));
312        }
313
314        let has_user_agent_middleware = self
315            .middlewares
316            .iter()
317            .any(|m| m.name() == "UserAgentMiddleware");
318
319        if !has_user_agent_middleware {
320            let pkg_name = env!("CARGO_PKG_NAME");
321            let pkg_version = env!("CARGO_PKG_VERSION");
322            let default_user_agent = format!("{}/{}", pkg_name, pkg_version);
323
324            let default_user_agent_mw = UserAgentMiddleware::builder()
325                .source(UserAgentSource::List(vec![default_user_agent.clone()]))
326                .fallback_user_agent(default_user_agent)
327                .build()?;
328            self.middlewares.insert(0, Box::new(default_user_agent_mw));
329        }
330
331        Ok(())
332    }
333}