Skip to main content

essence/crawler/
parallel.rs

1use crate::{
2    crawler::config::{CircuitBreaker, CrawlerConfig, MemoryMonitor},
3    crawler::filter::{is_same_domain, should_crawl_url},
4    crawler::pagination::{PaginationConfig, PaginationDetector},
5    crawler::rate_limiter::DomainRateLimiter,
6    engines::{http::HttpEngine, ScrapeEngine},
7    error::{Result, ScrapeError},
8    format,
9    types::{CrawlRequest, Document, ScrapeRequest},
10    utils::{normalize_url_string, robots},
11};
12use scraper::{Html, Selector};
13use std::collections::{HashMap, HashSet};
14use std::sync::Arc;
15use tokio::sync::{mpsc, Semaphore};
16use tokio::time::{timeout, Duration};
17use tracing::{debug, info, warn, error};
18use url::Url;
19
20/// Parallel crawler with configurable concurrency
21pub struct ParallelCrawler {
22    /// Semaphore to limit concurrent scrapes
23    scrape_semaphore: Arc<Semaphore>,
24    /// Semaphore to limit concurrent processing (reserved for future use)
25    #[allow(dead_code)]
26    process_semaphore: Arc<Semaphore>,
27    /// Maximum number of worker tasks
28    max_workers: usize,
29}
30
31impl ParallelCrawler {
32    /// Create a new parallel crawler with default settings
33    pub fn new() -> Self {
34        let num_cpus = num_cpus::get();
35        Self {
36            scrape_semaphore: Arc::new(Semaphore::new(num_cpus)),
37            process_semaphore: Arc::new(Semaphore::new(num_cpus / 2)),
38            max_workers: num_cpus,
39        }
40    }
41
42    /// Create a new parallel crawler with custom concurrency settings
43    pub fn with_config(config: &CrawlerConfig) -> Self {
44        let max_concurrent = config.max_concurrent_requests;
45        Self {
46            scrape_semaphore: Arc::new(Semaphore::new(max_concurrent)),
47            process_semaphore: Arc::new(Semaphore::new(max_concurrent / 2)),
48            max_workers: max_concurrent,
49        }
50    }
51
52    /// Main parallel crawl method
53    pub async fn crawl_parallel(&self, request: &CrawlRequest) -> Result<Vec<Document>> {
54        info!(
55            "Starting parallel crawl from URL: {} with {} workers",
56            request.url, self.max_workers
57        );
58
59        // Parse and validate base URL
60        let _base_url = Url::parse(&request.url)
61            .map_err(|e| ScrapeError::InvalidUrl(format!("Invalid base URL: {}", e)))?;
62
63        // Normalize the base URL to prevent duplicates
64        let normalized_base_url = normalize_url_string(&request.url)
65            .map_err(|e| ScrapeError::InvalidUrl(format!("Failed to normalize base URL: {}", e)))?;
66
67        // Initialize crawler config with bounds
68        let config = CrawlerConfig::default();
69
70        // Initialize circuit breaker and memory monitor
71        let circuit_breaker = Arc::new(CircuitBreaker::new(config.circuit_breaker_threshold));
72        let memory_monitor = Arc::new(MemoryMonitor::new(
73            config.max_memory_mb,
74            config.enable_memory_monitoring,
75        ));
76
77        // Shared state
78        let visited = Arc::new(tokio::sync::RwLock::new(HashSet::new()));
79        let url_depths = Arc::new(tokio::sync::RwLock::new(HashMap::new()));
80
81        // Check robots.txt for the domain
82        let robots_allowed = check_robots_txt(&request.url, request.ignore_sitemap).await;
83
84        // Create rate limiter
85        let rate_limit = std::env::var("CRAWL_RATE_LIMIT_PER_SEC")
86            .ok()
87            .and_then(|v| v.parse().ok())
88            .unwrap_or(2);
89        let rate_limiter = Arc::new(DomainRateLimiter::new(rate_limit));
90
91        info!(
92            "Rate limiting enabled: {} requests/second per domain",
93            rate_limit
94        );
95
96        // Initialize pagination detector with configuration
97        let pagination_config = PaginationConfig {
98            max_pages: request.max_pagination_pages.unwrap_or(50) as usize,
99            max_depth: request.max_depth as usize,
100            detect_circular: true,
101        };
102        let pagination_detector = Arc::new(tokio::sync::Mutex::new(PaginationDetector::new(
103            pagination_config,
104        )));
105        let detect_pagination = request.detect_pagination.unwrap_or(true);
106
107        // Channels for communication
108        // Buffer size matches queue size to prevent unbounded growth
109        let (url_tx, url_rx) = mpsc::channel::<(String, u32)>(config.max_queue_size);
110        let (doc_tx, doc_rx) = mpsc::channel::<Document>(config.max_queue_size);
111
112        // Add base URL to the queue
113        url_depths.write().await.insert(normalized_base_url.clone(), 0);
114        if let Err(e) = url_tx.send((normalized_base_url.clone(), 0)).await {
115            return Err(ScrapeError::Internal(format!("Failed to enqueue base URL: {}", e)));
116        }
117
118        // Clone request for workers
119        let request_clone = request.clone();
120
121        // Wrap receivers in Arc<Mutex> for sharing among workers
122        let url_rx = Arc::new(tokio::sync::Mutex::new(url_rx));
123        let doc_tx_clone = doc_tx.clone();
124
125        // Spawn scraping workers
126        let mut worker_handles = Vec::new();
127
128        for worker_id in 0..self.max_workers {
129            let url_rx_clone = Arc::clone(&url_rx);
130            let doc_tx_worker = doc_tx_clone.clone();
131            let url_tx_worker = url_tx.clone();
132            let visited_clone = Arc::clone(&visited);
133            let url_depths_clone = Arc::clone(&url_depths);
134            let circuit_breaker_clone = Arc::clone(&circuit_breaker);
135            let memory_monitor_clone = Arc::clone(&memory_monitor);
136            let rate_limiter_clone = Arc::clone(&rate_limiter);
137            let scrape_semaphore_clone = Arc::clone(&self.scrape_semaphore);
138            let pagination_detector_clone = Arc::clone(&pagination_detector);
139            let config_clone = config.clone();
140            let request_clone2 = request_clone.clone();
141
142            let handle = tokio::spawn(async move {
143                Self::scrape_worker(
144                    worker_id,
145                    url_rx_clone,
146                    doc_tx_worker,
147                    url_tx_worker,
148                    visited_clone,
149                    url_depths_clone,
150                    circuit_breaker_clone,
151                    memory_monitor_clone,
152                    rate_limiter_clone,
153                    scrape_semaphore_clone,
154                    pagination_detector_clone,
155                    config_clone,
156                    request_clone2,
157                    robots_allowed,
158                    detect_pagination,
159                )
160                .await
161            });
162
163            worker_handles.push(handle);
164        }
165
166        // Drop the original senders so receivers know when to close
167        drop(url_tx);
168        drop(doc_tx);
169
170        // Collect documents as they arrive
171        let doc_limit = request.limit as usize;
172        let mut documents = Vec::with_capacity(doc_limit.min(1000));
173        let mut doc_rx = doc_rx;
174
175        // Collect documents up to limit
176        while let Some(doc) = doc_rx.recv().await {
177            documents.push(doc);
178            
179            // Stop collecting once we reach the limit
180            if documents.len() >= doc_limit {
181                info!("Reached document limit of {}, stopping collection", doc_limit);
182                break;
183            }
184        }
185
186        // Wait for all workers to complete (or cancel them if we've reached the limit)
187        for handle in worker_handles {
188            // Abort remaining workers if we've reached the limit
189            if documents.len() >= doc_limit {
190                handle.abort();
191            } else if let Err(e) = handle.await {
192                if !e.is_cancelled() {
193                    error!("Worker task failed: {}", e);
194                }
195            }
196        }
197
198        // Log final stats
199        let visited_count = visited.read().await.len();
200        info!(
201            "Parallel crawl completed. Total pages crawled: {}, visited: {}",
202            documents.len(),
203            visited_count
204        );
205
206        Ok(documents)
207    }
208
209    /// Scraping worker that processes URLs from the queue
210    #[allow(clippy::too_many_arguments)]
211    async fn scrape_worker(
212        worker_id: usize,
213        url_rx: Arc<tokio::sync::Mutex<mpsc::Receiver<(String, u32)>>>,
214        doc_tx: mpsc::Sender<Document>,
215        url_tx: mpsc::Sender<(String, u32)>,
216        visited: Arc<tokio::sync::RwLock<HashSet<String>>>,
217        url_depths: Arc<tokio::sync::RwLock<HashMap<String, u32>>>,
218        circuit_breaker: Arc<CircuitBreaker>,
219        memory_monitor: Arc<MemoryMonitor>,
220        rate_limiter: Arc<DomainRateLimiter>,
221        scrape_semaphore: Arc<Semaphore>,
222        pagination_detector: Arc<tokio::sync::Mutex<PaginationDetector>>,
223        config: CrawlerConfig,
224        request: CrawlRequest,
225        robots_allowed: bool,
226        detect_pagination: bool,
227    ) -> Result<()> {
228        debug!("Worker {} started", worker_id);
229
230        let engine = HttpEngine::new()?;
231
232        loop {
233            // Acquire next URL from the shared queue
234            let url_item = {
235                let mut rx = url_rx.lock().await;
236                rx.recv().await
237            };
238
239            let (current_url, current_depth) = match url_item {
240                Some(item) => item,
241                None => {
242                    debug!("Worker {}: Queue closed, exiting", worker_id);
243                    break;
244                }
245            };
246
247            // Check if already visited
248            {
249                let visited_read = visited.read().await;
250                if visited_read.contains(&current_url) {
251                    continue;
252                }
253            }
254
255            // Mark as visited
256            {
257                let mut visited_write = visited.write().await;
258                if visited_write.contains(&current_url) {
259                    continue;
260                }
261                visited_write.insert(current_url.clone());
262            }
263
264            // Check depth limit
265            if current_depth > request.max_depth {
266                debug!("Worker {}: Skipping URL due to depth limit: {}", worker_id, current_url);
267                continue;
268            }
269
270            // Extract domain for circuit breaker
271            let domain = match Url::parse(&current_url) {
272                Ok(parsed) => parsed.host_str().unwrap_or("unknown").to_string(),
273                Err(_) => "unknown".to_string(),
274            };
275
276            // Check circuit breaker
277            if config.enable_circuit_breaker && circuit_breaker.should_skip(&domain) {
278                warn!(
279                    "Worker {}: Circuit breaker: Skipping domain {} due to excessive failures",
280                    worker_id, domain
281                );
282                continue;
283            }
284
285            // Check robots.txt
286            if !robots_allowed {
287                match robots::is_allowed_default(&current_url).await {
288                    Ok(allowed) => {
289                        if !allowed {
290                            warn!("Worker {}: Robots.txt disallows crawling: {}", worker_id, current_url);
291                            continue;
292                        }
293                    }
294                    Err(e) => {
295                        warn!("Worker {}: Failed to check robots.txt for {}: {}", worker_id, current_url, e);
296                    }
297                }
298            }
299
300            // Check if URL should be crawled based on filters
301            if !should_crawl_url(&current_url, &request.include_paths, &request.exclude_paths) {
302                debug!(
303                    "Worker {}: URL filtered out by include/exclude patterns: {}",
304                    worker_id, current_url
305                );
306                continue;
307            }
308
309            // Check memory limit
310            if config.enable_memory_monitoring {
311                if let Err(e) = memory_monitor.check_memory_limit() {
312                    warn!("Worker {}: Memory limit check failed: {}", worker_id, e);
313                    return Err(e);
314                }
315            }
316
317            // Acquire scrape semaphore permit
318            let _permit = scrape_semaphore.acquire().await
319                .map_err(|e| ScrapeError::Internal(format!("Failed to acquire scrape permit: {}", e)))?;
320
321            // Apply rate limiting before scraping
322            if let Err(e) = rate_limiter.wait_for_permission(&current_url).await {
323                warn!("Worker {}: Rate limiter error for {}: {}", worker_id, current_url, e);
324                continue;
325            }
326
327            // Scrape the URL with timeout
328            info!(
329                "Worker {}: Crawling URL: {} (depth: {})",
330                worker_id, current_url, current_depth
331            );
332
333            let scrape_result = timeout(
334                Duration::from_secs(30),
335                scrape_url(&current_url, &engine)
336            ).await;
337
338            match scrape_result {
339                Ok(Ok((document, links, html))) => {
340                    // Record success in circuit breaker
341                    if config.enable_circuit_breaker {
342                        circuit_breaker.record_success(&domain);
343                    }
344
345                    // Send document to collector
346                    if doc_tx.send(document).await.is_err() {
347                        debug!("Worker {}: Document receiver closed, exiting", worker_id);
348                        return Ok(());
349                    }
350
351                    // Process discovered links if we haven't reached max depth
352                    if current_depth < request.max_depth {
353                        // First, detect pagination links if enabled
354                        let pagination_links = if detect_pagination {
355                            let mut detector = pagination_detector.lock().await;
356                            detector.detect_pagination(&html, &current_url)
357                        } else {
358                            Vec::new()
359                        };
360
361                        // Add pagination links with priority (same depth as current)
362                        for link in &pagination_links {
363                            let normalized_link = match normalize_url_string(link) {
364                                Ok(url) => url,
365                                Err(e) => {
366                                    debug!("Worker {}: Failed to normalize pagination link {}: {}", worker_id, link, e);
367                                    continue;
368                                }
369                            };
370
371                            // Skip if already visited or queued
372                            {
373                                let visited_read = visited.read().await;
374                                let depths_read = url_depths.read().await;
375                                if visited_read.contains(&normalized_link) || depths_read.contains_key(&normalized_link) {
376                                    continue;
377                                }
378                            }
379
380                            // Check domain restrictions
381                            if is_same_domain(&normalized_link, &request.url) {
382                                // Add to queue
383                                {
384                                    let mut depths_write = url_depths.write().await;
385                                    depths_write.insert(normalized_link.clone(), current_depth);
386                                }
387                                
388                                if url_tx.send((normalized_link, current_depth)).await.is_err() {
389                                    debug!("Worker {}: URL receiver closed, exiting", worker_id);
390                                    return Ok(());
391                                }
392                            }
393                        }
394
395                        // Then process regular links
396                        for link in links {
397                            // Normalize the link to prevent duplicates
398                            let normalized_link = match normalize_url_string(&link) {
399                                Ok(url) => url,
400                                Err(e) => {
401                                    debug!("Worker {}: Failed to normalize link {}: {}", worker_id, link, e);
402                                    continue;
403                                }
404                            };
405
406                            // Skip if already visited or queued
407                            {
408                                let visited_read = visited.read().await;
409                                let depths_read = url_depths.read().await;
410                                if visited_read.contains(&normalized_link) || depths_read.contains_key(&normalized_link) {
411                                    continue;
412                                }
413                            }
414
415                            // Skip if this is a pagination link (already processed)
416                            if pagination_links.contains(&link) || pagination_links.contains(&normalized_link) {
417                                continue;
418                            }
419
420                            // Check domain restrictions
421                            let allow_link = if request.allow_external_links.unwrap_or(false) {
422                                true
423                            } else if request.allow_backward_links.unwrap_or(false) {
424                                // Allow backward links means crawl entire domain
425                                is_same_domain(&normalized_link, &request.url)
426                            } else {
427                                // Only allow links that are "forward" (same path or deeper)
428                                is_same_domain(&normalized_link, &request.url)
429                                    && is_forward_link(&normalized_link, &current_url)
430                            };
431
432                            if allow_link {
433                                // Add to queue
434                                {
435                                    let mut depths_write = url_depths.write().await;
436                                    if depths_write.len() < config.max_queue_size {
437                                        depths_write.insert(normalized_link.clone(), current_depth + 1);
438                                    } else {
439                                        debug!("Worker {}: Queue limit reached, skipping link: {}", worker_id, link);
440                                        continue;
441                                    }
442                                }
443                                
444                                if url_tx.send((normalized_link, current_depth + 1)).await.is_err() {
445                                    debug!("Worker {}: URL receiver closed, exiting", worker_id);
446                                    return Ok(());
447                                }
448                            }
449                        }
450                    }
451                }
452                Ok(Err(e)) => {
453                    warn!("Worker {}: Failed to scrape {}: {}", worker_id, current_url, e);
454
455                    // Record failure in circuit breaker
456                    if config.enable_circuit_breaker {
457                        circuit_breaker.record_failure(&domain);
458                    }
459                }
460                Err(_) => {
461                    warn!("Worker {}: Timeout scraping {}", worker_id, current_url);
462
463                    // Record failure in circuit breaker
464                    if config.enable_circuit_breaker {
465                        circuit_breaker.record_failure(&domain);
466                    }
467                }
468            }
469        }
470
471        debug!("Worker {} finished", worker_id);
472        Ok(())
473    }
474}
475
476impl Default for ParallelCrawler {
477    fn default() -> Self {
478        Self::new()
479    }
480}
481
482/// Check if robots.txt allows crawling
483async fn check_robots_txt(url: &str, ignore_sitemap: Option<bool>) -> bool {
484    if ignore_sitemap.unwrap_or(false) {
485        return true;
486    }
487
488    match robots::is_allowed_default(url).await {
489        Ok(allowed) => allowed,
490        Err(e) => {
491            warn!("Failed to check robots.txt: {}, allowing by default", e);
492            true
493        }
494    }
495}
496
497/// Scrape a single URL and extract links
498async fn scrape_url(url: &str, engine: &HttpEngine) -> Result<(Document, Vec<String>, String)> {
499    // Create a scrape request
500    let scrape_request = ScrapeRequest {
501        url: url.to_string(),
502        formats: vec!["markdown".to_string(), "links".to_string()],
503        headers: HashMap::new(),
504        include_tags: vec![],
505        exclude_tags: vec![],
506        only_main_content: true,
507        timeout: 30000,
508        wait_for: 0,
509        remove_base64_images: true,
510        skip_tls_verification: false,
511        engine: "http".to_string(),
512        wait_for_selector: None,
513        actions: vec![],
514        screenshot: false,
515        screenshot_format: "png".to_string(),
516    };
517
518    // Scrape the URL
519    let raw_result = engine.scrape(&scrape_request).await?;
520
521    // Extract links from HTML
522    let links = extract_links(&raw_result.html, url)?;
523
524    // Store HTML for pagination detection
525    let html = raw_result.html.clone();
526
527    // Process the result into a document
528    let document = format::process_scrape_result(raw_result, &scrape_request).await?;
529
530    Ok((document, links, html))
531}
532
533/// Extract all links from HTML
534fn extract_links(html: &str, base_url: &str) -> Result<Vec<String>> {
535    let document = Html::parse_document(html);
536    let selector = Selector::parse("a[href]")
537        .map_err(|e| ScrapeError::Internal(format!("Failed to create link selector: {:?}", e)))?;
538
539    let base = Url::parse(base_url)
540        .map_err(|e| ScrapeError::InvalidUrl(format!("Invalid base URL: {}", e)))?;
541
542    let mut links = Vec::new();
543
544    for element in document.select(&selector) {
545        if let Some(href) = element.value().attr("href") {
546            // Skip javascript:, mailto:, tel:, etc.
547            if href.starts_with("javascript:")
548                || href.starts_with("mailto:")
549                || href.starts_with("tel:")
550                || href.starts_with('#')
551            {
552                continue;
553            }
554
555            // Parse and resolve the URL
556            match base.join(href) {
557                Ok(absolute_url) => {
558                    let url_str = absolute_url.to_string();
559                    // Remove fragment
560                    let url_without_fragment = url_str.split('#').next().unwrap_or(&url_str);
561                    links.push(url_without_fragment.to_string());
562                }
563                Err(_) => {
564                    // Skip invalid URLs
565                    continue;
566                }
567            }
568        }
569    }
570
571    // Deduplicate
572    links.sort();
573    links.dedup();
574
575    Ok(links)
576}
577
578/// Check if a link is "forward" (same path or deeper)
579fn is_forward_link(link: &str, current: &str) -> bool {
580    let link_parsed = match Url::parse(link) {
581        Ok(u) => u,
582        Err(_) => return false,
583    };
584
585    let current_parsed = match Url::parse(current) {
586        Ok(u) => u,
587        Err(_) => return false,
588    };
589
590    let link_path = link_parsed.path();
591    let current_path = current_parsed.path();
592
593    // A link is forward if:
594    // 1. It has the same path as current, or
595    // 2. It's a subpath of current (starts with current path)
596    link_path == current_path || link_path.starts_with(current_path)
597}
598
599#[cfg(test)]
600mod tests {
601    use super::*;
602
603    #[tokio::test]
604    async fn test_parallel_crawler_creation() {
605        let crawler = ParallelCrawler::new();
606        assert_eq!(crawler.max_workers, num_cpus::get());
607    }
608
609    #[tokio::test]
610    async fn test_parallel_crawler_with_config() {
611        let mut config = CrawlerConfig::default();
612        config.max_concurrent_requests = 5;
613        
614        let crawler = ParallelCrawler::with_config(&config);
615        assert_eq!(crawler.max_workers, 5);
616    }
617
618    #[test]
619    fn test_extract_links() {
620        let html = r##"
621            <html>
622                <body>
623                    <a href="/page1">Page 1</a>
624                    <a href="/page2">Page 2</a>
625                    <a href="https://example.com/page3">Page 3</a>
626                    <a href="javascript:void(0)">JS</a>
627                    <a href="mailto:test@example.com">Email</a>
628                    <a href="#section">Section</a>
629                </body>
630            </html>
631        "##;
632
633        let links = extract_links(html, "https://example.com").unwrap();
634
635        assert!(links.contains(&"https://example.com/page1".to_string()));
636        assert!(links.contains(&"https://example.com/page2".to_string()));
637        assert!(links.contains(&"https://example.com/page3".to_string()));
638        assert!(!links.iter().any(|l| l.contains("javascript:")));
639        assert!(!links.iter().any(|l| l.contains("mailto:")));
640        assert!(!links.iter().any(|l| l.contains('#')));
641    }
642
643    #[test]
644    fn test_is_forward_link() {
645        assert!(is_forward_link(
646            "https://example.com/blog/post1",
647            "https://example.com/blog"
648        ));
649
650        assert!(is_forward_link(
651            "https://example.com/blog",
652            "https://example.com/blog"
653        ));
654
655        assert!(!is_forward_link(
656            "https://example.com/about",
657            "https://example.com/blog"
658        ));
659    }
660}