use std::fs;
use std::path::{Path, PathBuf};
use super::error::{ContractsError, ContractsResult};
pub const MAX_FILE_SIZE: u64 = 10 * 1024 * 1024;
pub const WARN_FILE_SIZE: u64 = 1024 * 1024;
pub const MAX_CFG_DEPTH: usize = 1000;
pub const MAX_SSA_NODES: usize = 100_000;
pub const MAX_AST_DEPTH: usize = 100;
pub const MAX_FUNCTION_NAME_LEN: usize = 256;
pub const MAX_CONDITIONS_PER_FUNCTION: usize = 100;
const BLOCKED_PREFIXES: &[&str] = &[
"/etc/",
"/etc/passwd",
"/etc/shadow",
"/root/",
"/sys/",
"/proc/",
"/dev/",
"/var/run/",
"/var/log/",
"/private/etc/", "C:\\Windows\\", "C:\\System32\\", ];
pub fn validate_file_path(path: &Path) -> ContractsResult<PathBuf> {
if !path.exists() {
return Err(ContractsError::FileNotFound {
path: path.to_path_buf(),
});
}
let canonical = fs::canonicalize(path).map_err(|_| ContractsError::FileNotFound {
path: path.to_path_buf(),
})?;
let canonical_str = canonical.to_string_lossy();
for blocked in BLOCKED_PREFIXES {
if canonical_str.starts_with(blocked) || canonical_str == blocked.trim_end_matches('/') {
return Err(ContractsError::PathTraversal {
path: path.to_path_buf(),
});
}
}
if canonical.to_str().is_none() {
return Err(ContractsError::PathTraversal {
path: path.to_path_buf(),
});
}
Ok(canonical)
}
pub fn validate_file_path_in_project(path: &Path, project_root: &Path) -> ContractsResult<PathBuf> {
let canonical = validate_file_path(path)?;
let canonical_root =
fs::canonicalize(project_root).map_err(|_| ContractsError::FileNotFound {
path: project_root.to_path_buf(),
})?;
if !canonical.starts_with(&canonical_root) {
return Err(ContractsError::PathTraversal {
path: path.to_path_buf(),
});
}
Ok(canonical)
}
pub fn has_path_traversal_pattern(path: &Path) -> bool {
let path_str = path.to_string_lossy();
if path_str.contains("..") {
return true;
}
if path_str.contains('\0') {
return true;
}
false
}
pub fn validate_line_numbers(start: u32, end: u32, max: u32) -> ContractsResult<()> {
if start == 0 {
return Err(ContractsError::LineOutsideFunction {
line: start,
function: "unknown".to_string(),
start: 1,
end: max,
});
}
if end == 0 {
return Err(ContractsError::LineOutsideFunction {
line: end,
function: "unknown".to_string(),
start: 1,
end: max,
});
}
if start > end {
return Err(ContractsError::LineOutsideFunction {
line: start,
function: "unknown".to_string(),
start: 1,
end,
});
}
if start > max {
return Err(ContractsError::LineOutsideFunction {
line: start,
function: "unknown".to_string(),
start: 1,
end: max,
});
}
if end > max {
return Err(ContractsError::LineOutsideFunction {
line: end,
function: "unknown".to_string(),
start: 1,
end: max,
});
}
Ok(())
}
pub fn validate_function_name(name: &str) -> ContractsResult<()> {
if name.is_empty() {
return Err(ContractsError::InvalidFunctionName {
reason: "function name cannot be empty".to_string(),
});
}
if name.len() > MAX_FUNCTION_NAME_LEN {
return Err(ContractsError::InvalidFunctionName {
reason: format!(
"function name too long ({} chars, max {})",
name.len(),
MAX_FUNCTION_NAME_LEN
),
});
}
let suspicious_chars = [
';', '(', ')', '{', '}', '[', ']', '`', '"', '\'', '\\', '/', '\0',
];
for c in name.chars() {
if suspicious_chars.contains(&c) {
return Err(ContractsError::InvalidFunctionName {
reason: format!("function name contains invalid character: '{}'", c),
});
}
}
if let Some(first) = name.chars().next() {
if !first.is_alphabetic() && first != '_' {
return Err(ContractsError::InvalidFunctionName {
reason: "function name must start with letter or underscore".to_string(),
});
}
}
Ok(())
}
pub fn read_file_safe(path: &Path) -> ContractsResult<String> {
let canonical = validate_file_path(path)?;
let metadata = fs::metadata(&canonical)?;
let size = metadata.len();
if size > MAX_FILE_SIZE {
return Err(ContractsError::FileTooLarge {
path: path.to_path_buf(),
bytes: size,
max_bytes: MAX_FILE_SIZE,
});
}
let content = fs::read(&canonical)?;
String::from_utf8(content).map_err(|_| ContractsError::ParseError {
file: path.to_path_buf(),
message: "file is not valid UTF-8".to_string(),
})
}
pub fn read_file_safe_with_warning<F>(path: &Path, warn_fn: Option<F>) -> ContractsResult<String>
where
F: FnOnce(&str),
{
let canonical = validate_file_path(path)?;
let metadata = fs::metadata(&canonical)?;
let size = metadata.len();
if size > MAX_FILE_SIZE {
return Err(ContractsError::FileTooLarge {
path: path.to_path_buf(),
bytes: size,
max_bytes: MAX_FILE_SIZE,
});
}
if size > WARN_FILE_SIZE {
let warning = format!(
"Warning: {} is large ({:.1} MB), analysis may be slow",
path.display(),
size as f64 / 1024.0 / 1024.0
);
if let Some(f) = warn_fn {
f(&warning);
} else {
eprintln!("{}", warning);
}
}
let content = fs::read(&canonical)?;
String::from_utf8(content).map_err(|_| ContractsError::ParseError {
file: path.to_path_buf(),
message: "file is not valid UTF-8".to_string(),
})
}
pub fn check_depth_limit(current_depth: usize, max_depth: usize) -> ContractsResult<()> {
if current_depth >= max_depth {
Err(ContractsError::SliceDepthExceeded {
max_depth: max_depth as u32,
})
} else {
Ok(())
}
}
pub fn check_ssa_node_limit(node_count: usize) -> ContractsResult<()> {
if node_count > MAX_SSA_NODES {
Err(ContractsError::SsaTooLarge {
nodes: node_count as u32,
max_nodes: MAX_SSA_NODES as u32,
})
} else {
Ok(())
}
}
pub fn check_ast_depth(depth: usize, file: &Path) -> ContractsResult<()> {
if depth > MAX_AST_DEPTH {
Err(ContractsError::AstTooDeep {
file: file.to_path_buf(),
depth: depth as u32,
max_depth: MAX_AST_DEPTH as u32,
})
} else {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::{tempdir, NamedTempFile};
#[test]
fn test_validate_file_path_normal() {
let file = NamedTempFile::new().unwrap();
let path = file.path();
let result = validate_file_path(path);
assert!(result.is_ok());
let canonical = result.unwrap();
assert!(canonical.is_absolute());
}
#[test]
fn test_validate_file_path_not_exists() {
let result = validate_file_path(Path::new("/nonexistent/file.py"));
assert!(result.is_err());
match result.unwrap_err() {
ContractsError::FileNotFound { path } => {
assert!(path.to_string_lossy().contains("nonexistent"));
}
_ => panic!("Expected FileNotFound error"),
}
}
#[test]
fn test_validate_file_path_traversal_rejected() {
let temp = tempdir().unwrap();
let subdir = temp.path().join("subdir");
fs::create_dir(&subdir).unwrap();
let file_path = temp.path().join("secret.txt");
fs::write(&file_path, "secret").unwrap();
let suspicious = subdir.join("..").join("secret.txt");
assert!(has_path_traversal_pattern(&suspicious));
}
#[test]
fn test_validate_file_path_symlink_outside_project() {
let project = tempdir().unwrap();
let outside = tempdir().unwrap();
let outside_file = outside.path().join("secret.txt");
fs::write(&outside_file, "secret").unwrap();
let symlink_path = project.path().join("link.txt");
#[cfg(unix)]
{
std::os::unix::fs::symlink(&outside_file, &symlink_path).unwrap();
let result = validate_file_path_in_project(&symlink_path, project.path());
assert!(result.is_err());
match result.unwrap_err() {
ContractsError::PathTraversal { .. } => {}
e => panic!("Expected PathTraversal error, got {:?}", e),
}
}
}
#[test]
fn test_validate_file_path_system_dir_rejected() {
let blocked = [
"/etc/passwd",
"/root/.bashrc",
"/sys/kernel/config",
"/proc/self/status",
];
for path_str in blocked {
let path = Path::new(path_str);
if path.exists() {
let result = validate_file_path(path);
assert!(result.is_err(), "Should reject system path: {}", path_str);
}
}
}
#[test]
fn test_validate_line_numbers_valid_range() {
assert!(validate_line_numbers(1, 10, 100).is_ok());
assert!(validate_line_numbers(1, 1, 100).is_ok()); assert!(validate_line_numbers(50, 100, 100).is_ok()); }
#[test]
fn test_validate_line_numbers_start_after_end() {
let result = validate_line_numbers(10, 5, 100);
assert!(result.is_err());
match result.unwrap_err() {
ContractsError::LineOutsideFunction { line, .. } => {
assert_eq!(line, 10);
}
_ => panic!("Expected LineOutsideFunction error"),
}
}
#[test]
fn test_validate_line_numbers_exceeds_max() {
let result = validate_line_numbers(1, 200, 100);
assert!(result.is_err());
match result.unwrap_err() {
ContractsError::LineOutsideFunction { line, .. } => {
assert_eq!(line, 200);
}
_ => panic!("Expected LineOutsideFunction error"),
}
}
#[test]
fn test_validate_line_numbers_zero() {
assert!(validate_line_numbers(0, 10, 100).is_err());
assert!(validate_line_numbers(1, 0, 100).is_err());
}
#[test]
fn test_validate_function_name_valid() {
assert!(validate_function_name("my_function").is_ok());
assert!(validate_function_name("_private").is_ok());
assert!(validate_function_name("CamelCase").is_ok());
assert!(validate_function_name("func123").is_ok());
assert!(validate_function_name("__dunder__").is_ok());
}
#[test]
fn test_validate_function_name_empty() {
let result = validate_function_name("");
assert!(result.is_err());
match result.unwrap_err() {
ContractsError::InvalidFunctionName { reason } => {
assert!(reason.contains("empty"));
}
_ => panic!("Expected InvalidFunctionName error"),
}
}
#[test]
fn test_validate_function_name_invalid_chars() {
let invalid_names = [
"func;drop", "func()", "func{}", "func`cmd`", "func\"name", "func\\name", "func/name", ];
for name in invalid_names {
let result = validate_function_name(name);
assert!(result.is_err(), "Should reject: {}", name);
}
}
#[test]
fn test_validate_function_name_starts_with_digit() {
let result = validate_function_name("123func");
assert!(result.is_err());
match result.unwrap_err() {
ContractsError::InvalidFunctionName { reason } => {
assert!(reason.contains("start with"));
}
_ => panic!("Expected InvalidFunctionName error"),
}
}
#[test]
fn test_validate_function_name_too_long() {
let long_name = "a".repeat(MAX_FUNCTION_NAME_LEN + 1);
let result = validate_function_name(&long_name);
assert!(result.is_err());
match result.unwrap_err() {
ContractsError::InvalidFunctionName { reason } => {
assert!(reason.contains("too long"));
}
_ => panic!("Expected InvalidFunctionName error"),
}
}
#[test]
fn test_read_file_safe_normal() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "def hello():\n print('hello')").unwrap();
let content = read_file_safe(file.path()).unwrap();
assert!(content.contains("def hello"));
assert!(content.contains("print"));
}
#[test]
fn test_read_file_safe_not_exists() {
let result = read_file_safe(Path::new("/nonexistent/file.py"));
assert!(result.is_err());
match result.unwrap_err() {
ContractsError::FileNotFound { .. } => {}
e => panic!("Expected FileNotFound error, got {:?}", e),
}
}
#[test]
fn test_read_file_safe_too_large() {
let temp = tempdir().unwrap();
let _large_file = temp.path().join("large.txt");
let max_file_size = std::hint::black_box(MAX_FILE_SIZE);
assert_eq!(max_file_size, 10 * 1024 * 1024);
}
#[test]
fn test_read_file_safe_not_utf8() {
let temp = tempdir().unwrap();
let binary_file = temp.path().join("binary.bin");
let invalid_utf8 = vec![0xFF, 0xFE, 0x00, 0x01];
fs::write(&binary_file, invalid_utf8).unwrap();
let result = read_file_safe(&binary_file);
assert!(result.is_err());
match result.unwrap_err() {
ContractsError::ParseError { message, .. } => {
assert!(message.contains("UTF-8"));
}
e => panic!("Expected ParseError, got {:?}", e),
}
}
#[test]
fn test_resource_limits_constants() {
assert_eq!(MAX_FILE_SIZE, 10 * 1024 * 1024); assert_eq!(MAX_CFG_DEPTH, 1000); assert_eq!(MAX_SSA_NODES, 100_000); assert_eq!(MAX_AST_DEPTH, 100); }
#[test]
fn test_check_depth_limit() {
assert!(check_depth_limit(0, 1000).is_ok());
assert!(check_depth_limit(999, 1000).is_ok());
assert!(check_depth_limit(1000, 1000).is_err());
assert!(check_depth_limit(1001, 1000).is_err());
}
#[test]
fn test_check_ssa_node_limit() {
assert!(check_ssa_node_limit(0).is_ok());
assert!(check_ssa_node_limit(MAX_SSA_NODES).is_ok());
assert!(check_ssa_node_limit(MAX_SSA_NODES + 1).is_err());
}
#[test]
fn test_check_ast_depth() {
let file = Path::new("test.py");
assert!(check_ast_depth(0, file).is_ok());
assert!(check_ast_depth(MAX_AST_DEPTH, file).is_ok());
assert!(check_ast_depth(MAX_AST_DEPTH + 1, file).is_err());
}
#[test]
fn test_has_path_traversal_pattern() {
assert!(has_path_traversal_pattern(Path::new("../etc/passwd")));
assert!(has_path_traversal_pattern(Path::new("foo/../bar")));
assert!(has_path_traversal_pattern(Path::new(
"..\\Windows\\System32"
)));
assert!(!has_path_traversal_pattern(Path::new("src/main.rs")));
assert!(!has_path_traversal_pattern(Path::new("/home/user/project")));
assert!(!has_path_traversal_pattern(Path::new(".")));
}
}