1use crate::{CweMapping, CvssV3Mapping, VrtNode, VrtTaxonomy};
2
3#[derive(Debug, Clone)]
5pub struct CategorizedFinding {
6 pub vrt_id: String,
8 pub vrt_name: String,
10 pub priority: Option<u8>,
12 pub category_path: Vec<String>,
14 pub cwes: Vec<String>,
16 pub cvss_vector: Option<String>,
18}
19
20pub struct VulnerabilityCategorizer {
22 vrt: VrtTaxonomy,
23 cwe_mapping: Option<CweMapping>,
24 cvss_mapping: Option<CvssV3Mapping>,
25}
26
27impl VulnerabilityCategorizer {
28 pub fn new(vrt: VrtTaxonomy) -> Self {
30 Self {
31 vrt,
32 cwe_mapping: None,
33 cvss_mapping: None,
34 }
35 }
36
37 pub fn with_all_mappings(
39 vrt: VrtTaxonomy,
40 cwe_mapping: CweMapping,
41 cvss_mapping: CvssV3Mapping,
42 ) -> Self {
43 Self {
44 vrt,
45 cwe_mapping: Some(cwe_mapping),
46 cvss_mapping: Some(cvss_mapping),
47 }
48 }
49
50 pub fn with_cwe_mapping(mut self, cwe_mapping: CweMapping) -> Self {
52 self.cwe_mapping = Some(cwe_mapping);
53 self
54 }
55
56 pub fn with_cvss_mapping(mut self, cvss_mapping: CvssV3Mapping) -> Self {
58 self.cvss_mapping = Some(cvss_mapping);
59 self
60 }
61
62 pub fn categorize_by_id(&self, vrt_id: &str) -> Option<CategorizedFinding> {
76 let (node, path) = self.find_node_with_path(vrt_id)?;
78
79 let cwes = self
81 .cwe_mapping
82 .as_ref()
83 .and_then(|m| m.lookup_cwe(vrt_id))
84 .map(|cwes| cwes.iter().map(|c| c.as_str().to_string()).collect())
85 .unwrap_or_default();
86
87 let cvss_vector = self
89 .cvss_mapping
90 .as_ref()
91 .and_then(|m| m.lookup_cvss(vrt_id))
92 .map(|v| v.to_string());
93
94 Some(CategorizedFinding {
95 vrt_id: node.id.clone(),
96 vrt_name: node.name.clone(),
97 priority: node.priority,
98 category_path: path,
99 cwes,
100 cvss_vector,
101 })
102 }
103
104 pub fn search_by_name(&self, query: &str) -> Vec<String> {
121 let query_lower = query.to_lowercase();
122 let mut results = Vec::new();
123
124 for category in &self.vrt {
125 self.search_node_by_name(&query_lower, category, &mut results);
126 }
127
128 results
129 }
130
131 pub fn categorize_by_description(&self, description: &str) -> Option<CategorizedFinding> {
148 let description_lower = description.to_lowercase();
149
150 let keyword_mappings = self.build_keyword_mappings();
152
153 let mut best_match: Option<(&str, usize)> = None;
155
156 for (vrt_id, keywords) in &keyword_mappings {
157 let mut score = 0;
158 for keyword in keywords {
159 if description_lower.contains(keyword) {
160 score += keyword.len(); }
162 }
163
164 if score > 0 {
165 if let Some((_, best_score)) = best_match {
166 if score > best_score {
167 best_match = Some((vrt_id, score));
168 }
169 } else {
170 best_match = Some((vrt_id, score));
171 }
172 }
173 }
174
175 best_match.and_then(|(vrt_id, _)| self.categorize_by_id(vrt_id))
176 }
177
178 pub fn list_all_variants(&self) -> Vec<String> {
180 let mut variants = Vec::new();
181 for category in &self.vrt {
182 self.collect_variant_ids(category, &mut variants);
183 }
184 variants
185 }
186
187 pub fn get_all_categorizations(&self) -> Vec<CategorizedFinding> {
189 let mut findings = Vec::new();
190 for variant_id in self.list_all_variants() {
191 if let Some(finding) = self.categorize_by_id(&variant_id) {
192 findings.push(finding);
193 }
194 }
195 findings
196 }
197
198 fn find_node_with_path(&self, vrt_id: &str) -> Option<(&VrtNode, Vec<String>)> {
201 for category in &self.vrt {
202 let mut path = vec![category.name.clone()];
203 if let Some((node, mut node_path)) =
204 self.find_node_recursive(vrt_id, category, &path)
205 {
206 path.append(&mut node_path);
207 return Some((node, path));
208 }
209 }
210 None
211 }
212
213 fn find_node_recursive<'a>(
214 &self,
215 vrt_id: &str,
216 node: &'a VrtNode,
217 current_path: &[String],
218 ) -> Option<(&'a VrtNode, Vec<String>)> {
219 if node.id == vrt_id {
220 return Some((node, vec![]));
221 }
222
223 for child in &node.children {
224 let mut path = current_path.to_vec();
225 path.push(child.name.clone());
226
227 if child.id == vrt_id {
228 return Some((child, vec![child.name.clone()]));
229 }
230
231 if let Some((found, mut subpath)) = self.find_node_recursive(vrt_id, child, &path) {
232 let mut result_path = vec![child.name.clone()];
233 result_path.append(&mut subpath);
234 return Some((found, result_path));
235 }
236 }
237
238 None
239 }
240
241 fn search_node_by_name(&self, query: &str, node: &VrtNode, results: &mut Vec<String>) {
242 if node.name.to_lowercase().contains(query) || node.id.contains(query) {
243 results.push(node.id.clone());
244 }
245
246 for child in &node.children {
247 self.search_node_by_name(query, child, results);
248 }
249 }
250
251 fn collect_variant_ids(&self, node: &VrtNode, variants: &mut Vec<String>) {
252 if node.is_variant() {
253 variants.push(node.id.clone());
254 }
255
256 for child in &node.children {
257 self.collect_variant_ids(child, variants);
258 }
259 }
260
261 fn build_keyword_mappings(&self) -> Vec<(&str, Vec<&str>)> {
262 vec![
263 ("sql_injection", vec!["sql injection", "sqli", "sql"]),
265 (
266 "cross_site_scripting_xss",
267 vec!["xss", "cross-site scripting", "cross site scripting"],
268 ),
269 (
270 "server_side_request_forgery_ssrf",
271 vec!["ssrf", "server-side request forgery", "server side request forgery"],
272 ),
273 ("remote_code_execution_rce", vec!["rce", "remote code execution", "code execution"]),
274 ("command_injection", vec!["command injection", "os command"]),
275 ("ldap_injection", vec!["ldap injection", "ldap"]),
276 ("xml_external_entity_injection_xxe", vec!["xxe", "xml external entity"]),
277 ("idor", vec!["idor", "insecure direct object", "object reference"]),
279 ("broken_access_control", vec!["access control", "authorization"]),
280 ("privilege_escalation", vec!["privilege escalation", "privesc"]),
281 ("csrf", vec!["csrf", "cross-site request forgery", "cross site request"]),
283 ("authentication_bypass", vec!["auth bypass", "authentication bypass"]),
285 ("session_fixation", vec!["session fixation"]),
286 ("weak_login_function", vec!["weak login", "plaintext password"]),
287 ("weak_hash", vec!["weak hash", "md5", "sha1"]),
289 ("insecure_ssl", vec!["weak ssl", "weak tls", "insecure ssl"]),
290 ("disclosure_of_secrets", vec!["secret disclosure", "credential leak", "api key"]),
292 ("visible_detailed_error_page", vec!["stack trace", "error page", "debug"]),
293 ("path_traversal", vec!["path traversal", "directory traversal", "../"]),
295 ("unsafe_file_upload", vec!["file upload", "upload"]),
296 ("clickjacking", vec!["clickjacking", "iframe", "x-frame-options"]),
298 ("open_redirect", vec!["open redirect", "unvalidated redirect"]),
300 ]
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use crate::{load_vrt_from_str, VrtNodeType};
308
309 fn create_test_taxonomy() -> VrtTaxonomy {
310 vec![VrtNode {
311 id: "server_side_injection".to_string(),
312 name: "Server-Side Injection".to_string(),
313 node_type: VrtNodeType::Category,
314 children: vec![VrtNode {
315 id: "sql_injection".to_string(),
316 name: "SQL Injection".to_string(),
317 node_type: VrtNodeType::Variant,
318 children: vec![],
319 priority: Some(1),
320 }],
321 priority: None,
322 }]
323 }
324
325 #[test]
326 fn test_categorize_by_id() {
327 let vrt = create_test_taxonomy();
328 let categorizer = VulnerabilityCategorizer::new(vrt);
329
330 let finding = categorizer
331 .categorize_by_id("sql_injection")
332 .expect("Should find SQL injection");
333
334 assert_eq!(finding.vrt_id, "sql_injection");
335 assert_eq!(finding.vrt_name, "SQL Injection");
336 assert_eq!(finding.priority, Some(1));
337 }
338
339 #[test]
340 fn test_search_by_name() {
341 let vrt = create_test_taxonomy();
342 let categorizer = VulnerabilityCategorizer::new(vrt);
343
344 let results = categorizer.search_by_name("sql");
345 assert!(results.contains(&"sql_injection".to_string()));
346 }
347
348 #[test]
349 fn test_categorize_by_description() {
350 let vrt = create_test_taxonomy();
351 let categorizer = VulnerabilityCategorizer::new(vrt);
352
353 let finding = categorizer
354 .categorize_by_description("SQL Injection detected in login form")
355 .expect("Should categorize");
356
357 assert_eq!(finding.vrt_id, "sql_injection");
358 }
359}