Skip to main content

spider_lib/
crawler.rs

1//! The core Crawler implementation for the `spider-lib` framework.
2//!
3//! This module defines the `Crawler` struct, which acts as the central orchestrator
4//! for the web scraping process. It ties together the scheduler, downloader,
5//! middlewares, spiders, and item pipelines to execute a crawl. The crawler
6//! manages the lifecycle of requests and items, handles concurrency, and supports
7//! checkpointing for fault tolerance.
8//!
9//! It utilizes a task-based asynchronous model, spawning distinct tasks for
10//! handling initial requests, downloading web pages, parsing responses, and
11//! processing scraped items.
12
13use crate::downloader::Downloader;
14use crate::error::SpiderError;
15use crate::item::{ParseOutput, ScrapedItem};
16use crate::middleware::{Middleware, MiddlewareAction};
17use crate::pipeline::Pipeline;
18use crate::request::Request;
19use crate::response::Response;
20use crate::scheduler::Scheduler;
21use crate::spider::Spider;
22use crate::state::CrawlerState;
23use anyhow::Result;
24use futures_util::future::join_all;
25use kanal::{AsyncReceiver, AsyncSender, bounded_async};
26
27#[cfg(feature = "checkpoint")]
28use crate::checkpoint::save_checkpoint;
29#[cfg(feature = "checkpoint")]
30use std::path::PathBuf;
31use std::sync::Arc;
32use std::sync::atomic::Ordering;
33use std::time::Duration;
34use tokio::sync::Mutex;
35use tokio::sync::Semaphore;
36use tokio::task::JoinSet;
37use tracing::{debug, error, info, warn};
38
39pub struct Crawler<S: Spider, C> {
40    scheduler: Arc<Scheduler>,
41    req_rx: AsyncReceiver<Request>,
42    downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
43    middlewares: Vec<Box<dyn Middleware<C> + Send + Sync>>,
44    spider: Arc<Mutex<S>>,
45    item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
46    max_concurrent_downloads: usize,
47    parser_workers: usize,
48    max_concurrent_pipelines: usize,
49    #[cfg(feature = "checkpoint")]
50    checkpoint_path: Option<PathBuf>,
51    #[cfg(feature = "checkpoint")]
52    checkpoint_interval: Option<Duration>,
53}
54
55impl<S, C> Crawler<S, C>
56where
57    S: Spider + 'static,
58    S::Item: ScrapedItem,
59    C: Send + Sync + 'static,
60{
61    #[allow(clippy::too_many_arguments)]
62    pub(crate) fn new(
63        scheduler: Arc<Scheduler>,
64        req_rx: AsyncReceiver<Request>,
65        downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
66        middlewares: Vec<Box<dyn Middleware<C> + Send + Sync>>,
67        spider: S,
68        item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
69        max_concurrent_downloads: usize,
70        parser_workers: usize,
71        max_concurrent_pipelines: usize,
72        #[cfg(feature = "checkpoint")] checkpoint_path: Option<PathBuf>,
73        #[cfg(feature = "checkpoint")] checkpoint_interval: Option<Duration>,
74    ) -> Self {
75        Crawler {
76            scheduler,
77            req_rx,
78            downloader,
79            middlewares,
80            spider: Arc::new(Mutex::new(spider)),
81            item_pipelines,
82            max_concurrent_downloads,
83            parser_workers,
84            max_concurrent_pipelines,
85            #[cfg(feature = "checkpoint")]
86            checkpoint_path,
87            #[cfg(feature = "checkpoint")]
88            checkpoint_interval,
89        }
90    }
91
92    /// Starts the crawl.
93    pub async fn start_crawl(self) -> Result<(), SpiderError> {
94        info!("Crawler starting crawl");
95
96        let Crawler {
97            scheduler,
98            req_rx,
99            downloader,
100            middlewares,
101            spider,
102            item_pipelines,
103            max_concurrent_downloads,
104            parser_workers,
105            max_concurrent_pipelines,
106            #[cfg(feature = "checkpoint")]
107            checkpoint_path,
108            #[cfg(feature = "checkpoint")]
109            checkpoint_interval,
110        } = self;
111
112        let state = CrawlerState::new();
113        let pipelines = Arc::new(item_pipelines);
114        let channel_capacity = max_concurrent_downloads * 2;
115
116        let (res_tx, res_rx) = bounded_async(channel_capacity);
117        let (item_tx, item_rx) = bounded_async(channel_capacity);
118
119        let initial_requests_task =
120            spawn_initial_requests_task::<S>(scheduler.clone(), spider.clone());
121
122        let downloader_task = spawn_downloader_task::<S, C>(
123            scheduler.clone(),
124            req_rx,
125            downloader,
126            Arc::new(Mutex::new(middlewares)),
127            state.clone(),
128            res_tx.clone(),
129            max_concurrent_downloads,
130        );
131
132        let parser_task = spawn_parser_task::<S>(
133            scheduler.clone(),
134            spider.clone(),
135            state.clone(),
136            res_rx,
137            item_tx.clone(),
138            parser_workers,
139        );
140
141        let item_processor_task = spawn_item_processor_task::<S>(
142            state.clone(),
143            item_rx,
144            pipelines.clone(),
145            max_concurrent_pipelines,
146        );
147
148        #[cfg(feature = "checkpoint")]
149        if let (Some(path), Some(interval)) = (&checkpoint_path, checkpoint_interval) {
150            let scheduler_clone = scheduler.clone();
151            let pipelines_clone = pipelines.clone();
152            let path_clone = path.clone();
153
154            tokio::spawn(async move {
155                let mut interval_timer = tokio::time::interval(interval);
156                interval_timer.tick().await;
157                loop {
158                    tokio::select! {
159                        _ = interval_timer.tick() => {
160                            if let Ok(scheduler_checkpoint) = scheduler_clone.snapshot().await &&
161                                let Err(e) = save_checkpoint::<S>(&path_clone, scheduler_checkpoint, &pipelines_clone).await {
162                                    error!("Periodic checkpoint save failed: {}", e);
163                            }
164                        }
165                    }
166                }
167            });
168        }
169
170        tokio::select! {
171            _ = tokio::signal::ctrl_c() => {
172                info!("Ctrl-C received, initiating graceful shutdown.");
173            }
174            _ = async {
175                loop {
176                    if scheduler.is_idle() && state.is_idle() {
177                        tokio::time::sleep(Duration::from_millis(50)).await;
178                        if scheduler.is_idle() && state.is_idle() {
179                            break;
180                        }
181                    }
182                    tokio::time::sleep(Duration::from_millis(100)).await;
183                }
184            } => {
185                info!("Crawl has become idle, initiating shutdown.");
186            }
187        }
188
189        info!("Initiating actor shutdowns.");
190
191        #[cfg(feature = "checkpoint")]
192        let scheduler_checkpoint = scheduler.snapshot().await?;
193
194        drop(res_tx);
195        drop(item_tx);
196
197        scheduler.shutdown().await?;
198
199        item_processor_task
200            .await
201            .map_err(|e| SpiderError::GeneralError(format!("Item processor task failed: {}", e)))?;
202
203        parser_task
204            .await
205            .map_err(|e| SpiderError::GeneralError(format!("Parser task failed: {}", e)))?;
206
207        downloader_task
208            .await
209            .map_err(|e| SpiderError::GeneralError(format!("Downloader task failed: {}", e)))?;
210
211        initial_requests_task.await.map_err(|e| {
212            SpiderError::GeneralError(format!("Initial requests task failed: {}", e))
213        })?;
214
215        #[cfg(feature = "checkpoint")]
216        if let Some(path) = &checkpoint_path
217            && let Err(e) = save_checkpoint::<S>(path, scheduler_checkpoint, &pipelines).await
218        {
219            error!("Final checkpoint save failed: {}", e);
220        }
221
222        // Close all pipelines
223        info!("Closing item pipelines...");
224        let closing_futures: Vec<_> = pipelines.iter().map(|p| p.close()).collect();
225        join_all(closing_futures).await;
226
227        info!("Crawl finished successfully.");
228        Ok(())
229    }
230}
231
232fn spawn_initial_requests_task<S>(
233    scheduler: Arc<Scheduler>,
234    spider: Arc<Mutex<S>>,
235) -> tokio::task::JoinHandle<()>
236where
237    S: Spider + 'static,
238    S::Item: ScrapedItem,
239{
240    tokio::spawn(async move {
241        match spider.lock().await.start_requests() {
242            Ok(requests) => {
243                for mut req in requests {
244                    req.url.set_fragment(None);
245                    match scheduler.enqueue_request(req).await {
246                        Ok(_) => {}
247                        Err(e) => {
248                            error!("Failed to enqueue initial request: {}", e);
249                        }
250                    }
251                }
252            }
253            Err(e) => error!("Failed to create start requests: {}", e),
254        }
255    })
256}
257
258#[allow(clippy::too_many_arguments)]
259fn spawn_downloader_task<S, C>(
260    scheduler: Arc<Scheduler>,
261    req_rx: AsyncReceiver<Request>,
262    downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
263    middlewares: Arc<Mutex<Vec<Box<dyn Middleware<C> + Send + Sync>>>>,
264    state: Arc<CrawlerState>,
265    res_tx: AsyncSender<Response>,
266    max_concurrent_downloads: usize,
267) -> tokio::task::JoinHandle<()>
268where
269    S: Spider + 'static,
270    S::Item: ScrapedItem,
271    C: Send + Sync + 'static,
272{
273    let semaphore = Arc::new(Semaphore::new(max_concurrent_downloads));
274    let mut tasks = JoinSet::new();
275
276    tokio::spawn(async move {
277        while let Ok(request) = req_rx.recv().await {
278            let permit = match semaphore.clone().acquire_owned().await {
279                Ok(permit) => permit,
280                Err(_) => {
281                    warn!("Semaphore closed, shutting down downloader actor.");
282                    break;
283                }
284            };
285
286            state.in_flight_requests.fetch_add(1, Ordering::SeqCst);
287            let downloader_clone = Arc::clone(&downloader);
288            let middlewares_clone = Arc::clone(&middlewares);
289            let res_tx_clone = res_tx.clone();
290            let state_clone = Arc::clone(&state);
291            let scheduler_clone = Arc::clone(&scheduler);
292
293            tasks.spawn(async move {
294                let mut processed_request = request;
295                let mut early_returned_response: Option<Response> = None;
296
297                // Process request middlewares
298                for mw in middlewares_clone.lock().await.iter_mut() {
299                    match mw.process_request(downloader_clone.client(), processed_request.clone()).await {
300                        Ok(MiddlewareAction::Continue(req)) => {
301                            processed_request = req;
302                        }
303                        Ok(MiddlewareAction::Retry(req, delay)) => {
304                            tokio::time::sleep(delay).await;
305                            if scheduler_clone.enqueue_request(*req).await.is_err() {
306                                error!("Failed to re-enqueue retried request.");
307                            }
308                            return;
309                        }
310                        Ok(MiddlewareAction::Drop) => {
311                            debug!("Request dropped by middleware.");
312                            return;
313                        }
314                        Ok(MiddlewareAction::ReturnResponse(resp)) => {
315                            early_returned_response = Some(resp);
316                            break;
317                        }
318                        Err(e) => {
319                            error!("Request middleware error: {:?}", e);
320                            return;
321                        }
322                    }
323                }
324
325                // Download or use early response
326                let mut response = match early_returned_response {
327                    Some(resp) => resp,
328                    None => match downloader_clone.download(processed_request).await {
329                        Ok(resp) => resp,
330                        Err(e) => {
331                            error!("Download error: {:?}", e);
332                            return;
333                        }
334                    },
335                };
336
337                // Process response middlewares
338                for mw in middlewares_clone.lock().await.iter_mut().rev() {
339                    match mw.process_response(response.clone()).await {
340                        Ok(MiddlewareAction::Continue(res)) => {
341                            response = res;
342                        }
343                        Ok(MiddlewareAction::Retry(req, delay)) => {
344                            tokio::time::sleep(delay).await;
345                            if scheduler_clone.enqueue_request(*req).await.is_err() {
346                                error!("Failed to re-enqueue retried request.");
347                            }
348                            return;
349                        }
350                        Ok(MiddlewareAction::Drop) => {
351                            debug!("Response dropped by middleware.");
352                            return;
353                        }
354                        Ok(MiddlewareAction::ReturnResponse(_)) => {
355                            debug!("ReturnResponse action encountered in process_response; this is unexpected.");
356                            continue;
357                        }
358                        Err(e) => {
359                            error!("Response middleware error: {:?}", e);
360                            return;
361                        }
362                    }
363                }
364
365                if res_tx_clone.send(response).await.is_err() {
366                    error!("Response channel closed, cannot send parsed response.");
367                }
368
369                state_clone.in_flight_requests.fetch_sub(1, Ordering::SeqCst);
370                drop(permit);
371            });
372        }
373        while let Some(res) = tasks.join_next().await {
374            if let Err(e) = res {
375                error!("A download task failed: {:?}", e);
376            }
377        }
378    })
379}
380
381fn spawn_parser_task<S>(
382    scheduler: Arc<Scheduler>,
383    spider: Arc<Mutex<S>>,
384    state: Arc<CrawlerState>,
385    res_rx: AsyncReceiver<Response>,
386    item_tx: AsyncSender<S::Item>,
387    parser_workers: usize,
388) -> tokio::task::JoinHandle<()>
389where
390    S: Spider + 'static,
391    S::Item: ScrapedItem,
392{
393    let mut tasks = JoinSet::new();
394    let internal_parse_tx: AsyncSender<Response>;
395    let internal_parse_rx: AsyncReceiver<Response>;
396    (internal_parse_tx, internal_parse_rx) = bounded_async(parser_workers * 2);
397
398    // Spawn N parsing worker tasks
399    for _ in 0..parser_workers {
400        let internal_parse_rx_clone = internal_parse_rx.clone();
401        let spider_clone = Arc::clone(&spider);
402        let scheduler_clone = Arc::clone(&scheduler);
403        let item_tx_clone = item_tx.clone();
404        let state_clone = Arc::clone(&state);
405
406        tasks.spawn(async move {
407            while let Ok(response) = internal_parse_rx_clone.recv().await {
408                debug!("Parsing response from {}", response.url);
409                match spider_clone.lock().await.parse(response).await {
410                    Ok(outputs) => {
411                        process_crawl_outputs::<S>(
412                            outputs,
413                            scheduler_clone.clone(),
414                            item_tx_clone.clone(),
415                            state_clone.clone(),
416                        )
417                        .await;
418                    }
419                    Err(e) => error!("Spider parsing error: {:?}", e),
420                }
421                state_clone.parsing_responses.fetch_sub(1, Ordering::SeqCst);
422            }
423        });
424    }
425
426    tokio::spawn(async move {
427        while let Ok(response) = res_rx.recv().await {
428            state.parsing_responses.fetch_add(1, Ordering::SeqCst);
429            if internal_parse_tx.send(response).await.is_err() {
430                error!("Internal parse channel closed, cannot send response to parser worker.");
431                state.parsing_responses.fetch_sub(1, Ordering::SeqCst);
432            }
433        }
434
435        drop(internal_parse_tx);
436
437        while let Some(res) = tasks.join_next().await {
438            if let Err(e) = res {
439                error!("A parsing worker task failed: {:?}", e);
440            }
441        }
442    })
443}
444
445async fn process_crawl_outputs<S>(
446    outputs: ParseOutput<S::Item>,
447    scheduler: Arc<Scheduler>,
448    item_tx: AsyncSender<S::Item>,
449    state: Arc<CrawlerState>,
450) where
451    S: Spider + 'static,
452    S::Item: ScrapedItem,
453{
454    let (items, requests) = outputs.into_parts();
455    info!(
456        "Processed {} requests and {} items from spider output.",
457        requests.len(),
458        items.len()
459    );
460
461    let mut request_error_total = 0;
462    for request in requests {
463        match scheduler.enqueue_request(request).await {
464            Ok(_) => {}
465            Err(_) => {
466                request_error_total += 1;
467            }
468        }
469    }
470    if request_error_total > 0 {
471        error!(
472            "Failed to enqueue {} requests: scheduler channel closed.",
473            request_error_total
474        );
475    }
476
477    let mut item_error_total = 0;
478    for item in items {
479        state.processing_items.fetch_add(1, Ordering::SeqCst);
480        if item_tx.send(item).await.is_err() {
481            item_error_total += 1;
482            state.processing_items.fetch_sub(1, Ordering::SeqCst);
483        }
484    }
485    if item_error_total > 0 {
486        error!(
487            "Failed to send {} scraped items: channel closed.",
488            item_error_total
489        );
490    }
491}
492
493fn spawn_item_processor_task<S>(
494    state: Arc<CrawlerState>,
495    item_rx: AsyncReceiver<S::Item>,
496    pipelines: Arc<Vec<Box<dyn Pipeline<S::Item>>>>,
497    max_concurrent_pipelines: usize,
498) -> tokio::task::JoinHandle<()>
499where
500    S: Spider + 'static,
501    S::Item: ScrapedItem,
502{
503    let mut tasks = JoinSet::new();
504    let semaphore = Arc::new(Semaphore::new(max_concurrent_pipelines));
505    tokio::spawn(async move {
506        while let Ok(item) = item_rx.recv().await {
507            let permit = match semaphore.clone().acquire_owned().await {
508                Ok(p) => p,
509                Err(_) => {
510                    warn!("Semaphore closed, shutting down item processor actor.");
511                    break;
512                }
513            };
514
515            let state_clone = Arc::clone(&state);
516            let pipelines_clone = Arc::clone(&pipelines);
517            tasks.spawn(async move {
518                let mut item_to_process = Some(item);
519                for pipeline in pipelines_clone.iter() {
520                    if let Some(item) = item_to_process.take() {
521                        match pipeline.process_item(item).await {
522                            Ok(Some(next_item)) => item_to_process = Some(next_item),
523                            Ok(None) => break,
524                            Err(e) => {
525                                error!("Pipeline error for {}: {:?}", pipeline.name(), e);
526                                break;
527                            }
528                        }
529                    } else {
530                        break;
531                    }
532                }
533                state_clone.processing_items.fetch_sub(1, Ordering::SeqCst);
534                drop(permit);
535            });
536        }
537        while let Some(res) = tasks.join_next().await {
538            if let Err(e) = res {
539                error!("An item processing task failed: {:?}", e);
540            }
541        }
542    })
543}