Skip to main content

ai_coding_shield/analyzer/
mod.rs

1use anyhow::Result;
2use regex::Regex;
3
4mod packages; // Add this line
5
6use crate::catalog::ThreatCatalog;
7use crate::types::{Artifact, AnalysisResult, Finding, RiskLevel};
8
9use crate::scanner::vuln_db::VulnerabilityScanner;
10
11pub struct Analyzer {
12    catalog: ThreatCatalog,
13    vuln_scanner: VulnerabilityScanner,
14}
15
16impl Analyzer {
17    pub fn new(catalog: ThreatCatalog) -> Self {
18        Self { 
19            catalog,
20            vuln_scanner: VulnerabilityScanner::new(),
21        }
22    }
23
24    pub async fn analyze(&self, artifacts: &[Artifact]) -> Result<Vec<AnalysisResult>> {
25        let mut results = Vec::new();
26
27        for artifact in artifacts {
28            let result = self.analyze_artifact(artifact).await?;
29            results.push(result);
30        }
31
32        Ok(results)
33    }
34
35    async fn analyze_artifact(&self, artifact: &Artifact) -> Result<AnalysisResult> {
36        let mut result = AnalysisResult::new(artifact.clone());
37
38        // Get all threat rules
39        let rules = self.catalog.all_rules();
40
41        // Analyze each command in the artifact
42        for (cmd_idx, command) in artifact.metadata.commands.iter().enumerate() {
43            for rule in &rules {
44                // Compile regex pattern
45                let pattern = match Regex::new(&rule.pattern) {
46                    Ok(p) => p,
47                    Err(_) => continue, // Skip invalid patterns
48                };
49
50                // Check if command matches the pattern
51                if pattern.is_match(command) {
52                    // Check for false positives
53                    let is_false_positive = rule.false_positives.iter().any(|fp| {
54                        Regex::new(fp).map(|r| r.is_match(command)).unwrap_or(false)
55                    });
56
57                    if is_false_positive {
58                        continue;
59                    }
60
61                    // Calculate risk score with context
62                    let has_network = command.contains("curl") 
63                        || command.contains("wget") 
64                        || command.contains("nc");
65                    
66                    let mut risk_score = rule.calculate_score(
67                        artifact.metadata.has_auto_run,
68                        has_network,
69                    );
70
71                    // Determine risk level from score
72                    let mut risk_level = if risk_score >= 80 {
73                        RiskLevel::Critical
74                    } else if risk_score >= 60 {
75                        RiskLevel::High
76                    } else if risk_score >= 30 {
77                        RiskLevel::Medium
78                    } else {
79                        RiskLevel::Low
80                    };
81
82                    // Build flags
83                    let mut flags = Vec::new();
84                    if artifact.metadata.has_auto_run {
85                        flags.push("Auto-run enabled".to_string());
86                    }
87                    if !artifact.metadata.has_turbo_annotation && artifact.metadata.has_auto_run {
88                        flags.push("Auto-run without turbo annotation".to_string());
89                    }
90
91                    // Find the category for MITRE/CWE info
92                    let (mitre_id, cwe_id) = self.catalog.categories.iter()
93                        .find(|cat| cat.rules.iter().any(|r| r.id == rule.id))
94                        .map(|cat| (cat.mitre_id.clone(), cat.cwe_id.clone()))
95                        .unwrap_or((None, None));
96
97                    // DYNAMIC ANALYSIS: Package extraction
98                    let mut extra_desc = String::new();
99                    
100                    if rule.id.starts_with("PKG") || rule.id.starts_with("TYPO") {
101                        if let Some(typo) = self.check_command_for_typos(command) {
102                            extra_desc = format!("\nPossible typosquatting detected! Did you mean '{}'?", typo);
103                            // Escalate risk if typo found
104                            if risk_score < 90 {
105                                risk_score = 90; // Critical
106                                risk_level = RiskLevel::Critical;
107                            }
108                        }
109                    }
110
111                    // DYNAMIC: Vulnerability Check (only if command is package install)
112                    if rule.id.starts_with("PKG") {
113                         if let Some((eco, pkg)) = self.extract_package_info(command) {
114                             // Only check if we haven't found critical risk yet (optimization)
115                             if let Ok(Some(vuln_info)) = self.vuln_scanner.check_package(eco, pkg).await {
116                                 extra_desc.push_str(&format!("\n❌ VULNERABILITY ALERT: {}", vuln_info));
117                                 if risk_score < 60 {
118                                     risk_score = 60; // High
119                                     risk_level = RiskLevel::High;
120                                 }
121                             }
122                         }
123                    }
124
125                    let finding = Finding {
126                        artifact_path: artifact.path.clone(),
127                        threat_id: rule.id.clone(),
128                        threat_name: rule.description.clone(),
129                        risk_level,
130                        risk_score,
131                        line_number: Some(cmd_idx + 1),
132                        matched_pattern: command.clone(),
133                        description: format!("{}{}", rule.description, extra_desc),
134                        recommendation: rule.remediation.clone(),
135                        mitre_id,
136                        cwe_id,
137                        flags,
138                    };
139
140                    result.add_finding(finding);
141                }
142            }
143        }
144
145        Ok(result)
146    }
147
148    fn check_command_for_typos(&self, command: &str) -> Option<String> {
149        // Simple extraction for MVP
150        // npm install package
151        if command.contains("npm install") || command.contains("npm i ") {
152            let parts: Vec<&str> = command.split_whitespace().collect();
153            for part in parts {
154                if part.starts_with("-") || part == "npm" || part == "install" || part == "i" {
155                    continue;
156                }
157                // Check extracted word
158                if let Some(target) = packages::check_typosquatting(part, &packages::POPULAR_NPM) {
159                    return Some(format!("{} (npm matched {})", part, target));
160                }
161            }
162        }
163        
164        // pip install package
165        if command.contains("pip install") {
166            let parts: Vec<&str> = command.split_whitespace().collect();
167            for part in parts {
168                if part.starts_with("-") || part == "pip" || part == "install" {
169                    continue;
170                }
171                if let Some(target) = packages::check_typosquatting(part, &packages::POPULAR_PIP) {
172                    return Some(format!("{} (pip matched {})", part, target));
173                }
174            }
175        }
176
177        // cargo install package
178        if command.contains("cargo install") {
179             let parts: Vec<&str> = command.split_whitespace().collect();
180            for part in parts {
181                if part.starts_with("-") || part == "cargo" || part == "install" {
182                    continue;
183                }
184                if let Some(target) = packages::check_typosquatting(part, &packages::POPULAR_CRATES) {
185                    return Some(format!("{} (cargo matched {})", part, target));
186                }
187            }
188        }
189
190        None
191    }
192
193    fn extract_package_info<'a>(&self, command: &'a str) -> Option<(&'static str, &'a str)> {
194        let parts: Vec<&str> = command.split_whitespace().collect();
195        if command.contains("npm install") || command.contains("npm i ") {
196            for part in &parts {
197                if !part.starts_with("-") && *part != "npm" && *part != "install" && *part != "i" {
198                    return Some(("npm", part));
199                }
200            }
201        }
202        if command.contains("pip install") {
203             for part in &parts {
204                if !part.starts_with("-") && *part != "pip" && *part != "install" {
205                    return Some(("PyPI", part));
206                }
207            }
208        }
209        if command.contains("cargo install") {
210             for part in &parts {
211                if !part.starts_with("-") && *part != "cargo" && *part != "install" {
212                    return Some(("crates.io", part)); // OSV uses crates.io
213                }
214            }
215        }
216        None
217    }
218}
219