use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
pub const DEFAULT_MAX_STRING_LENGTH: usize = 1024 * 1024;
pub const DEFAULT_MAX_PARAM_NAME_LENGTH: usize = 256;
pub const DEFAULT_MAX_URI_LENGTH: usize = 8192;
pub const DEFAULT_MAX_PARAMS: usize = 100;
#[derive(Debug, Clone, Copy)]
pub struct InputLimits {
pub max_string_length: usize,
pub max_param_name_length: usize,
pub max_uri_length: usize,
pub max_params: usize,
}
impl Default for InputLimits {
fn default() -> Self {
Self {
max_string_length: DEFAULT_MAX_STRING_LENGTH,
max_param_name_length: DEFAULT_MAX_PARAM_NAME_LENGTH,
max_uri_length: DEFAULT_MAX_URI_LENGTH,
max_params: DEFAULT_MAX_PARAMS,
}
}
}
impl InputLimits {
#[must_use]
pub const fn production() -> Self {
Self {
max_string_length: 64 * 1024, max_param_name_length: 128,
max_uri_length: 2048,
max_params: 50,
}
}
#[must_use]
pub const fn development() -> Self {
Self {
max_string_length: 10 * 1024 * 1024, max_param_name_length: 512,
max_uri_length: 65536,
max_params: 1000,
}
}
pub fn check_string_length(&self, s: &str) -> Result<(), InputValidationError> {
if s.len() > self.max_string_length {
return Err(InputValidationError::StringTooLong {
actual: s.len(),
max: self.max_string_length,
});
}
Ok(())
}
pub fn check_param_name(&self, name: &str) -> Result<(), InputValidationError> {
if name.len() > self.max_param_name_length {
return Err(InputValidationError::ParamNameTooLong {
actual: name.len(),
max: self.max_param_name_length,
});
}
Ok(())
}
pub fn check_uri_length(&self, uri: &str) -> Result<(), InputValidationError> {
if uri.len() > self.max_uri_length {
return Err(InputValidationError::UriTooLong {
actual: uri.len(),
max: self.max_uri_length,
});
}
Ok(())
}
pub fn check_param_count(&self, count: usize) -> Result<(), InputValidationError> {
if count > self.max_params {
return Err(InputValidationError::TooManyParams {
actual: count,
max: self.max_params,
});
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum InputValidationError {
StringTooLong {
actual: usize,
max: usize,
},
ParamNameTooLong {
actual: usize,
max: usize,
},
UriTooLong {
actual: usize,
max: usize,
},
TooManyParams {
actual: usize,
max: usize,
},
DangerousUriScheme {
scheme: String,
},
}
impl core::fmt::Display for InputValidationError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::StringTooLong { actual, max } => {
write!(f, "String too long: {} bytes (max: {})", actual, max)
}
Self::ParamNameTooLong { actual, max } => {
write!(
f,
"Parameter name too long: {} bytes (max: {})",
actual, max
)
}
Self::UriTooLong { actual, max } => {
write!(f, "URI too long: {} bytes (max: {})", actual, max)
}
Self::TooManyParams { actual, max } => {
write!(f, "Too many parameters: {} (max: {})", actual, max)
}
Self::DangerousUriScheme { scheme } => {
write!(f, "URI scheme is blocked for security: {}", scheme)
}
}
}
}
pub const DANGEROUS_URI_SCHEMES: &[&str] = &["javascript", "vbscript"];
pub fn check_uri_scheme_safety(uri: &str) -> Result<String, InputValidationError> {
let scheme = if let Some(pos) = uri.find("://") {
&uri[..pos]
} else if let Some(pos) = uri.find(':') {
&uri[..pos]
} else {
""
};
let normalized: String = scheme.chars().map(|c| c.to_ascii_lowercase()).collect();
if DANGEROUS_URI_SCHEMES.contains(&normalized.as_str()) {
return Err(InputValidationError::DangerousUriScheme { scheme: normalized });
}
Ok(normalized)
}
#[must_use]
pub fn sanitize_error_message(message: &str) -> String {
let mut result = String::from(message);
result = sanitize_connection_strings(&result);
result = sanitize_urls_with_credentials(&result);
result = sanitize_secrets(&result);
result = sanitize_ip_addresses(&result);
result = sanitize_file_paths(&result);
result = sanitize_emails(&result);
result
}
fn sanitize_connection_strings(s: &str) -> String {
let prefixes = [
"postgres://",
"postgresql://",
"mysql://",
"mongodb://",
"redis://",
"amqp://",
"kafka://",
"sqlite://",
];
let mut result = String::from(s);
for prefix in prefixes {
while let Some(start) = result.find(prefix) {
let end = result[start..]
.find(|c: char| c.is_whitespace() || c == '\'' || c == '"')
.map(|i| start + i)
.unwrap_or(result.len());
result.replace_range(start..end, "[CONNECTION]");
}
}
result
}
fn sanitize_urls_with_credentials(s: &str) -> String {
let mut result = String::from(s);
for prefix in ["http://", "https://", "ftp://"] {
while let Some(start) = result.find(prefix) {
let after_proto = start + prefix.len();
let rest = &result[after_proto..];
if let Some(at_pos) = rest.find('@') {
let slash_pos = rest.find('/').unwrap_or(rest.len());
if at_pos < slash_pos {
let end = result[start..]
.find(|c: char| c.is_whitespace() || c == '\'' || c == '"')
.map(|i| start + i)
.unwrap_or(result.len());
result.replace_range(start..end, "[URL]");
continue;
}
}
break;
}
}
result
}
fn sanitize_secrets(s: &str) -> String {
let patterns = [
"api_key=",
"api-key=",
"apikey=",
"password=",
"passwd=",
"token=",
"secret=",
"api_key:",
"api-key:",
"password:",
"token:",
"secret:",
"Bearer ",
"bearer ",
"Authorization: ",
];
let mut result = String::from(s);
let lower = s.to_lowercase();
for pattern in patterns {
let pattern_lower = pattern.to_lowercase();
let mut positions: Vec<usize> = Vec::new();
let mut search_start = 0;
while let Some(pos) = lower[search_start..].find(&pattern_lower) {
positions.push(search_start + pos);
search_start += pos + pattern.len();
}
for start in positions.into_iter().rev() {
let prefix_end = start + pattern.len();
if prefix_end >= result.len() {
continue;
}
let end = result[prefix_end..]
.find(|c: char| {
c.is_whitespace() || c == ',' || c == ';' || c == '\'' || c == '"' || c == ')'
})
.map(|i| prefix_end + i)
.unwrap_or(result.len());
let keyword = &result[start..start + pattern.len()];
let replacement = if keyword.ends_with('=') {
format!("{}=[REDACTED]", keyword.trim_end_matches('='))
} else if keyword.ends_with(':') {
format!("{}:[REDACTED]", keyword.trim_end_matches(':'))
} else {
format!("{} [REDACTED]", keyword.trim())
};
result.replace_range(start..end, &replacement);
}
}
result
}
fn sanitize_ip_addresses(s: &str) -> String {
let mut result = String::new();
let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
if c.is_ascii_digit() {
let mut potential_ip = String::from(c);
let mut dot_count = 0;
let mut is_ip = true;
let mut segment_digits = 1;
while let Some(&next) = chars.peek() {
if next.is_ascii_digit() {
segment_digits += 1;
if segment_digits > 3 {
is_ip = false;
break;
}
potential_ip.push(chars.next().unwrap());
} else if next == '.' && dot_count < 3 {
dot_count += 1;
segment_digits = 0;
potential_ip.push(chars.next().unwrap());
} else {
break;
}
}
if is_ip && dot_count == 3 {
let segments: Vec<&str> = potential_ip.split('.').collect();
let valid_ip = segments.len() == 4
&& segments
.iter()
.all(|seg| !seg.is_empty() && seg.len() <= 3 && seg.parse::<u8>().is_ok());
if valid_ip {
result.push_str("[IP]");
} else {
result.push_str(&potential_ip);
}
} else {
result.push_str(&potential_ip);
}
} else {
result.push(c);
}
}
result
}
fn sanitize_file_paths(s: &str) -> String {
let mut result = String::from(s);
let mut i = 0;
while i < result.len() {
let bytes = result.as_bytes();
if bytes[i] == b'/' && (i == 0 || !bytes[i - 1].is_ascii_alphanumeric()) {
let rest = &result[i..];
if rest.len() > 1
&& (rest
.chars()
.nth(1)
.is_some_and(|c| c.is_alphanumeric() || c == '.'))
{
let end = rest[1..]
.find(|c: char| {
c.is_whitespace() || c == '\'' || c == '"' || c == ')' || c == ']'
})
.map(|p| i + 1 + p)
.unwrap_or(result.len());
let path_segment = &result[i..end];
if path_segment.contains('/')
|| (path_segment.contains('.') && path_segment.len() > 4)
{
result.replace_range(i..end, "[PATH]");
} else {
i += 1;
}
} else {
i += 1;
}
} else {
i += 1;
}
}
let mut i = 0;
while i < result.len() {
let bytes = result.as_bytes();
if i + 2 < bytes.len()
&& bytes[i].is_ascii_alphabetic()
&& bytes[i + 1] == b':'
&& bytes[i + 2] == b'\\'
{
let end = result[i..]
.find(|c: char| c.is_whitespace() || c == '\'' || c == '"')
.map(|p| i + p)
.unwrap_or(result.len());
result.replace_range(i..end, "[PATH]");
}
i += 1;
}
result
}
fn sanitize_emails(s: &str) -> String {
let mut result = String::from(s);
let mut i = 0;
while i < result.len() {
if let Some(at_pos) = result[i..].find('@') {
let abs_at = i + at_pos;
let start = result[..abs_at]
.rfind(|c: char| c.is_whitespace() || c == '<' || c == '(' || c == ',')
.map(|p| p + 1)
.unwrap_or(0);
if start >= abs_at {
i = abs_at + 1;
continue;
}
let after_at = abs_at + 1;
if after_at >= result.len() {
break;
}
let end = result[after_at..]
.find(|c: char| c.is_whitespace() || c == '>' || c == ')' || c == ',')
.map(|p| after_at + p)
.unwrap_or(result.len());
let domain = &result[after_at..end];
if domain.contains('.') && domain.len() > 3 {
result.replace_range(start..end, "[EMAIL]");
i = start + 7; } else {
i = end;
}
} else {
break;
}
}
result
}
pub const GENERIC_ERROR_MESSAGE: &str = "An error occurred. Please try again.";
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sanitize_connection_strings() {
let msg = "Failed: postgres://admin:secret@localhost:5432/mydb";
let safe = sanitize_error_message(msg);
assert!(!safe.contains("admin"));
assert!(!safe.contains("secret"));
assert!(safe.contains("[CONNECTION]"));
}
#[test]
fn test_sanitize_ip_addresses() {
let msg = "Connection to 192.168.1.100:5432 failed";
let safe = sanitize_error_message(msg);
assert!(!safe.contains("192.168.1.100"));
assert!(safe.contains("[IP]"));
}
#[test]
fn test_sanitize_file_paths() {
let msg = "File not found: /etc/secrets/api_key.txt";
let safe = sanitize_error_message(msg);
assert!(!safe.contains("/etc/secrets"));
assert!(safe.contains("[PATH]"));
}
#[test]
fn test_sanitize_secrets() {
let msg = "Auth failed: api_key=sk_live_abc123xyz";
let safe = sanitize_error_message(msg);
assert!(!safe.contains("sk_live"));
assert!(safe.contains("[REDACTED]"));
}
#[test]
fn test_sanitize_emails() {
let msg = "User admin@example.com not found";
let safe = sanitize_error_message(msg);
assert!(!safe.contains("admin@example.com"));
assert!(safe.contains("[EMAIL]"));
}
#[test]
fn test_input_limits() {
let limits = InputLimits::production();
assert!(limits.check_string_length("short").is_ok());
let long_string = "x".repeat(limits.max_string_length + 1);
assert!(limits.check_string_length(&long_string).is_err());
}
#[test]
fn test_uri_scheme_safety_accepts_standard_schemes() {
assert!(check_uri_scheme_safety("file:///etc/passwd").is_ok());
assert!(check_uri_scheme_safety("https://example.com").is_ok());
assert!(check_uri_scheme_safety("data:text/html,hello").is_ok());
assert!(check_uri_scheme_safety("mcp://server/resource").is_ok());
}
#[test]
fn test_uri_scheme_safety_accepts_custom_schemes() {
assert_eq!(
check_uri_scheme_safety("apple-doc://swift/StringProtocol").unwrap(),
"apple-doc"
);
assert_eq!(
check_uri_scheme_safety("notion://workspace/page/abc123").unwrap(),
"notion"
);
assert_eq!(
check_uri_scheme_safety("slack://workspace/C01234567").unwrap(),
"slack"
);
assert_eq!(
check_uri_scheme_safety("weather://api/current").unwrap(),
"weather"
);
assert_eq!(
check_uri_scheme_safety("custom+scheme://data").unwrap(),
"custom+scheme"
);
}
#[test]
fn test_uri_scheme_safety_rejects_dangerous_schemes() {
assert!(matches!(
check_uri_scheme_safety("javascript:alert(1)"),
Err(InputValidationError::DangerousUriScheme { .. })
));
assert!(matches!(
check_uri_scheme_safety("vbscript:msgbox(\"xss\")"),
Err(InputValidationError::DangerousUriScheme { .. })
));
}
#[test]
fn test_uri_scheme_safety_is_case_insensitive() {
assert!(check_uri_scheme_safety("JavaScript:alert(1)").is_err());
assert!(check_uri_scheme_safety("JAVASCRIPT:alert(1)").is_err());
assert!(check_uri_scheme_safety("VBScript:msgbox(1)").is_err());
assert_eq!(
check_uri_scheme_safety("HTTPS://example.com").unwrap(),
"https"
);
}
#[test]
fn test_no_false_positives() {
let msg = "User 123 requested tool list on port 8080";
let safe = sanitize_error_message(msg);
assert_eq!(msg, safe);
}
}