Skip to main content

argus_robots/
cache.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use anyhow::{Context, Result};
6use tokio::sync::RwLock;
7use url::Url;
8
9use crate::parser::RobotsTxt;
10
11#[derive(Clone)]
12struct CachedRobots {
13    robots: RobotsTxt,
14    fetched_at: Instant,
15}
16
17#[derive(Clone)]
18pub struct RobotsCache {
19    cache: Arc<RwLock<HashMap<String, CachedRobots>>>,
20    client: reqwest::Client,
21    user_agent: String,
22    ttl: Duration,
23}
24
25impl RobotsCache {
26    pub fn new(user_agent: String, ttl: Duration) -> Result<Self> {
27        let client = reqwest::Client::builder()
28            .timeout(Duration::from_secs(10))
29            .redirect(reqwest::redirect::Policy::limited(5))
30            .build()
31            .context("failed to build HTTP client")?;
32
33        Ok(Self {
34            cache: Arc::new(RwLock::new(HashMap::new())),
35            client,
36            user_agent,
37            ttl,
38        })
39    }
40
41    pub async fn is_allowed(&self, url: &str) -> Result<bool> {
42        let parsed = Url::parse(url).context("invalid URL")?;
43        let origin = format!(
44            "{}://{}{}",
45            parsed.scheme(),
46            parsed.host_str().unwrap_or(""),
47            if let Some(port) = parsed.port() {
48                format!(":{}", port)
49            } else {
50                String::new()
51            }
52        );
53
54        let robots = self.get_robots(&origin).await?;
55        let path = parsed.path();
56        Ok(robots.is_allowed(path))
57    }
58
59    pub async fn get_crawl_delay(&self, url: &str) -> Result<Option<Duration>> {
60        let parsed = Url::parse(url).context("invalid URL")?;
61        let origin = format!(
62            "{}://{}{}",
63            parsed.scheme(),
64            parsed.host_str().unwrap_or(""),
65            if let Some(port) = parsed.port() {
66                format!(":{}", port)
67            } else {
68                String::new()
69            }
70        );
71
72        let robots = self.get_robots(&origin).await?;
73        Ok(robots.crawl_delay())
74    }
75
76    async fn get_robots(&self, origin: &str) -> Result<RobotsTxt> {
77        {
78            let cache = self.cache.read().await;
79            if let Some(cached) = cache.get(origin) {
80                if cached.fetched_at.elapsed() < self.ttl {
81                    return Ok(cached.robots.clone());
82                }
83            }
84        }
85
86        let robots_url = format!("{}/robots.txt", origin);
87        tracing::debug!("fetching robots.txt from {}", robots_url);
88
89        let robots = match self.fetch_robots(&robots_url).await {
90            Ok(r) => r,
91            Err(e) => {
92                tracing::warn!("failed to fetch robots.txt from {}: {}", robots_url, e);
93                RobotsTxt::parse("", &self.user_agent)
94            }
95        };
96
97        let mut cache = self.cache.write().await;
98        cache.insert(
99            origin.to_string(),
100            CachedRobots {
101                robots: robots.clone(),
102                fetched_at: Instant::now(),
103            },
104        );
105
106        Ok(robots)
107    }
108
109    async fn fetch_robots(&self, url: &str) -> Result<RobotsTxt> {
110        let response = self
111            .client
112            .get(url)
113            .header("User-Agent", &self.user_agent)
114            .send()
115            .await
116            .context("failed to send request")?;
117
118        if !response.status().is_success() {
119            anyhow::bail!("non-success status: {}", response.status());
120        }
121
122        let content = response.text().await.context("failed to read response")?;
123        Ok(RobotsTxt::parse(&content, &self.user_agent))
124    }
125
126    pub async fn clear_cache(&self) {
127        let mut cache = self.cache.write().await;
128        cache.clear();
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use wiremock::matchers::{method, path};
136    use wiremock::{Mock, MockServer, ResponseTemplate};
137
138    #[tokio::test]
139    async fn cache_robots_txt() {
140        let mock_server = MockServer::start().await;
141
142        Mock::given(method("GET"))
143            .and(path("/robots.txt"))
144            .respond_with(
145                ResponseTemplate::new(200).set_body_string("User-agent: *\nDisallow: /admin/\n"),
146            )
147            .expect(1)
148            .mount(&mock_server)
149            .await;
150
151        let cache = RobotsCache::new("TestBot".to_string(), Duration::from_secs(3600)).unwrap();
152
153        let url1 = format!("{}/page", mock_server.uri());
154        let url2 = format!("{}/admin/secret", mock_server.uri());
155
156        assert!(cache.is_allowed(&url1).await.unwrap());
157        assert!(!cache.is_allowed(&url2).await.unwrap());
158
159        assert!(cache.is_allowed(&url1).await.unwrap());
160    }
161
162    #[tokio::test]
163    async fn handle_missing_robots_txt() {
164        let mock_server = MockServer::start().await;
165
166        Mock::given(method("GET"))
167            .and(path("/robots.txt"))
168            .respond_with(ResponseTemplate::new(404))
169            .mount(&mock_server)
170            .await;
171
172        let cache = RobotsCache::new("TestBot".to_string(), Duration::from_secs(3600)).unwrap();
173
174        let url = format!("{}/any-page", mock_server.uri());
175        assert!(cache.is_allowed(&url).await.unwrap());
176    }
177
178    #[tokio::test]
179    async fn respect_crawl_delay() {
180        let mock_server = MockServer::start().await;
181
182        Mock::given(method("GET"))
183            .and(path("/robots.txt"))
184            .respond_with(
185                ResponseTemplate::new(200).set_body_string("User-agent: *\nCrawl-delay: 1.5\n"),
186            )
187            .mount(&mock_server)
188            .await;
189
190        let cache = RobotsCache::new("TestBot".to_string(), Duration::from_secs(3600)).unwrap();
191
192        let url = format!("{}/page", mock_server.uri());
193        let delay = cache.get_crawl_delay(&url).await.unwrap();
194        assert_eq!(delay, Some(Duration::from_secs_f64(1.5)));
195    }
196
197    #[tokio::test]
198    async fn cache_expiration() {
199        let mock_server = MockServer::start().await;
200
201        Mock::given(method("GET"))
202            .and(path("/robots.txt"))
203            .respond_with(
204                ResponseTemplate::new(200).set_body_string("User-agent: *\nDisallow: /\n"),
205            )
206            .expect(2)
207            .mount(&mock_server)
208            .await;
209
210        let cache = RobotsCache::new("TestBot".to_string(), Duration::from_millis(100)).unwrap();
211
212        let url = format!("{}/page", mock_server.uri());
213
214        assert!(!cache.is_allowed(&url).await.unwrap());
215
216        tokio::time::sleep(Duration::from_millis(150)).await;
217
218        assert!(!cache.is_allowed(&url).await.unwrap());
219    }
220}