1pub mod evaluate;
47pub mod render;
48
49use std::collections::BTreeMap;
50use std::path::{Path, PathBuf};
51
52use anyhow::{anyhow, Context};
53use serde::Serialize;
54use serde_yaml::Value as YamlValue;
55
56pub use evaluate::{evaluate_corpus, DecisionLine, EvalOptions};
57
58#[derive(Debug, Clone)]
62pub struct DiffOptions {
63 pub rules_before: PathBuf,
64 pub rules_after: PathBuf,
65 pub corpus: Option<PathBuf>,
67 pub workspace: Option<PathBuf>,
68 pub format: OutputFormat,
69 pub max_samples: usize,
70 pub fail_if_flipped: bool,
71 pub fail_if_loosened: bool,
72 pub fail_if_allows_loosened: Option<usize>,
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub enum OutputFormat {
78 Text,
79 Markdown,
80 Json,
81}
82
83impl OutputFormat {
84 pub fn parse(s: &str) -> anyhow::Result<Self> {
85 match s {
86 "text" => Ok(OutputFormat::Text),
87 "markdown" | "md" => Ok(OutputFormat::Markdown),
88 "json" => Ok(OutputFormat::Json),
89 other => Err(anyhow!(
90 "unknown --format '{}': must be one of text|markdown|json",
91 other
92 )),
93 }
94 }
95}
96
97#[derive(Debug, Clone, Serialize)]
101pub struct RuleDelta {
102 pub rule_id: String,
103 pub status: String,
105 #[serde(skip_serializing_if = "String::is_empty")]
106 pub yaml_diff: String,
107 pub fires_before: usize,
108 pub fires_after: usize,
109 #[serde(skip_serializing)]
111 pub flipped_lines_caused: Vec<(String, String, serde_json::Value)>,
112}
113
114pub type FlipCounter = BTreeMap<(String, String), usize>;
117
118pub const DECISIONS: [&str; 4] = ["allow", "warn", "approval", "block"];
119
120fn severity_rank(d: &str) -> u8 {
123 match d {
124 "allow" => 0,
125 "warn" => 1,
126 "approval" | "identity_verification" => 2,
127 "block" => 3,
128 _ => 99,
129 }
130}
131
132pub fn loosening_count(flips: &FlipCounter) -> usize {
136 flips
137 .iter()
138 .filter(|((b, a), _)| severity_rank(a) < severity_rank(b))
139 .map(|(_, c)| *c)
140 .sum()
141}
142
143pub fn flips_to_allow(flips: &FlipCounter) -> usize {
146 flips
147 .iter()
148 .filter(|((_, a), _)| a == "allow")
149 .map(|(_, c)| *c)
150 .sum()
151}
152
153pub fn load_ruleset_yaml(
159 path: &Path,
160) -> anyhow::Result<BTreeMap<String, YamlValue>> {
161 let raw = std::fs::read_to_string(path)
162 .with_context(|| format!("reading shieldset YAML from {}", path.display()))?;
163 let root: YamlValue = serde_yaml::from_str(&raw)
164 .with_context(|| format!("parsing YAML at {}", path.display()))?;
165 let YamlValue::Mapping(top) = &root else {
166 anyhow::bail!("{} did not parse as a YAML mapping", path.display());
167 };
168 let shieldset = top
169 .get(YamlValue::String("shieldset".into()))
170 .unwrap_or(&root);
171 let rules = match shieldset {
172 YamlValue::Mapping(m) => m.get(YamlValue::String("rules".into())).cloned(),
173 _ => None,
174 };
175 let Some(YamlValue::Sequence(rules)) = rules else {
176 return Ok(BTreeMap::new());
177 };
178 let mut out: BTreeMap<String, YamlValue> = BTreeMap::new();
179 for r in rules {
180 let YamlValue::Mapping(mut m) = r else { continue };
181 let Some(YamlValue::String(rid)) = m.remove(YamlValue::String("id".into())) else {
182 continue;
183 };
184 out.insert(rid, YamlValue::Mapping(m));
185 }
186 Ok(out)
187}
188
189pub fn yaml_dump_rule(rid: &str, body: &YamlValue) -> String {
192 let mut top = serde_yaml::Mapping::new();
193 top.insert(YamlValue::String("id".into()), YamlValue::String(rid.into()));
194 if let YamlValue::Mapping(m) = body {
195 for (k, v) in m {
196 top.insert(k.clone(), v.clone());
197 }
198 }
199 let wrapped = YamlValue::Sequence(vec![YamlValue::Mapping(top)]);
200 serde_yaml::to_string(&wrapped).unwrap_or_default()
201}
202
203pub fn diff_rulesets(
207 before: &BTreeMap<String, YamlValue>,
208 after: &BTreeMap<String, YamlValue>,
209) -> BTreeMap<String, RuleDelta> {
210 use similar::{ChangeTag, TextDiff};
211
212 let mut all_ids: std::collections::BTreeSet<&String> = before.keys().collect();
213 all_ids.extend(after.keys());
214
215 let mut deltas = BTreeMap::new();
216 for rid in all_ids {
217 let in_before = before.contains_key(rid);
218 let in_after = after.contains_key(rid);
219 let (status, yaml_diff): (&str, String) = match (in_before, in_after) {
220 (true, false) => {
221 let dumped = yaml_dump_rule(rid, &before[rid]);
222 let diff = dumped
223 .lines()
224 .map(|l| format!("- {}", l))
225 .collect::<Vec<_>>()
226 .join("\n");
227 ("removed", diff)
228 }
229 (false, true) => {
230 let dumped = yaml_dump_rule(rid, &after[rid]);
231 let diff = dumped
232 .lines()
233 .map(|l| format!("+ {}", l))
234 .collect::<Vec<_>>()
235 .join("\n");
236 ("added", diff)
237 }
238 (true, true) if before[rid] == after[rid] => ("unchanged", String::new()),
239 _ => {
240 let b_yaml = yaml_dump_rule(rid, &before[rid]);
241 let a_yaml = yaml_dump_rule(rid, &after[rid]);
242 let diff = TextDiff::from_lines(&b_yaml, &a_yaml);
243 let mut out = String::new();
244 out.push_str(&format!("--- {}.before\n", rid));
245 out.push_str(&format!("+++ {}.after\n", rid));
246 for change in diff.iter_all_changes() {
247 let sign = match change.tag() {
248 ChangeTag::Delete => "-",
249 ChangeTag::Insert => "+",
250 ChangeTag::Equal => " ",
251 };
252 out.push_str(sign);
253 out.push_str(change.value());
254 }
255 ("modified", out)
256 }
257 };
258 deltas.insert(
259 rid.clone(),
260 RuleDelta {
261 rule_id: rid.clone(),
262 status: status.to_string(),
263 yaml_diff,
264 fires_before: 0,
265 fires_after: 0,
266 flipped_lines_caused: Vec::new(),
267 },
268 );
269 }
270 deltas
271}
272
273pub fn populate_behavior(
283 deltas: &mut BTreeMap<String, RuleDelta>,
284 before: &[DecisionLine],
285 after: &[DecisionLine],
286) -> FlipCounter {
287 if before.len() != after.len() {
288 eprintln!(
289 "warn: decision counts differ ({} vs {}); pairing by index",
290 before.len(),
291 after.len()
292 );
293 }
294 let n = before.len().min(after.len());
295 let mut flips: FlipCounter = BTreeMap::new();
296 for i in 0..n {
297 let b = &before[i];
298 let a = &after[i];
299 for rid in &b.matched_rules {
300 if let Some(d) = deltas.get_mut(rid) {
301 d.fires_before += 1;
302 }
303 }
304 for rid in &a.matched_rules {
305 if let Some(d) = deltas.get_mut(rid) {
306 d.fires_after += 1;
307 }
308 }
309 if b.decision != a.decision {
310 *flips
311 .entry((b.decision.clone(), a.decision.clone()))
312 .or_insert(0) += 1;
313 for rid in &a.matched_rules {
317 if let Some(d) = deltas.get_mut(rid) {
318 if matches!(d.status.as_str(), "added" | "modified") {
319 d.flipped_lines_caused.push((
320 b.decision.clone(),
321 a.decision.clone(),
322 b.input.clone(),
323 ));
324 }
325 }
326 }
327 for rid in &b.matched_rules {
328 if let Some(d) = deltas.get_mut(rid) {
329 if d.status == "removed" {
330 d.flipped_lines_caused.push((
331 b.decision.clone(),
332 a.decision.clone(),
333 b.input.clone(),
334 ));
335 }
336 }
337 }
338 }
339 }
340 flips
341}
342
343pub async fn run_diff_mode(opts: DiffOptions) -> anyhow::Result<i32> {
347 let before_yaml = load_ruleset_yaml(&opts.rules_before)?;
348 let after_yaml = load_ruleset_yaml(&opts.rules_after)?;
349
350 let corpus_bytes = read_corpus(opts.corpus.as_deref())?;
351 if corpus_bytes.trim().is_empty() {
352 anyhow::bail!("corpus is empty");
353 }
354 let corpus_line_count = corpus_bytes
360 .lines()
361 .filter(|l| {
362 let t = l.trim();
363 !t.is_empty() && !t.starts_with('#') && !t.starts_with("//")
364 })
365 .count();
366
367 let eval_opts = EvalOptions {
368 workspace: opts.workspace.clone(),
369 };
370
371 let before_decisions = evaluate_corpus(&opts.rules_before, &corpus_bytes, &eval_opts)?;
372 let after_decisions = evaluate_corpus(&opts.rules_after, &corpus_bytes, &eval_opts)?;
373
374 let mut decision_before: BTreeMap<String, usize> = BTreeMap::new();
375 for d in DECISIONS {
376 decision_before.insert(d.into(), 0);
377 }
378 for d in &before_decisions {
379 *decision_before.entry(d.decision.clone()).or_insert(0) += 1;
380 }
381 let mut decision_after: BTreeMap<String, usize> = BTreeMap::new();
382 for d in DECISIONS {
383 decision_after.insert(d.into(), 0);
384 }
385 for d in &after_decisions {
386 *decision_after.entry(d.decision.clone()).or_insert(0) += 1;
387 }
388
389 let mut deltas = diff_rulesets(&before_yaml, &after_yaml);
390 let flips = populate_behavior(&mut deltas, &before_decisions, &after_decisions);
391
392 let before_label = opts.rules_before.display().to_string();
393 let after_label = opts.rules_after.display().to_string();
394
395 let out = match opts.format {
396 OutputFormat::Text => render::render_text(
397 &before_label,
398 &after_label,
399 corpus_line_count,
400 &decision_before,
401 &decision_after,
402 &deltas,
403 &flips,
404 opts.max_samples,
405 ),
406 OutputFormat::Markdown => render::render_markdown(
407 &before_label,
408 &after_label,
409 corpus_line_count,
410 &decision_before,
411 &decision_after,
412 &deltas,
413 &flips,
414 opts.max_samples,
415 ),
416 OutputFormat::Json => render::render_json(
417 &before_label,
418 &after_label,
419 corpus_line_count,
420 &decision_before,
421 &decision_after,
422 &deltas,
423 &flips,
424 ),
425 };
426 print!("{}", out);
427 if !out.ends_with('\n') {
428 println!();
429 }
430
431 let total_flipped: usize = flips.values().sum();
435 if let Some(threshold) = opts.fail_if_allows_loosened {
436 if flips_to_allow(&flips) > threshold {
437 return Ok(1);
438 }
439 }
440 if opts.fail_if_loosened && loosening_count(&flips) > 0 {
441 return Ok(1);
442 }
443 if opts.fail_if_flipped && total_flipped > 0 {
444 return Ok(1);
445 }
446 Ok(0)
447}
448
449fn read_corpus(path: Option<&Path>) -> anyhow::Result<String> {
454 use std::io::Read;
455 if let Some(p) = path {
456 return std::fs::read_to_string(p)
457 .with_context(|| format!("reading corpus from {}", p.display()));
458 }
459 if atty_stdin() {
460 anyhow::bail!(
461 "no corpus on stdin and no --corpus PATH given.\n\
462 hint: aperion-shield --diff --corpus tests/corpus/golden.jsonl \
463 --rules-before X --rules-after Y"
464 );
465 }
466 let mut buf = String::new();
467 std::io::stdin().read_to_string(&mut buf)?;
468 Ok(buf)
469}
470
471fn atty_stdin() -> bool {
475 #[cfg(unix)]
476 {
477 unsafe { libc_isatty(0) }
479 }
480 #[cfg(not(unix))]
481 {
482 true
483 }
484}
485
486#[cfg(unix)]
487unsafe fn libc_isatty(fd: i32) -> bool {
488 extern "C" {
489 fn isatty(fd: i32) -> i32;
490 }
491 isatty(fd) == 1
492}