cedar_policy_core/ast/
entity.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use crate::evaluator::{EvaluationError, RestrictedEvaluator};
19use crate::extensions::Extensions;
20use crate::parser::err::ParseErrors;
21use crate::parser::Loc;
22use crate::transitive_closure::TCNode;
23use crate::FromNormalizedStr;
24use itertools::Itertools;
25use miette::Diagnostic;
26use serde::{Deserialize, Serialize};
27use serde_with::{serde_as, TryFromInto};
28use smol_str::SmolStr;
29use std::collections::{BTreeMap, HashMap, HashSet};
30use thiserror::Error;
31
32/// We support two types of entities. The first is a nominal type (e.g., User, Action)
33/// and the second is an unspecified type, which is used (internally) to represent cases
34/// where the input request does not provide a principal, action, and/or resource.
35#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
36#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
37pub enum EntityType {
38    /// Concrete nominal type
39    Specified(Name),
40    /// Unspecified
41    Unspecified,
42}
43
44impl EntityType {
45    /// Is this an Action entity type
46    pub fn is_action(&self) -> bool {
47        match self {
48            Self::Specified(name) => name.basename() == &Id::new_unchecked("Action"),
49            Self::Unspecified => false,
50        }
51    }
52}
53
54// Note: the characters '<' and '>' are not allowed in `Name`s, so the display for
55// `Unspecified` never conflicts with `Specified(name)`.
56impl std::fmt::Display for EntityType {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        match self {
59            Self::Unspecified => write!(f, "<Unspecified>"),
60            Self::Specified(name) => write!(f, "{}", name),
61        }
62    }
63}
64
65/// Unique ID for an entity. These represent entities in the AST.
66#[derive(Serialize, Deserialize, Debug, Clone)]
67pub struct EntityUID {
68    /// Typename of the entity
69    ty: EntityType,
70    /// EID of the entity
71    eid: Eid,
72    /// Location of the entity in policy source
73    #[serde(skip)]
74    loc: Option<Loc>,
75}
76
77/// `PartialEq` implementation ignores the `loc`.
78impl PartialEq for EntityUID {
79    fn eq(&self, other: &Self) -> bool {
80        self.ty == other.ty && self.eid == other.eid
81    }
82}
83impl Eq for EntityUID {}
84
85impl std::hash::Hash for EntityUID {
86    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
87        // hash the ty and eid, in line with the `PartialEq` impl which compares
88        // the ty and eid.
89        self.ty.hash(state);
90        self.eid.hash(state);
91    }
92}
93
94impl PartialOrd for EntityUID {
95    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
96        Some(self.cmp(other))
97    }
98}
99impl Ord for EntityUID {
100    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
101        self.ty.cmp(&other.ty).then(self.eid.cmp(&other.eid))
102    }
103}
104
105impl StaticallyTyped for EntityUID {
106    fn type_of(&self) -> Type {
107        Type::Entity {
108            ty: self.ty.clone(),
109        }
110    }
111}
112
113impl EntityUID {
114    /// Create an `EntityUID` with the given string as its EID.
115    /// Useful for testing.
116    #[cfg(test)]
117    pub(crate) fn with_eid(eid: &str) -> Self {
118        Self {
119            ty: Self::test_entity_type(),
120            eid: Eid(eid.into()),
121            loc: None,
122        }
123    }
124    // by default, Coverlay does not track coverage for lines after a line
125    // containing #[cfg(test)].
126    // we use the following sentinel to "turn back on" coverage tracking for
127    // remaining lines of this file, until the next #[cfg(test)]
128    // GRCOV_BEGIN_COVERAGE
129
130    /// The type of entities created with the above `with_eid()`.
131    #[cfg(test)]
132    pub(crate) fn test_entity_type() -> EntityType {
133        let name = Name::parse_unqualified_name("test_entity_type")
134            .expect("test_entity_type should be a valid identifier");
135        EntityType::Specified(name)
136    }
137    // by default, Coverlay does not track coverage for lines after a line
138    // containing #[cfg(test)].
139    // we use the following sentinel to "turn back on" coverage tracking for
140    // remaining lines of this file, until the next #[cfg(test)]
141    // GRCOV_BEGIN_COVERAGE
142
143    /// Create an `EntityUID` with the given (unqualified) typename, and the given string as its EID.
144    pub fn with_eid_and_type(typename: &str, eid: &str) -> Result<Self, ParseErrors> {
145        Ok(Self {
146            ty: EntityType::Specified(Name::parse_unqualified_name(typename)?),
147            eid: Eid(eid.into()),
148            loc: None,
149        })
150    }
151
152    /// Split into the `EntityType` representing the entity type, and the `Eid`
153    /// representing its name
154    pub fn components(self) -> (EntityType, Eid) {
155        (self.ty, self.eid)
156    }
157
158    /// Get the source location for this `EntityUID`.
159    pub fn loc(&self) -> Option<&Loc> {
160        self.loc.as_ref()
161    }
162
163    /// Create a nominally-typed `EntityUID` with the given typename and EID
164    pub fn from_components(name: Name, eid: Eid, loc: Option<Loc>) -> Self {
165        Self {
166            ty: EntityType::Specified(name),
167            eid,
168            loc,
169        }
170    }
171
172    /// Create an unspecified `EntityUID` with the given EID
173    pub fn unspecified_from_eid(eid: Eid) -> Self {
174        Self {
175            ty: EntityType::Unspecified,
176            eid,
177            loc: None,
178        }
179    }
180
181    /// Get the type component.
182    pub fn entity_type(&self) -> &EntityType {
183        &self.ty
184    }
185
186    /// Get the Eid component.
187    pub fn eid(&self) -> &Eid {
188        &self.eid
189    }
190
191    /// Does this EntityUID refer to an action entity?
192    pub fn is_action(&self) -> bool {
193        self.entity_type().is_action()
194    }
195}
196
197impl std::fmt::Display for EntityUID {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        write!(f, "{}::\"{}\"", self.entity_type(), self.eid)
200    }
201}
202
203// allow `.parse()` on a string to make an `EntityUID`
204impl std::str::FromStr for EntityUID {
205    type Err = ParseErrors;
206
207    fn from_str(s: &str) -> Result<Self, Self::Err> {
208        crate::parser::parse_euid(s)
209    }
210}
211
212impl FromNormalizedStr for EntityUID {
213    fn describe_self() -> &'static str {
214        "Entity UID"
215    }
216}
217
218#[cfg(feature = "arbitrary")]
219impl<'a> arbitrary::Arbitrary<'a> for EntityUID {
220    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
221        Ok(Self {
222            ty: u.arbitrary()?,
223            eid: u.arbitrary()?,
224            loc: None,
225        })
226    }
227}
228
229/// EID type is just a SmolStr for now
230#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
231pub struct Eid(SmolStr);
232
233impl Eid {
234    /// Construct an Eid
235    pub fn new(eid: impl Into<SmolStr>) -> Self {
236        Eid(eid.into())
237    }
238}
239
240impl AsRef<SmolStr> for Eid {
241    fn as_ref(&self) -> &SmolStr {
242        &self.0
243    }
244}
245
246impl AsRef<str> for Eid {
247    fn as_ref(&self) -> &str {
248        &self.0
249    }
250}
251
252#[cfg(feature = "arbitrary")]
253impl<'a> arbitrary::Arbitrary<'a> for Eid {
254    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
255        let x: String = u.arbitrary()?;
256        Ok(Self(x.into()))
257    }
258}
259
260impl std::fmt::Display for Eid {
261    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        write!(f, "{}", self.0.escape_debug())
263    }
264}
265
266/// Entity datatype
267#[derive(Debug, Clone, Serialize)]
268pub struct Entity {
269    /// UID
270    uid: EntityUID,
271
272    /// Internal BTreMap of attributes.
273    /// We use a btreemap so that the keys have a determenistic order.
274    ///
275    /// In the serialized form of `Entity`, attribute values appear as
276    /// `RestrictedExpr`s, for mostly historical reasons.
277    attrs: BTreeMap<SmolStr, PartialValueSerializedAsExpr>,
278
279    /// Set of ancestors of this `Entity` (i.e., all direct and transitive
280    /// parents), as UIDs
281    ancestors: HashSet<EntityUID>,
282}
283
284impl Entity {
285    /// Create a new `Entity` with this UID, attributes, and ancestors
286    pub fn new(
287        uid: EntityUID,
288        attrs: HashMap<SmolStr, RestrictedExpr>,
289        ancestors: HashSet<EntityUID>,
290        extensions: &Extensions<'_>,
291    ) -> Result<Self, EntityAttrEvaluationError> {
292        let evaluator = RestrictedEvaluator::new(extensions);
293        let evaluated_attrs = attrs
294            .into_iter()
295            .map(|(k, v)| {
296                let attr_val = evaluator
297                    .partial_interpret(v.as_borrowed())
298                    .map_err(|err| EntityAttrEvaluationError {
299                        uid: uid.clone(),
300                        attr: k.clone(),
301                        err,
302                    })?;
303                Ok((k, attr_val.into()))
304            })
305            .collect::<Result<_, EntityAttrEvaluationError>>()?;
306        Ok(Entity {
307            uid,
308            attrs: evaluated_attrs,
309            ancestors,
310        })
311    }
312
313    /// Create a new `Entity` with this UID, attributes, and ancestors.
314    ///
315    /// Unlike in `Entity::new()`, in this constructor, attributes are expressed
316    /// as `PartialValue`.
317    pub fn new_with_attr_partial_value(
318        uid: EntityUID,
319        attrs: HashMap<SmolStr, PartialValue>,
320        ancestors: HashSet<EntityUID>,
321    ) -> Self {
322        Entity {
323            uid,
324            attrs: attrs.into_iter().map(|(k, v)| (k, v.into())).collect(), // TODO(#540): can we do this without disassembling and reassembling the HashMap
325            ancestors,
326        }
327    }
328
329    /// Create a new `Entity` with this UID, attributes, and ancestors.
330    ///
331    /// Unlike in `Entity::new()`, in this constructor, attributes are expressed
332    /// as `PartialValueSerializedAsExpr`.
333    pub fn new_with_attr_partial_value_serialized_as_expr(
334        uid: EntityUID,
335        attrs: BTreeMap<SmolStr, PartialValueSerializedAsExpr>,
336        ancestors: HashSet<EntityUID>,
337    ) -> Self {
338        Entity {
339            uid,
340            attrs,
341            ancestors,
342        }
343    }
344
345    /// Get the UID of this entity
346    pub fn uid(&self) -> &EntityUID {
347        &self.uid
348    }
349
350    /// Get the value for the given attribute, or `None` if not present
351    pub fn get(&self, attr: &str) -> Option<&PartialValue> {
352        self.attrs.get(attr).map(|v| v.as_ref())
353    }
354
355    /// Is this `Entity` a descendant of `e` in the entity hierarchy?
356    pub fn is_descendant_of(&self, e: &EntityUID) -> bool {
357        self.ancestors.contains(e)
358    }
359
360    /// Iterate over this entity's ancestors
361    pub fn ancestors(&self) -> impl Iterator<Item = &EntityUID> {
362        self.ancestors.iter()
363    }
364
365    /// Get the number of attributes on this entity
366    pub fn attrs_len(&self) -> usize {
367        self.attrs.len()
368    }
369
370    /// Iterate over this entity's attribute names
371    pub fn keys(&self) -> impl Iterator<Item = &SmolStr> {
372        self.attrs.keys()
373    }
374
375    /// Iterate over this entity's attributes
376    pub fn attrs(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
377        self.attrs.iter().map(|(k, v)| (k, v.as_ref()))
378    }
379
380    /// Create an `Entity` with the given UID, no attributes, and no parents.
381    pub fn with_uid(uid: EntityUID) -> Self {
382        Self {
383            uid,
384            attrs: BTreeMap::new(),
385            ancestors: HashSet::new(),
386        }
387    }
388
389    /// Test if two `Entity` objects are deep/structurally equal.
390    /// That is, not only do they have the same UID, but also the same
391    /// attributes, attribute values, and ancestors.
392    pub(crate) fn deep_eq(&self, other: &Self) -> bool {
393        self.uid == other.uid && self.attrs == other.attrs && self.ancestors == other.ancestors
394    }
395
396    /// Set the given attribute to the given value.
397    // Only used for convenience in some tests and when fuzzing
398    #[cfg(any(test, fuzzing))]
399    pub fn set_attr(
400        &mut self,
401        attr: SmolStr,
402        val: RestrictedExpr,
403        extensions: &Extensions<'_>,
404    ) -> Result<(), EvaluationError> {
405        let val = RestrictedEvaluator::new(extensions).partial_interpret(val.as_borrowed())?;
406        self.attrs.insert(attr, val.into());
407        Ok(())
408    }
409
410    /// Mark the given `UID` as an ancestor of this `Entity`.
411    // When fuzzing, `add_ancestor()` is fully `pub`.
412    #[cfg(not(fuzzing))]
413    pub(crate) fn add_ancestor(&mut self, uid: EntityUID) {
414        self.ancestors.insert(uid);
415    }
416    /// Mark the given `UID` as an ancestor of this `Entity`
417    #[cfg(fuzzing)]
418    pub fn add_ancestor(&mut self, uid: EntityUID) {
419        self.ancestors.insert(uid);
420    }
421
422    /// Consume the entity and return the entity's owned Uid, attributes and parents.
423    pub fn into_inner(
424        self,
425    ) -> (
426        EntityUID,
427        HashMap<SmolStr, PartialValue>,
428        HashSet<EntityUID>,
429    ) {
430        let Self {
431            uid,
432            attrs,
433            ancestors,
434        } = self;
435        (
436            uid,
437            attrs.into_iter().map(|(k, v)| (k, v.0)).collect(),
438            ancestors,
439        )
440    }
441}
442
443impl PartialEq for Entity {
444    fn eq(&self, other: &Self) -> bool {
445        self.uid() == other.uid()
446    }
447}
448
449impl Eq for Entity {}
450
451impl StaticallyTyped for Entity {
452    fn type_of(&self) -> Type {
453        self.uid.type_of()
454    }
455}
456
457impl TCNode<EntityUID> for Entity {
458    fn get_key(&self) -> EntityUID {
459        self.uid().clone()
460    }
461
462    fn add_edge_to(&mut self, k: EntityUID) {
463        self.add_ancestor(k)
464    }
465
466    fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
467        Box::new(self.ancestors())
468    }
469
470    fn has_edge_to(&self, e: &EntityUID) -> bool {
471        self.is_descendant_of(e)
472    }
473}
474
475impl std::fmt::Display for Entity {
476    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
477        write!(
478            f,
479            "{}:\n  attrs:{}\n  ancestors:{}",
480            self.uid,
481            self.attrs
482                .iter()
483                .map(|(k, v)| format!("{}: {}", k, v))
484                .join("; "),
485            self.ancestors.iter().join(", ")
486        )
487    }
488}
489
490/// `PartialValue`, but serialized as a `RestrictedExpr`.
491///
492/// (Extension values can't be directly serialized, but can be serialized as
493/// `RestrictedExpr`)
494#[serde_as]
495#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
496pub struct PartialValueSerializedAsExpr(
497    #[serde_as(as = "TryFromInto<RestrictedExpr>")] PartialValue,
498);
499
500impl AsRef<PartialValue> for PartialValueSerializedAsExpr {
501    fn as_ref(&self) -> &PartialValue {
502        &self.0
503    }
504}
505
506impl std::ops::Deref for PartialValueSerializedAsExpr {
507    type Target = PartialValue;
508    fn deref(&self) -> &Self::Target {
509        &self.0
510    }
511}
512
513impl From<PartialValue> for PartialValueSerializedAsExpr {
514    fn from(value: PartialValue) -> PartialValueSerializedAsExpr {
515        PartialValueSerializedAsExpr(value)
516    }
517}
518
519impl From<PartialValueSerializedAsExpr> for PartialValue {
520    fn from(value: PartialValueSerializedAsExpr) -> PartialValue {
521        value.0
522    }
523}
524
525impl std::fmt::Display for PartialValueSerializedAsExpr {
526    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527        write!(f, "{}", self.0)
528    }
529}
530
531/// Error type for evaluation errors when evaluating an entity attribute.
532/// Contains some extra contextual information and the underlying
533/// `EvaluationError`.
534#[derive(Debug, Diagnostic, Error)]
535#[error("failed to evaluate attribute `{attr}` of `{uid}`: {err}")]
536pub struct EntityAttrEvaluationError {
537    /// UID of the entity where the error was encountered
538    pub uid: EntityUID,
539    /// Attribute of the entity where the error was encountered
540    pub attr: SmolStr,
541    /// Underlying evaluation error
542    #[diagnostic(transparent)]
543    pub err: EvaluationError,
544}
545
546#[cfg(test)]
547mod test {
548    use super::*;
549
550    #[test]
551    fn display() {
552        let e = EntityUID::with_eid("eid");
553        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
554    }
555
556    #[test]
557    fn test_euid_equality() {
558        let e1 = EntityUID::with_eid("foo");
559        let e2 = EntityUID::from_components(
560            Name::parse_unqualified_name("test_entity_type").expect("should be a valid identifier"),
561            Eid("foo".into()),
562            None,
563        );
564        let e3 = EntityUID::unspecified_from_eid(Eid("foo".into()));
565        let e4 = EntityUID::unspecified_from_eid(Eid("bar".into()));
566        let e5 = EntityUID::from_components(
567            Name::parse_unqualified_name("Unspecified").expect("should be a valid identifier"),
568            Eid("foo".into()),
569            None,
570        );
571
572        // an EUID is equal to itself
573        assert_eq!(e1, e1);
574        assert_eq!(e2, e2);
575        assert_eq!(e3, e3);
576
577        // constructing with `with_euid` or `from_components` is the same
578        assert_eq!(e1, e2);
579
580        // other pairs are not equal
581        assert!(e1 != e3);
582        assert!(e1 != e4);
583        assert!(e1 != e5);
584        assert!(e3 != e4);
585        assert!(e3 != e5);
586        assert!(e4 != e5);
587
588        // e3 and e5 are displayed differently
589        assert!(format!("{e3}") != format!("{e5}"));
590    }
591}