Skip to main content

cedar_policy_core/ast/
entity.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use crate::entities::{err::EntitiesError, json::err::JsonSerializationError, EntityJson};
19use crate::evaluator::{EvaluationError, RestrictedEvaluator};
20use crate::extensions::Extensions;
21use crate::parser::err::ParseErrors;
22use crate::parser::Loc;
23use crate::transitive_closure::TCNode;
24use crate::FromNormalizedStr;
25use educe::Educe;
26use itertools::Itertools;
27use miette::Diagnostic;
28use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
29use smol_str::SmolStr;
30use std::collections::{BTreeMap, HashMap, HashSet};
31use std::str::FromStr;
32use std::sync::Arc;
33use thiserror::Error;
34
35#[cfg(feature = "tolerant-ast")]
36static ERROR_NAME: std::sync::LazyLock<Name> = std::sync::LazyLock::new(|| {
37    Name(InternalName::from(Id::new_unchecked_const(
38        "EntityTypeError",
39    )))
40});
41
42#[cfg(feature = "tolerant-ast")]
43static EID_ERROR_STR: &str = "Eid::Error";
44
45#[cfg(feature = "tolerant-ast")]
46static ENTITY_TYPE_ERROR_STR: &str = "EntityType::Error";
47
48#[cfg(feature = "tolerant-ast")]
49static ENTITY_UID_ERROR_STR: &str = "EntityUID::Error";
50
51/// The entity type that Actions must have
52pub static ACTION_ENTITY_TYPE: &str = "Action";
53
54#[derive(PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
55#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
56/// Entity type - can be an error type when 'tolerant-ast' feature is enabled
57pub enum EntityType {
58    /// Entity type names are just [`Name`]s, but we have some operations on them specific to entity types.
59    EntityType(Name),
60    #[cfg(feature = "tolerant-ast")]
61    /// Represents an error node of an entity that failed to parse
62    ErrorEntityType,
63}
64
65impl<'de> Deserialize<'de> for EntityType {
66    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
67    where
68        D: Deserializer<'de>,
69    {
70        let name = Name::deserialize(deserializer)?;
71        Ok(EntityType::EntityType(name))
72    }
73}
74
75impl Serialize for EntityType {
76    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77    where
78        S: Serializer,
79    {
80        match self {
81            EntityType::EntityType(name) => name.serialize(serializer),
82            #[cfg(feature = "tolerant-ast")]
83            EntityType::ErrorEntityType => serializer.serialize_str(ENTITY_TYPE_ERROR_STR),
84        }
85    }
86}
87
88impl EntityType {
89    /// Is this an Action entity type?
90    /// Returns true when an entity type is an action entity type. This compares the
91    /// base name for the type, so this will return true for any entity type named
92    /// `Action` regardless of namespaces.
93    pub fn is_action(&self) -> bool {
94        match self {
95            EntityType::EntityType(name) => {
96                name.as_ref().basename() == &Id::new_unchecked_const(ACTION_ENTITY_TYPE)
97            }
98            #[cfg(feature = "tolerant-ast")]
99            EntityType::ErrorEntityType => false,
100        }
101    }
102
103    /// The name of this entity type
104    pub fn name(&self) -> &Name {
105        match self {
106            EntityType::EntityType(name) => name,
107            #[cfg(feature = "tolerant-ast")]
108            EntityType::ErrorEntityType => &ERROR_NAME,
109        }
110    }
111
112    /// Consumes this entity type and returns the owned name. Clones the statically defined name
113    /// for entity errors when this entity type is ErrorEntityType.
114    pub fn into_name(self) -> Name {
115        match self {
116            EntityType::EntityType(name) => name,
117            #[cfg(feature = "tolerant-ast")]
118            EntityType::ErrorEntityType => ERROR_NAME.clone(),
119        }
120    }
121
122    /// The source location of this entity type
123    pub fn loc(&self) -> Option<&Loc> {
124        match self {
125            EntityType::EntityType(name) => name.as_ref().loc(),
126            #[cfg(feature = "tolerant-ast")]
127            EntityType::ErrorEntityType => None,
128        }
129    }
130
131    /// Create a clone of this EntityType with given loc
132    pub fn with_loc(&self, loc: Option<&Loc>) -> Self {
133        match self {
134            EntityType::EntityType(name) => EntityType::EntityType(Name(InternalName {
135                id: name.0.id.clone(),
136                path: name.0.path.clone(),
137                loc: loc.cloned(),
138            })),
139            #[cfg(feature = "tolerant-ast")]
140            EntityType::ErrorEntityType => self.clone(),
141        }
142    }
143
144    /// Calls [`Name::qualify_with_name`] on the underlying [`Name`]
145    pub fn qualify_with(&self, namespace: Option<&Name>) -> Self {
146        match self {
147            EntityType::EntityType(name) => Self::EntityType(name.qualify_with_name(namespace)),
148            #[cfg(feature = "tolerant-ast")]
149            EntityType::ErrorEntityType => Self::ErrorEntityType,
150        }
151    }
152
153    /// Wraps [`Name::from_normalized_str`]
154    pub fn from_normalized_str(src: &str) -> Result<Self, ParseErrors> {
155        Name::from_normalized_str(src).map(Into::into)
156    }
157}
158
159impl From<Name> for EntityType {
160    fn from(n: Name) -> Self {
161        Self::EntityType(n)
162    }
163}
164
165impl From<EntityType> for Name {
166    fn from(ty: EntityType) -> Name {
167        match ty {
168            EntityType::EntityType(name) => name,
169            #[cfg(feature = "tolerant-ast")]
170            EntityType::ErrorEntityType => ERROR_NAME.clone(),
171        }
172    }
173}
174
175impl AsRef<Name> for EntityType {
176    fn as_ref(&self) -> &Name {
177        match self {
178            EntityType::EntityType(name) => name,
179            #[cfg(feature = "tolerant-ast")]
180            EntityType::ErrorEntityType => &ERROR_NAME,
181        }
182    }
183}
184
185impl FromStr for EntityType {
186    type Err = ParseErrors;
187
188    fn from_str(s: &str) -> Result<Self, Self::Err> {
189        s.parse().map(Self::EntityType)
190    }
191}
192
193impl std::fmt::Display for EntityType {
194    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195        match self {
196            EntityType::EntityType(name) => write!(f, "{name}"),
197            #[cfg(feature = "tolerant-ast")]
198            EntityType::ErrorEntityType => write!(f, "{ENTITY_TYPE_ERROR_STR}"),
199        }
200    }
201}
202
203/// Unique ID for an entity. These represent entities in the AST.
204#[derive(Educe, Serialize, Deserialize, Debug, Clone)]
205#[serde(rename = "EntityUID")]
206#[educe(PartialEq, Eq, Hash, PartialOrd, Ord)]
207pub struct EntityUIDImpl {
208    /// Typename of the entity
209    ty: EntityType,
210    /// EID of the entity
211    eid: Eid,
212    /// Location of the entity in policy source
213    #[serde(skip)]
214    #[educe(PartialEq(ignore))]
215    #[educe(Hash(ignore))]
216    #[educe(PartialOrd(ignore))]
217    loc: Option<Loc>,
218}
219
220impl EntityUIDImpl {
221    /// The source location of this entity
222    pub fn loc(&self) -> Option<Loc> {
223        self.loc.clone()
224    }
225}
226
227/// Unique ID for an entity. These represent entities in the AST.
228#[derive(Educe, Debug, Clone)]
229#[educe(PartialEq, Eq, Hash, PartialOrd, Ord)]
230pub enum EntityUID {
231    /// Unique ID for an entity. These represent entities in the AST
232    EntityUID(EntityUIDImpl),
233    #[cfg(feature = "tolerant-ast")]
234    /// Represents the ID of an error that failed to parse
235    Error,
236}
237
238impl<'de> Deserialize<'de> for EntityUID {
239    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
240    where
241        D: Deserializer<'de>,
242    {
243        let uid_impl = EntityUIDImpl::deserialize(deserializer)?;
244        Ok(EntityUID::EntityUID(uid_impl))
245    }
246}
247
248impl Serialize for EntityUID {
249    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
250    where
251        S: Serializer,
252    {
253        match self {
254            EntityUID::EntityUID(uid_impl) => uid_impl.serialize(serializer),
255            #[cfg(feature = "tolerant-ast")]
256            EntityUID::Error => serializer.serialize_str(ENTITY_UID_ERROR_STR),
257        }
258    }
259}
260
261impl StaticallyTyped for EntityUID {
262    fn type_of(&self) -> Type {
263        match self {
264            EntityUID::EntityUID(entity_uid) => Type::Entity {
265                ty: entity_uid.ty.clone(),
266            },
267            #[cfg(feature = "tolerant-ast")]
268            EntityUID::Error => Type::Entity {
269                ty: EntityType::ErrorEntityType,
270            },
271        }
272    }
273}
274
275#[cfg(test)]
276impl EntityUID {
277    /// Create an `EntityUID` with the given string as its EID.
278    /// Useful for testing.
279    pub(crate) fn with_eid(eid: &str) -> Self {
280        Self::EntityUID(EntityUIDImpl {
281            ty: Self::test_entity_type(),
282            eid: Eid::Eid(eid.into()),
283            loc: None,
284        })
285    }
286
287    /// The type of entities created with the above `with_eid()`.
288    pub(crate) fn test_entity_type() -> EntityType {
289        let name = Name::parse_unqualified_name("test_entity_type")
290            .expect("test_entity_type should be a valid identifier");
291        EntityType::EntityType(name)
292    }
293}
294
295impl EntityUID {
296    /// Create an `EntityUID` with the given (unqualified) typename, and the given string as its EID.
297    pub fn with_eid_and_type(typename: &str, eid: &str) -> Result<Self, ParseErrors> {
298        Ok(Self::EntityUID(EntityUIDImpl {
299            ty: EntityType::EntityType(Name::parse_unqualified_name(typename)?),
300            eid: Eid::Eid(eid.into()),
301            loc: None,
302        }))
303    }
304
305    /// Split into the `EntityType` representing the entity type, and the `Eid`
306    /// representing its name
307    pub fn components(self) -> (EntityType, Eid) {
308        match self {
309            EntityUID::EntityUID(entity_uid) => (entity_uid.ty, entity_uid.eid),
310            #[cfg(feature = "tolerant-ast")]
311            EntityUID::Error => (EntityType::ErrorEntityType, Eid::ErrorEid),
312        }
313    }
314
315    /// Get the source location for this `EntityUID`.
316    pub fn loc(&self) -> Option<&Loc> {
317        match self {
318            EntityUID::EntityUID(entity_uid) => entity_uid.loc.as_ref(),
319            #[cfg(feature = "tolerant-ast")]
320            EntityUID::Error => None,
321        }
322    }
323
324    /// Create an [`EntityUID`] with the given typename and [`Eid`]
325    pub fn from_components(ty: EntityType, eid: Eid, loc: Option<Loc>) -> Self {
326        Self::EntityUID(EntityUIDImpl { ty, eid, loc })
327    }
328
329    /// Get the type component.
330    pub fn entity_type(&self) -> &EntityType {
331        match self {
332            EntityUID::EntityUID(entity_uid) => &entity_uid.ty,
333            #[cfg(feature = "tolerant-ast")]
334            EntityUID::Error => &EntityType::ErrorEntityType,
335        }
336    }
337
338    /// Get the Eid component.
339    pub fn eid(&self) -> &Eid {
340        match self {
341            EntityUID::EntityUID(entity_uid) => &entity_uid.eid,
342            #[cfg(feature = "tolerant-ast")]
343            EntityUID::Error => &Eid::ErrorEid,
344        }
345    }
346
347    /// Does this EntityUID refer to an action entity?
348    pub fn is_action(&self) -> bool {
349        self.entity_type().is_action()
350    }
351}
352
353impl std::fmt::Display for EntityUID {
354    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
355        write!(f, "{}::\"{}\"", self.entity_type(), self.eid().escaped())
356    }
357}
358
359// allow `.parse()` on a string to make an `EntityUID`
360impl std::str::FromStr for EntityUID {
361    type Err = ParseErrors;
362
363    fn from_str(s: &str) -> Result<Self, Self::Err> {
364        crate::parser::parse_euid(s)
365    }
366}
367
368impl FromNormalizedStr for EntityUID {
369    fn describe_self() -> &'static str {
370        "Entity UID"
371    }
372}
373
374#[cfg(feature = "arbitrary")]
375impl<'a> arbitrary::Arbitrary<'a> for EntityUID {
376    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
377        Ok(Self::EntityUID(EntityUIDImpl {
378            ty: u.arbitrary()?,
379            eid: u.arbitrary()?,
380            loc: None,
381        }))
382    }
383}
384
385/// The `Eid` type represents the id of an `Entity`, without the typename.
386/// Together with the typename it comprises an `EntityUID`.
387/// For example, in `User::"alice"`, the `Eid` is `alice`.
388///
389/// `Eid` does not implement `Display`, partly because it is unclear whether
390/// `Display` should produce an escaped representation or an unescaped representation
391/// (see [#884](https://github.com/cedar-policy/cedar/issues/884)).
392/// To get an escaped representation, use `.escaped()`.
393/// To get an unescaped representation, use `.as_ref()`.
394#[derive(PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
395pub enum Eid {
396    /// Actual Eid
397    Eid(SmolStr),
398    #[cfg(feature = "tolerant-ast")]
399    /// Represents an Eid of an entity that failed to parse
400    ErrorEid,
401}
402
403impl<'de> Deserialize<'de> for Eid {
404    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
405    where
406        D: Deserializer<'de>,
407    {
408        let value = String::deserialize(deserializer)?;
409        Ok(Eid::Eid(SmolStr::from(value)))
410    }
411}
412
413impl Serialize for Eid {
414    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
415    where
416        S: Serializer,
417    {
418        match self {
419            Eid::Eid(s) => s.serialize(serializer),
420            #[cfg(feature = "tolerant-ast")]
421            Eid::ErrorEid => serializer.serialize_str(EID_ERROR_STR),
422        }
423    }
424}
425
426impl Eid {
427    /// Construct an Eid
428    pub fn new(eid: impl Into<SmolStr>) -> Self {
429        Eid::Eid(eid.into())
430    }
431
432    /// Get the contents of the `Eid` as an escaped string
433    pub fn escaped(&self) -> SmolStr {
434        match self {
435            Eid::Eid(smol_str) => smol_str.escape_debug().collect(),
436            #[cfg(feature = "tolerant-ast")]
437            Eid::ErrorEid => SmolStr::new_static(EID_ERROR_STR),
438        }
439    }
440
441    /// Get the underlying smolstr for this `Eid`
442    pub fn into_smolstr(self) -> SmolStr {
443        match self {
444            Eid::Eid(smol_str) => smol_str,
445            #[cfg(feature = "tolerant-ast")]
446            Eid::ErrorEid => SmolStr::new_static(EID_ERROR_STR),
447        }
448    }
449}
450
451impl AsRef<str> for Eid {
452    fn as_ref(&self) -> &str {
453        match self {
454            Eid::Eid(smol_str) => smol_str,
455            #[cfg(feature = "tolerant-ast")]
456            Eid::ErrorEid => EID_ERROR_STR,
457        }
458    }
459}
460
461#[cfg(feature = "arbitrary")]
462impl<'a> arbitrary::Arbitrary<'a> for Eid {
463    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
464        let x: String = u.arbitrary()?;
465        Ok(Self::Eid(x.into()))
466    }
467}
468
469/// Entity datatype
470#[derive(Debug, Clone)]
471pub struct Entity {
472    /// UID
473    uid: EntityUID,
474
475    /// Internal `BTreeMap` of attributes.
476    ///
477    /// We use a `BTreeMap` so that the keys have a deterministic order.
478    attrs: BTreeMap<SmolStr, PartialValue>,
479
480    /// Set of indirect ancestors of this `Entity` as UIDs
481    indirect_ancestors: HashSet<EntityUID>,
482
483    /// Set of direct ancestors (i.e., parents) as UIDs
484    ///
485    /// indirect_ancestors and parents should be disjoint
486    /// even if a parent is also an indirect parent through
487    /// a different parent
488    parents: HashSet<EntityUID>,
489
490    /// Tags on this entity (RFC 82)
491    ///
492    /// Like for `attrs`, we use a `BTreeMap` so that the tags have a
493    /// deterministic order.
494    tags: BTreeMap<SmolStr, PartialValue>,
495}
496
497impl std::hash::Hash for Entity {
498    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
499        self.uid.hash(state);
500    }
501}
502
503impl Entity {
504    /// Create a new `Entity` with this UID, attributes, ancestors, and tags
505    ///
506    /// # Errors
507    /// - Will error if any of the [`RestrictedExpr]`s in `attrs` or `tags` error when evaluated
508    pub fn new(
509        uid: EntityUID,
510        attrs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
511        indirect_ancestors: HashSet<EntityUID>,
512        parents: HashSet<EntityUID>,
513        tags: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
514        extensions: &Extensions<'_>,
515    ) -> Result<Self, EntityAttrEvaluationError> {
516        let evaluator = RestrictedEvaluator::new(extensions);
517        let evaluate_kvs = |(k, v): (SmolStr, RestrictedExpr), was_attr: bool| {
518            let attr_val = evaluator
519                .partial_interpret(v.as_borrowed())
520                .map_err(|err| EntityAttrEvaluationError {
521                    uid: uid.clone(),
522                    attr_or_tag: k.clone(),
523                    was_attr,
524                    err,
525                })?;
526            Ok((k, attr_val))
527        };
528        let evaluated_attrs = attrs
529            .into_iter()
530            .map(|kv| evaluate_kvs(kv, true))
531            .collect::<Result<_, EntityAttrEvaluationError>>()?;
532        let evaluated_tags = tags
533            .into_iter()
534            .map(|kv| evaluate_kvs(kv, false))
535            .collect::<Result<_, EntityAttrEvaluationError>>()?;
536        Ok(Entity {
537            uid,
538            attrs: evaluated_attrs,
539            indirect_ancestors,
540            parents,
541            tags: evaluated_tags,
542        })
543    }
544
545    /// Create a new [`Entity`] with this UID, attributes, ancestors, and tags
546    ///
547    /// Unlike in `Entity::new()`, in this constructor, attributes and tags are
548    /// expressed as `PartialValue`.
549    pub fn new_with_attr_partial_value(
550        uid: EntityUID,
551        attrs: impl IntoIterator<Item = (SmolStr, PartialValue)>,
552        indirect_ancestors: HashSet<EntityUID>,
553        parents: HashSet<EntityUID>,
554        tags: impl IntoIterator<Item = (SmolStr, PartialValue)>,
555    ) -> Self {
556        Self {
557            uid,
558            attrs: attrs.into_iter().collect(),
559            indirect_ancestors,
560            parents,
561            tags: tags.into_iter().collect(),
562        }
563    }
564
565    /// Get the UID of this entity
566    pub fn uid(&self) -> &EntityUID {
567        &self.uid
568    }
569
570    /// Get the value for the given attribute, or `None` if not present
571    pub fn get(&self, attr: &str) -> Option<&PartialValue> {
572        self.attrs.get(attr)
573    }
574
575    /// Get the value for the given tag, or `None` if not present
576    pub fn get_tag(&self, tag: &str) -> Option<&PartialValue> {
577        self.tags.get(tag)
578    }
579
580    /// Is this `Entity` a (direct or indirect) descendant of `e` in the entity hierarchy?
581    pub fn is_descendant_of(&self, e: &EntityUID) -> bool {
582        self.parents.contains(e) || self.indirect_ancestors.contains(e)
583    }
584
585    /// Is this `Entity` a an indirect descendant of `e` in the entity hierarchy?
586    pub fn is_indirect_descendant_of(&self, e: &EntityUID) -> bool {
587        self.indirect_ancestors.contains(e)
588    }
589
590    /// Is this `Entity` a direct decendant (child) of `e` in the entity hierarchy?
591    pub fn is_child_of(&self, e: &EntityUID) -> bool {
592        self.parents.contains(e)
593    }
594
595    /// Iterate over this entity's (direct or indirect) ancestors
596    pub fn ancestors(&self) -> impl Iterator<Item = &EntityUID> {
597        self.parents.iter().chain(self.indirect_ancestors.iter())
598    }
599
600    /// Iterate over this entity's indirect ancestors
601    pub fn indirect_ancestors(&self) -> impl Iterator<Item = &EntityUID> {
602        self.indirect_ancestors.iter()
603    }
604
605    /// Iterate over this entity's direct ancestors (parents)
606    pub fn parents(&self) -> impl Iterator<Item = &EntityUID> {
607        self.parents.iter()
608    }
609
610    /// Get the number of attributes on this entity
611    pub fn attrs_len(&self) -> usize {
612        self.attrs.len()
613    }
614
615    /// Get the number of tags on this entity
616    pub fn tags_len(&self) -> usize {
617        self.tags.len()
618    }
619
620    /// Iterate over this entity's attribute names
621    pub fn keys(&self) -> impl Iterator<Item = &SmolStr> {
622        self.attrs.keys()
623    }
624
625    /// Iterate over this entity's tag names
626    pub fn tag_keys(&self) -> impl Iterator<Item = &SmolStr> {
627        self.tags.keys()
628    }
629
630    /// Iterate over this entity's attributes
631    pub fn attrs(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
632        self.attrs.iter()
633    }
634
635    /// Iterate over this entity's tags
636    pub fn tags(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
637        self.tags.iter()
638    }
639
640    /// Create an `Entity` with the given UID, no attributes, no parents, and no tags.
641    pub fn with_uid(uid: EntityUID) -> Self {
642        Self {
643            uid,
644            attrs: BTreeMap::new(),
645            indirect_ancestors: HashSet::new(),
646            parents: HashSet::new(),
647            tags: BTreeMap::new(),
648        }
649    }
650
651    /// Test if two `Entity` objects are deep/structurally equal.
652    /// That is, not only do they have the same UID, but also the same
653    /// attributes, attribute values, and ancestors/parents.
654    ///
655    /// Does not test that they have the same _direct_ parents, only that they have the same overall ancestor set.
656    pub fn deep_eq(&self, other: &Self) -> bool {
657        self.uid == other.uid
658            && self.attrs == other.attrs
659            && self.tags == other.tags
660            && (self.ancestors().collect::<HashSet<_>>())
661                == (other.ancestors().collect::<HashSet<_>>())
662    }
663
664    /// Mark the given `UID` as an indirect ancestor of this `Entity`
665    ///
666    /// The given `UID` will not be added as an indirecty ancestor if
667    /// it is already a direct ancestor (parent) of this `Entity`
668    /// The caller of this code is responsible for maintaining
669    /// transitive closure of hierarchy.
670    pub fn add_indirect_ancestor(&mut self, uid: EntityUID) {
671        if !self.parents.contains(&uid) {
672            self.indirect_ancestors.insert(uid);
673        }
674    }
675
676    /// Mark the given `UID` as a (direct) parent of this `Entity`, and
677    /// remove the UID from indirect ancestors
678    /// if it was previously added as an indirect ancestor
679    /// The caller of this code is responsible for maintaining
680    /// transitive closure of hierarchy.
681    pub fn add_parent(&mut self, uid: EntityUID) {
682        self.indirect_ancestors.remove(&uid);
683        self.parents.insert(uid);
684    }
685
686    /// Remove the given `UID` as an indirect ancestor of this `Entity`.
687    ///
688    /// No effect if the `UID` is a direct parent.
689    /// The caller of this code is responsible for maintaining
690    /// transitive closure of hierarchy.
691    pub fn remove_indirect_ancestor(&mut self, uid: &EntityUID) {
692        self.indirect_ancestors.remove(uid);
693    }
694
695    /// Remove the given `UID` as a (direct) parent of this `Entity`.
696    ///
697    /// No effect on the `Entity`'s indirect ancestors.
698    /// The caller of this code is responsible for maintaining
699    /// transitive closure of hierarchy.
700    pub fn remove_parent(&mut self, uid: &EntityUID) {
701        self.parents.remove(uid);
702    }
703
704    /// Remove all indirect ancestors of this `Entity`.
705    ///
706    /// The caller of this code is responsible for maintaining
707    /// transitive closure of hierarchy.
708    pub fn remove_all_indirect_ancestors(&mut self) {
709        self.indirect_ancestors.clear();
710    }
711
712    /// Consume the entity and return the entity's owned Uid, attributes, ancestors, parents, and tags.
713    #[expect(
714        clippy::type_complexity,
715        reason = "needs to return a 5-tuple by design"
716    )]
717    pub fn into_inner(
718        self,
719    ) -> (
720        EntityUID,
721        HashMap<SmolStr, PartialValue>,
722        HashSet<EntityUID>,
723        HashSet<EntityUID>,
724        HashMap<SmolStr, PartialValue>,
725    ) {
726        (
727            self.uid,
728            self.attrs.into_iter().collect(),
729            self.indirect_ancestors,
730            self.parents,
731            self.tags.into_iter().collect(),
732        )
733    }
734
735    /// Write the entity to a json document
736    pub fn write_to_json(&self, f: impl std::io::Write) -> Result<(), EntitiesError> {
737        let ejson = EntityJson::from_entity(self)?;
738        serde_json::to_writer_pretty(f, &ejson).map_err(JsonSerializationError::from)?;
739        Ok(())
740    }
741
742    /// write the entity to a json value
743    pub fn to_json_value(&self) -> Result<serde_json::Value, EntitiesError> {
744        let ejson = EntityJson::from_entity(self)?;
745        let v = serde_json::to_value(ejson).map_err(JsonSerializationError::from)?;
746        Ok(v)
747    }
748
749    /// write the entity to a json string
750    pub fn to_json_string(&self) -> Result<String, EntitiesError> {
751        let ejson = EntityJson::from_entity(self)?;
752        let string = serde_json::to_string(&ejson).map_err(JsonSerializationError::from)?;
753        Ok(string)
754    }
755}
756
757/// `Entity`s are equal if their UIDs are equal
758impl PartialEq for Entity {
759    fn eq(&self, other: &Self) -> bool {
760        self.uid() == other.uid()
761    }
762}
763
764impl Eq for Entity {}
765
766impl StaticallyTyped for Entity {
767    fn type_of(&self) -> Type {
768        self.uid.type_of()
769    }
770}
771
772impl TCNode<EntityUID> for Entity {
773    fn get_key(&self) -> EntityUID {
774        self.uid().clone()
775    }
776
777    fn add_edge_to(&mut self, k: EntityUID) {
778        self.add_indirect_ancestor(k);
779    }
780
781    fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
782        Box::new(self.ancestors())
783    }
784
785    fn has_edge_to(&self, e: &EntityUID) -> bool {
786        self.is_descendant_of(e)
787    }
788
789    fn reset_edges(&mut self) {
790        self.remove_all_indirect_ancestors()
791    }
792
793    fn direct_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
794        Box::new(self.parents())
795    }
796}
797
798impl TCNode<EntityUID> for Arc<Entity> {
799    fn get_key(&self) -> EntityUID {
800        self.uid().clone()
801    }
802
803    fn add_edge_to(&mut self, k: EntityUID) {
804        // Use Arc::make_mut to get a mutable reference to the inner value
805        Arc::make_mut(self).add_indirect_ancestor(k)
806    }
807
808    fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
809        Box::new(self.ancestors())
810    }
811
812    fn has_edge_to(&self, e: &EntityUID) -> bool {
813        self.is_descendant_of(e)
814    }
815
816    fn reset_edges(&mut self) {
817        // Use Arc::make_mut to get a mutable reference to the inner value
818        Arc::make_mut(self).remove_all_indirect_ancestors()
819    }
820
821    fn direct_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
822        Box::new(self.parents())
823    }
824}
825
826impl std::fmt::Display for Entity {
827    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
828        write!(
829            f,
830            "{}:\n  attrs:{}\n  ancestors:{}",
831            self.uid,
832            self.attrs
833                .iter()
834                .map(|(k, v)| format!("{k}: {v}"))
835                .join("; "),
836            self.ancestors().join(", ")
837        )
838    }
839}
840
841/// Error type for evaluation errors when evaluating an entity attribute or tag.
842/// Contains some extra contextual information and the underlying
843/// `EvaluationError`.
844//
845// This is NOT a publicly exported error type.
846#[derive(Debug, Diagnostic, Error)]
847#[error("failed to evaluate {} `{attr_or_tag}` of `{uid}`: {err}", if *.was_attr { "attribute" } else { "tag" })]
848pub struct EntityAttrEvaluationError {
849    /// UID of the entity where the error was encountered
850    pub uid: EntityUID,
851    /// Attribute or tag of the entity where the error was encountered
852    pub attr_or_tag: SmolStr,
853    /// If `attr_or_tag` was an attribute (`true`) or tag (`false`)
854    pub was_attr: bool,
855    /// Underlying evaluation error
856    #[diagnostic(transparent)]
857    pub err: EvaluationError,
858}
859
860#[cfg(test)]
861mod test {
862    use std::str::FromStr;
863
864    use super::*;
865
866    #[test]
867    fn display() {
868        let e = EntityUID::with_eid("eid");
869        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
870    }
871
872    #[test]
873    fn test_euid_equality() {
874        let e1 = EntityUID::with_eid("foo");
875        let e2 = EntityUID::from_components(
876            Name::parse_unqualified_name("test_entity_type")
877                .expect("should be a valid identifier")
878                .into(),
879            Eid::Eid("foo".into()),
880            None,
881        );
882        let e3 = EntityUID::from_components(
883            Name::parse_unqualified_name("Unspecified")
884                .expect("should be a valid identifier")
885                .into(),
886            Eid::Eid("foo".into()),
887            None,
888        );
889
890        // an EUID is equal to itself
891        assert_eq!(e1, e1);
892        assert_eq!(e2, e2);
893
894        // constructing with `with_euid` or `from_components` is the same
895        assert_eq!(e1, e2);
896
897        // other pairs are not equal
898        assert!(e1 != e3);
899    }
900
901    #[test]
902    fn action_checker() {
903        let euid = EntityUID::from_str("Action::\"view\"").unwrap();
904        assert!(euid.is_action());
905        let euid = EntityUID::from_str("Foo::Action::\"view\"").unwrap();
906        assert!(euid.is_action());
907        let euid = EntityUID::from_str("Foo::\"view\"").unwrap();
908        assert!(!euid.is_action());
909        let euid = EntityUID::from_str("Action::Foo::\"view\"").unwrap();
910        assert!(!euid.is_action());
911    }
912
913    #[test]
914    fn action_type_is_valid_id() {
915        Id::from_normalized_str(ACTION_ENTITY_TYPE).unwrap();
916    }
917
918    #[cfg(feature = "tolerant-ast")]
919    #[test]
920    fn error_entity() {
921        use cool_asserts::assert_matches;
922
923        let e = EntityUID::Error;
924        assert_matches!(e.eid(), Eid::ErrorEid);
925        assert_matches!(e.entity_type(), EntityType::ErrorEntityType);
926        assert!(!e.is_action());
927        assert_matches!(e.loc(), None);
928
929        let error_eid = Eid::ErrorEid;
930        assert_eq!(error_eid.escaped(), "Eid::Error");
931
932        let error_type = EntityType::ErrorEntityType;
933        assert!(!error_type.is_action());
934        assert_eq!(error_type.qualify_with(None), EntityType::ErrorEntityType);
935        assert_eq!(
936            error_type.qualify_with(Some(&Name(InternalName::from(Id::new_unchecked_const(
937                "EntityTypeError"
938            ))))),
939            EntityType::ErrorEntityType
940        );
941
942        assert_eq!(
943            error_type.name(),
944            &Name(InternalName::from(Id::new_unchecked_const(
945                "EntityTypeError"
946            )))
947        );
948        assert_eq!(error_type.loc(), None)
949    }
950
951    #[test]
952    fn entity_type_deserialization() {
953        let json = r#""some_entity_type""#;
954        let entity_type: EntityType = serde_json::from_str(json).unwrap();
955        assert_eq!(
956            entity_type.name().0.to_string(),
957            "some_entity_type".to_string()
958        )
959    }
960
961    #[test]
962    fn entity_type_serialization() {
963        let entity_type = EntityType::EntityType(Name(InternalName::from(
964            Id::new_unchecked_const("some_entity_type"),
965        )));
966        let serialized = serde_json::to_string(&entity_type).unwrap();
967
968        assert_eq!(serialized, r#""some_entity_type""#);
969    }
970}