cedar_policy_core/ast/
entity.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use crate::entities::{EntitiesError, EntityJson, JsonSerializationError};
19use crate::evaluator::{EvaluationError, RestrictedEvaluator};
20use crate::extensions::Extensions;
21use crate::parser::err::ParseErrors;
22use crate::parser::Loc;
23use crate::transitive_closure::TCNode;
24use crate::FromNormalizedStr;
25use itertools::Itertools;
26use miette::Diagnostic;
27use serde::{Deserialize, Serialize};
28use serde_with::{serde_as, TryFromInto};
29use smol_str::SmolStr;
30use std::collections::{BTreeMap, HashMap, HashSet};
31use thiserror::Error;
32
33/// We support two types of entities. The first is a nominal type (e.g., User, Action)
34/// and the second is an unspecified type, which is used (internally) to represent cases
35/// where the input request does not provide a principal, action, and/or resource.
36#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
37#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
38pub enum EntityType {
39    /// Concrete nominal type
40    Specified(Name),
41    /// Unspecified
42    Unspecified,
43}
44
45impl EntityType {
46    /// Is this an Action entity type
47    pub fn is_action(&self) -> bool {
48        match self {
49            Self::Specified(name) => name.basename() == &Id::new_unchecked("Action"),
50            Self::Unspecified => false,
51        }
52    }
53}
54
55// Note: the characters '<' and '>' are not allowed in `Name`s, so the display for
56// `Unspecified` never conflicts with `Specified(name)`.
57impl std::fmt::Display for EntityType {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            Self::Unspecified => write!(f, "<Unspecified>"),
61            Self::Specified(name) => write!(f, "{}", name),
62        }
63    }
64}
65
66/// Unique ID for an entity. These represent entities in the AST.
67#[derive(Serialize, Deserialize, Debug, Clone)]
68pub struct EntityUID {
69    /// Typename of the entity
70    ty: EntityType,
71    /// EID of the entity
72    eid: Eid,
73    /// Location of the entity in policy source
74    #[serde(skip)]
75    loc: Option<Loc>,
76}
77
78/// `PartialEq` implementation ignores the `loc`.
79impl PartialEq for EntityUID {
80    fn eq(&self, other: &Self) -> bool {
81        self.ty == other.ty && self.eid == other.eid
82    }
83}
84impl Eq for EntityUID {}
85
86impl std::hash::Hash for EntityUID {
87    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
88        // hash the ty and eid, in line with the `PartialEq` impl which compares
89        // the ty and eid.
90        self.ty.hash(state);
91        self.eid.hash(state);
92    }
93}
94
95impl PartialOrd for EntityUID {
96    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
97        Some(self.cmp(other))
98    }
99}
100impl Ord for EntityUID {
101    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
102        self.ty.cmp(&other.ty).then(self.eid.cmp(&other.eid))
103    }
104}
105
106impl StaticallyTyped for EntityUID {
107    fn type_of(&self) -> Type {
108        Type::Entity {
109            ty: self.ty.clone(),
110        }
111    }
112}
113
114impl EntityUID {
115    /// Create an `EntityUID` with the given string as its EID.
116    /// Useful for testing.
117    #[cfg(test)]
118    pub(crate) fn with_eid(eid: &str) -> Self {
119        Self {
120            ty: Self::test_entity_type(),
121            eid: Eid(eid.into()),
122            loc: None,
123        }
124    }
125    // by default, Coverlay does not track coverage for lines after a line
126    // containing #[cfg(test)].
127    // we use the following sentinel to "turn back on" coverage tracking for
128    // remaining lines of this file, until the next #[cfg(test)]
129    // GRCOV_BEGIN_COVERAGE
130
131    /// The type of entities created with the above `with_eid()`.
132    #[cfg(test)]
133    pub(crate) fn test_entity_type() -> EntityType {
134        let name = Name::parse_unqualified_name("test_entity_type")
135            .expect("test_entity_type should be a valid identifier");
136        EntityType::Specified(name)
137    }
138    // by default, Coverlay does not track coverage for lines after a line
139    // containing #[cfg(test)].
140    // we use the following sentinel to "turn back on" coverage tracking for
141    // remaining lines of this file, until the next #[cfg(test)]
142    // GRCOV_BEGIN_COVERAGE
143
144    /// Create an `EntityUID` with the given (unqualified) typename, and the given string as its EID.
145    pub fn with_eid_and_type(typename: &str, eid: &str) -> Result<Self, ParseErrors> {
146        Ok(Self {
147            ty: EntityType::Specified(Name::parse_unqualified_name(typename)?),
148            eid: Eid(eid.into()),
149            loc: None,
150        })
151    }
152
153    /// Split into the `EntityType` representing the entity type, and the `Eid`
154    /// representing its name
155    pub fn components(self) -> (EntityType, Eid) {
156        (self.ty, self.eid)
157    }
158
159    /// Get the source location for this `EntityUID`.
160    pub fn loc(&self) -> Option<&Loc> {
161        self.loc.as_ref()
162    }
163
164    /// Create a nominally-typed `EntityUID` with the given typename and EID
165    pub fn from_components(name: Name, eid: Eid, loc: Option<Loc>) -> Self {
166        Self {
167            ty: EntityType::Specified(name),
168            eid,
169            loc,
170        }
171    }
172
173    /// Create an unspecified `EntityUID` with the given EID
174    pub fn unspecified_from_eid(eid: Eid) -> Self {
175        Self {
176            ty: EntityType::Unspecified,
177            eid,
178            loc: None,
179        }
180    }
181
182    /// Get the type component.
183    pub fn entity_type(&self) -> &EntityType {
184        &self.ty
185    }
186
187    /// Get the Eid component.
188    pub fn eid(&self) -> &Eid {
189        &self.eid
190    }
191
192    /// Does this EntityUID refer to an action entity?
193    pub fn is_action(&self) -> bool {
194        self.entity_type().is_action()
195    }
196}
197
198impl std::fmt::Display for EntityUID {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        write!(f, "{}::\"{}\"", self.entity_type(), self.eid)
201    }
202}
203
204// allow `.parse()` on a string to make an `EntityUID`
205impl std::str::FromStr for EntityUID {
206    type Err = ParseErrors;
207
208    fn from_str(s: &str) -> Result<Self, Self::Err> {
209        crate::parser::parse_euid(s)
210    }
211}
212
213impl FromNormalizedStr for EntityUID {
214    fn describe_self() -> &'static str {
215        "Entity UID"
216    }
217}
218
219#[cfg(feature = "arbitrary")]
220impl<'a> arbitrary::Arbitrary<'a> for EntityUID {
221    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
222        Ok(Self {
223            ty: u.arbitrary()?,
224            eid: u.arbitrary()?,
225            loc: None,
226        })
227    }
228}
229
230/// EID type is just a SmolStr for now
231#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
232pub struct Eid(SmolStr);
233
234impl Eid {
235    /// Construct an Eid
236    pub fn new(eid: impl Into<SmolStr>) -> Self {
237        Eid(eid.into())
238    }
239
240    /// Get the contents of the `Eid` as an escaped string
241    pub fn escaped(&self) -> SmolStr {
242        self.0.escape_debug().collect()
243    }
244}
245
246impl AsRef<SmolStr> for Eid {
247    fn as_ref(&self) -> &SmolStr {
248        &self.0
249    }
250}
251
252impl AsRef<str> for Eid {
253    fn as_ref(&self) -> &str {
254        &self.0
255    }
256}
257
258#[cfg(feature = "arbitrary")]
259impl<'a> arbitrary::Arbitrary<'a> for Eid {
260    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
261        let x: String = u.arbitrary()?;
262        Ok(Self(x.into()))
263    }
264}
265
266impl std::fmt::Display for Eid {
267    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268        write!(f, "{}", self.0.escape_debug())
269    }
270}
271
272/// Entity datatype
273#[derive(Debug, Clone, Serialize)]
274pub struct Entity {
275    /// UID
276    uid: EntityUID,
277
278    /// Internal BTreMap of attributes.
279    /// We use a btreemap so that the keys have a determenistic order.
280    ///
281    /// In the serialized form of `Entity`, attribute values appear as
282    /// `RestrictedExpr`s, for mostly historical reasons.
283    attrs: BTreeMap<SmolStr, PartialValueSerializedAsExpr>,
284
285    /// Set of ancestors of this `Entity` (i.e., all direct and transitive
286    /// parents), as UIDs
287    ancestors: HashSet<EntityUID>,
288}
289
290impl std::hash::Hash for Entity {
291    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
292        self.uid.hash(state);
293    }
294}
295
296impl Entity {
297    /// Create a new `Entity` with this UID, attributes, and ancestors
298    ///
299    /// # Errors
300    /// - Will error if any of the [`RestrictedExpr]`s in `attrs` error when evaluated
301    pub fn new(
302        uid: EntityUID,
303        attrs: HashMap<SmolStr, RestrictedExpr>,
304        ancestors: HashSet<EntityUID>,
305        extensions: &Extensions<'_>,
306    ) -> Result<Self, EntityAttrEvaluationError> {
307        let evaluator = RestrictedEvaluator::new(extensions);
308        let evaluated_attrs = attrs
309            .into_iter()
310            .map(|(k, v)| {
311                let attr_val = evaluator
312                    .partial_interpret(v.as_borrowed())
313                    .map_err(|err| EntityAttrEvaluationError {
314                        uid: uid.clone(),
315                        attr: k.clone(),
316                        err,
317                    })?;
318                Ok((k, attr_val.into()))
319            })
320            .collect::<Result<_, EntityAttrEvaluationError>>()?;
321        Ok(Entity {
322            uid,
323            attrs: evaluated_attrs,
324            ancestors,
325        })
326    }
327
328    /// Create a new `Entity` with this UID, attributes, and ancestors.
329    ///
330    /// Unlike in `Entity::new()`, in this constructor, attributes are expressed
331    /// as `PartialValue`.
332    pub fn new_with_attr_partial_value(
333        uid: EntityUID,
334        attrs: HashMap<SmolStr, PartialValue>,
335        ancestors: HashSet<EntityUID>,
336    ) -> Self {
337        Entity {
338            uid,
339            attrs: attrs.into_iter().map(|(k, v)| (k, v.into())).collect(), // TODO(#540): can we do this without disassembling and reassembling the HashMap
340            ancestors,
341        }
342    }
343
344    /// Create a new `Entity` with this UID, attributes, and ancestors.
345    ///
346    /// Unlike in `Entity::new()`, in this constructor, attributes are expressed
347    /// as `PartialValueSerializedAsExpr`.
348    pub fn new_with_attr_partial_value_serialized_as_expr(
349        uid: EntityUID,
350        attrs: BTreeMap<SmolStr, PartialValueSerializedAsExpr>,
351        ancestors: HashSet<EntityUID>,
352    ) -> Self {
353        Entity {
354            uid,
355            attrs,
356            ancestors,
357        }
358    }
359
360    /// Get the UID of this entity
361    pub fn uid(&self) -> &EntityUID {
362        &self.uid
363    }
364
365    /// Get the value for the given attribute, or `None` if not present
366    pub fn get(&self, attr: &str) -> Option<&PartialValue> {
367        self.attrs.get(attr).map(|v| v.as_ref())
368    }
369
370    /// Is this `Entity` a descendant of `e` in the entity hierarchy?
371    pub fn is_descendant_of(&self, e: &EntityUID) -> bool {
372        self.ancestors.contains(e)
373    }
374
375    /// Iterate over this entity's ancestors
376    pub fn ancestors(&self) -> impl Iterator<Item = &EntityUID> {
377        self.ancestors.iter()
378    }
379
380    /// Get the number of attributes on this entity
381    pub fn attrs_len(&self) -> usize {
382        self.attrs.len()
383    }
384
385    /// Iterate over this entity's attribute names
386    pub fn keys(&self) -> impl Iterator<Item = &SmolStr> {
387        self.attrs.keys()
388    }
389
390    /// Iterate over this entity's attributes
391    pub fn attrs(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
392        self.attrs.iter().map(|(k, v)| (k, v.as_ref()))
393    }
394
395    /// Create an `Entity` with the given UID, no attributes, and no parents.
396    pub fn with_uid(uid: EntityUID) -> Self {
397        Self {
398            uid,
399            attrs: BTreeMap::new(),
400            ancestors: HashSet::new(),
401        }
402    }
403
404    /// Test if two `Entity` objects are deep/structurally equal.
405    /// That is, not only do they have the same UID, but also the same
406    /// attributes, attribute values, and ancestors.
407    pub(crate) fn deep_eq(&self, other: &Self) -> bool {
408        self.uid == other.uid && self.attrs == other.attrs && self.ancestors == other.ancestors
409    }
410
411    /// Set the given attribute to the given value.
412    // Only used for convenience in some tests and when fuzzing
413    #[cfg(any(test, fuzzing))]
414    pub fn set_attr(
415        &mut self,
416        attr: SmolStr,
417        val: RestrictedExpr,
418        extensions: &Extensions<'_>,
419    ) -> Result<(), EvaluationError> {
420        let val = RestrictedEvaluator::new(extensions).partial_interpret(val.as_borrowed())?;
421        self.attrs.insert(attr, val.into());
422        Ok(())
423    }
424
425    /// Mark the given `UID` as an ancestor of this `Entity`.
426    // When fuzzing, `add_ancestor()` is fully `pub`.
427    #[cfg(not(fuzzing))]
428    pub(crate) fn add_ancestor(&mut self, uid: EntityUID) {
429        self.ancestors.insert(uid);
430    }
431    /// Mark the given `UID` as an ancestor of this `Entity`
432    #[cfg(fuzzing)]
433    pub fn add_ancestor(&mut self, uid: EntityUID) {
434        self.ancestors.insert(uid);
435    }
436
437    /// Consume the entity and return the entity's owned Uid, attributes and parents.
438    pub fn into_inner(
439        self,
440    ) -> (
441        EntityUID,
442        HashMap<SmolStr, PartialValue>,
443        HashSet<EntityUID>,
444    ) {
445        let Self {
446            uid,
447            attrs,
448            ancestors,
449        } = self;
450        (
451            uid,
452            attrs.into_iter().map(|(k, v)| (k, v.0)).collect(),
453            ancestors,
454        )
455    }
456
457    /// Write the entity to a json document
458    pub fn write_to_json(&self, f: impl std::io::Write) -> Result<(), EntitiesError> {
459        let ejson = EntityJson::from_entity(self)?;
460        serde_json::to_writer_pretty(f, &ejson).map_err(JsonSerializationError::from)?;
461        Ok(())
462    }
463
464    /// write the entity to a json value
465    pub fn to_json_value(&self) -> Result<serde_json::Value, EntitiesError> {
466        let ejson = EntityJson::from_entity(self)?;
467        let v = serde_json::to_value(ejson).map_err(JsonSerializationError::from)?;
468        Ok(v)
469    }
470
471    /// write the entity to a json string
472    pub fn to_json_string(&self) -> Result<String, EntitiesError> {
473        let ejson = EntityJson::from_entity(self)?;
474        let string = serde_json::to_string(&ejson).map_err(JsonSerializationError::from)?;
475        Ok(string)
476    }
477}
478
479impl PartialEq for Entity {
480    fn eq(&self, other: &Self) -> bool {
481        self.uid() == other.uid()
482    }
483}
484
485impl Eq for Entity {}
486
487impl StaticallyTyped for Entity {
488    fn type_of(&self) -> Type {
489        self.uid.type_of()
490    }
491}
492
493impl TCNode<EntityUID> for Entity {
494    fn get_key(&self) -> EntityUID {
495        self.uid().clone()
496    }
497
498    fn add_edge_to(&mut self, k: EntityUID) {
499        self.add_ancestor(k)
500    }
501
502    fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
503        Box::new(self.ancestors())
504    }
505
506    fn has_edge_to(&self, e: &EntityUID) -> bool {
507        self.is_descendant_of(e)
508    }
509}
510
511impl std::fmt::Display for Entity {
512    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
513        write!(
514            f,
515            "{}:\n  attrs:{}\n  ancestors:{}",
516            self.uid,
517            self.attrs
518                .iter()
519                .map(|(k, v)| format!("{}: {}", k, v))
520                .join("; "),
521            self.ancestors.iter().join(", ")
522        )
523    }
524}
525
526/// `PartialValue`, but serialized as a `RestrictedExpr`.
527///
528/// (Extension values can't be directly serialized, but can be serialized as
529/// `RestrictedExpr`)
530#[serde_as]
531#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
532pub struct PartialValueSerializedAsExpr(
533    #[serde_as(as = "TryFromInto<RestrictedExpr>")] PartialValue,
534);
535
536impl AsRef<PartialValue> for PartialValueSerializedAsExpr {
537    fn as_ref(&self) -> &PartialValue {
538        &self.0
539    }
540}
541
542impl std::ops::Deref for PartialValueSerializedAsExpr {
543    type Target = PartialValue;
544    fn deref(&self) -> &Self::Target {
545        &self.0
546    }
547}
548
549impl From<PartialValue> for PartialValueSerializedAsExpr {
550    fn from(value: PartialValue) -> PartialValueSerializedAsExpr {
551        PartialValueSerializedAsExpr(value)
552    }
553}
554
555impl From<PartialValueSerializedAsExpr> for PartialValue {
556    fn from(value: PartialValueSerializedAsExpr) -> PartialValue {
557        value.0
558    }
559}
560
561impl std::fmt::Display for PartialValueSerializedAsExpr {
562    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
563        write!(f, "{}", self.0)
564    }
565}
566
567/// Error type for evaluation errors when evaluating an entity attribute.
568/// Contains some extra contextual information and the underlying
569/// `EvaluationError`.
570#[derive(Debug, Diagnostic, Error)]
571#[error("failed to evaluate attribute `{attr}` of `{uid}`: {err}")]
572pub struct EntityAttrEvaluationError {
573    /// UID of the entity where the error was encountered
574    pub uid: EntityUID,
575    /// Attribute of the entity where the error was encountered
576    pub attr: SmolStr,
577    /// Underlying evaluation error
578    #[diagnostic(transparent)]
579    pub err: EvaluationError,
580}
581
582#[cfg(test)]
583mod test {
584    use std::str::FromStr;
585
586    use super::*;
587
588    #[test]
589    fn display() {
590        let e = EntityUID::with_eid("eid");
591        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
592    }
593
594    #[test]
595    fn test_euid_equality() {
596        let e1 = EntityUID::with_eid("foo");
597        let e2 = EntityUID::from_components(
598            Name::parse_unqualified_name("test_entity_type").expect("should be a valid identifier"),
599            Eid("foo".into()),
600            None,
601        );
602        let e3 = EntityUID::unspecified_from_eid(Eid("foo".into()));
603        let e4 = EntityUID::unspecified_from_eid(Eid("bar".into()));
604        let e5 = EntityUID::from_components(
605            Name::parse_unqualified_name("Unspecified").expect("should be a valid identifier"),
606            Eid("foo".into()),
607            None,
608        );
609
610        // an EUID is equal to itself
611        assert_eq!(e1, e1);
612        assert_eq!(e2, e2);
613        assert_eq!(e3, e3);
614
615        // constructing with `with_euid` or `from_components` is the same
616        assert_eq!(e1, e2);
617
618        // other pairs are not equal
619        assert!(e1 != e3);
620        assert!(e1 != e4);
621        assert!(e1 != e5);
622        assert!(e3 != e4);
623        assert!(e3 != e5);
624        assert!(e4 != e5);
625
626        // e3 and e5 are displayed differently
627        assert!(format!("{e3}") != format!("{e5}"));
628    }
629
630    #[test]
631    fn action_checker() {
632        let euid = EntityUID::from_str("Action::\"view\"").unwrap();
633        assert!(euid.is_action());
634        let euid = EntityUID::from_str("Foo::Action::\"view\"").unwrap();
635        assert!(euid.is_action());
636        let euid = EntityUID::from_str("Foo::\"view\"").unwrap();
637        assert!(!euid.is_action());
638        let euid = EntityUID::from_str("Action::Foo::\"view\"").unwrap();
639        assert!(!euid.is_action());
640    }
641}