1use std::path::{Component, Path};
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::{Result, ShieldError};
6use crate::rules::policy::Policy;
7
8#[derive(Debug, Clone, Default, Serialize, Deserialize)]
10pub struct Config {
11 #[serde(default)]
12 pub policy: Policy,
13 #[serde(default)]
14 pub scan: ScanConfig,
15 #[serde(default)]
16 pub runtime: RuntimeConfig,
17}
18
19#[derive(Debug, Clone, Default, Serialize, Deserialize)]
21pub struct ScanConfig {
22 #[serde(default)]
24 pub ignore_tests: bool,
25 #[serde(default)]
26 pub include: Vec<String>,
27 #[serde(default)]
28 pub exclude: Vec<String>,
29}
30
31#[derive(Debug, Clone, Default, PartialEq, Eq)]
32pub struct ScanPathFilterSummary {
33 pub include: Vec<String>,
34 pub exclude: Vec<String>,
35}
36
37#[derive(Debug, Clone)]
38pub struct ScanPathFilter {
39 ignore_tests: bool,
40 include: Vec<CompiledPathPattern>,
41 exclude: Vec<CompiledPathPattern>,
42}
43
44#[derive(Debug, Clone)]
45struct CompiledPathPattern {
46 raw: String,
47 patterns: Vec<glob::Pattern>,
48}
49
50const PATH_PATTERN_MATCH_OPTIONS: glob::MatchOptions = glob::MatchOptions {
51 case_sensitive: true,
52 require_literal_separator: true,
53 require_literal_leading_dot: false,
54};
55
56impl ScanPathFilter {
57 pub fn for_ignore_tests(ignore_tests: bool) -> Self {
58 Self {
59 ignore_tests,
60 include: Vec::new(),
61 exclude: Vec::new(),
62 }
63 }
64
65 pub fn from_scan_config(config: &ScanConfig, ignore_tests: bool) -> Result<Self> {
66 Ok(Self {
67 ignore_tests,
68 include: compile_path_patterns("scan.include", &config.include)?,
69 exclude: compile_path_patterns("scan.exclude", &config.exclude)?,
70 })
71 }
72
73 pub const fn ignore_tests(&self) -> bool {
74 self.ignore_tests
75 }
76
77 pub fn allows_path(&self, root: &Path, path: &Path) -> bool {
78 let relative = relative_path(root, path);
79 let included = self.include.is_empty()
80 || self
81 .include
82 .iter()
83 .any(|pattern| pattern.matches(&relative));
84 let excluded = self
85 .exclude
86 .iter()
87 .any(|pattern| pattern.matches(&relative));
88
89 included && !excluded
90 }
91
92 pub fn summary(&self) -> ScanPathFilterSummary {
93 ScanPathFilterSummary {
94 include: self
95 .include
96 .iter()
97 .map(|pattern| pattern.raw.clone())
98 .collect(),
99 exclude: self
100 .exclude
101 .iter()
102 .map(|pattern| pattern.raw.clone())
103 .collect(),
104 }
105 }
106}
107
108impl CompiledPathPattern {
109 fn new(section: &str, raw: &str) -> Result<Self> {
110 let normalized = normalize_config_pattern(raw);
111 if normalized.is_empty() {
112 return Err(ShieldError::Config(format!(
113 "{section} pattern must not be empty"
114 )));
115 }
116 let patterns = expand_config_pattern(&normalized)
117 .into_iter()
118 .map(|pattern| {
119 glob::Pattern::new(&pattern).map_err(|err| {
120 ShieldError::Config(format!("invalid {section} pattern '{raw}': {err}"))
121 })
122 })
123 .collect::<Result<Vec<_>>>()?;
124
125 Ok(Self {
126 raw: raw.to_string(),
127 patterns,
128 })
129 }
130
131 fn matches(&self, relative_path: &str) -> bool {
132 self.patterns
133 .iter()
134 .any(|pattern| pattern.matches_with(relative_path, PATH_PATTERN_MATCH_OPTIONS))
135 }
136}
137
138fn compile_path_patterns(section: &str, patterns: &[String]) -> Result<Vec<CompiledPathPattern>> {
139 patterns
140 .iter()
141 .map(|pattern| CompiledPathPattern::new(section, pattern))
142 .collect()
143}
144
145fn normalize_config_pattern(pattern: &str) -> String {
146 let mut normalized = pattern.trim().replace('\\', "/");
147 normalized = normalized.trim_start_matches('/').to_string();
148 while let Some(stripped) = normalized.strip_prefix("./") {
149 normalized = stripped.to_string();
150 }
151 while normalized.contains("//") {
152 normalized = normalized.replace("//", "/");
153 }
154 if normalized.ends_with('/') {
155 normalized.push_str("**");
156 }
157 normalized
158}
159
160fn expand_config_pattern(pattern: &str) -> Vec<String> {
161 let mut patterns = vec![pattern.to_string()];
162 if let Some(root_pattern) = pattern.strip_prefix("**/") {
163 if !root_pattern.is_empty() {
164 patterns.push(root_pattern.to_string());
165 }
166 }
167 patterns
168}
169
170fn relative_path(root: &Path, path: &Path) -> String {
171 let canonical_root = root.canonicalize().unwrap_or_else(|_| root.to_path_buf());
172 let canonical_path = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
173 let relative = canonical_path
174 .strip_prefix(&canonical_root)
175 .or_else(|_| path.strip_prefix(root))
176 .unwrap_or(path);
177 let parts: Vec<String> = relative
178 .components()
179 .filter_map(|component| match component {
180 Component::Normal(part) => Some(part.to_string_lossy().into_owned()),
181 Component::CurDir => None,
182 Component::ParentDir => Some("..".to_string()),
183 Component::RootDir | Component::Prefix(_) => None,
184 })
185 .collect();
186
187 parts.join("/")
188}
189
190#[derive(Debug, Clone, Default, Serialize, Deserialize)]
192pub struct RuntimeConfig {
193 #[serde(default)]
194 pub proxy: RuntimeProxyConfig,
195}
196
197#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
199#[serde(rename_all = "lowercase")]
200pub enum ProxyFailOn {
201 #[default]
203 Block,
204 Warn,
206 Never,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct ProxyToolOverride {
213 pub name: String,
214 #[serde(default)]
215 pub fail_on: ProxyFailOn,
216}
217
218#[derive(Debug, Clone, Default, Serialize, Deserialize)]
220pub struct RuntimeProxyConfig {
221 #[serde(default)]
222 pub fail_on: ProxyFailOn,
223 #[serde(default, rename = "tool")]
224 pub tool_overrides: Vec<ProxyToolOverride>,
225}
226
227impl Config {
228 pub fn load(path: &Path) -> Result<Self> {
230 if !path.exists() {
231 return Ok(Self::default());
232 }
233 let content = std::fs::read_to_string(path)?;
234 let config: Config = toml::from_str(&content)?;
235 config.validate()?;
236 Ok(config)
237 }
238
239 fn validate(&self) -> Result<()> {
244 for s in &self.policy.suppressions {
245 if s.reason.trim().is_empty() {
246 return Err(ShieldError::Config(format!(
247 "Suppression for fingerprint '{}' must have a non-empty reason",
248 s.fingerprint,
249 )));
250 }
251 }
252 let _ = ScanPathFilter::from_scan_config(&self.scan, self.scan.ignore_tests)?;
253 Ok(())
254 }
255
256 #[cfg(test)]
258 pub fn validate_for_test(&self) -> Result<()> {
259 self.validate()
260 }
261
262 pub fn starter_toml() -> &'static str {
264 r#"# AgentShield configuration
265# See https://github.com/limaronaldo/agentshield for documentation.
266
267[policy]
268# Minimum severity to fail the scan (info, low, medium, high, critical).
269fail_on = "high"
270
271# Rule IDs to ignore entirely.
272# ignore_rules = ["SHIELD-008"]
273
274# Per-rule severity overrides.
275# [policy.overrides]
276# "SHIELD-012" = "info"
277
278# Suppress specific findings by fingerprint.
279# Run `agentshield scan . --format json` to see fingerprints.
280# [[policy.suppressions]]
281# fingerprint = "abc123..."
282# reason = "False positive: input is validated by middleware"
283# expires = "2026-06-01"
284
285# [scan]
286# Skip test files (test/, tests/, __tests__/, *.test.ts, *.spec.ts, etc.).
287# ignore_tests = false
288# Include only matching paths. Empty means include all scan-supported files.
289# Use ** for recursive directories; * and ? stay within one path segment.
290# include = ["src/**", "tools/**"]
291# Exclude matching paths after include filtering.
292# exclude = ["legacy/**", "**/generated/**", "vendor/**"]
293
294# [runtime.proxy]
295# Runtime MCP proxy guard blocking threshold: block, warn, or never.
296# fail_on = "block"
297
298# [[runtime.proxy.tool]]
299# name = "calculator.add"
300# fail_on = "never"
301"#
302 }
303}