cbtop/adversarial/
validators.rs1use super::{AdversarialError, AdversarialResult};
4
5#[derive(Debug, Clone)]
7pub struct InputValidator {
8 pub max_size: usize,
10 pub verify_checksum: bool,
12 pub detect_nan: bool,
14 pub detect_inf: bool,
16}
17
18impl Default for InputValidator {
19 fn default() -> Self {
20 Self {
21 max_size: 1024 * 1024 * 1024, verify_checksum: true,
23 detect_nan: true,
24 detect_inf: true,
25 }
26 }
27}
28
29impl InputValidator {
30 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn with_max_size(mut self, max_size: usize) -> Self {
37 self.max_size = max_size;
38 self
39 }
40
41 pub fn validate_bytes(&self, data: &[u8]) -> AdversarialResult<()> {
43 if data.is_empty() {
45 return Err(AdversarialError::ZeroSizeInput);
46 }
47
48 if data.len() > self.max_size {
50 return Err(AdversarialError::MaxSizeExceeded {
51 size: data.len(),
52 max: self.max_size,
53 });
54 }
55
56 Ok(())
57 }
58
59 pub fn validate_floats(&self, data: &[f32]) -> AdversarialResult<()> {
61 if data.is_empty() {
63 return Err(AdversarialError::ZeroSizeInput);
64 }
65
66 let byte_size = std::mem::size_of_val(data);
68 if byte_size > self.max_size {
69 return Err(AdversarialError::MaxSizeExceeded {
70 size: byte_size,
71 max: self.max_size,
72 });
73 }
74
75 if self.detect_nan {
77 for (i, &v) in data.iter().enumerate() {
78 if v.is_nan() {
79 return Err(AdversarialError::NaNDetected { index: i });
80 }
81 }
82 }
83
84 if self.detect_inf {
86 for (i, &v) in data.iter().enumerate() {
87 if v.is_infinite() {
88 return Err(AdversarialError::InfinityDetected {
89 index: i,
90 positive: v.is_sign_positive(),
91 });
92 }
93 }
94 }
95
96 Ok(())
97 }
98
99 pub fn compute_checksum(data: &[u8]) -> u32 {
101 let mut a: u32 = 1;
103 let mut b: u32 = 0;
104 for &byte in data {
105 a = (a.wrapping_add(u32::from(byte))) % 65521;
106 b = (b.wrapping_add(a)) % 65521;
107 }
108 (b << 16) | a
109 }
110
111 pub fn verify_checksum(&self, data: &[u8], expected: u32) -> AdversarialResult<()> {
113 if !self.verify_checksum {
114 return Ok(());
115 }
116
117 let actual = Self::compute_checksum(data);
118 if actual != expected {
119 return Err(AdversarialError::CorruptedInput {
120 byte_index: 0, expected_checksum: expected,
122 actual_checksum: actual,
123 });
124 }
125 Ok(())
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct ConfigValidator {
132 pub mins: std::collections::HashMap<String, f64>,
134 pub maxs: std::collections::HashMap<String, f64>,
136}
137
138impl Default for ConfigValidator {
139 fn default() -> Self {
140 Self {
141 mins: std::collections::HashMap::new(),
142 maxs: std::collections::HashMap::new(),
143 }
144 }
145}
146
147impl ConfigValidator {
148 pub fn new() -> Self {
150 Self::default()
151 }
152
153 pub fn with_bound(mut self, field: &str, min: f64, max: f64) -> Self {
155 self.mins.insert(field.to_string(), min);
156 self.maxs.insert(field.to_string(), max);
157 self
158 }
159
160 pub fn validate_numeric(&self, field: &str, value: f64) -> AdversarialResult<f64> {
162 if value.is_nan() {
164 return Err(AdversarialError::ConfigParseError {
165 field: field.to_string(),
166 reason: "value is NaN".to_string(),
167 });
168 }
169
170 if let Some(&min) = self.mins.get(field) {
172 if value < min {
173 return Err(AdversarialError::ConfigOutOfBounds {
174 field: field.to_string(),
175 value: value.to_string(),
176 min: min.to_string(),
177 max: self
178 .maxs
179 .get(field)
180 .map_or("unbounded".to_string(), |m| m.to_string()),
181 });
182 }
183 }
184
185 if let Some(&max) = self.maxs.get(field) {
186 if value > max {
187 return Err(AdversarialError::ConfigOutOfBounds {
188 field: field.to_string(),
189 value: value.to_string(),
190 min: self
191 .mins
192 .get(field)
193 .map_or("unbounded".to_string(), |m| m.to_string()),
194 max: max.to_string(),
195 });
196 }
197 }
198
199 Ok(value)
200 }
201
202 pub fn validate_toml_string(&self, input: &str) -> AdversarialResult<()> {
204 let trimmed = input.trim();
206
207 if trimmed.is_empty() {
209 return Err(AdversarialError::ConfigParseError {
210 field: "root".to_string(),
211 reason: "empty config".to_string(),
212 });
213 }
214
215 let open_brackets = trimmed.matches('[').count();
217 let close_brackets = trimmed.matches(']').count();
218 if open_brackets != close_brackets {
219 return Err(AdversarialError::ConfigParseError {
220 field: "root".to_string(),
221 reason: format!(
222 "mismatched brackets: {} open, {} close",
223 open_brackets, close_brackets
224 ),
225 });
226 }
227
228 let quotes = trimmed.matches('"').count();
230 if !quotes.is_multiple_of(2) {
231 return Err(AdversarialError::ConfigParseError {
232 field: "root".to_string(),
233 reason: "unclosed string literal".to_string(),
234 });
235 }
236
237 Ok(())
238 }
239}