Skip to main content

mockforge_intelligence/threat_modeling/
dos_analyzer.rs

1//! Denial of Service (DoS) risk analysis
2//!
3//! This module detects DoS risks in API contracts, such as:
4//! - Unbounded arrays
5//! - Missing pagination
6//! - Deeply nested schemas
7//! - Large payload sizes
8
9use super::types::{ThreatCategory, ThreatFinding, ThreatLevel};
10use mockforge_openapi::OpenApiSpec;
11use openapiv3::ReferenceOr;
12use std::collections::HashMap;
13
14/// DoS risk analyzer
15pub struct DosAnalyzer {
16    /// Maximum array size threshold (default: no limit = risk)
17    max_array_size_threshold: Option<usize>,
18    /// Maximum nesting depth
19    max_nesting_depth: usize,
20}
21
22impl DosAnalyzer {
23    /// Create a new DoS analyzer
24    pub fn new(max_array_size_threshold: Option<usize>, max_nesting_depth: usize) -> Self {
25        Self {
26            max_array_size_threshold,
27            max_nesting_depth,
28        }
29    }
30
31    /// Analyze spec for DoS risks
32    pub fn analyze_dos_risks(&self, spec: &OpenApiSpec) -> Vec<ThreatFinding> {
33        let mut findings = Vec::new();
34
35        for (path, path_item) in &spec.spec.paths.paths {
36            if let ReferenceOr::Item(path_item) = path_item {
37                // Iterate over all HTTP methods
38                let methods = vec![
39                    ("GET", path_item.get.as_ref()),
40                    ("POST", path_item.post.as_ref()),
41                    ("PUT", path_item.put.as_ref()),
42                    ("DELETE", path_item.delete.as_ref()),
43                    ("PATCH", path_item.patch.as_ref()),
44                    ("HEAD", path_item.head.as_ref()),
45                    ("OPTIONS", path_item.options.as_ref()),
46                    ("TRACE", path_item.trace.as_ref()),
47                ];
48
49                for (method, operation_opt) in methods {
50                    let Some(operation) = operation_opt else {
51                        continue;
52                    };
53                    let base_path = format!("{}.{}", method, path);
54
55                    // Analyze request body
56                    if let Some(request_body) = &operation.request_body {
57                        if let Some(ref_or_item) = request_body.as_item() {
58                            for media_type in ref_or_item.content.values() {
59                                if let Some(schema) = &media_type.schema {
60                                    // Convert ReferenceOr<Schema> to ReferenceOr<Box<Schema>>
61                                    let boxed_schema_ref = match schema {
62                                        ReferenceOr::Item(s) => {
63                                            ReferenceOr::Item(Box::new(s.clone()))
64                                        }
65                                        ReferenceOr::Reference { reference } => {
66                                            ReferenceOr::Reference {
67                                                reference: reference.clone(),
68                                            }
69                                        }
70                                    };
71                                    findings.extend(self.analyze_schema_for_dos(
72                                        &boxed_schema_ref,
73                                        &base_path,
74                                        "request",
75                                        0,
76                                    ));
77                                }
78                            }
79                        }
80                    }
81
82                    // Analyze responses
83                    for (status_code, response) in &operation.responses.responses {
84                        if let ReferenceOr::Item(resp) = response {
85                            for media_type in resp.content.values() {
86                                if let Some(schema) = &media_type.schema {
87                                    // Convert ReferenceOr<Schema> to ReferenceOr<Box<Schema>>
88                                    let boxed_schema_ref = match schema {
89                                        ReferenceOr::Item(s) => {
90                                            ReferenceOr::Item(Box::new(s.clone()))
91                                        }
92                                        ReferenceOr::Reference { reference } => {
93                                            ReferenceOr::Reference {
94                                                reference: reference.clone(),
95                                            }
96                                        }
97                                    };
98                                    findings.extend(self.analyze_schema_for_dos(
99                                        &boxed_schema_ref,
100                                        &base_path,
101                                        &format!("response.{}", status_code),
102                                        0,
103                                    ));
104                                }
105                            }
106                        }
107                    }
108                }
109            }
110        }
111
112        findings
113    }
114
115    /// Analyze schema for DoS risks
116    fn analyze_schema_for_dos(
117        &self,
118        schema_ref: &ReferenceOr<Box<openapiv3::Schema>>,
119        base_path: &str,
120        context: &str,
121        depth: usize,
122    ) -> Vec<ThreatFinding> {
123        let mut findings = Vec::new();
124
125        if depth > self.max_nesting_depth {
126            findings.push(ThreatFinding {
127                finding_type: ThreatCategory::DoSRisk,
128                severity: ThreatLevel::Medium,
129                description: format!(
130                    "Schema nesting depth ({}) exceeds recommended maximum ({})",
131                    depth, self.max_nesting_depth
132                ),
133                field_path: Some(base_path.to_string()),
134                context: HashMap::new(),
135                confidence: 0.8,
136            });
137            return findings;
138        }
139
140        if let ReferenceOr::Item(schema) = schema_ref {
141            // Check for unbounded arrays
142            if let openapiv3::SchemaKind::Type(openapiv3::Type::Array(array_type)) =
143                &schema.as_ref().schema_kind
144            {
145                // max_items might be in extensions
146                let max_items =
147                    schema.as_ref().schema_data.extensions.get("maxItems").and_then(|v| v.as_u64());
148
149                if max_items.is_none() && self.max_array_size_threshold.is_none() {
150                    findings.push(ThreatFinding {
151                        finding_type: ThreatCategory::UnboundedArrays,
152                        severity: ThreatLevel::High,
153                        description: format!(
154                            "Unbounded array detected in {} - no maxItems constraint",
155                            context
156                        ),
157                        field_path: Some(base_path.to_string()),
158                        context: HashMap::new(),
159                        confidence: 1.0,
160                    });
161                } else if let Some(threshold) = self.max_array_size_threshold {
162                    if let Some(max) = max_items {
163                        if max > threshold as u64 {
164                            findings.push(ThreatFinding {
165                                finding_type: ThreatCategory::UnboundedArrays,
166                                severity: ThreatLevel::Medium,
167                                description: format!(
168                                    "Array maxItems ({}) exceeds recommended threshold ({})",
169                                    max, threshold
170                                ),
171                                field_path: Some(base_path.to_string()),
172                                context: HashMap::new(),
173                                confidence: 0.7,
174                            });
175                        }
176                    }
177                }
178
179                // Recursively check array items
180                if let Some(items) = &array_type.items {
181                    findings.extend(self.analyze_schema_for_dos(
182                        items,
183                        &format!("{}.items", base_path),
184                        context,
185                        depth + 1,
186                    ));
187                }
188            }
189
190            // Check properties recursively
191            if let openapiv3::SchemaKind::Type(openapiv3::Type::Object(obj_type)) =
192                &schema.as_ref().schema_kind
193            {
194                for (prop_name, prop_schema) in &obj_type.properties {
195                    findings.extend(self.analyze_schema_for_dos(
196                        prop_schema,
197                        &format!("{}.{}", base_path, prop_name),
198                        context,
199                        depth + 1,
200                    ));
201                }
202            }
203        }
204
205        findings
206    }
207}
208
209impl Default for DosAnalyzer {
210    fn default() -> Self {
211        Self::new(None, 10)
212    }
213}