cedar_policy_core/ast/
request.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::entities::json::{
18    ContextJsonDeserializationError, ContextJsonParser, NullContextSchema,
19};
20use crate::evaluator::{EvaluationError, RestrictedEvaluator};
21use crate::extensions::Extensions;
22use crate::parser::Loc;
23use miette::Diagnostic;
24use serde::{Deserialize, Serialize};
25use smol_str::SmolStr;
26use std::collections::{BTreeMap, HashMap};
27use std::sync::Arc;
28use thiserror::Error;
29
30use super::{
31    BorrowedRestrictedExpr, EntityType, EntityUID, Expr, ExprKind, ExpressionConstructionError,
32    PartialValue, RestrictedExpr, Unknown, Value, ValueKind, Var,
33};
34
35/// Represents the request tuple <P, A, R, C> (see the Cedar design doc).
36#[derive(Debug, Clone, Serialize)]
37pub struct Request {
38    /// Principal associated with the request
39    pub(crate) principal: EntityUIDEntry,
40
41    /// Action associated with the request
42    pub(crate) action: EntityUIDEntry,
43
44    /// Resource associated with the request
45    pub(crate) resource: EntityUIDEntry,
46
47    /// Context associated with the request.
48    /// `None` means that variable will result in a residual for partial evaluation.
49    pub(crate) context: Option<Context>,
50}
51
52/// Represents the principal type, resource type, and action UID.
53#[derive(Debug, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
54#[serde(rename_all = "camelCase")]
55pub struct RequestType {
56    /// Principal type
57    pub principal: EntityType,
58    /// Action type
59    pub action: EntityUID,
60    /// Resource type
61    pub resource: EntityType,
62}
63
64/// An entry in a request for a Entity UID.
65/// It may either be a concrete EUID
66/// or an unknown in the case of partial evaluation
67#[derive(Debug, Clone, Serialize)]
68pub enum EntityUIDEntry {
69    /// A concrete EntityUID
70    Known {
71        /// The concrete `EntityUID`
72        euid: Arc<EntityUID>,
73        /// Source location associated with the `EntityUIDEntry`, if any
74        loc: Option<Loc>,
75    },
76    /// An EntityUID left as unknown for partial evaluation
77    Unknown {
78        /// Source location associated with the `EntityUIDEntry`, if any
79        loc: Option<Loc>,
80    },
81}
82
83impl EntityUIDEntry {
84    /// Evaluate the entry to either:
85    /// A value, if the entry is concrete
86    /// An unknown corresponding to the passed `var`
87    pub fn evaluate(&self, var: Var) -> PartialValue {
88        match self {
89            EntityUIDEntry::Known { euid, loc } => {
90                Value::new(Arc::unwrap_or_clone(Arc::clone(euid)), loc.clone()).into()
91            }
92            EntityUIDEntry::Unknown { loc } => Expr::unknown(Unknown::new_untyped(var.to_string()))
93                .with_maybe_source_loc(loc.clone())
94                .into(),
95        }
96    }
97
98    /// Create an entry with a concrete EntityUID and the given source location
99    pub fn known(euid: EntityUID, loc: Option<Loc>) -> Self {
100        Self::Known {
101            euid: Arc::new(euid),
102            loc,
103        }
104    }
105
106    /// Get the UID of the entry, or `None` if it is unknown (partial evaluation)
107    pub fn uid(&self) -> Option<&EntityUID> {
108        match self {
109            Self::Known { euid, .. } => Some(euid),
110            Self::Unknown { .. } => None,
111        }
112    }
113}
114
115impl Request {
116    /// Default constructor.
117    ///
118    /// If `schema` is provided, this constructor validates that this `Request`
119    /// complies with the given `schema`.
120    pub fn new<S: RequestSchema>(
121        principal: (EntityUID, Option<Loc>),
122        action: (EntityUID, Option<Loc>),
123        resource: (EntityUID, Option<Loc>),
124        context: Context,
125        schema: Option<&S>,
126        extensions: &Extensions<'_>,
127    ) -> Result<Self, S::Error> {
128        let req = Self {
129            principal: EntityUIDEntry::known(principal.0, principal.1),
130            action: EntityUIDEntry::known(action.0, action.1),
131            resource: EntityUIDEntry::known(resource.0, resource.1),
132            context: Some(context),
133        };
134        if let Some(schema) = schema {
135            schema.validate_request(&req, extensions)?;
136        }
137        Ok(req)
138    }
139
140    /// Create a new `Request` with potentially unknown (for partial eval) variables.
141    ///
142    /// If `schema` is provided, this constructor validates that this `Request`
143    /// complies with the given `schema` (at least to the extent that we can
144    /// validate with the given information)
145    pub fn new_with_unknowns<S: RequestSchema>(
146        principal: EntityUIDEntry,
147        action: EntityUIDEntry,
148        resource: EntityUIDEntry,
149        context: Option<Context>,
150        schema: Option<&S>,
151        extensions: &Extensions<'_>,
152    ) -> Result<Self, S::Error> {
153        let req = Self {
154            principal,
155            action,
156            resource,
157            context,
158        };
159        if let Some(schema) = schema {
160            schema.validate_request(&req, extensions)?;
161        }
162        Ok(req)
163    }
164
165    /// Create a new `Request` with potentially unknown (for partial eval) variables/context
166    /// and without schema validation.
167    pub fn new_unchecked(
168        principal: EntityUIDEntry,
169        action: EntityUIDEntry,
170        resource: EntityUIDEntry,
171        context: Option<Context>,
172    ) -> Self {
173        Self {
174            principal,
175            action,
176            resource,
177            context,
178        }
179    }
180
181    /// Get the principal associated with the request
182    pub fn principal(&self) -> &EntityUIDEntry {
183        &self.principal
184    }
185
186    /// Get the action associated with the request
187    pub fn action(&self) -> &EntityUIDEntry {
188        &self.action
189    }
190
191    /// Get the resource associated with the request
192    pub fn resource(&self) -> &EntityUIDEntry {
193        &self.resource
194    }
195
196    /// Get the context associated with the request
197    /// Returning `None` means the variable is unknown, and will result in a residual expression
198    pub fn context(&self) -> Option<&Context> {
199        self.context.as_ref()
200    }
201
202    /// Get the request types that correspond to this request.
203    /// This includes the types of the principal, action, and resource.
204    /// [`RequestType`] is used by the entity manifest.
205    /// The context type is implied by the action's type.
206    /// Returns `None` if the request is not fully concrete.
207    pub fn to_request_type(&self) -> Option<RequestType> {
208        Some(RequestType {
209            principal: self.principal().uid()?.entity_type().clone(),
210            action: self.action().uid()?.clone(),
211            resource: self.resource().uid()?.entity_type().clone(),
212        })
213    }
214}
215
216impl std::fmt::Display for Request {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        let display_euid = |maybe_euid: &EntityUIDEntry| match maybe_euid {
219            EntityUIDEntry::Known { euid, .. } => format!("{euid}"),
220            EntityUIDEntry::Unknown { .. } => "unknown".to_string(),
221        };
222        write!(
223            f,
224            "request with principal {}, action {}, resource {}, and context {}",
225            display_euid(&self.principal),
226            display_euid(&self.action),
227            display_euid(&self.resource),
228            match &self.context {
229                Some(x) => format!("{x}"),
230                None => "unknown".to_string(),
231            }
232        )
233    }
234}
235
236/// `Context` field of a `Request`
237#[derive(Debug, Clone, PartialEq, Serialize)]
238// Serialization is used for differential testing, which requires that `Context`
239// is serialized as a `RestrictedExpr`.
240#[serde(into = "RestrictedExpr")]
241pub enum Context {
242    /// The context is a concrete value.
243    Value(Arc<BTreeMap<SmolStr, Value>>),
244    /// The context is a residual expression, containing some unknown value in
245    /// the record attributes.
246    /// INVARIANT(restricted): Each `Expr` in this map must be a `RestrictedExpr`.
247    /// INVARIANT(unknown): At least one `Expr` must contain an `unknown`.
248    RestrictedResidual(Arc<BTreeMap<SmolStr, Expr>>),
249}
250
251impl Context {
252    /// Create an empty `Context`
253    pub fn empty() -> Self {
254        Self::Value(Arc::new(BTreeMap::new()))
255    }
256
257    /// Create a `Context` from a `PartialValue` without checking that the
258    /// residual is a restricted expression.  This function does check that the
259    /// value or residual is a record and returns `Err` when it is not.
260    ///
261    /// INVARIANT: if `value` is a residual, then it must be a valid restricted expression.
262    fn from_restricted_partial_val_unchecked(
263        value: PartialValue,
264    ) -> Result<Self, ContextCreationError> {
265        match value {
266            PartialValue::Value(v) => {
267                if let ValueKind::Record(attrs) = v.value {
268                    Ok(Context::Value(attrs))
269                } else {
270                    Err(ContextCreationError::not_a_record(v.into()))
271                }
272            }
273            PartialValue::Residual(e) => {
274                if let ExprKind::Record(attrs) = e.expr_kind() {
275                    // From the invariant on `PartialValue::Residual`, there is
276                    // an unknown in `e`. It is a record, so there must be an
277                    // unknown in one of the attributes expressions, satisfying
278                    // INVARIANT(unknown). From the invariant on this function,
279                    // `e` is a valid restricted expression, satisfying
280                    // INVARIANT(restricted).
281                    Ok(Context::RestrictedResidual(attrs.clone()))
282                } else {
283                    Err(ContextCreationError::not_a_record(e))
284                }
285            }
286        }
287    }
288
289    /// Create a `Context` from a `RestrictedExpr`, which must be a `Record`.
290    ///
291    /// `extensions` provides the `Extensions` which should be active for
292    /// evaluating the `RestrictedExpr`.
293    pub fn from_expr(
294        expr: BorrowedRestrictedExpr<'_>,
295        extensions: &Extensions<'_>,
296    ) -> Result<Self, ContextCreationError> {
297        match expr.expr_kind() {
298            ExprKind::Record { .. } => {
299                let evaluator = RestrictedEvaluator::new(extensions);
300                let pval = evaluator.partial_interpret(expr)?;
301                // The invariant on `from_restricted_partial_val_unchecked`
302                // is satisfied because `expr` is a restricted expression,
303                // and must still be restricted after `partial_interpret`.
304                // The function call cannot return `Err` because `expr` is a
305                // record, and partially evaluating a record expression will
306                // yield a record expression or a record value.
307                // PANIC SAFETY: See above
308                #[allow(clippy::expect_used)]
309                Ok(Self::from_restricted_partial_val_unchecked(pval).expect(
310                    "`from_restricted_partial_val_unchecked` should succeed when called on a record.",
311                ))
312            }
313            _ => Err(ContextCreationError::not_a_record(expr.to_owned().into())),
314        }
315    }
316
317    /// Create a `Context` from a map of key to `RestrictedExpr`, or a Vec of
318    /// `(key, RestrictedExpr)` pairs, or any other iterator of `(key, RestrictedExpr)` pairs
319    ///
320    /// `extensions` provides the `Extensions` which should be active for
321    /// evaluating the `RestrictedExpr`.
322    pub fn from_pairs(
323        pairs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
324        extensions: &Extensions<'_>,
325    ) -> Result<Self, ContextCreationError> {
326        match RestrictedExpr::record(pairs) {
327            Ok(record) => Self::from_expr(record.as_borrowed(), extensions),
328            Err(ExpressionConstructionError::DuplicateKey(err)) => Err(
329                ExpressionConstructionError::DuplicateKey(err.with_context("in context")).into(),
330            ),
331        }
332    }
333
334    /// Create a `Context` from a string containing JSON (which must be a JSON
335    /// object, not any other JSON type, or you will get an error here).
336    /// JSON here must use the `__entity` and `__extn` escapes for entity
337    /// references, extension values, etc.
338    ///
339    /// For schema-based parsing, use `ContextJsonParser`.
340    pub fn from_json_str(json: &str) -> Result<Self, ContextJsonDeserializationError> {
341        ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
342            .from_json_str(json)
343    }
344
345    /// Create a `Context` from a `serde_json::Value` (which must be a JSON
346    /// object, not any other JSON type, or you will get an error here).
347    /// JSON here must use the `__entity` and `__extn` escapes for entity
348    /// references, extension values, etc.
349    ///
350    /// For schema-based parsing, use `ContextJsonParser`.
351    pub fn from_json_value(
352        json: serde_json::Value,
353    ) -> Result<Self, ContextJsonDeserializationError> {
354        ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
355            .from_json_value(json)
356    }
357
358    /// Create a `Context` from a JSON file.  The JSON file must contain a JSON
359    /// object, not any other JSON type, or you will get an error here.
360    /// JSON here must use the `__entity` and `__extn` escapes for entity
361    /// references, extension values, etc.
362    ///
363    /// For schema-based parsing, use `ContextJsonParser`.
364    pub fn from_json_file(
365        json: impl std::io::Read,
366    ) -> Result<Self, ContextJsonDeserializationError> {
367        ContextJsonParser::new(None::<&NullContextSchema>, Extensions::all_available())
368            .from_json_file(json)
369    }
370
371    /// Private helper function to implement `into_iter()` for `Context`.
372    /// Gets an iterator over the (key, value) pairs in the `Context`, cloning
373    /// only if necessary.
374    fn into_values(self) -> Box<dyn Iterator<Item = (SmolStr, RestrictedExpr)>> {
375        match self {
376            Context::Value(record) => Box::new(
377                Arc::unwrap_or_clone(record)
378                    .into_iter()
379                    .map(|(k, v)| (k, RestrictedExpr::from(v))),
380            ),
381            Context::RestrictedResidual(record) => Box::new(
382                Arc::unwrap_or_clone(record)
383                    .into_iter()
384                    // By INVARIANT(restricted), all attributes expressions are
385                    // restricted expressions.
386                    .map(|(k, v)| (k, RestrictedExpr::new_unchecked(v))),
387            ),
388        }
389    }
390
391    /// Substitute unknowns with concrete values in this context. If this is
392    /// already a `Context::Value`, then this returns `self` unchanged and will
393    /// not error. Otherwise delegate to [`Expr::substitute`].
394    pub fn substitute(self, mapping: &HashMap<SmolStr, Value>) -> Result<Self, EvaluationError> {
395        match self {
396            Context::RestrictedResidual(residual_context) => {
397                // From Invariant(Restricted), `residual_context` contains only
398                // restricted expressions, so `Expr::record_arc` of the attributes
399                // will also be a restricted expression. This doesn't change after
400                // substitution, so we know `expr` must be a restricted expression.
401                let expr = Expr::record_arc(residual_context).substitute(mapping);
402                let expr = BorrowedRestrictedExpr::new_unchecked(&expr);
403
404                let extns = Extensions::all_available();
405                let eval = RestrictedEvaluator::new(extns);
406                let partial_value = eval.partial_interpret(expr)?;
407
408                // The invariant on `from_restricted_partial_val_unchecked`
409                // is satisfied because `expr` is restricted and must still be
410                // restricted after `partial_interpret`.
411                // The function call cannot fail because because `expr` was
412                // constructed as a record, and substitution and partial
413                // evaluation does not change this.
414                // PANIC SAFETY: See above
415                #[allow(clippy::expect_used)]
416                Ok(
417                    Self::from_restricted_partial_val_unchecked(partial_value).expect(
418                        "`from_restricted_partial_val_unchecked` should succeed when called on a record.",
419                    ),
420                )
421            }
422            Context::Value(_) => Ok(self),
423        }
424    }
425}
426
427/// Utilities for implementing `IntoIterator` for `Context`
428mod iter {
429    use super::*;
430
431    /// `IntoIter` iterator for `Context`
432    pub struct IntoIter(pub(super) Box<dyn Iterator<Item = (SmolStr, RestrictedExpr)>>);
433
434    impl std::fmt::Debug for IntoIter {
435        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
436            write!(f, "IntoIter(<context>)")
437        }
438    }
439
440    impl Iterator for IntoIter {
441        type Item = (SmolStr, RestrictedExpr);
442
443        fn next(&mut self) -> Option<Self::Item> {
444            self.0.next()
445        }
446    }
447}
448
449impl IntoIterator for Context {
450    type Item = (SmolStr, RestrictedExpr);
451
452    type IntoIter = iter::IntoIter;
453
454    fn into_iter(self) -> Self::IntoIter {
455        iter::IntoIter(self.into_values())
456    }
457}
458
459impl From<Context> for RestrictedExpr {
460    fn from(value: Context) -> Self {
461        match value {
462            Context::Value(attrs) => Value::record_arc(attrs, None).into(),
463            Context::RestrictedResidual(attrs) => {
464                // By INVARIANT(restricted), all attributes expressions are
465                // restricted expressions, so the result of `record_arc` will be
466                // a restricted expression.
467                RestrictedExpr::new_unchecked(Expr::record_arc(attrs))
468            }
469        }
470    }
471}
472
473impl From<Context> for PartialValue {
474    fn from(ctx: Context) -> PartialValue {
475        match ctx {
476            Context::Value(attrs) => Value::record_arc(attrs, None).into(),
477            Context::RestrictedResidual(attrs) => {
478                // A `PartialValue::Residual` must contain an unknown in the
479                // expression. By INVARIANT(unknown), at least one expr in
480                // `attrs` contains an unknown, so the `record_arc` expression
481                // contains at least one unknown.
482                PartialValue::Residual(Expr::record_arc(attrs))
483            }
484        }
485    }
486}
487
488impl std::default::Default for Context {
489    fn default() -> Context {
490        Context::empty()
491    }
492}
493
494impl std::fmt::Display for Context {
495    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496        write!(f, "{}", PartialValue::from(self.clone()))
497    }
498}
499
500/// Errors while trying to create a `Context`
501#[derive(Debug, Diagnostic, Error)]
502pub enum ContextCreationError {
503    /// Tried to create a `Context` out of something other than a record
504    #[error(transparent)]
505    #[diagnostic(transparent)]
506    NotARecord(#[from] context_creation_errors::NotARecord),
507    /// Error evaluating the expression given for the `Context`
508    #[error(transparent)]
509    #[diagnostic(transparent)]
510    Evaluation(#[from] EvaluationError),
511    /// Error constructing a record for the `Context`.
512    /// Only returned by `Context::from_pairs()` and `Context::merge()`
513    #[error(transparent)]
514    #[diagnostic(transparent)]
515    ExpressionConstruction(#[from] ExpressionConstructionError),
516}
517
518impl ContextCreationError {
519    pub(crate) fn not_a_record(expr: Expr) -> Self {
520        Self::NotARecord(context_creation_errors::NotARecord {
521            expr: Box::new(expr),
522        })
523    }
524}
525
526/// Error subtypes for [`ContextCreationError`]
527pub mod context_creation_errors {
528    use super::Expr;
529    use crate::impl_diagnostic_from_expr_field;
530    use miette::Diagnostic;
531    use thiserror::Error;
532
533    /// Error type for an expression that needed to be a record, but is not
534    //
535    // CAUTION: this type is publicly exported in `cedar-policy`.
536    // Don't make fields `pub`, don't make breaking changes, and use caution
537    // when adding public methods.
538    #[derive(Debug, Error)]
539    #[error("expression is not a record: {expr}")]
540    pub struct NotARecord {
541        /// Expression which is not a record
542        pub(super) expr: Box<Expr>,
543    }
544
545    // custom impl of `Diagnostic`: take source location from the `expr` field
546    impl Diagnostic for NotARecord {
547        impl_diagnostic_from_expr_field!(expr);
548    }
549}
550
551/// Trait for schemas capable of validating `Request`s
552pub trait RequestSchema {
553    /// Error type returned when a request fails validation
554    type Error: miette::Diagnostic;
555    /// Validate the given `request`, returning `Err` if it fails validation
556    fn validate_request(
557        &self,
558        request: &Request,
559        extensions: &Extensions<'_>,
560    ) -> Result<(), Self::Error>;
561}
562
563/// A `RequestSchema` that does no validation and always reports a passing result
564#[derive(Debug, Clone)]
565pub struct RequestSchemaAllPass;
566impl RequestSchema for RequestSchemaAllPass {
567    type Error = Infallible;
568    fn validate_request(
569        &self,
570        _request: &Request,
571        _extensions: &Extensions<'_>,
572    ) -> Result<(), Self::Error> {
573        Ok(())
574    }
575}
576
577/// Wrapper around `std::convert::Infallible` which also implements
578/// `miette::Diagnostic`
579#[derive(Debug, Diagnostic, Error)]
580#[error(transparent)]
581pub struct Infallible(pub std::convert::Infallible);
582
583#[cfg(test)]
584mod test {
585    use super::*;
586    use cool_asserts::assert_matches;
587
588    #[test]
589    fn test_json_from_str_non_record() {
590        assert_matches!(
591            Context::from_expr(RestrictedExpr::val("1").as_borrowed(), Extensions::none()),
592            Err(ContextCreationError::NotARecord { .. })
593        );
594        assert_matches!(
595            Context::from_json_str("1"),
596            Err(ContextJsonDeserializationError::ContextCreation(
597                ContextCreationError::NotARecord { .. }
598            ))
599        );
600    }
601}