Skip to main content

argot_cmd/
input_validation.rs

1//! Hardened input validation middleware for agent-generated CLI input.
2//!
3//! AI agents may hallucinate or produce adversarial values. This module
4//! provides [`InputValidator`], an opt-in validator that can be wired into
5//! the [`crate::cli::Cli`] dispatch loop as [`crate::middleware::Middleware`],
6//! or called directly via [`InputValidator::validate_value`] and
7//! [`InputValidator::validate_parsed`].
8//!
9//! ## Example
10//!
11//! ```
12//! use argot_cmd::input_validation::InputValidator;
13//! use argot_cmd::Middleware;
14//! use argot_cmd::ParsedCommand;
15//!
16//! // Enable all checks at once.
17//! let validator = InputValidator::strict();
18//!
19//! // Or selectively opt in.
20//! let selective = InputValidator::new()
21//!     .check_path_traversal()
22//!     .check_control_chars();
23//! ```
24
25use thiserror::Error;
26
27use crate::middleware::Middleware;
28use crate::model::ParsedCommand;
29
30/// Errors produced by [`InputValidator`].
31#[derive(Debug, Error, PartialEq)]
32pub enum ValidationError {
33    /// A field value contains a path traversal sequence (`../`, `..\`, `/…`, or `~…`).
34    #[error(
35        "field `{field}` contains a path traversal sequence in value: {value:?}"
36    )]
37    PathTraversal {
38        /// Name of the argument or flag that triggered the error.
39        field: String,
40        /// The offending value.
41        value: String,
42    },
43
44    /// A field value contains an ASCII control character (0x00–0x1F or 0x7F),
45    /// excluding horizontal tab (0x09) and newline (0x0A).
46    #[error(
47        "field `{field}` contains a control character in value: {value:?}"
48    )]
49    ControlCharacter {
50        /// Name of the argument or flag that triggered the error.
51        field: String,
52        /// The offending value.
53        value: String,
54    },
55
56    /// A field value appears to contain an embedded URL query string
57    /// (`?` or `&key=val` patterns).
58    #[error(
59        "field `{field}` contains an embedded query parameter in value: {value:?}"
60    )]
61    QueryInjection {
62        /// Name of the argument or flag that triggered the error.
63        field: String,
64        /// The offending value.
65        value: String,
66    },
67
68    /// A field value contains a percent-encoded (`%XX`) sequence.
69    #[error(
70        "field `{field}` contains a URL-encoded sequence in value: {value:?}"
71    )]
72    UrlEncoding {
73        /// Name of the argument or flag that triggered the error.
74        field: String,
75        /// The offending value.
76        value: String,
77    },
78}
79
80/// Opt-in validator for argument and flag values supplied to a parsed command.
81///
82/// Create a permissive instance with [`InputValidator::new`] and enable
83/// individual checks through the builder methods, or use [`InputValidator::strict`]
84/// to enable every check at once.
85///
86/// # Examples
87///
88/// ```
89/// use argot_cmd::input_validation::InputValidator;
90///
91/// // Only check for path traversal.
92/// let v = InputValidator::new().check_path_traversal();
93/// assert!(v.validate_value("file", "safe_name.txt").is_ok());
94/// assert!(v.validate_value("file", "../etc/passwd").is_err());
95/// ```
96#[derive(Debug, Clone, Default)]
97pub struct InputValidator {
98    path_traversal: bool,
99    control_chars: bool,
100    query_injection: bool,
101    url_encoding: bool,
102}
103
104impl InputValidator {
105    /// Create a new [`InputValidator`] with all checks **disabled**.
106    ///
107    /// Use the builder methods (`.check_path_traversal()`, etc.) to opt in
108    /// to specific checks, or call [`InputValidator::strict`] to enable all
109    /// of them at once.
110    pub fn new() -> Self {
111        Self::default()
112    }
113
114    /// Create an [`InputValidator`] with **all** checks enabled.
115    ///
116    /// Equivalent to:
117    /// ```
118    /// # use argot_cmd::input_validation::InputValidator;
119    /// InputValidator::new()
120    ///     .check_path_traversal()
121    ///     .check_control_chars()
122    ///     .check_query_injection()
123    ///     .check_url_encoding();
124    /// ```
125    pub fn strict() -> Self {
126        Self {
127            path_traversal: true,
128            control_chars: true,
129            query_injection: true,
130            url_encoding: true,
131        }
132    }
133
134    /// Enable path-traversal detection.
135    ///
136    /// Flags values containing `../`, `..\`, or starting with `/` or `~`.
137    pub fn check_path_traversal(mut self) -> Self {
138        self.path_traversal = true;
139        self
140    }
141
142    /// Enable control-character detection.
143    ///
144    /// Flags values containing ASCII bytes in the range 0x00–0x1F or 0x7F,
145    /// **except** horizontal tab (0x09) and newline (0x0A).
146    pub fn check_control_chars(mut self) -> Self {
147        self.control_chars = true;
148        self
149    }
150
151    /// Enable embedded query-parameter detection.
152    ///
153    /// Flags values that contain `?` or match the pattern `&<key>=<val>`,
154    /// which may indicate URL-injection attempts.
155    pub fn check_query_injection(mut self) -> Self {
156        self.query_injection = true;
157        self
158    }
159
160    /// Enable percent-encoded string detection.
161    ///
162    /// Flags values containing `%XX` sequences (where `XX` is a pair of hex
163    /// digits), which may indicate attempts to smuggle disallowed characters
164    /// past earlier checks.
165    pub fn check_url_encoding(mut self) -> Self {
166        self.url_encoding = true;
167        self
168    }
169
170    /// Validate a single named value against all enabled checks.
171    ///
172    /// Returns the first [`ValidationError`] encountered, or `Ok(())` if the
173    /// value passes every enabled check.
174    ///
175    /// # Arguments
176    ///
177    /// * `field` — the name of the argument or flag being validated (used in
178    ///   the error message).
179    /// * `value` — the string value to inspect.
180    ///
181    /// # Examples
182    ///
183    /// ```
184    /// use argot_cmd::input_validation::InputValidator;
185    ///
186    /// let v = InputValidator::strict();
187    /// assert!(v.validate_value("path", "hello.txt").is_ok());
188    /// assert!(v.validate_value("path", "../secret").is_err());
189    /// ```
190    pub fn validate_value(&self, field: &str, value: &str) -> Result<(), ValidationError> {
191        if self.path_traversal {
192            if contains_path_traversal(value) {
193                return Err(ValidationError::PathTraversal {
194                    field: field.to_owned(),
195                    value: value.to_owned(),
196                });
197            }
198        }
199
200        if self.control_chars {
201            if contains_control_char(value) {
202                return Err(ValidationError::ControlCharacter {
203                    field: field.to_owned(),
204                    value: value.to_owned(),
205                });
206            }
207        }
208
209        if self.query_injection {
210            if contains_query_injection(value) {
211                return Err(ValidationError::QueryInjection {
212                    field: field.to_owned(),
213                    value: value.to_owned(),
214                });
215            }
216        }
217
218        if self.url_encoding {
219            if contains_url_encoding(value) {
220                return Err(ValidationError::UrlEncoding {
221                    field: field.to_owned(),
222                    value: value.to_owned(),
223                });
224            }
225        }
226
227        Ok(())
228    }
229
230    /// Validate all argument and flag values in a [`ParsedCommand`].
231    ///
232    /// Iterates over every entry in `parsed.args` and `parsed.flags` and calls
233    /// [`InputValidator::validate_value`] on each. Returns the first error
234    /// encountered, or `Ok(())` when every value passes.
235    ///
236    /// # Examples
237    ///
238    /// ```
239    /// use argot_cmd::{Command, Argument, Parser};
240    /// use argot_cmd::input_validation::InputValidator;
241    ///
242    /// let cmd = Command::builder("get")
243    ///     .argument(Argument::builder("id").required().build().unwrap())
244    ///     .build()
245    ///     .unwrap();
246    /// let cmds = vec![cmd];
247    /// let parser = Parser::new(&cmds);
248    /// let parsed = parser.parse(&["get", "safe_value"]).unwrap();
249    ///
250    /// let v = InputValidator::strict();
251    /// assert!(v.validate_parsed(&parsed).is_ok());
252    /// ```
253    pub fn validate_parsed(&self, parsed: &ParsedCommand<'_>) -> Result<(), ValidationError> {
254        for (field, value) in &parsed.args {
255            self.validate_value(field, value)?;
256        }
257        for (field, value) in &parsed.flags {
258            self.validate_value(field, value)?;
259        }
260        Ok(())
261    }
262}
263
264impl Middleware for InputValidator {
265    /// Validate all argument and flag values before the handler is invoked.
266    ///
267    /// Returns a [`ValidationError`] (boxed) if any enabled check fails,
268    /// which causes [`crate::cli::Cli`] to abort dispatch and surface the
269    /// error to the caller.
270    fn before_dispatch(
271        &self,
272        parsed: &ParsedCommand<'_>,
273    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
274        self.validate_parsed(parsed)
275            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
276    }
277}
278
279// ── Internal helpers ──────────────────────────────────────────────────────────
280
281/// Returns `true` when `value` contains a path traversal pattern.
282fn contains_path_traversal(value: &str) -> bool {
283    value.contains("../")
284        || value.contains("..\\")
285        || value.starts_with('/')
286        || value.starts_with('~')
287}
288
289/// Returns `true` when `value` contains an ASCII control character other than
290/// horizontal tab (0x09) and newline (0x0A).
291fn contains_control_char(value: &str) -> bool {
292    value.bytes().any(|b| {
293        let is_control = b <= 0x1F || b == 0x7F;
294        let is_allowed = b == b'\t' || b == b'\n';
295        is_control && !is_allowed
296    })
297}
298
299/// Returns `true` when `value` contains `?` or an `&key=val` pattern.
300fn contains_query_injection(value: &str) -> bool {
301    if value.contains('?') {
302        return true;
303    }
304    // Look for &key=val — an ampersand followed by at least one word char and
305    // then an equals sign.
306    let bytes = value.as_bytes();
307    let mut i = 0;
308    while i < bytes.len() {
309        if bytes[i] == b'&' {
310            // Scan for '=' after the '&'
311            let rest = &bytes[i + 1..];
312            if let Some(eq_pos) = rest.iter().position(|&b| b == b'=') {
313                // There must be at least one non-special byte between '&' and '='
314                if eq_pos > 0 {
315                    return true;
316                }
317            }
318        }
319        i += 1;
320    }
321    false
322}
323
324/// Returns `true` when `value` contains a `%XX` percent-encoded sequence.
325fn contains_url_encoding(value: &str) -> bool {
326    let bytes = value.as_bytes();
327    let mut i = 0;
328    while i < bytes.len() {
329        if bytes[i] == b'%' && i + 2 < bytes.len() {
330            if bytes[i + 1].is_ascii_hexdigit() && bytes[i + 2].is_ascii_hexdigit() {
331                return true;
332            }
333        }
334        i += 1;
335    }
336    false
337}
338
339// ── Tests ─────────────────────────────────────────────────────────────────────
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use crate::model::{Argument, Command, Flag};
345    use crate::parser::Parser;
346
347    // ── Path traversal ────────────────────────────────────────────────────────
348
349    #[test]
350    fn path_traversal_forward_slash_prefix() {
351        let v = InputValidator::new().check_path_traversal();
352        assert!(v.validate_value("f", "/etc/passwd").is_err());
353    }
354
355    #[test]
356    fn path_traversal_tilde_prefix() {
357        let v = InputValidator::new().check_path_traversal();
358        assert!(v.validate_value("f", "~/.ssh/id_rsa").is_err());
359    }
360
361    #[test]
362    fn path_traversal_dotdot_unix() {
363        let v = InputValidator::new().check_path_traversal();
364        assert!(v.validate_value("f", "../../secret").is_err());
365    }
366
367    #[test]
368    fn path_traversal_dotdot_windows() {
369        let v = InputValidator::new().check_path_traversal();
370        assert!(v.validate_value("f", "..\\windows\\system32").is_err());
371    }
372
373    #[test]
374    fn path_traversal_safe_relative_path() {
375        let v = InputValidator::new().check_path_traversal();
376        assert!(v.validate_value("f", "subdir/file.txt").is_ok());
377    }
378
379    #[test]
380    fn path_traversal_safe_filename() {
381        let v = InputValidator::new().check_path_traversal();
382        assert!(v.validate_value("f", "README.md").is_ok());
383    }
384
385    #[test]
386    fn path_traversal_disabled_does_not_flag() {
387        let v = InputValidator::new(); // traversal check off
388        assert!(v.validate_value("f", "/etc/passwd").is_ok());
389    }
390
391    // ── Control characters ────────────────────────────────────────────────────
392
393    #[test]
394    fn control_char_null_byte() {
395        let v = InputValidator::new().check_control_chars();
396        assert!(v.validate_value("f", "hello\x00world").is_err());
397    }
398
399    #[test]
400    fn control_char_carriage_return() {
401        let v = InputValidator::new().check_control_chars();
402        assert!(v.validate_value("f", "hello\rworld").is_err());
403    }
404
405    #[test]
406    fn control_char_delete() {
407        let v = InputValidator::new().check_control_chars();
408        assert!(v.validate_value("f", "hello\x7fworld").is_err());
409    }
410
411    #[test]
412    fn control_char_tab_is_allowed() {
413        let v = InputValidator::new().check_control_chars();
414        assert!(v.validate_value("f", "hello\tworld").is_ok());
415    }
416
417    #[test]
418    fn control_char_newline_is_allowed() {
419        let v = InputValidator::new().check_control_chars();
420        assert!(v.validate_value("f", "hello\nworld").is_ok());
421    }
422
423    #[test]
424    fn control_char_safe_value() {
425        let v = InputValidator::new().check_control_chars();
426        assert!(v.validate_value("f", "ordinary text 123").is_ok());
427    }
428
429    #[test]
430    fn control_char_disabled_does_not_flag() {
431        let v = InputValidator::new(); // control char check off
432        assert!(v.validate_value("f", "hello\x00world").is_ok());
433    }
434
435    // ── Query injection ───────────────────────────────────────────────────────
436
437    #[test]
438    fn query_injection_question_mark() {
439        let v = InputValidator::new().check_query_injection();
440        assert!(v.validate_value("url", "example.com?admin=1").is_err());
441    }
442
443    #[test]
444    fn query_injection_ampersand_key_val() {
445        let v = InputValidator::new().check_query_injection();
446        assert!(v.validate_value("q", "value&role=admin").is_err());
447    }
448
449    #[test]
450    fn query_injection_ampersand_no_equals_safe() {
451        let v = InputValidator::new().check_query_injection();
452        // A lone '&' without an '=' after it is not flagged as key=val injection.
453        assert!(v.validate_value("q", "Tom & Jerry").is_ok());
454    }
455
456    #[test]
457    fn query_injection_safe_value() {
458        let v = InputValidator::new().check_query_injection();
459        assert!(v.validate_value("q", "normal search term").is_ok());
460    }
461
462    #[test]
463    fn query_injection_disabled_does_not_flag() {
464        let v = InputValidator::new(); // query check off
465        assert!(v.validate_value("q", "example.com?admin=1").is_ok());
466    }
467
468    // ── URL encoding ──────────────────────────────────────────────────────────
469
470    #[test]
471    fn url_encoding_percent_2f() {
472        let v = InputValidator::new().check_url_encoding();
473        assert!(v.validate_value("f", "hello%2Fworld").is_err());
474    }
475
476    #[test]
477    fn url_encoding_percent_00() {
478        let v = InputValidator::new().check_url_encoding();
479        assert!(v.validate_value("f", "null%00byte").is_err());
480    }
481
482    #[test]
483    fn url_encoding_uppercase_hex() {
484        let v = InputValidator::new().check_url_encoding();
485        assert!(v.validate_value("f", "%2E%2E%2F").is_err());
486    }
487
488    #[test]
489    fn url_encoding_lone_percent_is_safe() {
490        let v = InputValidator::new().check_url_encoding();
491        // A bare '%' not followed by two hex digits is not flagged.
492        assert!(v.validate_value("f", "50% off").is_ok());
493    }
494
495    #[test]
496    fn url_encoding_safe_value() {
497        let v = InputValidator::new().check_url_encoding();
498        assert!(v.validate_value("f", "hello world").is_ok());
499    }
500
501    #[test]
502    fn url_encoding_disabled_does_not_flag() {
503        let v = InputValidator::new(); // url encoding check off
504        assert!(v.validate_value("f", "hello%2Fworld").is_ok());
505    }
506
507    // ── strict() helper ───────────────────────────────────────────────────────
508
509    #[test]
510    fn strict_catches_path_traversal() {
511        let v = InputValidator::strict();
512        let err = v.validate_value("f", "../etc").unwrap_err();
513        assert!(matches!(err, ValidationError::PathTraversal { .. }));
514    }
515
516    #[test]
517    fn strict_catches_control_char() {
518        let v = InputValidator::strict();
519        let err = v.validate_value("f", "a\x01b").unwrap_err();
520        assert!(matches!(err, ValidationError::ControlCharacter { .. }));
521    }
522
523    #[test]
524    fn strict_catches_query_injection() {
525        let v = InputValidator::strict();
526        let err = v.validate_value("f", "x?y=z").unwrap_err();
527        assert!(matches!(err, ValidationError::QueryInjection { .. }));
528    }
529
530    #[test]
531    fn strict_catches_url_encoding() {
532        let v = InputValidator::strict();
533        let err = v.validate_value("f", "%41").unwrap_err();
534        assert!(matches!(err, ValidationError::UrlEncoding { .. }));
535    }
536
537    #[test]
538    fn strict_safe_value_passes() {
539        let v = InputValidator::strict();
540        assert!(v.validate_value("f", "hello world").is_ok());
541    }
542
543    // ── validate_parsed ───────────────────────────────────────────────────────
544
545    #[test]
546    fn validate_parsed_clean_args_pass() {
547        let cmd = Command::builder("get")
548            .argument(Argument::builder("id").required().build().unwrap())
549            .build()
550            .unwrap();
551        let cmds = vec![cmd];
552        let parser = Parser::new(&cmds);
553        let parsed = parser.parse(&["get", "42"]).unwrap();
554
555        let v = InputValidator::strict();
556        assert!(v.validate_parsed(&parsed).is_ok());
557    }
558
559    #[test]
560    fn validate_parsed_bad_arg_fails() {
561        let cmd = Command::builder("get")
562            .argument(Argument::builder("id").required().build().unwrap())
563            .build()
564            .unwrap();
565        let cmds = vec![cmd];
566        let parser = Parser::new(&cmds);
567        let parsed = parser.parse(&["get", "../secret"]).unwrap();
568
569        let v = InputValidator::new().check_path_traversal();
570        assert!(v.validate_parsed(&parsed).is_err());
571    }
572
573    #[test]
574    fn validate_parsed_bad_flag_fails() {
575        let cmd = Command::builder("deploy")
576            .flag(
577                Flag::builder("env")
578                    .takes_value()
579                    .required()
580                    .build()
581                    .unwrap(),
582            )
583            .build()
584            .unwrap();
585        let cmds = vec![cmd];
586        let parser = Parser::new(&cmds);
587        let parsed = parser.parse(&["deploy", "--env", "prod?debug=1"]).unwrap();
588
589        let v = InputValidator::new().check_query_injection();
590        assert!(v.validate_parsed(&parsed).is_err());
591    }
592
593    // ── Middleware impl ───────────────────────────────────────────────────────
594
595    #[test]
596    fn middleware_before_dispatch_ok_for_clean_input() {
597        let cmd = Command::builder("ping").build().unwrap();
598        let cmds = vec![cmd];
599        let parsed = Parser::new(&cmds).parse(&["ping"]).unwrap();
600
601        let v = InputValidator::strict();
602        assert!(v.before_dispatch(&parsed).is_ok());
603    }
604
605    #[test]
606    fn middleware_before_dispatch_err_for_bad_input() {
607        let cmd = Command::builder("get")
608            .argument(Argument::builder("path").required().build().unwrap())
609            .build()
610            .unwrap();
611        let cmds = vec![cmd];
612        let parsed = Parser::new(&cmds).parse(&["get", "/etc/passwd"]).unwrap();
613
614        let v = InputValidator::new().check_path_traversal();
615        let result = v.before_dispatch(&parsed);
616        assert!(result.is_err());
617    }
618
619    // ── Error messages ────────────────────────────────────────────────────────
620
621    #[test]
622    fn error_display_path_traversal() {
623        let err = ValidationError::PathTraversal {
624            field: "file".to_owned(),
625            value: "../secret".to_owned(),
626        };
627        let msg = err.to_string();
628        assert!(msg.contains("file"));
629        assert!(msg.contains("../secret"));
630    }
631
632    #[test]
633    fn error_display_control_character() {
634        let err = ValidationError::ControlCharacter {
635            field: "name".to_owned(),
636            value: "a\x00b".to_owned(),
637        };
638        let msg = err.to_string();
639        assert!(msg.contains("name"));
640    }
641
642    #[test]
643    fn error_display_query_injection() {
644        let err = ValidationError::QueryInjection {
645            field: "q".to_owned(),
646            value: "x?y=1".to_owned(),
647        };
648        let msg = err.to_string();
649        assert!(msg.contains("q"));
650    }
651
652    #[test]
653    fn error_display_url_encoding() {
654        let err = ValidationError::UrlEncoding {
655            field: "val".to_owned(),
656            value: "%2F".to_owned(),
657        };
658        let msg = err.to_string();
659        assert!(msg.contains("val"));
660        assert!(msg.contains("%2F"));
661    }
662}