Skip to main content

aperion_shield/diff/
mod.rs

1//! Pre-merge behavior-diff explainer for shieldset changes.
2//!
3//! This module implements `aperion-shield --diff`, the native Rust
4//! port of `scripts/shield-diff.py` (the Python prototype shipped
5//! alongside `docs/shieldset-as-code.md`). Both produce a
6//! source-compatible JSON output schema, so CI wired up against the
7//! Python prototype keeps working unchanged when the flag flips to
8//! `aperion-shield --diff`.
9//!
10//! ## Why this exists
11//!
12//! Shieldset changes are policy changes -- they need PR review like
13//! code, but `diff shieldset.before.yaml shieldset.after.yaml` only
14//! tells you the YAML changed. It does not tell you which calls in
15//! the real corpus will now flip from `allow` to `block`, or worse,
16//! from `block` to `allow`. That is what this mode is for.
17//!
18//! ## Pipeline
19//!
20//! ```text
21//!   shieldset.before.yaml ─┐
22//!   shieldset.after.yaml  ─┤── load Engine x2  ── evaluate corpus x2
23//!   corpus.jsonl          ─┘                       │
24//!                                                  v
25//!                                          DecisionLine sets x2
26//!                                                  │
27//!                                                  v
28//!                       diff rulesets (added/removed/modified/unchanged)
29//!                                                  │
30//!                                                  v
31//!                       pair decisions by index, attribute every flip to
32//!                       the changed rule(s) that fired under the after-state
33//!                                                  │
34//!                                                  v
35//!                       render text / markdown / json
36//! ```
37//!
38//! ## In-process, not subprocess
39//!
40//! The Python prototype shells out to `aperion-shield --check` twice.
41//! This native port skips the subprocess: both runs use the same
42//! `Engine::evaluate` path the proxy uses. It is materially faster
43//! on big corpora (no JSON re-encode + re-decode trip per line) and
44//! removes the runtime PATH dependency the Python prototype carried.
45
46pub mod evaluate;
47pub mod render;
48
49use std::collections::BTreeMap;
50use std::path::{Path, PathBuf};
51
52use anyhow::{anyhow, Context};
53use serde::Serialize;
54use serde_yaml::Value as YamlValue;
55
56pub use evaluate::{evaluate_corpus, DecisionLine, EvalOptions};
57
58/// CLI-level options for `aperion-shield --diff`. Mirrors the Python
59/// prototype 1:1 so the `--format json` output schema stays
60/// source-compatible.
61#[derive(Debug, Clone)]
62pub struct DiffOptions {
63    pub rules_before: PathBuf,
64    pub rules_after: PathBuf,
65    /// Corpus path; `None` means read JSON-Lines from stdin.
66    pub corpus: Option<PathBuf>,
67    pub workspace: Option<PathBuf>,
68    pub format: OutputFormat,
69    pub max_samples: usize,
70    pub fail_if_flipped: bool,
71    pub fail_if_loosened: bool,
72    pub fail_if_allows_loosened: Option<usize>,
73}
74
75/// Output format for the diff report.
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub enum OutputFormat {
78    Text,
79    Markdown,
80    Json,
81}
82
83impl OutputFormat {
84    pub fn parse(s: &str) -> anyhow::Result<Self> {
85        match s {
86            "text" => Ok(OutputFormat::Text),
87            "markdown" | "md" => Ok(OutputFormat::Markdown),
88            "json" => Ok(OutputFormat::Json),
89            other => Err(anyhow!(
90                "unknown --format '{}': must be one of text|markdown|json",
91                other
92            )),
93        }
94    }
95}
96
97/// Per-rule change: YAML-level (textual) + behavioral (corpus-level).
98/// Mirrors `shield-diff.py::RuleDelta`. Serialised in `--format json`
99/// output -- keep the field names stable.
100#[derive(Debug, Clone, Serialize)]
101pub struct RuleDelta {
102    pub rule_id: String,
103    /// "added" | "removed" | "modified" | "unchanged"
104    pub status: String,
105    #[serde(skip_serializing_if = "String::is_empty")]
106    pub yaml_diff: String,
107    pub fires_before: usize,
108    pub fires_after: usize,
109    /// Each entry: (decision_before, decision_after, input_obj).
110    #[serde(skip_serializing)]
111    pub flipped_lines_caused: Vec<(String, String, serde_json::Value)>,
112}
113
114/// Aggregate counter: `(decision_before, decision_after) -> count`.
115/// Lexically ordered so render order is deterministic.
116pub type FlipCounter = BTreeMap<(String, String), usize>;
117
118pub const DECISIONS: [&str; 4] = ["allow", "warn", "approval", "block"];
119
120/// Numeric ordering of decision severities. Used by [`loosening_count`]
121/// to decide whether a flip moved toward a more permissive decision.
122fn severity_rank(d: &str) -> u8 {
123    match d {
124        "allow" => 0,
125        "warn" => 1,
126        "approval" | "identity_verification" => 2,
127        "block" => 3,
128        _ => 99,
129    }
130}
131
132/// How many flipped lines moved toward a more permissive decision.
133/// `identity_verification` counts at the same severity as `approval`
134/// because both gate the call before it runs upstream.
135pub fn loosening_count(flips: &FlipCounter) -> usize {
136    flips
137        .iter()
138        .filter(|((b, a), _)| severity_rank(a) < severity_rank(b))
139        .map(|(_, c)| *c)
140        .sum()
141}
142
143/// How many flipped lines ended at `allow`. Used by
144/// `--fail-if-allows-loosened`.
145pub fn flips_to_allow(flips: &FlipCounter) -> usize {
146    flips
147        .iter()
148        .filter(|((_, a), _)| a == "allow")
149        .map(|(_, c)| *c)
150        .sum()
151}
152
153/// Parse a shieldset YAML file into a `BTreeMap<rule_id, rule_body>`
154/// where `rule_body` is the rule's YAML node MINUS its `id` field.
155/// Used for diffing rules textually. Tolerates both the wrapped
156/// (`shieldset:\n  rules:`) and bare (`rules:`) forms, matching the
157/// Python prototype.
158pub fn load_ruleset_yaml(
159    path: &Path,
160) -> anyhow::Result<BTreeMap<String, YamlValue>> {
161    let raw = std::fs::read_to_string(path)
162        .with_context(|| format!("reading shieldset YAML from {}", path.display()))?;
163    let root: YamlValue = serde_yaml::from_str(&raw)
164        .with_context(|| format!("parsing YAML at {}", path.display()))?;
165    let YamlValue::Mapping(top) = &root else {
166        anyhow::bail!("{} did not parse as a YAML mapping", path.display());
167    };
168    let shieldset = top
169        .get(YamlValue::String("shieldset".into()))
170        .unwrap_or(&root);
171    let rules = match shieldset {
172        YamlValue::Mapping(m) => m.get(YamlValue::String("rules".into())).cloned(),
173        _ => None,
174    };
175    let Some(YamlValue::Sequence(rules)) = rules else {
176        return Ok(BTreeMap::new());
177    };
178    let mut out: BTreeMap<String, YamlValue> = BTreeMap::new();
179    for r in rules {
180        let YamlValue::Mapping(mut m) = r else { continue };
181        let Some(YamlValue::String(rid)) = m.remove(YamlValue::String("id".into())) else {
182            continue;
183        };
184        out.insert(rid, YamlValue::Mapping(m));
185    }
186    Ok(out)
187}
188
189/// Dump one rule (id + body) back to YAML for textual diffing.
190/// Always emits `id` first to keep the diff stable across runs.
191pub fn yaml_dump_rule(rid: &str, body: &YamlValue) -> String {
192    let mut top = serde_yaml::Mapping::new();
193    top.insert(YamlValue::String("id".into()), YamlValue::String(rid.into()));
194    if let YamlValue::Mapping(m) = body {
195        for (k, v) in m {
196            top.insert(k.clone(), v.clone());
197        }
198    }
199    let wrapped = YamlValue::Sequence(vec![YamlValue::Mapping(top)]);
200    serde_yaml::to_string(&wrapped).unwrap_or_default()
201}
202
203/// Classify every rule that appears in either ruleset. The YAML diff
204/// is rendered eagerly so we don't pay the cost twice if the renderer
205/// is asked to embed it.
206pub fn diff_rulesets(
207    before: &BTreeMap<String, YamlValue>,
208    after: &BTreeMap<String, YamlValue>,
209) -> BTreeMap<String, RuleDelta> {
210    use similar::{ChangeTag, TextDiff};
211
212    let mut all_ids: std::collections::BTreeSet<&String> = before.keys().collect();
213    all_ids.extend(after.keys());
214
215    let mut deltas = BTreeMap::new();
216    for rid in all_ids {
217        let in_before = before.contains_key(rid);
218        let in_after = after.contains_key(rid);
219        let (status, yaml_diff): (&str, String) = match (in_before, in_after) {
220            (true, false) => {
221                let dumped = yaml_dump_rule(rid, &before[rid]);
222                let diff = dumped
223                    .lines()
224                    .map(|l| format!("- {}", l))
225                    .collect::<Vec<_>>()
226                    .join("\n");
227                ("removed", diff)
228            }
229            (false, true) => {
230                let dumped = yaml_dump_rule(rid, &after[rid]);
231                let diff = dumped
232                    .lines()
233                    .map(|l| format!("+ {}", l))
234                    .collect::<Vec<_>>()
235                    .join("\n");
236                ("added", diff)
237            }
238            (true, true) if before[rid] == after[rid] => ("unchanged", String::new()),
239            _ => {
240                let b_yaml = yaml_dump_rule(rid, &before[rid]);
241                let a_yaml = yaml_dump_rule(rid, &after[rid]);
242                let diff = TextDiff::from_lines(&b_yaml, &a_yaml);
243                let mut out = String::new();
244                out.push_str(&format!("--- {}.before\n", rid));
245                out.push_str(&format!("+++ {}.after\n", rid));
246                for change in diff.iter_all_changes() {
247                    let sign = match change.tag() {
248                        ChangeTag::Delete => "-",
249                        ChangeTag::Insert => "+",
250                        ChangeTag::Equal => " ",
251                    };
252                    out.push_str(sign);
253                    out.push_str(change.value());
254                }
255                ("modified", out)
256            }
257        };
258        deltas.insert(
259            rid.clone(),
260            RuleDelta {
261                rule_id: rid.clone(),
262                status: status.to_string(),
263                yaml_diff,
264                fires_before: 0,
265                fires_after: 0,
266                flipped_lines_caused: Vec::new(),
267            },
268        );
269    }
270    deltas
271}
272
273/// Walk paired before/after decision lists, fill in `fires_before` /
274/// `fires_after`, build the global flip counter, attribute each flip
275/// to the rule(s) that materially changed under the after-state, and
276/// return the global flip counter.
277///
278/// Pairing is by index, mirroring the Python prototype. If the two
279/// runs produced different counts (which can happen if a shieldset
280/// change causes evaluation errors on some lines), we pair as many
281/// as we have and emit a stderr warning.
282pub fn populate_behavior(
283    deltas: &mut BTreeMap<String, RuleDelta>,
284    before: &[DecisionLine],
285    after: &[DecisionLine],
286) -> FlipCounter {
287    if before.len() != after.len() {
288        eprintln!(
289            "warn: decision counts differ ({} vs {}); pairing by index",
290            before.len(),
291            after.len()
292        );
293    }
294    let n = before.len().min(after.len());
295    let mut flips: FlipCounter = BTreeMap::new();
296    for i in 0..n {
297        let b = &before[i];
298        let a = &after[i];
299        for rid in &b.matched_rules {
300            if let Some(d) = deltas.get_mut(rid) {
301                d.fires_before += 1;
302            }
303        }
304        for rid in &a.matched_rules {
305            if let Some(d) = deltas.get_mut(rid) {
306                d.fires_after += 1;
307            }
308        }
309        if b.decision != a.decision {
310            *flips
311                .entry((b.decision.clone(), a.decision.clone()))
312                .or_insert(0) += 1;
313            // Attribute the flip to whichever changed rule(s) actually
314            // fired under the new state. For removals we attribute to
315            // the rule that fired under the OLD state and is now gone.
316            for rid in &a.matched_rules {
317                if let Some(d) = deltas.get_mut(rid) {
318                    if matches!(d.status.as_str(), "added" | "modified") {
319                        d.flipped_lines_caused.push((
320                            b.decision.clone(),
321                            a.decision.clone(),
322                            b.input.clone(),
323                        ));
324                    }
325                }
326            }
327            for rid in &b.matched_rules {
328                if let Some(d) = deltas.get_mut(rid) {
329                    if d.status == "removed" {
330                        d.flipped_lines_caused.push((
331                            b.decision.clone(),
332                            a.decision.clone(),
333                            b.input.clone(),
334                        ));
335                    }
336                }
337            }
338        }
339    }
340    flips
341}
342
343/// Top-level entry point for `aperion-shield --diff`. Returns the
344/// shell exit code: 0 for success, 1 for a policy-gate trip (e.g.
345/// `--fail-if-flipped`), 2 for an internal / I/O error.
346pub async fn run_diff_mode(opts: DiffOptions) -> anyhow::Result<i32> {
347    let before_yaml = load_ruleset_yaml(&opts.rules_before)?;
348    let after_yaml = load_ruleset_yaml(&opts.rules_after)?;
349
350    let corpus_bytes = read_corpus(opts.corpus.as_deref())?;
351    if corpus_bytes.trim().is_empty() {
352        anyhow::bail!("corpus is empty");
353    }
354    // Count non-comment, non-blank lines for the "corpus: N commands"
355    // header. Matches the Python prototype's line-count semantics AND
356    // the `evaluate_corpus` skip rules (which treat both `#` and `//`
357    // as comment markers) so the line count agrees with the number
358    // of decisions reported.
359    let corpus_line_count = corpus_bytes
360        .lines()
361        .filter(|l| {
362            let t = l.trim();
363            !t.is_empty() && !t.starts_with('#') && !t.starts_with("//")
364        })
365        .count();
366
367    let eval_opts = EvalOptions {
368        workspace: opts.workspace.clone(),
369    };
370
371    let before_decisions = evaluate_corpus(&opts.rules_before, &corpus_bytes, &eval_opts)?;
372    let after_decisions = evaluate_corpus(&opts.rules_after, &corpus_bytes, &eval_opts)?;
373
374    let mut decision_before: BTreeMap<String, usize> = BTreeMap::new();
375    for d in DECISIONS {
376        decision_before.insert(d.into(), 0);
377    }
378    for d in &before_decisions {
379        *decision_before.entry(d.decision.clone()).or_insert(0) += 1;
380    }
381    let mut decision_after: BTreeMap<String, usize> = BTreeMap::new();
382    for d in DECISIONS {
383        decision_after.insert(d.into(), 0);
384    }
385    for d in &after_decisions {
386        *decision_after.entry(d.decision.clone()).or_insert(0) += 1;
387    }
388
389    let mut deltas = diff_rulesets(&before_yaml, &after_yaml);
390    let flips = populate_behavior(&mut deltas, &before_decisions, &after_decisions);
391
392    let before_label = opts.rules_before.display().to_string();
393    let after_label = opts.rules_after.display().to_string();
394
395    let out = match opts.format {
396        OutputFormat::Text => render::render_text(
397            &before_label,
398            &after_label,
399            corpus_line_count,
400            &decision_before,
401            &decision_after,
402            &deltas,
403            &flips,
404            opts.max_samples,
405        ),
406        OutputFormat::Markdown => render::render_markdown(
407            &before_label,
408            &after_label,
409            corpus_line_count,
410            &decision_before,
411            &decision_after,
412            &deltas,
413            &flips,
414            opts.max_samples,
415        ),
416        OutputFormat::Json => render::render_json(
417            &before_label,
418            &after_label,
419            corpus_line_count,
420            &decision_before,
421            &decision_after,
422            &deltas,
423            &flips,
424        ),
425    };
426    print!("{}", out);
427    if !out.ends_with('\n') {
428        println!();
429    }
430
431    // Exit-code policy: order matches the Python prototype so a
432    // shell wrapper that does `aperion-shield --diff || exit $?` keeps
433    // its same semantics across the Python/Rust swap.
434    let total_flipped: usize = flips.values().sum();
435    if let Some(threshold) = opts.fail_if_allows_loosened {
436        if flips_to_allow(&flips) > threshold {
437            return Ok(1);
438        }
439    }
440    if opts.fail_if_loosened && loosening_count(&flips) > 0 {
441        return Ok(1);
442    }
443    if opts.fail_if_flipped && total_flipped > 0 {
444        return Ok(1);
445    }
446    Ok(0)
447}
448
449/// Read the JSON-Lines corpus from a file or stdin. The Python
450/// prototype refuses to read from a TTY; we keep that behaviour so
451/// `aperion-shield --diff` doesn't hang waiting for input when the
452/// user forgot `--corpus`.
453fn read_corpus(path: Option<&Path>) -> anyhow::Result<String> {
454    use std::io::Read;
455    if let Some(p) = path {
456        return std::fs::read_to_string(p)
457            .with_context(|| format!("reading corpus from {}", p.display()));
458    }
459    if atty_stdin() {
460        anyhow::bail!(
461            "no corpus on stdin and no --corpus PATH given.\n\
462             hint: aperion-shield --diff --corpus tests/corpus/golden.jsonl \
463             --rules-before X --rules-after Y"
464        );
465    }
466    let mut buf = String::new();
467    std::io::stdin().read_to_string(&mut buf)?;
468    Ok(buf)
469}
470
471/// Minimal TTY detection that avoids pulling in the `atty` /
472/// `is-terminal` crate just for one call site. Falls back to "no
473/// pipe" (i.e. treat as TTY) on any error.
474fn atty_stdin() -> bool {
475    #[cfg(unix)]
476    {
477        // SAFETY: isatty is a thread-safe libc call.
478        unsafe { libc_isatty(0) }
479    }
480    #[cfg(not(unix))]
481    {
482        true
483    }
484}
485
486#[cfg(unix)]
487unsafe fn libc_isatty(fd: i32) -> bool {
488    extern "C" {
489        fn isatty(fd: i32) -> i32;
490    }
491    isatty(fd) == 1
492}