use thiserror::Error;
use crate::middleware::Middleware;
use crate::model::ParsedCommand;
#[derive(Debug, Error, PartialEq)]
pub enum ValidationError {
#[error(
"field `{field}` contains a path traversal sequence in value: {value:?}"
)]
PathTraversal {
field: String,
value: String,
},
#[error(
"field `{field}` contains a control character in value: {value:?}"
)]
ControlCharacter {
field: String,
value: String,
},
#[error(
"field `{field}` contains an embedded query parameter in value: {value:?}"
)]
QueryInjection {
field: String,
value: String,
},
#[error(
"field `{field}` contains a URL-encoded sequence in value: {value:?}"
)]
UrlEncoding {
field: String,
value: String,
},
}
#[derive(Debug, Clone, Default)]
pub struct InputValidator {
path_traversal: bool,
control_chars: bool,
query_injection: bool,
url_encoding: bool,
}
impl InputValidator {
pub fn new() -> Self {
Self::default()
}
pub fn strict() -> Self {
Self {
path_traversal: true,
control_chars: true,
query_injection: true,
url_encoding: true,
}
}
pub fn check_path_traversal(mut self) -> Self {
self.path_traversal = true;
self
}
pub fn check_control_chars(mut self) -> Self {
self.control_chars = true;
self
}
pub fn check_query_injection(mut self) -> Self {
self.query_injection = true;
self
}
pub fn check_url_encoding(mut self) -> Self {
self.url_encoding = true;
self
}
pub fn validate_value(&self, field: &str, value: &str) -> Result<(), ValidationError> {
if self.path_traversal {
if contains_path_traversal(value) {
return Err(ValidationError::PathTraversal {
field: field.to_owned(),
value: value.to_owned(),
});
}
}
if self.control_chars {
if contains_control_char(value) {
return Err(ValidationError::ControlCharacter {
field: field.to_owned(),
value: value.to_owned(),
});
}
}
if self.query_injection {
if contains_query_injection(value) {
return Err(ValidationError::QueryInjection {
field: field.to_owned(),
value: value.to_owned(),
});
}
}
if self.url_encoding {
if contains_url_encoding(value) {
return Err(ValidationError::UrlEncoding {
field: field.to_owned(),
value: value.to_owned(),
});
}
}
Ok(())
}
pub fn validate_parsed(&self, parsed: &ParsedCommand<'_>) -> Result<(), ValidationError> {
for (field, value) in &parsed.args {
self.validate_value(field, value)?;
}
for (field, value) in &parsed.flags {
self.validate_value(field, value)?;
}
Ok(())
}
}
impl Middleware for InputValidator {
fn before_dispatch(
&self,
parsed: &ParsedCommand<'_>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
self.validate_parsed(parsed)
.map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
}
fn contains_path_traversal(value: &str) -> bool {
value.contains("../")
|| value.contains("..\\")
|| value.starts_with('/')
|| value.starts_with('~')
}
fn contains_control_char(value: &str) -> bool {
value.bytes().any(|b| {
let is_control = b <= 0x1F || b == 0x7F;
let is_allowed = b == b'\t' || b == b'\n';
is_control && !is_allowed
})
}
fn contains_query_injection(value: &str) -> bool {
if value.contains('?') {
return true;
}
let bytes = value.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'&' {
let rest = &bytes[i + 1..];
if let Some(eq_pos) = rest.iter().position(|&b| b == b'=') {
if eq_pos > 0 {
return true;
}
}
}
i += 1;
}
false
}
fn contains_url_encoding(value: &str) -> bool {
let bytes = value.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' && i + 2 < bytes.len() {
if bytes[i + 1].is_ascii_hexdigit() && bytes[i + 2].is_ascii_hexdigit() {
return true;
}
}
i += 1;
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::{Argument, Command, Flag};
use crate::parser::Parser;
#[test]
fn path_traversal_forward_slash_prefix() {
let v = InputValidator::new().check_path_traversal();
assert!(v.validate_value("f", "/etc/passwd").is_err());
}
#[test]
fn path_traversal_tilde_prefix() {
let v = InputValidator::new().check_path_traversal();
assert!(v.validate_value("f", "~/.ssh/id_rsa").is_err());
}
#[test]
fn path_traversal_dotdot_unix() {
let v = InputValidator::new().check_path_traversal();
assert!(v.validate_value("f", "../../secret").is_err());
}
#[test]
fn path_traversal_dotdot_windows() {
let v = InputValidator::new().check_path_traversal();
assert!(v.validate_value("f", "..\\windows\\system32").is_err());
}
#[test]
fn path_traversal_safe_relative_path() {
let v = InputValidator::new().check_path_traversal();
assert!(v.validate_value("f", "subdir/file.txt").is_ok());
}
#[test]
fn path_traversal_safe_filename() {
let v = InputValidator::new().check_path_traversal();
assert!(v.validate_value("f", "README.md").is_ok());
}
#[test]
fn path_traversal_disabled_does_not_flag() {
let v = InputValidator::new(); assert!(v.validate_value("f", "/etc/passwd").is_ok());
}
#[test]
fn control_char_null_byte() {
let v = InputValidator::new().check_control_chars();
assert!(v.validate_value("f", "hello\x00world").is_err());
}
#[test]
fn control_char_carriage_return() {
let v = InputValidator::new().check_control_chars();
assert!(v.validate_value("f", "hello\rworld").is_err());
}
#[test]
fn control_char_delete() {
let v = InputValidator::new().check_control_chars();
assert!(v.validate_value("f", "hello\x7fworld").is_err());
}
#[test]
fn control_char_tab_is_allowed() {
let v = InputValidator::new().check_control_chars();
assert!(v.validate_value("f", "hello\tworld").is_ok());
}
#[test]
fn control_char_newline_is_allowed() {
let v = InputValidator::new().check_control_chars();
assert!(v.validate_value("f", "hello\nworld").is_ok());
}
#[test]
fn control_char_safe_value() {
let v = InputValidator::new().check_control_chars();
assert!(v.validate_value("f", "ordinary text 123").is_ok());
}
#[test]
fn control_char_disabled_does_not_flag() {
let v = InputValidator::new(); assert!(v.validate_value("f", "hello\x00world").is_ok());
}
#[test]
fn query_injection_question_mark() {
let v = InputValidator::new().check_query_injection();
assert!(v.validate_value("url", "example.com?admin=1").is_err());
}
#[test]
fn query_injection_ampersand_key_val() {
let v = InputValidator::new().check_query_injection();
assert!(v.validate_value("q", "value&role=admin").is_err());
}
#[test]
fn query_injection_ampersand_no_equals_safe() {
let v = InputValidator::new().check_query_injection();
assert!(v.validate_value("q", "Tom & Jerry").is_ok());
}
#[test]
fn query_injection_safe_value() {
let v = InputValidator::new().check_query_injection();
assert!(v.validate_value("q", "normal search term").is_ok());
}
#[test]
fn query_injection_disabled_does_not_flag() {
let v = InputValidator::new(); assert!(v.validate_value("q", "example.com?admin=1").is_ok());
}
#[test]
fn url_encoding_percent_2f() {
let v = InputValidator::new().check_url_encoding();
assert!(v.validate_value("f", "hello%2Fworld").is_err());
}
#[test]
fn url_encoding_percent_00() {
let v = InputValidator::new().check_url_encoding();
assert!(v.validate_value("f", "null%00byte").is_err());
}
#[test]
fn url_encoding_uppercase_hex() {
let v = InputValidator::new().check_url_encoding();
assert!(v.validate_value("f", "%2E%2E%2F").is_err());
}
#[test]
fn url_encoding_lone_percent_is_safe() {
let v = InputValidator::new().check_url_encoding();
assert!(v.validate_value("f", "50% off").is_ok());
}
#[test]
fn url_encoding_safe_value() {
let v = InputValidator::new().check_url_encoding();
assert!(v.validate_value("f", "hello world").is_ok());
}
#[test]
fn url_encoding_disabled_does_not_flag() {
let v = InputValidator::new(); assert!(v.validate_value("f", "hello%2Fworld").is_ok());
}
#[test]
fn strict_catches_path_traversal() {
let v = InputValidator::strict();
let err = v.validate_value("f", "../etc").unwrap_err();
assert!(matches!(err, ValidationError::PathTraversal { .. }));
}
#[test]
fn strict_catches_control_char() {
let v = InputValidator::strict();
let err = v.validate_value("f", "a\x01b").unwrap_err();
assert!(matches!(err, ValidationError::ControlCharacter { .. }));
}
#[test]
fn strict_catches_query_injection() {
let v = InputValidator::strict();
let err = v.validate_value("f", "x?y=z").unwrap_err();
assert!(matches!(err, ValidationError::QueryInjection { .. }));
}
#[test]
fn strict_catches_url_encoding() {
let v = InputValidator::strict();
let err = v.validate_value("f", "%41").unwrap_err();
assert!(matches!(err, ValidationError::UrlEncoding { .. }));
}
#[test]
fn strict_safe_value_passes() {
let v = InputValidator::strict();
assert!(v.validate_value("f", "hello world").is_ok());
}
#[test]
fn validate_parsed_clean_args_pass() {
let cmd = Command::builder("get")
.argument(Argument::builder("id").required().build().unwrap())
.build()
.unwrap();
let cmds = vec![cmd];
let parser = Parser::new(&cmds);
let parsed = parser.parse(&["get", "42"]).unwrap();
let v = InputValidator::strict();
assert!(v.validate_parsed(&parsed).is_ok());
}
#[test]
fn validate_parsed_bad_arg_fails() {
let cmd = Command::builder("get")
.argument(Argument::builder("id").required().build().unwrap())
.build()
.unwrap();
let cmds = vec![cmd];
let parser = Parser::new(&cmds);
let parsed = parser.parse(&["get", "../secret"]).unwrap();
let v = InputValidator::new().check_path_traversal();
assert!(v.validate_parsed(&parsed).is_err());
}
#[test]
fn validate_parsed_bad_flag_fails() {
let cmd = Command::builder("deploy")
.flag(
Flag::builder("env")
.takes_value()
.required()
.build()
.unwrap(),
)
.build()
.unwrap();
let cmds = vec![cmd];
let parser = Parser::new(&cmds);
let parsed = parser.parse(&["deploy", "--env", "prod?debug=1"]).unwrap();
let v = InputValidator::new().check_query_injection();
assert!(v.validate_parsed(&parsed).is_err());
}
#[test]
fn middleware_before_dispatch_ok_for_clean_input() {
let cmd = Command::builder("ping").build().unwrap();
let cmds = vec![cmd];
let parsed = Parser::new(&cmds).parse(&["ping"]).unwrap();
let v = InputValidator::strict();
assert!(v.before_dispatch(&parsed).is_ok());
}
#[test]
fn middleware_before_dispatch_err_for_bad_input() {
let cmd = Command::builder("get")
.argument(Argument::builder("path").required().build().unwrap())
.build()
.unwrap();
let cmds = vec![cmd];
let parsed = Parser::new(&cmds).parse(&["get", "/etc/passwd"]).unwrap();
let v = InputValidator::new().check_path_traversal();
let result = v.before_dispatch(&parsed);
assert!(result.is_err());
}
#[test]
fn error_display_path_traversal() {
let err = ValidationError::PathTraversal {
field: "file".to_owned(),
value: "../secret".to_owned(),
};
let msg = err.to_string();
assert!(msg.contains("file"));
assert!(msg.contains("../secret"));
}
#[test]
fn error_display_control_character() {
let err = ValidationError::ControlCharacter {
field: "name".to_owned(),
value: "a\x00b".to_owned(),
};
let msg = err.to_string();
assert!(msg.contains("name"));
}
#[test]
fn error_display_query_injection() {
let err = ValidationError::QueryInjection {
field: "q".to_owned(),
value: "x?y=1".to_owned(),
};
let msg = err.to_string();
assert!(msg.contains("q"));
}
#[test]
fn error_display_url_encoding() {
let err = ValidationError::UrlEncoding {
field: "val".to_owned(),
value: "%2F".to_owned(),
};
let msg = err.to_string();
assert!(msg.contains("val"));
assert!(msg.contains("%2F"));
}
}