1use std::sync::atomic::{AtomicU64, Ordering};
2use std::sync::Arc;
3
4use anyhow::Result;
5use argus_common::CrawlJob;
6use argus_dedupe::SeenSet;
7use argus_fetcher::http::HttpFetcher;
8use argus_frontier::Frontier;
9use argus_parser::html;
10use argus_robots;
11use argus_storage::Storage;
12
13use crate::rate_limit::{InMemoryRateLimiter, RateLimiter};
14use crate::shutdown::ShutdownSignal;
15
16#[derive(Clone, Debug)]
17pub struct CrawlConfig {
18 pub seed_url: Option<String>,
20 pub max_depth: u16,
21 pub global_concurrency: usize,
22 pub per_host_concurrency: usize,
23 pub per_host_delay_ms: u64,
24}
25
26pub async fn run<F, S>(
28 config: CrawlConfig,
29 frontier: F,
30 seen: S,
31 storage: Arc<dyn Storage>,
32 rate_limiter: Arc<dyn RateLimiter>,
33 shutdown: Option<ShutdownSignal>,
34) -> Result<()>
35where
36 F: Frontier + Clone + Send + Sync + 'static,
37 S: SeenSet + Clone + Send + Sync + 'static,
38{
39 argus_storage::init_storage();
40
41 if let Some(ref seed_url) = config.seed_url {
42 let (normalized_seed, host) = match argus_common::url::normalize_url(seed_url) {
43 Some(pair) => pair,
44 None => anyhow::bail!("invalid seed URL: {}", seed_url),
45 };
46 let seed_job = CrawlJob {
47 url: seed_url.clone(),
48 normalized_url: normalized_seed.clone(),
49 host: host.clone(),
50 depth: 0,
51 };
52 if !seen.insert_if_new(normalized_seed).await {
53 tracing::info!("seed URL already seen, skipping push");
54 } else {
55 frontier.push(seed_job).await;
56 }
57 tracing::info!(
58 "crawl started seed={} concurrency={} max_depth={}",
59 seed_url,
60 config.global_concurrency,
61 config.max_depth
62 );
63 } else {
64 tracing::info!(
65 "crawl started (consumer only) concurrency={} max_depth={}",
66 config.global_concurrency,
67 config.max_depth
68 );
69 }
70
71 let fetcher = HttpFetcher::new()?;
72
73 let fetched = Arc::new(AtomicU64::new(0));
74 let active = Arc::new(AtomicU64::new(0));
75 let concurrency = config.global_concurrency.max(1);
76 let mut handles = Vec::with_capacity(concurrency);
77
78 let shutdown_signal = shutdown.unwrap_or_default();
79
80 for _ in 0..concurrency {
81 let frontier = frontier.clone();
82 let seen = seen.clone();
83 let fetcher = fetcher.clone();
84 let storage = Arc::clone(&storage);
85 let rate_limiter = Arc::clone(&rate_limiter);
86 let config = config.clone();
87 let fetched = Arc::clone(&fetched);
88 let active = Arc::clone(&active);
89 let shutdown_clone = shutdown_signal.clone();
90
91 handles.push(tokio::spawn(async move {
92 loop {
93 if shutdown_clone.is_shutdown() {
94 tracing::info!("worker shutting down gracefully");
95 break;
96 }
97
98 let job = match frontier.pop().await {
99 Some(j) => j,
100 None => {
101 if active.load(Ordering::SeqCst) == 0 {
102 break;
103 }
104 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
105 continue;
106 }
107 };
108
109 active.fetch_add(1, Ordering::SeqCst);
110
111 if job.depth > config.max_depth {
112 active.fetch_sub(1, Ordering::SeqCst);
113 continue;
114 }
115 if !argus_robots::is_allowed(&job.url) {
116 active.fetch_sub(1, Ordering::SeqCst);
117 continue;
118 }
119
120 rate_limiter
121 .wait_for_host(&job.host, config.per_host_delay_ms)
122 .await;
123
124 let fetch_result = match fetcher.fetch(&job).await {
125 Ok(r) => r,
126 Err(e) => {
127 tracing::warn!("fetch failed url={} error={}", job.url, e);
128 active.fetch_sub(1, Ordering::SeqCst);
129 continue;
130 }
131 };
132
133 let n = fetched.fetch_add(1, Ordering::SeqCst) + 1;
134 if n == 1 || n.is_multiple_of(10) {
135 tracing::info!("fetched {} pages (current: {})", n, job.url);
136 }
137
138 if let Err(e) = storage.record_fetch(&job, &fetch_result).await {
139 tracing::warn!("storage record failed url={} error={}", job.url, e);
140 }
141
142 if fetch_result.status != 200 {
143 active.fetch_sub(1, Ordering::SeqCst);
144 continue;
145 }
146
147 let is_html = fetch_result
148 .content_type
149 .as_deref()
150 .is_some_and(|ct| ct.starts_with("text/html"));
151 if !is_html {
152 active.fetch_sub(1, Ordering::SeqCst);
153 continue;
154 }
155
156 let links = html::extract_links(&fetch_result.final_url, &fetch_result.body);
157
158 for link in links {
159 let Some((norm_url, link_host)) =
160 argus_common::url::normalize_url(&link.to_url)
161 else {
162 continue;
163 };
164 if !seen.insert_if_new(norm_url.clone()).await {
165 continue;
166 }
167 let new_job = CrawlJob {
168 url: link.to_url,
169 normalized_url: norm_url,
170 host: link_host,
171 depth: job.depth + 1,
172 };
173 frontier.push(new_job).await;
174 }
175
176 active.fetch_sub(1, Ordering::SeqCst);
177 }
178 }));
179 }
180
181 for h in handles {
182 let _ = h.await;
183 }
184
185 let total = fetched.load(Ordering::SeqCst);
186 tracing::info!("crawl finished, fetched {} pages", total);
187 Ok(())
188}
189
190pub async fn run_in_memory(
192 config: CrawlConfig,
193 storage: Arc<dyn Storage>,
194 shutdown: Option<ShutdownSignal>,
195) -> Result<()> {
196 let frontier = argus_frontier::InMemoryFrontier::default();
197 let seen = argus_dedupe::SeenUrlSet::default();
198 let rate_limiter = Arc::new(InMemoryRateLimiter::default());
199 run(config, frontier, seen, storage, rate_limiter, shutdown).await
200}
201
202#[cfg(feature = "redis")]
204pub async fn run_redis(
205 config: CrawlConfig,
206 redis_url: &str,
207 storage: Arc<dyn Storage>,
208 use_redis_rate_limit: bool,
209 shutdown: Option<ShutdownSignal>,
210) -> Result<()> {
211 use argus_dedupe::RedisSeenSet;
212 use argus_frontier::RedisFrontier;
213
214 use crate::rate_limit::RedisRateLimiter;
215
216 let frontier = RedisFrontier::connect(redis_url, None).await?;
217 let seen = RedisSeenSet::connect(redis_url, None).await?;
218 let rate_limiter: Arc<dyn RateLimiter> = if use_redis_rate_limit {
219 Arc::new(RedisRateLimiter::connect(redis_url).await?)
220 } else {
221 Arc::new(InMemoryRateLimiter::default())
222 };
223 run(config, frontier, seen, storage, rate_limiter, shutdown).await
224}
225
226#[cfg(feature = "redis")]
228pub async fn seed_redis(redis_url: &str, urls: &[String]) -> Result<()> {
229 use argus_dedupe::RedisSeenSet;
230 use argus_frontier::RedisFrontier;
231
232 let frontier = RedisFrontier::connect(redis_url, None).await?;
233 let seen = RedisSeenSet::connect(redis_url, None).await?;
234
235 for url in urls {
236 let Some((normalized_url, host)) = argus_common::url::normalize_url(url) else {
237 tracing::warn!("invalid URL, skipping: {}", url);
238 continue;
239 };
240 let job = CrawlJob {
241 url: url.clone(),
242 normalized_url: normalized_url.clone(),
243 host,
244 depth: 0,
245 };
246 if seen.insert_if_new(normalized_url).await {
247 frontier.push(job).await;
248 tracing::info!("seeded: {}", url);
249 }
250 }
251 Ok(())
252}