1use crate::checkpoint::{Checkpoint, SchedulerCheckpoint};
2use crate::downloader::Downloader;
3use crate::error::SpiderError;
4use crate::item::{ParseOutput, ScrapedItem};
5use crate::middleware::{Middleware, MiddlewareAction};
6use crate::pipeline::Pipeline;
7use crate::request::Request;
8use crate::response::Response;
9use crate::scheduler::Scheduler;
10use crate::spider::Spider;
11use crate::state::CrawlerState;
12use anyhow::Result;
13use futures_util::future::join_all;
14use kanal::{AsyncReceiver, AsyncSender, bounded_async};
15use std::collections::{HashMap, VecDeque};
16use std::fs;
17use std::path::{Path, PathBuf};
18use std::sync::Arc;
19use std::sync::atomic::Ordering;
20use std::time::Duration;
21use tokio::sync::Mutex;
22use tokio::sync::Semaphore;
23use tokio::task::JoinSet;
24use tracing::{debug, error, info, warn};
25
26pub struct Crawler<S: Spider, C> {
27 scheduler: Arc<Scheduler>,
28 req_rx: AsyncReceiver<Request>,
29 downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
30 middlewares: Vec<Box<dyn Middleware<C> + Send + Sync>>,
31 spider: Arc<Mutex<S>>,
32 item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
33 max_concurrent_downloads: usize,
34 parser_workers: usize,
35 max_concurrent_pipelines: usize,
36 checkpoint_path: Option<PathBuf>,
37 checkpoint_interval: Option<Duration>,
38}
39
40impl<S, C> Crawler<S, C>
41where
42 S: Spider + 'static,
43 S::Item: ScrapedItem,
44 C: Send + Sync + 'static,
45{
46 #[allow(clippy::too_many_arguments)]
47 pub(crate) fn new(
48 scheduler: Arc<Scheduler>,
49 req_rx: AsyncReceiver<Request>,
50 downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
51 middlewares: Vec<Box<dyn Middleware<C> + Send + Sync>>,
52 spider: S,
53 item_pipelines: Vec<Box<dyn Pipeline<S::Item>>>,
54 max_concurrent_downloads: usize,
55 parser_workers: usize,
56 max_concurrent_pipelines: usize,
57 checkpoint_path: Option<PathBuf>,
58 checkpoint_interval: Option<Duration>,
59 ) -> Self {
60 Crawler {
61 scheduler,
62 req_rx,
63 downloader,
64 middlewares,
65 spider: Arc::new(Mutex::new(spider)),
66 item_pipelines,
67 max_concurrent_downloads,
68 parser_workers,
69 max_concurrent_pipelines,
70 checkpoint_path,
71 checkpoint_interval,
72 }
73 }
74
75 pub async fn start_crawl(self) -> Result<(), SpiderError> {
77 info!("Crawler starting crawl");
78
79 let Crawler {
80 scheduler,
81 req_rx,
82 downloader,
83 middlewares,
84 spider,
85 item_pipelines,
86 max_concurrent_downloads,
87 parser_workers,
88 max_concurrent_pipelines,
89 checkpoint_path,
90 checkpoint_interval,
91 } = self;
92
93 let state = CrawlerState::new();
94 let pipelines = Arc::new(item_pipelines);
95 let channel_capacity = max_concurrent_downloads * 2;
96
97 let (res_tx, res_rx) = bounded_async(channel_capacity);
98 let (item_tx, item_rx) = bounded_async(channel_capacity);
99
100 let (salvaged_requests_tx, salvaged_requests_rx) = bounded_async(channel_capacity);
101
102 let initial_requests_task = spawn_initial_requests_task::<S>(
103 scheduler.clone(),
104 spider.clone(),
105 salvaged_requests_tx.clone(),
106 );
107
108 let downloader_task = spawn_downloader_task::<S, C>(
109 scheduler.clone(),
110 req_rx,
111 downloader,
112 Arc::new(Mutex::new(middlewares)),
113 state.clone(),
114 res_tx.clone(),
115 max_concurrent_downloads,
116 salvaged_requests_tx.clone(),
117 );
118
119 let parser_task = spawn_parser_task::<S>(
120 scheduler.clone(),
121 spider.clone(),
122 state.clone(),
123 res_rx,
124 item_tx.clone(),
125 parser_workers,
126 salvaged_requests_tx.clone(),
127 );
128
129 let item_processor_task = spawn_item_processor_task::<S>(
130 state.clone(),
131 item_rx,
132 pipelines.clone(),
133 max_concurrent_pipelines,
134 );
135
136 if let (Some(path), Some(interval)) = (&checkpoint_path, checkpoint_interval) {
137 let scheduler_clone = scheduler.clone();
138 let pipelines_clone = pipelines.clone();
139 let path_clone = path.clone();
140 let salvaged_requests_rx_clone = salvaged_requests_rx.clone();
141 tokio::spawn(async move {
142 let mut interval_timer = tokio::time::interval(interval);
143 interval_timer.tick().await;
144 loop {
145 tokio::select! {
146 _ = interval_timer.tick() => {
147 if let Ok(scheduler_checkpoint) = scheduler_clone.snapshot().await &&
148 let Err(e) = save_checkpoint::<S>(&path_clone, scheduler_checkpoint, &pipelines_clone, salvaged_requests_rx_clone.clone()).await {
149 error!("Periodic checkpoint save failed: {}", e);
150 }
151 }
152 }
153 }
154 });
155 }
156
157 tokio::select! {
158 _ = tokio::signal::ctrl_c() => {
159 info!("Ctrl-C received, initiating graceful shutdown.");
160 }
161 _ = async {
162 loop {
163 if scheduler.is_idle() && state.is_idle() {
164 tokio::time::sleep(Duration::from_millis(50)).await;
165 if scheduler.is_idle() && state.is_idle() {
166 break;
167 }
168 }
169 tokio::time::sleep(Duration::from_millis(100)).await;
170 }
171 } => {
172 info!("Crawl has become idle, initiating shutdown.");
173 }
174 }
175
176 info!("Initiating actor shutdowns.");
177
178 let scheduler_checkpoint = scheduler.snapshot().await?;
179
180 drop(res_tx);
181 drop(item_tx);
182 drop(salvaged_requests_tx);
183
184 scheduler.shutdown().await?;
185
186 item_processor_task
187 .await
188 .map_err(|e| SpiderError::GeneralError(format!("Item processor task failed: {}", e)))?;
189
190 parser_task
191 .await
192 .map_err(|e| SpiderError::GeneralError(format!("Parser task failed: {}", e)))?;
193
194 downloader_task
195 .await
196 .map_err(|e| SpiderError::GeneralError(format!("Downloader task failed: {}", e)))?;
197
198 initial_requests_task.await.map_err(|e| {
199 SpiderError::GeneralError(format!("Initial requests task failed: {}", e))
200 })?;
201
202 if let Some(path) = &checkpoint_path
204 && let Err(e) =
205 save_checkpoint::<S>(path, scheduler_checkpoint, &pipelines, salvaged_requests_rx)
206 .await
207 {
208 error!("Final checkpoint save failed: {}", e);
209 }
210
211 info!("Closing item pipelines...");
213 let closing_futures: Vec<_> = pipelines.iter().map(|p| p.close()).collect();
214 join_all(closing_futures).await;
215
216 info!("Crawl finished successfully.");
217 Ok(())
218 }
219}
220
221fn spawn_initial_requests_task<S>(
222 scheduler: Arc<Scheduler>,
223 spider: Arc<Mutex<S>>,
224 salvaged_requests_tx: AsyncSender<Request>,
225) -> tokio::task::JoinHandle<()>
226where
227 S: Spider + 'static,
228 S::Item: ScrapedItem,
229{
230 tokio::spawn(async move {
231 match spider.lock().await.start_requests() {
232 Ok(requests) => {
233 for mut req in requests {
234 req.url.set_fragment(None);
235 match scheduler.enqueue_request(req).await {
236 Ok(_) => {}
237 Err((req, e)) => {
238 error!("Failed to enqueue initial request: {}", e);
239 if salvaged_requests_tx.send(req).await.is_err() {
240 error!("Failed to send salvaged request to channel.");
241 }
242 }
243 }
244 }
245 }
246 Err(e) => error!("Failed to create start requests: {}", e),
247 }
248 })
249}
250
251async fn save_checkpoint<S: Spider>(
252 path: &Path,
253 mut scheduler_checkpoint: SchedulerCheckpoint,
254 pipelines: &Arc<Vec<Box<dyn Pipeline<S::Item>>>>,
255 salvaged_requests_rx: AsyncReceiver<Request>,
256) -> Result<(), SpiderError>
257where
258 S::Item: ScrapedItem,
259{
260 info!("Saving checkpoint to {:?}", path);
261
262 let mut pipelines_checkpoint_map = HashMap::new();
263 for pipeline in pipelines.iter() {
264 if let Some(state) = pipeline.get_state().await? {
265 pipelines_checkpoint_map.insert(pipeline.name().to_string(), state);
266 }
267 }
268
269 let mut salvaged_requests_vec: VecDeque<Request> = VecDeque::new();
270 while let Ok(req_option) = salvaged_requests_rx.try_recv() {
271 if let Some(req) = req_option {
272 salvaged_requests_vec.push_back(req);
273 }
274 }
275 scheduler_checkpoint
276 .request_queue
277 .extend(salvaged_requests_vec.drain(..));
278 if !salvaged_requests_vec.is_empty() {
279 warn!(
280 "Found {} salvaged requests during checkpoint. These have been added to the request queue.",
281 salvaged_requests_vec.len()
282 );
283 }
284
285 let checkpoint = Checkpoint {
286 scheduler: scheduler_checkpoint,
287 pipelines: pipelines_checkpoint_map,
288 };
289
290 let tmp_path = path.with_extension("tmp");
292 let encoded = rmp_serde::to_vec(&checkpoint)
293 .map_err(|e| SpiderError::GeneralError(format!("Failed to serialize checkpoint: {}", e)))?;
294 fs::write(&tmp_path, encoded).map_err(|e| {
295 SpiderError::GeneralError(format!(
296 "Failed to write checkpoint to temporary file: {}",
297 e
298 ))
299 })?;
300 fs::rename(&tmp_path, path).map_err(|e| {
301 SpiderError::GeneralError(format!("Failed to rename temporary checkpoint file: {}", e))
302 })?;
303
304 info!("Checkpoint saved successfully.");
305 Ok(())
306}
307
308#[allow(clippy::too_many_arguments)]
309fn spawn_downloader_task<S, C>(
310 scheduler: Arc<Scheduler>,
311 req_rx: AsyncReceiver<Request>,
312 downloader: Arc<dyn Downloader<Client = C> + Send + Sync>,
313 middlewares: Arc<Mutex<Vec<Box<dyn Middleware<C> + Send + Sync>>>>,
314 state: Arc<CrawlerState>,
315 res_tx: AsyncSender<Response>,
316 max_concurrent_downloads: usize,
317 salvaged_requests_tx: AsyncSender<Request>,
318) -> tokio::task::JoinHandle<()>
319where
320 S: Spider + 'static,
321 S::Item: ScrapedItem,
322 C: Send + Sync + 'static,
323{
324 let semaphore = Arc::new(Semaphore::new(max_concurrent_downloads));
325 let mut tasks = JoinSet::new();
326
327 tokio::spawn(async move {
328 while let Ok(request) = req_rx.recv().await {
329 let permit = match semaphore.clone().acquire_owned().await {
330 Ok(permit) => permit,
331 Err(_) => {
332 warn!("Semaphore closed, shutting down downloader actor.");
333 break;
334 }
335 };
336
337 state.in_flight_requests.fetch_add(1, Ordering::SeqCst);
338 let downloader_clone = Arc::clone(&downloader);
339 let middlewares_clone = Arc::clone(&middlewares);
340 let res_tx_clone = res_tx.clone();
341 let state_clone = Arc::clone(&state);
342 let scheduler_clone = Arc::clone(&scheduler);
343 let _salvaged_requests_tx_clone = salvaged_requests_tx.clone();
344
345 tasks.spawn(async move {
346 let mut processed_request = request;
347 let mut early_returned_response: Option<Response> = None;
348
349 for mw in middlewares_clone.lock().await.iter_mut() {
351 match mw.process_request(downloader_clone.client(), processed_request.clone()).await {
352 Ok(MiddlewareAction::Continue(req)) => {
353 processed_request = req;
354 }
355 Ok(MiddlewareAction::Retry(req, delay)) => {
356 tokio::time::sleep(delay).await;
357 if scheduler_clone.enqueue_request(*req).await.is_err() {
358 error!("Failed to re-enqueue retried request.");
359 }
360 return;
361 }
362 Ok(MiddlewareAction::Drop) => {
363 debug!("Request dropped by middleware.");
364 return;
365 }
366 Ok(MiddlewareAction::ReturnResponse(resp)) => {
367 early_returned_response = Some(resp);
368 break;
369 }
370 Err(e) => {
371 error!("Request middleware error: {:?}", e);
372 return;
373 }
374 }
375 }
376
377 let mut response = match early_returned_response {
379 Some(resp) => resp,
380 None => match downloader_clone.download(processed_request).await {
381 Ok(resp) => resp,
382 Err(e) => {
383 error!("Download error: {:?}", e);
384 return;
385 }
386 },
387 };
388
389 for mw in middlewares_clone.lock().await.iter_mut().rev() {
391 match mw.process_response(response.clone()).await {
392 Ok(MiddlewareAction::Continue(res)) => {
393 response = res;
394 }
395 Ok(MiddlewareAction::Retry(req, delay)) => {
396 tokio::time::sleep(delay).await;
397 if scheduler_clone.enqueue_request(*req).await.is_err() {
398 error!("Failed to re-enqueue retried request.");
399 }
400 return;
401 }
402 Ok(MiddlewareAction::Drop) => {
403 debug!("Response dropped by middleware.");
404 return;
405 }
406 Ok(MiddlewareAction::ReturnResponse(_)) => {
407 debug!("ReturnResponse action encountered in process_response; this is unexpected.");
408 continue;
409 }
410 Err(e) => {
411 error!("Response middleware error: {:?}", e);
412 return;
413 }
414 }
415 }
416
417 if res_tx_clone.send(response).await.is_err() {
418 error!("Response channel closed, cannot send parsed response.");
419 }
420
421 state_clone.in_flight_requests.fetch_sub(1, Ordering::SeqCst);
422 drop(permit);
423 });
424 }
425 while let Some(res) = tasks.join_next().await {
426 if let Err(e) = res {
427 error!("A download task failed: {:?}", e);
428 }
429 }
430 })
431}
432
433fn spawn_parser_task<S>(
434 scheduler: Arc<Scheduler>,
435 spider: Arc<Mutex<S>>,
436 state: Arc<CrawlerState>,
437 res_rx: AsyncReceiver<Response>,
438 item_tx: AsyncSender<S::Item>,
439 parser_workers: usize,
440 salvaged_requests_tx: AsyncSender<Request>,
441) -> tokio::task::JoinHandle<()>
442where
443 S: Spider + 'static,
444 S::Item: ScrapedItem,
445{
446 let mut tasks = JoinSet::new();
447 let internal_parse_tx: AsyncSender<Response>;
448 let internal_parse_rx: AsyncReceiver<Response>;
449 (internal_parse_tx, internal_parse_rx) = bounded_async(parser_workers * 2);
450
451 for _ in 0..parser_workers {
453 let internal_parse_rx_clone = internal_parse_rx.clone();
454 let spider_clone = Arc::clone(&spider);
455 let scheduler_clone = Arc::clone(&scheduler);
456 let item_tx_clone = item_tx.clone();
457 let state_clone = Arc::clone(&state);
458 let salvaged_requests_tx_clone = salvaged_requests_tx.clone();
459
460 tasks.spawn(async move {
461 while let Ok(response) = internal_parse_rx_clone.recv().await {
462 debug!("Parsing response from {}", response.url);
463 match spider_clone.lock().await.parse(response).await {
464 Ok(outputs) => {
465 process_crawl_outputs::<S>(
466 outputs,
467 scheduler_clone.clone(),
468 item_tx_clone.clone(),
469 state_clone.clone(),
470 salvaged_requests_tx_clone.clone(),
471 )
472 .await;
473 }
474 Err(e) => error!("Spider parsing error: {:?}", e),
475 }
476 state_clone.parsing_responses.fetch_sub(1, Ordering::SeqCst);
477 }
478 });
479 }
480
481 tokio::spawn(async move {
482 while let Ok(response) = res_rx.recv().await {
483 state.parsing_responses.fetch_add(1, Ordering::SeqCst);
484 if internal_parse_tx.send(response).await.is_err() {
485 error!("Internal parse channel closed, cannot send response to parser worker.");
486 state.parsing_responses.fetch_sub(1, Ordering::SeqCst); }
488 }
489 drop(internal_parse_tx);
491
492 while let Some(res) = tasks.join_next().await {
494 if let Err(e) = res {
495 error!("A parsing worker task failed: {:?}", e);
496 }
497 }
498 })
499}
500
501async fn process_crawl_outputs<S>(
502 outputs: ParseOutput<S::Item>,
503 scheduler: Arc<Scheduler>,
504 item_tx: AsyncSender<S::Item>,
505 state: Arc<CrawlerState>,
506 salvaged_requests_tx: AsyncSender<Request>,
507) where
508 S: Spider + 'static,
509 S::Item: ScrapedItem,
510{
511 let (items, requests) = outputs.into_parts();
512 info!(
513 "Processed {} requests and {} items from spider output.",
514 requests.len(),
515 items.len()
516 );
517
518 let mut request_error_total = 0;
519 for request in requests {
520 match scheduler.enqueue_request(request).await {
521 Ok(_) => {}
522 Err((req, _)) => {
523 request_error_total += 1;
524 if salvaged_requests_tx.send(req).await.is_err() {
525 error!(
526 "Failed to send salvaged request to channel from process_crawl_outputs."
527 );
528 }
529 }
530 }
531 }
532 if request_error_total > 0 {
533 error!(
534 "Failed to enqueue {} requests: scheduler channel closed.",
535 request_error_total
536 );
537 }
538
539 let mut item_error_total = 0;
540 for item in items {
541 state.processing_items.fetch_add(1, Ordering::SeqCst);
542 if item_tx.send(item).await.is_err() {
543 item_error_total += 1;
544 state.processing_items.fetch_sub(1, Ordering::SeqCst);
545 }
546 }
547 if item_error_total > 0 {
548 error!(
549 "Failed to send {} scraped items: channel closed.",
550 item_error_total
551 );
552 }
553}
554
555fn spawn_item_processor_task<S>(
556 state: Arc<CrawlerState>,
557 item_rx: AsyncReceiver<S::Item>,
558 pipelines: Arc<Vec<Box<dyn Pipeline<S::Item>>>>,
559 max_concurrent_pipelines: usize,
560) -> tokio::task::JoinHandle<()>
561where
562 S: Spider + 'static,
563 S::Item: ScrapedItem,
564{
565 let mut tasks = JoinSet::new();
566 let semaphore = Arc::new(Semaphore::new(max_concurrent_pipelines));
567 tokio::spawn(async move {
568 while let Ok(item) = item_rx.recv().await {
569 let permit = match semaphore.clone().acquire_owned().await {
570 Ok(p) => p,
571 Err(_) => {
572 warn!("Semaphore closed, shutting down item processor actor.");
573 break;
574 }
575 };
576
577 let state_clone = Arc::clone(&state);
578 let pipelines_clone = Arc::clone(&pipelines);
579 tasks.spawn(async move {
580 let mut item_to_process = Some(item);
581 for pipeline in pipelines_clone.iter() {
582 if let Some(item) = item_to_process.take() {
583 match pipeline.process_item(item).await {
584 Ok(Some(next_item)) => item_to_process = Some(next_item),
585 Ok(None) => break, Err(e) => {
587 error!("Pipeline error for {}: {:?}", pipeline.name(), e);
588 break;
589 }
590 }
591 } else {
592 break;
593 }
594 }
595 state_clone.processing_items.fetch_sub(1, Ordering::SeqCst);
596 drop(permit);
597 });
598 }
599 while let Some(res) = tasks.join_next().await {
600 if let Err(e) = res {
601 error!("An item processing task failed: {:?}", e);
602 }
603 }
604 })
605}