1use crate::{
2 crawler::config::{CircuitBreaker, CrawlerConfig, MemoryMonitor},
3 crawler::filter::{is_same_domain, should_crawl_url},
4 crawler::pagination::{PaginationConfig, PaginationDetector},
5 crawler::rate_limiter::DomainRateLimiter,
6 engines::{http::HttpEngine, ScrapeEngine},
7 error::{Result, ScrapeError},
8 format,
9 types::{CrawlRequest, Document, ScrapeRequest},
10 utils::{normalize_url_string, robots},
11};
12use scraper::{Html, Selector};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tokio::sync::{mpsc, Semaphore};
16use tokio::time::{timeout, Duration};
17use tracing::{debug, info, warn, error};
18use url::Url;
19
20pub struct ParallelCrawler {
22 scrape_semaphore: Arc<Semaphore>,
24 #[allow(dead_code)]
26 process_semaphore: Arc<Semaphore>,
27 max_workers: usize,
29}
30
31impl ParallelCrawler {
32 pub fn new() -> Self {
34 let num_cpus = num_cpus::get();
35 Self {
36 scrape_semaphore: Arc::new(Semaphore::new(num_cpus)),
37 process_semaphore: Arc::new(Semaphore::new(num_cpus / 2)),
38 max_workers: num_cpus,
39 }
40 }
41
42 pub fn with_config(config: &CrawlerConfig) -> Self {
44 let max_concurrent = config.max_concurrent_requests;
45 Self {
46 scrape_semaphore: Arc::new(Semaphore::new(max_concurrent)),
47 process_semaphore: Arc::new(Semaphore::new(max_concurrent / 2)),
48 max_workers: max_concurrent,
49 }
50 }
51
52 pub async fn crawl_parallel(&self, request: &CrawlRequest) -> Result<Vec<Document>> {
54 info!(
55 "Starting parallel crawl from URL: {} with {} workers",
56 request.url, self.max_workers
57 );
58
59 let _base_url = Url::parse(&request.url)
61 .map_err(|e| ScrapeError::InvalidUrl(format!("Invalid base URL: {}", e)))?;
62
63 let normalized_base_url = normalize_url_string(&request.url)
65 .map_err(|e| ScrapeError::InvalidUrl(format!("Failed to normalize base URL: {}", e)))?;
66
67 let config = CrawlerConfig::default();
69
70 let circuit_breaker = Arc::new(CircuitBreaker::new(config.circuit_breaker_threshold));
72 let memory_monitor = Arc::new(MemoryMonitor::new(
73 config.max_memory_mb,
74 config.enable_memory_monitoring,
75 ));
76
77 let visited = Arc::new(tokio::sync::RwLock::new(HashSet::new()));
79 let url_depths = Arc::new(tokio::sync::RwLock::new(HashMap::new()));
80
81 let robots_allowed = check_robots_txt(&request.url, request.ignore_sitemap).await;
83
84 let rate_limit = std::env::var("CRAWL_RATE_LIMIT_PER_SEC")
86 .ok()
87 .and_then(|v| v.parse().ok())
88 .unwrap_or(2);
89 let rate_limiter = Arc::new(DomainRateLimiter::new(rate_limit));
90
91 info!(
92 "Rate limiting enabled: {} requests/second per domain",
93 rate_limit
94 );
95
96 let pagination_config = PaginationConfig {
98 max_pages: request.max_pagination_pages.unwrap_or(50) as usize,
99 max_depth: request.max_depth as usize,
100 detect_circular: true,
101 };
102 let pagination_detector = Arc::new(tokio::sync::Mutex::new(PaginationDetector::new(
103 pagination_config,
104 )));
105 let detect_pagination = request.detect_pagination.unwrap_or(true);
106
107 let (url_tx, url_rx) = mpsc::channel::<(String, u32)>(config.max_queue_size);
110 let (doc_tx, doc_rx) = mpsc::channel::<Document>(config.max_queue_size);
111
112 url_depths.write().await.insert(normalized_base_url.clone(), 0);
114 if let Err(e) = url_tx.send((normalized_base_url.clone(), 0)).await {
115 return Err(ScrapeError::Internal(format!("Failed to enqueue base URL: {}", e)));
116 }
117
118 let request_clone = request.clone();
120
121 let url_rx = Arc::new(tokio::sync::Mutex::new(url_rx));
123 let doc_tx_clone = doc_tx.clone();
124
125 let mut worker_handles = Vec::new();
127
128 for worker_id in 0..self.max_workers {
129 let url_rx_clone = Arc::clone(&url_rx);
130 let doc_tx_worker = doc_tx_clone.clone();
131 let url_tx_worker = url_tx.clone();
132 let visited_clone = Arc::clone(&visited);
133 let url_depths_clone = Arc::clone(&url_depths);
134 let circuit_breaker_clone = Arc::clone(&circuit_breaker);
135 let memory_monitor_clone = Arc::clone(&memory_monitor);
136 let rate_limiter_clone = Arc::clone(&rate_limiter);
137 let scrape_semaphore_clone = Arc::clone(&self.scrape_semaphore);
138 let pagination_detector_clone = Arc::clone(&pagination_detector);
139 let config_clone = config.clone();
140 let request_clone2 = request_clone.clone();
141
142 let handle = tokio::spawn(async move {
143 Self::scrape_worker(
144 worker_id,
145 url_rx_clone,
146 doc_tx_worker,
147 url_tx_worker,
148 visited_clone,
149 url_depths_clone,
150 circuit_breaker_clone,
151 memory_monitor_clone,
152 rate_limiter_clone,
153 scrape_semaphore_clone,
154 pagination_detector_clone,
155 config_clone,
156 request_clone2,
157 robots_allowed,
158 detect_pagination,
159 )
160 .await
161 });
162
163 worker_handles.push(handle);
164 }
165
166 drop(url_tx);
168 drop(doc_tx);
169
170 let doc_limit = request.limit as usize;
172 let mut documents = Vec::with_capacity(doc_limit.min(1000));
173 let mut doc_rx = doc_rx;
174
175 while let Some(doc) = doc_rx.recv().await {
177 documents.push(doc);
178
179 if documents.len() >= doc_limit {
181 info!("Reached document limit of {}, stopping collection", doc_limit);
182 break;
183 }
184 }
185
186 for handle in worker_handles {
188 if documents.len() >= doc_limit {
190 handle.abort();
191 } else if let Err(e) = handle.await {
192 if !e.is_cancelled() {
193 error!("Worker task failed: {}", e);
194 }
195 }
196 }
197
198 let visited_count = visited.read().await.len();
200 info!(
201 "Parallel crawl completed. Total pages crawled: {}, visited: {}",
202 documents.len(),
203 visited_count
204 );
205
206 Ok(documents)
207 }
208
209 #[allow(clippy::too_many_arguments)]
211 async fn scrape_worker(
212 worker_id: usize,
213 url_rx: Arc<tokio::sync::Mutex<mpsc::Receiver<(String, u32)>>>,
214 doc_tx: mpsc::Sender<Document>,
215 url_tx: mpsc::Sender<(String, u32)>,
216 visited: Arc<tokio::sync::RwLock<HashSet<String>>>,
217 url_depths: Arc<tokio::sync::RwLock<HashMap<String, u32>>>,
218 circuit_breaker: Arc<CircuitBreaker>,
219 memory_monitor: Arc<MemoryMonitor>,
220 rate_limiter: Arc<DomainRateLimiter>,
221 scrape_semaphore: Arc<Semaphore>,
222 pagination_detector: Arc<tokio::sync::Mutex<PaginationDetector>>,
223 config: CrawlerConfig,
224 request: CrawlRequest,
225 robots_allowed: bool,
226 detect_pagination: bool,
227 ) -> Result<()> {
228 debug!("Worker {} started", worker_id);
229
230 let engine = HttpEngine::new()?;
231
232 loop {
233 let url_item = {
235 let mut rx = url_rx.lock().await;
236 rx.recv().await
237 };
238
239 let (current_url, current_depth) = match url_item {
240 Some(item) => item,
241 None => {
242 debug!("Worker {}: Queue closed, exiting", worker_id);
243 break;
244 }
245 };
246
247 {
249 let visited_read = visited.read().await;
250 if visited_read.contains(¤t_url) {
251 continue;
252 }
253 }
254
255 {
257 let mut visited_write = visited.write().await;
258 if visited_write.contains(¤t_url) {
259 continue;
260 }
261 visited_write.insert(current_url.clone());
262 }
263
264 if current_depth > request.max_depth {
266 debug!("Worker {}: Skipping URL due to depth limit: {}", worker_id, current_url);
267 continue;
268 }
269
270 let domain = match Url::parse(¤t_url) {
272 Ok(parsed) => parsed.host_str().unwrap_or("unknown").to_string(),
273 Err(_) => "unknown".to_string(),
274 };
275
276 if config.enable_circuit_breaker && circuit_breaker.should_skip(&domain) {
278 warn!(
279 "Worker {}: Circuit breaker: Skipping domain {} due to excessive failures",
280 worker_id, domain
281 );
282 continue;
283 }
284
285 if !robots_allowed {
287 match robots::is_allowed_default(¤t_url).await {
288 Ok(allowed) => {
289 if !allowed {
290 warn!("Worker {}: Robots.txt disallows crawling: {}", worker_id, current_url);
291 continue;
292 }
293 }
294 Err(e) => {
295 warn!("Worker {}: Failed to check robots.txt for {}: {}", worker_id, current_url, e);
296 }
297 }
298 }
299
300 if !should_crawl_url(¤t_url, &request.include_paths, &request.exclude_paths) {
302 debug!(
303 "Worker {}: URL filtered out by include/exclude patterns: {}",
304 worker_id, current_url
305 );
306 continue;
307 }
308
309 if config.enable_memory_monitoring {
311 if let Err(e) = memory_monitor.check_memory_limit() {
312 warn!("Worker {}: Memory limit check failed: {}", worker_id, e);
313 return Err(e);
314 }
315 }
316
317 let _permit = scrape_semaphore.acquire().await
319 .map_err(|e| ScrapeError::Internal(format!("Failed to acquire scrape permit: {}", e)))?;
320
321 if let Err(e) = rate_limiter.wait_for_permission(¤t_url).await {
323 warn!("Worker {}: Rate limiter error for {}: {}", worker_id, current_url, e);
324 continue;
325 }
326
327 info!(
329 "Worker {}: Crawling URL: {} (depth: {})",
330 worker_id, current_url, current_depth
331 );
332
333 let scrape_result = timeout(
334 Duration::from_secs(30),
335 scrape_url(¤t_url, &engine)
336 ).await;
337
338 match scrape_result {
339 Ok(Ok((document, links, html))) => {
340 if config.enable_circuit_breaker {
342 circuit_breaker.record_success(&domain);
343 }
344
345 if doc_tx.send(document).await.is_err() {
347 debug!("Worker {}: Document receiver closed, exiting", worker_id);
348 return Ok(());
349 }
350
351 if current_depth < request.max_depth {
353 let pagination_links = if detect_pagination {
355 let mut detector = pagination_detector.lock().await;
356 detector.detect_pagination(&html, ¤t_url)
357 } else {
358 Vec::new()
359 };
360
361 for link in &pagination_links {
363 let normalized_link = match normalize_url_string(link) {
364 Ok(url) => url,
365 Err(e) => {
366 debug!("Worker {}: Failed to normalize pagination link {}: {}", worker_id, link, e);
367 continue;
368 }
369 };
370
371 {
373 let visited_read = visited.read().await;
374 let depths_read = url_depths.read().await;
375 if visited_read.contains(&normalized_link) || depths_read.contains_key(&normalized_link) {
376 continue;
377 }
378 }
379
380 if is_same_domain(&normalized_link, &request.url) {
382 {
384 let mut depths_write = url_depths.write().await;
385 depths_write.insert(normalized_link.clone(), current_depth);
386 }
387
388 if url_tx.send((normalized_link, current_depth)).await.is_err() {
389 debug!("Worker {}: URL receiver closed, exiting", worker_id);
390 return Ok(());
391 }
392 }
393 }
394
395 for link in links {
397 let normalized_link = match normalize_url_string(&link) {
399 Ok(url) => url,
400 Err(e) => {
401 debug!("Worker {}: Failed to normalize link {}: {}", worker_id, link, e);
402 continue;
403 }
404 };
405
406 {
408 let visited_read = visited.read().await;
409 let depths_read = url_depths.read().await;
410 if visited_read.contains(&normalized_link) || depths_read.contains_key(&normalized_link) {
411 continue;
412 }
413 }
414
415 if pagination_links.contains(&link) || pagination_links.contains(&normalized_link) {
417 continue;
418 }
419
420 let allow_link = if request.allow_external_links.unwrap_or(false) {
422 true
423 } else if request.allow_backward_links.unwrap_or(false) {
424 is_same_domain(&normalized_link, &request.url)
426 } else {
427 is_same_domain(&normalized_link, &request.url)
429 && is_forward_link(&normalized_link, ¤t_url)
430 };
431
432 if allow_link {
433 {
435 let mut depths_write = url_depths.write().await;
436 if depths_write.len() < config.max_queue_size {
437 depths_write.insert(normalized_link.clone(), current_depth + 1);
438 } else {
439 debug!("Worker {}: Queue limit reached, skipping link: {}", worker_id, link);
440 continue;
441 }
442 }
443
444 if url_tx.send((normalized_link, current_depth + 1)).await.is_err() {
445 debug!("Worker {}: URL receiver closed, exiting", worker_id);
446 return Ok(());
447 }
448 }
449 }
450 }
451 }
452 Ok(Err(e)) => {
453 warn!("Worker {}: Failed to scrape {}: {}", worker_id, current_url, e);
454
455 if config.enable_circuit_breaker {
457 circuit_breaker.record_failure(&domain);
458 }
459 }
460 Err(_) => {
461 warn!("Worker {}: Timeout scraping {}", worker_id, current_url);
462
463 if config.enable_circuit_breaker {
465 circuit_breaker.record_failure(&domain);
466 }
467 }
468 }
469 }
470
471 debug!("Worker {} finished", worker_id);
472 Ok(())
473 }
474}
475
476impl Default for ParallelCrawler {
477 fn default() -> Self {
478 Self::new()
479 }
480}
481
482async fn check_robots_txt(url: &str, ignore_sitemap: Option<bool>) -> bool {
484 if ignore_sitemap.unwrap_or(false) {
485 return true;
486 }
487
488 match robots::is_allowed_default(url).await {
489 Ok(allowed) => allowed,
490 Err(e) => {
491 warn!("Failed to check robots.txt: {}, allowing by default", e);
492 true
493 }
494 }
495}
496
497async fn scrape_url(url: &str, engine: &HttpEngine) -> Result<(Document, Vec<String>, String)> {
499 let scrape_request = ScrapeRequest {
501 url: url.to_string(),
502 formats: vec!["markdown".to_string(), "links".to_string()],
503 headers: HashMap::new(),
504 include_tags: vec![],
505 exclude_tags: vec![],
506 only_main_content: true,
507 timeout: 30000,
508 wait_for: 0,
509 remove_base64_images: true,
510 skip_tls_verification: false,
511 engine: "http".to_string(),
512 wait_for_selector: None,
513 actions: vec![],
514 screenshot: false,
515 screenshot_format: "png".to_string(),
516 };
517
518 let raw_result = engine.scrape(&scrape_request).await?;
520
521 let links = extract_links(&raw_result.html, url)?;
523
524 let html = raw_result.html.clone();
526
527 let document = format::process_scrape_result(raw_result, &scrape_request).await?;
529
530 Ok((document, links, html))
531}
532
533fn extract_links(html: &str, base_url: &str) -> Result<Vec<String>> {
535 let document = Html::parse_document(html);
536 let selector = Selector::parse("a[href]")
537 .map_err(|e| ScrapeError::Internal(format!("Failed to create link selector: {:?}", e)))?;
538
539 let base = Url::parse(base_url)
540 .map_err(|e| ScrapeError::InvalidUrl(format!("Invalid base URL: {}", e)))?;
541
542 let mut links = Vec::new();
543
544 for element in document.select(&selector) {
545 if let Some(href) = element.value().attr("href") {
546 if href.starts_with("javascript:")
548 || href.starts_with("mailto:")
549 || href.starts_with("tel:")
550 || href.starts_with('#')
551 {
552 continue;
553 }
554
555 match base.join(href) {
557 Ok(absolute_url) => {
558 let url_str = absolute_url.to_string();
559 let url_without_fragment = url_str.split('#').next().unwrap_or(&url_str);
561 links.push(url_without_fragment.to_string());
562 }
563 Err(_) => {
564 continue;
566 }
567 }
568 }
569 }
570
571 links.sort();
573 links.dedup();
574
575 Ok(links)
576}
577
578fn is_forward_link(link: &str, current: &str) -> bool {
580 let link_parsed = match Url::parse(link) {
581 Ok(u) => u,
582 Err(_) => return false,
583 };
584
585 let current_parsed = match Url::parse(current) {
586 Ok(u) => u,
587 Err(_) => return false,
588 };
589
590 let link_path = link_parsed.path();
591 let current_path = current_parsed.path();
592
593 link_path == current_path || link_path.starts_with(current_path)
597}
598
599#[cfg(test)]
600mod tests {
601 use super::*;
602
603 #[tokio::test]
604 async fn test_parallel_crawler_creation() {
605 let crawler = ParallelCrawler::new();
606 assert_eq!(crawler.max_workers, num_cpus::get());
607 }
608
609 #[tokio::test]
610 async fn test_parallel_crawler_with_config() {
611 let mut config = CrawlerConfig::default();
612 config.max_concurrent_requests = 5;
613
614 let crawler = ParallelCrawler::with_config(&config);
615 assert_eq!(crawler.max_workers, 5);
616 }
617
618 #[test]
619 fn test_extract_links() {
620 let html = r##"
621 <html>
622 <body>
623 <a href="/page1">Page 1</a>
624 <a href="/page2">Page 2</a>
625 <a href="https://example.com/page3">Page 3</a>
626 <a href="javascript:void(0)">JS</a>
627 <a href="mailto:test@example.com">Email</a>
628 <a href="#section">Section</a>
629 </body>
630 </html>
631 "##;
632
633 let links = extract_links(html, "https://example.com").unwrap();
634
635 assert!(links.contains(&"https://example.com/page1".to_string()));
636 assert!(links.contains(&"https://example.com/page2".to_string()));
637 assert!(links.contains(&"https://example.com/page3".to_string()));
638 assert!(!links.iter().any(|l| l.contains("javascript:")));
639 assert!(!links.iter().any(|l| l.contains("mailto:")));
640 assert!(!links.iter().any(|l| l.contains('#')));
641 }
642
643 #[test]
644 fn test_is_forward_link() {
645 assert!(is_forward_link(
646 "https://example.com/blog/post1",
647 "https://example.com/blog"
648 ));
649
650 assert!(is_forward_link(
651 "https://example.com/blog",
652 "https://example.com/blog"
653 ));
654
655 assert!(!is_forward_link(
656 "https://example.com/about",
657 "https://example.com/blog"
658 ));
659 }
660}