Skip to main content

argus_worker/
worker.rs

1use std::sync::atomic::{AtomicU64, Ordering};
2use std::sync::Arc;
3
4use anyhow::Result;
5use argus_common::CrawlJob;
6use argus_dedupe::SeenSet;
7use argus_fetcher::http::HttpFetcher;
8use argus_frontier::Frontier;
9use argus_parser::html;
10use argus_robots;
11use argus_storage::Storage;
12
13use crate::rate_limit::{InMemoryRateLimiter, RateLimiter};
14use crate::shutdown::ShutdownSignal;
15
16#[derive(Clone, Debug)]
17pub struct CrawlConfig {
18    /// If Some, push this URL as the seed job before running. If None (e.g. Redis consumer-only), just drain the queue.
19    pub seed_url: Option<String>,
20    pub max_depth: u16,
21    pub global_concurrency: usize,
22    pub per_host_concurrency: usize,
23    pub per_host_delay_ms: u64,
24}
25
26/// Runs the crawl with the given frontier, seen set, storage, and rate limiter.
27pub async fn run<F, S>(
28    config: CrawlConfig,
29    frontier: F,
30    seen: S,
31    storage: Arc<dyn Storage>,
32    rate_limiter: Arc<dyn RateLimiter>,
33    shutdown: Option<ShutdownSignal>,
34) -> Result<()>
35where
36    F: Frontier + Clone + Send + Sync + 'static,
37    S: SeenSet + Clone + Send + Sync + 'static,
38{
39    argus_storage::init_storage();
40
41    if let Some(ref seed_url) = config.seed_url {
42        let (normalized_seed, host) = match argus_common::url::normalize_url(seed_url) {
43            Some(pair) => pair,
44            None => anyhow::bail!("invalid seed URL: {}", seed_url),
45        };
46        let seed_job = CrawlJob {
47            url: seed_url.clone(),
48            normalized_url: normalized_seed.clone(),
49            host: host.clone(),
50            depth: 0,
51        };
52        if !seen.insert_if_new(normalized_seed).await {
53            tracing::info!("seed URL already seen, skipping push");
54        } else {
55            frontier.push(seed_job).await;
56        }
57        tracing::info!(
58            "crawl started seed={} concurrency={} max_depth={}",
59            seed_url,
60            config.global_concurrency,
61            config.max_depth
62        );
63    } else {
64        tracing::info!(
65            "crawl started (consumer only) concurrency={} max_depth={}",
66            config.global_concurrency,
67            config.max_depth
68        );
69    }
70
71    let fetcher = HttpFetcher::new()?;
72
73    let fetched = Arc::new(AtomicU64::new(0));
74    let active = Arc::new(AtomicU64::new(0));
75    let concurrency = config.global_concurrency.max(1);
76    let mut handles = Vec::with_capacity(concurrency);
77
78    let shutdown_signal = shutdown.unwrap_or_default();
79
80    for _ in 0..concurrency {
81        let frontier = frontier.clone();
82        let seen = seen.clone();
83        let fetcher = fetcher.clone();
84        let storage = Arc::clone(&storage);
85        let rate_limiter = Arc::clone(&rate_limiter);
86        let config = config.clone();
87        let fetched = Arc::clone(&fetched);
88        let active = Arc::clone(&active);
89        let shutdown_clone = shutdown_signal.clone();
90
91        handles.push(tokio::spawn(async move {
92            loop {
93                if shutdown_clone.is_shutdown() {
94                    tracing::info!("worker shutting down gracefully");
95                    break;
96                }
97
98                let job = match frontier.pop().await {
99                    Some(j) => j,
100                    None => {
101                        if active.load(Ordering::SeqCst) == 0 {
102                            break;
103                        }
104                        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
105                        continue;
106                    }
107                };
108
109                active.fetch_add(1, Ordering::SeqCst);
110
111                if job.depth > config.max_depth {
112                    active.fetch_sub(1, Ordering::SeqCst);
113                    continue;
114                }
115                if !argus_robots::is_allowed(&job.url) {
116                    active.fetch_sub(1, Ordering::SeqCst);
117                    continue;
118                }
119
120                rate_limiter
121                    .wait_for_host(&job.host, config.per_host_delay_ms)
122                    .await;
123
124                let fetch_result = match fetcher.fetch(&job).await {
125                    Ok(r) => r,
126                    Err(e) => {
127                        tracing::warn!("fetch failed url={} error={}", job.url, e);
128                        active.fetch_sub(1, Ordering::SeqCst);
129                        continue;
130                    }
131                };
132
133                let n = fetched.fetch_add(1, Ordering::SeqCst) + 1;
134                if n == 1 || n.is_multiple_of(10) {
135                    tracing::info!("fetched {} pages (current: {})", n, job.url);
136                }
137
138                if let Err(e) = storage.record_fetch(&job, &fetch_result).await {
139                    tracing::warn!("storage record failed url={} error={}", job.url, e);
140                }
141
142                if fetch_result.status != 200 {
143                    active.fetch_sub(1, Ordering::SeqCst);
144                    continue;
145                }
146
147                let is_html = fetch_result
148                    .content_type
149                    .as_deref()
150                    .is_some_and(|ct| ct.starts_with("text/html"));
151                if !is_html {
152                    active.fetch_sub(1, Ordering::SeqCst);
153                    continue;
154                }
155
156                let links = html::extract_links(&fetch_result.final_url, &fetch_result.body);
157
158                for link in links {
159                    let Some((norm_url, link_host)) =
160                        argus_common::url::normalize_url(&link.to_url)
161                    else {
162                        continue;
163                    };
164                    if !seen.insert_if_new(norm_url.clone()).await {
165                        continue;
166                    }
167                    let new_job = CrawlJob {
168                        url: link.to_url,
169                        normalized_url: norm_url,
170                        host: link_host,
171                        depth: job.depth + 1,
172                    };
173                    frontier.push(new_job).await;
174                }
175
176                active.fetch_sub(1, Ordering::SeqCst);
177            }
178        }));
179    }
180
181    for h in handles {
182        let _ = h.await;
183    }
184
185    let total = fetched.load(Ordering::SeqCst);
186    tracing::info!("crawl finished, fetched {} pages", total);
187    Ok(())
188}
189
190/// In-memory backend for single-node runs.
191pub async fn run_in_memory(
192    config: CrawlConfig,
193    storage: Arc<dyn Storage>,
194    shutdown: Option<ShutdownSignal>,
195) -> Result<()> {
196    let frontier = argus_frontier::InMemoryFrontier::default();
197    let seen = argus_dedupe::SeenUrlSet::default();
198    let rate_limiter = Arc::new(InMemoryRateLimiter::default());
199    run(config, frontier, seen, storage, rate_limiter, shutdown).await
200}
201
202/// Redis-backed frontier and seen set; optional Redis-backed rate limiter for global per-host delay.
203#[cfg(feature = "redis")]
204pub async fn run_redis(
205    config: CrawlConfig,
206    redis_url: &str,
207    storage: Arc<dyn Storage>,
208    use_redis_rate_limit: bool,
209    shutdown: Option<ShutdownSignal>,
210) -> Result<()> {
211    use argus_dedupe::RedisSeenSet;
212    use argus_frontier::RedisFrontier;
213
214    use crate::rate_limit::RedisRateLimiter;
215
216    let frontier = RedisFrontier::connect(redis_url, None).await?;
217    let seen = RedisSeenSet::connect(redis_url, None).await?;
218    let rate_limiter: Arc<dyn RateLimiter> = if use_redis_rate_limit {
219        Arc::new(RedisRateLimiter::connect(redis_url).await?)
220    } else {
221        Arc::new(InMemoryRateLimiter::default())
222    };
223    run(config, frontier, seen, storage, rate_limiter, shutdown).await
224}
225
226/// Push URLs onto the Redis frontier (and mark them in the seen set). Exits after pushing; no crawl.
227#[cfg(feature = "redis")]
228pub async fn seed_redis(redis_url: &str, urls: &[String]) -> Result<()> {
229    use argus_dedupe::RedisSeenSet;
230    use argus_frontier::RedisFrontier;
231
232    let frontier = RedisFrontier::connect(redis_url, None).await?;
233    let seen = RedisSeenSet::connect(redis_url, None).await?;
234
235    for url in urls {
236        let Some((normalized_url, host)) = argus_common::url::normalize_url(url) else {
237            tracing::warn!("invalid URL, skipping: {}", url);
238            continue;
239        };
240        let job = CrawlJob {
241            url: url.clone(),
242            normalized_url: normalized_url.clone(),
243            host,
244            depth: 0,
245        };
246        if seen.insert_if_new(normalized_url).await {
247            frontier.push(job).await;
248            tracing::info!("seeded: {}", url);
249        }
250    }
251    Ok(())
252}