Skip to main content

auths_cli/commands/
policy.rs

1//! Policy governance commands for Auths.
2//!
3//! Commands for linting, compiling, testing, and comparing policies.
4
5use crate::ux::format::{JsonResponse, Output, is_json_mode};
6use anyhow::{Context, Result, anyhow};
7use auths_policy::{
8    CompileError, CompiledExpr, EvalContext, Expr, Outcome, PolicyLimits,
9    compile_from_json_with_limits,
10};
11use auths_sdk::workflows::policy_diff::{compute_policy_diff, overall_risk_score};
12use chrono::{DateTime, Utc};
13use clap::{Parser, Subcommand};
14use serde::{Deserialize, Serialize};
15use std::fs;
16use std::path::PathBuf;
17
18/// Manage authorization policies.
19#[derive(Parser, Debug, Clone)]
20#[command(name = "policy", about = "Manage authorization policies")]
21pub struct PolicyCommand {
22    #[command(subcommand)]
23    pub command: PolicySubcommand,
24}
25
26#[derive(Subcommand, Debug, Clone)]
27pub enum PolicySubcommand {
28    /// Validate policy JSON syntax without full compilation.
29    Lint(LintCommand),
30
31    /// Compile a policy file with full validation.
32    Compile(CompileCommand),
33
34    /// Evaluate a policy against a context and show the decision.
35    Explain(ExplainCommand),
36
37    /// Run a policy against a test suite.
38    Test(TestCommand),
39
40    /// Compare two policies and show semantic differences.
41    Diff(DiffCommand),
42}
43
44/// Validate policy JSON syntax.
45#[derive(Parser, Debug, Clone)]
46pub struct LintCommand {
47    /// Path to the policy file (JSON).
48    pub file: PathBuf,
49}
50
51/// Compile a policy with full validation.
52#[derive(Parser, Debug, Clone)]
53pub struct CompileCommand {
54    /// Path to the policy file (JSON).
55    pub file: PathBuf,
56}
57
58/// Evaluate a policy against a context.
59#[derive(Parser, Debug, Clone)]
60pub struct ExplainCommand {
61    /// Path to the policy file (JSON).
62    pub file: PathBuf,
63
64    /// Path to the context file (JSON).
65    #[clap(long, short = 'c')]
66    pub context: PathBuf,
67}
68
69/// Run a policy against a test suite.
70#[derive(Parser, Debug, Clone)]
71pub struct TestCommand {
72    /// Path to the policy file (JSON).
73    pub file: PathBuf,
74
75    /// Path to the test suite file (JSON).
76    #[clap(long, short = 't')]
77    pub tests: PathBuf,
78}
79
80/// Compare two policies.
81#[derive(Parser, Debug, Clone)]
82pub struct DiffCommand {
83    /// Path to the old policy file (JSON).
84    pub old: PathBuf,
85
86    /// Path to the new policy file (JSON).
87    pub new: PathBuf,
88}
89
90// ── JSON Output Types ───────────────────────────────────────────────────
91
92#[derive(Debug, Serialize)]
93struct LintData {
94    bytes: usize,
95    byte_limit: usize,
96}
97
98#[derive(Debug, Serialize)]
99struct CompileData {
100    #[serde(skip_serializing_if = "Option::is_none")]
101    nodes: Option<u32>,
102    #[serde(skip_serializing_if = "Option::is_none")]
103    depth: Option<u32>,
104    #[serde(skip_serializing_if = "Option::is_none")]
105    hash: Option<String>,
106    #[serde(skip_serializing_if = "Vec::is_empty")]
107    errors: Vec<String>,
108}
109
110#[derive(Debug, Serialize)]
111struct ExplainOutput {
112    decision: String,
113    reason_code: String,
114    message: String,
115    policy_hash: String,
116}
117
118#[derive(Debug, Serialize)]
119struct TestOutput {
120    passed: usize,
121    failed: usize,
122    total: usize,
123    results: Vec<TestResult>,
124}
125
126#[derive(Debug, Serialize)]
127struct TestResult {
128    name: String,
129    passed: bool,
130    expected: String,
131    actual: String,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    message: Option<String>,
134}
135
136#[derive(Debug, Serialize)]
137struct DiffOutput {
138    changes: Vec<DiffChange>,
139    risk_score: String,
140}
141
142#[derive(Debug, Serialize)]
143struct DiffChange {
144    kind: String,
145    description: String,
146    risk: String,
147}
148
149// ── Test Suite Types ────────────────────────────────────────────────────
150
151#[derive(Debug, Deserialize)]
152struct TestCase {
153    name: String,
154    context: TestContext,
155    expect: String,
156}
157
158#[derive(Debug, Deserialize)]
159struct TestContext {
160    issuer: String,
161    subject: String,
162    #[serde(default)]
163    revoked: bool,
164    #[serde(default)]
165    capabilities: Vec<String>,
166    #[serde(default)]
167    role: Option<String>,
168    #[serde(default)]
169    expires_at: Option<DateTime<Utc>>,
170    #[serde(default)]
171    timestamp: Option<DateTime<Utc>>,
172    #[serde(default)]
173    chain_depth: u32,
174    #[serde(default)]
175    repo: Option<String>,
176    #[serde(default)]
177    git_ref: Option<String>,
178    #[serde(default)]
179    paths: Vec<String>,
180    #[serde(default)]
181    environment: Option<String>,
182}
183
184// ── Handler ─────────────────────────────────────────────────────────────
185
186pub fn handle_policy(cmd: PolicyCommand) -> Result<()> {
187    match cmd.command {
188        PolicySubcommand::Lint(lint) => handle_lint(lint),
189        PolicySubcommand::Compile(compile) => handle_compile(compile),
190        PolicySubcommand::Explain(explain) => handle_explain(explain),
191        PolicySubcommand::Test(test) => handle_test(test),
192        PolicySubcommand::Diff(diff) => handle_diff(diff),
193    }
194}
195
196fn handle_lint(cmd: LintCommand) -> Result<()> {
197    let out = Output::new();
198    let limits = PolicyLimits::default();
199
200    // Read the file
201    let content =
202        fs::read(&cmd.file).with_context(|| format!("failed to read {}", cmd.file.display()))?;
203
204    let bytes = content.len();
205
206    // Check size limit
207    if bytes > limits.max_json_bytes {
208        if is_json_mode() {
209            JsonResponse::<()>::error(
210                "policy lint",
211                format!(
212                    "file exceeds size limit: {} > {}",
213                    bytes, limits.max_json_bytes
214                ),
215            )
216            .print()?;
217        } else {
218            out.println(&format!(
219                "{} File exceeds size limit: {} bytes (limit: {})",
220                out.error("x"),
221                bytes,
222                limits.max_json_bytes
223            ));
224        }
225        anyhow::bail!(
226            "file exceeds size limit: {} > {}",
227            bytes,
228            limits.max_json_bytes
229        );
230    }
231
232    // Parse JSON
233    match serde_json::from_slice::<Expr>(&content) {
234        Ok(_expr) => {
235            if is_json_mode() {
236                JsonResponse::success(
237                    "policy lint",
238                    LintData {
239                        bytes,
240                        byte_limit: limits.max_json_bytes,
241                    },
242                )
243                .print()?;
244            } else {
245                out.println(&format!("{} Valid JSON", out.success("ok")));
246                out.println(&format!("{} All ops recognized", out.success("ok")));
247                out.println(&format!(
248                    "{} {} bytes (limit: {})",
249                    out.success("ok"),
250                    bytes,
251                    limits.max_json_bytes
252                ));
253            }
254        }
255        Err(e) => {
256            if is_json_mode() {
257                JsonResponse::<()>::error("policy lint", e.to_string()).print()?;
258            } else {
259                out.println(&format!("{} Invalid JSON: {}", out.error("x"), e));
260            }
261            anyhow::bail!("lint failed: {}", e);
262        }
263    }
264
265    Ok(())
266}
267
268fn handle_compile(cmd: CompileCommand) -> Result<()> {
269    let out = Output::new();
270    let limits = PolicyLimits::default();
271
272    let content =
273        fs::read(&cmd.file).with_context(|| format!("failed to read {}", cmd.file.display()))?;
274
275    match compile_from_json_with_limits(&content, &limits) {
276        Ok(policy) => {
277            let stats = compute_policy_stats(policy.expr());
278            let hash = hex::encode(policy.source_hash());
279
280            if is_json_mode() {
281                JsonResponse::success(
282                    "policy compile",
283                    CompileData {
284                        nodes: Some(stats.nodes),
285                        depth: Some(stats.depth),
286                        hash: Some(hash),
287                        errors: vec![],
288                    },
289                )
290                .print()?;
291            } else {
292                out.println(&format!("{} Compiled successfully", out.success("ok")));
293                out.println(&format!(
294                    "  Nodes: {} (limit: {})",
295                    stats.nodes, limits.max_total_nodes
296                ));
297                out.println(&format!(
298                    "  Depth: {} (limit: {})",
299                    stats.depth, limits.max_depth
300                ));
301                out.println(&format!("  Hash:  {}", hash));
302            }
303        }
304        Err(errors) => {
305            let error_strs: Vec<String> = errors.iter().map(format_compile_error).collect();
306
307            if is_json_mode() {
308                JsonResponse {
309                    success: false,
310                    command: "policy compile".to_string(),
311                    data: Some(CompileData {
312                        nodes: None,
313                        depth: None,
314                        hash: None,
315                        errors: error_strs,
316                    }),
317                    error: None,
318                }
319                .print()?;
320            } else {
321                out.println(&format!(
322                    "{} Compilation failed ({} errors):",
323                    out.error("x"),
324                    errors.len()
325                ));
326                for error in &error_strs {
327                    out.println(&format!("  {}", error));
328                }
329            }
330        }
331    }
332
333    Ok(())
334}
335
336fn handle_explain(cmd: ExplainCommand) -> Result<()> {
337    let out = Output::new();
338    let limits = PolicyLimits::default();
339
340    // Load and compile policy
341    let policy_content = fs::read(&cmd.file)
342        .with_context(|| format!("failed to read policy: {}", cmd.file.display()))?;
343
344    let policy = compile_from_json_with_limits(&policy_content, &limits).map_err(|errors| {
345        anyhow!(
346            "policy compilation failed: {}",
347            errors
348                .iter()
349                .map(format_compile_error)
350                .collect::<Vec<_>>()
351                .join("; ")
352        )
353    })?;
354
355    // Load context
356    let ctx_content = fs::read(&cmd.context)
357        .with_context(|| format!("failed to read context: {}", cmd.context.display()))?;
358
359    let test_ctx: TestContext =
360        serde_json::from_slice(&ctx_content).with_context(|| "failed to parse context JSON")?;
361
362    let eval_ctx = build_eval_context(&test_ctx)?;
363
364    // Evaluate
365    let decision = auths_policy::evaluate3(&policy, &eval_ctx);
366    let hash = hex::encode(policy.source_hash());
367
368    if is_json_mode() {
369        JsonResponse::success(
370            "policy explain",
371            ExplainOutput {
372                decision: format!("{:?}", decision.outcome),
373                reason_code: format!("{:?}", decision.reason),
374                message: decision.message.clone(),
375                policy_hash: hash,
376            },
377        )
378        .print()?;
379    } else {
380        let decision_str = match decision.outcome {
381            Outcome::Allow => out.success("ALLOW"),
382            Outcome::Deny => out.error("DENY"),
383            Outcome::Indeterminate => out.warn("INDETERMINATE"),
384            Outcome::RequiresApproval => out.warn("REQUIRES_APPROVAL"),
385        };
386        out.println(&format!("Decision: {}", decision_str));
387        out.println(&format!("  Reason: {:?}", decision.reason));
388        out.println(&format!("  Message: {}", decision.message));
389        out.println(&format!("Policy hash: {}", hash));
390    }
391
392    Ok(())
393}
394
395fn handle_test(cmd: TestCommand) -> Result<()> {
396    let out = Output::new();
397    let limits = PolicyLimits::default();
398
399    // Load and compile policy
400    let policy_content = fs::read(&cmd.file)
401        .with_context(|| format!("failed to read policy: {}", cmd.file.display()))?;
402
403    let policy = compile_from_json_with_limits(&policy_content, &limits).map_err(|errors| {
404        anyhow!(
405            "policy compilation failed: {}",
406            errors
407                .iter()
408                .map(format_compile_error)
409                .collect::<Vec<_>>()
410                .join("; ")
411        )
412    })?;
413
414    // Load test suite
415    let tests_content = fs::read(&cmd.tests)
416        .with_context(|| format!("failed to read tests: {}", cmd.tests.display()))?;
417
418    let test_cases: Vec<TestCase> = serde_json::from_slice(&tests_content)
419        .with_context(|| "failed to parse test suite JSON")?;
420
421    let mut results: Vec<TestResult> = Vec::new();
422    let mut passed = 0;
423    let mut failed = 0;
424
425    for test in test_cases {
426        let eval_ctx = match build_eval_context(&test.context) {
427            Ok(ctx) => ctx,
428            Err(e) => {
429                results.push(TestResult {
430                    name: test.name.clone(),
431                    passed: false,
432                    expected: test.expect.clone(),
433                    actual: "ERROR".into(),
434                    message: Some(e.to_string()),
435                });
436                failed += 1;
437                continue;
438            }
439        };
440
441        let decision = auths_policy::evaluate3(&policy, &eval_ctx);
442        let actual = format!("{:?}", decision.outcome);
443        let expected_normalized = normalize_outcome(&test.expect);
444        let test_passed = actual == expected_normalized;
445
446        if test_passed {
447            passed += 1;
448        } else {
449            failed += 1;
450        }
451
452        results.push(TestResult {
453            name: test.name,
454            passed: test_passed,
455            expected: expected_normalized,
456            actual,
457            message: if test_passed {
458                None
459            } else {
460                Some(decision.message.clone())
461            },
462        });
463    }
464
465    let total = passed + failed;
466
467    if is_json_mode() {
468        JsonResponse::success(
469            "policy test",
470            TestOutput {
471                passed,
472                failed,
473                total,
474                results,
475            },
476        )
477        .print()?;
478    } else {
479        for result in &results {
480            let status = if result.passed {
481                out.success("ok")
482            } else {
483                out.error("FAIL")
484            };
485            out.println(&format!(
486                "  {} {}: {} (expected {})",
487                status, result.name, result.actual, result.expected
488            ));
489            if let Some(msg) = &result.message {
490                out.println(&format!("      {}", out.dim(msg)));
491            }
492        }
493        out.println(&format!("{}/{} passed", passed, total));
494    }
495
496    if failed > 0 {
497        anyhow::bail!("{} test(s) failed", failed);
498    }
499
500    Ok(())
501}
502
503fn handle_diff(cmd: DiffCommand) -> Result<()> {
504    let out = Output::new();
505
506    // Parse both policy files (don't need full compilation for structural diff)
507    let old_content = fs::read(&cmd.old)
508        .with_context(|| format!("failed to read old policy: {}", cmd.old.display()))?;
509    let new_content = fs::read(&cmd.new)
510        .with_context(|| format!("failed to read new policy: {}", cmd.new.display()))?;
511
512    let old_expr: Expr =
513        serde_json::from_slice(&old_content).with_context(|| "failed to parse old policy JSON")?;
514    let new_expr: Expr =
515        serde_json::from_slice(&new_content).with_context(|| "failed to parse new policy JSON")?;
516
517    let changes = compute_policy_diff(&old_expr, &new_expr);
518    let risk_score = overall_risk_score(&changes);
519
520    if is_json_mode() {
521        JsonResponse::success(
522            "policy diff",
523            DiffOutput {
524                changes: changes
525                    .iter()
526                    .map(|c| DiffChange {
527                        kind: c.kind.clone(),
528                        description: c.description.clone(),
529                        risk: c.risk.clone(),
530                    })
531                    .collect(),
532                risk_score: risk_score.clone(),
533            },
534        )
535        .print()?;
536    } else if changes.is_empty() {
537        out.println("No changes detected");
538    } else {
539        out.println("Changes:");
540        for change in &changes {
541            let risk_marker = match change.risk.as_str() {
542                "HIGH" => out.error("HIGH RISK"),
543                "MEDIUM" => out.warn("MEDIUM"),
544                _ => out.dim("LOW"),
545            };
546            let kind_marker = match change.kind.as_str() {
547                "added" => "+",
548                "removed" => "-",
549                "changed" => "~",
550                _ => "?",
551            };
552            out.println(&format!(
553                "  {} {}: {} [{}]",
554                kind_marker, change.description, risk_marker, change.risk
555            ));
556        }
557        out.println("");
558        let risk_display = match risk_score.as_str() {
559            "HIGH" => out.error(&risk_score),
560            "MEDIUM" => out.warn(&risk_score),
561            _ => out.dim(&risk_score),
562        };
563        out.println(&format!("Risk score: {}", risk_display));
564    }
565
566    Ok(())
567}
568
569// ── Helpers ─────────────────────────────────────────────────────────────
570
571fn format_compile_error(error: &CompileError) -> String {
572    format!("at {}: {}", error.path, error.message)
573}
574
575struct PolicyStats {
576    nodes: u32,
577    depth: u32,
578}
579
580fn compute_policy_stats(expr: &CompiledExpr) -> PolicyStats {
581    fn count_nodes(expr: &CompiledExpr) -> u32 {
582        match expr {
583            CompiledExpr::True | CompiledExpr::False => 1,
584            CompiledExpr::And(children) | CompiledExpr::Or(children) => {
585                1 + children.iter().map(count_nodes).sum::<u32>()
586            }
587            CompiledExpr::Not(inner) => 1 + count_nodes(inner),
588            _ => 1,
589        }
590    }
591
592    fn compute_depth(expr: &CompiledExpr) -> u32 {
593        match expr {
594            CompiledExpr::True | CompiledExpr::False => 1,
595            CompiledExpr::And(children) | CompiledExpr::Or(children) => {
596                1 + children.iter().map(compute_depth).max().unwrap_or(0)
597            }
598            CompiledExpr::Not(inner) => 1 + compute_depth(inner),
599            _ => 1,
600        }
601    }
602
603    PolicyStats {
604        nodes: count_nodes(expr),
605        depth: compute_depth(expr),
606    }
607}
608
609fn build_eval_context(test: &TestContext) -> Result<EvalContext> {
610    let mut ctx = EvalContext::try_from_strings(Utc::now(), &test.issuer, &test.subject)
611        .map_err(|e| anyhow!("invalid DID: {}", e))?;
612
613    ctx = ctx.revoked(test.revoked);
614    ctx = ctx.chain_depth(test.chain_depth);
615
616    for cap in &test.capabilities {
617        let canonical = auths_policy::CanonicalCapability::parse(cap)
618            .map_err(|e| anyhow!("invalid capability '{}': {}", cap, e))?;
619        ctx = ctx.capability(canonical);
620    }
621
622    if let Some(role) = &test.role {
623        ctx = ctx.role(role.clone());
624    }
625
626    if let Some(exp) = test.expires_at {
627        ctx = ctx.expires_at(exp);
628    }
629
630    if let Some(ts) = test.timestamp {
631        ctx = ctx.timestamp(ts);
632    }
633
634    if let Some(repo) = &test.repo {
635        ctx = ctx.repo(repo.clone());
636    }
637
638    if let Some(git_ref) = &test.git_ref {
639        ctx = ctx.git_ref(git_ref.clone());
640    }
641
642    if !test.paths.is_empty() {
643        ctx = ctx.paths(test.paths.clone());
644    }
645
646    if let Some(env) = &test.environment {
647        ctx = ctx.environment(env.clone());
648    }
649
650    Ok(ctx)
651}
652
653fn normalize_outcome(s: &str) -> String {
654    match s.to_lowercase().as_str() {
655        "allow" => "Allow".into(),
656        "deny" => "Deny".into(),
657        "indeterminate" => "Indeterminate".into(),
658        _ => s.to_string(),
659    }
660}
661
662use crate::commands::executable::ExecutableCommand;
663use crate::config::CliConfig;
664
665impl ExecutableCommand for PolicyCommand {
666    fn execute(&self, _ctx: &CliConfig) -> Result<()> {
667        handle_policy(self.clone())
668    }
669}
670
671#[cfg(test)]
672mod tests {
673    use super::*;
674    use auths_sdk::workflows::policy_diff::{
675        PolicyChange, compute_policy_diff, overall_risk_score,
676    };
677
678    #[test]
679    fn test_normalize_outcome() {
680        assert_eq!(normalize_outcome("allow"), "Allow");
681        assert_eq!(normalize_outcome("Allow"), "Allow");
682        assert_eq!(normalize_outcome("ALLOW"), "Allow");
683        assert_eq!(normalize_outcome("deny"), "Deny");
684        assert_eq!(normalize_outcome("Deny"), "Deny");
685        assert_eq!(normalize_outcome("indeterminate"), "Indeterminate");
686    }
687
688    #[test]
689    fn test_overall_risk_score() {
690        let high = vec![PolicyChange {
691            kind: "removed".into(),
692            description: "NotRevoked".into(),
693            risk: "HIGH".into(),
694        }];
695        assert_eq!(overall_risk_score(&high), "HIGH");
696
697        let medium = vec![PolicyChange {
698            kind: "added".into(),
699            description: "HasCapability(sign)".into(),
700            risk: "MEDIUM".into(),
701        }];
702        assert_eq!(overall_risk_score(&medium), "MEDIUM");
703
704        let low = vec![PolicyChange {
705            kind: "added".into(),
706            description: "RepoIs(org/repo)".into(),
707            risk: "LOW".into(),
708        }];
709        assert_eq!(overall_risk_score(&low), "LOW");
710
711        assert_eq!(overall_risk_score(&[]), "LOW");
712    }
713
714    #[test]
715    fn test_collect_predicates_via_diff() {
716        let old = Expr::And(vec![Expr::NotRevoked, Expr::HasCapability("sign".into())]);
717        let new = Expr::And(vec![Expr::NotRevoked]);
718        let changes = compute_policy_diff(&old, &new);
719        assert!(
720            changes
721                .iter()
722                .any(|c| c.description.contains("HasCapability") && c.kind == "removed")
723        );
724    }
725
726    #[test]
727    fn test_structural_change_and_to_or() {
728        let old = Expr::And(vec![Expr::True]);
729        let new = Expr::Or(vec![Expr::True]);
730        let changes = compute_policy_diff(&old, &new);
731        let structural = changes.iter().find(|c| c.kind == "changed");
732        assert!(structural.is_some());
733        assert_eq!(structural.unwrap().risk, "HIGH");
734    }
735}