Skip to main content

cedar_policy_cli/command/
run_test.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use clap::Args;
18use miette::{miette, IntoDiagnostic, Report, Result, WrapErr};
19use owo_colors::OwoColorize;
20use serde::de::{DeserializeSeed, IntoDeserializer};
21use serde::{Deserialize, Deserializer};
22use std::collections::BTreeSet;
23use std::io::{BufReader, Write};
24use std::path::Path;
25
26use cedar_policy::*;
27
28use crate::{CedarExitCode, OptionalSchemaArgs, PoliciesArgs, RequestJSON};
29
30#[derive(Args, Debug)]
31pub struct RunTestsArgs {
32    /// Policies args (incorporated by reference)
33    #[command(flatten)]
34    pub policies: PoliciesArgs,
35    /// Tests in JSON format
36    #[arg(long, value_name = "FILE")]
37    pub tests: String,
38    #[command(flatten)]
39    pub schema: OptionalSchemaArgs,
40}
41
42#[derive(Clone, Debug)]
43enum TestResult {
44    Pass,
45    Fail(String),
46}
47
48/// Compare the test's expected decision against the actual decision
49fn compare_test_decisions(test: &TestCase, ans: &Response) -> TestResult {
50    if ans.decision() == test.decision.into() {
51        let mut errors = Vec::new();
52        let reason = ans.diagnostics().reason().collect::<BTreeSet<_>>();
53
54        // Check that the declared reason is a subset of the actual reason
55        let missing_reason = test
56            .reason
57            .iter()
58            .filter(|r| !reason.contains(&PolicyId::new(r)))
59            .collect::<Vec<_>>();
60
61        if !missing_reason.is_empty() {
62            errors.push(format!(
63                "missing reason(s): {}",
64                missing_reason
65                    .into_iter()
66                    .map(|r| format!("`{r}`"))
67                    .collect::<Vec<_>>()
68                    .join(", ")
69            ));
70        }
71
72        // Check that evaluation errors are expected
73        let num_errors = ans.diagnostics().errors().count();
74        if num_errors != test.num_errors {
75            errors.push(format!(
76                "expected {} error(s), but got {} runtime error(s){}",
77                test.num_errors,
78                num_errors,
79                if num_errors == 0 {
80                    "".to_string()
81                } else {
82                    format!(
83                        ": {}",
84                        ans.diagnostics()
85                            .errors()
86                            .map(|e| e.to_string())
87                            .collect::<Vec<_>>()
88                            .join(", ")
89                    )
90                },
91            ));
92        }
93
94        if errors.is_empty() {
95            TestResult::Pass
96        } else {
97            TestResult::Fail(errors.join("; "))
98        }
99    } else {
100        TestResult::Fail(format!(
101            "expected {:?}, got {:?}",
102            test.decision,
103            ans.decision()
104        ))
105    }
106}
107
108/// Parse the test, validate against schema,
109/// and then check the authorization decision
110fn run_one_test(
111    policies: &PolicySet,
112    test: &serde_json::Value,
113    validator: Option<&Validator>,
114) -> Result<TestResult> {
115    let test = CheckedTestCaseSeed(validator.map(Validator::schema))
116        .deserialize(test.into_deserializer())
117        .into_diagnostic()?;
118    if let Some(validator) = validator {
119        let val_res = validator.validate(policies, cedar_policy::ValidationMode::Strict);
120        if !val_res.validation_passed_without_warnings() {
121            return Err(Report::new(val_res).wrap_err("policy set validation failed"));
122        }
123    }
124    let ans = Authorizer::new().is_authorized(&test.request, policies, &test.entities);
125    Ok(compare_test_decisions(&test, &ans))
126}
127
128fn run_tests_inner(args: &RunTestsArgs) -> Result<CedarExitCode> {
129    let policies = args.policies.get_policy_set()?;
130    let tests = load_partial_tests(&args.tests)?;
131    let validator = args.schema.get_schema()?.map(Validator::new);
132
133    let mut total_fails: usize = 0;
134
135    println!("running {} test(s)", tests.len());
136    for test in tests.iter() {
137        if let Some(name) = test["name"].as_str() {
138            print!("  test {name} ... ");
139        } else {
140            print!("  test (unnamed) ... ");
141        }
142        std::io::stdout().flush().into_diagnostic()?;
143        match run_one_test(&policies, test, validator.as_ref()) {
144            Ok(TestResult::Pass) => {
145                println!(
146                    "{}",
147                    "ok".if_supports_color(owo_colors::Stream::Stdout, |s| s.green())
148                );
149            }
150            Ok(TestResult::Fail(reason)) => {
151                total_fails += 1;
152                println!(
153                    "{}: {}",
154                    "fail".if_supports_color(owo_colors::Stream::Stdout, |s| s.red()),
155                    reason
156                );
157            }
158            Err(e) => {
159                total_fails += 1;
160                println!(
161                    "{}:\n  {:?}",
162                    "error".if_supports_color(owo_colors::Stream::Stdout, |s| s.red()),
163                    e
164                );
165            }
166        }
167    }
168
169    println!(
170        "results: {} {}, {} {}",
171        tests.len() - total_fails,
172        if total_fails == 0 {
173            "passed"
174                .if_supports_color(owo_colors::Stream::Stdout, |s| s.green())
175                .to_string()
176        } else {
177            "passed".to_string()
178        },
179        total_fails,
180        if total_fails != 0 {
181            "failed"
182                .if_supports_color(owo_colors::Stream::Stdout, |s| s.red())
183                .to_string()
184        } else {
185            "failed".to_string()
186        },
187    );
188
189    Ok(if total_fails != 0 {
190        CedarExitCode::Failure
191    } else {
192        CedarExitCode::Success
193    })
194}
195
196pub fn run_tests(args: &RunTestsArgs) -> CedarExitCode {
197    run_tests_inner(args).unwrap_or_else(|e| {
198        println!("{e:?}");
199        CedarExitCode::Failure
200    })
201}
202
203#[derive(Copy, Clone, Debug, Deserialize)]
204enum ExpectedDecision {
205    #[serde(rename = "allow")]
206    Allow,
207    #[serde(rename = "deny")]
208    Deny,
209}
210
211impl From<ExpectedDecision> for Decision {
212    fn from(value: ExpectedDecision) -> Self {
213        match value {
214            ExpectedDecision::Allow => Decision::Allow,
215            ExpectedDecision::Deny => Decision::Deny,
216        }
217    }
218}
219
220#[derive(Clone, Debug, Deserialize)]
221struct UncheckedTestCase {
222    request: RequestJSON,
223    entities: serde_json::Value,
224    decision: ExpectedDecision,
225    reason: Vec<String>,
226    num_errors: usize,
227}
228
229#[derive(Clone, Debug)]
230struct TestCase {
231    request: Request,
232    entities: Entities,
233    decision: ExpectedDecision,
234    reason: Vec<String>,
235    num_errors: usize,
236}
237
238struct CheckedTestCaseSeed<'a>(Option<&'a Schema>);
239
240impl<'de, 'a> DeserializeSeed<'de> for CheckedTestCaseSeed<'a> {
241    type Value = TestCase;
242
243    fn deserialize<D>(self, deserializer: D) -> std::result::Result<Self::Value, D::Error>
244    where
245        D: Deserializer<'de>,
246    {
247        let UncheckedTestCase {
248            request,
249            entities,
250            decision,
251            reason,
252            num_errors,
253        } = UncheckedTestCase::deserialize(deserializer)?;
254
255        let principal = request
256            .principal
257            .ok_or_else(|| serde::de::Error::missing_field("principal"))?;
258        let principal = principal.parse().map_err(|e| {
259            serde::de::Error::custom(format!("failed to parse principal `{principal}`: {e}",))
260        })?;
261
262        let action = request
263            .action
264            .ok_or_else(|| serde::de::Error::missing_field("action"))?;
265        let action = action.parse().map_err(|e| {
266            serde::de::Error::custom(format!("failed to parse action `{action}`: {e}",))
267        })?;
268
269        let resource = request
270            .resource
271            .ok_or_else(|| serde::de::Error::missing_field("resource"))?;
272        let resource = resource.parse().map_err(|e| {
273            serde::de::Error::custom(format!("failed to parse resource `{resource}`: {e}",))
274        })?;
275
276        let context = Context::from_json_value(request.context.clone(), None).map_err(|e| {
277            serde::de::Error::custom(format!(
278                "failed to parse context `{}`: {}",
279                request.context, e
280            ))
281        })?;
282
283        let request = Request::new(principal, action, resource, context, self.0)
284            .map_err(|e| serde::de::Error::custom(format!("failed to create request: {e}")))?;
285
286        let entities = Entities::from_json_value(entities, self.0)
287            .map_err(|e| serde::de::Error::custom(format!("failed to parse entities: {e}")))?;
288
289        Ok(TestCase {
290            request,
291            entities,
292            decision,
293            reason,
294            num_errors,
295        })
296    }
297}
298
299/// Load partially parsed tests from a JSON file
300/// (as JSON values first without parsing to TestCase)
301fn load_partial_tests(tests_filename: impl AsRef<Path>) -> Result<Vec<serde_json::Value>> {
302    match std::fs::OpenOptions::new()
303        .read(true)
304        .open(tests_filename.as_ref())
305    {
306        Ok(f) => {
307            let reader = BufReader::new(f);
308            serde_json::from_reader(reader).map_err(|e| {
309                miette!(
310                    "failed to parse tests from file {}: {e}",
311                    tests_filename.as_ref().display()
312                )
313            })
314        }
315        Err(e) => Err(e).into_diagnostic().wrap_err_with(|| {
316            format!(
317                "failed to open test file {}",
318                tests_filename.as_ref().display()
319            )
320        }),
321    }
322}