Skip to main content

auths_cli/commands/
policy.rs

1//! Policy governance commands for Auths.
2//!
3//! Commands for linting, compiling, testing, and comparing policies.
4
5use crate::ux::format::{JsonResponse, Output, is_json_mode};
6use anyhow::{Context, Result, anyhow};
7use auths_policy::{
8    CompileError, CompiledExpr, EvalContext, Expr, Outcome, PolicyLimits,
9    compile_from_json_with_limits,
10};
11use auths_sdk::workflows::policy_diff::{compute_policy_diff, overall_risk_score};
12use chrono::{DateTime, Utc};
13use clap::{Parser, Subcommand};
14use serde::{Deserialize, Serialize};
15use std::fs;
16use std::path::PathBuf;
17
18/// Manage authorization policies.
19#[derive(Parser, Debug, Clone)]
20#[command(name = "policy", about = "Manage authorization policies")]
21pub struct PolicyCommand {
22    #[command(subcommand)]
23    pub command: PolicySubcommand,
24}
25
26#[derive(Subcommand, Debug, Clone)]
27pub enum PolicySubcommand {
28    /// Validate policy JSON syntax without full compilation.
29    Lint(LintCommand),
30
31    /// Compile a policy file with full validation.
32    Compile(CompileCommand),
33
34    /// Evaluate a policy against a context and show the decision.
35    Explain(ExplainCommand),
36
37    /// Run a policy against a test suite.
38    Test(TestCommand),
39
40    /// Compare two policies and show semantic differences.
41    Diff(DiffCommand),
42}
43
44/// Validate policy JSON syntax.
45#[derive(Parser, Debug, Clone)]
46pub struct LintCommand {
47    /// Path to the policy file (JSON).
48    pub file: PathBuf,
49}
50
51/// Compile a policy with full validation.
52#[derive(Parser, Debug, Clone)]
53pub struct CompileCommand {
54    /// Path to the policy file (JSON).
55    pub file: PathBuf,
56}
57
58/// Evaluate a policy against a context.
59#[derive(Parser, Debug, Clone)]
60pub struct ExplainCommand {
61    /// Path to the policy file (JSON).
62    pub file: PathBuf,
63
64    /// Path to the context file (JSON).
65    #[clap(long, short = 'c')]
66    pub context: PathBuf,
67}
68
69/// Run a policy against a test suite.
70#[derive(Parser, Debug, Clone)]
71pub struct TestCommand {
72    /// Path to the policy file (JSON).
73    pub file: PathBuf,
74
75    /// Path to the test suite file (JSON).
76    #[clap(long, short = 't')]
77    pub tests: PathBuf,
78}
79
80/// Compare two policies.
81#[derive(Parser, Debug, Clone)]
82pub struct DiffCommand {
83    /// Path to the old policy file (JSON).
84    pub old: PathBuf,
85
86    /// Path to the new policy file (JSON).
87    pub new: PathBuf,
88}
89
90// ── JSON Output Types ───────────────────────────────────────────────────
91
92#[derive(Debug, Serialize)]
93struct LintData {
94    bytes: usize,
95    byte_limit: usize,
96}
97
98#[derive(Debug, Serialize)]
99struct CompileData {
100    #[serde(skip_serializing_if = "Option::is_none")]
101    nodes: Option<u32>,
102    #[serde(skip_serializing_if = "Option::is_none")]
103    depth: Option<u32>,
104    #[serde(skip_serializing_if = "Option::is_none")]
105    hash: Option<String>,
106    #[serde(skip_serializing_if = "Vec::is_empty")]
107    errors: Vec<String>,
108}
109
110#[derive(Debug, Serialize)]
111struct ExplainOutput {
112    decision: String,
113    reason_code: String,
114    message: String,
115    policy_hash: String,
116}
117
118#[derive(Debug, Serialize)]
119struct TestOutput {
120    passed: usize,
121    failed: usize,
122    total: usize,
123    results: Vec<TestResult>,
124}
125
126#[derive(Debug, Serialize)]
127struct TestResult {
128    name: String,
129    passed: bool,
130    expected: String,
131    actual: String,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    message: Option<String>,
134}
135
136#[derive(Debug, Serialize)]
137struct DiffOutput {
138    changes: Vec<DiffChange>,
139    risk_score: String,
140}
141
142#[derive(Debug, Serialize)]
143struct DiffChange {
144    kind: String,
145    description: String,
146    risk: String,
147}
148
149// ── Test Suite Types ────────────────────────────────────────────────────
150
151#[derive(Debug, Deserialize)]
152struct TestCase {
153    name: String,
154    context: TestContext,
155    expect: String,
156}
157
158#[derive(Debug, Deserialize)]
159struct TestContext {
160    issuer: String,
161    subject: String,
162    #[serde(default)]
163    revoked: bool,
164    #[serde(default)]
165    capabilities: Vec<String>,
166    #[serde(default)]
167    role: Option<String>,
168    #[serde(default)]
169    expires_at: Option<DateTime<Utc>>,
170    #[serde(default)]
171    timestamp: Option<DateTime<Utc>>,
172    #[serde(default)]
173    chain_depth: u32,
174    #[serde(default)]
175    repo: Option<String>,
176    #[serde(default)]
177    git_ref: Option<String>,
178    #[serde(default)]
179    paths: Vec<String>,
180    #[serde(default)]
181    environment: Option<String>,
182}
183
184// ── Handler ─────────────────────────────────────────────────────────────
185
186pub fn handle_policy(cmd: PolicyCommand) -> Result<()> {
187    match cmd.command {
188        PolicySubcommand::Lint(lint) => handle_lint(lint),
189        PolicySubcommand::Compile(compile) => handle_compile(compile),
190        PolicySubcommand::Explain(explain) => handle_explain(explain),
191        PolicySubcommand::Test(test) => handle_test(test),
192        PolicySubcommand::Diff(diff) => handle_diff(diff),
193    }
194}
195
196fn handle_lint(cmd: LintCommand) -> Result<()> {
197    let out = Output::new();
198    let limits = PolicyLimits::default();
199
200    // Read the file
201    let content =
202        fs::read(&cmd.file).with_context(|| format!("failed to read {}", cmd.file.display()))?;
203
204    let bytes = content.len();
205
206    // Check size limit
207    if bytes > limits.max_json_bytes {
208        if is_json_mode() {
209            JsonResponse::<()>::error(
210                "policy lint",
211                format!(
212                    "file exceeds size limit: {} > {}",
213                    bytes, limits.max_json_bytes
214                ),
215            )
216            .print()?;
217        } else {
218            out.println(&format!(
219                "{} File exceeds size limit: {} bytes (limit: {})",
220                out.error("x"),
221                bytes,
222                limits.max_json_bytes
223            ));
224        }
225        return Ok(());
226    }
227
228    // Parse JSON
229    match serde_json::from_slice::<Expr>(&content) {
230        Ok(_expr) => {
231            if is_json_mode() {
232                JsonResponse::success(
233                    "policy lint",
234                    LintData {
235                        bytes,
236                        byte_limit: limits.max_json_bytes,
237                    },
238                )
239                .print()?;
240            } else {
241                out.println(&format!("{} Valid JSON", out.success("ok")));
242                out.println(&format!("{} All ops recognized", out.success("ok")));
243                out.println(&format!(
244                    "{} {} bytes (limit: {})",
245                    out.success("ok"),
246                    bytes,
247                    limits.max_json_bytes
248                ));
249            }
250        }
251        Err(e) => {
252            if is_json_mode() {
253                JsonResponse::<()>::error("policy lint", e.to_string()).print()?;
254            } else {
255                out.println(&format!("{} Invalid JSON: {}", out.error("x"), e));
256            }
257        }
258    }
259
260    Ok(())
261}
262
263fn handle_compile(cmd: CompileCommand) -> Result<()> {
264    let out = Output::new();
265    let limits = PolicyLimits::default();
266
267    let content =
268        fs::read(&cmd.file).with_context(|| format!("failed to read {}", cmd.file.display()))?;
269
270    match compile_from_json_with_limits(&content, &limits) {
271        Ok(policy) => {
272            let stats = compute_policy_stats(policy.expr());
273            let hash = hex::encode(policy.source_hash());
274
275            if is_json_mode() {
276                JsonResponse::success(
277                    "policy compile",
278                    CompileData {
279                        nodes: Some(stats.nodes),
280                        depth: Some(stats.depth),
281                        hash: Some(hash),
282                        errors: vec![],
283                    },
284                )
285                .print()?;
286            } else {
287                out.println(&format!("{} Compiled successfully", out.success("ok")));
288                out.println(&format!(
289                    "  Nodes: {} (limit: {})",
290                    stats.nodes, limits.max_total_nodes
291                ));
292                out.println(&format!(
293                    "  Depth: {} (limit: {})",
294                    stats.depth, limits.max_depth
295                ));
296                out.println(&format!("  Hash:  {}", hash));
297            }
298        }
299        Err(errors) => {
300            let error_strs: Vec<String> = errors.iter().map(format_compile_error).collect();
301
302            if is_json_mode() {
303                JsonResponse {
304                    success: false,
305                    command: "policy compile".to_string(),
306                    data: Some(CompileData {
307                        nodes: None,
308                        depth: None,
309                        hash: None,
310                        errors: error_strs,
311                    }),
312                    error: None,
313                }
314                .print()?;
315            } else {
316                out.println(&format!(
317                    "{} Compilation failed ({} errors):",
318                    out.error("x"),
319                    errors.len()
320                ));
321                for error in &error_strs {
322                    out.println(&format!("  {}", error));
323                }
324            }
325        }
326    }
327
328    Ok(())
329}
330
331fn handle_explain(cmd: ExplainCommand) -> Result<()> {
332    let out = Output::new();
333    let limits = PolicyLimits::default();
334
335    // Load and compile policy
336    let policy_content = fs::read(&cmd.file)
337        .with_context(|| format!("failed to read policy: {}", cmd.file.display()))?;
338
339    let policy = compile_from_json_with_limits(&policy_content, &limits).map_err(|errors| {
340        anyhow!(
341            "policy compilation failed: {}",
342            errors
343                .iter()
344                .map(format_compile_error)
345                .collect::<Vec<_>>()
346                .join("; ")
347        )
348    })?;
349
350    // Load context
351    let ctx_content = fs::read(&cmd.context)
352        .with_context(|| format!("failed to read context: {}", cmd.context.display()))?;
353
354    let test_ctx: TestContext =
355        serde_json::from_slice(&ctx_content).with_context(|| "failed to parse context JSON")?;
356
357    let eval_ctx = build_eval_context(&test_ctx)?;
358
359    // Evaluate
360    let decision = auths_policy::evaluate3(&policy, &eval_ctx);
361    let hash = hex::encode(policy.source_hash());
362
363    if is_json_mode() {
364        JsonResponse::success(
365            "policy explain",
366            ExplainOutput {
367                decision: format!("{:?}", decision.outcome),
368                reason_code: format!("{:?}", decision.reason),
369                message: decision.message.clone(),
370                policy_hash: hash,
371            },
372        )
373        .print()?;
374    } else {
375        let decision_str = match decision.outcome {
376            Outcome::Allow => out.success("ALLOW"),
377            Outcome::Deny => out.error("DENY"),
378            Outcome::Indeterminate => out.warn("INDETERMINATE"),
379        };
380        out.println(&format!("Decision: {}", decision_str));
381        out.println(&format!("  Reason: {:?}", decision.reason));
382        out.println(&format!("  Message: {}", decision.message));
383        out.println(&format!("Policy hash: {}", hash));
384    }
385
386    Ok(())
387}
388
389fn handle_test(cmd: TestCommand) -> Result<()> {
390    let out = Output::new();
391    let limits = PolicyLimits::default();
392
393    // Load and compile policy
394    let policy_content = fs::read(&cmd.file)
395        .with_context(|| format!("failed to read policy: {}", cmd.file.display()))?;
396
397    let policy = compile_from_json_with_limits(&policy_content, &limits).map_err(|errors| {
398        anyhow!(
399            "policy compilation failed: {}",
400            errors
401                .iter()
402                .map(format_compile_error)
403                .collect::<Vec<_>>()
404                .join("; ")
405        )
406    })?;
407
408    // Load test suite
409    let tests_content = fs::read(&cmd.tests)
410        .with_context(|| format!("failed to read tests: {}", cmd.tests.display()))?;
411
412    let test_cases: Vec<TestCase> = serde_json::from_slice(&tests_content)
413        .with_context(|| "failed to parse test suite JSON")?;
414
415    let mut results: Vec<TestResult> = Vec::new();
416    let mut passed = 0;
417    let mut failed = 0;
418
419    for test in test_cases {
420        let eval_ctx = match build_eval_context(&test.context) {
421            Ok(ctx) => ctx,
422            Err(e) => {
423                results.push(TestResult {
424                    name: test.name.clone(),
425                    passed: false,
426                    expected: test.expect.clone(),
427                    actual: "ERROR".into(),
428                    message: Some(e.to_string()),
429                });
430                failed += 1;
431                continue;
432            }
433        };
434
435        let decision = auths_policy::evaluate3(&policy, &eval_ctx);
436        let actual = format!("{:?}", decision.outcome);
437        let expected_normalized = normalize_outcome(&test.expect);
438        let test_passed = actual == expected_normalized;
439
440        if test_passed {
441            passed += 1;
442        } else {
443            failed += 1;
444        }
445
446        results.push(TestResult {
447            name: test.name,
448            passed: test_passed,
449            expected: expected_normalized,
450            actual,
451            message: if test_passed {
452                None
453            } else {
454                Some(decision.message.clone())
455            },
456        });
457    }
458
459    let total = passed + failed;
460
461    if is_json_mode() {
462        JsonResponse::success(
463            "policy test",
464            TestOutput {
465                passed,
466                failed,
467                total,
468                results,
469            },
470        )
471        .print()?;
472    } else {
473        for result in &results {
474            let status = if result.passed {
475                out.success("ok")
476            } else {
477                out.error("FAIL")
478            };
479            out.println(&format!(
480                "  {} {}: {} (expected {})",
481                status, result.name, result.actual, result.expected
482            ));
483            if let Some(msg) = &result.message {
484                out.println(&format!("      {}", out.dim(msg)));
485            }
486        }
487        out.println(&format!("{}/{} passed", passed, total));
488    }
489
490    if failed > 0 {
491        anyhow::bail!("{} test(s) failed", failed);
492    }
493
494    Ok(())
495}
496
497fn handle_diff(cmd: DiffCommand) -> Result<()> {
498    let out = Output::new();
499
500    // Parse both policy files (don't need full compilation for structural diff)
501    let old_content = fs::read(&cmd.old)
502        .with_context(|| format!("failed to read old policy: {}", cmd.old.display()))?;
503    let new_content = fs::read(&cmd.new)
504        .with_context(|| format!("failed to read new policy: {}", cmd.new.display()))?;
505
506    let old_expr: Expr =
507        serde_json::from_slice(&old_content).with_context(|| "failed to parse old policy JSON")?;
508    let new_expr: Expr =
509        serde_json::from_slice(&new_content).with_context(|| "failed to parse new policy JSON")?;
510
511    let changes = compute_policy_diff(&old_expr, &new_expr);
512    let risk_score = overall_risk_score(&changes);
513
514    if is_json_mode() {
515        JsonResponse::success(
516            "policy diff",
517            DiffOutput {
518                changes: changes
519                    .iter()
520                    .map(|c| DiffChange {
521                        kind: c.kind.clone(),
522                        description: c.description.clone(),
523                        risk: c.risk.clone(),
524                    })
525                    .collect(),
526                risk_score: risk_score.clone(),
527            },
528        )
529        .print()?;
530    } else if changes.is_empty() {
531        out.println("No changes detected");
532    } else {
533        out.println("Changes:");
534        for change in &changes {
535            let risk_marker = match change.risk.as_str() {
536                "HIGH" => out.error("HIGH RISK"),
537                "MEDIUM" => out.warn("MEDIUM"),
538                _ => out.dim("LOW"),
539            };
540            let kind_marker = match change.kind.as_str() {
541                "added" => "+",
542                "removed" => "-",
543                "changed" => "~",
544                _ => "?",
545            };
546            out.println(&format!(
547                "  {} {}: {} [{}]",
548                kind_marker, change.description, risk_marker, change.risk
549            ));
550        }
551        out.println("");
552        let risk_display = match risk_score.as_str() {
553            "HIGH" => out.error(&risk_score),
554            "MEDIUM" => out.warn(&risk_score),
555            _ => out.dim(&risk_score),
556        };
557        out.println(&format!("Risk score: {}", risk_display));
558    }
559
560    Ok(())
561}
562
563// ── Helpers ─────────────────────────────────────────────────────────────
564
565fn format_compile_error(error: &CompileError) -> String {
566    format!("at {}: {}", error.path, error.message)
567}
568
569struct PolicyStats {
570    nodes: u32,
571    depth: u32,
572}
573
574fn compute_policy_stats(expr: &CompiledExpr) -> PolicyStats {
575    fn count_nodes(expr: &CompiledExpr) -> u32 {
576        match expr {
577            CompiledExpr::True | CompiledExpr::False => 1,
578            CompiledExpr::And(children) | CompiledExpr::Or(children) => {
579                1 + children.iter().map(count_nodes).sum::<u32>()
580            }
581            CompiledExpr::Not(inner) => 1 + count_nodes(inner),
582            _ => 1,
583        }
584    }
585
586    fn compute_depth(expr: &CompiledExpr) -> u32 {
587        match expr {
588            CompiledExpr::True | CompiledExpr::False => 1,
589            CompiledExpr::And(children) | CompiledExpr::Or(children) => {
590                1 + children.iter().map(compute_depth).max().unwrap_or(0)
591            }
592            CompiledExpr::Not(inner) => 1 + compute_depth(inner),
593            _ => 1,
594        }
595    }
596
597    PolicyStats {
598        nodes: count_nodes(expr),
599        depth: compute_depth(expr),
600    }
601}
602
603fn build_eval_context(test: &TestContext) -> Result<EvalContext> {
604    let mut ctx = EvalContext::try_from_strings(Utc::now(), &test.issuer, &test.subject)
605        .map_err(|e| anyhow!("invalid DID: {}", e))?;
606
607    ctx = ctx.revoked(test.revoked);
608    ctx = ctx.chain_depth(test.chain_depth);
609
610    for cap in &test.capabilities {
611        let canonical = auths_policy::CanonicalCapability::parse(cap)
612            .map_err(|e| anyhow!("invalid capability '{}': {}", cap, e))?;
613        ctx = ctx.capability(canonical);
614    }
615
616    if let Some(role) = &test.role {
617        ctx = ctx.role(role.clone());
618    }
619
620    if let Some(exp) = test.expires_at {
621        ctx = ctx.expires_at(exp);
622    }
623
624    if let Some(ts) = test.timestamp {
625        ctx = ctx.timestamp(ts);
626    }
627
628    if let Some(repo) = &test.repo {
629        ctx = ctx.repo(repo.clone());
630    }
631
632    if let Some(git_ref) = &test.git_ref {
633        ctx = ctx.git_ref(git_ref.clone());
634    }
635
636    if !test.paths.is_empty() {
637        ctx = ctx.paths(test.paths.clone());
638    }
639
640    if let Some(env) = &test.environment {
641        ctx = ctx.environment(env.clone());
642    }
643
644    Ok(ctx)
645}
646
647fn normalize_outcome(s: &str) -> String {
648    match s.to_lowercase().as_str() {
649        "allow" => "Allow".into(),
650        "deny" => "Deny".into(),
651        "indeterminate" => "Indeterminate".into(),
652        _ => s.to_string(),
653    }
654}
655
656use crate::commands::executable::ExecutableCommand;
657use crate::config::CliConfig;
658
659impl ExecutableCommand for PolicyCommand {
660    fn execute(&self, _ctx: &CliConfig) -> Result<()> {
661        handle_policy(self.clone())
662    }
663}
664
665#[cfg(test)]
666mod tests {
667    use super::*;
668    use auths_sdk::workflows::policy_diff::{
669        PolicyChange, compute_policy_diff, overall_risk_score,
670    };
671
672    #[test]
673    fn test_normalize_outcome() {
674        assert_eq!(normalize_outcome("allow"), "Allow");
675        assert_eq!(normalize_outcome("Allow"), "Allow");
676        assert_eq!(normalize_outcome("ALLOW"), "Allow");
677        assert_eq!(normalize_outcome("deny"), "Deny");
678        assert_eq!(normalize_outcome("Deny"), "Deny");
679        assert_eq!(normalize_outcome("indeterminate"), "Indeterminate");
680    }
681
682    #[test]
683    fn test_overall_risk_score() {
684        let high = vec![PolicyChange {
685            kind: "removed".into(),
686            description: "NotRevoked".into(),
687            risk: "HIGH".into(),
688        }];
689        assert_eq!(overall_risk_score(&high), "HIGH");
690
691        let medium = vec![PolicyChange {
692            kind: "added".into(),
693            description: "HasCapability(sign)".into(),
694            risk: "MEDIUM".into(),
695        }];
696        assert_eq!(overall_risk_score(&medium), "MEDIUM");
697
698        let low = vec![PolicyChange {
699            kind: "added".into(),
700            description: "RepoIs(org/repo)".into(),
701            risk: "LOW".into(),
702        }];
703        assert_eq!(overall_risk_score(&low), "LOW");
704
705        assert_eq!(overall_risk_score(&[]), "LOW");
706    }
707
708    #[test]
709    fn test_collect_predicates_via_diff() {
710        let old = Expr::And(vec![Expr::NotRevoked, Expr::HasCapability("sign".into())]);
711        let new = Expr::And(vec![Expr::NotRevoked]);
712        let changes = compute_policy_diff(&old, &new);
713        assert!(
714            changes
715                .iter()
716                .any(|c| c.description.contains("HasCapability") && c.kind == "removed")
717        );
718    }
719
720    #[test]
721    fn test_structural_change_and_to_or() {
722        let old = Expr::And(vec![Expr::True]);
723        let new = Expr::Or(vec![Expr::True]);
724        let changes = compute_policy_diff(&old, &new);
725        let structural = changes.iter().find(|c| c.kind == "changed");
726        assert!(structural.is_some());
727        assert_eq!(structural.unwrap().risk, "HIGH");
728    }
729}