1use thiserror::Error;
26
27use crate::middleware::Middleware;
28use crate::model::ParsedCommand;
29
30#[derive(Debug, Error, PartialEq)]
32pub enum ValidationError {
33 #[error(
35 "field `{field}` contains a path traversal sequence in value: {value:?}"
36 )]
37 PathTraversal {
38 field: String,
40 value: String,
42 },
43
44 #[error(
47 "field `{field}` contains a control character in value: {value:?}"
48 )]
49 ControlCharacter {
50 field: String,
52 value: String,
54 },
55
56 #[error(
59 "field `{field}` contains an embedded query parameter in value: {value:?}"
60 )]
61 QueryInjection {
62 field: String,
64 value: String,
66 },
67
68 #[error(
70 "field `{field}` contains a URL-encoded sequence in value: {value:?}"
71 )]
72 UrlEncoding {
73 field: String,
75 value: String,
77 },
78}
79
80#[derive(Debug, Clone, Default)]
97pub struct InputValidator {
98 path_traversal: bool,
99 control_chars: bool,
100 query_injection: bool,
101 url_encoding: bool,
102}
103
104impl InputValidator {
105 pub fn new() -> Self {
111 Self::default()
112 }
113
114 pub fn strict() -> Self {
126 Self {
127 path_traversal: true,
128 control_chars: true,
129 query_injection: true,
130 url_encoding: true,
131 }
132 }
133
134 pub fn check_path_traversal(mut self) -> Self {
138 self.path_traversal = true;
139 self
140 }
141
142 pub fn check_control_chars(mut self) -> Self {
147 self.control_chars = true;
148 self
149 }
150
151 pub fn check_query_injection(mut self) -> Self {
156 self.query_injection = true;
157 self
158 }
159
160 pub fn check_url_encoding(mut self) -> Self {
166 self.url_encoding = true;
167 self
168 }
169
170 pub fn validate_value(&self, field: &str, value: &str) -> Result<(), ValidationError> {
191 if self.path_traversal {
192 if contains_path_traversal(value) {
193 return Err(ValidationError::PathTraversal {
194 field: field.to_owned(),
195 value: value.to_owned(),
196 });
197 }
198 }
199
200 if self.control_chars {
201 if contains_control_char(value) {
202 return Err(ValidationError::ControlCharacter {
203 field: field.to_owned(),
204 value: value.to_owned(),
205 });
206 }
207 }
208
209 if self.query_injection {
210 if contains_query_injection(value) {
211 return Err(ValidationError::QueryInjection {
212 field: field.to_owned(),
213 value: value.to_owned(),
214 });
215 }
216 }
217
218 if self.url_encoding {
219 if contains_url_encoding(value) {
220 return Err(ValidationError::UrlEncoding {
221 field: field.to_owned(),
222 value: value.to_owned(),
223 });
224 }
225 }
226
227 Ok(())
228 }
229
230 pub fn validate_parsed(&self, parsed: &ParsedCommand<'_>) -> Result<(), ValidationError> {
254 for (field, value) in &parsed.args {
255 self.validate_value(field, value)?;
256 }
257 for (field, value) in &parsed.flags {
258 self.validate_value(field, value)?;
259 }
260 Ok(())
261 }
262}
263
264impl Middleware for InputValidator {
265 fn before_dispatch(
271 &self,
272 parsed: &ParsedCommand<'_>,
273 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
274 self.validate_parsed(parsed)
275 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
276 }
277}
278
279fn contains_path_traversal(value: &str) -> bool {
283 value.contains("../")
284 || value.contains("..\\")
285 || value.starts_with('/')
286 || value.starts_with('~')
287}
288
289fn contains_control_char(value: &str) -> bool {
292 value.bytes().any(|b| {
293 let is_control = b <= 0x1F || b == 0x7F;
294 let is_allowed = b == b'\t' || b == b'\n';
295 is_control && !is_allowed
296 })
297}
298
299fn contains_query_injection(value: &str) -> bool {
301 if value.contains('?') {
302 return true;
303 }
304 let bytes = value.as_bytes();
307 let mut i = 0;
308 while i < bytes.len() {
309 if bytes[i] == b'&' {
310 let rest = &bytes[i + 1..];
312 if let Some(eq_pos) = rest.iter().position(|&b| b == b'=') {
313 if eq_pos > 0 {
315 return true;
316 }
317 }
318 }
319 i += 1;
320 }
321 false
322}
323
324fn contains_url_encoding(value: &str) -> bool {
326 let bytes = value.as_bytes();
327 let mut i = 0;
328 while i < bytes.len() {
329 if bytes[i] == b'%' && i + 2 < bytes.len() {
330 if bytes[i + 1].is_ascii_hexdigit() && bytes[i + 2].is_ascii_hexdigit() {
331 return true;
332 }
333 }
334 i += 1;
335 }
336 false
337}
338
339#[cfg(test)]
342mod tests {
343 use super::*;
344 use crate::model::{Argument, Command, Flag};
345 use crate::parser::Parser;
346
347 #[test]
350 fn path_traversal_forward_slash_prefix() {
351 let v = InputValidator::new().check_path_traversal();
352 assert!(v.validate_value("f", "/etc/passwd").is_err());
353 }
354
355 #[test]
356 fn path_traversal_tilde_prefix() {
357 let v = InputValidator::new().check_path_traversal();
358 assert!(v.validate_value("f", "~/.ssh/id_rsa").is_err());
359 }
360
361 #[test]
362 fn path_traversal_dotdot_unix() {
363 let v = InputValidator::new().check_path_traversal();
364 assert!(v.validate_value("f", "../../secret").is_err());
365 }
366
367 #[test]
368 fn path_traversal_dotdot_windows() {
369 let v = InputValidator::new().check_path_traversal();
370 assert!(v.validate_value("f", "..\\windows\\system32").is_err());
371 }
372
373 #[test]
374 fn path_traversal_safe_relative_path() {
375 let v = InputValidator::new().check_path_traversal();
376 assert!(v.validate_value("f", "subdir/file.txt").is_ok());
377 }
378
379 #[test]
380 fn path_traversal_safe_filename() {
381 let v = InputValidator::new().check_path_traversal();
382 assert!(v.validate_value("f", "README.md").is_ok());
383 }
384
385 #[test]
386 fn path_traversal_disabled_does_not_flag() {
387 let v = InputValidator::new(); assert!(v.validate_value("f", "/etc/passwd").is_ok());
389 }
390
391 #[test]
394 fn control_char_null_byte() {
395 let v = InputValidator::new().check_control_chars();
396 assert!(v.validate_value("f", "hello\x00world").is_err());
397 }
398
399 #[test]
400 fn control_char_carriage_return() {
401 let v = InputValidator::new().check_control_chars();
402 assert!(v.validate_value("f", "hello\rworld").is_err());
403 }
404
405 #[test]
406 fn control_char_delete() {
407 let v = InputValidator::new().check_control_chars();
408 assert!(v.validate_value("f", "hello\x7fworld").is_err());
409 }
410
411 #[test]
412 fn control_char_tab_is_allowed() {
413 let v = InputValidator::new().check_control_chars();
414 assert!(v.validate_value("f", "hello\tworld").is_ok());
415 }
416
417 #[test]
418 fn control_char_newline_is_allowed() {
419 let v = InputValidator::new().check_control_chars();
420 assert!(v.validate_value("f", "hello\nworld").is_ok());
421 }
422
423 #[test]
424 fn control_char_safe_value() {
425 let v = InputValidator::new().check_control_chars();
426 assert!(v.validate_value("f", "ordinary text 123").is_ok());
427 }
428
429 #[test]
430 fn control_char_disabled_does_not_flag() {
431 let v = InputValidator::new(); assert!(v.validate_value("f", "hello\x00world").is_ok());
433 }
434
435 #[test]
438 fn query_injection_question_mark() {
439 let v = InputValidator::new().check_query_injection();
440 assert!(v.validate_value("url", "example.com?admin=1").is_err());
441 }
442
443 #[test]
444 fn query_injection_ampersand_key_val() {
445 let v = InputValidator::new().check_query_injection();
446 assert!(v.validate_value("q", "value&role=admin").is_err());
447 }
448
449 #[test]
450 fn query_injection_ampersand_no_equals_safe() {
451 let v = InputValidator::new().check_query_injection();
452 assert!(v.validate_value("q", "Tom & Jerry").is_ok());
454 }
455
456 #[test]
457 fn query_injection_safe_value() {
458 let v = InputValidator::new().check_query_injection();
459 assert!(v.validate_value("q", "normal search term").is_ok());
460 }
461
462 #[test]
463 fn query_injection_disabled_does_not_flag() {
464 let v = InputValidator::new(); assert!(v.validate_value("q", "example.com?admin=1").is_ok());
466 }
467
468 #[test]
471 fn url_encoding_percent_2f() {
472 let v = InputValidator::new().check_url_encoding();
473 assert!(v.validate_value("f", "hello%2Fworld").is_err());
474 }
475
476 #[test]
477 fn url_encoding_percent_00() {
478 let v = InputValidator::new().check_url_encoding();
479 assert!(v.validate_value("f", "null%00byte").is_err());
480 }
481
482 #[test]
483 fn url_encoding_uppercase_hex() {
484 let v = InputValidator::new().check_url_encoding();
485 assert!(v.validate_value("f", "%2E%2E%2F").is_err());
486 }
487
488 #[test]
489 fn url_encoding_lone_percent_is_safe() {
490 let v = InputValidator::new().check_url_encoding();
491 assert!(v.validate_value("f", "50% off").is_ok());
493 }
494
495 #[test]
496 fn url_encoding_safe_value() {
497 let v = InputValidator::new().check_url_encoding();
498 assert!(v.validate_value("f", "hello world").is_ok());
499 }
500
501 #[test]
502 fn url_encoding_disabled_does_not_flag() {
503 let v = InputValidator::new(); assert!(v.validate_value("f", "hello%2Fworld").is_ok());
505 }
506
507 #[test]
510 fn strict_catches_path_traversal() {
511 let v = InputValidator::strict();
512 let err = v.validate_value("f", "../etc").unwrap_err();
513 assert!(matches!(err, ValidationError::PathTraversal { .. }));
514 }
515
516 #[test]
517 fn strict_catches_control_char() {
518 let v = InputValidator::strict();
519 let err = v.validate_value("f", "a\x01b").unwrap_err();
520 assert!(matches!(err, ValidationError::ControlCharacter { .. }));
521 }
522
523 #[test]
524 fn strict_catches_query_injection() {
525 let v = InputValidator::strict();
526 let err = v.validate_value("f", "x?y=z").unwrap_err();
527 assert!(matches!(err, ValidationError::QueryInjection { .. }));
528 }
529
530 #[test]
531 fn strict_catches_url_encoding() {
532 let v = InputValidator::strict();
533 let err = v.validate_value("f", "%41").unwrap_err();
534 assert!(matches!(err, ValidationError::UrlEncoding { .. }));
535 }
536
537 #[test]
538 fn strict_safe_value_passes() {
539 let v = InputValidator::strict();
540 assert!(v.validate_value("f", "hello world").is_ok());
541 }
542
543 #[test]
546 fn validate_parsed_clean_args_pass() {
547 let cmd = Command::builder("get")
548 .argument(Argument::builder("id").required().build().unwrap())
549 .build()
550 .unwrap();
551 let cmds = vec![cmd];
552 let parser = Parser::new(&cmds);
553 let parsed = parser.parse(&["get", "42"]).unwrap();
554
555 let v = InputValidator::strict();
556 assert!(v.validate_parsed(&parsed).is_ok());
557 }
558
559 #[test]
560 fn validate_parsed_bad_arg_fails() {
561 let cmd = Command::builder("get")
562 .argument(Argument::builder("id").required().build().unwrap())
563 .build()
564 .unwrap();
565 let cmds = vec![cmd];
566 let parser = Parser::new(&cmds);
567 let parsed = parser.parse(&["get", "../secret"]).unwrap();
568
569 let v = InputValidator::new().check_path_traversal();
570 assert!(v.validate_parsed(&parsed).is_err());
571 }
572
573 #[test]
574 fn validate_parsed_bad_flag_fails() {
575 let cmd = Command::builder("deploy")
576 .flag(
577 Flag::builder("env")
578 .takes_value()
579 .required()
580 .build()
581 .unwrap(),
582 )
583 .build()
584 .unwrap();
585 let cmds = vec![cmd];
586 let parser = Parser::new(&cmds);
587 let parsed = parser.parse(&["deploy", "--env", "prod?debug=1"]).unwrap();
588
589 let v = InputValidator::new().check_query_injection();
590 assert!(v.validate_parsed(&parsed).is_err());
591 }
592
593 #[test]
596 fn middleware_before_dispatch_ok_for_clean_input() {
597 let cmd = Command::builder("ping").build().unwrap();
598 let cmds = vec![cmd];
599 let parsed = Parser::new(&cmds).parse(&["ping"]).unwrap();
600
601 let v = InputValidator::strict();
602 assert!(v.before_dispatch(&parsed).is_ok());
603 }
604
605 #[test]
606 fn middleware_before_dispatch_err_for_bad_input() {
607 let cmd = Command::builder("get")
608 .argument(Argument::builder("path").required().build().unwrap())
609 .build()
610 .unwrap();
611 let cmds = vec![cmd];
612 let parsed = Parser::new(&cmds).parse(&["get", "/etc/passwd"]).unwrap();
613
614 let v = InputValidator::new().check_path_traversal();
615 let result = v.before_dispatch(&parsed);
616 assert!(result.is_err());
617 }
618
619 #[test]
622 fn error_display_path_traversal() {
623 let err = ValidationError::PathTraversal {
624 field: "file".to_owned(),
625 value: "../secret".to_owned(),
626 };
627 let msg = err.to_string();
628 assert!(msg.contains("file"));
629 assert!(msg.contains("../secret"));
630 }
631
632 #[test]
633 fn error_display_control_character() {
634 let err = ValidationError::ControlCharacter {
635 field: "name".to_owned(),
636 value: "a\x00b".to_owned(),
637 };
638 let msg = err.to_string();
639 assert!(msg.contains("name"));
640 }
641
642 #[test]
643 fn error_display_query_injection() {
644 let err = ValidationError::QueryInjection {
645 field: "q".to_owned(),
646 value: "x?y=1".to_owned(),
647 };
648 let msg = err.to_string();
649 assert!(msg.contains("q"));
650 }
651
652 #[test]
653 fn error_display_url_encoding() {
654 let err = ValidationError::UrlEncoding {
655 field: "val".to_owned(),
656 value: "%2F".to_owned(),
657 };
658 let msg = err.to_string();
659 assert!(msg.contains("val"));
660 assert!(msg.contains("%2F"));
661 }
662}