Skip to main content

parsentry_parser/
patterns.rs

1//! Security pattern matching for vulnerability detection.
2
3use parsentry_core::Language;
4use serde::Deserialize;
5use std::collections::HashMap;
6use std::path::Path;
7use streaming_iterator::StreamingIterator;
8use tree_sitter::{Language as TreeSitterLanguage, Parser, Query, QueryCursor};
9
10/// Configuration for a security pattern.
11#[derive(Debug, Clone, Deserialize)]
12pub struct PatternConfig {
13    #[serde(flatten)]
14    pub pattern_type: PatternQuery,
15    pub description: String,
16    pub attack_vector: Vec<String>,
17}
18
19/// Query type for pattern matching.
20#[derive(Debug, Clone, Deserialize)]
21#[serde(untagged)]
22pub enum PatternQuery {
23    Definition { definition: String },
24    Reference { reference: String },
25}
26
27/// Language-specific patterns configuration.
28#[derive(Debug, Clone, Deserialize)]
29pub struct LanguagePatterns {
30    pub principals: Option<Vec<PatternConfig>>,
31    pub actions: Option<Vec<PatternConfig>>,
32    pub resources: Option<Vec<PatternConfig>>,
33}
34
35/// Security risk pattern matcher.
36pub struct SecurityRiskPatterns {
37    definition_queries: Vec<Query>,
38    reference_queries: Vec<Query>,
39    language: TreeSitterLanguage,
40    pattern_configs: Vec<PatternConfig>,
41}
42
43/// A matched security pattern.
44#[derive(Debug, Clone)]
45pub struct PatternMatch {
46    pub pattern_config: PatternConfig,
47    pub start_byte: usize,
48    pub end_byte: usize,
49    pub matched_text: String,
50}
51
52impl SecurityRiskPatterns {
53    /// Create a new pattern matcher for the given language.
54    #[must_use]
55    pub fn new(language: Language) -> Self {
56        Self::new_with_root(language, None)
57    }
58
59    /// Create a new pattern matcher with a custom root directory for patterns.
60    #[must_use]
61    pub fn new_with_root(language: Language, root_dir: Option<&Path>) -> Self {
62        let pattern_map = Self::load_patterns(root_dir);
63        let lang_patterns = pattern_map
64            .get(&language)
65            .or_else(|| pattern_map.get(&Language::Other))
66            .unwrap_or(&LanguagePatterns {
67                principals: None,
68                actions: None,
69                resources: None,
70            });
71
72        let ts_language = Self::get_tree_sitter_language(language);
73
74        let mut definition_queries = Vec::new();
75        let mut reference_queries = Vec::new();
76        let mut pattern_configs = Vec::new();
77
78        // Collect all patterns from principals, actions, and resources into a flat list
79        let all_configs: Vec<&PatternConfig> = lang_patterns
80            .principals
81            .iter()
82            .chain(lang_patterns.actions.iter())
83            .chain(lang_patterns.resources.iter())
84            .flat_map(|v| v.iter())
85            .collect();
86
87        for config in all_configs {
88            pattern_configs.push(config.clone());
89            match &config.pattern_type {
90                PatternQuery::Definition { definition } => {
91                    if let Ok(query) = Query::new(&ts_language, definition) {
92                        definition_queries.push(query);
93                    }
94                }
95                PatternQuery::Reference { reference } => {
96                    if let Ok(query) = Query::new(&ts_language, reference) {
97                        reference_queries.push(query);
98                    }
99                }
100            }
101        }
102
103        Self {
104            definition_queries,
105            reference_queries,
106            language: ts_language,
107            pattern_configs,
108        }
109    }
110
111    fn get_tree_sitter_language(language: Language) -> TreeSitterLanguage {
112        match language {
113            Language::Python => tree_sitter_python::LANGUAGE.into(),
114            Language::JavaScript => tree_sitter_javascript::LANGUAGE.into(),
115            Language::TypeScript => tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into(),
116            Language::Rust => tree_sitter_rust::LANGUAGE.into(),
117            Language::Java => tree_sitter_java::LANGUAGE.into(),
118            Language::Go => tree_sitter_go::LANGUAGE.into(),
119            Language::Ruby => tree_sitter_ruby::LANGUAGE.into(),
120            Language::C => tree_sitter_c::LANGUAGE.into(),
121            Language::Cpp => tree_sitter_cpp::LANGUAGE.into(),
122            Language::Terraform => tree_sitter_hcl::LANGUAGE.into(),
123            Language::Php => tree_sitter_php::LANGUAGE_PHP.into(),
124            Language::Yaml => tree_sitter_yaml::LANGUAGE.into(),
125            _ => tree_sitter_javascript::LANGUAGE.into(),
126        }
127    }
128
129    /// Check if content matches any security pattern.
130    #[must_use]
131    pub fn matches(&self, content: &str) -> bool {
132        let mut parser = Parser::new();
133        if parser.set_language(&self.language).is_err() {
134            return false;
135        }
136
137        let tree = match parser.parse(content, None) {
138            Some(tree) => tree,
139            None => return false,
140        };
141
142        let root_node = tree.root_node();
143
144        let all_queries = [&self.definition_queries, &self.reference_queries];
145
146        for query_set in all_queries {
147            for query in query_set {
148                let mut cursor = QueryCursor::new();
149                let mut matches = cursor.matches(query, root_node, content.as_bytes());
150                while let Some(match_) = matches.next() {
151                    let mut has_valid_capture = false;
152
153                    for capture in match_.captures {
154                        let capture_name = &query.capture_names()[capture.index as usize];
155                        let node = capture.node;
156                        let start_byte = node.start_byte();
157                        let end_byte = node.end_byte();
158                        let matched_text = content[start_byte..end_byte].to_string();
159
160                        if matched_text.trim().len() <= 2 {
161                            continue;
162                        }
163
164                        match *capture_name {
165                            "function" | "definition" | "class" | "method_def" | "call"
166                            | "expression" | "attribute" => {
167                                has_valid_capture = true;
168                            }
169                            "name" | "func" | "attr" | "obj" | "method" => {
170                                has_valid_capture = true;
171                            }
172                            _ => {
173                                has_valid_capture = true;
174                            }
175                        }
176                    }
177
178                    if has_valid_capture {
179                        return true;
180                    }
181                }
182            }
183        }
184
185        false
186    }
187
188    /// Get attack vectors for content.
189    #[must_use]
190    pub fn get_attack_vectors(&self, _content: &str) -> Vec<String> {
191        Vec::new()
192    }
193
194    /// Get all pattern matches in content.
195    #[must_use]
196    pub fn get_pattern_matches(&self, content: &str) -> Vec<PatternMatch> {
197        let mut parser = Parser::new();
198        if parser.set_language(&self.language).is_err() {
199            return Vec::new();
200        }
201
202        let tree = match parser.parse(content, None) {
203            Some(tree) => tree,
204            None => return Vec::new(),
205        };
206
207        let root_node = tree.root_node();
208        let mut pattern_matches = Vec::new();
209        let content_bytes = content.as_bytes();
210
211        let mut process_queries = |queries: &[Query], is_definition: bool| {
212            for (query_idx, query) in queries.iter().enumerate() {
213                let mut cursor = QueryCursor::new();
214                let mut matches = cursor.matches(query, root_node, content_bytes);
215
216                while let Some(match_) = matches.next() {
217                    let mut best_node = None;
218                    let mut best_text = String::new();
219                    let mut best_priority = 0;
220
221                    for capture in match_.captures {
222                        let capture_name = &query.capture_names()[capture.index as usize];
223                        let node = capture.node;
224                        let start_byte = node.start_byte();
225                        let end_byte = node.end_byte();
226                        let matched_text = content[start_byte..end_byte].to_string();
227
228                        if matched_text.trim().len() <= 2 {
229                            continue;
230                        }
231
232                        let (priority, candidate_node, candidate_text) = match *capture_name {
233                            "function" | "definition" | "class" | "method_def" => {
234                                (100, Some(node), matched_text.clone())
235                            }
236                            "call" | "expression" | "attribute" => {
237                                (90, Some(node), matched_text.clone())
238                            }
239                            "name" | "func" | "attr" | "obj" | "method" => {
240                                let mut found_parent = None;
241                                let mut parent = node.parent();
242                                while let Some(p) = parent {
243                                    if (is_definition
244                                        && (p.kind().contains("definition")
245                                            || p.kind().contains("declaration")))
246                                        || (!is_definition
247                                            && (p.kind().contains("call")
248                                                || p.kind().contains("attribute")
249                                                || p.kind().contains("expression")))
250                                    {
251                                        found_parent = Some(p);
252                                        break;
253                                    }
254                                    parent = p.parent();
255                                }
256                                if let Some(p) = found_parent {
257                                    (
258                                        80,
259                                        Some(p),
260                                        content[p.start_byte()..p.end_byte()].to_string(),
261                                    )
262                                } else {
263                                    (70, Some(node), matched_text.clone())
264                                }
265                            }
266                            "param" | "func_name" => {
267                                let mut found_func = None;
268                                let mut parent = node.parent();
269                                while let Some(p) = parent {
270                                    if p.kind() == "function_definition" {
271                                        found_func = Some(p);
272                                        break;
273                                    }
274                                    parent = p.parent();
275                                }
276                                if let Some(p) = found_func {
277                                    (
278                                        85,
279                                        Some(p),
280                                        content[p.start_byte()..p.end_byte()].to_string(),
281                                    )
282                                } else {
283                                    (60, Some(node), matched_text.clone())
284                                }
285                            }
286                            _ => (50, Some(node), matched_text.clone()),
287                        };
288
289                        if priority > best_priority {
290                            best_priority = priority;
291                            best_node = candidate_node;
292                            best_text = candidate_text;
293                        }
294                    }
295
296                    if let Some(node) = best_node {
297                        let start_byte = node.start_byte();
298                        let end_byte = node.end_byte();
299
300                        // Find the matching config by counting definition/reference queries
301                        let mut config_idx = 0;
302                        for config in &self.pattern_configs {
303                            let matches_type = matches!(
304                                (&config.pattern_type, is_definition),
305                                (PatternQuery::Definition { .. }, true)
306                                    | (PatternQuery::Reference { .. }, false)
307                            );
308
309                            if matches_type {
310                                if config_idx == query_idx {
311                                    pattern_matches.push(PatternMatch {
312                                        pattern_config: config.clone(),
313                                        start_byte,
314                                        end_byte,
315                                        matched_text: best_text.clone(),
316                                    });
317                                    break;
318                                }
319                                config_idx += 1;
320                            }
321                        }
322                    }
323                }
324            }
325        };
326
327        process_queries(&self.definition_queries, true);
328        process_queries(&self.reference_queries, false);
329
330        pattern_matches
331    }
332
333    fn load_patterns(root_dir: Option<&Path>) -> HashMap<Language, LanguagePatterns> {
334        use Language::*;
335
336        let mut map = HashMap::new();
337
338        let languages = [
339            (Python, include_str!("patterns/python.yml")),
340            (JavaScript, include_str!("patterns/javascript.yml")),
341            (Rust, include_str!("patterns/rust.yml")),
342            (TypeScript, include_str!("patterns/typescript.yml")),
343            (Java, include_str!("patterns/java.yml")),
344            (Go, include_str!("patterns/go.yml")),
345            (Ruby, include_str!("patterns/ruby.yml")),
346            (C, include_str!("patterns/c.yml")),
347            (Cpp, include_str!("patterns/cpp.yml")),
348            (Php, include_str!("patterns/php.yml")),
349            (Terraform, include_str!("patterns/terraform.yml")),
350        ];
351
352        for (lang, content) in languages {
353            match serde_yaml::from_str::<LanguagePatterns>(content) {
354                Ok(patterns) => {
355                    map.insert(lang, patterns);
356                }
357                Err(e) => {
358                    eprintln!("Failed to parse patterns for {:?}: {}", lang, e);
359                }
360            }
361        }
362
363        // Load CI/CD platform patterns and merge into Yaml language
364        let cicd_patterns = [
365            include_str!("patterns/github-actions.yml"), // GitHub Actions
366            include_str!("patterns/gitlab-ci.yml"),
367            include_str!("patterns/circleci.yml"),
368            include_str!("patterns/travis.yml"),
369            include_str!("patterns/jenkins.yml"),
370        ];
371
372        let mut merged_yaml_patterns = LanguagePatterns {
373            principals: Some(Vec::new()),
374            actions: Some(Vec::new()),
375            resources: Some(Vec::new()),
376        };
377
378        for content in cicd_patterns {
379            if let Ok(patterns) = serde_yaml::from_str::<LanguagePatterns>(content) {
380                if let Some(principals) = patterns.principals {
381                    merged_yaml_patterns
382                        .principals
383                        .as_mut()
384                        .unwrap()
385                        .extend(principals);
386                }
387                if let Some(actions) = patterns.actions {
388                    merged_yaml_patterns
389                        .actions
390                        .as_mut()
391                        .unwrap()
392                        .extend(actions);
393                }
394                if let Some(resources) = patterns.resources {
395                    merged_yaml_patterns
396                        .resources
397                        .as_mut()
398                        .unwrap()
399                        .extend(resources);
400                }
401            }
402        }
403
404        map.insert(Yaml, merged_yaml_patterns);
405
406        Self::load_custom_patterns(&mut map, root_dir);
407
408        map
409    }
410
411    /// Add dynamic queries (e.g. from threat model) at runtime.
412    /// `query_type` is "definition" or "reference".
413    pub fn add_query(
414        &mut self,
415        query_type: &str,
416        query_str: &str,
417        description: &str,
418        attack_vector: Vec<String>,
419    ) -> bool {
420        let query = match Query::new(&self.language, query_str) {
421            Ok(q) => q,
422            Err(_) => return false,
423        };
424
425        let is_definition = query_type == "definition";
426        let pattern_query = if is_definition {
427            PatternQuery::Definition {
428                definition: query_str.to_string(),
429            }
430        } else {
431            PatternQuery::Reference {
432                reference: query_str.to_string(),
433            }
434        };
435
436        let config = PatternConfig {
437            pattern_type: pattern_query,
438            description: description.to_string(),
439            attack_vector,
440        };
441
442        self.pattern_configs.push(config);
443
444        if is_definition {
445            self.definition_queries.push(query);
446        } else {
447            self.reference_queries.push(query);
448        }
449
450        true
451    }
452
453    fn load_custom_patterns(
454        map: &mut HashMap<Language, LanguagePatterns>,
455        root_dir: Option<&Path>,
456    ) {
457        let vuln_patterns_path = if let Some(root) = root_dir {
458            root.join("vuln-patterns.yml")
459        } else {
460            Path::new("vuln-patterns.yml").to_path_buf()
461        };
462
463        if vuln_patterns_path.exists() {
464            match std::fs::read_to_string(&vuln_patterns_path) {
465                Ok(content) => {
466                    match serde_yaml::from_str::<HashMap<String, LanguagePatterns>>(&content) {
467                        Ok(custom_patterns) => {
468                            for (lang_name, patterns) in custom_patterns {
469                                let language = match lang_name.as_str() {
470                                    "Python" => Language::Python,
471                                    "JavaScript" => Language::JavaScript,
472                                    "TypeScript" => Language::TypeScript,
473                                    "Rust" => Language::Rust,
474                                    "Java" => Language::Java,
475                                    "Go" => Language::Go,
476                                    "Ruby" => Language::Ruby,
477                                    "C" => Language::C,
478                                    "Cpp" => Language::Cpp,
479                                    "Terraform" => Language::Terraform,
480                                    "CloudFormation" => Language::CloudFormation,
481                                    "Kubernetes" => Language::Kubernetes,
482                                    "YAML" => Language::Yaml,
483                                    "GitLabCI" => Language::Yaml,
484                                    "CircleCI" => Language::Yaml,
485                                    "TravisCI" => Language::Yaml,
486                                    "Jenkins" => Language::Yaml,
487                                    "Bash" => Language::Bash,
488                                    "Shell" => Language::Shell,
489                                    "Php" | "PHP" => Language::Php,
490                                    _ => continue,
491                                };
492
493                                match map.get_mut(&language) {
494                                    Some(existing) => {
495                                        if let Some(custom_principals) = patterns.principals {
496                                            match &mut existing.principals {
497                                                Some(principals) => {
498                                                    principals.extend(custom_principals)
499                                                }
500                                                None => {
501                                                    existing.principals = Some(custom_principals)
502                                                }
503                                            }
504                                        }
505                                        if let Some(custom_actions) = patterns.actions {
506                                            match &mut existing.actions {
507                                                Some(actions) => actions.extend(custom_actions),
508                                                None => existing.actions = Some(custom_actions),
509                                            }
510                                        }
511                                        if let Some(custom_resources) = patterns.resources {
512                                            match &mut existing.resources {
513                                                Some(resources) => {
514                                                    resources.extend(custom_resources)
515                                                }
516                                                None => existing.resources = Some(custom_resources),
517                                            }
518                                        }
519                                    }
520                                    None => {
521                                        map.insert(language, patterns);
522                                    }
523                                }
524                            }
525                        }
526                        Err(e) => {
527                            eprintln!("Failed to parse vuln-patterns.yml: {}", e);
528                        }
529                    }
530                }
531                Err(e) => {
532                    eprintln!("Failed to read vuln-patterns.yml: {}", e);
533                }
534            }
535        }
536    }
537}