use crate::browser::BrowserPool;
use crate::extract::extract_urls;
use anyhow::{Context, Result};
use clap::Args;
use futures::future::join_all;
use serde::Serialize;
use std::io::{self, BufRead};
use std::sync::Arc;
use tokio::fs;
#[derive(Args)]
pub struct CheckLinksArgs {
#[arg(value_name = "FILE")]
file: Option<String>,
#[arg(long)]
url: Option<String>,
#[arg(long)]
stdin: bool,
#[arg(short, long, default_value = "5", value_parser = clap::value_parser!(u8).range(1..=20))]
concurrency: u8,
#[arg(long, default_value = "15000")]
timeout: u64,
#[arg(long, default_value = "1")]
retries: u8,
}
pub struct CheckLinksConfig {
pub concurrency: usize,
pub timeout_ms: u64,
pub retries: u8,
}
#[derive(Debug, Serialize)]
pub struct LinkResult {
pub url: String,
pub status: u16,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub redirect_to: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct LinkReport {
pub ok: usize,
pub failed: usize,
pub results: Vec<LinkResult>,
}
pub async fn run_check_links(args: CheckLinksArgs) -> Result<()> {
let urls = get_urls(&args).await?;
if urls.is_empty() {
eprintln!("No URLs found.");
std::process::exit(1);
}
eprintln!(
"Checking {} URLs ({} parallel)...",
urls.len(),
args.concurrency
);
let config = CheckLinksConfig {
concurrency: args.concurrency as usize,
timeout_ms: args.timeout,
retries: args.retries,
};
let report = check_links(&urls, &config).await?;
println!("{}", serde_json::to_string(&report)?);
eprintln!("Done: {}/{} OK", report.ok, report.ok + report.failed);
Ok(())
}
async fn get_urls(args: &CheckLinksArgs) -> Result<Vec<String>> {
if let Some(url) = &args.url {
return Ok(vec![url.clone()]);
}
if args.stdin {
let stdin = io::stdin();
let urls: Vec<String> = stdin
.lock()
.lines()
.map_while(Result::ok)
.filter(|line| line.starts_with("http"))
.collect();
return Ok(urls);
}
if let Some(file) = &args.file {
let content = fs::read_to_string(file)
.await
.with_context(|| format!("Failed to read file: {file}"))?;
return Ok(extract_urls(&content));
}
eprintln!("Usage:");
eprintln!(" ref check-links <file.md> Check URLs in markdown file");
eprintln!(" ref check-links --url <URL> Check single URL");
eprintln!(" ref check-links --stdin Read URLs from stdin");
std::process::exit(1);
}
pub async fn check_links(urls: &[String], config: &CheckLinksConfig) -> Result<LinkReport> {
let pool = Arc::new(BrowserPool::new(config.concurrency).await?);
let timeout = config.timeout_ms;
let retries = config.retries;
let tasks: Vec<_> = urls
.iter()
.map(|url| {
let pool = Arc::clone(&pool);
let url = url.clone();
tokio::spawn(async move { check_one(&pool, &url, timeout, retries).await })
})
.collect();
let results: Vec<LinkResult> = join_all(tasks)
.await
.into_iter()
.filter_map(std::result::Result::ok)
.collect();
if let Ok(pool) = Arc::try_unwrap(pool) {
pool.close().await?;
}
let ok = results
.iter()
.filter(|r| {
r.error.is_none() && r.redirect_to.is_none() && r.status >= 200 && r.status < 400
})
.count();
let failed = results.len() - ok;
Ok(LinkReport {
ok,
failed,
results,
})
}
async fn check_one(pool: &BrowserPool, url: &str, timeout: u64, retries: u8) -> LinkResult {
eprintln!(" -> {}", truncate(url, 60));
let page = match pool.new_page().await {
Ok(p) => p,
Err(e) => {
return LinkResult {
url: url.to_string(),
status: 0,
error: Some(e.to_string()),
redirect_to: None,
};
}
};
let mut result = match page.goto(url, timeout).await {
Ok(r) => r,
Err(e) => {
return LinkResult {
url: url.to_string(),
status: 0,
error: Some(e.to_string()),
redirect_to: None,
};
}
};
if result.status == 0 && retries > 0 {
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
if let Ok(r) = page.goto(url, timeout).await {
result = r;
}
}
let redirect_to = if result.status >= 200 && result.status < 400 {
if let Some(final_url) = page.current_url().await {
let orig_host = url::Url::parse(url)
.ok()
.and_then(|u| u.host_str().map(String::from));
let final_host = url::Url::parse(&final_url)
.ok()
.and_then(|u| u.host_str().map(String::from));
if let (Some(orig), Some(fin)) = (orig_host, final_host) {
let orig_norm = orig.trim_start_matches("www.");
let fin_norm = fin.trim_start_matches("www.");
if orig_norm == fin_norm {
None
} else {
Some(final_url)
}
} else {
None
}
} else {
None
}
} else {
None
};
LinkResult {
url: url.to_string(),
status: result.status,
error: result.error,
redirect_to,
}
}
fn truncate(s: &str, max: usize) -> String {
if s.len() <= max {
s.to_string()
} else {
format!("{}...", &s[..max - 3])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_check_links_config() {
let config = CheckLinksConfig {
concurrency: 5,
timeout_ms: 15000,
retries: 1,
};
assert_eq!(config.concurrency, 5);
assert_eq!(config.timeout_ms, 15000);
assert_eq!(config.retries, 1);
}
#[test]
fn test_truncate() {
assert_eq!(truncate("short", 10), "short");
assert_eq!(truncate("this is a very long string", 10), "this is...");
}
}