1use crate::downloader::Downloader;
14use crate::error::SpiderError;
15use crate::item::{ParseOutput, ScrapedItem};
16use crate::middleware::{Middleware, MiddlewareAction};
17use crate::pipeline::Pipeline;
18use crate::request::Request;
19use crate::response::Response;
20use crate::scheduler::Scheduler;
21use crate::spider::Spider;
22use crate::state::CrawlerState;
23use anyhow::Result;
24use futures_util::future::join_all;
25use kanal::{AsyncReceiver, AsyncSender, bounded_async};
26
27use crate::checkpoint::save_checkpoint;
28use std::path::PathBuf;
29use std::sync::Arc;
30use std::sync::atomic::Ordering;
31use std::time::Duration;
32use tokio::sync::Mutex;
33use tokio::sync::Semaphore;
34use tokio::task::JoinSet;
35use tracing::{debug, error, info, warn};
36
37pub struct Crawler<S: Spider, C> {
38 scheduler: Arc<Scheduler>,
39 req_rx: AsyncReceiver<Request>,
40 downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
41 middlewares: Vec<Box<dyn Middleware<C> + Send + Sync>>,
42 spider: Arc<Mutex<S>>,
43 item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
44 max_concurrent_downloads: usize,
45 parser_workers: usize,
46 max_concurrent_pipelines: usize,
47 checkpoint_path: Option<PathBuf>,
48 checkpoint_interval: Option<Duration>,
49}
50
51impl<S, C> Crawler<S, C>
52where
53 S: Spider + 'static,
54 S::Item: ScrapedItem,
55 C: Send + Sync + 'static,
56{
57 #[allow(clippy::too_many_arguments)]
58 pub(crate) fn new(
59 scheduler: Arc<Scheduler>,
60 req_rx: AsyncReceiver<Request>,
61 downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
62 middlewares: Vec<Box<dyn Middleware<C> + Send + Sync>>,
63 spider: S,
64 item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
65 max_concurrent_downloads: usize,
66 parser_workers: usize,
67 max_concurrent_pipelines: usize,
68 checkpoint_path: Option<PathBuf>,
69 checkpoint_interval: Option<Duration>,
70 ) -> Self {
71 Crawler {
72 scheduler,
73 req_rx,
74 downloader,
75 middlewares,
76 spider: Arc::new(Mutex::new(spider)),
77 item_pipelines,
78 max_concurrent_downloads,
79 parser_workers,
80 max_concurrent_pipelines,
81 checkpoint_path,
82 checkpoint_interval,
83 }
84 }
85
86 pub async fn start_crawl(self) -> Result<(), SpiderError> {
88 info!("Crawler starting crawl");
89
90 let Crawler {
91 scheduler,
92 req_rx,
93 downloader,
94 middlewares,
95 spider,
96 item_pipelines,
97 max_concurrent_downloads,
98 parser_workers,
99 max_concurrent_pipelines,
100 checkpoint_path,
101 checkpoint_interval,
102 } = self;
103
104 let state = CrawlerState::new();
105 let pipelines = Arc::new(item_pipelines);
106 let channel_capacity = max_concurrent_downloads * 2;
107
108 let (res_tx, res_rx) = bounded_async(channel_capacity);
109 let (item_tx, item_rx) = bounded_async(channel_capacity);
110
111 let initial_requests_task =
112 spawn_initial_requests_task::<S>(scheduler.clone(), spider.clone());
113
114 let downloader_task = spawn_downloader_task::<S, C>(
115 scheduler.clone(),
116 req_rx,
117 downloader,
118 Arc::new(Mutex::new(middlewares)),
119 state.clone(),
120 res_tx.clone(),
121 max_concurrent_downloads,
122 );
123
124 let parser_task = spawn_parser_task::<S>(
125 scheduler.clone(),
126 spider.clone(),
127 state.clone(),
128 res_rx,
129 item_tx.clone(),
130 parser_workers,
131 );
132
133 let item_processor_task = spawn_item_processor_task::<S>(
134 state.clone(),
135 item_rx,
136 pipelines.clone(),
137 max_concurrent_pipelines,
138 );
139
140 if let (Some(path), Some(interval)) = (&checkpoint_path, checkpoint_interval) {
141 let scheduler_clone = scheduler.clone();
142 let pipelines_clone = pipelines.clone();
143 let path_clone = path.clone();
144
145 tokio::spawn(async move {
146 let mut interval_timer = tokio::time::interval(interval);
147 interval_timer.tick().await;
148 loop {
149 tokio::select! {
150 _ = interval_timer.tick() => {
151 if let Ok(scheduler_checkpoint) = scheduler_clone.snapshot().await &&
152 let Err(e) = save_checkpoint::<S>(&path_clone, scheduler_checkpoint, &pipelines_clone).await {
153 error!("Periodic checkpoint save failed: {}", e);
154 }
155 }
156 }
157 }
158 });
159 }
160
161 tokio::select! {
162 _ = tokio::signal::ctrl_c() => {
163 info!("Ctrl-C received, initiating graceful shutdown.");
164 }
165 _ = async {
166 loop {
167 if scheduler.is_idle() && state.is_idle() {
168 tokio::time::sleep(Duration::from_millis(50)).await;
169 if scheduler.is_idle() && state.is_idle() {
170 break;
171 }
172 }
173 tokio::time::sleep(Duration::from_millis(100)).await;
174 }
175 } => {
176 info!("Crawl has become idle, initiating shutdown.");
177 }
178 }
179
180 info!("Initiating actor shutdowns.");
181
182 let scheduler_checkpoint = scheduler.snapshot().await?;
183
184 drop(res_tx);
185 drop(item_tx);
186
187 scheduler.shutdown().await?;
188
189 item_processor_task
190 .await
191 .map_err(|e| SpiderError::GeneralError(format!("Item processor task failed: {}", e)))?;
192
193 parser_task
194 .await
195 .map_err(|e| SpiderError::GeneralError(format!("Parser task failed: {}", e)))?;
196
197 downloader_task
198 .await
199 .map_err(|e| SpiderError::GeneralError(format!("Downloader task failed: {}", e)))?;
200
201 initial_requests_task.await.map_err(|e| {
202 SpiderError::GeneralError(format!("Initial requests task failed: {}", e))
203 })?;
204
205 if let Some(path) = &checkpoint_path
207 && let Err(e) = save_checkpoint::<S>(path, scheduler_checkpoint, &pipelines).await
208 {
209 error!("Final checkpoint save failed: {}", e);
210 }
211
212 info!("Closing item pipelines...");
214 let closing_futures: Vec<_> = pipelines.iter().map(|p| p.close()).collect();
215 join_all(closing_futures).await;
216
217 info!("Crawl finished successfully.");
218 Ok(())
219 }
220}
221
222fn spawn_initial_requests_task<S>(
223 scheduler: Arc<Scheduler>,
224 spider: Arc<Mutex<S>>,
225) -> tokio::task::JoinHandle<()>
226where
227 S: Spider + 'static,
228 S::Item: ScrapedItem,
229{
230 tokio::spawn(async move {
231 match spider.lock().await.start_requests() {
232 Ok(requests) => {
233 for mut req in requests {
234 req.url.set_fragment(None);
235 match scheduler.enqueue_request(req).await {
236 Ok(_) => {}
237 Err(e) => {
238 error!("Failed to enqueue initial request: {}", e);
239 }
240 }
241 }
242 }
243 Err(e) => error!("Failed to create start requests: {}", e),
244 }
245 })
246}
247
248#[allow(clippy::too_many_arguments)]
249fn spawn_downloader_task<S, C>(
250 scheduler: Arc<Scheduler>,
251 req_rx: AsyncReceiver<Request>,
252 downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
253 middlewares: Arc<Mutex<Vec<Box<dyn Middleware<C> + Send + Sync>>>>,
254 state: Arc<CrawlerState>,
255 res_tx: AsyncSender<Response>,
256 max_concurrent_downloads: usize,
257) -> tokio::task::JoinHandle<()>
258where
259 S: Spider + 'static,
260 S::Item: ScrapedItem,
261 C: Send + Sync + 'static,
262{
263 let semaphore = Arc::new(Semaphore::new(max_concurrent_downloads));
264 let mut tasks = JoinSet::new();
265
266 tokio::spawn(async move {
267 while let Ok(request) = req_rx.recv().await {
268 let permit = match semaphore.clone().acquire_owned().await {
269 Ok(permit) => permit,
270 Err(_) => {
271 warn!("Semaphore closed, shutting down downloader actor.");
272 break;
273 }
274 };
275
276 state.in_flight_requests.fetch_add(1, Ordering::SeqCst);
277 let downloader_clone = Arc::clone(&downloader);
278 let middlewares_clone = Arc::clone(&middlewares);
279 let res_tx_clone = res_tx.clone();
280 let state_clone = Arc::clone(&state);
281 let scheduler_clone = Arc::clone(&scheduler);
282
283 tasks.spawn(async move {
284 let mut processed_request = request;
285 let mut early_returned_response: Option<Response> = None;
286
287 for mw in middlewares_clone.lock().await.iter_mut() {
289 match mw.process_request(downloader_clone.client(), processed_request.clone()).await {
290 Ok(MiddlewareAction::Continue(req)) => {
291 processed_request = req;
292 }
293 Ok(MiddlewareAction::Retry(req, delay)) => {
294 tokio::time::sleep(delay).await;
295 if scheduler_clone.enqueue_request(*req).await.is_err() {
296 error!("Failed to re-enqueue retried request.");
297 }
298 return;
299 }
300 Ok(MiddlewareAction::Drop) => {
301 debug!("Request dropped by middleware.");
302 return;
303 }
304 Ok(MiddlewareAction::ReturnResponse(resp)) => {
305 early_returned_response = Some(resp);
306 break;
307 }
308 Err(e) => {
309 error!("Request middleware error: {:?}", e);
310 return;
311 }
312 }
313 }
314
315 let mut response = match early_returned_response {
317 Some(resp) => resp,
318 None => match downloader_clone.download(processed_request).await {
319 Ok(resp) => resp,
320 Err(e) => {
321 error!("Download error: {:?}", e);
322 return;
323 }
324 },
325 };
326
327 for mw in middlewares_clone.lock().await.iter_mut().rev() {
329 match mw.process_response(response.clone()).await {
330 Ok(MiddlewareAction::Continue(res)) => {
331 response = res;
332 }
333 Ok(MiddlewareAction::Retry(req, delay)) => {
334 tokio::time::sleep(delay).await;
335 if scheduler_clone.enqueue_request(*req).await.is_err() {
336 error!("Failed to re-enqueue retried request.");
337 }
338 return;
339 }
340 Ok(MiddlewareAction::Drop) => {
341 debug!("Response dropped by middleware.");
342 return;
343 }
344 Ok(MiddlewareAction::ReturnResponse(_)) => {
345 debug!("ReturnResponse action encountered in process_response; this is unexpected.");
346 continue;
347 }
348 Err(e) => {
349 error!("Response middleware error: {:?}", e);
350 return;
351 }
352 }
353 }
354
355 if res_tx_clone.send(response).await.is_err() {
356 error!("Response channel closed, cannot send parsed response.");
357 }
358
359 state_clone.in_flight_requests.fetch_sub(1, Ordering::SeqCst);
360 drop(permit);
361 });
362 }
363 while let Some(res) = tasks.join_next().await {
364 if let Err(e) = res {
365 error!("A download task failed: {:?}", e);
366 }
367 }
368 })
369}
370
371fn spawn_parser_task<S>(
372 scheduler: Arc<Scheduler>,
373 spider: Arc<Mutex<S>>,
374 state: Arc<CrawlerState>,
375 res_rx: AsyncReceiver<Response>,
376 item_tx: AsyncSender<S::Item>,
377 parser_workers: usize,
378) -> tokio::task::JoinHandle<()>
379where
380 S: Spider + 'static,
381 S::Item: ScrapedItem,
382{
383 let mut tasks = JoinSet::new();
384 let internal_parse_tx: AsyncSender<Response>;
385 let internal_parse_rx: AsyncReceiver<Response>;
386 (internal_parse_tx, internal_parse_rx) = bounded_async(parser_workers * 2);
387
388 for _ in 0..parser_workers {
390 let internal_parse_rx_clone = internal_parse_rx.clone();
391 let spider_clone = Arc::clone(&spider);
392 let scheduler_clone = Arc::clone(&scheduler);
393 let item_tx_clone = item_tx.clone();
394 let state_clone = Arc::clone(&state);
395
396 tasks.spawn(async move {
397 while let Ok(response) = internal_parse_rx_clone.recv().await {
398 debug!("Parsing response from {}", response.url);
399 match spider_clone.lock().await.parse(response).await {
400 Ok(outputs) => {
401 process_crawl_outputs::<S>(
402 outputs,
403 scheduler_clone.clone(),
404 item_tx_clone.clone(),
405 state_clone.clone(),
406 )
407 .await;
408 }
409 Err(e) => error!("Spider parsing error: {:?}", e),
410 }
411 state_clone.parsing_responses.fetch_sub(1, Ordering::SeqCst);
412 }
413 });
414 }
415
416 tokio::spawn(async move {
417 while let Ok(response) = res_rx.recv().await {
418 state.parsing_responses.fetch_add(1, Ordering::SeqCst);
419 if internal_parse_tx.send(response).await.is_err() {
420 error!("Internal parse channel closed, cannot send response to parser worker.");
421 state.parsing_responses.fetch_sub(1, Ordering::SeqCst);
422 }
423 }
424
425 drop(internal_parse_tx);
426
427 while let Some(res) = tasks.join_next().await {
428 if let Err(e) = res {
429 error!("A parsing worker task failed: {:?}", e);
430 }
431 }
432 })
433}
434
435async fn process_crawl_outputs<S>(
436 outputs: ParseOutput<S::Item>,
437 scheduler: Arc<Scheduler>,
438 item_tx: AsyncSender<S::Item>,
439 state: Arc<CrawlerState>,
440) where
441 S: Spider + 'static,
442 S::Item: ScrapedItem,
443{
444 let (items, requests) = outputs.into_parts();
445 info!(
446 "Processed {} requests and {} items from spider output.",
447 requests.len(),
448 items.len()
449 );
450
451 let mut request_error_total = 0;
452 for request in requests {
453 match scheduler.enqueue_request(request).await {
454 Ok(_) => {}
455 Err(_) => {
456 request_error_total += 1;
457 }
458 }
459 }
460 if request_error_total > 0 {
461 error!(
462 "Failed to enqueue {} requests: scheduler channel closed.",
463 request_error_total
464 );
465 }
466
467 let mut item_error_total = 0;
468 for item in items {
469 state.processing_items.fetch_add(1, Ordering::SeqCst);
470 if item_tx.send(item).await.is_err() {
471 item_error_total += 1;
472 state.processing_items.fetch_sub(1, Ordering::SeqCst);
473 }
474 }
475 if item_error_total > 0 {
476 error!(
477 "Failed to send {} scraped items: channel closed.",
478 item_error_total
479 );
480 }
481}
482
483fn spawn_item_processor_task<S>(
484 state: Arc<CrawlerState>,
485 item_rx: AsyncReceiver<S::Item>,
486 pipelines: Arc<Vec<Box<dyn Pipeline<S::Item>>>>,
487 max_concurrent_pipelines: usize,
488) -> tokio::task::JoinHandle<()>
489where
490 S: Spider + 'static,
491 S::Item: ScrapedItem,
492{
493 let mut tasks = JoinSet::new();
494 let semaphore = Arc::new(Semaphore::new(max_concurrent_pipelines));
495 tokio::spawn(async move {
496 while let Ok(item) = item_rx.recv().await {
497 let permit = match semaphore.clone().acquire_owned().await {
498 Ok(p) => p,
499 Err(_) => {
500 warn!("Semaphore closed, shutting down item processor actor.");
501 break;
502 }
503 };
504
505 let state_clone = Arc::clone(&state);
506 let pipelines_clone = Arc::clone(&pipelines);
507 tasks.spawn(async move {
508 let mut item_to_process = Some(item);
509 for pipeline in pipelines_clone.iter() {
510 if let Some(item) = item_to_process.take() {
511 match pipeline.process_item(item).await {
512 Ok(Some(next_item)) => item_to_process = Some(next_item),
513 Ok(None) => break,
514 Err(e) => {
515 error!("Pipeline error for {}: {:?}", pipeline.name(), e);
516 break;
517 }
518 }
519 } else {
520 break;
521 }
522 }
523 state_clone.processing_items.fetch_sub(1, Ordering::SeqCst);
524 drop(permit);
525 });
526 }
527 while let Some(res) = tasks.join_next().await {
528 if let Err(e) = res {
529 error!("An item processing task failed: {:?}", e);
530 }
531 }
532 })
533}