Skip to main content

ai_coding_shield/catalog/
mod.rs

1pub mod threats;
2
3use anyhow::Result;
4use colored::*;
5use serde::{Deserialize, Serialize};
6use std::fs;
7use std::path::PathBuf;
8
9pub use threats::*;
10
11const CATALOG_FILE: &str = "threats.yaml";
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ThreatCatalog {
15    pub version: String,
16    pub last_updated: String,
17    pub categories: Vec<ThreatCategory>,
18    #[serde(default)]
19    pub trusted_authors: Vec<String>,
20    #[serde(default)]
21    pub trusted_domains: Vec<String>,
22}
23
24impl ThreatCatalog {
25    /// Load threat catalog from embedded or local file
26    pub fn load() -> Result<Self> {
27        // Try to load from local file first
28        let local_path = Self::local_catalog_path();
29        
30        if local_path.exists() {
31            let content = fs::read_to_string(&local_path)?;
32            let catalog: ThreatCatalog = serde_yaml::from_str(&content)?;
33            return Ok(catalog);
34        }
35
36        // Fall back to embedded catalog
37        let embedded = include_str!("../../config/threats.yaml");
38        let catalog: ThreatCatalog = serde_yaml::from_str(embedded)?;
39        Ok(catalog)
40    }
41
42    fn local_catalog_path() -> PathBuf {
43        dirs::config_dir()
44            .unwrap_or_else(|| PathBuf::from("."))
45            .join("ai-coding-shield")
46            .join("config")
47            .join(CATALOG_FILE)
48    }
49
50    /// Get all rules across all categories
51    pub fn all_rules(&self) -> Vec<&ThreatRule> {
52        self.categories
53            .iter()
54            .flat_map(|cat| cat.rules.iter())
55            .collect()
56    }
57
58    /// Find a specific threat by ID
59    pub fn find_threat(&self, threat_id: &str) -> Option<(&ThreatCategory, &ThreatRule)> {
60        for category in &self.categories {
61            for rule in &category.rules {
62                if rule.id == threat_id {
63                    return Some((category, rule));
64                }
65            }
66        }
67        None
68    }
69
70    /// Show detailed information about a threat
71    pub fn show_threat_info(&self, threat_id: &str) -> Result<()> {
72        match self.find_threat(threat_id) {
73            Some((category, rule)) => {
74                println!("{}", format!("🔍 Threat: {}", rule.id).bold().cyan());
75                println!("{}", "━".repeat(50).cyan());
76                println!();
77                println!("{}: {}", "Name".bold(), rule.description);
78                println!("{}: {}", "Category".bold(), category.name);
79                println!("{}: {}", "Severity".bold(), rule.severity.as_str().red());
80                println!();
81                
82                if let Some(mitre) = &category.mitre_id {
83                    println!("{}: {}", "MITRE ATT&CK".bold(), mitre);
84                }
85                if let Some(cwe) = &category.cwe_id {
86                    println!("{}: {}", "CWE".bold(), cwe);
87                }
88                println!();
89                
90                println!("{}", "Pattern:".bold());
91                println!("  {}", rule.pattern);
92                println!();
93                
94                if !rule.examples.is_empty() {
95                    println!("{}", "Examples:".bold());
96                    for example in &rule.examples {
97                        println!("  • {}", example.dimmed());
98                    }
99                    println!();
100                }
101                
102                println!("{}", "Remediation:".bold());
103                println!("  {}", rule.remediation);
104                
105                Ok(())
106            }
107            None => {
108                eprintln!("{}", format!("❌ Threat ID '{}' not found", threat_id).red());
109                Err(anyhow::anyhow!("Threat not found"))
110            }
111        }
112    }
113
114    /// Save current catalog to local config file
115    pub fn save(&self) -> Result<()> {
116        let local_path = Self::local_catalog_path();
117        
118        if let Some(parent) = local_path.parent() {
119            fs::create_dir_all(parent)?;
120        }
121        
122        let content = serde_yaml::to_string(self)?;
123        fs::write(local_path, content)?;
124        Ok(())
125    }
126
127    pub fn add_trusted_author(&mut self, author: String) {
128        if !self.trusted_authors.contains(&author) {
129            self.trusted_authors.push(author);
130        }
131    }
132
133    pub fn remove_trusted_author(&mut self, author: &str) -> bool {
134        if let Some(pos) = self.trusted_authors.iter().position(|x| x == author) {
135            self.trusted_authors.remove(pos);
136            true
137        } else {
138            false
139        }
140    }
141
142    pub fn add_trusted_domain(&mut self, domain: String) {
143        if !self.trusted_domains.contains(&domain) {
144            self.trusted_domains.push(domain);
145        }
146    }
147
148    pub fn remove_trusted_domain(&mut self, domain: &str) -> bool {
149        if let Some(pos) = self.trusted_domains.iter().position(|x| x == domain) {
150            self.trusted_domains.remove(pos);
151            true
152        } else {
153            false
154        }
155    }
156
157    /// List all threats, optionally filtered
158    pub fn list_threats(&self, category_filter: Option<&str>, severity_filter: Option<&str>) -> Result<()> {
159        println!("{}", "📋 Threat Catalog".bold().cyan());
160        println!("{}", "━".repeat(50).cyan());
161        println!();
162
163        for category in &self.categories {
164            if let Some(filter) = category_filter {
165                if !category.id.to_lowercase().contains(&filter.to_lowercase()) {
166                    continue;
167                }
168            }
169
170            println!("{}", format!("▸ {}", category.name).bold());
171            
172            for rule in &category.rules {
173                if let Some(filter) = severity_filter {
174                    if rule.severity.as_str().to_lowercase() != filter.to_lowercase() {
175                        continue;
176                    }
177                }
178
179                let severity_color = match rule.severity {
180                    RiskLevel::Critical => "CRITICAL".red().bold(),
181                    RiskLevel::High => "HIGH".yellow().bold(),
182                    RiskLevel::Medium => "MEDIUM".blue(),
183                    RiskLevel::Low => "LOW".green(),
184                };
185
186                println!("  {} [{}] {}", 
187                    rule.id.cyan(), 
188                    severity_color,
189                    rule.description.dimmed()
190                );
191            }
192            println!();
193        }
194
195        Ok(())
196    }
197}
198
199
200pub fn update_catalog(_force: bool) -> Result<()> {
201    // TODO: Implement fetching from MITRE, CWE, etc.
202    // For now, just a placeholder
203    println!("⚠️  Threat catalog update not yet implemented");
204    println!("   This will fetch from:");
205    println!("   • MITRE ATT&CK");
206    println!("   • CWE Database");
207    println!("   • GitHub Security Advisories");
208    Ok(())
209}
210
211use crate::types::RiskLevel;