cedar_policy_core/ast/
entity.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use crate::entities::{EntitiesError, EntityJson, JsonSerializationError};
19use crate::evaluator::{EvaluationError, RestrictedEvaluator};
20use crate::extensions::Extensions;
21use crate::parser::err::ParseErrors;
22use crate::parser::Loc;
23use crate::transitive_closure::TCNode;
24use crate::FromNormalizedStr;
25use itertools::Itertools;
26use miette::Diagnostic;
27use serde::{Deserialize, Serialize};
28use serde_with::{serde_as, TryFromInto};
29use smol_str::SmolStr;
30use std::collections::{BTreeMap, HashMap, HashSet};
31use thiserror::Error;
32
33/// We support two types of entities. The first is a nominal type (e.g., User, Action)
34/// and the second is an unspecified type, which is used (internally) to represent cases
35/// where the input request does not provide a principal, action, and/or resource.
36#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
37#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
38pub enum EntityType {
39    /// Concrete nominal type
40    Specified(Name),
41    /// Unspecified
42    Unspecified,
43}
44
45impl EntityType {
46    /// Is this an Action entity type
47    pub fn is_action(&self) -> bool {
48        match self {
49            Self::Specified(name) => name.basename() == &Id::new_unchecked("Action"),
50            Self::Unspecified => false,
51        }
52    }
53}
54
55// Note: the characters '<' and '>' are not allowed in `Name`s, so the display for
56// `Unspecified` never conflicts with `Specified(name)`.
57impl std::fmt::Display for EntityType {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            Self::Unspecified => write!(f, "<Unspecified>"),
61            Self::Specified(name) => write!(f, "{}", name),
62        }
63    }
64}
65
66/// Unique ID for an entity. These represent entities in the AST.
67#[derive(Serialize, Deserialize, Debug, Clone)]
68pub struct EntityUID {
69    /// Typename of the entity
70    ty: EntityType,
71    /// EID of the entity
72    eid: Eid,
73    /// Location of the entity in policy source
74    #[serde(skip)]
75    loc: Option<Loc>,
76}
77
78/// `PartialEq` implementation ignores the `loc`.
79impl PartialEq for EntityUID {
80    fn eq(&self, other: &Self) -> bool {
81        self.ty == other.ty && self.eid == other.eid
82    }
83}
84impl Eq for EntityUID {}
85
86impl std::hash::Hash for EntityUID {
87    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
88        // hash the ty and eid, in line with the `PartialEq` impl which compares
89        // the ty and eid.
90        self.ty.hash(state);
91        self.eid.hash(state);
92    }
93}
94
95impl PartialOrd for EntityUID {
96    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
97        Some(self.cmp(other))
98    }
99}
100impl Ord for EntityUID {
101    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
102        self.ty.cmp(&other.ty).then(self.eid.cmp(&other.eid))
103    }
104}
105
106impl StaticallyTyped for EntityUID {
107    fn type_of(&self) -> Type {
108        Type::Entity {
109            ty: self.ty.clone(),
110        }
111    }
112}
113
114impl EntityUID {
115    /// Create an `EntityUID` with the given string as its EID.
116    /// Useful for testing.
117    #[cfg(test)]
118    pub(crate) fn with_eid(eid: &str) -> Self {
119        Self {
120            ty: Self::test_entity_type(),
121            eid: Eid(eid.into()),
122            loc: None,
123        }
124    }
125    // by default, Coverlay does not track coverage for lines after a line
126    // containing #[cfg(test)].
127    // we use the following sentinel to "turn back on" coverage tracking for
128    // remaining lines of this file, until the next #[cfg(test)]
129    // GRCOV_BEGIN_COVERAGE
130
131    /// The type of entities created with the above `with_eid()`.
132    #[cfg(test)]
133    pub(crate) fn test_entity_type() -> EntityType {
134        let name = Name::parse_unqualified_name("test_entity_type")
135            .expect("test_entity_type should be a valid identifier");
136        EntityType::Specified(name)
137    }
138    // by default, Coverlay does not track coverage for lines after a line
139    // containing #[cfg(test)].
140    // we use the following sentinel to "turn back on" coverage tracking for
141    // remaining lines of this file, until the next #[cfg(test)]
142    // GRCOV_BEGIN_COVERAGE
143
144    /// Create an `EntityUID` with the given (unqualified) typename, and the given string as its EID.
145    pub fn with_eid_and_type(typename: &str, eid: &str) -> Result<Self, ParseErrors> {
146        Ok(Self {
147            ty: EntityType::Specified(Name::parse_unqualified_name(typename)?),
148            eid: Eid(eid.into()),
149            loc: None,
150        })
151    }
152
153    /// Split into the `EntityType` representing the entity type, and the `Eid`
154    /// representing its name
155    pub fn components(self) -> (EntityType, Eid) {
156        (self.ty, self.eid)
157    }
158
159    /// Get the source location for this `EntityUID`.
160    pub fn loc(&self) -> Option<&Loc> {
161        self.loc.as_ref()
162    }
163
164    /// Create a nominally-typed `EntityUID` with the given typename and EID
165    pub fn from_components(name: Name, eid: Eid, loc: Option<Loc>) -> Self {
166        Self {
167            ty: EntityType::Specified(name),
168            eid,
169            loc,
170        }
171    }
172
173    /// Create an unspecified `EntityUID` with the given EID
174    pub fn unspecified_from_eid(eid: Eid) -> Self {
175        Self {
176            ty: EntityType::Unspecified,
177            eid,
178            loc: None,
179        }
180    }
181
182    /// Get the type component.
183    pub fn entity_type(&self) -> &EntityType {
184        &self.ty
185    }
186
187    /// Get the Eid component.
188    pub fn eid(&self) -> &Eid {
189        &self.eid
190    }
191
192    /// Does this EntityUID refer to an action entity?
193    pub fn is_action(&self) -> bool {
194        self.entity_type().is_action()
195    }
196}
197
198impl std::fmt::Display for EntityUID {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        write!(f, "{}::\"{}\"", self.entity_type(), self.eid)
201    }
202}
203
204// allow `.parse()` on a string to make an `EntityUID`
205impl std::str::FromStr for EntityUID {
206    type Err = ParseErrors;
207
208    fn from_str(s: &str) -> Result<Self, Self::Err> {
209        crate::parser::parse_euid(s)
210    }
211}
212
213impl FromNormalizedStr for EntityUID {
214    fn describe_self() -> &'static str {
215        "Entity UID"
216    }
217}
218
219#[cfg(feature = "arbitrary")]
220impl<'a> arbitrary::Arbitrary<'a> for EntityUID {
221    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
222        Ok(Self {
223            ty: u.arbitrary()?,
224            eid: u.arbitrary()?,
225            loc: None,
226        })
227    }
228}
229
230/// EID type is just a SmolStr for now
231#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
232pub struct Eid(SmolStr);
233
234impl Eid {
235    /// Construct an Eid
236    pub fn new(eid: impl Into<SmolStr>) -> Self {
237        Eid(eid.into())
238    }
239
240    /// Get the contents of the `Eid` as an escaped string
241    pub fn escaped(&self) -> SmolStr {
242        self.0.escape_debug().collect()
243    }
244}
245
246impl AsRef<SmolStr> for Eid {
247    fn as_ref(&self) -> &SmolStr {
248        &self.0
249    }
250}
251
252impl AsRef<str> for Eid {
253    fn as_ref(&self) -> &str {
254        &self.0
255    }
256}
257
258#[cfg(feature = "arbitrary")]
259impl<'a> arbitrary::Arbitrary<'a> for Eid {
260    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
261        let x: String = u.arbitrary()?;
262        Ok(Self(x.into()))
263    }
264}
265
266impl std::fmt::Display for Eid {
267    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268        write!(f, "{}", self.0.escape_debug())
269    }
270}
271
272/// Entity datatype
273#[derive(Debug, Clone, Serialize)]
274pub struct Entity {
275    /// UID
276    uid: EntityUID,
277
278    /// Internal BTreMap of attributes.
279    /// We use a btreemap so that the keys have a determenistic order.
280    ///
281    /// In the serialized form of `Entity`, attribute values appear as
282    /// `RestrictedExpr`s, for mostly historical reasons.
283    attrs: BTreeMap<SmolStr, PartialValueSerializedAsExpr>,
284
285    /// Set of ancestors of this `Entity` (i.e., all direct and transitive
286    /// parents), as UIDs
287    ancestors: HashSet<EntityUID>,
288}
289
290impl Entity {
291    /// Create a new `Entity` with this UID, attributes, and ancestors
292    pub fn new(
293        uid: EntityUID,
294        attrs: HashMap<SmolStr, RestrictedExpr>,
295        ancestors: HashSet<EntityUID>,
296        extensions: &Extensions<'_>,
297    ) -> Result<Self, EntityAttrEvaluationError> {
298        let evaluator = RestrictedEvaluator::new(extensions);
299        let evaluated_attrs = attrs
300            .into_iter()
301            .map(|(k, v)| {
302                let attr_val = evaluator
303                    .partial_interpret(v.as_borrowed())
304                    .map_err(|err| EntityAttrEvaluationError {
305                        uid: uid.clone(),
306                        attr: k.clone(),
307                        err,
308                    })?;
309                Ok((k, attr_val.into()))
310            })
311            .collect::<Result<_, EntityAttrEvaluationError>>()?;
312        Ok(Entity {
313            uid,
314            attrs: evaluated_attrs,
315            ancestors,
316        })
317    }
318
319    /// Create a new `Entity` with this UID, attributes, and ancestors.
320    ///
321    /// Unlike in `Entity::new()`, in this constructor, attributes are expressed
322    /// as `PartialValue`.
323    pub fn new_with_attr_partial_value(
324        uid: EntityUID,
325        attrs: HashMap<SmolStr, PartialValue>,
326        ancestors: HashSet<EntityUID>,
327    ) -> Self {
328        Entity {
329            uid,
330            attrs: attrs.into_iter().map(|(k, v)| (k, v.into())).collect(), // TODO(#540): can we do this without disassembling and reassembling the HashMap
331            ancestors,
332        }
333    }
334
335    /// Create a new `Entity` with this UID, attributes, and ancestors.
336    ///
337    /// Unlike in `Entity::new()`, in this constructor, attributes are expressed
338    /// as `PartialValueSerializedAsExpr`.
339    pub fn new_with_attr_partial_value_serialized_as_expr(
340        uid: EntityUID,
341        attrs: BTreeMap<SmolStr, PartialValueSerializedAsExpr>,
342        ancestors: HashSet<EntityUID>,
343    ) -> Self {
344        Entity {
345            uid,
346            attrs,
347            ancestors,
348        }
349    }
350
351    /// Get the UID of this entity
352    pub fn uid(&self) -> &EntityUID {
353        &self.uid
354    }
355
356    /// Get the value for the given attribute, or `None` if not present
357    pub fn get(&self, attr: &str) -> Option<&PartialValue> {
358        self.attrs.get(attr).map(|v| v.as_ref())
359    }
360
361    /// Is this `Entity` a descendant of `e` in the entity hierarchy?
362    pub fn is_descendant_of(&self, e: &EntityUID) -> bool {
363        self.ancestors.contains(e)
364    }
365
366    /// Iterate over this entity's ancestors
367    pub fn ancestors(&self) -> impl Iterator<Item = &EntityUID> {
368        self.ancestors.iter()
369    }
370
371    /// Get the number of attributes on this entity
372    pub fn attrs_len(&self) -> usize {
373        self.attrs.len()
374    }
375
376    /// Iterate over this entity's attribute names
377    pub fn keys(&self) -> impl Iterator<Item = &SmolStr> {
378        self.attrs.keys()
379    }
380
381    /// Iterate over this entity's attributes
382    pub fn attrs(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
383        self.attrs.iter().map(|(k, v)| (k, v.as_ref()))
384    }
385
386    /// Create an `Entity` with the given UID, no attributes, and no parents.
387    pub fn with_uid(uid: EntityUID) -> Self {
388        Self {
389            uid,
390            attrs: BTreeMap::new(),
391            ancestors: HashSet::new(),
392        }
393    }
394
395    /// Test if two `Entity` objects are deep/structurally equal.
396    /// That is, not only do they have the same UID, but also the same
397    /// attributes, attribute values, and ancestors.
398    pub(crate) fn deep_eq(&self, other: &Self) -> bool {
399        self.uid == other.uid && self.attrs == other.attrs && self.ancestors == other.ancestors
400    }
401
402    /// Set the given attribute to the given value.
403    // Only used for convenience in some tests and when fuzzing
404    #[cfg(any(test, fuzzing))]
405    pub fn set_attr(
406        &mut self,
407        attr: SmolStr,
408        val: RestrictedExpr,
409        extensions: &Extensions<'_>,
410    ) -> Result<(), EvaluationError> {
411        let val = RestrictedEvaluator::new(extensions).partial_interpret(val.as_borrowed())?;
412        self.attrs.insert(attr, val.into());
413        Ok(())
414    }
415
416    /// Mark the given `UID` as an ancestor of this `Entity`.
417    // When fuzzing, `add_ancestor()` is fully `pub`.
418    #[cfg(not(fuzzing))]
419    pub(crate) fn add_ancestor(&mut self, uid: EntityUID) {
420        self.ancestors.insert(uid);
421    }
422    /// Mark the given `UID` as an ancestor of this `Entity`
423    #[cfg(fuzzing)]
424    pub fn add_ancestor(&mut self, uid: EntityUID) {
425        self.ancestors.insert(uid);
426    }
427
428    /// Consume the entity and return the entity's owned Uid, attributes and parents.
429    pub fn into_inner(
430        self,
431    ) -> (
432        EntityUID,
433        HashMap<SmolStr, PartialValue>,
434        HashSet<EntityUID>,
435    ) {
436        let Self {
437            uid,
438            attrs,
439            ancestors,
440        } = self;
441        (
442            uid,
443            attrs.into_iter().map(|(k, v)| (k, v.0)).collect(),
444            ancestors,
445        )
446    }
447
448    /// Write the entity to a json document
449    pub fn write_to_json(&self, f: impl std::io::Write) -> Result<(), EntitiesError> {
450        let ejson = EntityJson::from_entity(self)?;
451        serde_json::to_writer_pretty(f, &ejson).map_err(JsonSerializationError::from)?;
452        Ok(())
453    }
454
455    /// write the entity to a json value
456    pub fn to_json_value(&self) -> Result<serde_json::Value, EntitiesError> {
457        let ejson = EntityJson::from_entity(self)?;
458        let v = serde_json::to_value(ejson).map_err(JsonSerializationError::from)?;
459        Ok(v)
460    }
461
462    /// write the entity to a json string
463    pub fn to_json_string(&self) -> Result<String, EntitiesError> {
464        let ejson = EntityJson::from_entity(self)?;
465        let string = serde_json::to_string(&ejson).map_err(JsonSerializationError::from)?;
466        Ok(string)
467    }
468}
469
470impl PartialEq for Entity {
471    fn eq(&self, other: &Self) -> bool {
472        self.uid() == other.uid()
473    }
474}
475
476impl Eq for Entity {}
477
478impl StaticallyTyped for Entity {
479    fn type_of(&self) -> Type {
480        self.uid.type_of()
481    }
482}
483
484impl TCNode<EntityUID> for Entity {
485    fn get_key(&self) -> EntityUID {
486        self.uid().clone()
487    }
488
489    fn add_edge_to(&mut self, k: EntityUID) {
490        self.add_ancestor(k)
491    }
492
493    fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
494        Box::new(self.ancestors())
495    }
496
497    fn has_edge_to(&self, e: &EntityUID) -> bool {
498        self.is_descendant_of(e)
499    }
500}
501
502impl std::fmt::Display for Entity {
503    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504        write!(
505            f,
506            "{}:\n  attrs:{}\n  ancestors:{}",
507            self.uid,
508            self.attrs
509                .iter()
510                .map(|(k, v)| format!("{}: {}", k, v))
511                .join("; "),
512            self.ancestors.iter().join(", ")
513        )
514    }
515}
516
517/// `PartialValue`, but serialized as a `RestrictedExpr`.
518///
519/// (Extension values can't be directly serialized, but can be serialized as
520/// `RestrictedExpr`)
521#[serde_as]
522#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
523pub struct PartialValueSerializedAsExpr(
524    #[serde_as(as = "TryFromInto<RestrictedExpr>")] PartialValue,
525);
526
527impl AsRef<PartialValue> for PartialValueSerializedAsExpr {
528    fn as_ref(&self) -> &PartialValue {
529        &self.0
530    }
531}
532
533impl std::ops::Deref for PartialValueSerializedAsExpr {
534    type Target = PartialValue;
535    fn deref(&self) -> &Self::Target {
536        &self.0
537    }
538}
539
540impl From<PartialValue> for PartialValueSerializedAsExpr {
541    fn from(value: PartialValue) -> PartialValueSerializedAsExpr {
542        PartialValueSerializedAsExpr(value)
543    }
544}
545
546impl From<PartialValueSerializedAsExpr> for PartialValue {
547    fn from(value: PartialValueSerializedAsExpr) -> PartialValue {
548        value.0
549    }
550}
551
552impl std::fmt::Display for PartialValueSerializedAsExpr {
553    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
554        write!(f, "{}", self.0)
555    }
556}
557
558/// Error type for evaluation errors when evaluating an entity attribute.
559/// Contains some extra contextual information and the underlying
560/// `EvaluationError`.
561#[derive(Debug, Diagnostic, Error)]
562#[error("failed to evaluate attribute `{attr}` of `{uid}`: {err}")]
563pub struct EntityAttrEvaluationError {
564    /// UID of the entity where the error was encountered
565    pub uid: EntityUID,
566    /// Attribute of the entity where the error was encountered
567    pub attr: SmolStr,
568    /// Underlying evaluation error
569    #[diagnostic(transparent)]
570    pub err: EvaluationError,
571}
572
573#[cfg(test)]
574mod test {
575    use std::str::FromStr;
576
577    use super::*;
578
579    #[test]
580    fn display() {
581        let e = EntityUID::with_eid("eid");
582        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
583    }
584
585    #[test]
586    fn test_euid_equality() {
587        let e1 = EntityUID::with_eid("foo");
588        let e2 = EntityUID::from_components(
589            Name::parse_unqualified_name("test_entity_type").expect("should be a valid identifier"),
590            Eid("foo".into()),
591            None,
592        );
593        let e3 = EntityUID::unspecified_from_eid(Eid("foo".into()));
594        let e4 = EntityUID::unspecified_from_eid(Eid("bar".into()));
595        let e5 = EntityUID::from_components(
596            Name::parse_unqualified_name("Unspecified").expect("should be a valid identifier"),
597            Eid("foo".into()),
598            None,
599        );
600
601        // an EUID is equal to itself
602        assert_eq!(e1, e1);
603        assert_eq!(e2, e2);
604        assert_eq!(e3, e3);
605
606        // constructing with `with_euid` or `from_components` is the same
607        assert_eq!(e1, e2);
608
609        // other pairs are not equal
610        assert!(e1 != e3);
611        assert!(e1 != e4);
612        assert!(e1 != e5);
613        assert!(e3 != e4);
614        assert!(e3 != e5);
615        assert!(e4 != e5);
616
617        // e3 and e5 are displayed differently
618        assert!(format!("{e3}") != format!("{e5}"));
619    }
620
621    #[test]
622    fn action_checker() {
623        let euid = EntityUID::from_str("Action::\"view\"").unwrap();
624        assert!(euid.is_action());
625        let euid = EntityUID::from_str("Foo::Action::\"view\"").unwrap();
626        assert!(euid.is_action());
627        let euid = EntityUID::from_str("Foo::\"view\"").unwrap();
628        assert!(!euid.is_action());
629        let euid = EntityUID::from_str("Action::Foo::\"view\"").unwrap();
630        assert!(!euid.is_action());
631    }
632}