use clap::Args;
use miette::{miette, IntoDiagnostic, Report, Result, WrapErr};
use owo_colors::OwoColorize;
use serde::de::{DeserializeSeed, IntoDeserializer};
use serde::{Deserialize, Deserializer};
use std::collections::BTreeSet;
use std::io::{BufReader, Write};
use std::path::Path;
use cedar_policy::*;
use crate::{CedarExitCode, OptionalSchemaArgs, PoliciesArgs, RequestJSON};
#[derive(Args, Debug)]
pub struct RunTestsArgs {
#[command(flatten)]
pub policies: PoliciesArgs,
#[arg(long, value_name = "FILE")]
pub tests: String,
#[command(flatten)]
pub schema: OptionalSchemaArgs,
}
#[derive(Clone, Debug)]
enum TestResult {
Pass,
Fail(String),
}
fn compare_test_decisions(test: &TestCase, ans: &Response) -> TestResult {
if ans.decision() == test.decision.into() {
let mut errors = Vec::new();
let reason = ans.diagnostics().reason().collect::<BTreeSet<_>>();
let missing_reason = test
.reason
.iter()
.filter(|r| !reason.contains(&PolicyId::new(r)))
.collect::<Vec<_>>();
if !missing_reason.is_empty() {
errors.push(format!(
"missing reason(s): {}",
missing_reason
.into_iter()
.map(|r| format!("`{r}`"))
.collect::<Vec<_>>()
.join(", ")
));
}
let num_errors = ans.diagnostics().errors().count();
if num_errors != test.num_errors {
errors.push(format!(
"expected {} error(s), but got {} runtime error(s){}",
test.num_errors,
num_errors,
if num_errors == 0 {
"".to_string()
} else {
format!(
": {}",
ans.diagnostics()
.errors()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", ")
)
},
));
}
if errors.is_empty() {
TestResult::Pass
} else {
TestResult::Fail(errors.join("; "))
}
} else {
TestResult::Fail(format!(
"expected {:?}, got {:?}",
test.decision,
ans.decision()
))
}
}
fn run_one_test(
policies: &PolicySet,
test: &serde_json::Value,
validator: Option<&Validator>,
) -> Result<TestResult> {
let test = CheckedTestCaseSeed(validator.map(Validator::schema))
.deserialize(test.into_deserializer())
.into_diagnostic()?;
if let Some(validator) = validator {
let val_res = validator.validate(policies, cedar_policy::ValidationMode::Strict);
if !val_res.validation_passed_without_warnings() {
return Err(Report::new(val_res).wrap_err("policy set validation failed"));
}
}
let ans = Authorizer::new().is_authorized(&test.request, policies, &test.entities);
Ok(compare_test_decisions(&test, &ans))
}
fn run_tests_inner(args: &RunTestsArgs) -> Result<CedarExitCode> {
let policies = args.policies.get_policy_set()?;
let tests = load_partial_tests(&args.tests)?;
let validator = args.schema.get_schema()?.map(Validator::new);
let mut total_fails: usize = 0;
println!("running {} test(s)", tests.len());
for test in tests.iter() {
if let Some(name) = test["name"].as_str() {
print!(" test {name} ... ");
} else {
print!(" test (unnamed) ... ");
}
std::io::stdout().flush().into_diagnostic()?;
match run_one_test(&policies, test, validator.as_ref()) {
Ok(TestResult::Pass) => {
println!(
"{}",
"ok".if_supports_color(owo_colors::Stream::Stdout, |s| s.green())
);
}
Ok(TestResult::Fail(reason)) => {
total_fails += 1;
println!(
"{}: {}",
"fail".if_supports_color(owo_colors::Stream::Stdout, |s| s.red()),
reason
);
}
Err(e) => {
total_fails += 1;
println!(
"{}:\n {:?}",
"error".if_supports_color(owo_colors::Stream::Stdout, |s| s.red()),
e
);
}
}
}
println!(
"results: {} {}, {} {}",
tests.len() - total_fails,
if total_fails == 0 {
"passed"
.if_supports_color(owo_colors::Stream::Stdout, |s| s.green())
.to_string()
} else {
"passed".to_string()
},
total_fails,
if total_fails != 0 {
"failed"
.if_supports_color(owo_colors::Stream::Stdout, |s| s.red())
.to_string()
} else {
"failed".to_string()
},
);
Ok(if total_fails != 0 {
CedarExitCode::Failure
} else {
CedarExitCode::Success
})
}
pub fn run_tests(args: &RunTestsArgs) -> CedarExitCode {
run_tests_inner(args).unwrap_or_else(|e| {
println!("{e:?}");
CedarExitCode::Failure
})
}
#[derive(Copy, Clone, Debug, Deserialize)]
enum ExpectedDecision {
#[serde(rename = "allow")]
Allow,
#[serde(rename = "deny")]
Deny,
}
impl From<ExpectedDecision> for Decision {
fn from(value: ExpectedDecision) -> Self {
match value {
ExpectedDecision::Allow => Decision::Allow,
ExpectedDecision::Deny => Decision::Deny,
}
}
}
#[derive(Clone, Debug, Deserialize)]
struct UncheckedTestCase {
request: RequestJSON,
entities: serde_json::Value,
decision: ExpectedDecision,
reason: Vec<String>,
num_errors: usize,
}
#[derive(Clone, Debug)]
struct TestCase {
request: Request,
entities: Entities,
decision: ExpectedDecision,
reason: Vec<String>,
num_errors: usize,
}
struct CheckedTestCaseSeed<'a>(Option<&'a Schema>);
impl<'de, 'a> DeserializeSeed<'de> for CheckedTestCaseSeed<'a> {
type Value = TestCase;
fn deserialize<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
let UncheckedTestCase {
request,
entities,
decision,
reason,
num_errors,
} = UncheckedTestCase::deserialize(deserializer)?;
let principal = request
.principal
.ok_or_else(|| serde::de::Error::missing_field("principal"))?;
let principal = principal.parse().map_err(|e| {
serde::de::Error::custom(format!("failed to parse principal `{principal}`: {e}",))
})?;
let action = request
.action
.ok_or_else(|| serde::de::Error::missing_field("action"))?;
let action = action.parse().map_err(|e| {
serde::de::Error::custom(format!("failed to parse action `{action}`: {e}",))
})?;
let resource = request
.resource
.ok_or_else(|| serde::de::Error::missing_field("resource"))?;
let resource = resource.parse().map_err(|e| {
serde::de::Error::custom(format!("failed to parse resource `{resource}`: {e}",))
})?;
let context = Context::from_json_value(request.context.clone(), None).map_err(|e| {
serde::de::Error::custom(format!(
"failed to parse context `{}`: {}",
request.context, e
))
})?;
let request = Request::new(principal, action, resource, context, self.0)
.map_err(|e| serde::de::Error::custom(format!("failed to create request: {e}")))?;
let entities = Entities::from_json_value(entities, self.0)
.map_err(|e| serde::de::Error::custom(format!("failed to parse entities: {e}")))?;
Ok(TestCase {
request,
entities,
decision,
reason,
num_errors,
})
}
}
fn load_partial_tests(tests_filename: impl AsRef<Path>) -> Result<Vec<serde_json::Value>> {
match std::fs::OpenOptions::new()
.read(true)
.open(tests_filename.as_ref())
{
Ok(f) => {
let reader = BufReader::new(f);
serde_json::from_reader(reader).map_err(|e| {
miette!(
"failed to parse tests from file {}: {e}",
tests_filename.as_ref().display()
)
})
}
Err(e) => Err(e).into_diagnostic().wrap_err_with(|| {
format!(
"failed to open test file {}",
tests_filename.as_ref().display()
)
}),
}
}