cedar_policy_core/ast/
entity.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use crate::entities::{err::EntitiesError, json::err::JsonSerializationError, EntityJson};
19use crate::evaluator::{EvaluationError, RestrictedEvaluator};
20use crate::extensions::Extensions;
21use crate::parser::err::ParseErrors;
22use crate::parser::Loc;
23use crate::transitive_closure::TCNode;
24use crate::FromNormalizedStr;
25use educe::Educe;
26use itertools::Itertools;
27use miette::Diagnostic;
28use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
29use smol_str::SmolStr;
30use std::collections::{BTreeMap, HashMap, HashSet};
31use std::str::FromStr;
32use std::sync::Arc;
33use thiserror::Error;
34
35#[cfg(feature = "tolerant-ast")]
36static ERROR_NAME: std::sync::LazyLock<Name> =
37    std::sync::LazyLock::new(|| Name(InternalName::from(Id::new_unchecked("EntityTypeError"))));
38
39#[cfg(feature = "tolerant-ast")]
40static EID_ERROR_STR: &str = "Eid::Error";
41
42#[cfg(feature = "tolerant-ast")]
43static ENTITY_TYPE_ERROR_STR: &str = "EntityType::Error";
44
45#[cfg(feature = "tolerant-ast")]
46static ENTITY_UID_ERROR_STR: &str = "EntityUID::Error";
47
48/// The entity type that Actions must have
49pub static ACTION_ENTITY_TYPE: &str = "Action";
50
51#[derive(PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
52#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
53/// Entity type - can be an error type when 'tolerant-ast' feature is enabled
54pub enum EntityType {
55    /// Entity type names are just [`Name`]s, but we have some operations on them specific to entity types.
56    EntityType(Name),
57    #[cfg(feature = "tolerant-ast")]
58    /// Represents an error node of an entity that failed to parse
59    ErrorEntityType,
60}
61
62impl<'de> Deserialize<'de> for EntityType {
63    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
64    where
65        D: Deserializer<'de>,
66    {
67        let name = Name::deserialize(deserializer)?;
68        Ok(EntityType::EntityType(name))
69    }
70}
71
72impl Serialize for EntityType {
73    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
74    where
75        S: Serializer,
76    {
77        match self {
78            EntityType::EntityType(name) => name.serialize(serializer),
79            #[cfg(feature = "tolerant-ast")]
80            EntityType::ErrorEntityType => serializer.serialize_str(ENTITY_TYPE_ERROR_STR),
81        }
82    }
83}
84
85impl EntityType {
86    /// Is this an Action entity type?
87    /// Returns true when an entity type is an action entity type. This compares the
88    /// base name for the type, so this will return true for any entity type named
89    /// `Action` regardless of namespaces.
90    pub fn is_action(&self) -> bool {
91        match self {
92            EntityType::EntityType(name) => {
93                name.as_ref().basename() == &Id::new_unchecked(ACTION_ENTITY_TYPE)
94            }
95            #[cfg(feature = "tolerant-ast")]
96            EntityType::ErrorEntityType => false,
97        }
98    }
99
100    /// The name of this entity type
101    pub fn name(&self) -> &Name {
102        match self {
103            EntityType::EntityType(name) => name,
104            #[cfg(feature = "tolerant-ast")]
105            EntityType::ErrorEntityType => &ERROR_NAME,
106        }
107    }
108
109    /// The source location of this entity type
110    pub fn loc(&self) -> Option<&Loc> {
111        match self {
112            EntityType::EntityType(name) => name.as_ref().loc(),
113            #[cfg(feature = "tolerant-ast")]
114            EntityType::ErrorEntityType => None,
115        }
116    }
117
118    /// Create a clone of this EntityType with given loc
119    pub fn with_loc(&self, loc: Option<&Loc>) -> Self {
120        match self {
121            EntityType::EntityType(name) => EntityType::EntityType(Name(InternalName {
122                id: name.0.id.clone(),
123                path: name.0.path.clone(),
124                loc: loc.cloned(),
125            })),
126            #[cfg(feature = "tolerant-ast")]
127            EntityType::ErrorEntityType => self.clone(),
128        }
129    }
130
131    /// Calls [`Name::qualify_with_name`] on the underlying [`Name`]
132    pub fn qualify_with(&self, namespace: Option<&Name>) -> Self {
133        match self {
134            EntityType::EntityType(name) => Self::EntityType(name.qualify_with_name(namespace)),
135            #[cfg(feature = "tolerant-ast")]
136            EntityType::ErrorEntityType => Self::ErrorEntityType,
137        }
138    }
139
140    /// Wraps [`Name::from_normalized_str`]
141    pub fn from_normalized_str(src: &str) -> Result<Self, ParseErrors> {
142        Name::from_normalized_str(src).map(Into::into)
143    }
144}
145
146impl From<Name> for EntityType {
147    fn from(n: Name) -> Self {
148        Self::EntityType(n)
149    }
150}
151
152impl From<EntityType> for Name {
153    fn from(ty: EntityType) -> Name {
154        match ty {
155            EntityType::EntityType(name) => name,
156            #[cfg(feature = "tolerant-ast")]
157            EntityType::ErrorEntityType => ERROR_NAME.clone(),
158        }
159    }
160}
161
162impl AsRef<Name> for EntityType {
163    fn as_ref(&self) -> &Name {
164        match self {
165            EntityType::EntityType(name) => name,
166            #[cfg(feature = "tolerant-ast")]
167            EntityType::ErrorEntityType => &ERROR_NAME,
168        }
169    }
170}
171
172impl FromStr for EntityType {
173    type Err = ParseErrors;
174
175    fn from_str(s: &str) -> Result<Self, Self::Err> {
176        s.parse().map(Self::EntityType)
177    }
178}
179
180impl std::fmt::Display for EntityType {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        match self {
183            EntityType::EntityType(name) => write!(f, "{name}"),
184            #[cfg(feature = "tolerant-ast")]
185            EntityType::ErrorEntityType => write!(f, "{ENTITY_TYPE_ERROR_STR}"),
186        }
187    }
188}
189
190/// Unique ID for an entity. These represent entities in the AST.
191#[derive(Educe, Serialize, Deserialize, Debug, Clone)]
192#[serde(rename = "EntityUID")]
193#[educe(PartialEq, Eq, Hash, PartialOrd, Ord)]
194pub struct EntityUIDImpl {
195    /// Typename of the entity
196    ty: EntityType,
197    /// EID of the entity
198    eid: Eid,
199    /// Location of the entity in policy source
200    #[serde(skip)]
201    #[educe(PartialEq(ignore))]
202    #[educe(Hash(ignore))]
203    #[educe(PartialOrd(ignore))]
204    loc: Option<Loc>,
205}
206
207impl EntityUIDImpl {
208    /// The source location of this entity
209    pub fn loc(&self) -> Option<Loc> {
210        self.loc.clone()
211    }
212}
213
214/// Unique ID for an entity. These represent entities in the AST.
215#[derive(Educe, Debug, Clone)]
216#[educe(PartialEq, Eq, Hash, PartialOrd, Ord)]
217pub enum EntityUID {
218    /// Unique ID for an entity. These represent entities in the AST
219    EntityUID(EntityUIDImpl),
220    #[cfg(feature = "tolerant-ast")]
221    /// Represents the ID of an error that failed to parse
222    Error,
223}
224
225impl<'de> Deserialize<'de> for EntityUID {
226    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
227    where
228        D: Deserializer<'de>,
229    {
230        let uid_impl = EntityUIDImpl::deserialize(deserializer)?;
231        Ok(EntityUID::EntityUID(uid_impl))
232    }
233}
234
235impl Serialize for EntityUID {
236    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
237    where
238        S: Serializer,
239    {
240        match self {
241            EntityUID::EntityUID(uid_impl) => uid_impl.serialize(serializer),
242            #[cfg(feature = "tolerant-ast")]
243            EntityUID::Error => serializer.serialize_str(ENTITY_UID_ERROR_STR),
244        }
245    }
246}
247
248impl StaticallyTyped for EntityUID {
249    fn type_of(&self) -> Type {
250        match self {
251            EntityUID::EntityUID(entity_uid) => Type::Entity {
252                ty: entity_uid.ty.clone(),
253            },
254            #[cfg(feature = "tolerant-ast")]
255            EntityUID::Error => Type::Entity {
256                ty: EntityType::ErrorEntityType,
257            },
258        }
259    }
260}
261
262#[cfg(test)]
263impl EntityUID {
264    /// Create an `EntityUID` with the given string as its EID.
265    /// Useful for testing.
266    pub(crate) fn with_eid(eid: &str) -> Self {
267        Self::EntityUID(EntityUIDImpl {
268            ty: Self::test_entity_type(),
269            eid: Eid::Eid(eid.into()),
270            loc: None,
271        })
272    }
273
274    /// The type of entities created with the above `with_eid()`.
275    pub(crate) fn test_entity_type() -> EntityType {
276        let name = Name::parse_unqualified_name("test_entity_type")
277            .expect("test_entity_type should be a valid identifier");
278        EntityType::EntityType(name)
279    }
280}
281
282impl EntityUID {
283    /// Create an `EntityUID` with the given (unqualified) typename, and the given string as its EID.
284    pub fn with_eid_and_type(typename: &str, eid: &str) -> Result<Self, ParseErrors> {
285        Ok(Self::EntityUID(EntityUIDImpl {
286            ty: EntityType::EntityType(Name::parse_unqualified_name(typename)?),
287            eid: Eid::Eid(eid.into()),
288            loc: None,
289        }))
290    }
291
292    /// Split into the `EntityType` representing the entity type, and the `Eid`
293    /// representing its name
294    pub fn components(self) -> (EntityType, Eid) {
295        match self {
296            EntityUID::EntityUID(entity_uid) => (entity_uid.ty, entity_uid.eid),
297            #[cfg(feature = "tolerant-ast")]
298            EntityUID::Error => (EntityType::ErrorEntityType, Eid::ErrorEid),
299        }
300    }
301
302    /// Get the source location for this `EntityUID`.
303    pub fn loc(&self) -> Option<&Loc> {
304        match self {
305            EntityUID::EntityUID(entity_uid) => entity_uid.loc.as_ref(),
306            #[cfg(feature = "tolerant-ast")]
307            EntityUID::Error => None,
308        }
309    }
310
311    /// Create an [`EntityUID`] with the given typename and [`Eid`]
312    pub fn from_components(ty: EntityType, eid: Eid, loc: Option<Loc>) -> Self {
313        Self::EntityUID(EntityUIDImpl { ty, eid, loc })
314    }
315
316    /// Get the type component.
317    pub fn entity_type(&self) -> &EntityType {
318        match self {
319            EntityUID::EntityUID(entity_uid) => &entity_uid.ty,
320            #[cfg(feature = "tolerant-ast")]
321            EntityUID::Error => &EntityType::ErrorEntityType,
322        }
323    }
324
325    /// Get the Eid component.
326    pub fn eid(&self) -> &Eid {
327        match self {
328            EntityUID::EntityUID(entity_uid) => &entity_uid.eid,
329            #[cfg(feature = "tolerant-ast")]
330            EntityUID::Error => &Eid::ErrorEid,
331        }
332    }
333
334    /// Does this EntityUID refer to an action entity?
335    pub fn is_action(&self) -> bool {
336        self.entity_type().is_action()
337    }
338}
339
340impl std::fmt::Display for EntityUID {
341    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342        write!(f, "{}::\"{}\"", self.entity_type(), self.eid().escaped())
343    }
344}
345
346// allow `.parse()` on a string to make an `EntityUID`
347impl std::str::FromStr for EntityUID {
348    type Err = ParseErrors;
349
350    fn from_str(s: &str) -> Result<Self, Self::Err> {
351        crate::parser::parse_euid(s)
352    }
353}
354
355impl FromNormalizedStr for EntityUID {
356    fn describe_self() -> &'static str {
357        "Entity UID"
358    }
359}
360
361#[cfg(feature = "arbitrary")]
362impl<'a> arbitrary::Arbitrary<'a> for EntityUID {
363    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
364        Ok(Self::EntityUID(EntityUIDImpl {
365            ty: u.arbitrary()?,
366            eid: u.arbitrary()?,
367            loc: None,
368        }))
369    }
370}
371
372/// The `Eid` type represents the id of an `Entity`, without the typename.
373/// Together with the typename it comprises an `EntityUID`.
374/// For example, in `User::"alice"`, the `Eid` is `alice`.
375///
376/// `Eid` does not implement `Display`, partly because it is unclear whether
377/// `Display` should produce an escaped representation or an unescaped representation
378/// (see [#884](https://github.com/cedar-policy/cedar/issues/884)).
379/// To get an escaped representation, use `.escaped()`.
380/// To get an unescaped representation, use `.as_ref()`.
381#[derive(PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
382pub enum Eid {
383    /// Actual Eid
384    Eid(SmolStr),
385    #[cfg(feature = "tolerant-ast")]
386    /// Represents an Eid of an entity that failed to parse
387    ErrorEid,
388}
389
390impl<'de> Deserialize<'de> for Eid {
391    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
392    where
393        D: Deserializer<'de>,
394    {
395        let value = String::deserialize(deserializer)?;
396        Ok(Eid::Eid(SmolStr::from(value)))
397    }
398}
399
400impl Serialize for Eid {
401    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
402    where
403        S: Serializer,
404    {
405        match self {
406            Eid::Eid(s) => s.serialize(serializer),
407            #[cfg(feature = "tolerant-ast")]
408            Eid::ErrorEid => serializer.serialize_str(EID_ERROR_STR),
409        }
410    }
411}
412
413impl Eid {
414    /// Construct an Eid
415    pub fn new(eid: impl Into<SmolStr>) -> Self {
416        Eid::Eid(eid.into())
417    }
418
419    /// Get the contents of the `Eid` as an escaped string
420    pub fn escaped(&self) -> SmolStr {
421        match self {
422            Eid::Eid(smol_str) => smol_str.escape_debug().collect(),
423            #[cfg(feature = "tolerant-ast")]
424            Eid::ErrorEid => SmolStr::new_static(EID_ERROR_STR),
425        }
426    }
427
428    /// Get the underlying smolstr for this `Eid`
429    pub fn into_smolstr(self) -> SmolStr {
430        match self {
431            Eid::Eid(smol_str) => smol_str,
432            #[cfg(feature = "tolerant-ast")]
433            Eid::ErrorEid => SmolStr::new_static(EID_ERROR_STR),
434        }
435    }
436}
437
438impl AsRef<str> for Eid {
439    fn as_ref(&self) -> &str {
440        match self {
441            Eid::Eid(smol_str) => smol_str,
442            #[cfg(feature = "tolerant-ast")]
443            Eid::ErrorEid => EID_ERROR_STR,
444        }
445    }
446}
447
448#[cfg(feature = "arbitrary")]
449impl<'a> arbitrary::Arbitrary<'a> for Eid {
450    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
451        let x: String = u.arbitrary()?;
452        Ok(Self::Eid(x.into()))
453    }
454}
455
456/// Entity datatype
457#[derive(Debug, Clone)]
458pub struct Entity {
459    /// UID
460    uid: EntityUID,
461
462    /// Internal `BTreeMap` of attributes.
463    ///
464    /// We use a `BTreeMap` so that the keys have a deterministic order.
465    attrs: BTreeMap<SmolStr, PartialValue>,
466
467    /// Set of indirect ancestors of this `Entity` as UIDs
468    indirect_ancestors: HashSet<EntityUID>,
469
470    /// Set of direct ancestors (i.e., parents) as UIDs
471    ///
472    /// indirect_ancestors and parents should be disjoint
473    /// even if a parent is also an indirect parent through
474    /// a different parent
475    parents: HashSet<EntityUID>,
476
477    /// Tags on this entity (RFC 82)
478    ///
479    /// Like for `attrs`, we use a `BTreeMap` so that the tags have a
480    /// deterministic order.
481    tags: BTreeMap<SmolStr, PartialValue>,
482}
483
484impl std::hash::Hash for Entity {
485    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
486        self.uid.hash(state);
487    }
488}
489
490impl Entity {
491    /// Create a new `Entity` with this UID, attributes, ancestors, and tags
492    ///
493    /// # Errors
494    /// - Will error if any of the [`RestrictedExpr]`s in `attrs` or `tags` error when evaluated
495    pub fn new(
496        uid: EntityUID,
497        attrs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
498        indirect_ancestors: HashSet<EntityUID>,
499        parents: HashSet<EntityUID>,
500        tags: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
501        extensions: &Extensions<'_>,
502    ) -> Result<Self, EntityAttrEvaluationError> {
503        let evaluator = RestrictedEvaluator::new(extensions);
504        let evaluate_kvs = |(k, v): (SmolStr, RestrictedExpr), was_attr: bool| {
505            let attr_val = evaluator
506                .partial_interpret(v.as_borrowed())
507                .map_err(|err| EntityAttrEvaluationError {
508                    uid: uid.clone(),
509                    attr_or_tag: k.clone(),
510                    was_attr,
511                    err,
512                })?;
513            Ok((k, attr_val))
514        };
515        let evaluated_attrs = attrs
516            .into_iter()
517            .map(|kv| evaluate_kvs(kv, true))
518            .collect::<Result<_, EntityAttrEvaluationError>>()?;
519        let evaluated_tags = tags
520            .into_iter()
521            .map(|kv| evaluate_kvs(kv, false))
522            .collect::<Result<_, EntityAttrEvaluationError>>()?;
523        Ok(Entity {
524            uid,
525            attrs: evaluated_attrs,
526            indirect_ancestors,
527            parents,
528            tags: evaluated_tags,
529        })
530    }
531
532    /// Create a new [`Entity`] with this UID, attributes, ancestors, and tags
533    ///
534    /// Unlike in `Entity::new()`, in this constructor, attributes and tags are
535    /// expressed as `PartialValue`.
536    ///
537    /// Callers should consider directly using [`Entity::new_with_attr_partial_value_serialized_as_expr`]
538    /// if they would call this method by first building a map, as it will
539    /// deconstruct and re-build the map perhaps unnecessarily.
540    pub fn new_with_attr_partial_value(
541        uid: EntityUID,
542        attrs: impl IntoIterator<Item = (SmolStr, PartialValue)>,
543        indirect_ancestors: HashSet<EntityUID>,
544        parents: HashSet<EntityUID>,
545        tags: impl IntoIterator<Item = (SmolStr, PartialValue)>,
546    ) -> Self {
547        Self {
548            uid,
549            attrs: attrs.into_iter().collect(),
550            indirect_ancestors,
551            parents,
552            tags: tags.into_iter().collect(),
553        }
554    }
555
556    /// Get the UID of this entity
557    pub fn uid(&self) -> &EntityUID {
558        &self.uid
559    }
560
561    /// Get the value for the given attribute, or `None` if not present
562    pub fn get(&self, attr: &str) -> Option<&PartialValue> {
563        self.attrs.get(attr)
564    }
565
566    /// Get the value for the given tag, or `None` if not present
567    pub fn get_tag(&self, tag: &str) -> Option<&PartialValue> {
568        self.tags.get(tag)
569    }
570
571    /// Is this `Entity` a (direct or indirect) descendant of `e` in the entity hierarchy?
572    pub fn is_descendant_of(&self, e: &EntityUID) -> bool {
573        self.parents.contains(e) || self.indirect_ancestors.contains(e)
574    }
575
576    /// Is this `Entity` a an indirect descendant of `e` in the entity hierarchy?
577    pub fn is_indirect_descendant_of(&self, e: &EntityUID) -> bool {
578        self.indirect_ancestors.contains(e)
579    }
580
581    /// Is this `Entity` a direct decendant (child) of `e` in the entity hierarchy?
582    pub fn is_child_of(&self, e: &EntityUID) -> bool {
583        self.parents.contains(e)
584    }
585
586    /// Iterate over this entity's (direct or indirect) ancestors
587    pub fn ancestors(&self) -> impl Iterator<Item = &EntityUID> {
588        self.parents.iter().chain(self.indirect_ancestors.iter())
589    }
590
591    /// Iterate over this entity's indirect ancestors
592    pub fn indirect_ancestors(&self) -> impl Iterator<Item = &EntityUID> {
593        self.indirect_ancestors.iter()
594    }
595
596    /// Iterate over this entity's direct ancestors (parents)
597    pub fn parents(&self) -> impl Iterator<Item = &EntityUID> {
598        self.parents.iter()
599    }
600
601    /// Get the number of attributes on this entity
602    pub fn attrs_len(&self) -> usize {
603        self.attrs.len()
604    }
605
606    /// Get the number of tags on this entity
607    pub fn tags_len(&self) -> usize {
608        self.tags.len()
609    }
610
611    /// Iterate over this entity's attribute names
612    pub fn keys(&self) -> impl Iterator<Item = &SmolStr> {
613        self.attrs.keys()
614    }
615
616    /// Iterate over this entity's tag names
617    pub fn tag_keys(&self) -> impl Iterator<Item = &SmolStr> {
618        self.tags.keys()
619    }
620
621    /// Iterate over this entity's attributes
622    pub fn attrs(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
623        self.attrs.iter()
624    }
625
626    /// Iterate over this entity's tags
627    pub fn tags(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
628        self.tags.iter()
629    }
630
631    /// Create an `Entity` with the given UID, no attributes, no parents, and no tags.
632    pub fn with_uid(uid: EntityUID) -> Self {
633        Self {
634            uid,
635            attrs: BTreeMap::new(),
636            indirect_ancestors: HashSet::new(),
637            parents: HashSet::new(),
638            tags: BTreeMap::new(),
639        }
640    }
641
642    /// Test if two `Entity` objects are deep/structurally equal.
643    /// That is, not only do they have the same UID, but also the same
644    /// attributes, attribute values, and ancestors/parents.
645    ///
646    /// Does not test that they have the same _direct_ parents, only that they have the same overall ancestor set.
647    pub fn deep_eq(&self, other: &Self) -> bool {
648        self.uid == other.uid
649            && self.attrs == other.attrs
650            && self.tags == other.tags
651            && (self.ancestors().collect::<HashSet<_>>())
652                == (other.ancestors().collect::<HashSet<_>>())
653    }
654
655    /// Mark the given `UID` as an indirect ancestor of this `Entity`
656    ///
657    /// The given `UID` will not be added as an indirecty ancestor if
658    /// it is already a direct ancestor (parent) of this `Entity`
659    /// The caller of this code is responsible for maintaining
660    /// transitive closure of hierarchy.
661    pub fn add_indirect_ancestor(&mut self, uid: EntityUID) {
662        if !self.parents.contains(&uid) {
663            self.indirect_ancestors.insert(uid);
664        }
665    }
666
667    /// Mark the given `UID` as a (direct) parent of this `Entity`, and
668    /// remove the UID from indirect ancestors
669    /// if it was previously added as an indirect ancestor
670    /// The caller of this code is responsible for maintaining
671    /// transitive closure of hierarchy.
672    pub fn add_parent(&mut self, uid: EntityUID) {
673        self.indirect_ancestors.remove(&uid);
674        self.parents.insert(uid);
675    }
676
677    /// Remove the given `UID` as an indirect ancestor of this `Entity`.
678    ///
679    /// No effect if the `UID` is a direct parent.
680    /// The caller of this code is responsible for maintaining
681    /// transitive closure of hierarchy.
682    pub fn remove_indirect_ancestor(&mut self, uid: &EntityUID) {
683        self.indirect_ancestors.remove(uid);
684    }
685
686    /// Remove the given `UID` as a (direct) parent of this `Entity`.
687    ///
688    /// No effect on the `Entity`'s indirect ancestors.
689    /// The caller of this code is responsible for maintaining
690    /// transitive closure of hierarchy.
691    pub fn remove_parent(&mut self, uid: &EntityUID) {
692        self.parents.remove(uid);
693    }
694
695    /// Remove all indirect ancestors of this `Entity`.
696    ///
697    /// The caller of this code is responsible for maintaining
698    /// transitive closure of hierarchy.
699    pub fn remove_all_indirect_ancestors(&mut self) {
700        self.indirect_ancestors.clear();
701    }
702
703    /// Consume the entity and return the entity's owned Uid, attributes, ancestors, parents, and tags.
704    #[allow(clippy::type_complexity)]
705    pub fn into_inner(
706        self,
707    ) -> (
708        EntityUID,
709        HashMap<SmolStr, PartialValue>,
710        HashSet<EntityUID>,
711        HashSet<EntityUID>,
712        HashMap<SmolStr, PartialValue>,
713    ) {
714        (
715            self.uid,
716            self.attrs.into_iter().collect(),
717            self.indirect_ancestors,
718            self.parents,
719            self.tags.into_iter().collect(),
720        )
721    }
722
723    /// Write the entity to a json document
724    pub fn write_to_json(&self, f: impl std::io::Write) -> Result<(), EntitiesError> {
725        let ejson = EntityJson::from_entity(self)?;
726        serde_json::to_writer_pretty(f, &ejson).map_err(JsonSerializationError::from)?;
727        Ok(())
728    }
729
730    /// write the entity to a json value
731    pub fn to_json_value(&self) -> Result<serde_json::Value, EntitiesError> {
732        let ejson = EntityJson::from_entity(self)?;
733        let v = serde_json::to_value(ejson).map_err(JsonSerializationError::from)?;
734        Ok(v)
735    }
736
737    /// write the entity to a json string
738    pub fn to_json_string(&self) -> Result<String, EntitiesError> {
739        let ejson = EntityJson::from_entity(self)?;
740        let string = serde_json::to_string(&ejson).map_err(JsonSerializationError::from)?;
741        Ok(string)
742    }
743}
744
745/// `Entity`s are equal if their UIDs are equal
746impl PartialEq for Entity {
747    fn eq(&self, other: &Self) -> bool {
748        self.uid() == other.uid()
749    }
750}
751
752impl Eq for Entity {}
753
754impl StaticallyTyped for Entity {
755    fn type_of(&self) -> Type {
756        self.uid.type_of()
757    }
758}
759
760impl TCNode<EntityUID> for Entity {
761    fn get_key(&self) -> EntityUID {
762        self.uid().clone()
763    }
764
765    fn add_edge_to(&mut self, k: EntityUID) {
766        self.add_indirect_ancestor(k);
767    }
768
769    fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
770        Box::new(self.ancestors())
771    }
772
773    fn has_edge_to(&self, e: &EntityUID) -> bool {
774        self.is_descendant_of(e)
775    }
776
777    fn reset_edges(&mut self) {
778        self.remove_all_indirect_ancestors()
779    }
780
781    fn direct_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
782        Box::new(self.parents())
783    }
784}
785
786impl TCNode<EntityUID> for Arc<Entity> {
787    fn get_key(&self) -> EntityUID {
788        self.uid().clone()
789    }
790
791    fn add_edge_to(&mut self, k: EntityUID) {
792        // Use Arc::make_mut to get a mutable reference to the inner value
793        Arc::make_mut(self).add_indirect_ancestor(k)
794    }
795
796    fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
797        Box::new(self.ancestors())
798    }
799
800    fn has_edge_to(&self, e: &EntityUID) -> bool {
801        self.is_descendant_of(e)
802    }
803
804    fn reset_edges(&mut self) {
805        // Use Arc::make_mut to get a mutable reference to the inner value
806        Arc::make_mut(self).remove_all_indirect_ancestors()
807    }
808
809    fn direct_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
810        Box::new(self.parents())
811    }
812}
813
814impl std::fmt::Display for Entity {
815    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
816        write!(
817            f,
818            "{}:\n  attrs:{}\n  ancestors:{}",
819            self.uid,
820            self.attrs
821                .iter()
822                .map(|(k, v)| format!("{k}: {v}"))
823                .join("; "),
824            self.ancestors().join(", ")
825        )
826    }
827}
828
829/// Error type for evaluation errors when evaluating an entity attribute or tag.
830/// Contains some extra contextual information and the underlying
831/// `EvaluationError`.
832//
833// This is NOT a publicly exported error type.
834#[derive(Debug, Diagnostic, Error)]
835#[error("failed to evaluate {} `{attr_or_tag}` of `{uid}`: {err}", if *.was_attr { "attribute" } else { "tag" })]
836pub struct EntityAttrEvaluationError {
837    /// UID of the entity where the error was encountered
838    pub uid: EntityUID,
839    /// Attribute or tag of the entity where the error was encountered
840    pub attr_or_tag: SmolStr,
841    /// If `attr_or_tag` was an attribute (`true`) or tag (`false`)
842    pub was_attr: bool,
843    /// Underlying evaluation error
844    #[diagnostic(transparent)]
845    pub err: EvaluationError,
846}
847
848#[cfg(test)]
849mod test {
850    use std::str::FromStr;
851
852    use super::*;
853
854    #[test]
855    fn display() {
856        let e = EntityUID::with_eid("eid");
857        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
858    }
859
860    #[test]
861    fn test_euid_equality() {
862        let e1 = EntityUID::with_eid("foo");
863        let e2 = EntityUID::from_components(
864            Name::parse_unqualified_name("test_entity_type")
865                .expect("should be a valid identifier")
866                .into(),
867            Eid::Eid("foo".into()),
868            None,
869        );
870        let e3 = EntityUID::from_components(
871            Name::parse_unqualified_name("Unspecified")
872                .expect("should be a valid identifier")
873                .into(),
874            Eid::Eid("foo".into()),
875            None,
876        );
877
878        // an EUID is equal to itself
879        assert_eq!(e1, e1);
880        assert_eq!(e2, e2);
881
882        // constructing with `with_euid` or `from_components` is the same
883        assert_eq!(e1, e2);
884
885        // other pairs are not equal
886        assert!(e1 != e3);
887    }
888
889    #[test]
890    fn action_checker() {
891        let euid = EntityUID::from_str("Action::\"view\"").unwrap();
892        assert!(euid.is_action());
893        let euid = EntityUID::from_str("Foo::Action::\"view\"").unwrap();
894        assert!(euid.is_action());
895        let euid = EntityUID::from_str("Foo::\"view\"").unwrap();
896        assert!(!euid.is_action());
897        let euid = EntityUID::from_str("Action::Foo::\"view\"").unwrap();
898        assert!(!euid.is_action());
899    }
900
901    #[test]
902    fn action_type_is_valid_id() {
903        assert!(Id::from_normalized_str(ACTION_ENTITY_TYPE).is_ok());
904    }
905
906    #[cfg(feature = "tolerant-ast")]
907    #[test]
908    fn error_entity() {
909        use cool_asserts::assert_matches;
910
911        let e = EntityUID::Error;
912        assert_matches!(e.eid(), Eid::ErrorEid);
913        assert_matches!(e.entity_type(), EntityType::ErrorEntityType);
914        assert!(!e.is_action());
915        assert_matches!(e.loc(), None);
916
917        let error_eid = Eid::ErrorEid;
918        assert_eq!(error_eid.escaped(), "Eid::Error");
919
920        let error_type = EntityType::ErrorEntityType;
921        assert!(!error_type.is_action());
922        assert_eq!(error_type.qualify_with(None), EntityType::ErrorEntityType);
923        assert_eq!(
924            error_type.qualify_with(Some(&Name(InternalName::from(Id::new_unchecked(
925                "EntityTypeError"
926            ))))),
927            EntityType::ErrorEntityType
928        );
929
930        assert_eq!(
931            error_type.name(),
932            &Name(InternalName::from(Id::new_unchecked("EntityTypeError")))
933        );
934        assert_eq!(error_type.loc(), None)
935    }
936
937    #[test]
938    fn entity_type_deserialization() {
939        let json = r#""some_entity_type""#;
940        let entity_type: EntityType = serde_json::from_str(json).unwrap();
941        assert_eq!(
942            entity_type.name().0.to_string(),
943            "some_entity_type".to_string()
944        )
945    }
946
947    #[test]
948    fn entity_type_serialization() {
949        let entity_type = EntityType::EntityType(Name(InternalName::from(Id::new_unchecked(
950            "some_entity_type",
951        ))));
952        let serialized = serde_json::to_string(&entity_type).unwrap();
953
954        assert_eq!(serialized, r#""some_entity_type""#);
955    }
956}