Skip to main content

cbtop/adversarial/
validators.rs

1//! Input and configuration validators for adversarial testing.
2
3use super::{AdversarialError, AdversarialResult};
4
5/// Input validator for adversarial testing
6#[derive(Debug, Clone)]
7pub struct InputValidator {
8    /// Maximum allowed input size in bytes
9    pub max_size: usize,
10    /// Whether to compute and verify checksums
11    pub verify_checksum: bool,
12    /// Whether to detect NaN values
13    pub detect_nan: bool,
14    /// Whether to detect infinity values
15    pub detect_inf: bool,
16}
17
18impl Default for InputValidator {
19    fn default() -> Self {
20        Self {
21            max_size: 1024 * 1024 * 1024, // 1GB default max
22            verify_checksum: true,
23            detect_nan: true,
24            detect_inf: true,
25        }
26    }
27}
28
29impl InputValidator {
30    /// Create a new input validator
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    /// Set maximum input size
36    pub fn with_max_size(mut self, max_size: usize) -> Self {
37        self.max_size = max_size;
38        self
39    }
40
41    /// Validate a byte slice
42    pub fn validate_bytes(&self, data: &[u8]) -> AdversarialResult<()> {
43        // F1004: Zero-size inputs handled
44        if data.is_empty() {
45            return Err(AdversarialError::ZeroSizeInput);
46        }
47
48        // F1005: Maximum-size inputs handled
49        if data.len() > self.max_size {
50            return Err(AdversarialError::MaxSizeExceeded {
51                size: data.len(),
52                max: self.max_size,
53            });
54        }
55
56        Ok(())
57    }
58
59    /// Validate a float slice for NaN and Inf
60    pub fn validate_floats(&self, data: &[f32]) -> AdversarialResult<()> {
61        // F1004: Zero-size inputs handled
62        if data.is_empty() {
63            return Err(AdversarialError::ZeroSizeInput);
64        }
65
66        // F1005: Maximum-size inputs handled
67        let byte_size = std::mem::size_of_val(data);
68        if byte_size > self.max_size {
69            return Err(AdversarialError::MaxSizeExceeded {
70                size: byte_size,
71                max: self.max_size,
72            });
73        }
74
75        // F1014: NaN propagation controlled
76        if self.detect_nan {
77            for (i, &v) in data.iter().enumerate() {
78                if v.is_nan() {
79                    return Err(AdversarialError::NaNDetected { index: i });
80                }
81            }
82        }
83
84        // F1015: Inf propagation controlled
85        if self.detect_inf {
86            for (i, &v) in data.iter().enumerate() {
87                if v.is_infinite() {
88                    return Err(AdversarialError::InfinityDetected {
89                        index: i,
90                        positive: v.is_sign_positive(),
91                    });
92                }
93            }
94        }
95
96        Ok(())
97    }
98
99    /// Compute a simple checksum for data validation
100    pub fn compute_checksum(data: &[u8]) -> u32 {
101        // Simple Adler-32-like checksum
102        let mut a: u32 = 1;
103        let mut b: u32 = 0;
104        for &byte in data {
105            a = (a.wrapping_add(u32::from(byte))) % 65521;
106            b = (b.wrapping_add(a)) % 65521;
107        }
108        (b << 16) | a
109    }
110
111    /// Verify data against expected checksum
112    pub fn verify_checksum(&self, data: &[u8], expected: u32) -> AdversarialResult<()> {
113        if !self.verify_checksum {
114            return Ok(());
115        }
116
117        let actual = Self::compute_checksum(data);
118        if actual != expected {
119            return Err(AdversarialError::CorruptedInput {
120                byte_index: 0, // Can't pinpoint exact corruption
121                expected_checksum: expected,
122                actual_checksum: actual,
123            });
124        }
125        Ok(())
126    }
127}
128
129/// Configuration validator for fuzzing (F1008, F1009)
130#[derive(Debug, Clone)]
131pub struct ConfigValidator {
132    /// Minimum allowed values by field name
133    pub mins: std::collections::HashMap<String, f64>,
134    /// Maximum allowed values by field name
135    pub maxs: std::collections::HashMap<String, f64>,
136}
137
138impl Default for ConfigValidator {
139    fn default() -> Self {
140        Self {
141            mins: std::collections::HashMap::new(),
142            maxs: std::collections::HashMap::new(),
143        }
144    }
145}
146
147impl ConfigValidator {
148    /// Create a new config validator
149    pub fn new() -> Self {
150        Self::default()
151    }
152
153    /// Add a bound for a field
154    pub fn with_bound(mut self, field: &str, min: f64, max: f64) -> Self {
155        self.mins.insert(field.to_string(), min);
156        self.maxs.insert(field.to_string(), max);
157        self
158    }
159
160    /// Validate a numeric config value (F1009)
161    pub fn validate_numeric(&self, field: &str, value: f64) -> AdversarialResult<f64> {
162        // Check for NaN
163        if value.is_nan() {
164            return Err(AdversarialError::ConfigParseError {
165                field: field.to_string(),
166                reason: "value is NaN".to_string(),
167            });
168        }
169
170        // Check bounds
171        if let Some(&min) = self.mins.get(field) {
172            if value < min {
173                return Err(AdversarialError::ConfigOutOfBounds {
174                    field: field.to_string(),
175                    value: value.to_string(),
176                    min: min.to_string(),
177                    max: self
178                        .maxs
179                        .get(field)
180                        .map_or("unbounded".to_string(), |m| m.to_string()),
181                });
182            }
183        }
184
185        if let Some(&max) = self.maxs.get(field) {
186            if value > max {
187                return Err(AdversarialError::ConfigOutOfBounds {
188                    field: field.to_string(),
189                    value: value.to_string(),
190                    min: self
191                        .mins
192                        .get(field)
193                        .map_or("unbounded".to_string(), |m| m.to_string()),
194                    max: max.to_string(),
195                });
196            }
197        }
198
199        Ok(value)
200    }
201
202    /// Validate TOML-like string config (F1008)
203    pub fn validate_toml_string(&self, input: &str) -> AdversarialResult<()> {
204        // Check for common malformed TOML patterns
205        let trimmed = input.trim();
206
207        // Empty input
208        if trimmed.is_empty() {
209            return Err(AdversarialError::ConfigParseError {
210                field: "root".to_string(),
211                reason: "empty config".to_string(),
212            });
213        }
214
215        // Unclosed brackets
216        let open_brackets = trimmed.matches('[').count();
217        let close_brackets = trimmed.matches(']').count();
218        if open_brackets != close_brackets {
219            return Err(AdversarialError::ConfigParseError {
220                field: "root".to_string(),
221                reason: format!(
222                    "mismatched brackets: {} open, {} close",
223                    open_brackets, close_brackets
224                ),
225            });
226        }
227
228        // Unclosed quotes
229        let quotes = trimmed.matches('"').count();
230        if !quotes.is_multiple_of(2) {
231            return Err(AdversarialError::ConfigParseError {
232                field: "root".to_string(),
233                reason: "unclosed string literal".to_string(),
234            });
235        }
236
237        Ok(())
238    }
239}