1use crate::error::ScrapeError;
2use crate::types::{CrawlRequest, MapRequest, ScrapeRequest, SearchRequest};
3use crate::utils::ssrf_protection;
4use scraper::Selector;
5use std::time::Duration;
6use url::Url;
7
8const MAX_URL_LENGTH: usize = 2048;
10const MAX_HEADERS_COUNT: usize = 50;
11const MAX_HEADER_VALUE_LENGTH: usize = 4096;
12const MAX_ACTIONS_COUNT: usize = 20;
13const MAX_CSS_SELECTOR_LENGTH: usize = 1000;
14const MAX_TIMEOUT_MS: u64 = 300_000; const MAX_CRAWL_TIMEOUT_MS: u64 = 300_000; const MAX_CRAWL_LIMIT: u32 = 10_000;
17const MAX_MAP_LIMIT: u32 = 100_000;
18const MAX_SEARCH_LIMIT: u32 = 100;
19
20pub async fn validate_url(url: &str) -> Result<(), ScrapeError> {
22 if url.is_empty() {
23 return Err(ScrapeError::InvalidUrl("URL cannot be empty".to_string()));
24 }
25
26 if url.len() > MAX_URL_LENGTH {
27 return Err(ScrapeError::InvalidUrl(format!(
28 "URL too long: {} > {} characters",
29 url.len(),
30 MAX_URL_LENGTH
31 )));
32 }
33
34 Url::parse(url).map_err(|e| ScrapeError::InvalidUrl(format!("Invalid URL: {}", e)))?;
35
36 ssrf_protection::validate_url_safe(url).await?;
38
39 Ok(())
40}
41
42pub fn validate_css_selector(selector: &str) -> Result<(), ScrapeError> {
44 if selector.is_empty() {
45 return Ok(());
46 }
47
48 if selector.len() > MAX_CSS_SELECTOR_LENGTH {
49 return Err(ScrapeError::InvalidRequest(format!(
50 "CSS selector too long: {} > {} characters",
51 selector.len(),
52 MAX_CSS_SELECTOR_LENGTH
53 )));
54 }
55
56 let dangerous_patterns = [
58 "<script", "javascript:", "eval(", "onclick=", "onerror=", "onload=",
59 ];
60
61 for pattern in &dangerous_patterns {
62 if selector.to_lowercase().contains(pattern) {
63 return Err(ScrapeError::InvalidRequest(format!(
64 "Invalid CSS selector: contains dangerous pattern '{}'",
65 pattern
66 )));
67 }
68 }
69
70 Selector::parse(selector).map_err(|e| {
72 ScrapeError::InvalidRequest(format!("Invalid CSS selector syntax: {:?}", e))
73 })?;
74
75 Ok(())
76}
77
78pub async fn validate_scrape_request(req: &ScrapeRequest) -> Result<(), ScrapeError> {
80 validate_url(&req.url).await?;
82
83 if req.timeout > MAX_TIMEOUT_MS {
85 return Err(ScrapeError::InvalidRequest(format!(
86 "Timeout too large: {}ms > {}ms",
87 req.timeout, MAX_TIMEOUT_MS
88 )));
89 }
90
91 if req.headers.len() > MAX_HEADERS_COUNT {
93 return Err(ScrapeError::InvalidRequest(format!(
94 "Too many headers: {} > {}",
95 req.headers.len(),
96 MAX_HEADERS_COUNT
97 )));
98 }
99
100 for (key, value) in &req.headers {
101 if value.len() > MAX_HEADER_VALUE_LENGTH {
102 return Err(ScrapeError::InvalidRequest(format!(
103 "Header '{}' value too long: {} > {} characters",
104 key,
105 value.len(),
106 MAX_HEADER_VALUE_LENGTH
107 )));
108 }
109 }
110
111 if req.actions.len() > MAX_ACTIONS_COUNT {
113 return Err(ScrapeError::InvalidRequest(format!(
114 "Too many browser actions: {} > {}",
115 req.actions.len(),
116 MAX_ACTIONS_COUNT
117 )));
118 }
119
120 if let Some(ref selector) = req.wait_for_selector {
122 validate_css_selector(selector)?;
123 }
124
125 for tag in &req.include_tags {
126 validate_css_selector(tag)?;
127 }
128
129 for tag in &req.exclude_tags {
130 validate_css_selector(tag)?;
131 }
132
133 let valid_formats = [
135 "markdown",
136 "html",
137 "rawHtml",
138 "links",
139 "images",
140 "screenshot",
141 ];
142
143 for format in &req.formats {
144 if !valid_formats.contains(&format.as_str()) {
145 return Err(ScrapeError::UnsupportedFormat(format.clone()));
146 }
147 }
148
149 Ok(())
150}
151
152pub async fn validate_map_request(req: &MapRequest) -> Result<(), ScrapeError> {
154 validate_url(&req.url).await?;
155
156 if let Some(limit) = req.limit {
157 if limit > MAX_MAP_LIMIT {
158 return Err(ScrapeError::InvalidRequest(format!(
159 "Map limit too large: {} > {}",
160 limit, MAX_MAP_LIMIT
161 )));
162 }
163 }
164
165 Ok(())
166}
167
168pub async fn validate_crawl_request(req: &CrawlRequest) -> Result<(), ScrapeError> {
170 validate_url(&req.url).await?;
171
172 if req.limit > MAX_CRAWL_LIMIT {
173 return Err(ScrapeError::InvalidRequest(format!(
174 "Crawl limit too large: {} > {}",
175 req.limit, MAX_CRAWL_LIMIT
176 )));
177 }
178
179 Ok(())
180}
181
182pub fn validate_search_request(req: &SearchRequest) -> Result<(), ScrapeError> {
184 if req.query.is_empty() {
185 return Err(ScrapeError::InvalidRequest(
186 "Search query cannot be empty".to_string(),
187 ));
188 }
189
190 if req.limit > MAX_SEARCH_LIMIT {
191 return Err(ScrapeError::InvalidRequest(format!(
192 "Search limit too large: {} > {}",
193 req.limit, MAX_SEARCH_LIMIT
194 )));
195 }
196
197 Ok(())
198}
199
200pub fn get_crawl_timeout() -> Duration {
202 Duration::from_millis(MAX_CRAWL_TIMEOUT_MS)
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[tokio::test]
210 async fn test_validate_url() {
211 assert!(validate_url("https://example.com").await.is_ok());
212 assert!(validate_url("").await.is_err());
213 assert!(validate_url(&"a".repeat(3000)).await.is_err());
214 assert!(validate_url("not-a-url").await.is_err());
215 }
216
217 #[test]
218 fn test_validate_css_selector() {
219 assert!(validate_css_selector("div.class").is_ok());
220 assert!(validate_css_selector("#id").is_ok());
221 assert!(validate_css_selector("<script>alert('xss')</script>").is_err());
222 assert!(validate_css_selector("javascript:void(0)").is_err());
223 assert!(validate_css_selector(&"a".repeat(2000)).is_err());
224 }
225
226 #[tokio::test]
227 async fn test_validate_scrape_request() {
228 let valid_req = ScrapeRequest {
229 url: "https://example.com".to_string(),
230 formats: vec!["markdown".to_string()],
231 headers: Default::default(),
232 include_tags: vec![],
233 exclude_tags: vec![],
234 only_main_content: true,
235 timeout: 30000,
236 wait_for: 0,
237 remove_base64_images: true,
238 skip_tls_verification: false,
239 engine: "auto".to_string(),
240 wait_for_selector: None,
241 actions: vec![],
242 screenshot: false,
243 screenshot_format: "png".to_string(),
244 };
245
246 assert!(validate_scrape_request(&valid_req).await.is_ok());
247
248 let mut invalid_req = valid_req.clone();
250 invalid_req.timeout = 400_000; assert!(validate_scrape_request(&invalid_req).await.is_err());
252
253 let mut invalid_req = valid_req.clone();
255 invalid_req.formats = vec!["invalid_format".to_string()];
256 assert!(validate_scrape_request(&invalid_req).await.is_err());
257 }
258}