1use crate::error::Result;
9use std::collections::HashMap;
10use std::time::{Duration, Instant};
11
12use super::types::*;
13
14#[derive(Debug)]
16pub struct InputValidationAnalyzer {
17 test_cases: Vec<InputValidationTest>,
19 results: Vec<ValidationTestResult>,
21 vulnerability_stats: VulnerabilityStatistics,
23}
24
25impl InputValidationAnalyzer {
26 pub fn new() -> Self {
28 Self {
29 test_cases: Vec::new(),
30 results: Vec::new(),
31 vulnerability_stats: VulnerabilityStatistics::default(),
32 }
33 }
34
35 pub fn with_builtin_tests() -> Self {
37 let mut analyzer = Self::new();
38 analyzer.register_builtin_tests();
39 analyzer
40 }
41
42 pub fn register_builtin_tests(&mut self) {
44 self.test_cases.clear();
46
47 self.test_cases.push(InputValidationTest {
49 name: "NaN Injection Test".to_string(),
50 description: "Tests resistance to NaN value injection".to_string(),
51 category: ValidationCategory::MalformedInput,
52 attack_vector: AttackVector::NaNInjection,
53 expected_behavior: ExpectedBehavior::RejectWithError("Non-finite values".to_string()),
54 payload_generator: PayloadType::NaNPayload,
55 });
56
57 self.test_cases.push(InputValidationTest {
59 name: "Infinity Injection Test".to_string(),
60 description: "Tests resistance to infinity value injection".to_string(),
61 category: ValidationCategory::MalformedInput,
62 attack_vector: AttackVector::ExtremeValues,
63 expected_behavior: ExpectedBehavior::RejectWithError("Non-finite values".to_string()),
64 payload_generator: PayloadType::InfinityPayload,
65 });
66
67 self.test_cases.push(InputValidationTest {
69 name: "Dimension Mismatch Test".to_string(),
70 description: "Tests handling of mismatched array dimensions".to_string(),
71 category: ValidationCategory::BoundaryConditions,
72 attack_vector: AttackVector::DimensionMismatch,
73 expected_behavior: ExpectedBehavior::RejectWithError("Dimension mismatch".to_string()),
74 payload_generator: PayloadType::DimensionMismatchPayload,
75 });
76
77 self.test_cases.push(InputValidationTest {
79 name: "Empty Array Test".to_string(),
80 description: "Tests handling of empty input arrays".to_string(),
81 category: ValidationCategory::BoundaryConditions,
82 attack_vector: AttackVector::EmptyArrays,
83 expected_behavior: ExpectedBehavior::RejectWithError("Empty arrays".to_string()),
84 payload_generator: PayloadType::ZeroSizedPayload,
85 });
86
87 self.test_cases.push(InputValidationTest {
89 name: "Negative Learning Rate Test".to_string(),
90 description: "Tests handling of negative learning rates".to_string(),
91 category: ValidationCategory::MalformedInput,
92 attack_vector: AttackVector::ExtremeValues,
93 expected_behavior: ExpectedBehavior::RejectWithError(
94 "Negative learning rate".to_string(),
95 ),
96 payload_generator: PayloadType::NegativeLearningRate,
97 });
98
99 self.test_cases.push(InputValidationTest {
101 name: "Extreme Value Test".to_string(),
102 description: "Tests handling of extremely large values".to_string(),
103 category: ValidationCategory::BoundaryConditions,
104 attack_vector: AttackVector::ExtremeValues,
105 expected_behavior: ExpectedBehavior::HandleGracefully,
106 payload_generator: PayloadType::ExtremeValuePayload(1e100),
107 });
108
109 self.test_cases.push(InputValidationTest {
111 name: "Negative Dimensions Test".to_string(),
112 description: "Tests handling of negative array dimensions".to_string(),
113 category: ValidationCategory::BoundaryConditions,
114 attack_vector: AttackVector::NegativeDimensions,
115 expected_behavior: ExpectedBehavior::RejectWithError("Negative dimensions".to_string()),
116 payload_generator: PayloadType::DimensionMismatchPayload,
117 });
118
119 self.test_cases.push(InputValidationTest {
121 name: "Malformed Gradients Test".to_string(),
122 description: "Tests handling of malformed gradient arrays".to_string(),
123 category: ValidationCategory::TypeConfusion,
124 attack_vector: AttackVector::MalformedGradients,
125 expected_behavior: ExpectedBehavior::SanitizeInput,
126 payload_generator: PayloadType::NaNPayload,
127 });
128 }
129
130 pub fn add_test(&mut self, test: InputValidationTest) {
132 self.test_cases.push(test);
133 }
134
135 pub fn clear_tests(&mut self) {
137 self.test_cases.clear();
138 }
139
140 pub fn test_count(&self) -> usize {
142 self.test_cases.len()
143 }
144
145 pub fn run_all_tests(&mut self) -> Result<Vec<ValidationTestResult>> {
147 self.results.clear();
148
149 for test in &self.test_cases.clone() {
150 let result = self.execute_validation_test(test)?;
151 self.results.push(result);
152 }
153
154 self.update_vulnerability_statistics();
155 Ok(self.results.clone())
156 }
157
158 pub fn run_test(&mut self, test_name: &str) -> Result<Option<ValidationTestResult>> {
160 if let Some(test) = self.test_cases.iter().find(|t| t.name == test_name) {
161 let result = self.execute_validation_test(test)?;
162 self.results.push(result.clone());
163 self.update_vulnerability_statistics();
164 Ok(Some(result))
165 } else {
166 Ok(None)
167 }
168 }
169
170 fn execute_validation_test(&self, test: &InputValidationTest) -> Result<ValidationTestResult> {
172 let start_time = Instant::now();
173
174 let (status, vulnerability, error_message) = match &test.attack_vector {
176 AttackVector::NaNInjection => self.simulate_nan_injection_test(),
177 AttackVector::ExtremeValues => self.simulate_extreme_values_test(),
178 AttackVector::DimensionMismatch => self.simulate_dimension_mismatch_test(),
179 AttackVector::EmptyArrays => self.simulate_empty_arrays_test(),
180 AttackVector::NegativeDimensions => self.simulate_negative_dimensions_test(),
181 AttackVector::MalformedGradients => self.simulate_malformed_gradients_test(),
182 AttackVector::PrivacyParameterAttack => self.simulate_privacy_parameter_attack_test(),
183 AttackVector::MemoryExhaustionAttack => self.simulate_memory_exhaustion_test(),
184 };
185
186 let execution_time = start_time.elapsed();
187
188 let severity = if let Some(ref vuln) = vulnerability {
190 SeverityLevel::from_cvss_score(vuln.cvss_score)
191 } else {
192 SeverityLevel::Low
193 };
194
195 let recommendation = if vulnerability.is_some() {
196 self.generate_validation_recommendation(test)
197 } else {
198 None
199 };
200
201 Ok(ValidationTestResult {
202 test_name: test.name.clone(),
203 status,
204 vulnerability_detected: vulnerability,
205 error_message,
206 execution_time,
207 severity,
208 recommendation,
209 })
210 }
211
212 fn simulate_nan_injection_test(&self) -> (TestStatus, Option<Vulnerability>, Option<String>) {
214 if self.should_detect_vulnerability(0.3) {
217 let vulnerability = Vulnerability {
218 vulnerability_type: VulnerabilityType::InputValidationBypass,
219 cvss_score: 6.5,
220 description: "Application accepts NaN values without validation".to_string(),
221 proof_of_concept: "Injected f64::NAN into input parameters".to_string(),
222 impact: ImpactAssessment {
223 confidentiality: ImpactLevel::Low,
224 integrity: ImpactLevel::Medium,
225 availability: ImpactLevel::Medium,
226 privacy: ImpactLevel::Low,
227 },
228 exploitability: ExploitabilityAssessment {
229 attack_complexity: ComplexityLevel::Low,
230 privileges_required: PrivilegeLevel::None,
231 user_interaction: false,
232 attack_vector: AccessibilityLevel::Network,
233 },
234 };
235 (
236 TestStatus::Failed,
237 Some(vulnerability),
238 Some("NaN values accepted without validation".to_string()),
239 )
240 } else {
241 (TestStatus::Passed, None, None)
242 }
243 }
244
245 fn simulate_extreme_values_test(&self) -> (TestStatus, Option<Vulnerability>, Option<String>) {
247 if self.should_detect_vulnerability(0.25) {
248 let vulnerability = Vulnerability {
249 vulnerability_type: VulnerabilityType::NumericalInstability,
250 cvss_score: 5.5,
251 description: "Application vulnerable to extreme value overflow".to_string(),
252 proof_of_concept: "Injected values > 1e100 causing numerical overflow".to_string(),
253 impact: ImpactAssessment {
254 confidentiality: ImpactLevel::None,
255 integrity: ImpactLevel::Medium,
256 availability: ImpactLevel::High,
257 privacy: ImpactLevel::Low,
258 },
259 exploitability: ExploitabilityAssessment {
260 attack_complexity: ComplexityLevel::Low,
261 privileges_required: PrivilegeLevel::None,
262 user_interaction: false,
263 attack_vector: AccessibilityLevel::Network,
264 },
265 };
266 (
267 TestStatus::Failed,
268 Some(vulnerability),
269 Some("Extreme values cause numerical overflow".to_string()),
270 )
271 } else {
272 (TestStatus::Passed, None, None)
273 }
274 }
275
276 fn simulate_dimension_mismatch_test(
278 &self,
279 ) -> (TestStatus, Option<Vulnerability>, Option<String>) {
280 if self.should_detect_vulnerability(0.4) {
281 let vulnerability = Vulnerability {
282 vulnerability_type: VulnerabilityType::BufferOverflow,
283 cvss_score: 7.2,
284 description: "Dimension validation bypass leading to buffer overflow".to_string(),
285 proof_of_concept: "Provided mismatched array dimensions bypassing validation"
286 .to_string(),
287 impact: ImpactAssessment {
288 confidentiality: ImpactLevel::Medium,
289 integrity: ImpactLevel::High,
290 availability: ImpactLevel::High,
291 privacy: ImpactLevel::Medium,
292 },
293 exploitability: ExploitabilityAssessment {
294 attack_complexity: ComplexityLevel::Medium,
295 privileges_required: PrivilegeLevel::Low,
296 user_interaction: false,
297 attack_vector: AccessibilityLevel::Network,
298 },
299 };
300 (
301 TestStatus::Failed,
302 Some(vulnerability),
303 Some("Dimension mismatch not properly validated".to_string()),
304 )
305 } else {
306 (TestStatus::Passed, None, None)
307 }
308 }
309
310 fn simulate_empty_arrays_test(&self) -> (TestStatus, Option<Vulnerability>, Option<String>) {
312 if self.should_detect_vulnerability(0.2) {
313 let vulnerability = Vulnerability {
314 vulnerability_type: VulnerabilityType::DenialOfService,
315 cvss_score: 4.0,
316 description: "Empty arrays cause application crash".to_string(),
317 proof_of_concept: "Provided zero-length arrays causing division by zero"
318 .to_string(),
319 impact: ImpactAssessment {
320 confidentiality: ImpactLevel::None,
321 integrity: ImpactLevel::Low,
322 availability: ImpactLevel::High,
323 privacy: ImpactLevel::None,
324 },
325 exploitability: ExploitabilityAssessment {
326 attack_complexity: ComplexityLevel::Low,
327 privileges_required: PrivilegeLevel::None,
328 user_interaction: false,
329 attack_vector: AccessibilityLevel::Network,
330 },
331 };
332 (
333 TestStatus::Failed,
334 Some(vulnerability),
335 Some("Empty arrays not handled gracefully".to_string()),
336 )
337 } else {
338 (TestStatus::Passed, None, None)
339 }
340 }
341
342 fn simulate_negative_dimensions_test(
344 &self,
345 ) -> (TestStatus, Option<Vulnerability>, Option<String>) {
346 if self.should_detect_vulnerability(0.35) {
347 let vulnerability = Vulnerability {
348 vulnerability_type: VulnerabilityType::InputValidationBypass,
349 cvss_score: 5.8,
350 description: "Negative dimensions bypass validation checks".to_string(),
351 proof_of_concept: "Provided negative array dimensions causing unexpected behavior"
352 .to_string(),
353 impact: ImpactAssessment {
354 confidentiality: ImpactLevel::Low,
355 integrity: ImpactLevel::Medium,
356 availability: ImpactLevel::Medium,
357 privacy: ImpactLevel::Low,
358 },
359 exploitability: ExploitabilityAssessment {
360 attack_complexity: ComplexityLevel::Low,
361 privileges_required: PrivilegeLevel::None,
362 user_interaction: false,
363 attack_vector: AccessibilityLevel::Network,
364 },
365 };
366 (
367 TestStatus::Failed,
368 Some(vulnerability),
369 Some("Negative dimensions not validated".to_string()),
370 )
371 } else {
372 (TestStatus::Passed, None, None)
373 }
374 }
375
376 fn simulate_malformed_gradients_test(
378 &self,
379 ) -> (TestStatus, Option<Vulnerability>, Option<String>) {
380 if self.should_detect_vulnerability(0.3) {
381 let vulnerability = Vulnerability {
382 vulnerability_type: VulnerabilityType::MemoryCorruption,
383 cvss_score: 8.1,
384 description: "Malformed gradients cause memory corruption".to_string(),
385 proof_of_concept: "Injected malformed gradient arrays with invalid pointers"
386 .to_string(),
387 impact: ImpactAssessment {
388 confidentiality: ImpactLevel::High,
389 integrity: ImpactLevel::High,
390 availability: ImpactLevel::High,
391 privacy: ImpactLevel::Medium,
392 },
393 exploitability: ExploitabilityAssessment {
394 attack_complexity: ComplexityLevel::High,
395 privileges_required: PrivilegeLevel::Low,
396 user_interaction: false,
397 attack_vector: AccessibilityLevel::Network,
398 },
399 };
400 (
401 TestStatus::Failed,
402 Some(vulnerability),
403 Some("Malformed gradients cause memory corruption".to_string()),
404 )
405 } else {
406 (TestStatus::Passed, None, None)
407 }
408 }
409
410 fn simulate_privacy_parameter_attack_test(
412 &self,
413 ) -> (TestStatus, Option<Vulnerability>, Option<String>) {
414 if self.should_detect_vulnerability(0.4) {
415 let vulnerability = Vulnerability {
416 vulnerability_type: VulnerabilityType::PrivacyViolation,
417 cvss_score: 7.5,
418 description: "Privacy parameters can be manipulated to reduce protection"
419 .to_string(),
420 proof_of_concept: "Modified epsilon/delta parameters bypassing privacy guarantees"
421 .to_string(),
422 impact: ImpactAssessment {
423 confidentiality: ImpactLevel::High,
424 integrity: ImpactLevel::Medium,
425 availability: ImpactLevel::Low,
426 privacy: ImpactLevel::High,
427 },
428 exploitability: ExploitabilityAssessment {
429 attack_complexity: ComplexityLevel::Medium,
430 privileges_required: PrivilegeLevel::Low,
431 user_interaction: false,
432 attack_vector: AccessibilityLevel::Network,
433 },
434 };
435 (
436 TestStatus::Failed,
437 Some(vulnerability),
438 Some("Privacy parameters not properly validated".to_string()),
439 )
440 } else {
441 (TestStatus::Passed, None, None)
442 }
443 }
444
445 fn simulate_memory_exhaustion_test(
447 &self,
448 ) -> (TestStatus, Option<Vulnerability>, Option<String>) {
449 if self.should_detect_vulnerability(0.2) {
450 let vulnerability = Vulnerability {
451 vulnerability_type: VulnerabilityType::DenialOfService,
452 cvss_score: 6.0,
453 description: "Memory exhaustion attack successful".to_string(),
454 proof_of_concept: "Allocated excessive memory causing system slowdown".to_string(),
455 impact: ImpactAssessment {
456 confidentiality: ImpactLevel::None,
457 integrity: ImpactLevel::Low,
458 availability: ImpactLevel::High,
459 privacy: ImpactLevel::None,
460 },
461 exploitability: ExploitabilityAssessment {
462 attack_complexity: ComplexityLevel::Low,
463 privileges_required: PrivilegeLevel::None,
464 user_interaction: false,
465 attack_vector: AccessibilityLevel::Network,
466 },
467 };
468 (
469 TestStatus::Failed,
470 Some(vulnerability),
471 Some("Memory exhaustion protection insufficient".to_string()),
472 )
473 } else {
474 (TestStatus::Passed, None, None)
475 }
476 }
477
478 fn should_detect_vulnerability(&self, probability: f64) -> bool {
480 let seed = (self.test_cases.len() + self.results.len()) as f64;
482 (seed * 0.1234567).fract() < probability
483 }
484
485 fn update_vulnerability_statistics(&mut self) {
487 let total_tests = self.results.len();
488 let tests_passed = self
489 .results
490 .iter()
491 .filter(|r| r.status == TestStatus::Passed)
492 .count();
493 let tests_failed = self
494 .results
495 .iter()
496 .filter(|r| r.status == TestStatus::Failed)
497 .count();
498
499 let mut vulnerabilities_by_severity = HashMap::new();
500 let mut vulnerabilities_by_type = HashMap::new();
501 let mut total_cvss = 0.0;
502 let mut vuln_count = 0;
503 let mut total_detection_time = Duration::from_secs(0);
504
505 for result in &self.results {
506 if let Some(vuln) = &result.vulnerability_detected {
507 *vulnerabilities_by_severity
508 .entry(result.severity.clone())
509 .or_insert(0) += 1;
510 *vulnerabilities_by_type
511 .entry(format!("{:?}", vuln.vulnerability_type))
512 .or_insert(0) += 1;
513 total_cvss += vuln.cvss_score;
514 vuln_count += 1;
515 total_detection_time += result.execution_time;
516 }
517 }
518
519 let average_cvss_score = if vuln_count > 0 {
520 total_cvss / vuln_count as f64
521 } else {
522 0.0
523 };
524
525 let average_detection_time = if vuln_count > 0 {
526 total_detection_time / vuln_count as u32
527 } else {
528 Duration::from_secs(0)
529 };
530
531 self.vulnerability_stats = VulnerabilityStatistics {
532 total_tests,
533 tests_passed,
534 tests_failed,
535 vulnerabilities_by_severity,
536 vulnerabilities_by_type,
537 average_cvss_score,
538 average_detection_time,
539 };
540 }
541
542 fn generate_validation_recommendation(&self, test: &InputValidationTest) -> Option<String> {
544 match test.attack_vector {
545 AttackVector::NaNInjection => {
546 Some("Implement NaN/Infinity checks in input validation".to_string())
547 }
548 AttackVector::ExtremeValues => Some("Add bounds checking for input values".to_string()),
549 AttackVector::DimensionMismatch => {
550 Some("Validate array dimensions before processing".to_string())
551 }
552 AttackVector::EmptyArrays => {
553 Some("Check for empty arrays and handle appropriately".to_string())
554 }
555 AttackVector::NegativeDimensions => {
556 Some("Validate dimensions are non-negative before use".to_string())
557 }
558 AttackVector::MalformedGradients => {
559 Some("Implement gradient validation and sanitization".to_string())
560 }
561 AttackVector::PrivacyParameterAttack => {
562 Some("Validate privacy parameters before use".to_string())
563 }
564 AttackVector::MemoryExhaustionAttack => {
565 Some("Implement memory usage limits and monitoring".to_string())
566 }
567 }
568 }
569
570 pub fn get_statistics(&self) -> &VulnerabilityStatistics {
572 &self.vulnerability_stats
573 }
574
575 pub fn get_results(&self) -> &[ValidationTestResult] {
577 &self.results
578 }
579
580 pub fn get_results_by_status(&self, status: TestStatus) -> Vec<&ValidationTestResult> {
582 self.results.iter().filter(|r| r.status == status).collect()
583 }
584
585 pub fn get_results_by_severity(&self, severity: SeverityLevel) -> Vec<&ValidationTestResult> {
587 self.results
588 .iter()
589 .filter(|r| r.severity == severity)
590 .collect()
591 }
592
593 pub fn clear_results(&mut self) {
595 self.results.clear();
596 self.vulnerability_stats = VulnerabilityStatistics::default();
597 }
598
599 pub fn get_test(&self, name: &str) -> Option<&InputValidationTest> {
601 self.test_cases.iter().find(|t| t.name == name)
602 }
603
604 pub fn get_tests(&self) -> &[InputValidationTest] {
606 &self.test_cases
607 }
608}
609
610impl Default for InputValidationAnalyzer {
611 fn default() -> Self {
612 Self::new()
613 }
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619
620 #[test]
621 fn test_new_analyzer() {
622 let analyzer = InputValidationAnalyzer::new();
623 assert_eq!(analyzer.test_count(), 0);
624 assert_eq!(analyzer.get_results().len(), 0);
625 }
626
627 #[test]
628 fn test_builtin_tests() {
629 let analyzer = InputValidationAnalyzer::with_builtin_tests();
630 assert!(analyzer.test_count() > 0);
631
632 assert!(analyzer.get_test("NaN Injection Test").is_some());
634 assert!(analyzer.get_test("Dimension Mismatch Test").is_some());
635 assert!(analyzer.get_test("Empty Array Test").is_some());
636 }
637
638 #[test]
639 fn test_add_custom_test() {
640 let mut analyzer = InputValidationAnalyzer::new();
641 let custom_test = InputValidationTest {
642 name: "Custom Test".to_string(),
643 description: "A custom test".to_string(),
644 category: ValidationCategory::MalformedInput,
645 attack_vector: AttackVector::NaNInjection,
646 expected_behavior: ExpectedBehavior::HandleGracefully,
647 payload_generator: PayloadType::NaNPayload,
648 };
649
650 analyzer.add_test(custom_test);
651 assert_eq!(analyzer.test_count(), 1);
652 assert!(analyzer.get_test("Custom Test").is_some());
653 }
654
655 #[test]
656 fn test_run_all_tests() {
657 let mut analyzer = InputValidationAnalyzer::with_builtin_tests();
658 let initial_count = analyzer.test_count();
659
660 let results = analyzer.run_all_tests().expect("unwrap failed");
661 assert_eq!(results.len(), initial_count);
662 assert_eq!(analyzer.get_results().len(), initial_count);
663 }
664
665 #[test]
666 fn test_clear_operations() {
667 let mut analyzer = InputValidationAnalyzer::with_builtin_tests();
668 let _ = analyzer.run_all_tests().expect("unwrap failed");
669
670 assert!(analyzer.test_count() > 0);
671 assert!(!analyzer.get_results().is_empty());
672
673 analyzer.clear_tests();
674 assert_eq!(analyzer.test_count(), 0);
675
676 analyzer.clear_results();
677 assert_eq!(analyzer.get_results().len(), 0);
678 }
679
680 #[test]
681 fn test_statistics_update() {
682 let mut analyzer = InputValidationAnalyzer::with_builtin_tests();
683 let _ = analyzer.run_all_tests().expect("unwrap failed");
684
685 let stats = analyzer.get_statistics();
686 assert!(stats.total_tests > 0);
687 assert_eq!(stats.total_tests, stats.tests_passed + stats.tests_failed);
688 }
689}