1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5use anyhow::{Context, Result};
6use tokio::sync::RwLock;
7use url::Url;
8
9use crate::parser::RobotsTxt;
10
11#[derive(Clone)]
12struct CachedRobots {
13 robots: RobotsTxt,
14 fetched_at: Instant,
15}
16
17#[derive(Clone)]
18pub struct RobotsCache {
19 cache: Arc<RwLock<HashMap<String, CachedRobots>>>,
20 client: reqwest::Client,
21 user_agent: String,
22 ttl: Duration,
23}
24
25impl RobotsCache {
26 pub fn new(user_agent: String, ttl: Duration) -> Result<Self> {
27 let client = reqwest::Client::builder()
28 .timeout(Duration::from_secs(10))
29 .redirect(reqwest::redirect::Policy::limited(5))
30 .build()
31 .context("failed to build HTTP client")?;
32
33 Ok(Self {
34 cache: Arc::new(RwLock::new(HashMap::new())),
35 client,
36 user_agent,
37 ttl,
38 })
39 }
40
41 pub async fn is_allowed(&self, url: &str) -> Result<bool> {
42 let parsed = Url::parse(url).context("invalid URL")?;
43 let origin = format!(
44 "{}://{}{}",
45 parsed.scheme(),
46 parsed.host_str().unwrap_or(""),
47 if let Some(port) = parsed.port() {
48 format!(":{}", port)
49 } else {
50 String::new()
51 }
52 );
53
54 let robots = self.get_robots(&origin).await?;
55 let path = parsed.path();
56 Ok(robots.is_allowed(path))
57 }
58
59 pub async fn get_crawl_delay(&self, url: &str) -> Result<Option<Duration>> {
60 let parsed = Url::parse(url).context("invalid URL")?;
61 let origin = format!(
62 "{}://{}{}",
63 parsed.scheme(),
64 parsed.host_str().unwrap_or(""),
65 if let Some(port) = parsed.port() {
66 format!(":{}", port)
67 } else {
68 String::new()
69 }
70 );
71
72 let robots = self.get_robots(&origin).await?;
73 Ok(robots.crawl_delay())
74 }
75
76 async fn get_robots(&self, origin: &str) -> Result<RobotsTxt> {
77 {
78 let cache = self.cache.read().await;
79 if let Some(cached) = cache.get(origin) {
80 if cached.fetched_at.elapsed() < self.ttl {
81 return Ok(cached.robots.clone());
82 }
83 }
84 }
85
86 let robots_url = format!("{}/robots.txt", origin);
87 tracing::debug!("fetching robots.txt from {}", robots_url);
88
89 let robots = match self.fetch_robots(&robots_url).await {
90 Ok(r) => r,
91 Err(e) => {
92 tracing::warn!("failed to fetch robots.txt from {}: {}", robots_url, e);
93 RobotsTxt::parse("", &self.user_agent)
94 }
95 };
96
97 let mut cache = self.cache.write().await;
98 cache.insert(
99 origin.to_string(),
100 CachedRobots {
101 robots: robots.clone(),
102 fetched_at: Instant::now(),
103 },
104 );
105
106 Ok(robots)
107 }
108
109 async fn fetch_robots(&self, url: &str) -> Result<RobotsTxt> {
110 let response = self
111 .client
112 .get(url)
113 .header("User-Agent", &self.user_agent)
114 .send()
115 .await
116 .context("failed to send request")?;
117
118 if !response.status().is_success() {
119 anyhow::bail!("non-success status: {}", response.status());
120 }
121
122 let content = response.text().await.context("failed to read response")?;
123 Ok(RobotsTxt::parse(&content, &self.user_agent))
124 }
125
126 pub async fn clear_cache(&self) {
127 let mut cache = self.cache.write().await;
128 cache.clear();
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use wiremock::matchers::{method, path};
136 use wiremock::{Mock, MockServer, ResponseTemplate};
137
138 #[tokio::test]
139 async fn cache_robots_txt() {
140 let mock_server = MockServer::start().await;
141
142 Mock::given(method("GET"))
143 .and(path("/robots.txt"))
144 .respond_with(
145 ResponseTemplate::new(200).set_body_string("User-agent: *\nDisallow: /admin/\n"),
146 )
147 .expect(1)
148 .mount(&mock_server)
149 .await;
150
151 let cache = RobotsCache::new("TestBot".to_string(), Duration::from_secs(3600)).unwrap();
152
153 let url1 = format!("{}/page", mock_server.uri());
154 let url2 = format!("{}/admin/secret", mock_server.uri());
155
156 assert!(cache.is_allowed(&url1).await.unwrap());
157 assert!(!cache.is_allowed(&url2).await.unwrap());
158
159 assert!(cache.is_allowed(&url1).await.unwrap());
160 }
161
162 #[tokio::test]
163 async fn handle_missing_robots_txt() {
164 let mock_server = MockServer::start().await;
165
166 Mock::given(method("GET"))
167 .and(path("/robots.txt"))
168 .respond_with(ResponseTemplate::new(404))
169 .mount(&mock_server)
170 .await;
171
172 let cache = RobotsCache::new("TestBot".to_string(), Duration::from_secs(3600)).unwrap();
173
174 let url = format!("{}/any-page", mock_server.uri());
175 assert!(cache.is_allowed(&url).await.unwrap());
176 }
177
178 #[tokio::test]
179 async fn respect_crawl_delay() {
180 let mock_server = MockServer::start().await;
181
182 Mock::given(method("GET"))
183 .and(path("/robots.txt"))
184 .respond_with(
185 ResponseTemplate::new(200).set_body_string("User-agent: *\nCrawl-delay: 1.5\n"),
186 )
187 .mount(&mock_server)
188 .await;
189
190 let cache = RobotsCache::new("TestBot".to_string(), Duration::from_secs(3600)).unwrap();
191
192 let url = format!("{}/page", mock_server.uri());
193 let delay = cache.get_crawl_delay(&url).await.unwrap();
194 assert_eq!(delay, Some(Duration::from_secs_f64(1.5)));
195 }
196
197 #[tokio::test]
198 async fn cache_expiration() {
199 let mock_server = MockServer::start().await;
200
201 Mock::given(method("GET"))
202 .and(path("/robots.txt"))
203 .respond_with(
204 ResponseTemplate::new(200).set_body_string("User-agent: *\nDisallow: /\n"),
205 )
206 .expect(2)
207 .mount(&mock_server)
208 .await;
209
210 let cache = RobotsCache::new("TestBot".to_string(), Duration::from_millis(100)).unwrap();
211
212 let url = format!("{}/page", mock_server.uri());
213
214 assert!(!cache.is_allowed(&url).await.unwrap());
215
216 tokio::time::sleep(Duration::from_millis(150)).await;
217
218 assert!(!cache.is_allowed(&url).await.unwrap());
219 }
220}