use std::path::{Component, Path, PathBuf};
use std::sync::LazyLock;
use std::{collections::HashSet, ffi::OsStr};
use tracing::warn;
pub async fn file_exists(path: &Path) -> bool {
if !is_safe_path(path) {
warn!("Unsafe path access attempt: {:?}", path);
return false;
}
tokio::fs::metadata(path).await.is_ok()
}
pub fn get_file_extension(path: &Path) -> Option<&str> {
path.extension().and_then(|ext| ext.to_str())
}
pub fn is_markdown_file(path: &Path) -> bool {
match get_file_extension(path) {
Some(ext) => matches!(ext.to_lowercase().as_str(), "md" | "markdown"),
None => false,
}
}
pub fn is_image_file(path: &Path) -> bool {
match get_file_extension(path) {
Some(ext) => matches!(
ext.to_lowercase().as_str(),
"jpg" | "jpeg" | "png" | "gif" | "webp" | "bmp"
),
None => false,
}
}
pub fn validate_app_credentials(app_id: &str, app_secret: &str) -> Result<(), String> {
if app_id.is_empty() {
return Err("App ID cannot be empty".to_string());
}
if app_secret.is_empty() {
return Err("App secret cannot be empty".to_string());
}
if !app_id.starts_with("wx") || app_id.len() != 18 {
return Err(
"Invalid app ID format (should start with 'wx' and be 18 characters)".to_string(),
);
}
if app_secret.len() != 32 {
return Err("Invalid app secret format (should be 32 characters)".to_string());
}
Ok(())
}
static DANGEROUS_EXTENSIONS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {
let mut set = HashSet::new();
set.insert("exe");
set.insert("bat");
set.insert("cmd");
set.insert("com");
set.insert("scr");
set.insert("pif");
set.insert("vbs");
set.insert("js");
set.insert("jse");
set.insert("wsf");
set.insert("wsh");
set.insert("msi");
set.insert("dll");
set.insert("scf");
set.insert("lnk");
set.insert("inf");
set.insert("reg");
set
});
pub fn is_safe_path(path: &Path) -> bool {
if let Some(path_str) = path.to_str() {
if path_str.contains("/tmp/")
|| path_str.contains("/var/folders/")
|| path_str.contains("\\Temp\\")
{
if let Some(extension) = path.extension().and_then(OsStr::to_str) {
if DANGEROUS_EXTENSIONS.contains(&extension.to_lowercase().as_str()) {
return false;
}
}
return true;
}
}
if let Some(extension) = path.extension().and_then(OsStr::to_str) {
if DANGEROUS_EXTENSIONS.contains(&extension.to_lowercase().as_str()) {
return false;
}
}
for component in path.components() {
match component {
Component::ParentDir => {
continue;
}
Component::Normal(name) => {
let name_str = name.to_string_lossy();
if name_str.starts_with('.') && name_str.len() > 1 {
if !matches!(name_str.as_ref(), ".gitignore" | ".env" | ".dockerignore")
&& !name_str.starts_with(".tmp")
{
return false;
}
}
if name_str.contains('\0') || name_str.contains('\x01') {
return false;
}
if is_reserved_name(&name_str) {
return false;
}
}
Component::RootDir | Component::CurDir => {
continue;
}
Component::Prefix(_) => {
continue;
}
}
}
true
}
fn is_reserved_name(name: &str) -> bool {
let upper_name = name.to_uppercase();
let base_name = upper_name.split('.').next().unwrap_or("");
matches!(
base_name,
"CON"
| "PRN"
| "AUX"
| "NUL"
| "COM1"
| "COM2"
| "COM3"
| "COM4"
| "COM5"
| "COM6"
| "COM7"
| "COM8"
| "COM9"
| "LPT1"
| "LPT2"
| "LPT3"
| "LPT4"
| "LPT5"
| "LPT6"
| "LPT7"
| "LPT8"
| "LPT9"
)
}
pub fn has_path_traversal(path: &str) -> bool {
path.contains("../")
|| path.contains("..\\")
|| path.contains("/..")
|| path.contains("\\..")
|| path.contains("....")
|| path == ".."
}
pub fn sanitize_filename(filename: &str) -> String {
let mut sanitized = filename
.chars()
.filter(|&c| !matches!(c, '<' | '>' | ':' | '"' | '|' | '?' | '*' | '\0'..='\x1F'))
.collect::<String>();
sanitized = sanitized.replace(['/', '\\'], "_");
if sanitized.starts_with('.') && sanitized.len() > 1 {
sanitized = format!("_{}", &sanitized[1..]);
}
if sanitized.is_empty() {
sanitized = "unnamed".to_string();
}
if sanitized.len() > 255 {
sanitized.truncate(252);
sanitized.push_str("...");
}
sanitized
}
pub fn validate_file_size(size: u64, max_size: u64, file_type: &str) -> Result<(), String> {
if size > max_size {
return Err(format!(
"{file_type} file too large: {size} bytes (max: {max_size} bytes)"
));
}
Ok(())
}
pub fn get_base_directory(file_path: &Path) -> Option<&Path> {
file_path.parent()
}
pub fn resolve_path(base_dir: &Path, relative_path: &str) -> Result<PathBuf, String> {
let relative = Path::new(relative_path);
if relative.is_absolute() {
if !is_safe_path(relative) {
return Err("Absolute path contains unsafe components".to_string());
}
return Ok(PathBuf::from(relative_path));
}
let resolved = base_dir.join(relative_path);
if !is_safe_path(&resolved) {
return Err("Resolved path contains unsafe components".to_string());
}
match resolved.canonicalize() {
Ok(canonical_resolved) => {
match base_dir.canonicalize() {
Ok(canonical_base) => {
if canonical_resolved.starts_with(&canonical_base) {
Ok(resolved)
} else {
Err("Path traversal attempt detected".to_string())
}
}
Err(_) => {
if has_path_traversal(relative_path) {
Err("Path contains traversal sequences".to_string())
} else {
Ok(resolved)
}
}
}
}
Err(_) => {
if has_path_traversal(relative_path) {
Err("Path contains traversal sequences".to_string())
} else {
Ok(resolved)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_get_file_extension() {
assert_eq!(get_file_extension(Path::new("test.md")), Some("md"));
assert_eq!(
get_file_extension(Path::new("test.markdown")),
Some("markdown")
);
assert_eq!(get_file_extension(Path::new("test.jpg")), Some("jpg"));
assert_eq!(get_file_extension(Path::new("test")), None);
assert_eq!(get_file_extension(Path::new(".gitignore")), None);
}
#[test]
fn test_is_markdown_file() {
assert!(is_markdown_file(Path::new("test.md")));
assert!(is_markdown_file(Path::new("test.markdown")));
assert!(is_markdown_file(Path::new("TEST.MD")));
assert!(!is_markdown_file(Path::new("test.txt")));
assert!(!is_markdown_file(Path::new("test")));
}
#[test]
fn test_is_image_file() {
assert!(is_image_file(Path::new("test.jpg")));
assert!(is_image_file(Path::new("test.PNG")));
assert!(is_image_file(Path::new("test.gif")));
assert!(!is_image_file(Path::new("test.txt")));
assert!(!is_image_file(Path::new("test")));
}
#[test]
fn test_validate_app_credentials() {
assert!(
validate_app_credentials("wx1234567890123456", "12345678901234567890123456789012")
.is_ok()
);
assert!(validate_app_credentials("", "12345678901234567890123456789012").is_err());
assert!(validate_app_credentials("invalid", "12345678901234567890123456789012").is_err());
assert!(validate_app_credentials("wx123", "12345678901234567890123456789012").is_err());
assert!(validate_app_credentials("wx1234567890123456", "").is_err());
assert!(validate_app_credentials("wx1234567890123456", "short").is_err());
}
#[test]
fn test_get_base_directory() {
assert_eq!(
get_base_directory(Path::new("/path/to/file.md")),
Some(Path::new("/path/to"))
);
assert_eq!(
get_base_directory(Path::new("file.md")),
Some(Path::new(""))
);
assert_eq!(get_base_directory(Path::new("/")), None);
}
#[test]
fn test_resolve_path() {
let base = Path::new("/base/dir");
assert_eq!(
resolve_path(base, "relative.md").unwrap(),
PathBuf::from("/base/dir/relative.md")
);
assert_eq!(
resolve_path(base, "/absolute.md").unwrap(),
PathBuf::from("/absolute.md")
);
assert_eq!(
resolve_path(base, "./relative.md").unwrap(),
PathBuf::from("/base/dir/./relative.md")
);
assert!(resolve_path(base, "../../../etc/passwd").is_err());
assert!(resolve_path(base, "..\\..\\windows\\system32").is_err());
assert!(resolve_path(base, "malware.exe").is_err());
assert!(resolve_path(base, "script.bat").is_err());
}
#[test]
fn test_is_safe_path() {
assert!(is_safe_path(Path::new("document.md")));
assert!(is_safe_path(Path::new("image.jpg")));
assert!(is_safe_path(Path::new("folder/file.txt")));
assert!(!is_safe_path(Path::new("malware.exe")));
assert!(!is_safe_path(Path::new("script.bat")));
assert!(!is_safe_path(Path::new("virus.scr")));
assert!(!is_safe_path(Path::new("CON")));
assert!(!is_safe_path(Path::new("PRN.txt")));
assert!(!is_safe_path(Path::new("COM1.dat")));
assert!(!is_safe_path(Path::new(".hidden")));
assert!(is_safe_path(Path::new(".gitignore")));
assert!(is_safe_path(Path::new(".env")));
}
#[test]
fn test_has_path_traversal() {
assert!(has_path_traversal("../etc/passwd"));
assert!(has_path_traversal("..\\windows\\system32"));
assert!(has_path_traversal("folder/../../../etc"));
assert!(has_path_traversal(".."));
assert!(has_path_traversal("...."));
assert!(!has_path_traversal("normal/path/file.txt"));
assert!(!has_path_traversal("file.md"));
assert!(!has_path_traversal("folder/subfolder/file"));
}
#[test]
fn test_sanitize_filename() {
assert_eq!(sanitize_filename("normal_file.txt"), "normal_file.txt");
assert_eq!(sanitize_filename("file<>:\"|?*.txt"), "file.txt");
assert_eq!(sanitize_filename("path/to/file.txt"), "path_to_file.txt");
assert_eq!(sanitize_filename("path\\to\\file.txt"), "path_to_file.txt");
assert_eq!(sanitize_filename(".hidden"), "_hidden");
assert_eq!(sanitize_filename(""), "unnamed");
let long_name = "a".repeat(300);
let sanitized = sanitize_filename(&long_name);
assert!(sanitized.len() <= 255);
assert!(sanitized.ends_with("..."));
}
#[test]
fn test_validate_file_size() {
assert!(validate_file_size(1000, 2000, "test").is_ok());
let result = validate_file_size(3000, 2000, "image");
assert!(result.is_err());
assert!(result.unwrap_err().contains("image file too large"));
}
}