use thiserror::Error;
#[derive(Error, Debug, PartialEq)]
pub enum ValidationError {
#[error("Invalid paper ID: {0}")]
InvalidPaperId(String),
#[error("Invalid URL: {0}")]
InvalidUrl(String),
#[error("Invalid DOI format: {0}")]
InvalidDoi(String),
#[error("Invalid filename: contains disallowed characters")]
InvalidFilename,
#[error("URL contains potentially dangerous characters")]
DangerousUrl,
#[error("Path traversal detected: {0}")]
PathTraversal(String),
}
pub fn sanitize_paper_id(id: &str) -> Result<String, ValidationError> {
let id = id.trim();
if id.is_empty() {
return Err(ValidationError::InvalidPaperId("empty ID".to_string()));
}
if id.contains("..") || id.contains("./") || id.contains(".\\") {
return Err(ValidationError::PathTraversal(id.to_string()));
}
if id.contains('\0') {
return Err(ValidationError::InvalidPaperId(
"contains null byte".to_string(),
));
}
for ch in id.chars() {
if ch.is_control() && ch != '\t' && ch != '\n' && ch != '\r' {
return Err(ValidationError::InvalidPaperId(
"contains control characters".to_string(),
));
}
}
let dangerous_chars = [
';', '|', '&', '$', '`', '(', ')', '{', '}', '[', ']', '<', '>', '*', '?', '!',
];
for ch in dangerous_chars {
if id.contains(ch) {
return Err(ValidationError::InvalidPaperId(format!(
"contains dangerous character: {}",
ch
)));
}
}
Ok(id.to_string())
}
pub fn validate_url(url: &str) -> Result<String, ValidationError> {
let url = url.trim();
if url.is_empty() {
return Err(ValidationError::InvalidUrl("empty URL".to_string()));
}
if url.contains('\0') {
return Err(ValidationError::InvalidUrl(
"contains null byte".to_string(),
));
}
let parsed = url::Url::parse(url).map_err(|e| ValidationError::InvalidUrl(e.to_string()))?;
match parsed.scheme() {
"http" | "https" => {}
_ => {
return Err(ValidationError::InvalidUrl(format!(
"invalid scheme: {}",
parsed.scheme()
)))
}
}
let url_lower = url.to_lowercase();
if url.contains('\n') || url.contains('\r') || url.contains('\0') {
return Err(ValidationError::DangerousUrl);
}
if url_lower.starts_with("data:") || url_lower.starts_with("javascript:") {
return Err(ValidationError::DangerousUrl);
}
if let Some(host) = parsed.host_str() {
let host_lower = host.to_lowercase();
if host_lower == "localhost"
|| host_lower == "127.0.0.1"
|| host_lower == "::1"
|| host_lower == "0.0.0.0"
{
return Err(ValidationError::DangerousUrl);
}
if host_lower.parse::<std::net::Ipv4Addr>().is_ok() {
let octets: Vec<&str> = host_lower.split('.').collect();
if octets.len() == 4 {
if let Ok(first) = octets[0].parse::<u8>() {
if first == 10
|| (first == 172
&& octets[1]
.parse::<u8>()
.is_ok_and(|v| (16..=31).contains(&v)))
|| (first == 192 && octets[1] == "168")
{
return Err(ValidationError::DangerousUrl);
}
}
}
}
}
Ok(url.to_string())
}
pub fn validate_doi(doi: &str) -> Result<String, ValidationError> {
let doi = doi.trim().to_lowercase();
if doi.is_empty() {
return Err(ValidationError::InvalidDoi("empty DOI".to_string()));
}
let doi = doi.strip_prefix("doi:").unwrap_or(&doi);
let doi = doi.strip_prefix("https://doi.org/").unwrap_or(doi);
let doi = doi.strip_prefix("http://doi.org/").unwrap_or(doi);
if !doi.starts_with("10.") {
return Err(ValidationError::InvalidDoi(
"DOI must start with '10.'".to_string(),
));
}
if !doi.contains('/') {
return Err(ValidationError::InvalidDoi(
"DOI must contain a slash".to_string(),
));
}
if doi.contains("..") {
return Err(ValidationError::InvalidDoi(
"path traversal detected".to_string(),
));
}
Ok(doi.to_string())
}
pub fn sanitize_filename(filename: &str) -> Result<String, ValidationError> {
let filename = filename.trim();
if filename.is_empty() {
return Err(ValidationError::InvalidFilename);
}
if filename.contains("..")
|| filename.starts_with('/')
|| filename.starts_with('\\')
|| filename.contains(":/")
|| filename.contains(":\\")
{
return Err(ValidationError::PathTraversal(filename.to_string()));
}
let filename = filename.replace('\0', "");
let mut sanitized = String::new();
for ch in filename.chars() {
if ch.is_alphanumeric() || ch == '-' || ch == '_' || ch == '.' || ch == ' ' {
sanitized.push(ch);
}
}
const MAX_FILENAME_LENGTH: usize = 255;
if sanitized.len() > MAX_FILENAME_LENGTH {
let ext_pos = sanitized.rfind('.').unwrap_or(sanitized.len());
let ext = sanitized.split_at(ext_pos).1;
let base_len = MAX_FILENAME_LENGTH.saturating_sub(ext.len());
sanitized = format!("{}{}", &sanitized[..base_len.min(sanitized.len())], ext);
}
if sanitized.is_empty() {
return Err(ValidationError::InvalidFilename);
}
Ok(sanitized)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sanitize_paper_id_valid() {
assert!(sanitize_paper_id("2301.12345").is_ok());
assert!(sanitize_paper_id("arxiv:2301.12345").is_ok());
assert!(sanitize_paper_id("PMC12345").is_ok());
assert!(sanitize_paper_id("hal-12345").is_ok());
assert!(sanitize_paper_id("10.1234/test").is_ok());
}
#[test]
fn test_sanitize_paper_id_empty() {
assert!(sanitize_paper_id("").is_err());
assert!(sanitize_paper_id(" ").is_err());
}
#[test]
fn test_sanitize_paper_id_path_traversal() {
assert!(sanitize_paper_id("../etc/passwd").is_err());
assert!(sanitize_paper_id("foo/../../bar").is_err());
assert!(sanitize_paper_id("foo\\..\\bar").is_err());
}
#[test]
fn test_sanitize_paper_id_dangerous_chars() {
assert!(sanitize_paper_id("foo;rm -rf /").is_err());
assert!(sanitize_paper_id("foo|whoami").is_err());
assert!(sanitize_paper_id("foo`ls`").is_err());
assert!(sanitize_paper_id("foo$(whoami)").is_err());
}
#[test]
fn test_validate_url_valid() {
assert!(validate_url("https://api.semanticscholar.org/graph/v1/paper/search").is_ok());
assert!(validate_url("http://export.arxiv.org/api/query").is_ok());
}
#[test]
fn test_validate_url_invalid() {
assert!(validate_url("").is_err());
assert!(validate_url("ftp://example.com").is_err());
assert!(validate_url("javascript:alert(1)").is_err());
assert!(validate_url("http://localhost:8000").is_err());
assert!(validate_url("http://127.0.0.1:8000").is_err());
}
#[test]
fn test_validate_doi_valid() {
assert!(validate_doi("10.1234/abc123").is_ok());
assert!(validate_doi("10.1038/nature12345").is_ok());
assert_eq!(
validate_doi("doi:10.1234/abc123").unwrap(),
"10.1234/abc123"
);
assert_eq!(
validate_doi("https://doi.org/10.1234/abc123").unwrap(),
"10.1234/abc123"
);
}
#[test]
fn test_validate_doi_invalid() {
assert!(validate_doi("").is_err());
assert!(validate_doi("10.1234").is_err()); assert!(validate_doi("9.1234/abc").is_err()); assert!(validate_doi("10.1234/../abc").is_err()); }
#[test]
fn test_sanitize_filename_valid() {
assert_eq!(sanitize_filename("my_paper.pdf").unwrap(), "my_paper.pdf");
assert_eq!(
sanitize_filename("2023-01-15-test.pdf").unwrap(),
"2023-01-15-test.pdf"
);
assert_eq!(
sanitize_filename("Test Paper Final.pdf").unwrap(),
"Test Paper Final.pdf"
);
}
#[test]
fn test_sanitize_filename_dangerous() {
assert!(sanitize_filename("../etc/passwd").is_err());
assert!(sanitize_filename("/etc/passwd").is_err());
assert!(sanitize_filename("C:\\Windows\\System32").is_err());
assert!(sanitize_filename("../../../etc/passwd").is_err());
}
#[test]
fn test_sanitize_filename_removes_dangerous_chars() {
let result = sanitize_filename("test;rm -rf /;file.pdf").unwrap();
assert!(!result.contains(';'), "semicolon should be removed");
}
}