Skip to main content

cedar_policy_core/ast/
entity.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use crate::entities::{err::EntitiesError, json::err::JsonSerializationError, EntityJson};
19use crate::evaluator::{EvaluationError, RestrictedEvaluator};
20use crate::extensions::Extensions;
21use crate::parser::err::ParseErrors;
22use crate::parser::Loc;
23use crate::transitive_closure::TCNode;
24use crate::FromNormalizedStr;
25use educe::Educe;
26use itertools::Itertools;
27use miette::Diagnostic;
28use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
29use smol_str::SmolStr;
30use std::collections::{BTreeMap, HashMap, HashSet};
31use std::str::FromStr;
32use std::sync::Arc;
33use thiserror::Error;
34
35#[cfg(feature = "tolerant-ast")]
36static ERROR_NAME: std::sync::LazyLock<Name> =
37    std::sync::LazyLock::new(|| Name(InternalName::from(Id::new_unchecked("EntityTypeError"))));
38
39#[cfg(feature = "tolerant-ast")]
40static EID_ERROR_STR: &str = "Eid::Error";
41
42#[cfg(feature = "tolerant-ast")]
43static ENTITY_TYPE_ERROR_STR: &str = "EntityType::Error";
44
45#[cfg(feature = "tolerant-ast")]
46static ENTITY_UID_ERROR_STR: &str = "EntityUID::Error";
47
48/// The entity type that Actions must have
49pub static ACTION_ENTITY_TYPE: &str = "Action";
50
51#[derive(PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
52#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
53/// Entity type - can be an error type when 'tolerant-ast' feature is enabled
54pub enum EntityType {
55    /// Entity type names are just [`Name`]s, but we have some operations on them specific to entity types.
56    EntityType(Name),
57    #[cfg(feature = "tolerant-ast")]
58    /// Represents an error node of an entity that failed to parse
59    ErrorEntityType,
60}
61
62impl<'de> Deserialize<'de> for EntityType {
63    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
64    where
65        D: Deserializer<'de>,
66    {
67        let name = Name::deserialize(deserializer)?;
68        Ok(EntityType::EntityType(name))
69    }
70}
71
72impl Serialize for EntityType {
73    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
74    where
75        S: Serializer,
76    {
77        match self {
78            EntityType::EntityType(name) => name.serialize(serializer),
79            #[cfg(feature = "tolerant-ast")]
80            EntityType::ErrorEntityType => serializer.serialize_str(ENTITY_TYPE_ERROR_STR),
81        }
82    }
83}
84
85impl EntityType {
86    /// Is this an Action entity type?
87    /// Returns true when an entity type is an action entity type. This compares the
88    /// base name for the type, so this will return true for any entity type named
89    /// `Action` regardless of namespaces.
90    pub fn is_action(&self) -> bool {
91        match self {
92            EntityType::EntityType(name) => {
93                name.as_ref().basename() == &Id::new_unchecked(ACTION_ENTITY_TYPE)
94            }
95            #[cfg(feature = "tolerant-ast")]
96            EntityType::ErrorEntityType => false,
97        }
98    }
99
100    /// The name of this entity type
101    pub fn name(&self) -> &Name {
102        match self {
103            EntityType::EntityType(name) => name,
104            #[cfg(feature = "tolerant-ast")]
105            EntityType::ErrorEntityType => &ERROR_NAME,
106        }
107    }
108
109    /// The source location of this entity type
110    pub fn loc(&self) -> Option<&Loc> {
111        match self {
112            EntityType::EntityType(name) => name.as_ref().loc(),
113            #[cfg(feature = "tolerant-ast")]
114            EntityType::ErrorEntityType => None,
115        }
116    }
117
118    /// Create a clone of this EntityType with given loc
119    pub fn with_loc(&self, loc: Option<&Loc>) -> Self {
120        match self {
121            EntityType::EntityType(name) => EntityType::EntityType(Name(InternalName {
122                id: name.0.id.clone(),
123                path: name.0.path.clone(),
124                loc: loc.cloned(),
125            })),
126            #[cfg(feature = "tolerant-ast")]
127            EntityType::ErrorEntityType => self.clone(),
128        }
129    }
130
131    /// Calls [`Name::qualify_with_name`] on the underlying [`Name`]
132    pub fn qualify_with(&self, namespace: Option<&Name>) -> Self {
133        match self {
134            EntityType::EntityType(name) => Self::EntityType(name.qualify_with_name(namespace)),
135            #[cfg(feature = "tolerant-ast")]
136            EntityType::ErrorEntityType => Self::ErrorEntityType,
137        }
138    }
139
140    /// Wraps [`Name::from_normalized_str`]
141    pub fn from_normalized_str(src: &str) -> Result<Self, ParseErrors> {
142        Name::from_normalized_str(src).map(Into::into)
143    }
144}
145
146impl From<Name> for EntityType {
147    fn from(n: Name) -> Self {
148        Self::EntityType(n)
149    }
150}
151
152impl From<EntityType> for Name {
153    fn from(ty: EntityType) -> Name {
154        match ty {
155            EntityType::EntityType(name) => name,
156            #[cfg(feature = "tolerant-ast")]
157            EntityType::ErrorEntityType => ERROR_NAME.clone(),
158        }
159    }
160}
161
162impl AsRef<Name> for EntityType {
163    fn as_ref(&self) -> &Name {
164        match self {
165            EntityType::EntityType(name) => name,
166            #[cfg(feature = "tolerant-ast")]
167            EntityType::ErrorEntityType => &ERROR_NAME,
168        }
169    }
170}
171
172impl FromStr for EntityType {
173    type Err = ParseErrors;
174
175    fn from_str(s: &str) -> Result<Self, Self::Err> {
176        s.parse().map(Self::EntityType)
177    }
178}
179
180impl std::fmt::Display for EntityType {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        match self {
183            EntityType::EntityType(name) => write!(f, "{name}"),
184            #[cfg(feature = "tolerant-ast")]
185            EntityType::ErrorEntityType => write!(f, "{ENTITY_TYPE_ERROR_STR}"),
186        }
187    }
188}
189
190/// Unique ID for an entity. These represent entities in the AST.
191#[derive(Educe, Serialize, Deserialize, Debug, Clone)]
192#[serde(rename = "EntityUID")]
193#[educe(PartialEq, Eq, Hash, PartialOrd, Ord)]
194pub struct EntityUIDImpl {
195    /// Typename of the entity
196    ty: EntityType,
197    /// EID of the entity
198    eid: Eid,
199    /// Location of the entity in policy source
200    #[serde(skip)]
201    #[educe(PartialEq(ignore))]
202    #[educe(Hash(ignore))]
203    #[educe(PartialOrd(ignore))]
204    loc: Option<Loc>,
205}
206
207impl EntityUIDImpl {
208    /// The source location of this entity
209    pub fn loc(&self) -> Option<Loc> {
210        self.loc.clone()
211    }
212}
213
214/// Unique ID for an entity. These represent entities in the AST.
215#[derive(Educe, Debug, Clone)]
216#[educe(PartialEq, Eq, Hash, PartialOrd, Ord)]
217pub enum EntityUID {
218    /// Unique ID for an entity. These represent entities in the AST
219    EntityUID(EntityUIDImpl),
220    #[cfg(feature = "tolerant-ast")]
221    /// Represents the ID of an error that failed to parse
222    Error,
223}
224
225impl<'de> Deserialize<'de> for EntityUID {
226    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
227    where
228        D: Deserializer<'de>,
229    {
230        let uid_impl = EntityUIDImpl::deserialize(deserializer)?;
231        Ok(EntityUID::EntityUID(uid_impl))
232    }
233}
234
235impl Serialize for EntityUID {
236    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
237    where
238        S: Serializer,
239    {
240        match self {
241            EntityUID::EntityUID(uid_impl) => uid_impl.serialize(serializer),
242            #[cfg(feature = "tolerant-ast")]
243            EntityUID::Error => serializer.serialize_str(ENTITY_UID_ERROR_STR),
244        }
245    }
246}
247
248impl StaticallyTyped for EntityUID {
249    fn type_of(&self) -> Type {
250        match self {
251            EntityUID::EntityUID(entity_uid) => Type::Entity {
252                ty: entity_uid.ty.clone(),
253            },
254            #[cfg(feature = "tolerant-ast")]
255            EntityUID::Error => Type::Entity {
256                ty: EntityType::ErrorEntityType,
257            },
258        }
259    }
260}
261
262#[cfg(test)]
263impl EntityUID {
264    /// Create an `EntityUID` with the given string as its EID.
265    /// Useful for testing.
266    pub(crate) fn with_eid(eid: &str) -> Self {
267        Self::EntityUID(EntityUIDImpl {
268            ty: Self::test_entity_type(),
269            eid: Eid::Eid(eid.into()),
270            loc: None,
271        })
272    }
273
274    /// The type of entities created with the above `with_eid()`.
275    pub(crate) fn test_entity_type() -> EntityType {
276        let name = Name::parse_unqualified_name("test_entity_type")
277            .expect("test_entity_type should be a valid identifier");
278        EntityType::EntityType(name)
279    }
280}
281
282impl EntityUID {
283    /// Create an `EntityUID` with the given (unqualified) typename, and the given string as its EID.
284    pub fn with_eid_and_type(typename: &str, eid: &str) -> Result<Self, ParseErrors> {
285        Ok(Self::EntityUID(EntityUIDImpl {
286            ty: EntityType::EntityType(Name::parse_unqualified_name(typename)?),
287            eid: Eid::Eid(eid.into()),
288            loc: None,
289        }))
290    }
291
292    /// Split into the `EntityType` representing the entity type, and the `Eid`
293    /// representing its name
294    pub fn components(self) -> (EntityType, Eid) {
295        match self {
296            EntityUID::EntityUID(entity_uid) => (entity_uid.ty, entity_uid.eid),
297            #[cfg(feature = "tolerant-ast")]
298            EntityUID::Error => (EntityType::ErrorEntityType, Eid::ErrorEid),
299        }
300    }
301
302    /// Get the source location for this `EntityUID`.
303    pub fn loc(&self) -> Option<&Loc> {
304        match self {
305            EntityUID::EntityUID(entity_uid) => entity_uid.loc.as_ref(),
306            #[cfg(feature = "tolerant-ast")]
307            EntityUID::Error => None,
308        }
309    }
310
311    /// Create an [`EntityUID`] with the given typename and [`Eid`]
312    pub fn from_components(ty: EntityType, eid: Eid, loc: Option<Loc>) -> Self {
313        Self::EntityUID(EntityUIDImpl { ty, eid, loc })
314    }
315
316    /// Get the type component.
317    pub fn entity_type(&self) -> &EntityType {
318        match self {
319            EntityUID::EntityUID(entity_uid) => &entity_uid.ty,
320            #[cfg(feature = "tolerant-ast")]
321            EntityUID::Error => &EntityType::ErrorEntityType,
322        }
323    }
324
325    /// Get the Eid component.
326    pub fn eid(&self) -> &Eid {
327        match self {
328            EntityUID::EntityUID(entity_uid) => &entity_uid.eid,
329            #[cfg(feature = "tolerant-ast")]
330            EntityUID::Error => &Eid::ErrorEid,
331        }
332    }
333
334    /// Does this EntityUID refer to an action entity?
335    pub fn is_action(&self) -> bool {
336        self.entity_type().is_action()
337    }
338}
339
340impl std::fmt::Display for EntityUID {
341    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342        write!(f, "{}::\"{}\"", self.entity_type(), self.eid().escaped())
343    }
344}
345
346// allow `.parse()` on a string to make an `EntityUID`
347impl std::str::FromStr for EntityUID {
348    type Err = ParseErrors;
349
350    fn from_str(s: &str) -> Result<Self, Self::Err> {
351        crate::parser::parse_euid(s)
352    }
353}
354
355impl FromNormalizedStr for EntityUID {
356    fn describe_self() -> &'static str {
357        "Entity UID"
358    }
359}
360
361#[cfg(feature = "arbitrary")]
362impl<'a> arbitrary::Arbitrary<'a> for EntityUID {
363    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
364        Ok(Self::EntityUID(EntityUIDImpl {
365            ty: u.arbitrary()?,
366            eid: u.arbitrary()?,
367            loc: None,
368        }))
369    }
370}
371
372/// The `Eid` type represents the id of an `Entity`, without the typename.
373/// Together with the typename it comprises an `EntityUID`.
374/// For example, in `User::"alice"`, the `Eid` is `alice`.
375///
376/// `Eid` does not implement `Display`, partly because it is unclear whether
377/// `Display` should produce an escaped representation or an unescaped representation
378/// (see [#884](https://github.com/cedar-policy/cedar/issues/884)).
379/// To get an escaped representation, use `.escaped()`.
380/// To get an unescaped representation, use `.as_ref()`.
381#[derive(PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
382pub enum Eid {
383    /// Actual Eid
384    Eid(SmolStr),
385    #[cfg(feature = "tolerant-ast")]
386    /// Represents an Eid of an entity that failed to parse
387    ErrorEid,
388}
389
390impl<'de> Deserialize<'de> for Eid {
391    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
392    where
393        D: Deserializer<'de>,
394    {
395        let value = String::deserialize(deserializer)?;
396        Ok(Eid::Eid(SmolStr::from(value)))
397    }
398}
399
400impl Serialize for Eid {
401    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
402    where
403        S: Serializer,
404    {
405        match self {
406            Eid::Eid(s) => s.serialize(serializer),
407            #[cfg(feature = "tolerant-ast")]
408            Eid::ErrorEid => serializer.serialize_str(EID_ERROR_STR),
409        }
410    }
411}
412
413impl Eid {
414    /// Construct an Eid
415    pub fn new(eid: impl Into<SmolStr>) -> Self {
416        Eid::Eid(eid.into())
417    }
418
419    /// Get the contents of the `Eid` as an escaped string
420    pub fn escaped(&self) -> SmolStr {
421        match self {
422            Eid::Eid(smol_str) => smol_str.escape_debug().collect(),
423            #[cfg(feature = "tolerant-ast")]
424            Eid::ErrorEid => SmolStr::new_static(EID_ERROR_STR),
425        }
426    }
427
428    /// Get the underlying smolstr for this `Eid`
429    pub fn into_smolstr(self) -> SmolStr {
430        match self {
431            Eid::Eid(smol_str) => smol_str,
432            #[cfg(feature = "tolerant-ast")]
433            Eid::ErrorEid => SmolStr::new_static(EID_ERROR_STR),
434        }
435    }
436}
437
438impl AsRef<str> for Eid {
439    fn as_ref(&self) -> &str {
440        match self {
441            Eid::Eid(smol_str) => smol_str,
442            #[cfg(feature = "tolerant-ast")]
443            Eid::ErrorEid => EID_ERROR_STR,
444        }
445    }
446}
447
448#[cfg(feature = "arbitrary")]
449impl<'a> arbitrary::Arbitrary<'a> for Eid {
450    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
451        let x: String = u.arbitrary()?;
452        Ok(Self::Eid(x.into()))
453    }
454}
455
456/// Entity datatype
457#[derive(Debug, Clone)]
458pub struct Entity {
459    /// UID
460    uid: EntityUID,
461
462    /// Internal `BTreeMap` of attributes.
463    ///
464    /// We use a `BTreeMap` so that the keys have a deterministic order.
465    attrs: BTreeMap<SmolStr, PartialValue>,
466
467    /// Set of indirect ancestors of this `Entity` as UIDs
468    indirect_ancestors: HashSet<EntityUID>,
469
470    /// Set of direct ancestors (i.e., parents) as UIDs
471    ///
472    /// indirect_ancestors and parents should be disjoint
473    /// even if a parent is also an indirect parent through
474    /// a different parent
475    parents: HashSet<EntityUID>,
476
477    /// Tags on this entity (RFC 82)
478    ///
479    /// Like for `attrs`, we use a `BTreeMap` so that the tags have a
480    /// deterministic order.
481    tags: BTreeMap<SmolStr, PartialValue>,
482}
483
484impl std::hash::Hash for Entity {
485    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
486        self.uid.hash(state);
487    }
488}
489
490impl Entity {
491    /// Create a new `Entity` with this UID, attributes, ancestors, and tags
492    ///
493    /// # Errors
494    /// - Will error if any of the [`RestrictedExpr]`s in `attrs` or `tags` error when evaluated
495    pub fn new(
496        uid: EntityUID,
497        attrs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
498        indirect_ancestors: HashSet<EntityUID>,
499        parents: HashSet<EntityUID>,
500        tags: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
501        extensions: &Extensions<'_>,
502    ) -> Result<Self, EntityAttrEvaluationError> {
503        let evaluator = RestrictedEvaluator::new(extensions);
504        let evaluate_kvs = |(k, v): (SmolStr, RestrictedExpr), was_attr: bool| {
505            let attr_val = evaluator
506                .partial_interpret(v.as_borrowed())
507                .map_err(|err| EntityAttrEvaluationError {
508                    uid: uid.clone(),
509                    attr_or_tag: k.clone(),
510                    was_attr,
511                    err,
512                })?;
513            Ok((k, attr_val))
514        };
515        let evaluated_attrs = attrs
516            .into_iter()
517            .map(|kv| evaluate_kvs(kv, true))
518            .collect::<Result<_, EntityAttrEvaluationError>>()?;
519        let evaluated_tags = tags
520            .into_iter()
521            .map(|kv| evaluate_kvs(kv, false))
522            .collect::<Result<_, EntityAttrEvaluationError>>()?;
523        Ok(Entity {
524            uid,
525            attrs: evaluated_attrs,
526            indirect_ancestors,
527            parents,
528            tags: evaluated_tags,
529        })
530    }
531
532    /// Create a new [`Entity`] with this UID, attributes, ancestors, and tags
533    ///
534    /// Unlike in `Entity::new()`, in this constructor, attributes and tags are
535    /// expressed as `PartialValue`.
536    pub fn new_with_attr_partial_value(
537        uid: EntityUID,
538        attrs: impl IntoIterator<Item = (SmolStr, PartialValue)>,
539        indirect_ancestors: HashSet<EntityUID>,
540        parents: HashSet<EntityUID>,
541        tags: impl IntoIterator<Item = (SmolStr, PartialValue)>,
542    ) -> Self {
543        Self {
544            uid,
545            attrs: attrs.into_iter().collect(),
546            indirect_ancestors,
547            parents,
548            tags: tags.into_iter().collect(),
549        }
550    }
551
552    /// Get the UID of this entity
553    pub fn uid(&self) -> &EntityUID {
554        &self.uid
555    }
556
557    /// Get the value for the given attribute, or `None` if not present
558    pub fn get(&self, attr: &str) -> Option<&PartialValue> {
559        self.attrs.get(attr)
560    }
561
562    /// Get the value for the given tag, or `None` if not present
563    pub fn get_tag(&self, tag: &str) -> Option<&PartialValue> {
564        self.tags.get(tag)
565    }
566
567    /// Is this `Entity` a (direct or indirect) descendant of `e` in the entity hierarchy?
568    pub fn is_descendant_of(&self, e: &EntityUID) -> bool {
569        self.parents.contains(e) || self.indirect_ancestors.contains(e)
570    }
571
572    /// Is this `Entity` a an indirect descendant of `e` in the entity hierarchy?
573    pub fn is_indirect_descendant_of(&self, e: &EntityUID) -> bool {
574        self.indirect_ancestors.contains(e)
575    }
576
577    /// Is this `Entity` a direct decendant (child) of `e` in the entity hierarchy?
578    pub fn is_child_of(&self, e: &EntityUID) -> bool {
579        self.parents.contains(e)
580    }
581
582    /// Iterate over this entity's (direct or indirect) ancestors
583    pub fn ancestors(&self) -> impl Iterator<Item = &EntityUID> {
584        self.parents.iter().chain(self.indirect_ancestors.iter())
585    }
586
587    /// Iterate over this entity's indirect ancestors
588    pub fn indirect_ancestors(&self) -> impl Iterator<Item = &EntityUID> {
589        self.indirect_ancestors.iter()
590    }
591
592    /// Iterate over this entity's direct ancestors (parents)
593    pub fn parents(&self) -> impl Iterator<Item = &EntityUID> {
594        self.parents.iter()
595    }
596
597    /// Get the number of attributes on this entity
598    pub fn attrs_len(&self) -> usize {
599        self.attrs.len()
600    }
601
602    /// Get the number of tags on this entity
603    pub fn tags_len(&self) -> usize {
604        self.tags.len()
605    }
606
607    /// Iterate over this entity's attribute names
608    pub fn keys(&self) -> impl Iterator<Item = &SmolStr> {
609        self.attrs.keys()
610    }
611
612    /// Iterate over this entity's tag names
613    pub fn tag_keys(&self) -> impl Iterator<Item = &SmolStr> {
614        self.tags.keys()
615    }
616
617    /// Iterate over this entity's attributes
618    pub fn attrs(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
619        self.attrs.iter()
620    }
621
622    /// Iterate over this entity's tags
623    pub fn tags(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
624        self.tags.iter()
625    }
626
627    /// Create an `Entity` with the given UID, no attributes, no parents, and no tags.
628    pub fn with_uid(uid: EntityUID) -> Self {
629        Self {
630            uid,
631            attrs: BTreeMap::new(),
632            indirect_ancestors: HashSet::new(),
633            parents: HashSet::new(),
634            tags: BTreeMap::new(),
635        }
636    }
637
638    /// Test if two `Entity` objects are deep/structurally equal.
639    /// That is, not only do they have the same UID, but also the same
640    /// attributes, attribute values, and ancestors/parents.
641    ///
642    /// Does not test that they have the same _direct_ parents, only that they have the same overall ancestor set.
643    pub fn deep_eq(&self, other: &Self) -> bool {
644        self.uid == other.uid
645            && self.attrs == other.attrs
646            && self.tags == other.tags
647            && (self.ancestors().collect::<HashSet<_>>())
648                == (other.ancestors().collect::<HashSet<_>>())
649    }
650
651    /// Mark the given `UID` as an indirect ancestor of this `Entity`
652    ///
653    /// The given `UID` will not be added as an indirecty ancestor if
654    /// it is already a direct ancestor (parent) of this `Entity`
655    /// The caller of this code is responsible for maintaining
656    /// transitive closure of hierarchy.
657    pub fn add_indirect_ancestor(&mut self, uid: EntityUID) {
658        if !self.parents.contains(&uid) {
659            self.indirect_ancestors.insert(uid);
660        }
661    }
662
663    /// Mark the given `UID` as a (direct) parent of this `Entity`, and
664    /// remove the UID from indirect ancestors
665    /// if it was previously added as an indirect ancestor
666    /// The caller of this code is responsible for maintaining
667    /// transitive closure of hierarchy.
668    pub fn add_parent(&mut self, uid: EntityUID) {
669        self.indirect_ancestors.remove(&uid);
670        self.parents.insert(uid);
671    }
672
673    /// Remove the given `UID` as an indirect ancestor of this `Entity`.
674    ///
675    /// No effect if the `UID` is a direct parent.
676    /// The caller of this code is responsible for maintaining
677    /// transitive closure of hierarchy.
678    pub fn remove_indirect_ancestor(&mut self, uid: &EntityUID) {
679        self.indirect_ancestors.remove(uid);
680    }
681
682    /// Remove the given `UID` as a (direct) parent of this `Entity`.
683    ///
684    /// No effect on the `Entity`'s indirect ancestors.
685    /// The caller of this code is responsible for maintaining
686    /// transitive closure of hierarchy.
687    pub fn remove_parent(&mut self, uid: &EntityUID) {
688        self.parents.remove(uid);
689    }
690
691    /// Remove all indirect ancestors of this `Entity`.
692    ///
693    /// The caller of this code is responsible for maintaining
694    /// transitive closure of hierarchy.
695    pub fn remove_all_indirect_ancestors(&mut self) {
696        self.indirect_ancestors.clear();
697    }
698
699    /// Consume the entity and return the entity's owned Uid, attributes, ancestors, parents, and tags.
700    #[expect(
701        clippy::type_complexity,
702        reason = "needs to return a 5-tuple by design"
703    )]
704    pub fn into_inner(
705        self,
706    ) -> (
707        EntityUID,
708        HashMap<SmolStr, PartialValue>,
709        HashSet<EntityUID>,
710        HashSet<EntityUID>,
711        HashMap<SmolStr, PartialValue>,
712    ) {
713        (
714            self.uid,
715            self.attrs.into_iter().collect(),
716            self.indirect_ancestors,
717            self.parents,
718            self.tags.into_iter().collect(),
719        )
720    }
721
722    /// Write the entity to a json document
723    pub fn write_to_json(&self, f: impl std::io::Write) -> Result<(), EntitiesError> {
724        let ejson = EntityJson::from_entity(self)?;
725        serde_json::to_writer_pretty(f, &ejson).map_err(JsonSerializationError::from)?;
726        Ok(())
727    }
728
729    /// write the entity to a json value
730    pub fn to_json_value(&self) -> Result<serde_json::Value, EntitiesError> {
731        let ejson = EntityJson::from_entity(self)?;
732        let v = serde_json::to_value(ejson).map_err(JsonSerializationError::from)?;
733        Ok(v)
734    }
735
736    /// write the entity to a json string
737    pub fn to_json_string(&self) -> Result<String, EntitiesError> {
738        let ejson = EntityJson::from_entity(self)?;
739        let string = serde_json::to_string(&ejson).map_err(JsonSerializationError::from)?;
740        Ok(string)
741    }
742}
743
744/// `Entity`s are equal if their UIDs are equal
745impl PartialEq for Entity {
746    fn eq(&self, other: &Self) -> bool {
747        self.uid() == other.uid()
748    }
749}
750
751impl Eq for Entity {}
752
753impl StaticallyTyped for Entity {
754    fn type_of(&self) -> Type {
755        self.uid.type_of()
756    }
757}
758
759impl TCNode<EntityUID> for Entity {
760    fn get_key(&self) -> EntityUID {
761        self.uid().clone()
762    }
763
764    fn add_edge_to(&mut self, k: EntityUID) {
765        self.add_indirect_ancestor(k);
766    }
767
768    fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
769        Box::new(self.ancestors())
770    }
771
772    fn has_edge_to(&self, e: &EntityUID) -> bool {
773        self.is_descendant_of(e)
774    }
775
776    fn reset_edges(&mut self) {
777        self.remove_all_indirect_ancestors()
778    }
779
780    fn direct_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
781        Box::new(self.parents())
782    }
783}
784
785impl TCNode<EntityUID> for Arc<Entity> {
786    fn get_key(&self) -> EntityUID {
787        self.uid().clone()
788    }
789
790    fn add_edge_to(&mut self, k: EntityUID) {
791        // Use Arc::make_mut to get a mutable reference to the inner value
792        Arc::make_mut(self).add_indirect_ancestor(k)
793    }
794
795    fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
796        Box::new(self.ancestors())
797    }
798
799    fn has_edge_to(&self, e: &EntityUID) -> bool {
800        self.is_descendant_of(e)
801    }
802
803    fn reset_edges(&mut self) {
804        // Use Arc::make_mut to get a mutable reference to the inner value
805        Arc::make_mut(self).remove_all_indirect_ancestors()
806    }
807
808    fn direct_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
809        Box::new(self.parents())
810    }
811}
812
813impl std::fmt::Display for Entity {
814    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
815        write!(
816            f,
817            "{}:\n  attrs:{}\n  ancestors:{}",
818            self.uid,
819            self.attrs
820                .iter()
821                .map(|(k, v)| format!("{k}: {v}"))
822                .join("; "),
823            self.ancestors().join(", ")
824        )
825    }
826}
827
828/// Error type for evaluation errors when evaluating an entity attribute or tag.
829/// Contains some extra contextual information and the underlying
830/// `EvaluationError`.
831//
832// This is NOT a publicly exported error type.
833#[derive(Debug, Diagnostic, Error)]
834#[error("failed to evaluate {} `{attr_or_tag}` of `{uid}`: {err}", if *.was_attr { "attribute" } else { "tag" })]
835pub struct EntityAttrEvaluationError {
836    /// UID of the entity where the error was encountered
837    pub uid: EntityUID,
838    /// Attribute or tag of the entity where the error was encountered
839    pub attr_or_tag: SmolStr,
840    /// If `attr_or_tag` was an attribute (`true`) or tag (`false`)
841    pub was_attr: bool,
842    /// Underlying evaluation error
843    #[diagnostic(transparent)]
844    pub err: EvaluationError,
845}
846
847#[cfg(test)]
848mod test {
849    use std::str::FromStr;
850
851    use super::*;
852
853    #[test]
854    fn display() {
855        let e = EntityUID::with_eid("eid");
856        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
857    }
858
859    #[test]
860    fn test_euid_equality() {
861        let e1 = EntityUID::with_eid("foo");
862        let e2 = EntityUID::from_components(
863            Name::parse_unqualified_name("test_entity_type")
864                .expect("should be a valid identifier")
865                .into(),
866            Eid::Eid("foo".into()),
867            None,
868        );
869        let e3 = EntityUID::from_components(
870            Name::parse_unqualified_name("Unspecified")
871                .expect("should be a valid identifier")
872                .into(),
873            Eid::Eid("foo".into()),
874            None,
875        );
876
877        // an EUID is equal to itself
878        assert_eq!(e1, e1);
879        assert_eq!(e2, e2);
880
881        // constructing with `with_euid` or `from_components` is the same
882        assert_eq!(e1, e2);
883
884        // other pairs are not equal
885        assert!(e1 != e3);
886    }
887
888    #[test]
889    fn action_checker() {
890        let euid = EntityUID::from_str("Action::\"view\"").unwrap();
891        assert!(euid.is_action());
892        let euid = EntityUID::from_str("Foo::Action::\"view\"").unwrap();
893        assert!(euid.is_action());
894        let euid = EntityUID::from_str("Foo::\"view\"").unwrap();
895        assert!(!euid.is_action());
896        let euid = EntityUID::from_str("Action::Foo::\"view\"").unwrap();
897        assert!(!euid.is_action());
898    }
899
900    #[test]
901    fn action_type_is_valid_id() {
902        Id::from_normalized_str(ACTION_ENTITY_TYPE).unwrap();
903    }
904
905    #[cfg(feature = "tolerant-ast")]
906    #[test]
907    fn error_entity() {
908        use cool_asserts::assert_matches;
909
910        let e = EntityUID::Error;
911        assert_matches!(e.eid(), Eid::ErrorEid);
912        assert_matches!(e.entity_type(), EntityType::ErrorEntityType);
913        assert!(!e.is_action());
914        assert_matches!(e.loc(), None);
915
916        let error_eid = Eid::ErrorEid;
917        assert_eq!(error_eid.escaped(), "Eid::Error");
918
919        let error_type = EntityType::ErrorEntityType;
920        assert!(!error_type.is_action());
921        assert_eq!(error_type.qualify_with(None), EntityType::ErrorEntityType);
922        assert_eq!(
923            error_type.qualify_with(Some(&Name(InternalName::from(Id::new_unchecked(
924                "EntityTypeError"
925            ))))),
926            EntityType::ErrorEntityType
927        );
928
929        assert_eq!(
930            error_type.name(),
931            &Name(InternalName::from(Id::new_unchecked("EntityTypeError")))
932        );
933        assert_eq!(error_type.loc(), None)
934    }
935
936    #[test]
937    fn entity_type_deserialization() {
938        let json = r#""some_entity_type""#;
939        let entity_type: EntityType = serde_json::from_str(json).unwrap();
940        assert_eq!(
941            entity_type.name().0.to_string(),
942            "some_entity_type".to_string()
943        )
944    }
945
946    #[test]
947    fn entity_type_serialization() {
948        let entity_type = EntityType::EntityType(Name(InternalName::from(Id::new_unchecked(
949            "some_entity_type",
950        ))));
951        let serialized = serde_json::to_string(&entity_type).unwrap();
952
953        assert_eq!(serialized, r#""some_entity_type""#);
954    }
955}