Skip to main content

agentshield/config/
mod.rs

1use std::path::{Component, Path};
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{Result, ShieldError};
6use crate::rules::policy::Policy;
7
8/// Top-level configuration from `.agentshield.toml`.
9#[derive(Debug, Clone, Default, Serialize, Deserialize)]
10pub struct Config {
11    #[serde(default)]
12    pub policy: Policy,
13    #[serde(default)]
14    pub scan: ScanConfig,
15    #[serde(default)]
16    pub runtime: RuntimeConfig,
17}
18
19/// `[scan]` section of the config file.
20#[derive(Debug, Clone, Default, Serialize, Deserialize)]
21pub struct ScanConfig {
22    /// Skip test files when true.
23    #[serde(default)]
24    pub ignore_tests: bool,
25    #[serde(default)]
26    pub include: Vec<String>,
27    #[serde(default)]
28    pub exclude: Vec<String>,
29}
30
31#[derive(Debug, Clone, Default, PartialEq, Eq)]
32pub struct ScanPathFilterSummary {
33    pub include: Vec<String>,
34    pub exclude: Vec<String>,
35}
36
37#[derive(Debug, Clone)]
38pub struct ScanPathFilter {
39    ignore_tests: bool,
40    include: Vec<CompiledPathPattern>,
41    exclude: Vec<CompiledPathPattern>,
42}
43
44#[derive(Debug, Clone)]
45struct CompiledPathPattern {
46    raw: String,
47    patterns: Vec<glob::Pattern>,
48}
49
50const PATH_PATTERN_MATCH_OPTIONS: glob::MatchOptions = glob::MatchOptions {
51    case_sensitive: true,
52    require_literal_separator: true,
53    require_literal_leading_dot: false,
54};
55
56impl ScanPathFilter {
57    pub fn for_ignore_tests(ignore_tests: bool) -> Self {
58        Self {
59            ignore_tests,
60            include: Vec::new(),
61            exclude: Vec::new(),
62        }
63    }
64
65    pub fn from_scan_config(config: &ScanConfig, ignore_tests: bool) -> Result<Self> {
66        Ok(Self {
67            ignore_tests,
68            include: compile_path_patterns("scan.include", &config.include)?,
69            exclude: compile_path_patterns("scan.exclude", &config.exclude)?,
70        })
71    }
72
73    pub const fn ignore_tests(&self) -> bool {
74        self.ignore_tests
75    }
76
77    pub fn allows_path(&self, root: &Path, path: &Path) -> bool {
78        let relative = relative_path(root, path);
79        let included = self.include.is_empty()
80            || self
81                .include
82                .iter()
83                .any(|pattern| pattern.matches(&relative));
84        let excluded = self
85            .exclude
86            .iter()
87            .any(|pattern| pattern.matches(&relative));
88
89        included && !excluded
90    }
91
92    pub fn summary(&self) -> ScanPathFilterSummary {
93        ScanPathFilterSummary {
94            include: self
95                .include
96                .iter()
97                .map(|pattern| pattern.raw.clone())
98                .collect(),
99            exclude: self
100                .exclude
101                .iter()
102                .map(|pattern| pattern.raw.clone())
103                .collect(),
104        }
105    }
106}
107
108impl CompiledPathPattern {
109    fn new(section: &str, raw: &str) -> Result<Self> {
110        let normalized = normalize_config_pattern(raw);
111        if normalized.is_empty() {
112            return Err(ShieldError::Config(format!(
113                "{section} pattern must not be empty"
114            )));
115        }
116        let patterns = expand_config_pattern(&normalized)
117            .into_iter()
118            .map(|pattern| {
119                glob::Pattern::new(&pattern).map_err(|err| {
120                    ShieldError::Config(format!("invalid {section} pattern '{raw}': {err}"))
121                })
122            })
123            .collect::<Result<Vec<_>>>()?;
124
125        Ok(Self {
126            raw: raw.to_string(),
127            patterns,
128        })
129    }
130
131    fn matches(&self, relative_path: &str) -> bool {
132        self.patterns
133            .iter()
134            .any(|pattern| pattern.matches_with(relative_path, PATH_PATTERN_MATCH_OPTIONS))
135    }
136}
137
138fn compile_path_patterns(section: &str, patterns: &[String]) -> Result<Vec<CompiledPathPattern>> {
139    patterns
140        .iter()
141        .map(|pattern| CompiledPathPattern::new(section, pattern))
142        .collect()
143}
144
145fn normalize_config_pattern(pattern: &str) -> String {
146    let mut normalized = pattern.trim().replace('\\', "/");
147    normalized = normalized.trim_start_matches('/').to_string();
148    while let Some(stripped) = normalized.strip_prefix("./") {
149        normalized = stripped.to_string();
150    }
151    while normalized.contains("//") {
152        normalized = normalized.replace("//", "/");
153    }
154    if normalized.ends_with('/') {
155        normalized.push_str("**");
156    }
157    normalized
158}
159
160fn expand_config_pattern(pattern: &str) -> Vec<String> {
161    let mut patterns = vec![pattern.to_string()];
162    if let Some(root_pattern) = pattern.strip_prefix("**/") {
163        if !root_pattern.is_empty() {
164            patterns.push(root_pattern.to_string());
165        }
166    }
167    patterns
168}
169
170fn relative_path(root: &Path, path: &Path) -> String {
171    let canonical_root = root.canonicalize().unwrap_or_else(|_| root.to_path_buf());
172    let canonical_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
173    let relative = canonical_path
174        .strip_prefix(&canonical_root)
175        .or_else(|_| path.strip_prefix(root))
176        .unwrap_or(path);
177    let parts: Vec<String> = relative
178        .components()
179        .filter_map(|component| match component {
180            Component::Normal(part) => Some(part.to_string_lossy().into_owned()),
181            Component::CurDir => None,
182            Component::ParentDir => Some("..".to_string()),
183            Component::RootDir | Component::Prefix(_) => None,
184        })
185        .collect();
186
187    parts.join("/")
188}
189
190/// `[runtime]` section of the config file.
191#[derive(Debug, Clone, Default, Serialize, Deserialize)]
192pub struct RuntimeConfig {
193    #[serde(default)]
194    pub proxy: RuntimeProxyConfig,
195}
196
197/// Blocking threshold for the MCP proxy guard.
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
199#[serde(rename_all = "lowercase")]
200pub enum ProxyFailOn {
201    /// Block only `block` verdicts (default).
202    #[default]
203    Block,
204    /// Block `warn` and `block` verdicts.
205    Warn,
206    /// Never block; still evaluated and audited.
207    Never,
208}
209
210/// Per-tool proxy policy override.
211#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct ProxyToolOverride {
213    pub name: String,
214    #[serde(default)]
215    pub fail_on: ProxyFailOn,
216}
217
218/// `[runtime.proxy]` section: MCP proxy guard policy.
219#[derive(Debug, Clone, Default, Serialize, Deserialize)]
220pub struct RuntimeProxyConfig {
221    #[serde(default)]
222    pub fail_on: ProxyFailOn,
223    #[serde(default, rename = "tool")]
224    pub tool_overrides: Vec<ProxyToolOverride>,
225}
226
227impl Config {
228    /// Load config from a TOML file. Returns default if file doesn't exist.
229    pub fn load(path: &Path) -> Result<Self> {
230        if !path.exists() {
231            return Ok(Self::default());
232        }
233        let content = std::fs::read_to_string(path)?;
234        let config: Config = toml::from_str(&content)?;
235        config.validate()?;
236        Ok(config)
237    }
238
239    /// Validate the loaded configuration.
240    ///
241    /// Called automatically by `load()`. Exposed for testing via
242    /// `validate_for_test()`.
243    fn validate(&self) -> Result<()> {
244        for s in &self.policy.suppressions {
245            if s.reason.trim().is_empty() {
246                return Err(ShieldError::Config(format!(
247                    "Suppression for fingerprint '{}' must have a non-empty reason",
248                    s.fingerprint,
249                )));
250            }
251        }
252        let _ = ScanPathFilter::from_scan_config(&self.scan, self.scan.ignore_tests)?;
253        Ok(())
254    }
255
256    /// Validate without loading from file. Used by tests.
257    #[cfg(test)]
258    pub fn validate_for_test(&self) -> Result<()> {
259        self.validate()
260    }
261
262    /// Generate a starter config file.
263    pub fn starter_toml() -> &'static str {
264        r#"# AgentShield configuration
265# See https://github.com/limaronaldo/agentshield for documentation.
266
267[policy]
268# Minimum severity to fail the scan (info, low, medium, high, critical).
269fail_on = "high"
270
271# Rule IDs to ignore entirely.
272# ignore_rules = ["SHIELD-008"]
273
274# Per-rule severity overrides.
275# [policy.overrides]
276# "SHIELD-012" = "info"
277
278# Suppress specific findings by fingerprint.
279# Run `agentshield scan . --format json` to see fingerprints.
280# [[policy.suppressions]]
281# fingerprint = "abc123..."
282# reason = "False positive: input is validated by middleware"
283# expires = "2026-06-01"
284
285# [scan]
286# Skip test files (test/, tests/, __tests__/, *.test.ts, *.spec.ts, etc.).
287# ignore_tests = false
288# Include only matching paths. Empty means include all scan-supported files.
289# Use ** for recursive directories; * and ? stay within one path segment.
290# include = ["src/**", "tools/**"]
291# Exclude matching paths after include filtering.
292# exclude = ["legacy/**", "**/generated/**", "vendor/**"]
293
294# [runtime.proxy]
295# Runtime MCP proxy guard blocking threshold: block, warn, or never.
296# fail_on = "block"
297
298# [[runtime.proxy.tool]]
299# name = "calculator.add"
300# fail_on = "never"
301"#
302    }
303}