cedar_policy_core/ast/
entity.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use crate::entities::{err::EntitiesError, json::err::JsonSerializationError, EntityJson};
19use crate::evaluator::{EvaluationError, RestrictedEvaluator};
20use crate::extensions::Extensions;
21use crate::parser::err::ParseErrors;
22use crate::parser::Loc;
23use crate::transitive_closure::TCNode;
24use crate::FromNormalizedStr;
25use itertools::Itertools;
26use miette::Diagnostic;
27use ref_cast::RefCast;
28use serde::{Deserialize, Serialize};
29use serde_with::{serde_as, TryFromInto};
30use smol_str::SmolStr;
31use std::collections::{BTreeMap, HashMap, HashSet};
32use std::str::FromStr;
33use thiserror::Error;
34
35/// The entity type that Actions must have
36pub static ACTION_ENTITY_TYPE: &str = "Action";
37
38/// Entity type names are just [`Name`]s, but we have some operations on them specific to entity types.
39#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord, RefCast)]
40#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
41#[serde(transparent)]
42#[repr(transparent)]
43pub struct EntityType(Name);
44
45impl EntityType {
46    /// Is this an Action entity type?
47    /// Returns true when an entity type is an action entity type. This compares the
48    /// base name for the type, so this will return true for any entity type named
49    /// `Action` regardless of namespaces.
50    pub fn is_action(&self) -> bool {
51        self.0.as_ref().basename() == &Id::new_unchecked(ACTION_ENTITY_TYPE)
52    }
53
54    /// The name of this entity type
55    pub fn name(&self) -> &Name {
56        &self.0
57    }
58
59    /// The source location of this entity type
60    pub fn loc(&self) -> Option<&Loc> {
61        self.0.as_ref().loc()
62    }
63
64    /// Calls [`Name::qualify_with_name`] on the underlying [`Name`]
65    pub fn qualify_with(&self, namespace: Option<&Name>) -> Self {
66        Self(self.0.qualify_with_name(namespace))
67    }
68
69    /// Wraps [`Name::from_normalized_str`]
70    pub fn from_normalized_str(src: &str) -> Result<Self, ParseErrors> {
71        Name::from_normalized_str(src).map(Into::into)
72    }
73}
74
75impl From<Name> for EntityType {
76    fn from(n: Name) -> Self {
77        Self(n)
78    }
79}
80
81impl From<EntityType> for Name {
82    fn from(ty: EntityType) -> Name {
83        ty.0
84    }
85}
86
87impl AsRef<Name> for EntityType {
88    fn as_ref(&self) -> &Name {
89        &self.0
90    }
91}
92
93impl FromStr for EntityType {
94    type Err = ParseErrors;
95
96    fn from_str(s: &str) -> Result<Self, Self::Err> {
97        s.parse().map(Self)
98    }
99}
100
101impl std::fmt::Display for EntityType {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        write!(f, "{}", self.0)
104    }
105}
106
107/// Unique ID for an entity. These represent entities in the AST.
108#[derive(Serialize, Deserialize, Debug, Clone)]
109pub struct EntityUID {
110    /// Typename of the entity
111    ty: EntityType,
112    /// EID of the entity
113    eid: Eid,
114    /// Location of the entity in policy source
115    #[serde(skip)]
116    loc: Option<Loc>,
117}
118
119/// `PartialEq` implementation ignores the `loc`.
120impl PartialEq for EntityUID {
121    fn eq(&self, other: &Self) -> bool {
122        self.ty == other.ty && self.eid == other.eid
123    }
124}
125impl Eq for EntityUID {}
126
127impl std::hash::Hash for EntityUID {
128    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
129        // hash the ty and eid, in line with the `PartialEq` impl which compares
130        // the ty and eid.
131        self.ty.hash(state);
132        self.eid.hash(state);
133    }
134}
135
136impl PartialOrd for EntityUID {
137    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
138        Some(self.cmp(other))
139    }
140}
141impl Ord for EntityUID {
142    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
143        self.ty.cmp(&other.ty).then(self.eid.cmp(&other.eid))
144    }
145}
146
147impl StaticallyTyped for EntityUID {
148    fn type_of(&self) -> Type {
149        Type::Entity {
150            ty: self.ty.clone(),
151        }
152    }
153}
154
155impl EntityUID {
156    /// Create an `EntityUID` with the given string as its EID.
157    /// Useful for testing.
158    #[cfg(test)]
159    pub(crate) fn with_eid(eid: &str) -> Self {
160        Self {
161            ty: Self::test_entity_type(),
162            eid: Eid(eid.into()),
163            loc: None,
164        }
165    }
166    // by default, Coverlay does not track coverage for lines after a line
167    // containing #[cfg(test)].
168    // we use the following sentinel to "turn back on" coverage tracking for
169    // remaining lines of this file, until the next #[cfg(test)]
170    // GRCOV_BEGIN_COVERAGE
171
172    /// The type of entities created with the above `with_eid()`.
173    #[cfg(test)]
174    pub(crate) fn test_entity_type() -> EntityType {
175        let name = Name::parse_unqualified_name("test_entity_type")
176            .expect("test_entity_type should be a valid identifier");
177        EntityType(name)
178    }
179    // by default, Coverlay does not track coverage for lines after a line
180    // containing #[cfg(test)].
181    // we use the following sentinel to "turn back on" coverage tracking for
182    // remaining lines of this file, until the next #[cfg(test)]
183    // GRCOV_BEGIN_COVERAGE
184
185    /// Create an `EntityUID` with the given (unqualified) typename, and the given string as its EID.
186    pub fn with_eid_and_type(typename: &str, eid: &str) -> Result<Self, ParseErrors> {
187        Ok(Self {
188            ty: EntityType(Name::parse_unqualified_name(typename)?),
189            eid: Eid(eid.into()),
190            loc: None,
191        })
192    }
193
194    /// Split into the `EntityType` representing the entity type, and the `Eid`
195    /// representing its name
196    pub fn components(self) -> (EntityType, Eid) {
197        (self.ty, self.eid)
198    }
199
200    /// Get the source location for this `EntityUID`.
201    pub fn loc(&self) -> Option<&Loc> {
202        self.loc.as_ref()
203    }
204
205    /// Create an [`EntityUID`] with the given typename and [`Eid`]
206    pub fn from_components(ty: EntityType, eid: Eid, loc: Option<Loc>) -> Self {
207        Self { ty, eid, loc }
208    }
209
210    /// Get the type component.
211    pub fn entity_type(&self) -> &EntityType {
212        &self.ty
213    }
214
215    /// Get the Eid component.
216    pub fn eid(&self) -> &Eid {
217        &self.eid
218    }
219
220    /// Does this EntityUID refer to an action entity?
221    pub fn is_action(&self) -> bool {
222        self.entity_type().is_action()
223    }
224}
225
226impl std::fmt::Display for EntityUID {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        write!(f, "{}::\"{}\"", self.entity_type(), self.eid.escaped())
229    }
230}
231
232// allow `.parse()` on a string to make an `EntityUID`
233impl std::str::FromStr for EntityUID {
234    type Err = ParseErrors;
235
236    fn from_str(s: &str) -> Result<Self, Self::Err> {
237        crate::parser::parse_euid(s)
238    }
239}
240
241impl FromNormalizedStr for EntityUID {
242    fn describe_self() -> &'static str {
243        "Entity UID"
244    }
245}
246
247#[cfg(feature = "arbitrary")]
248impl<'a> arbitrary::Arbitrary<'a> for EntityUID {
249    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
250        Ok(Self {
251            ty: u.arbitrary()?,
252            eid: u.arbitrary()?,
253            loc: None,
254        })
255    }
256}
257
258/// The `Eid` type represents the id of an `Entity`, without the typename.
259/// Together with the typename it comprises an `EntityUID`.
260/// For example, in `User::"alice"`, the `Eid` is `alice`.
261///
262/// `Eid` does not implement `Display`, partly because it is unclear whether
263/// `Display` should produce an escaped representation or an unescaped representation
264/// (see [#884](https://github.com/cedar-policy/cedar/issues/884)).
265/// To get an escaped representation, use `.escaped()`.
266/// To get an unescaped representation, use `.as_ref()`.
267#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
268pub struct Eid(SmolStr);
269
270impl Eid {
271    /// Construct an Eid
272    pub fn new(eid: impl Into<SmolStr>) -> Self {
273        Eid(eid.into())
274    }
275
276    /// Get the contents of the `Eid` as an escaped string
277    pub fn escaped(&self) -> SmolStr {
278        self.0.escape_debug().collect()
279    }
280}
281
282impl AsRef<SmolStr> for Eid {
283    fn as_ref(&self) -> &SmolStr {
284        &self.0
285    }
286}
287
288impl AsRef<str> for Eid {
289    fn as_ref(&self) -> &str {
290        &self.0
291    }
292}
293
294#[cfg(feature = "arbitrary")]
295impl<'a> arbitrary::Arbitrary<'a> for Eid {
296    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
297        let x: String = u.arbitrary()?;
298        Ok(Self(x.into()))
299    }
300}
301
302/// Entity datatype
303#[derive(Debug, Clone, Serialize)]
304pub struct Entity {
305    /// UID
306    uid: EntityUID,
307
308    /// Internal BTreMap of attributes.
309    /// We use a btreemap so that the keys have a determenistic order.
310    ///
311    /// In the serialized form of `Entity`, attribute values appear as
312    /// `RestrictedExpr`s, for mostly historical reasons.
313    attrs: BTreeMap<SmolStr, PartialValueSerializedAsExpr>,
314
315    /// Set of ancestors of this `Entity` (i.e., all direct and transitive
316    /// parents), as UIDs
317    ancestors: HashSet<EntityUID>,
318}
319
320impl std::hash::Hash for Entity {
321    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
322        self.uid.hash(state);
323    }
324}
325
326impl Entity {
327    /// Create a new `Entity` with this UID, attributes, and ancestors
328    ///
329    /// # Errors
330    /// - Will error if any of the [`RestrictedExpr]`s in `attrs` error when evaluated
331    pub fn new(
332        uid: EntityUID,
333        attrs: HashMap<SmolStr, RestrictedExpr>,
334        ancestors: HashSet<EntityUID>,
335        extensions: &Extensions<'_>,
336    ) -> Result<Self, EntityAttrEvaluationError> {
337        let evaluator = RestrictedEvaluator::new(extensions);
338        let evaluated_attrs = attrs
339            .into_iter()
340            .map(|(k, v)| {
341                let attr_val = evaluator
342                    .partial_interpret(v.as_borrowed())
343                    .map_err(|err| EntityAttrEvaluationError {
344                        uid: uid.clone(),
345                        attr: k.clone(),
346                        err,
347                    })?;
348                Ok((k, attr_val.into()))
349            })
350            .collect::<Result<_, EntityAttrEvaluationError>>()?;
351        Ok(Entity {
352            uid,
353            attrs: evaluated_attrs,
354            ancestors,
355        })
356    }
357
358    /// Create a new `Entity` with this UID, attributes, and ancestors.
359    ///
360    /// Unlike in `Entity::new()`, in this constructor, attributes are expressed
361    /// as `PartialValue`.
362    pub fn new_with_attr_partial_value(
363        uid: EntityUID,
364        attrs: HashMap<SmolStr, PartialValue>,
365        ancestors: HashSet<EntityUID>,
366    ) -> Self {
367        Entity {
368            uid,
369            attrs: attrs.into_iter().map(|(k, v)| (k, v.into())).collect(), // TODO(#540): can we do this without disassembling and reassembling the HashMap
370            ancestors,
371        }
372    }
373
374    /// Create a new `Entity` with this UID, attributes, and ancestors.
375    ///
376    /// Unlike in `Entity::new()`, in this constructor, attributes are expressed
377    /// as `PartialValueSerializedAsExpr`.
378    pub fn new_with_attr_partial_value_serialized_as_expr(
379        uid: EntityUID,
380        attrs: BTreeMap<SmolStr, PartialValueSerializedAsExpr>,
381        ancestors: HashSet<EntityUID>,
382    ) -> Self {
383        Entity {
384            uid,
385            attrs,
386            ancestors,
387        }
388    }
389
390    /// Get the UID of this entity
391    pub fn uid(&self) -> &EntityUID {
392        &self.uid
393    }
394
395    /// Get the value for the given attribute, or `None` if not present
396    pub fn get(&self, attr: &str) -> Option<&PartialValue> {
397        self.attrs.get(attr).map(|v| v.as_ref())
398    }
399
400    /// Is this `Entity` a descendant of `e` in the entity hierarchy?
401    pub fn is_descendant_of(&self, e: &EntityUID) -> bool {
402        self.ancestors.contains(e)
403    }
404
405    /// Iterate over this entity's ancestors
406    pub fn ancestors(&self) -> impl Iterator<Item = &EntityUID> {
407        self.ancestors.iter()
408    }
409
410    /// Get the number of attributes on this entity
411    pub fn attrs_len(&self) -> usize {
412        self.attrs.len()
413    }
414
415    /// Iterate over this entity's attribute names
416    pub fn keys(&self) -> impl Iterator<Item = &SmolStr> {
417        self.attrs.keys()
418    }
419
420    /// Iterate over this entity's attributes
421    pub fn attrs(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
422        self.attrs.iter().map(|(k, v)| (k, v.as_ref()))
423    }
424
425    /// Create an `Entity` with the given UID, no attributes, and no parents.
426    pub fn with_uid(uid: EntityUID) -> Self {
427        Self {
428            uid,
429            attrs: BTreeMap::new(),
430            ancestors: HashSet::new(),
431        }
432    }
433
434    /// Test if two `Entity` objects are deep/structurally equal.
435    /// That is, not only do they have the same UID, but also the same
436    /// attributes, attribute values, and ancestors.
437    pub(crate) fn deep_eq(&self, other: &Self) -> bool {
438        self.uid == other.uid && self.attrs == other.attrs && self.ancestors == other.ancestors
439    }
440
441    /// Set the given attribute to the given value.
442    // Only used for convenience in some tests and when fuzzing
443    #[cfg(any(test, fuzzing))]
444    pub fn set_attr(
445        &mut self,
446        attr: SmolStr,
447        val: RestrictedExpr,
448        extensions: &Extensions<'_>,
449    ) -> Result<(), EvaluationError> {
450        let val = RestrictedEvaluator::new(extensions).partial_interpret(val.as_borrowed())?;
451        self.attrs.insert(attr, val.into());
452        Ok(())
453    }
454
455    /// Mark the given `UID` as an ancestor of this `Entity`.
456    // When fuzzing, `add_ancestor()` is fully `pub`.
457    #[cfg(not(fuzzing))]
458    pub(crate) fn add_ancestor(&mut self, uid: EntityUID) {
459        self.ancestors.insert(uid);
460    }
461    /// Mark the given `UID` as an ancestor of this `Entity`
462    #[cfg(fuzzing)]
463    pub fn add_ancestor(&mut self, uid: EntityUID) {
464        self.ancestors.insert(uid);
465    }
466
467    /// Consume the entity and return the entity's owned Uid, attributes and parents.
468    pub fn into_inner(
469        self,
470    ) -> (
471        EntityUID,
472        HashMap<SmolStr, PartialValue>,
473        HashSet<EntityUID>,
474    ) {
475        let Self {
476            uid,
477            attrs,
478            ancestors,
479        } = self;
480        (
481            uid,
482            attrs.into_iter().map(|(k, v)| (k, v.0)).collect(),
483            ancestors,
484        )
485    }
486
487    /// Write the entity to a json document
488    pub fn write_to_json(&self, f: impl std::io::Write) -> Result<(), EntitiesError> {
489        let ejson = EntityJson::from_entity(self)?;
490        serde_json::to_writer_pretty(f, &ejson).map_err(JsonSerializationError::from)?;
491        Ok(())
492    }
493
494    /// write the entity to a json value
495    pub fn to_json_value(&self) -> Result<serde_json::Value, EntitiesError> {
496        let ejson = EntityJson::from_entity(self)?;
497        let v = serde_json::to_value(ejson).map_err(JsonSerializationError::from)?;
498        Ok(v)
499    }
500
501    /// write the entity to a json string
502    pub fn to_json_string(&self) -> Result<String, EntitiesError> {
503        let ejson = EntityJson::from_entity(self)?;
504        let string = serde_json::to_string(&ejson).map_err(JsonSerializationError::from)?;
505        Ok(string)
506    }
507}
508
509impl PartialEq for Entity {
510    fn eq(&self, other: &Self) -> bool {
511        self.uid() == other.uid()
512    }
513}
514
515impl Eq for Entity {}
516
517impl StaticallyTyped for Entity {
518    fn type_of(&self) -> Type {
519        self.uid.type_of()
520    }
521}
522
523impl TCNode<EntityUID> for Entity {
524    fn get_key(&self) -> EntityUID {
525        self.uid().clone()
526    }
527
528    fn add_edge_to(&mut self, k: EntityUID) {
529        self.add_ancestor(k)
530    }
531
532    fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
533        Box::new(self.ancestors())
534    }
535
536    fn has_edge_to(&self, e: &EntityUID) -> bool {
537        self.is_descendant_of(e)
538    }
539}
540
541impl std::fmt::Display for Entity {
542    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
543        write!(
544            f,
545            "{}:\n  attrs:{}\n  ancestors:{}",
546            self.uid,
547            self.attrs
548                .iter()
549                .map(|(k, v)| format!("{}: {}", k, v))
550                .join("; "),
551            self.ancestors.iter().join(", ")
552        )
553    }
554}
555
556/// `PartialValue`, but serialized as a `RestrictedExpr`.
557///
558/// (Extension values can't be directly serialized, but can be serialized as
559/// `RestrictedExpr`)
560#[serde_as]
561#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
562pub struct PartialValueSerializedAsExpr(
563    #[serde_as(as = "TryFromInto<RestrictedExpr>")] PartialValue,
564);
565
566impl AsRef<PartialValue> for PartialValueSerializedAsExpr {
567    fn as_ref(&self) -> &PartialValue {
568        &self.0
569    }
570}
571
572impl std::ops::Deref for PartialValueSerializedAsExpr {
573    type Target = PartialValue;
574    fn deref(&self) -> &Self::Target {
575        &self.0
576    }
577}
578
579impl From<PartialValue> for PartialValueSerializedAsExpr {
580    fn from(value: PartialValue) -> PartialValueSerializedAsExpr {
581        PartialValueSerializedAsExpr(value)
582    }
583}
584
585impl From<PartialValueSerializedAsExpr> for PartialValue {
586    fn from(value: PartialValueSerializedAsExpr) -> PartialValue {
587        value.0
588    }
589}
590
591impl std::fmt::Display for PartialValueSerializedAsExpr {
592    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
593        write!(f, "{}", self.0)
594    }
595}
596
597/// Error type for evaluation errors when evaluating an entity attribute.
598/// Contains some extra contextual information and the underlying
599/// `EvaluationError`.
600#[derive(Debug, Diagnostic, Error)]
601#[error("failed to evaluate attribute `{attr}` of `{uid}`: {err}")]
602pub struct EntityAttrEvaluationError {
603    /// UID of the entity where the error was encountered
604    pub uid: EntityUID,
605    /// Attribute of the entity where the error was encountered
606    pub attr: SmolStr,
607    /// Underlying evaluation error
608    #[diagnostic(transparent)]
609    pub err: EvaluationError,
610}
611
612#[cfg(test)]
613mod test {
614    use std::str::FromStr;
615
616    use super::*;
617
618    #[test]
619    fn display() {
620        let e = EntityUID::with_eid("eid");
621        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
622    }
623
624    #[test]
625    fn test_euid_equality() {
626        let e1 = EntityUID::with_eid("foo");
627        let e2 = EntityUID::from_components(
628            Name::parse_unqualified_name("test_entity_type")
629                .expect("should be a valid identifier")
630                .into(),
631            Eid("foo".into()),
632            None,
633        );
634        let e3 = EntityUID::from_components(
635            Name::parse_unqualified_name("Unspecified")
636                .expect("should be a valid identifier")
637                .into(),
638            Eid("foo".into()),
639            None,
640        );
641
642        // an EUID is equal to itself
643        assert_eq!(e1, e1);
644        assert_eq!(e2, e2);
645
646        // constructing with `with_euid` or `from_components` is the same
647        assert_eq!(e1, e2);
648
649        // other pairs are not equal
650        assert!(e1 != e3);
651    }
652
653    #[test]
654    fn action_checker() {
655        let euid = EntityUID::from_str("Action::\"view\"").unwrap();
656        assert!(euid.is_action());
657        let euid = EntityUID::from_str("Foo::Action::\"view\"").unwrap();
658        assert!(euid.is_action());
659        let euid = EntityUID::from_str("Foo::\"view\"").unwrap();
660        assert!(!euid.is_action());
661        let euid = EntityUID::from_str("Action::Foo::\"view\"").unwrap();
662        assert!(!euid.is_action());
663    }
664
665    #[test]
666    fn action_type_is_valid_id() {
667        assert!(Id::from_normalized_str(ACTION_ENTITY_TYPE).is_ok());
668    }
669}