use regex::Regex;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct SensitiveDataDetector {
sensitive_patterns: Vec<Regex>,
high_sensitivity_keywords: HashSet<&'static str>,
custom_sensitive_fields: HashSet<String>,
}
impl Default for SensitiveDataDetector {
fn default() -> Self {
Self::new()
}
}
impl SensitiveDataDetector {
pub fn new() -> Self {
Self {
sensitive_patterns: Self::default_sensitive_patterns(),
high_sensitivity_keywords: Self::default_high_sensitivity_keywords(),
custom_sensitive_fields: HashSet::new(),
}
}
fn default_sensitive_patterns() -> Vec<Regex> {
vec![
Regex::new(r"(?i)password").unwrap(),
Regex::new(r"(?i)secret").unwrap(),
Regex::new(r"(?i)token").unwrap(),
Regex::new(r"(?i)api_key").unwrap(),
Regex::new(r"(?i)access_key").unwrap(),
Regex::new(r"(?i)private_key").unwrap(),
Regex::new(r"(?i)credential").unwrap(),
Regex::new(r"(?i)auth").unwrap(),
Regex::new(r"(?i)key").unwrap(),
Regex::new(r"(?i)cert").unwrap(),
Regex::new(r"(?i)password_hash").unwrap(),
Regex::new(r"(?i)session_id").unwrap(),
]
}
fn default_high_sensitivity_keywords() -> HashSet<&'static str> {
let mut set = HashSet::new();
set.insert("password");
set.insert("secret");
set.insert("private_key");
set.insert("master_key");
set.insert("encryption_key");
set.insert("api_secret");
set.insert("access_token");
set.insert("refresh_token");
set.insert("client_secret");
set.insert("db_password");
set.insert("admin_password");
set
}
pub fn add_custom_sensitive_field(&mut self, field: impl Into<String>) {
self.custom_sensitive_fields
.insert(field.into().to_lowercase());
}
pub fn is_sensitive(&self, field_name: &str, field_value: &str) -> SensitivityResult {
let field_lower = field_name.to_lowercase();
let value_lower = field_value.to_lowercase();
if self
.high_sensitivity_keywords
.contains(field_lower.as_str())
{
return SensitivityResult::High {
field: field_name.to_string(),
reason: "high sensitivity keyword in field name".to_string(),
};
}
if self.custom_sensitive_fields.contains(&field_lower) {
return SensitivityResult::Medium {
field: field_name.to_string(),
reason: "custom sensitive field".to_string(),
};
}
for pattern in &self.sensitive_patterns {
if pattern.is_match(&field_lower) || pattern.is_match(&value_lower) {
return SensitivityResult::Medium {
field: field_name.to_string(),
reason: format!("sensitive pattern detected: {}", pattern.as_str()),
};
}
}
SensitivityResult::Low
}
pub fn detect_all<'a>(
&self,
data: &'a HashMap<String, String>,
) -> Vec<(&'a str, SensitivityResult)> {
data.iter()
.map(|(k, v)| (k.as_str(), self.is_sensitive(k, v)))
.filter(|(_, result)| !result.is_low())
.collect()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum SensitivityResult {
Low,
Medium { field: String, reason: String },
High { field: String, reason: String },
}
impl SensitivityResult {
pub fn is_low(&self) -> bool {
matches!(self, SensitivityResult::Low)
}
pub fn needs_protection(&self) -> bool {
!self.is_low()
}
pub fn description(&self) -> String {
match self {
SensitivityResult::Low => "low sensitivity".to_string(),
SensitivityResult::Medium { field, reason } => {
format!("medium sensitivity: {} - {}", field, reason)
}
SensitivityResult::High { field, reason } => {
format!("high sensitivity: {} - {}", field, reason)
}
}
}
}
#[derive(Debug, Clone)]
pub struct InputValidator {
max_string_length: usize,
max_array_length: usize,
max_depth: usize,
allowed_chars_pattern: Option<Regex>,
dangerous_patterns: Vec<Regex>,
whitelist_patterns: Vec<Regex>,
}
impl Default for InputValidator {
fn default() -> Self {
Self::new()
}
}
impl InputValidator {
pub fn new() -> Self {
Self {
max_string_length: 1024,
max_array_length: 100,
max_depth: 10,
allowed_chars_pattern: None,
dangerous_patterns: Self::default_dangerous_patterns(),
whitelist_patterns: Vec::new(),
}
}
fn default_dangerous_patterns() -> Vec<Regex> {
vec![
Regex::new(r"[;<>&|`$()]").unwrap(), Regex::new(r"\$\{.*\}").unwrap(), Regex::new(r"`[^`]+`").unwrap(), Regex::new(r"\|").unwrap(), Regex::new(r"&&").unwrap(), Regex::new(r"\|\|").unwrap(), Regex::new(r">>").unwrap(), Regex::new(r"2>").unwrap(), Regex::new(r"\.\.[/\\]").unwrap(), Regex::new(r"[/\\]\.\.[/\\]").unwrap(), Regex::new(r"(?i)(;?\s*(drop|delete|update|insert|alter|create)\b)").unwrap(),
Regex::new(r"(?i)(union\s+select\b)").unwrap(),
Regex::new(r"(?i)'+\s*(or|and)\b").unwrap(),
Regex::new(r"(?i)--\s*$").unwrap(),
]
}
pub fn with_max_string_length(mut self, length: usize) -> Self {
self.max_string_length = length;
self
}
pub fn with_max_array_length(mut self, length: usize) -> Self {
self.max_array_length = length;
self
}
pub fn with_max_depth(mut self, depth: usize) -> Self {
self.max_depth = depth;
self
}
pub fn with_allowed_chars_pattern(mut self, pattern: &str) -> Self {
self.allowed_chars_pattern = Regex::new(pattern).ok();
self
}
pub fn add_whitelist_pattern(mut self, pattern: &str) -> Self {
if let Ok(regex) = Regex::new(pattern) {
self.whitelist_patterns.push(regex);
}
self
}
pub fn validate_string(&self, value: &str) -> Result<(), InputValidationError> {
if value.len() > self.max_string_length {
return Err(InputValidationError::TooLong {
max: self.max_string_length,
actual: value.len(),
});
}
if let Some(ref pattern) = self.allowed_chars_pattern {
if !pattern.is_match(value) {
return Err(InputValidationError::InvalidCharacters);
}
}
for pattern in &self.dangerous_patterns {
if pattern.is_match(value) {
return Err(InputValidationError::DangerousPattern {
pattern: pattern.as_str().to_string(),
});
}
}
Ok(())
}
pub fn sanitize_string(&self, value: &str) -> Result<String, InputValidationError> {
let mut result = String::new();
for c in value.chars() {
if !self.is_dangerous_char(c) {
result.push(c);
}
}
self.validate_string(&result)?;
Ok(result)
}
fn is_dangerous_char(&self, c: char) -> bool {
matches!(
c,
';' | '<' | '>' | '&' | '|' | '`' | '$' | '(' | ')' | '\0' | '{' | '}'
)
}
pub fn validate_field_name(&self, name: &str) -> Result<(), InputValidationError> {
if name.is_empty() {
return Err(InputValidationError::EmptyFieldName);
}
if name.len() > self.max_string_length {
return Err(InputValidationError::TooLong {
max: self.max_string_length,
actual: name.len(),
});
}
let valid_pattern = Regex::new(r"^[a-zA-Z][a-zA-Z0-9_-]*$").unwrap();
if !valid_pattern.is_match(name) {
return Err(InputValidationError::InvalidFieldNameFormat);
}
Ok(())
}
pub fn validate_url(&self, url: &str) -> Result<(), InputValidationError> {
if url.len() > self.max_string_length {
return Err(InputValidationError::TooLong {
max: self.max_string_length,
actual: url.len(),
});
}
let parsed = url::Url::parse(url).map_err(|_| InputValidationError::InvalidUrl)?;
if !matches!(parsed.scheme(), "http" | "https") {
return Err(InputValidationError::InvalidUrlScheme);
}
for pattern in &self.dangerous_patterns {
if pattern.is_match(url) {
return Err(InputValidationError::DangerousPattern {
pattern: pattern.as_str().to_string(),
});
}
}
Ok(())
}
pub fn validate_email(&self, email: &str) -> Result<(), InputValidationError> {
if email.len() > self.max_string_length {
return Err(InputValidationError::TooLong {
max: self.max_string_length,
actual: email.len(),
});
}
let email_pattern =
Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$").unwrap();
if !email_pattern.is_match(email) {
return Err(InputValidationError::InvalidEmail);
}
Ok(())
}
pub fn validate_whitelist(&self, value: &str) -> Result<(), InputValidationError> {
if self.whitelist_patterns.is_empty() {
return Ok(());
}
for pattern in &self.whitelist_patterns {
if pattern.is_match(value) {
return Ok(());
}
}
Err(InputValidationError::NotInWhitelist)
}
pub fn validate_all<'a>(
&'a self,
data: &'a HashMap<String, String>,
) -> Vec<(&'a String, InputValidationError)> {
let mut errors = Vec::new();
for (name, value) in data {
if let Err(e) = self.validate_field_name(name) {
errors.push((name, e));
continue;
}
if let Err(e) = self.validate_string(value) {
errors.push((name, e));
}
}
errors
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum InputValidationError {
TooLong { max: usize, actual: usize },
InvalidCharacters,
DangerousPattern { pattern: String },
EmptyFieldName,
InvalidFieldNameFormat,
InvalidUrl,
InvalidUrlScheme,
InvalidEmail,
NotInWhitelist,
DepthExceeded { max: usize, actual: usize },
ArrayTooLong { max: usize, actual: usize },
}
impl std::fmt::Display for InputValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InputValidationError::TooLong { max, actual } => {
write!(f, "Input too long: max={}, actual={}", max, actual)
}
InputValidationError::InvalidCharacters => {
write!(f, "Input contains invalid characters")
}
InputValidationError::DangerousPattern { pattern } => {
write!(f, "Input contains dangerous pattern: {}", pattern)
}
InputValidationError::EmptyFieldName => {
write!(f, "Field name is empty")
}
InputValidationError::InvalidFieldNameFormat => {
write!(f, "Field name format is invalid")
}
InputValidationError::InvalidUrl => {
write!(f, "URL is invalid")
}
InputValidationError::InvalidUrlScheme => {
write!(f, "URL scheme is not allowed")
}
InputValidationError::InvalidEmail => {
write!(f, "Email format is invalid")
}
InputValidationError::NotInWhitelist => {
write!(f, "Input does not match any whitelist pattern")
}
InputValidationError::DepthExceeded { max, actual } => {
write!(f, "Nesting depth exceeded: max={}, actual={}", max, actual)
}
InputValidationError::ArrayTooLong { max, actual } => {
write!(f, "Array too long: max={}, actual={}", max, actual)
}
}
}
}
impl std::error::Error for InputValidationError {}
#[derive(Default)]
pub struct ConfigValidatorBuilder {
validator: InputValidator,
sensitive_detector: SensitiveDataDetector,
}
impl ConfigValidatorBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn build(self) -> ConfigValidator {
ConfigValidator {
input_validator: self.validator,
sensitive_detector: self.sensitive_detector,
}
}
pub fn max_string_length(mut self, length: usize) -> Self {
self.validator = self.validator.with_max_string_length(length);
self
}
pub fn add_sensitive_field(mut self, field: &str) -> Self {
self.sensitive_detector.add_custom_sensitive_field(field);
self
}
pub fn strict_mode(self) -> Self {
self.max_string_length(256)
.add_sensitive_field("token")
.add_sensitive_field("password")
}
}
#[derive(Debug, Clone)]
pub struct ConfigValidator {
input_validator: InputValidator,
sensitive_detector: SensitiveDataDetector,
}
impl ConfigValidator {
pub fn new() -> Self {
Self {
input_validator: InputValidator::new(),
sensitive_detector: SensitiveDataDetector::new(),
}
}
pub fn builder() -> ConfigValidatorBuilder {
ConfigValidatorBuilder::new()
}
#[cfg(test)]
pub fn sensitive_detector(&self) -> &SensitiveDataDetector {
&self.sensitive_detector
}
pub fn validate(&self, data: &HashMap<String, String>) -> ConfigValidationResult {
let mut errors = Vec::new();
let mut sensitive_fields = Vec::new();
for (name, value) in data {
if let Err(e) = self.input_validator.validate_field_name(name) {
errors.push(ConfigValidationError::FieldError {
field: name.clone(),
error: e,
});
}
if let Err(e) = self.input_validator.validate_string(value) {
errors.push(ConfigValidationError::FieldError {
field: name.clone(),
error: e,
});
}
let sensitivity = self.sensitive_detector.is_sensitive(name, value);
if sensitivity.needs_protection() {
sensitive_fields.push((name.clone(), sensitivity));
}
}
ConfigValidationResult {
errors,
sensitive_fields,
}
}
pub fn validate_safe(&self, data: &HashMap<String, String>) -> bool {
data.iter().all(|(name, value)| {
self.input_validator.validate_field_name(name).is_ok()
&& self.input_validator.validate_string(value).is_ok()
})
}
}
impl Default for ConfigValidator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ConfigValidationResult {
pub errors: Vec<ConfigValidationError>,
pub sensitive_fields: Vec<(String, SensitivityResult)>,
}
impl ConfigValidationResult {
pub fn is_valid(&self) -> bool {
self.errors.is_empty()
}
pub fn has_sensitive_data(&self) -> bool {
!self.sensitive_fields.is_empty()
}
pub fn error_report(&self) -> String {
self.errors
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join("\n")
}
}
#[derive(Debug, Clone)]
pub enum ConfigValidationError {
FieldError {
field: String,
error: InputValidationError,
},
SensitiveDataWarning {
field: String,
sensitivity: SensitivityResult,
},
}
impl std::fmt::Display for ConfigValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConfigValidationError::FieldError { field, error } => {
write!(f, "Field '{}': {}", field, error)
}
ConfigValidationError::SensitiveDataWarning { field, sensitivity } => {
write!(f, "Field '{}': {}", field, sensitivity.description())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sensitive_data_detection() {
let detector = SensitiveDataDetector::new();
assert!(detector
.is_sensitive("password", "value")
.needs_protection());
assert!(detector
.is_sensitive("secret_key", "value")
.needs_protection());
assert!(detector
.is_sensitive("api_token", "value")
.needs_protection());
assert!(detector
.is_sensitive("user_token", "value")
.needs_protection());
assert!(!detector
.is_sensitive("username", "value")
.needs_protection());
assert!(!detector.is_sensitive("port", "8080").needs_protection());
}
#[test]
fn test_input_validation() {
let validator = InputValidator::new();
assert!(validator.validate_string("hello world").is_ok());
assert!(validator.validate_field_name("app_port").is_ok());
assert!(validator.validate_string("hello;world").is_err());
assert!(validator.validate_string("hello${world}").is_err());
assert!(validator.validate_field_name("").is_err());
assert!(validator.validate_field_name("123port").is_err());
}
#[test]
fn test_url_validation() {
let validator = InputValidator::new();
assert!(validator.validate_url("https://example.com").is_ok());
assert!(validator.validate_url("http://localhost:8080").is_ok());
assert!(validator.validate_url("ftp://example.com").is_err());
assert!(validator.validate_url("javascript:alert(1)").is_err());
}
#[test]
fn test_email_validation() {
let validator = InputValidator::new();
assert!(validator.validate_email("user@example.com").is_ok());
assert!(validator.validate_email("invalid-email").is_err());
assert!(validator.validate_email("@example.com").is_err());
}
#[test]
fn test_sanitization() {
let validator = InputValidator::new();
let input = "hello; world ${test}";
let sanitized = validator.sanitize_string(input).unwrap();
assert!(!sanitized.contains(';'));
assert!(!sanitized.contains('$'));
assert_eq!(sanitized, "hello world test");
}
#[test]
fn test_config_validation() {
let validator = ConfigValidator::new();
let mut config = HashMap::new();
config.insert("app_name".to_string(), "my-app".to_string());
config.insert("app_port".to_string(), "8080".to_string());
config.insert("database_password".to_string(), "secret".to_string());
let result = validator.validate(&config);
assert!(result.is_valid());
assert!(result.has_sensitive_data());
assert_eq!(result.sensitive_fields.len(), 1);
}
#[test]
fn test_custom_sensitive_field() {
let detector = SensitiveDataDetector::new();
let mut custom_detector = detector.clone();
custom_detector.add_custom_sensitive_field("custom_field");
assert!(custom_detector
.is_sensitive("custom_field", "value")
.needs_protection());
}
}