Skip to main content

research_master/utils/
validate.rs

1//! Input validation utilities for paper IDs, URLs, and filenames.
2//!
3//! This module provides validation functions to prevent injection attacks
4//! and path traversal vulnerabilities.
5
6use thiserror::Error;
7
8/// Validation error types
9#[derive(Error, Debug, PartialEq)]
10pub enum ValidationError {
11    #[error("Invalid paper ID: {0}")]
12    InvalidPaperId(String),
13
14    #[error("Invalid URL: {0}")]
15    InvalidUrl(String),
16
17    #[error("Invalid DOI format: {0}")]
18    InvalidDoi(String),
19
20    #[error("Invalid filename: contains disallowed characters")]
21    InvalidFilename,
22
23    #[error("URL contains potentially dangerous characters")]
24    DangerousUrl,
25
26    #[error("Path traversal detected: {0}")]
27    PathTraversal(String),
28}
29
30/// Validate a paper ID to prevent injection attacks
31///
32/// Paper IDs should only contain alphanumeric characters, hyphens, underscores,
33/// dots, slashes (for some formats like arXiv), and the prefix "arxiv:", "hal-", "PMC", etc.
34///
35/// Returns `Ok(String)` if valid, or `Err(ValidationError)` if invalid.
36pub fn sanitize_paper_id(id: &str) -> Result<String, ValidationError> {
37    let id = id.trim();
38
39    if id.is_empty() {
40        return Err(ValidationError::InvalidPaperId("empty ID".to_string()));
41    }
42
43    // Check for path traversal attempts
44    if id.contains("..") || id.contains("./") || id.contains(".\\") {
45        return Err(ValidationError::PathTraversal(id.to_string()));
46    }
47
48    // Check for null bytes
49    if id.contains('\0') {
50        return Err(ValidationError::InvalidPaperId(
51            "contains null byte".to_string(),
52        ));
53    }
54
55    // Check for control characters (except tab, newline, carriage return)
56    for ch in id.chars() {
57        if ch.is_control() && ch != '\t' && ch != '\n' && ch != '\r' {
58            return Err(ValidationError::InvalidPaperId(
59                "contains control characters".to_string(),
60            ));
61        }
62    }
63
64    // Check for shell metacharacters that could enable injection
65    let dangerous_chars = [
66        ';', '|', '&', '$', '`', '(', ')', '{', '}', '[', ']', '<', '>', '*', '?', '!',
67    ];
68    for ch in dangerous_chars {
69        if id.contains(ch) {
70            return Err(ValidationError::InvalidPaperId(format!(
71                "contains dangerous character: {}",
72                ch
73            )));
74        }
75    }
76
77    Ok(id.to_string())
78}
79
80/// Validate a URL to prevent injection and SSRF attacks
81///
82/// Returns `Ok(String)` if valid, or `Err(ValidationError)` if invalid.
83pub fn validate_url(url: &str) -> Result<String, ValidationError> {
84    let url = url.trim();
85
86    if url.is_empty() {
87        return Err(ValidationError::InvalidUrl("empty URL".to_string()));
88    }
89
90    // Check for null bytes
91    if url.contains('\0') {
92        return Err(ValidationError::InvalidUrl(
93            "contains null byte".to_string(),
94        ));
95    }
96
97    // Parse URL to validate structure
98    let parsed = url::Url::parse(url).map_err(|e| ValidationError::InvalidUrl(e.to_string()))?;
99
100    // Only allow HTTP and HTTPS schemes
101    match parsed.scheme() {
102        "http" | "https" => {}
103        _ => {
104            return Err(ValidationError::InvalidUrl(format!(
105                "invalid scheme: {}",
106                parsed.scheme()
107            )))
108        }
109    }
110
111    // Check for dangerous URL patterns
112    let url_lower = url.to_lowercase();
113
114    // Check for embedded newlines or nulls (already checked above, but double-check)
115    if url.contains('\n') || url.contains('\r') || url.contains('\0') {
116        return Err(ValidationError::DangerousUrl);
117    }
118
119    // Check for data: or javascript: URLs (already filtered by scheme check, but be explicit)
120    if url_lower.starts_with("data:") || url_lower.starts_with("javascript:") {
121        return Err(ValidationError::DangerousUrl);
122    }
123
124    // Check for internal IP addresses (basic check for SSRF prevention)
125    if let Some(host) = parsed.host_str() {
126        // Check for localhost variants
127        let host_lower = host.to_lowercase();
128        if host_lower == "localhost"
129            || host_lower == "127.0.0.1"
130            || host_lower == "::1"
131            || host_lower == "0.0.0.0"
132        {
133            return Err(ValidationError::DangerousUrl);
134        }
135
136        // Basic IPv4 check (simplified - doesn't catch all cases)
137        if host_lower.parse::<std::net::Ipv4Addr>().is_ok() {
138            let octets: Vec<&str> = host_lower.split('.').collect();
139            if octets.len() == 4 {
140                if let Ok(first) = octets[0].parse::<u8>() {
141                    // Check for private IP ranges (simplified)
142                    if first == 10
143                        || (first == 172
144                            && octets[1]
145                                .parse::<u8>()
146                                .is_ok_and(|v| (16..=31).contains(&v)))
147                        || (first == 192 && octets[1] == "168")
148                    {
149                        return Err(ValidationError::DangerousUrl);
150                    }
151                }
152            }
153        }
154    }
155
156    Ok(url.to_string())
157}
158
159/// Validate and sanitize a DOI
160///
161/// DOIs have the format "10.xxxx/xxxxxx" where xxxx is a registrant code
162/// and xxxxxx is an item ID.
163pub fn validate_doi(doi: &str) -> Result<String, ValidationError> {
164    let doi = doi.trim().to_lowercase();
165
166    if doi.is_empty() {
167        return Err(ValidationError::InvalidDoi("empty DOI".to_string()));
168    }
169
170    // Remove any URL prefix if present first
171    let doi = doi.strip_prefix("doi:").unwrap_or(&doi);
172    let doi = doi.strip_prefix("https://doi.org/").unwrap_or(doi);
173    let doi = doi.strip_prefix("http://doi.org/").unwrap_or(doi);
174
175    // DOI must start with "10."
176    if !doi.starts_with("10.") {
177        return Err(ValidationError::InvalidDoi(
178            "DOI must start with '10.'".to_string(),
179        ));
180    }
181
182    // DOI must contain a slash after the prefix
183    if !doi.contains('/') {
184        return Err(ValidationError::InvalidDoi(
185            "DOI must contain a slash".to_string(),
186        ));
187    }
188
189    // Check for path traversal in DOI (shouldn't happen but be safe)
190    if doi.contains("..") {
191        return Err(ValidationError::InvalidDoi(
192            "path traversal detected".to_string(),
193        ));
194    }
195
196    Ok(doi.to_string())
197}
198
199/// Sanitize a filename to prevent path traversal and other attacks
200///
201/// Removes path separators and dangerous characters, limits length,
202/// and ensures the filename is safe to use.
203pub fn sanitize_filename(filename: &str) -> Result<String, ValidationError> {
204    let filename = filename.trim();
205
206    if filename.is_empty() {
207        return Err(ValidationError::InvalidFilename);
208    }
209
210    // Check for path traversal
211    if filename.contains("..")
212        || filename.starts_with('/')
213        || filename.starts_with('\\')
214        || filename.contains(":/")
215        || filename.contains(":\\")
216    {
217        return Err(ValidationError::PathTraversal(filename.to_string()));
218    }
219
220    // Remove any null bytes
221    let filename = filename.replace('\0', "");
222
223    // Keep only safe characters: alphanumeric, dash, underscore, dot, space
224    let mut sanitized = String::new();
225    for ch in filename.chars() {
226        if ch.is_alphanumeric() || ch == '-' || ch == '_' || ch == '.' || ch == ' ' {
227            sanitized.push(ch);
228        }
229        // Replace other characters with underscore
230    }
231
232    // Limit filename length
233    const MAX_FILENAME_LENGTH: usize = 255;
234    if sanitized.len() > MAX_FILENAME_LENGTH {
235        let ext_pos = sanitized.rfind('.').unwrap_or(sanitized.len());
236        let ext = sanitized.split_at(ext_pos).1;
237        let base_len = MAX_FILENAME_LENGTH.saturating_sub(ext.len());
238        sanitized = format!("{}{}", &sanitized[..base_len.min(sanitized.len())], ext);
239    }
240
241    if sanitized.is_empty() {
242        return Err(ValidationError::InvalidFilename);
243    }
244
245    Ok(sanitized)
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_sanitize_paper_id_valid() {
254        assert!(sanitize_paper_id("2301.12345").is_ok());
255        assert!(sanitize_paper_id("arxiv:2301.12345").is_ok());
256        assert!(sanitize_paper_id("PMC12345").is_ok());
257        assert!(sanitize_paper_id("hal-12345").is_ok());
258        assert!(sanitize_paper_id("10.1234/test").is_ok());
259    }
260
261    #[test]
262    fn test_sanitize_paper_id_empty() {
263        assert!(sanitize_paper_id("").is_err());
264        assert!(sanitize_paper_id("   ").is_err());
265    }
266
267    #[test]
268    fn test_sanitize_paper_id_path_traversal() {
269        assert!(sanitize_paper_id("../etc/passwd").is_err());
270        assert!(sanitize_paper_id("foo/../../bar").is_err());
271        assert!(sanitize_paper_id("foo\\..\\bar").is_err());
272    }
273
274    #[test]
275    fn test_sanitize_paper_id_dangerous_chars() {
276        assert!(sanitize_paper_id("foo;rm -rf /").is_err());
277        assert!(sanitize_paper_id("foo|whoami").is_err());
278        assert!(sanitize_paper_id("foo`ls`").is_err());
279        assert!(sanitize_paper_id("foo$(whoami)").is_err());
280    }
281
282    #[test]
283    fn test_validate_url_valid() {
284        assert!(validate_url("https://api.semanticscholar.org/graph/v1/paper/search").is_ok());
285        assert!(validate_url("http://export.arxiv.org/api/query").is_ok());
286    }
287
288    #[test]
289    fn test_validate_url_invalid() {
290        assert!(validate_url("").is_err());
291        assert!(validate_url("ftp://example.com").is_err());
292        assert!(validate_url("javascript:alert(1)").is_err());
293        assert!(validate_url("http://localhost:8000").is_err());
294        assert!(validate_url("http://127.0.0.1:8000").is_err());
295    }
296
297    #[test]
298    fn test_validate_doi_valid() {
299        assert!(validate_doi("10.1234/abc123").is_ok());
300        assert!(validate_doi("10.1038/nature12345").is_ok());
301        // Without prefix
302        assert_eq!(
303            validate_doi("doi:10.1234/abc123").unwrap(),
304            "10.1234/abc123"
305        );
306        assert_eq!(
307            validate_doi("https://doi.org/10.1234/abc123").unwrap(),
308            "10.1234/abc123"
309        );
310    }
311
312    #[test]
313    fn test_validate_doi_invalid() {
314        assert!(validate_doi("").is_err());
315        assert!(validate_doi("10.1234").is_err()); // No slash
316        assert!(validate_doi("9.1234/abc").is_err()); // Doesn't start with 10
317        assert!(validate_doi("10.1234/../abc").is_err()); // Path traversal
318    }
319
320    #[test]
321    fn test_sanitize_filename_valid() {
322        assert_eq!(sanitize_filename("my_paper.pdf").unwrap(), "my_paper.pdf");
323        assert_eq!(
324            sanitize_filename("2023-01-15-test.pdf").unwrap(),
325            "2023-01-15-test.pdf"
326        );
327        // Parentheses are removed, only alphanumeric, dash, underscore, dot, space allowed
328        assert_eq!(
329            sanitize_filename("Test Paper Final.pdf").unwrap(),
330            "Test Paper Final.pdf"
331        );
332    }
333
334    #[test]
335    fn test_sanitize_filename_dangerous() {
336        assert!(sanitize_filename("../etc/passwd").is_err());
337        assert!(sanitize_filename("/etc/passwd").is_err());
338        assert!(sanitize_filename("C:\\Windows\\System32").is_err());
339        assert!(sanitize_filename("../../../etc/passwd").is_err());
340    }
341
342    #[test]
343    fn test_sanitize_filename_removes_dangerous_chars() {
344        // Dangerous characters should be removed or replaced with underscore
345        let result = sanitize_filename("test;rm -rf /;file.pdf").unwrap();
346        assert!(!result.contains(';'), "semicolon should be removed");
347        // The filename becomes "testrm -rf  file.pdf" because we only remove special chars
348        // The test is checking that ; is removed which it is
349    }
350}