cedar_policy_cli/command/
run_test.rs1use clap::Args;
18use miette::{miette, IntoDiagnostic, Report, Result, WrapErr};
19use owo_colors::OwoColorize;
20use serde::de::{DeserializeSeed, IntoDeserializer};
21use serde::{Deserialize, Deserializer};
22use std::collections::BTreeSet;
23use std::io::{BufReader, Write};
24use std::path::Path;
25
26use cedar_policy::*;
27
28use crate::{CedarExitCode, OptionalSchemaArgs, PoliciesArgs, RequestJSON};
29
30#[derive(Args, Debug)]
31pub struct RunTestsArgs {
32 #[command(flatten)]
34 pub policies: PoliciesArgs,
35 #[arg(long, value_name = "FILE")]
37 pub tests: String,
38 #[command(flatten)]
39 pub schema: OptionalSchemaArgs,
40}
41
42#[derive(Clone, Debug)]
43enum TestResult {
44 Pass,
45 Fail(String),
46}
47
48fn compare_test_decisions(test: &TestCase, ans: &Response) -> TestResult {
50 if ans.decision() == test.decision.into() {
51 let mut errors = Vec::new();
52 let reason = ans.diagnostics().reason().collect::<BTreeSet<_>>();
53
54 let missing_reason = test
56 .reason
57 .iter()
58 .filter(|r| !reason.contains(&PolicyId::new(r)))
59 .collect::<Vec<_>>();
60
61 if !missing_reason.is_empty() {
62 errors.push(format!(
63 "missing reason(s): {}",
64 missing_reason
65 .into_iter()
66 .map(|r| format!("`{r}`"))
67 .collect::<Vec<_>>()
68 .join(", ")
69 ));
70 }
71
72 let num_errors = ans.diagnostics().errors().count();
74 if num_errors != test.num_errors {
75 errors.push(format!(
76 "expected {} error(s), but got {} runtime error(s){}",
77 test.num_errors,
78 num_errors,
79 if num_errors == 0 {
80 "".to_string()
81 } else {
82 format!(
83 ": {}",
84 ans.diagnostics()
85 .errors()
86 .map(|e| e.to_string())
87 .collect::<Vec<_>>()
88 .join(", ")
89 )
90 },
91 ));
92 }
93
94 if errors.is_empty() {
95 TestResult::Pass
96 } else {
97 TestResult::Fail(errors.join("; "))
98 }
99 } else {
100 TestResult::Fail(format!(
101 "expected {:?}, got {:?}",
102 test.decision,
103 ans.decision()
104 ))
105 }
106}
107
108fn run_one_test(
111 policies: &PolicySet,
112 test: &serde_json::Value,
113 validator: Option<&Validator>,
114) -> Result<TestResult> {
115 let test = CheckedTestCaseSeed(validator.map(Validator::schema))
116 .deserialize(test.into_deserializer())
117 .into_diagnostic()?;
118 if let Some(validator) = validator {
119 let val_res = validator.validate(policies, cedar_policy::ValidationMode::Strict);
120 if !val_res.validation_passed_without_warnings() {
121 return Err(Report::new(val_res).wrap_err("policy set validation failed"));
122 }
123 }
124 let ans = Authorizer::new().is_authorized(&test.request, policies, &test.entities);
125 Ok(compare_test_decisions(&test, &ans))
126}
127
128fn run_tests_inner(args: &RunTestsArgs) -> Result<CedarExitCode> {
129 let policies = args.policies.get_policy_set()?;
130 let tests = load_partial_tests(&args.tests)?;
131 let validator = args.schema.get_schema()?.map(Validator::new);
132
133 let mut total_fails: usize = 0;
134
135 println!("running {} test(s)", tests.len());
136 for test in tests.iter() {
137 if let Some(name) = test["name"].as_str() {
138 print!(" test {name} ... ");
139 } else {
140 print!(" test (unnamed) ... ");
141 }
142 std::io::stdout().flush().into_diagnostic()?;
143 match run_one_test(&policies, test, validator.as_ref()) {
144 Ok(TestResult::Pass) => {
145 println!(
146 "{}",
147 "ok".if_supports_color(owo_colors::Stream::Stdout, |s| s.green())
148 );
149 }
150 Ok(TestResult::Fail(reason)) => {
151 total_fails += 1;
152 println!(
153 "{}: {}",
154 "fail".if_supports_color(owo_colors::Stream::Stdout, |s| s.red()),
155 reason
156 );
157 }
158 Err(e) => {
159 total_fails += 1;
160 println!(
161 "{}:\n {:?}",
162 "error".if_supports_color(owo_colors::Stream::Stdout, |s| s.red()),
163 e
164 );
165 }
166 }
167 }
168
169 println!(
170 "results: {} {}, {} {}",
171 tests.len() - total_fails,
172 if total_fails == 0 {
173 "passed"
174 .if_supports_color(owo_colors::Stream::Stdout, |s| s.green())
175 .to_string()
176 } else {
177 "passed".to_string()
178 },
179 total_fails,
180 if total_fails != 0 {
181 "failed"
182 .if_supports_color(owo_colors::Stream::Stdout, |s| s.red())
183 .to_string()
184 } else {
185 "failed".to_string()
186 },
187 );
188
189 Ok(if total_fails != 0 {
190 CedarExitCode::Failure
191 } else {
192 CedarExitCode::Success
193 })
194}
195
196pub fn run_tests(args: &RunTestsArgs) -> CedarExitCode {
197 run_tests_inner(args).unwrap_or_else(|e| {
198 println!("{e:?}");
199 CedarExitCode::Failure
200 })
201}
202
203#[derive(Copy, Clone, Debug, Deserialize)]
204enum ExpectedDecision {
205 #[serde(rename = "allow")]
206 Allow,
207 #[serde(rename = "deny")]
208 Deny,
209}
210
211impl From<ExpectedDecision> for Decision {
212 fn from(value: ExpectedDecision) -> Self {
213 match value {
214 ExpectedDecision::Allow => Decision::Allow,
215 ExpectedDecision::Deny => Decision::Deny,
216 }
217 }
218}
219
220#[derive(Clone, Debug, Deserialize)]
221struct UncheckedTestCase {
222 request: RequestJSON,
223 entities: serde_json::Value,
224 decision: ExpectedDecision,
225 reason: Vec<String>,
226 num_errors: usize,
227}
228
229#[derive(Clone, Debug)]
230struct TestCase {
231 request: Request,
232 entities: Entities,
233 decision: ExpectedDecision,
234 reason: Vec<String>,
235 num_errors: usize,
236}
237
238struct CheckedTestCaseSeed<'a>(Option<&'a Schema>);
239
240impl<'de, 'a> DeserializeSeed<'de> for CheckedTestCaseSeed<'a> {
241 type Value = TestCase;
242
243 fn deserialize<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
244 where
245 D: Deserializer<'de>,
246 {
247 let UncheckedTestCase {
248 request,
249 entities,
250 decision,
251 reason,
252 num_errors,
253 } = UncheckedTestCase::deserialize(deserializer)?;
254
255 let principal = request
256 .principal
257 .ok_or_else(|| serde::de::Error::missing_field("principal"))?;
258 let principal = principal.parse().map_err(|e| {
259 serde::de::Error::custom(format!("failed to parse principal `{principal}`: {e}",))
260 })?;
261
262 let action = request
263 .action
264 .ok_or_else(|| serde::de::Error::missing_field("action"))?;
265 let action = action.parse().map_err(|e| {
266 serde::de::Error::custom(format!("failed to parse action `{action}`: {e}",))
267 })?;
268
269 let resource = request
270 .resource
271 .ok_or_else(|| serde::de::Error::missing_field("resource"))?;
272 let resource = resource.parse().map_err(|e| {
273 serde::de::Error::custom(format!("failed to parse resource `{resource}`: {e}",))
274 })?;
275
276 let context = Context::from_json_value(request.context.clone(), None).map_err(|e| {
277 serde::de::Error::custom(format!(
278 "failed to parse context `{}`: {}",
279 request.context, e
280 ))
281 })?;
282
283 let request = Request::new(principal, action, resource, context, self.0)
284 .map_err(|e| serde::de::Error::custom(format!("failed to create request: {e}")))?;
285
286 let entities = Entities::from_json_value(entities, self.0)
287 .map_err(|e| serde::de::Error::custom(format!("failed to parse entities: {e}")))?;
288
289 Ok(TestCase {
290 request,
291 entities,
292 decision,
293 reason,
294 num_errors,
295 })
296 }
297}
298
299fn load_partial_tests(tests_filename: impl AsRef<Path>) -> Result<Vec<serde_json::Value>> {
302 match std::fs::OpenOptions::new()
303 .read(true)
304 .open(tests_filename.as_ref())
305 {
306 Ok(f) => {
307 let reader = BufReader::new(f);
308 serde_json::from_reader(reader).map_err(|e| {
309 miette!(
310 "failed to parse tests from file {}: {e}",
311 tests_filename.as_ref().display()
312 )
313 })
314 }
315 Err(e) => Err(e).into_diagnostic().wrap_err_with(|| {
316 format!(
317 "failed to open test file {}",
318 tests_filename.as_ref().display()
319 )
320 }),
321 }
322}