mod metachar;
mod path_guard;
mod whitelist;
use crate::error::{NlResult, ValidatorError};
use crate::types::ValidationStatus;
pub const MAX_COMMAND_LENGTH: usize = 512;
#[must_use]
pub fn validate_command(command: &str) -> ValidationStatus {
let validator = SafetyValidator::new();
match validator.validate(command) {
Ok(status) => status,
Err(crate::error::NlError::Validator(err)) => match err {
ValidatorError::MetacharDetected => ValidationStatus::RejectedMetachar,
ValidatorError::EnvVarDetected => ValidationStatus::RejectedEnvVar,
ValidatorError::PathTraversal | ValidatorError::AbsolutePath => {
ValidationStatus::RejectedPathTraversal
}
ValidatorError::CommandTooLong => ValidationStatus::RejectedTooLong,
ValidatorError::WriteOperation => ValidationStatus::RejectedWriteMode,
ValidatorError::TemplateMismatch => ValidationStatus::RejectedUnknown,
},
Err(_) => ValidationStatus::RejectedUnknown,
}
}
pub struct SafetyValidator {
strict_mode: bool,
}
impl Default for SafetyValidator {
fn default() -> Self {
Self::new()
}
}
impl SafetyValidator {
#[must_use]
pub fn new() -> Self {
Self { strict_mode: true }
}
#[must_use]
pub fn with_strict_mode(strict_mode: bool) -> Self {
Self { strict_mode }
}
pub fn validate(&self, command: &str) -> NlResult<ValidationStatus> {
if command.len() > MAX_COMMAND_LENGTH {
return Err(ValidatorError::CommandTooLong.into());
}
if let Some(status) = metachar::contains_dangerous_chars(command) {
return match status {
ValidationStatus::RejectedEnvVar => Err(ValidatorError::EnvVarDetected.into()),
_ => Err(ValidatorError::MetacharDetected.into()),
};
}
if path_guard::contains_path_traversal(command) {
return Err(ValidatorError::PathTraversal.into());
}
if path_guard::contains_absolute_path(command) {
return Err(ValidatorError::AbsolutePath.into());
}
if metachar::contains_write_operation(command) {
return Err(ValidatorError::WriteOperation.into());
}
if self.strict_mode && !whitelist::matches_allowed_template(command) {
return Err(ValidatorError::TemplateMismatch.into());
}
Ok(ValidationStatus::Valid)
}
#[must_use]
pub fn is_valid(&self, command: &str) -> bool {
self.validate(command).is_ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_query_command() {
let validator = SafetyValidator::new();
let result = validator.validate("sqry query \"authenticate\"");
assert!(result.is_ok());
}
#[test]
fn test_reject_semicolon() {
let validator = SafetyValidator::new();
let result = validator.validate("sqry query \"foo\"; rm -rf /");
assert!(matches!(
result,
Err(crate::error::NlError::Validator(
ValidatorError::MetacharDetected
))
));
}
#[test]
fn test_reject_env_var() {
let validator = SafetyValidator::new();
let result = validator.validate("sqry query \"$HOME\"");
assert!(matches!(
result,
Err(crate::error::NlError::Validator(
ValidatorError::EnvVarDetected
))
));
}
#[test]
fn test_reject_path_traversal() {
let validator = SafetyValidator::new();
let result = validator.validate("sqry query ../../../etc/passwd");
assert!(matches!(
result,
Err(crate::error::NlError::Validator(
ValidatorError::PathTraversal
))
));
}
#[test]
fn test_allow_quoted_double_dot() {
let validator = SafetyValidator::new();
let result = validator.validate("sqry query \"..*password\"");
assert!(!matches!(
result,
Err(crate::error::NlError::Validator(
ValidatorError::PathTraversal
))
));
}
#[test]
fn test_reject_too_long() {
let validator = SafetyValidator::new();
let long_command = format!("sqry query \"{}\"", "x".repeat(MAX_COMMAND_LENGTH));
let result = validator.validate(&long_command);
assert!(matches!(
result,
Err(crate::error::NlError::Validator(
ValidatorError::CommandTooLong
))
));
}
}