use futures::future::join_all;
use std::collections::HashSet;
use std::num::NonZeroU32;
use std::path::PathBuf;
use std::sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
};
use governor::clock::DefaultClock;
use governor::state::{InMemoryState, NotKeyed};
use governor::{Quota, RateLimiter};
use indicatif::{ProgressBar, ProgressStyle};
use reqwest_middleware::ClientBuilder;
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;
use std::time::Instant;
use tokio::sync::{mpsc, Mutex, Notify, RwLock, Semaphore};
use tracing::{debug, warn};
use url::Url;
use crate::cache::{Cache, ResourceMeta};
use crate::config::Config;
use crate::extract::extract_page;
use crate::robots::RobotsManager;
use crate::security::{clean_markdown, is_safe_image_content_type, sanitize_html_for_md, sanitize_markdown};
use crate::sink::{FileSink, Sink, Stats};
use crate::sitemap::fetch_and_parse_sitemaps;
use crate::util::{
is_same_host, normalize_url, path_for_asset, path_for_url, relpath, site_name_from_url,
};
use regex::Regex;
use std::time::Duration;
pub struct CrawlConfig {
pub base_url: Url,
pub output_dir: PathBuf,
pub user_agent: String,
pub max_depth: Option<usize>,
pub rate_limit_per_sec: u32,
pub follow_sitemaps: bool,
pub concurrency: usize,
pub timeout: Option<Duration>,
pub resume: bool,
pub config: Config,
pub silence: bool,
}
struct PendingGuard {
pending: Arc<AtomicUsize>,
notify: Arc<Notify>,
}
impl PendingGuard {
fn new(pending: Arc<AtomicUsize>, notify: Arc<Notify>) -> Self {
Self { pending, notify }
}
}
impl Drop for PendingGuard {
fn drop(&mut self) {
let remaining = self.pending.fetch_sub(1, Ordering::SeqCst) - 1;
debug!("Task done. Pending: {}", remaining);
if remaining == 0 {
debug!("No more pending tasks, notifying waiters");
self.notify.notify_waiters();
}
}
}
#[derive(Clone)]
struct SharedState {
visited: Arc<RwLock<HashSet<String>>>,
cache: Option<Arc<Cache>>,
exclude_res: Arc<Vec<Regex>>,
base: Arc<Url>,
robots: Arc<Mutex<RobotsManager>>,
limiter: Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>,
client: reqwest_middleware::ClientWithMiddleware,
sink: Arc<FileSink>,
output_dir: PathBuf,
saved_pages: Arc<AtomicUsize>,
saved_assets: Arc<AtomicUsize>,
skipped: Arc<AtomicUsize>,
semaphore: Arc<Semaphore>,
pending: Arc<AtomicUsize>,
notify: Arc<Notify>,
config: Config,
stop_flag: Arc<AtomicBool>,
tx: mpsc::UnboundedSender<(Url, usize)>,
max_depth: Option<usize>,
path_prefix: String,
}
impl SharedState {
fn enqueue(&self, url: Url, depth: usize) {
debug!("Enqueue called for {} at depth {}", url, depth);
if let Some(maxd) = self.max_depth {
if depth > maxd {
debug!("Skipping {}: depth {} > max {}", url, depth, maxd);
return;
}
}
if !within_scope(&self.base, &url, &self.config) {
debug!("Skipping {}: out of scope", url);
return;
}
if !path_in_prefix(url.path(), &self.path_prefix) {
debug!("Skipping {}: outside path prefix {}", url, self.path_prefix);
return;
}
if self.stop_flag.load(Ordering::Acquire) {
debug!("Skipping {}: stop flag set", url);
if let Some(c) = &self.cache {
let canon = normalize_url(&url);
let _ = c.add_frontier(&canon, depth);
}
return;
}
self.pending.fetch_add(1, Ordering::SeqCst);
if let Some(c) = &self.cache {
let canon = normalize_url(&url);
let _ = c.add_frontier(&canon, depth);
}
let state = self.clone();
tokio::spawn(async move {
state.process_url(url, depth).await;
});
}
async fn process_url(&self, url: Url, depth: usize) {
let _permit = match self.semaphore.acquire().await {
Ok(p) => p,
Err(_) => {
let rem = self.pending.fetch_sub(1, Ordering::SeqCst) - 1;
if rem == 0 {
self.notify.notify_waiters();
}
return;
}
};
let _guard = PendingGuard::new(self.pending.clone(), self.notify.clone());
let canon = normalize_url(&url);
debug!("Processing {} (canonical: {})", url, canon);
if let Some(c) = &self.cache {
let _ = c.remove_frontier(&canon);
}
{
let mut vis = self.visited.write().await;
if !vis.insert(canon.clone()) {
debug!("Already visited: {}", canon);
self.skipped.fetch_add(1, Ordering::Relaxed);
return;
}
}
if let Some(c) = &self.cache {
let cache_result = c.check_and_mark_visited(&canon);
debug!("Cache check for {}: {:?}", canon, cache_result);
if !cache_result.unwrap_or(true) {
debug!("Already in cache, skipping: {}", canon);
self.skipped.fetch_add(1, Ordering::Relaxed);
return;
}
}
if self.exclude_res.iter().any(|re| re.is_match(canon.as_str())) {
self.skipped.fetch_add(1, Ordering::Relaxed);
return;
}
{
let rb = self.robots.lock().await;
if !rb.allowed(&self.base, url.path()) {
debug!("Blocked by robots.txt: {}", url);
self.skipped.fetch_add(1, Ordering::Relaxed);
return;
}
}
let mut req = self.client.get(url.clone());
let cached_meta = self
.cache
.as_ref()
.and_then(|c| c.get_meta(&canon).ok())
.flatten();
if let Some(m) = &cached_meta {
if let Some(etag) = &m.etag {
req = req.header(reqwest::header::IF_NONE_MATCH, etag);
}
if let Some(lm) = &m.last_modified {
req = req.header(reqwest::header::IF_MODIFIED_SINCE, lm);
}
}
self.limiter.until_ready().await;
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
warn!(error=%e, "request failed");
self.skipped.fetch_add(1, Ordering::Relaxed);
return;
}
};
if resp.status() == reqwest::StatusCode::NOT_MODIFIED {
debug!("Not modified: {}", url);
self.skipped.fetch_add(1, Ordering::Relaxed);
return;
}
let resp = match resp.error_for_status() {
Ok(r) => r,
Err(e) => {
warn!(error=%e, "bad status");
self.skipped.fetch_add(1, Ordering::Relaxed);
return;
}
};
let final_url = resp.url().clone();
let content_type = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_lowercase();
if !content_type.contains("text/html") && !content_type.contains("application/xhtml") {
debug!("Skipping non-HTML: {} [{}]", final_url, content_type);
self.skipped.fetch_add(1, Ordering::Relaxed);
return;
}
let headers = resp.headers().clone();
const MAX_BODY_BYTES: u64 = 10 * 1024 * 1024; if let Some(cl) = headers
.get(reqwest::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
{
if cl > MAX_BODY_BYTES {
debug!("Skipping {}: content-length {} exceeds limit", url, cl);
self.skipped.fetch_add(1, Ordering::Relaxed);
return;
}
}
let body = match resp.text().await {
Ok(t) if t.len() as u64 > MAX_BODY_BYTES => {
debug!("Skipping {}: body size {} exceeds limit", url, t.len());
self.skipped.fetch_add(1, Ordering::Relaxed);
return;
}
Ok(t) => t,
Err(e) => {
warn!(error=%e, "read body failed");
self.skipped.fetch_add(1, Ordering::Relaxed);
return;
}
};
if let Some(c) = &self.cache {
let meta = ResourceMeta {
etag: headers
.get(reqwest::header::ETAG)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string()),
last_modified: headers
.get(reqwest::header::LAST_MODIFIED)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string()),
};
let _ = c.set_meta(&canon, &meta);
}
let page = extract_page(&final_url, &body, self.config.selectors.as_deref());
let sanitized_html = sanitize_html_for_md(&final_url, &page.main_html);
let mut md = html2md::rewrite_html(&sanitized_html, true);
let (md_checked, security_flags) = sanitize_markdown(&md);
md = md_checked;
md = clean_markdown(&md);
let out_path = path_for_url(&self.output_dir, &self.base, &final_url);
let quarantined = !security_flags.is_empty();
if !quarantined && !self.config.skip_assets {
let mut replace_pairs: Vec<(String, String)> = vec![];
let mut asset_futures = vec![];
for img in page.images.iter() {
if !self.config.external_assets && !within_scope(&self.base, img, &self.config) {
continue;
}
let asset_path = path_for_asset(&self.output_dir, &self.base, img);
let img_url = img.clone();
let client = self.client.clone();
let allow_svg = self.config.allow_svg;
let fut = async move {
match fetch_asset_checked(&client, img_url.clone(), allow_svg).await {
Ok((ct, bytes)) => Some((img_url, asset_path.clone(), ct, bytes)),
Err(e) => {
debug!(error=%e, "asset download failed: {}", img_url);
None
}
}
};
asset_futures.push(fut);
}
let asset_results = join_all(asset_futures).await;
for (img_url, asset_path, ct, bytes) in asset_results.into_iter().flatten() {
if let Err(e) = self
.sink
.save_asset(&img_url, &asset_path, ct.as_deref(), bytes)
.await
{
debug!(error=%e, "asset save failed: {}", img_url);
} else {
self.saved_assets.fetch_add(1, Ordering::Relaxed);
}
if let Some(rel) = relpath(&out_path, &asset_path) {
replace_pairs.push((
img_url.as_str().to_string(),
rel.to_string_lossy().to_string(),
));
}
}
if !replace_pairs.is_empty() {
md = rewrite_md_images(md, &replace_pairs);
}
}
let body_to_write = if quarantined {
let reasons = security_flags
.iter()
.map(|f| format!("- {}", f))
.collect::<Vec<_>>()
.join("\n");
format!(
"Content skipped by docrawl due to detected security issues.\n\nDetected flags:\n{}\n\nReason: One or more patterns indicating potential prompt-injection, risky content, or unsafe assets were found. To protect downstream consumers, docrawl omitted page content.",
reasons
)
} else {
md
};
if let Err(e) = self
.sink
.save_page(
&out_path,
&page.title,
&final_url,
&body_to_write,
&security_flags,
quarantined,
)
.await
{
warn!(error=%e, "save page failed");
}
let new_total = self.saved_pages.fetch_add(1, Ordering::Relaxed) + 1;
if let Some(max) = self.config.max_pages {
if new_total >= max {
self.stop_flag.store(true, Ordering::Release);
}
}
debug!("Saved {}", out_path.display());
let next_depth = depth.saturating_add(1);
debug!("Found {} links on {}", page.links.len(), final_url);
if !self.stop_flag.load(Ordering::Acquire) {
for link in page.links {
debug!("Sending link to queue: {} at depth {}", link, next_depth);
let _ = self.tx.send((link, next_depth));
}
} else if let Some(c) = &self.cache {
for link in page.links {
if within_scope(&self.base, &link, &self.config)
&& path_in_prefix(link.path(), &self.path_prefix)
{
let canon = normalize_url(&link);
let _ = c.add_frontier(&canon, next_depth);
}
}
}
}
}
pub async fn crawl(cfg: CrawlConfig) -> Result<Stats, Box<dyn std::error::Error>> {
let base_origin = origin_of(&cfg.base_url)?;
std::fs::create_dir_all(&cfg.output_dir)?;
let client = build_client(&cfg.user_agent)?;
let limiter = Arc::new(RateLimiter::direct(Quota::per_second(
NonZeroU32::new(cfg.rate_limit_per_sec.max(1)).unwrap(),
)));
let mut robots_loaded = RobotsManager::new(cfg.user_agent.clone());
robots_loaded.load_for(&client, &base_origin).await;
let robots = Arc::new(Mutex::new(robots_loaded));
let cache = Cache::open(&cfg.output_dir.join(".docrawl_cache"))
.ok()
.map(Arc::new);
let mut visited_set = HashSet::new();
if cfg.resume {
if let Some(c) = &cache {
if let Ok(list) = c.list_visited() {
debug!("Pre-loaded {} visited URLs from cache", list.len());
visited_set.extend(list);
}
}
}
let visited = Arc::new(RwLock::new(visited_set));
let host_dir = cfg.output_dir.join(site_name_from_url(&base_origin));
let sink = Arc::new(FileSink::new(host_dir.clone()));
let path_prefix = if cfg.follow_sitemaps {
"/".to_string() } else {
derive_path_prefix(&cfg.base_url)
};
debug!("Path prefix: {:?}", path_prefix);
let mut seeds: Vec<(Url, usize)> = vec![(cfg.base_url.clone(), 0)];
if cfg.follow_sitemaps {
let sitemap_urls = fetch_and_parse_sitemaps(&client, &base_origin).await;
for u in sitemap_urls {
if within_scope(&base_origin, &u, &cfg.config)
&& path_in_prefix(u.path(), &path_prefix)
{
seeds.push((u, 0));
}
}
}
let exclude_res: Vec<Regex> = cfg
.config
.exclude_patterns
.iter()
.filter_map(|p| match Regex::new(p) {
Ok(re) => Some(re),
Err(e) => {
warn!("Invalid exclude pattern '{}': {}", p, e);
None
}
})
.collect();
let seeds_count = seeds.len();
let pb = if cfg.silence {
None
} else {
let p = ProgressBar::new_spinner();
p.set_style(
ProgressStyle::with_template(
"{spinner:.cyan} [{elapsed_precise}] {msg}",
)
.unwrap()
.tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]),
);
p.enable_steady_tick(std::time::Duration::from_millis(80));
Some(p)
};
let start_time = Instant::now();
let pending = Arc::new(AtomicUsize::new(0));
let notify = Arc::new(Notify::new());
let stop_flag = Arc::new(AtomicBool::new(false));
if let Some(dur) = cfg.timeout {
let stop = stop_flag.clone();
let wake = notify.clone();
tokio::spawn(async move {
tokio::time::sleep(dur).await;
stop.store(true, Ordering::Release);
wake.notify_waiters();
});
}
let (tx, mut rx) = mpsc::unbounded_channel::<(Url, usize)>();
let state = SharedState {
visited,
cache: cache.clone(),
exclude_res: Arc::new(exclude_res),
base: Arc::new(base_origin),
robots,
limiter,
client,
sink: sink.clone(),
output_dir: cfg.output_dir.clone(),
saved_pages: Arc::new(AtomicUsize::new(0)),
saved_assets: Arc::new(AtomicUsize::new(0)),
skipped: Arc::new(AtomicUsize::new(0)),
semaphore: Arc::new(Semaphore::new(cfg.concurrency)),
pending: pending.clone(),
notify: notify.clone(),
config: cfg.config.clone(),
stop_flag,
tx,
max_depth: cfg.max_depth,
path_prefix,
};
let seeds_to_enqueue = if let Some(c) = &cache {
if cfg.resume {
match c.list_frontier() {
Ok(list) if !list.is_empty() => list
.into_iter()
.filter_map(|(u, d)| Url::parse(&u).ok().map(|url| (url, d)))
.collect(),
_ => seeds,
}
} else {
let _ = c.clear_frontier();
let _ = c.clear_visited();
let _ = c.clear_meta();
seeds
}
} else {
seeds
};
debug!("Enqueueing {} seed URLs", seeds_to_enqueue.len());
for (u, d) in seeds_to_enqueue {
debug!("Seed URL: {} at depth {}", u, d);
state.enqueue(u, d);
}
if let Some(pb) = &pb {
pb.set_message(format!(
"{} | Saved: 0 | Skipped: 0 | Queue: 0 | 0.0 pg/s",
state.base.host_str().unwrap_or("site")
));
}
let mut has_processed_any = false;
let follow_sitemaps = cfg.follow_sitemaps;
loop {
let pending_count = pending.load(Ordering::SeqCst);
let total_processed =
state.saved_pages.load(Ordering::Relaxed) + state.saved_assets.load(Ordering::Relaxed);
debug!(
"Main loop: pending tasks = {}, total processed = {}",
pending_count, total_processed
);
if let Some(pb) = &pb {
let saved = state.saved_pages.load(Ordering::Relaxed);
let skipped = state.skipped.load(Ordering::Relaxed);
let elapsed_secs = start_time.elapsed().as_secs_f64();
let rate = if elapsed_secs > 0.0 {
saved as f64 / elapsed_secs
} else {
0.0
};
pb.set_message(format!(
"{} | Saved: {} | Skipped: {} | Queue: {} | {:.1} pg/s",
state.base.host_str().unwrap_or("site"),
saved,
skipped,
pending_count,
rate
));
}
if pending_count > 0 {
has_processed_any = true;
}
if pending_count == 0 {
match rx.try_recv() {
Ok((u, d)) => {
debug!("Found pending message in channel: {} at depth {}", u, d);
state.enqueue(u, d);
continue;
}
Err(_) => {
if !has_processed_any && total_processed == 0 {
debug!("No tasks started yet, waiting for initial processing...");
let wait_time = if follow_sitemaps && seeds_count > 10 {
1000 + (seeds_count * 10).min(2000)
} else if follow_sitemaps {
500
} else {
100
};
debug!(
"Waiting {}ms for {} seeds to start processing",
wait_time, seeds_count
);
tokio::time::sleep(tokio::time::Duration::from_millis(wait_time as u64))
.await;
if pending.load(Ordering::SeqCst) > 0 {
continue;
}
}
debug!("No pending tasks and no messages in channel, exiting");
break;
}
}
}
tokio::select! {
_ = notify.notified() => { },
maybe = rx.recv() => {
if let Some((u,d)) = maybe {
debug!("Received link from channel: {} at depth {}", u, d);
state.enqueue(u,d);
}
}
}
}
let was_stopped = state.stop_flag.load(Ordering::Acquire);
if !was_stopped {
if let Some(c) = &cache {
let _ = c.clear_frontier();
}
}
let total_pages = state.saved_pages.load(Ordering::Relaxed);
let total_assets = state.saved_assets.load(Ordering::Relaxed);
let total_skipped = state.skipped.load(Ordering::Relaxed);
let elapsed = start_time.elapsed();
let host = state.base.host_str().unwrap_or("site");
if let Some(pb) = pb {
let done_msg = if total_assets > 0 {
format!(
"{} | Done: {} pages, {} assets saved, {} skipped in {:.1}s",
host, total_pages, total_assets, total_skipped, elapsed.as_secs_f32()
)
} else {
format!(
"{} | Done: {} pages saved, {} skipped in {:.1}s",
host, total_pages, total_skipped, elapsed.as_secs_f32()
)
};
pb.finish_with_message(done_msg);
}
let _ = sink.finalize().await;
Ok(Stats {
pages: total_pages,
assets: total_assets,
skipped: total_skipped,
})
}
fn origin_of(u: &Url) -> Result<Url, Box<dyn std::error::Error>> {
let mut o = u.clone();
o.set_path("");
o.set_query(None);
o.set_fragment(None);
Ok(o)
}
fn build_client(
user_agent: &str,
) -> Result<reqwest_middleware::ClientWithMiddleware, Box<dyn std::error::Error>> {
let mut builder = reqwest::Client::builder()
.user_agent(user_agent.to_string())
.redirect(reqwest::redirect::Policy::limited(10))
.cookie_store(true)
.pool_max_idle_per_host(32)
.pool_idle_timeout(std::time::Duration::from_secs(30))
.tcp_keepalive(std::time::Duration::from_secs(30))
.http2_adaptive_window(true);
if std::env::var_os("HTTP_PROXY").is_none() && std::env::var_os("HTTPS_PROXY").is_none() {
builder = builder.no_proxy();
}
let base = builder.build()?;
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
let client = ClientBuilder::new(base)
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
Ok(client)
}
async fn fetch_asset_checked(
client: &reqwest_middleware::ClientWithMiddleware,
url: Url,
allow_svg: bool,
) -> Result<(Option<String>, bytes::Bytes), Box<dyn std::error::Error + Send + Sync>> {
let resp = client.get(url).send().await?;
if !resp.status().is_success() {
return Err(format!("non-success: {}", resp.status()).into());
}
let ct_str = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let is_svg = ct_str
.as_deref()
.map(|s| s.eq_ignore_ascii_case("image/svg+xml"))
.unwrap_or(false);
if !is_safe_image_content_type(ct_str.as_deref()) && !(is_svg && allow_svg) {
return Err(if is_svg {
"svg not allowed"
} else {
"unsupported image content-type"
}
.into());
}
let bytes = resp.bytes().await?;
Ok((ct_str, bytes))
}
fn within_scope(base_origin: &Url, target: &Url, cfg: &Config) -> bool {
if cfg.host_only {
fn bare_host(u: &Url) -> Option<String> {
u.host_str()
.map(|h| h.strip_prefix("www.").unwrap_or(h).to_string())
}
bare_host(base_origin) == bare_host(target)
&& base_origin.scheme() == target.scheme()
&& base_origin.port_or_known_default() == target.port_or_known_default()
} else {
is_same_host(base_origin, target)
}
}
fn derive_path_prefix(url: &Url) -> String {
let path = url.path();
if path == "/" || path.is_empty() {
return "/".to_string();
}
let trimmed = path.trim_end_matches('/');
if let Some(last) = trimmed.rsplit('/').next() {
if last.contains('.') {
let parent = trimmed[..trimmed.len() - last.len()].trim_end_matches('/');
if parent.is_empty() {
return "/".to_string();
}
return parent.to_string();
}
}
trimmed.to_string()
}
fn path_in_prefix(target_path: &str, prefix: &str) -> bool {
if prefix == "/" {
return true;
}
let tp = target_path.trim_end_matches('/');
let pp = prefix.trim_end_matches('/');
tp == pp || target_path.starts_with(&format!("{}/", pp))
}
fn rewrite_md_images(mut md: String, pairs: &[(String, String)]) -> String {
for (src, rel) in pairs {
let needle = format!("]({})", src);
let replacement = format!("]({})", rel);
md = md.replace(&needle, &replacement);
}
md
}