1use crate::ConsoleWriterPipeline;
16#[cfg(feature = "checkpoint")]
17use crate::checkpoint::Checkpoint;
18use crate::downloader::Downloader;
19use crate::downloaders::reqwest_client::ReqwestClientDownloader;
20use crate::error::SpiderError;
21use crate::middleware::Middleware;
22use crate::middlewares::user_agent::{UserAgentMiddleware, UserAgentSource};
23use crate::pipeline::Pipeline;
24use crate::scheduler::Scheduler;
25use crate::spider::Spider;
26use num_cpus;
27#[cfg(feature = "checkpoint")]
28use std::fs;
29use std::marker::PhantomData;
30#[cfg(feature = "checkpoint")]
31use std::path::{Path, PathBuf};
32use std::sync::Arc;
33#[cfg(feature = "checkpoint")]
34use std::time::Duration;
35#[cfg(feature = "checkpoint")]
36use tracing::{debug, warn};
37
38use super::Crawler;
39
40pub struct CrawlerConfig {
42 pub max_concurrent_downloads: usize,
44 pub parser_workers: usize,
46 pub max_concurrent_pipelines: usize,
48}
49
50impl Default for CrawlerConfig {
51 fn default() -> Self {
52 CrawlerConfig {
53 max_concurrent_downloads: 5,
54 parser_workers: num_cpus::get(),
55 max_concurrent_pipelines: 5,
56 }
57 }
58}
59
60pub struct CrawlerBuilder<S: Spider, D = ReqwestClientDownloader>
61where
62 D: Downloader,
63{
64 crawler_config: CrawlerConfig,
65 downloader: D,
66 spider: Option<S>,
67 middlewares: Vec<Box<dyn Middleware<D::Client> + Send + Sync>>,
68 item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
69 #[cfg(feature = "checkpoint")]
70 checkpoint_path: Option<PathBuf>,
71 #[cfg(feature = "checkpoint")]
72 checkpoint_interval: Option<Duration>,
73 _phantom: PhantomData<S>,
74}
75
76impl<S: Spider, D: Default + Downloader> Default for CrawlerBuilder<S, D> {
77 fn default() -> Self {
78 Self {
79 crawler_config: CrawlerConfig::default(),
80 downloader: D::default(),
81 spider: None,
82 middlewares: Vec::new(),
83 item_pipelines: Vec::new(),
84 #[cfg(feature = "checkpoint")]
85 checkpoint_path: None,
86 #[cfg(feature = "checkpoint")]
87 checkpoint_interval: None,
88 _phantom: PhantomData,
89 }
90 }
91}
92
93impl<S: Spider, D: Downloader> CrawlerBuilder<S, D> {
94 pub fn new(spider: S) -> Self
96 where
97 D: Default,
98 {
99 Self {
100 spider: Some(spider),
101 ..Default::default()
102 }
103 }
104
105 pub fn max_concurrent_downloads(mut self, limit: usize) -> Self {
107 self.crawler_config.max_concurrent_downloads = limit;
108 self
109 }
110
111 pub fn max_parser_workers(mut self, limit: usize) -> Self {
113 self.crawler_config.parser_workers = limit;
114 self
115 }
116
117 pub fn max_concurrent_pipelines(mut self, limit: usize) -> Self {
119 self.crawler_config.max_concurrent_pipelines = limit;
120 self
121 }
122
123 pub fn downloader(mut self, downloader: D) -> Self {
125 self.downloader = downloader;
126 self
127 }
128
129 pub fn add_middleware<M>(mut self, middleware: M) -> Self
131 where
132 D: Downloader,
133 M: Middleware<D::Client> + Send + Sync + 'static,
134 {
135 self.middlewares.push(Box::new(middleware));
136 self
137 }
138
139 pub fn add_pipeline<P>(mut self, pipeline: P) -> Self
141 where
142 P: Pipeline<S::Item> + 'static,
143 {
144 self.item_pipelines.push(Box::new(pipeline));
145 self
146 }
147
148 #[cfg(feature = "checkpoint")]
150 pub fn with_checkpoint_path<P: AsRef<Path>>(mut self, path: P) -> Self {
151 self.checkpoint_path = Some(path.as_ref().to_path_buf());
152 self
153 }
154
155 #[cfg(feature = "checkpoint")]
157 pub fn with_checkpoint_interval(mut self, interval: Duration) -> Self {
158 self.checkpoint_interval = Some(interval);
159 self
160 }
161
162 pub async fn build(mut self) -> Result<Crawler<S, D::Client>, SpiderError>
164 where
165 D: Downloader + Send + Sync + 'static,
166 D::Client: Send + Sync,
167 {
168 if self.item_pipelines.is_empty() {
169 self = self.add_pipeline(ConsoleWriterPipeline::new());
170 }
171
172 let spider = self.spider.take().ok_or_else(|| {
173 SpiderError::ConfigurationError("Crawler must have a spider.".to_string())
174 })?;
175
176 if self.crawler_config.max_concurrent_downloads == 0 {
177 return Err(SpiderError::ConfigurationError(
178 "max_concurrent_downloads must be greater than 0.".to_string(),
179 ));
180 }
181 if self.crawler_config.parser_workers == 0 {
182 return Err(SpiderError::ConfigurationError(
183 "parser_workers must be greater than 0.".to_string(),
184 ));
185 }
186
187 #[cfg(feature = "checkpoint")]
188 let mut initial_scheduler_state = None;
189 #[cfg(not(feature = "checkpoint"))]
190 let initial_scheduler_state = None;
191 #[cfg(feature = "checkpoint")]
192 let mut loaded_pipelines_state = None;
193
194 #[cfg(feature = "checkpoint")]
195 if let Some(path) = &self.checkpoint_path {
196 debug!("Attempting to load checkpoint from {:?}", path);
197 match fs::read(path) {
198 Ok(bytes) => match rmp_serde::from_slice::<Checkpoint>(&bytes) {
199 Ok(checkpoint) => {
200 initial_scheduler_state = Some(checkpoint.scheduler);
201 loaded_pipelines_state = Some(checkpoint.pipelines);
202 }
203 Err(e) => warn!("Failed to deserialize checkpoint from {:?}: {}", path, e),
204 },
205 Err(e) => warn!("Failed to read checkpoint file {:?}: {}", path, e),
206 }
207 }
208
209 #[cfg(feature = "checkpoint")]
210 if let Some(pipeline_states) = loaded_pipelines_state {
212 for (name, state) in pipeline_states {
213 if let Some(pipeline) = self.item_pipelines.iter().find(|p| p.name() == name) {
214 pipeline.restore_state(state).await?;
215 } else {
216 warn!("Checkpoint contains state for unknown pipeline: {}", name);
217 }
218 }
219 }
220
221 let (scheduler_arc, req_rx) = Scheduler::new(initial_scheduler_state);
222
223 let has_user_agent_middleware = self
224 .middlewares
225 .iter()
226 .any(|m| m.name() == "UserAgentMiddleware");
227
228 if !has_user_agent_middleware {
229 let pkg_name = env!("CARGO_PKG_NAME");
230 let pkg_version = env!("CARGO_PKG_VERSION");
231 let default_user_agent = format!("{}/{}", pkg_name, pkg_version);
232
233 let default_user_agent_mw = UserAgentMiddleware::builder()
234 .source(UserAgentSource::List(vec![default_user_agent.clone()]))
235 .fallback_user_agent(default_user_agent)
236 .build()?;
237 self.middlewares.insert(0, Box::new(default_user_agent_mw));
238 }
239
240 let downloader_arc = Arc::new(self.downloader);
241
242 let crawler = Crawler::new(
243 scheduler_arc,
244 req_rx,
245 downloader_arc,
246 self.middlewares,
247 spider,
248 self.item_pipelines,
249 self.crawler_config.max_concurrent_downloads,
250 self.crawler_config.parser_workers,
251 self.crawler_config.max_concurrent_pipelines,
252 #[cfg(feature = "checkpoint")]
253 self.checkpoint_path.take(),
254 #[cfg(feature = "checkpoint")]
255 self.checkpoint_interval,
256 );
257
258 Ok(crawler)
259 }
260}