use crate::{
crawler::config::{CircuitBreaker, CrawlerConfig, MemoryMonitor},
crawler::filter::{is_same_domain, should_crawl_url},
crawler::pagination::{PaginationConfig, PaginationDetector},
crawler::rate_limiter::DomainRateLimiter,
engines::{http::HttpEngine, ScrapeEngine},
error::{Result, ScrapeError},
format,
types::{CrawlRequest, Document, ScrapeRequest},
utils::{normalize_url_string, robots},
};
use scraper::{Html, Selector};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::{mpsc, Semaphore};
use tokio::time::{timeout, Duration};
use tracing::{debug, info, warn, error};
use url::Url;
pub struct ParallelCrawler {
scrape_semaphore: Arc<Semaphore>,
#[allow(dead_code)]
process_semaphore: Arc<Semaphore>,
max_workers: usize,
}
impl ParallelCrawler {
pub fn new() -> Self {
let num_cpus = num_cpus::get();
Self {
scrape_semaphore: Arc::new(Semaphore::new(num_cpus)),
process_semaphore: Arc::new(Semaphore::new(num_cpus / 2)),
max_workers: num_cpus,
}
}
pub fn with_config(config: &CrawlerConfig) -> Self {
let max_concurrent = config.max_concurrent_requests;
Self {
scrape_semaphore: Arc::new(Semaphore::new(max_concurrent)),
process_semaphore: Arc::new(Semaphore::new(max_concurrent / 2)),
max_workers: max_concurrent,
}
}
pub async fn crawl_parallel(&self, request: &CrawlRequest) -> Result<Vec<Document>> {
info!(
"Starting parallel crawl from URL: {} with {} workers",
request.url, self.max_workers
);
let _base_url = Url::parse(&request.url)
.map_err(|e| ScrapeError::InvalidUrl(format!("Invalid base URL: {}", e)))?;
let normalized_base_url = normalize_url_string(&request.url)
.map_err(|e| ScrapeError::InvalidUrl(format!("Failed to normalize base URL: {}", e)))?;
let config = CrawlerConfig::default();
let circuit_breaker = Arc::new(CircuitBreaker::new(config.circuit_breaker_threshold));
let memory_monitor = Arc::new(MemoryMonitor::new(
config.max_memory_mb,
config.enable_memory_monitoring,
));
let visited = Arc::new(tokio::sync::RwLock::new(HashSet::new()));
let url_depths = Arc::new(tokio::sync::RwLock::new(HashMap::new()));
let robots_allowed = check_robots_txt(&request.url, request.ignore_sitemap).await;
let rate_limit = std::env::var("CRAWL_RATE_LIMIT_PER_SEC")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(2);
let rate_limiter = Arc::new(DomainRateLimiter::new(rate_limit));
info!(
"Rate limiting enabled: {} requests/second per domain",
rate_limit
);
let pagination_config = PaginationConfig {
max_pages: request.max_pagination_pages.unwrap_or(50) as usize,
max_depth: request.max_depth as usize,
detect_circular: true,
};
let pagination_detector = Arc::new(tokio::sync::Mutex::new(PaginationDetector::new(
pagination_config,
)));
let detect_pagination = request.detect_pagination.unwrap_or(true);
let (url_tx, url_rx) = mpsc::channel::<(String, u32)>(config.max_queue_size);
let (doc_tx, doc_rx) = mpsc::channel::<Document>(config.max_queue_size);
url_depths.write().await.insert(normalized_base_url.clone(), 0);
if let Err(e) = url_tx.send((normalized_base_url.clone(), 0)).await {
return Err(ScrapeError::Internal(format!("Failed to enqueue base URL: {}", e)));
}
let request_clone = request.clone();
let url_rx = Arc::new(tokio::sync::Mutex::new(url_rx));
let doc_tx_clone = doc_tx.clone();
let mut worker_handles = Vec::new();
for worker_id in 0..self.max_workers {
let url_rx_clone = Arc::clone(&url_rx);
let doc_tx_worker = doc_tx_clone.clone();
let url_tx_worker = url_tx.clone();
let visited_clone = Arc::clone(&visited);
let url_depths_clone = Arc::clone(&url_depths);
let circuit_breaker_clone = Arc::clone(&circuit_breaker);
let memory_monitor_clone = Arc::clone(&memory_monitor);
let rate_limiter_clone = Arc::clone(&rate_limiter);
let scrape_semaphore_clone = Arc::clone(&self.scrape_semaphore);
let pagination_detector_clone = Arc::clone(&pagination_detector);
let config_clone = config.clone();
let request_clone2 = request_clone.clone();
let handle = tokio::spawn(async move {
Self::scrape_worker(
worker_id,
url_rx_clone,
doc_tx_worker,
url_tx_worker,
visited_clone,
url_depths_clone,
circuit_breaker_clone,
memory_monitor_clone,
rate_limiter_clone,
scrape_semaphore_clone,
pagination_detector_clone,
config_clone,
request_clone2,
robots_allowed,
detect_pagination,
)
.await
});
worker_handles.push(handle);
}
drop(url_tx);
drop(doc_tx);
let doc_limit = request.limit as usize;
let mut documents = Vec::with_capacity(doc_limit.min(1000));
let mut doc_rx = doc_rx;
while let Some(doc) = doc_rx.recv().await {
documents.push(doc);
if documents.len() >= doc_limit {
info!("Reached document limit of {}, stopping collection", doc_limit);
break;
}
}
for handle in worker_handles {
if documents.len() >= doc_limit {
handle.abort();
} else if let Err(e) = handle.await {
if !e.is_cancelled() {
error!("Worker task failed: {}", e);
}
}
}
let visited_count = visited.read().await.len();
info!(
"Parallel crawl completed. Total pages crawled: {}, visited: {}",
documents.len(),
visited_count
);
Ok(documents)
}
#[allow(clippy::too_many_arguments)]
async fn scrape_worker(
worker_id: usize,
url_rx: Arc<tokio::sync::Mutex<mpsc::Receiver<(String, u32)>>>,
doc_tx: mpsc::Sender<Document>,
url_tx: mpsc::Sender<(String, u32)>,
visited: Arc<tokio::sync::RwLock<HashSet<String>>>,
url_depths: Arc<tokio::sync::RwLock<HashMap<String, u32>>>,
circuit_breaker: Arc<CircuitBreaker>,
memory_monitor: Arc<MemoryMonitor>,
rate_limiter: Arc<DomainRateLimiter>,
scrape_semaphore: Arc<Semaphore>,
pagination_detector: Arc<tokio::sync::Mutex<PaginationDetector>>,
config: CrawlerConfig,
request: CrawlRequest,
robots_allowed: bool,
detect_pagination: bool,
) -> Result<()> {
debug!("Worker {} started", worker_id);
let engine = HttpEngine::new()?;
loop {
let url_item = {
let mut rx = url_rx.lock().await;
rx.recv().await
};
let (current_url, current_depth) = match url_item {
Some(item) => item,
None => {
debug!("Worker {}: Queue closed, exiting", worker_id);
break;
}
};
{
let visited_read = visited.read().await;
if visited_read.contains(¤t_url) {
continue;
}
}
{
let mut visited_write = visited.write().await;
if visited_write.contains(¤t_url) {
continue;
}
visited_write.insert(current_url.clone());
}
if current_depth > request.max_depth {
debug!("Worker {}: Skipping URL due to depth limit: {}", worker_id, current_url);
continue;
}
let domain = match Url::parse(¤t_url) {
Ok(parsed) => parsed.host_str().unwrap_or("unknown").to_string(),
Err(_) => "unknown".to_string(),
};
if config.enable_circuit_breaker && circuit_breaker.should_skip(&domain) {
warn!(
"Worker {}: Circuit breaker: Skipping domain {} due to excessive failures",
worker_id, domain
);
continue;
}
if !robots_allowed {
match robots::is_allowed_default(¤t_url).await {
Ok(allowed) => {
if !allowed {
warn!("Worker {}: Robots.txt disallows crawling: {}", worker_id, current_url);
continue;
}
}
Err(e) => {
warn!("Worker {}: Failed to check robots.txt for {}: {}", worker_id, current_url, e);
}
}
}
if !should_crawl_url(¤t_url, &request.include_paths, &request.exclude_paths) {
debug!(
"Worker {}: URL filtered out by include/exclude patterns: {}",
worker_id, current_url
);
continue;
}
if config.enable_memory_monitoring {
if let Err(e) = memory_monitor.check_memory_limit() {
warn!("Worker {}: Memory limit check failed: {}", worker_id, e);
return Err(e);
}
}
let _permit = scrape_semaphore.acquire().await
.map_err(|e| ScrapeError::Internal(format!("Failed to acquire scrape permit: {}", e)))?;
if let Err(e) = rate_limiter.wait_for_permission(¤t_url).await {
warn!("Worker {}: Rate limiter error for {}: {}", worker_id, current_url, e);
continue;
}
info!(
"Worker {}: Crawling URL: {} (depth: {})",
worker_id, current_url, current_depth
);
let scrape_result = timeout(
Duration::from_secs(30),
scrape_url(¤t_url, &engine)
).await;
match scrape_result {
Ok(Ok((document, links, html))) => {
if config.enable_circuit_breaker {
circuit_breaker.record_success(&domain);
}
if doc_tx.send(document).await.is_err() {
debug!("Worker {}: Document receiver closed, exiting", worker_id);
return Ok(());
}
if current_depth < request.max_depth {
let pagination_links = if detect_pagination {
let mut detector = pagination_detector.lock().await;
detector.detect_pagination(&html, ¤t_url)
} else {
Vec::new()
};
for link in &pagination_links {
let normalized_link = match normalize_url_string(link) {
Ok(url) => url,
Err(e) => {
debug!("Worker {}: Failed to normalize pagination link {}: {}", worker_id, link, e);
continue;
}
};
{
let visited_read = visited.read().await;
let depths_read = url_depths.read().await;
if visited_read.contains(&normalized_link) || depths_read.contains_key(&normalized_link) {
continue;
}
}
if is_same_domain(&normalized_link, &request.url) {
{
let mut depths_write = url_depths.write().await;
depths_write.insert(normalized_link.clone(), current_depth);
}
if url_tx.send((normalized_link, current_depth)).await.is_err() {
debug!("Worker {}: URL receiver closed, exiting", worker_id);
return Ok(());
}
}
}
for link in links {
let normalized_link = match normalize_url_string(&link) {
Ok(url) => url,
Err(e) => {
debug!("Worker {}: Failed to normalize link {}: {}", worker_id, link, e);
continue;
}
};
{
let visited_read = visited.read().await;
let depths_read = url_depths.read().await;
if visited_read.contains(&normalized_link) || depths_read.contains_key(&normalized_link) {
continue;
}
}
if pagination_links.contains(&link) || pagination_links.contains(&normalized_link) {
continue;
}
let allow_link = if request.allow_external_links.unwrap_or(false) {
true
} else if request.allow_backward_links.unwrap_or(false) {
is_same_domain(&normalized_link, &request.url)
} else {
is_same_domain(&normalized_link, &request.url)
&& is_forward_link(&normalized_link, ¤t_url)
};
if allow_link {
{
let mut depths_write = url_depths.write().await;
if depths_write.len() < config.max_queue_size {
depths_write.insert(normalized_link.clone(), current_depth + 1);
} else {
debug!("Worker {}: Queue limit reached, skipping link: {}", worker_id, link);
continue;
}
}
if url_tx.send((normalized_link, current_depth + 1)).await.is_err() {
debug!("Worker {}: URL receiver closed, exiting", worker_id);
return Ok(());
}
}
}
}
}
Ok(Err(e)) => {
warn!("Worker {}: Failed to scrape {}: {}", worker_id, current_url, e);
if config.enable_circuit_breaker {
circuit_breaker.record_failure(&domain);
}
}
Err(_) => {
warn!("Worker {}: Timeout scraping {}", worker_id, current_url);
if config.enable_circuit_breaker {
circuit_breaker.record_failure(&domain);
}
}
}
}
debug!("Worker {} finished", worker_id);
Ok(())
}
}
impl Default for ParallelCrawler {
fn default() -> Self {
Self::new()
}
}
async fn check_robots_txt(url: &str, ignore_sitemap: Option<bool>) -> bool {
if ignore_sitemap.unwrap_or(false) {
return true;
}
match robots::is_allowed_default(url).await {
Ok(allowed) => allowed,
Err(e) => {
warn!("Failed to check robots.txt: {}, allowing by default", e);
true
}
}
}
async fn scrape_url(url: &str, engine: &HttpEngine) -> Result<(Document, Vec<String>, String)> {
let scrape_request = ScrapeRequest {
url: url.to_string(),
formats: vec!["markdown".to_string(), "links".to_string()],
headers: HashMap::new(),
include_tags: vec![],
exclude_tags: vec![],
only_main_content: true,
timeout: 30000,
wait_for: 0,
remove_base64_images: true,
skip_tls_verification: false,
engine: "http".to_string(),
wait_for_selector: None,
actions: vec![],
screenshot: false,
screenshot_format: "png".to_string(),
};
let raw_result = engine.scrape(&scrape_request).await?;
let links = extract_links(&raw_result.html, url)?;
let html = raw_result.html.clone();
let document = format::process_scrape_result(raw_result, &scrape_request).await?;
Ok((document, links, html))
}
fn extract_links(html: &str, base_url: &str) -> Result<Vec<String>> {
let document = Html::parse_document(html);
let selector = Selector::parse("a[href]")
.map_err(|e| ScrapeError::Internal(format!("Failed to create link selector: {:?}", e)))?;
let base = Url::parse(base_url)
.map_err(|e| ScrapeError::InvalidUrl(format!("Invalid base URL: {}", e)))?;
let mut links = Vec::new();
for element in document.select(&selector) {
if let Some(href) = element.value().attr("href") {
if href.starts_with("javascript:")
|| href.starts_with("mailto:")
|| href.starts_with("tel:")
|| href.starts_with('#')
{
continue;
}
match base.join(href) {
Ok(absolute_url) => {
let url_str = absolute_url.to_string();
let url_without_fragment = url_str.split('#').next().unwrap_or(&url_str);
links.push(url_without_fragment.to_string());
}
Err(_) => {
continue;
}
}
}
}
links.sort();
links.dedup();
Ok(links)
}
fn is_forward_link(link: &str, current: &str) -> bool {
let link_parsed = match Url::parse(link) {
Ok(u) => u,
Err(_) => return false,
};
let current_parsed = match Url::parse(current) {
Ok(u) => u,
Err(_) => return false,
};
let link_path = link_parsed.path();
let current_path = current_parsed.path();
link_path == current_path || link_path.starts_with(current_path)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parallel_crawler_creation() {
let crawler = ParallelCrawler::new();
assert_eq!(crawler.max_workers, num_cpus::get());
}
#[tokio::test]
async fn test_parallel_crawler_with_config() {
let mut config = CrawlerConfig::default();
config.max_concurrent_requests = 5;
let crawler = ParallelCrawler::with_config(&config);
assert_eq!(crawler.max_workers, 5);
}
#[test]
fn test_extract_links() {
let html = r##"
<html>
<body>
<a href="/page1">Page 1</a>
<a href="/page2">Page 2</a>
<a href="https://example.com/page3">Page 3</a>
<a href="javascript:void(0)">JS</a>
<a href="mailto:test@example.com">Email</a>
<a href="#section">Section</a>
</body>
</html>
"##;
let links = extract_links(html, "https://example.com").unwrap();
assert!(links.contains(&"https://example.com/page1".to_string()));
assert!(links.contains(&"https://example.com/page2".to_string()));
assert!(links.contains(&"https://example.com/page3".to_string()));
assert!(!links.iter().any(|l| l.contains("javascript:")));
assert!(!links.iter().any(|l| l.contains("mailto:")));
assert!(!links.iter().any(|l| l.contains('#')));
}
#[test]
fn test_is_forward_link() {
assert!(is_forward_link(
"https://example.com/blog/post1",
"https://example.com/blog"
));
assert!(is_forward_link(
"https://example.com/blog",
"https://example.com/blog"
));
assert!(!is_forward_link(
"https://example.com/about",
"https://example.com/blog"
));
}
}