Skip to main content

essence/
validation.rs

1use crate::error::ScrapeError;
2use crate::types::{CrawlRequest, MapRequest, ScrapeRequest, SearchRequest};
3use crate::utils::ssrf_protection;
4use scraper::Selector;
5use std::time::Duration;
6use url::Url;
7
8// Size limits (in bytes)
9const MAX_URL_LENGTH: usize = 2048;
10const MAX_HEADERS_COUNT: usize = 50;
11const MAX_HEADER_VALUE_LENGTH: usize = 4096;
12const MAX_ACTIONS_COUNT: usize = 20;
13const MAX_CSS_SELECTOR_LENGTH: usize = 1000;
14const MAX_TIMEOUT_MS: u64 = 300_000; // 5 minutes
15const MAX_CRAWL_TIMEOUT_MS: u64 = 300_000; // 5 minutes for crawls
16const MAX_CRAWL_LIMIT: u32 = 10_000;
17const MAX_MAP_LIMIT: u32 = 100_000;
18const MAX_SEARCH_LIMIT: u32 = 100;
19
20/// Validate a URL (including SSRF protection)
21pub async fn validate_url(url: &str) -> Result<(), ScrapeError> {
22    if url.is_empty() {
23        return Err(ScrapeError::InvalidUrl("URL cannot be empty".to_string()));
24    }
25
26    if url.len() > MAX_URL_LENGTH {
27        return Err(ScrapeError::InvalidUrl(format!(
28            "URL too long: {} > {} characters",
29            url.len(),
30            MAX_URL_LENGTH
31        )));
32    }
33
34    Url::parse(url).map_err(|e| ScrapeError::InvalidUrl(format!("Invalid URL: {}", e)))?;
35
36    // SSRF protection: check for private IPs and DNS rebinding attacks
37    ssrf_protection::validate_url_safe(url).await?;
38
39    Ok(())
40}
41
42/// Validate a CSS selector
43pub fn validate_css_selector(selector: &str) -> Result<(), ScrapeError> {
44    if selector.is_empty() {
45        return Ok(());
46    }
47
48    if selector.len() > MAX_CSS_SELECTOR_LENGTH {
49        return Err(ScrapeError::InvalidRequest(format!(
50            "CSS selector too long: {} > {} characters",
51            selector.len(),
52            MAX_CSS_SELECTOR_LENGTH
53        )));
54    }
55
56    // Check for dangerous patterns
57    let dangerous_patterns = [
58        "<script", "javascript:", "eval(", "onclick=", "onerror=", "onload=",
59    ];
60
61    for pattern in &dangerous_patterns {
62        if selector.to_lowercase().contains(pattern) {
63            return Err(ScrapeError::InvalidRequest(format!(
64                "Invalid CSS selector: contains dangerous pattern '{}'",
65                pattern
66            )));
67        }
68    }
69
70    // Validate it parses correctly
71    Selector::parse(selector).map_err(|e| {
72        ScrapeError::InvalidRequest(format!("Invalid CSS selector syntax: {:?}", e))
73    })?;
74
75    Ok(())
76}
77
78/// Validate scrape request
79pub async fn validate_scrape_request(req: &ScrapeRequest) -> Result<(), ScrapeError> {
80    // URL validation
81    validate_url(&req.url).await?;
82
83    // Timeout validation
84    if req.timeout > MAX_TIMEOUT_MS {
85        return Err(ScrapeError::InvalidRequest(format!(
86            "Timeout too large: {}ms > {}ms",
87            req.timeout, MAX_TIMEOUT_MS
88        )));
89    }
90
91    // Headers validation
92    if req.headers.len() > MAX_HEADERS_COUNT {
93        return Err(ScrapeError::InvalidRequest(format!(
94            "Too many headers: {} > {}",
95            req.headers.len(),
96            MAX_HEADERS_COUNT
97        )));
98    }
99
100    for (key, value) in &req.headers {
101        if value.len() > MAX_HEADER_VALUE_LENGTH {
102            return Err(ScrapeError::InvalidRequest(format!(
103                "Header '{}' value too long: {} > {} characters",
104                key,
105                value.len(),
106                MAX_HEADER_VALUE_LENGTH
107            )));
108        }
109    }
110
111    // Actions validation
112    if req.actions.len() > MAX_ACTIONS_COUNT {
113        return Err(ScrapeError::InvalidRequest(format!(
114            "Too many browser actions: {} > {}",
115            req.actions.len(),
116            MAX_ACTIONS_COUNT
117        )));
118    }
119
120    // Selector validation
121    if let Some(ref selector) = req.wait_for_selector {
122        validate_css_selector(selector)?;
123    }
124
125    for tag in &req.include_tags {
126        validate_css_selector(tag)?;
127    }
128
129    for tag in &req.exclude_tags {
130        validate_css_selector(tag)?;
131    }
132
133    // Format validation
134    let valid_formats = [
135        "markdown",
136        "html",
137        "rawHtml",
138        "links",
139        "images",
140        "screenshot",
141    ];
142
143    for format in &req.formats {
144        if !valid_formats.contains(&format.as_str()) {
145            return Err(ScrapeError::UnsupportedFormat(format.clone()));
146        }
147    }
148
149    Ok(())
150}
151
152/// Validate map request
153pub async fn validate_map_request(req: &MapRequest) -> Result<(), ScrapeError> {
154    validate_url(&req.url).await?;
155
156    if let Some(limit) = req.limit {
157        if limit > MAX_MAP_LIMIT {
158            return Err(ScrapeError::InvalidRequest(format!(
159                "Map limit too large: {} > {}",
160                limit, MAX_MAP_LIMIT
161            )));
162        }
163    }
164
165    Ok(())
166}
167
168/// Validate crawl request
169pub async fn validate_crawl_request(req: &CrawlRequest) -> Result<(), ScrapeError> {
170    validate_url(&req.url).await?;
171
172    if req.limit > MAX_CRAWL_LIMIT {
173        return Err(ScrapeError::InvalidRequest(format!(
174            "Crawl limit too large: {} > {}",
175            req.limit, MAX_CRAWL_LIMIT
176        )));
177    }
178
179    Ok(())
180}
181
182/// Validate search request
183pub fn validate_search_request(req: &SearchRequest) -> Result<(), ScrapeError> {
184    if req.query.is_empty() {
185        return Err(ScrapeError::InvalidRequest(
186            "Search query cannot be empty".to_string(),
187        ));
188    }
189
190    if req.limit > MAX_SEARCH_LIMIT {
191        return Err(ScrapeError::InvalidRequest(format!(
192            "Search limit too large: {} > {}",
193            req.limit, MAX_SEARCH_LIMIT
194        )));
195    }
196
197    Ok(())
198}
199
200/// Get timeout duration for crawl operations
201pub fn get_crawl_timeout() -> Duration {
202    Duration::from_millis(MAX_CRAWL_TIMEOUT_MS)
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[tokio::test]
210    async fn test_validate_url() {
211        assert!(validate_url("https://example.com").await.is_ok());
212        assert!(validate_url("").await.is_err());
213        assert!(validate_url(&"a".repeat(3000)).await.is_err());
214        assert!(validate_url("not-a-url").await.is_err());
215    }
216
217    #[test]
218    fn test_validate_css_selector() {
219        assert!(validate_css_selector("div.class").is_ok());
220        assert!(validate_css_selector("#id").is_ok());
221        assert!(validate_css_selector("<script>alert('xss')</script>").is_err());
222        assert!(validate_css_selector("javascript:void(0)").is_err());
223        assert!(validate_css_selector(&"a".repeat(2000)).is_err());
224    }
225
226    #[tokio::test]
227    async fn test_validate_scrape_request() {
228        let valid_req = ScrapeRequest {
229            url: "https://example.com".to_string(),
230            formats: vec!["markdown".to_string()],
231            headers: Default::default(),
232            include_tags: vec![],
233            exclude_tags: vec![],
234            only_main_content: true,
235            timeout: 30000,
236            wait_for: 0,
237            remove_base64_images: true,
238            skip_tls_verification: false,
239            engine: "auto".to_string(),
240            wait_for_selector: None,
241            actions: vec![],
242            screenshot: false,
243            screenshot_format: "png".to_string(),
244        };
245
246        assert!(validate_scrape_request(&valid_req).await.is_ok());
247
248        // Test timeout validation
249        let mut invalid_req = valid_req.clone();
250        invalid_req.timeout = 400_000; // > 5 minutes
251        assert!(validate_scrape_request(&invalid_req).await.is_err());
252
253        // Test format validation
254        let mut invalid_req = valid_req.clone();
255        invalid_req.formats = vec!["invalid_format".to_string()];
256        assert!(validate_scrape_request(&invalid_req).await.is_err());
257    }
258}