cedar_policy_core/ast/
entity.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use crate::entities::{err::EntitiesError, json::err::JsonSerializationError, EntityJson};
19use crate::evaluator::{EvaluationError, RestrictedEvaluator};
20use crate::extensions::Extensions;
21use crate::parser::err::ParseErrors;
22use crate::parser::Loc;
23use crate::transitive_closure::TCNode;
24use crate::FromNormalizedStr;
25use educe::Educe;
26use itertools::Itertools;
27use miette::Diagnostic;
28use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
29use smol_str::SmolStr;
30use std::collections::{BTreeMap, HashMap, HashSet};
31use std::str::FromStr;
32use std::sync::Arc;
33use thiserror::Error;
34
35#[cfg(feature = "tolerant-ast")]
36static ERROR_NAME: std::sync::LazyLock<Name> =
37    std::sync::LazyLock::new(|| Name(InternalName::from(Id::new_unchecked("EntityTypeError"))));
38
39#[cfg(feature = "tolerant-ast")]
40static ERROR_EID_SMOL_STR: std::sync::LazyLock<SmolStr> =
41    std::sync::LazyLock::new(|| SmolStr::from("Eid::ErrorEid"));
42
43#[cfg(feature = "tolerant-ast")]
44static EID_ERROR_STR: &str = "Eid::Error";
45
46#[cfg(feature = "tolerant-ast")]
47static ENTITY_TYPE_ERROR_STR: &str = "EntityType::Error";
48
49#[cfg(feature = "tolerant-ast")]
50static ENTITY_UID_ERROR_STR: &str = "EntityUID::Error";
51
52/// The entity type that Actions must have
53pub static ACTION_ENTITY_TYPE: &str = "Action";
54
55#[derive(PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
56#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
57/// Entity type - can be an error type when 'tolerant-ast' feature is enabled
58pub enum EntityType {
59    /// Entity type names are just [`Name`]s, but we have some operations on them specific to entity types.
60    EntityType(Name),
61    #[cfg(feature = "tolerant-ast")]
62    /// Represents an error node of an entity that failed to parse
63    ErrorEntityType,
64}
65
66impl<'de> Deserialize<'de> for EntityType {
67    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
68    where
69        D: Deserializer<'de>,
70    {
71        let name = Name::deserialize(deserializer)?;
72        Ok(EntityType::EntityType(name))
73    }
74}
75
76impl Serialize for EntityType {
77    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
78    where
79        S: Serializer,
80    {
81        match self {
82            EntityType::EntityType(name) => name.serialize(serializer),
83            #[cfg(feature = "tolerant-ast")]
84            EntityType::ErrorEntityType => serializer.serialize_str(ENTITY_TYPE_ERROR_STR),
85        }
86    }
87}
88
89impl EntityType {
90    /// Is this an Action entity type?
91    /// Returns true when an entity type is an action entity type. This compares the
92    /// base name for the type, so this will return true for any entity type named
93    /// `Action` regardless of namespaces.
94    pub fn is_action(&self) -> bool {
95        match self {
96            EntityType::EntityType(name) => {
97                name.as_ref().basename() == &Id::new_unchecked(ACTION_ENTITY_TYPE)
98            }
99            #[cfg(feature = "tolerant-ast")]
100            EntityType::ErrorEntityType => false,
101        }
102    }
103
104    /// The name of this entity type
105    pub fn name(&self) -> &Name {
106        match self {
107            EntityType::EntityType(name) => name,
108            #[cfg(feature = "tolerant-ast")]
109            EntityType::ErrorEntityType => &ERROR_NAME,
110        }
111    }
112
113    /// The source location of this entity type
114    pub fn loc(&self) -> Option<&Loc> {
115        match self {
116            EntityType::EntityType(name) => name.as_ref().loc(),
117            #[cfg(feature = "tolerant-ast")]
118            EntityType::ErrorEntityType => None,
119        }
120    }
121
122    /// Create a clone of this EntityType with given loc
123    pub fn with_loc(&self, loc: Option<&Loc>) -> Self {
124        match self {
125            EntityType::EntityType(name) => EntityType::EntityType(Name(InternalName {
126                id: name.0.id.clone(),
127                path: name.0.path.clone(),
128                loc: loc.cloned(),
129            })),
130            #[cfg(feature = "tolerant-ast")]
131            EntityType::ErrorEntityType => self.clone(),
132        }
133    }
134
135    /// Calls [`Name::qualify_with_name`] on the underlying [`Name`]
136    pub fn qualify_with(&self, namespace: Option<&Name>) -> Self {
137        match self {
138            EntityType::EntityType(name) => Self::EntityType(name.qualify_with_name(namespace)),
139            #[cfg(feature = "tolerant-ast")]
140            EntityType::ErrorEntityType => Self::ErrorEntityType,
141        }
142    }
143
144    /// Wraps [`Name::from_normalized_str`]
145    pub fn from_normalized_str(src: &str) -> Result<Self, ParseErrors> {
146        Name::from_normalized_str(src).map(Into::into)
147    }
148}
149
150impl From<Name> for EntityType {
151    fn from(n: Name) -> Self {
152        Self::EntityType(n)
153    }
154}
155
156impl From<EntityType> for Name {
157    fn from(ty: EntityType) -> Name {
158        match ty {
159            EntityType::EntityType(name) => name,
160            #[cfg(feature = "tolerant-ast")]
161            EntityType::ErrorEntityType => ERROR_NAME.clone(),
162        }
163    }
164}
165
166impl AsRef<Name> for EntityType {
167    fn as_ref(&self) -> &Name {
168        match self {
169            EntityType::EntityType(name) => name,
170            #[cfg(feature = "tolerant-ast")]
171            EntityType::ErrorEntityType => &ERROR_NAME,
172        }
173    }
174}
175
176impl FromStr for EntityType {
177    type Err = ParseErrors;
178
179    fn from_str(s: &str) -> Result<Self, Self::Err> {
180        s.parse().map(Self::EntityType)
181    }
182}
183
184impl std::fmt::Display for EntityType {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        match self {
187            EntityType::EntityType(name) => write!(f, "{}", name),
188            #[cfg(feature = "tolerant-ast")]
189            EntityType::ErrorEntityType => write!(f, "{ENTITY_TYPE_ERROR_STR}"),
190        }
191    }
192}
193
194/// Unique ID for an entity. These represent entities in the AST.
195#[derive(Educe, Serialize, Deserialize, Debug, Clone)]
196#[serde(rename = "EntityUID")]
197#[educe(PartialEq, Eq, Hash, PartialOrd, Ord)]
198pub struct EntityUIDImpl {
199    /// Typename of the entity
200    ty: EntityType,
201    /// EID of the entity
202    eid: Eid,
203    /// Location of the entity in policy source
204    #[serde(skip)]
205    #[educe(PartialEq(ignore))]
206    #[educe(Hash(ignore))]
207    #[educe(PartialOrd(ignore))]
208    loc: Option<Loc>,
209}
210
211/// Unique ID for an entity. These represent entities in the AST.
212#[derive(Educe, Debug, Clone)]
213#[educe(PartialEq, Eq, Hash, PartialOrd, Ord)]
214pub enum EntityUID {
215    /// Unique ID for an entity. These represent entities in the AST
216    EntityUID(EntityUIDImpl),
217    #[cfg(feature = "tolerant-ast")]
218    /// Represents the ID of an error that failed to parse
219    Error,
220}
221
222impl<'de> Deserialize<'de> for EntityUID {
223    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
224    where
225        D: Deserializer<'de>,
226    {
227        let uid_impl = EntityUIDImpl::deserialize(deserializer)?;
228        Ok(EntityUID::EntityUID(uid_impl))
229    }
230}
231
232impl Serialize for EntityUID {
233    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
234    where
235        S: Serializer,
236    {
237        match self {
238            EntityUID::EntityUID(uid_impl) => uid_impl.serialize(serializer),
239            #[cfg(feature = "tolerant-ast")]
240            EntityUID::Error => serializer.serialize_str(ENTITY_UID_ERROR_STR),
241        }
242    }
243}
244
245impl StaticallyTyped for EntityUID {
246    fn type_of(&self) -> Type {
247        match self {
248            EntityUID::EntityUID(entity_uid) => Type::Entity {
249                ty: entity_uid.ty.clone(),
250            },
251            #[cfg(feature = "tolerant-ast")]
252            EntityUID::Error => Type::Entity {
253                ty: EntityType::ErrorEntityType,
254            },
255        }
256    }
257}
258
259#[cfg(test)]
260impl EntityUID {
261    /// Create an `EntityUID` with the given string as its EID.
262    /// Useful for testing.
263    pub(crate) fn with_eid(eid: &str) -> Self {
264        Self::EntityUID(EntityUIDImpl {
265            ty: Self::test_entity_type(),
266            eid: Eid::Eid(eid.into()),
267            loc: None,
268        })
269    }
270
271    /// The type of entities created with the above `with_eid()`.
272    pub(crate) fn test_entity_type() -> EntityType {
273        let name = Name::parse_unqualified_name("test_entity_type")
274            .expect("test_entity_type should be a valid identifier");
275        EntityType::EntityType(name)
276    }
277}
278
279impl EntityUID {
280    /// Create an `EntityUID` with the given (unqualified) typename, and the given string as its EID.
281    pub fn with_eid_and_type(typename: &str, eid: &str) -> Result<Self, ParseErrors> {
282        Ok(Self::EntityUID(EntityUIDImpl {
283            ty: EntityType::EntityType(Name::parse_unqualified_name(typename)?),
284            eid: Eid::Eid(eid.into()),
285            loc: None,
286        }))
287    }
288
289    /// Split into the `EntityType` representing the entity type, and the `Eid`
290    /// representing its name
291    pub fn components(self) -> (EntityType, Eid) {
292        match self {
293            EntityUID::EntityUID(entity_uid) => (entity_uid.ty, entity_uid.eid),
294            #[cfg(feature = "tolerant-ast")]
295            EntityUID::Error => (EntityType::ErrorEntityType, Eid::ErrorEid),
296        }
297    }
298
299    /// Get the source location for this `EntityUID`.
300    pub fn loc(&self) -> Option<&Loc> {
301        match self {
302            EntityUID::EntityUID(entity_uid) => entity_uid.loc.as_ref(),
303            #[cfg(feature = "tolerant-ast")]
304            EntityUID::Error => None,
305        }
306    }
307
308    /// Create an [`EntityUID`] with the given typename and [`Eid`]
309    pub fn from_components(ty: EntityType, eid: Eid, loc: Option<Loc>) -> Self {
310        Self::EntityUID(EntityUIDImpl { ty, eid, loc })
311    }
312
313    /// Get the type component.
314    pub fn entity_type(&self) -> &EntityType {
315        match self {
316            EntityUID::EntityUID(entity_uid) => &entity_uid.ty,
317            #[cfg(feature = "tolerant-ast")]
318            EntityUID::Error => &EntityType::ErrorEntityType,
319        }
320    }
321
322    /// Get the Eid component.
323    pub fn eid(&self) -> &Eid {
324        match self {
325            EntityUID::EntityUID(entity_uid) => &entity_uid.eid,
326            #[cfg(feature = "tolerant-ast")]
327            EntityUID::Error => &Eid::ErrorEid,
328        }
329    }
330
331    /// Does this EntityUID refer to an action entity?
332    pub fn is_action(&self) -> bool {
333        self.entity_type().is_action()
334    }
335}
336
337impl std::fmt::Display for EntityUID {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        write!(f, "{}::\"{}\"", self.entity_type(), self.eid().escaped())
340    }
341}
342
343// allow `.parse()` on a string to make an `EntityUID`
344impl std::str::FromStr for EntityUID {
345    type Err = ParseErrors;
346
347    fn from_str(s: &str) -> Result<Self, Self::Err> {
348        crate::parser::parse_euid(s)
349    }
350}
351
352impl FromNormalizedStr for EntityUID {
353    fn describe_self() -> &'static str {
354        "Entity UID"
355    }
356}
357
358#[cfg(feature = "arbitrary")]
359impl<'a> arbitrary::Arbitrary<'a> for EntityUID {
360    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
361        Ok(Self::EntityUID(EntityUIDImpl {
362            ty: u.arbitrary()?,
363            eid: u.arbitrary()?,
364            loc: None,
365        }))
366    }
367}
368
369/// The `Eid` type represents the id of an `Entity`, without the typename.
370/// Together with the typename it comprises an `EntityUID`.
371/// For example, in `User::"alice"`, the `Eid` is `alice`.
372///
373/// `Eid` does not implement `Display`, partly because it is unclear whether
374/// `Display` should produce an escaped representation or an unescaped representation
375/// (see [#884](https://github.com/cedar-policy/cedar/issues/884)).
376/// To get an escaped representation, use `.escaped()`.
377/// To get an unescaped representation, use `.as_ref()`.
378#[derive(PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
379pub enum Eid {
380    /// Actual Eid
381    Eid(SmolStr),
382    #[cfg(feature = "tolerant-ast")]
383    /// Represents an Eid of an entity that failed to parse
384    ErrorEid,
385}
386
387impl<'de> Deserialize<'de> for Eid {
388    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
389    where
390        D: Deserializer<'de>,
391    {
392        let value = String::deserialize(deserializer)?;
393        Ok(Eid::Eid(SmolStr::from(value)))
394    }
395}
396
397impl Serialize for Eid {
398    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
399    where
400        S: Serializer,
401    {
402        match self {
403            Eid::Eid(s) => s.serialize(serializer),
404            #[cfg(feature = "tolerant-ast")]
405            Eid::ErrorEid => serializer.serialize_str(EID_ERROR_STR),
406        }
407    }
408}
409
410impl Eid {
411    /// Construct an Eid
412    pub fn new(eid: impl Into<SmolStr>) -> Self {
413        Eid::Eid(eid.into())
414    }
415
416    /// Get the contents of the `Eid` as an escaped string
417    pub fn escaped(&self) -> SmolStr {
418        match self {
419            Eid::Eid(smol_str) => smol_str.escape_debug().collect(),
420            #[cfg(feature = "tolerant-ast")]
421            Eid::ErrorEid => EID_ERROR_STR.into(),
422        }
423    }
424}
425
426impl AsRef<SmolStr> for Eid {
427    fn as_ref(&self) -> &SmolStr {
428        match self {
429            Eid::Eid(smol_str) => smol_str,
430            #[cfg(feature = "tolerant-ast")]
431            Eid::ErrorEid => &ERROR_EID_SMOL_STR,
432        }
433    }
434}
435
436impl AsRef<str> for Eid {
437    fn as_ref(&self) -> &str {
438        match self {
439            Eid::Eid(smol_str) => smol_str,
440            #[cfg(feature = "tolerant-ast")]
441            Eid::ErrorEid => EID_ERROR_STR,
442        }
443    }
444}
445
446#[cfg(feature = "arbitrary")]
447impl<'a> arbitrary::Arbitrary<'a> for Eid {
448    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
449        let x: String = u.arbitrary()?;
450        Ok(Self::Eid(x.into()))
451    }
452}
453
454/// Entity datatype
455#[derive(Debug, Clone)]
456pub struct Entity {
457    /// UID
458    uid: EntityUID,
459
460    /// Internal `BTreeMap` of attributes.
461    ///
462    /// We use a `BTreeMap` so that the keys have a deterministic order.
463    attrs: BTreeMap<SmolStr, PartialValue>,
464
465    /// Set of indirect ancestors of this `Entity` as UIDs
466    indirect_ancestors: HashSet<EntityUID>,
467
468    /// Set of direct ancestors (i.e., parents) as UIDs
469    ///
470    /// indirect_ancestors and parents should be disjoint
471    /// even if a parent is also an indirect parent through
472    /// a different parent
473    parents: HashSet<EntityUID>,
474
475    /// Tags on this entity (RFC 82)
476    ///
477    /// Like for `attrs`, we use a `BTreeMap` so that the tags have a
478    /// deterministic order.
479    tags: BTreeMap<SmolStr, PartialValue>,
480}
481
482impl std::hash::Hash for Entity {
483    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
484        self.uid.hash(state);
485    }
486}
487
488impl Entity {
489    /// Create a new `Entity` with this UID, attributes, ancestors, and tags
490    ///
491    /// # Errors
492    /// - Will error if any of the [`RestrictedExpr]`s in `attrs` or `tags` error when evaluated
493    pub fn new(
494        uid: EntityUID,
495        attrs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
496        indirect_ancestors: HashSet<EntityUID>,
497        parents: HashSet<EntityUID>,
498        tags: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
499        extensions: &Extensions<'_>,
500    ) -> Result<Self, EntityAttrEvaluationError> {
501        let evaluator = RestrictedEvaluator::new(extensions);
502        let evaluate_kvs = |(k, v): (SmolStr, RestrictedExpr), was_attr: bool| {
503            let attr_val = evaluator
504                .partial_interpret(v.as_borrowed())
505                .map_err(|err| EntityAttrEvaluationError {
506                    uid: uid.clone(),
507                    attr_or_tag: k.clone(),
508                    was_attr,
509                    err,
510                })?;
511            Ok((k, attr_val))
512        };
513        let evaluated_attrs = attrs
514            .into_iter()
515            .map(|kv| evaluate_kvs(kv, true))
516            .collect::<Result<_, EntityAttrEvaluationError>>()?;
517        let evaluated_tags = tags
518            .into_iter()
519            .map(|kv| evaluate_kvs(kv, false))
520            .collect::<Result<_, EntityAttrEvaluationError>>()?;
521        Ok(Entity {
522            uid,
523            attrs: evaluated_attrs,
524            indirect_ancestors,
525            parents,
526            tags: evaluated_tags,
527        })
528    }
529
530    /// Create a new `Entity` with this UID, attributes, ancestors, and tags
531    ///
532    /// Unlike in `Entity::new()`, in this constructor, attributes and tags are
533    /// expressed as `PartialValue`.
534    ///
535    /// Callers should consider directly using [`Entity::new_with_attr_partial_value_serialized_as_expr`]
536    /// if they would call this method by first building a map, as it will
537    /// deconstruct and re-build the map perhaps unnecessarily.
538    pub fn new_with_attr_partial_value(
539        uid: EntityUID,
540        attrs: impl IntoIterator<Item = (SmolStr, PartialValue)>,
541        indirect_ancestors: HashSet<EntityUID>,
542        parents: HashSet<EntityUID>,
543        tags: impl IntoIterator<Item = (SmolStr, PartialValue)>,
544    ) -> Self {
545        Self {
546            uid,
547            attrs: attrs.into_iter().collect(),
548            indirect_ancestors,
549            parents,
550            tags: tags.into_iter().collect(),
551        }
552    }
553
554    /// Get the UID of this entity
555    pub fn uid(&self) -> &EntityUID {
556        &self.uid
557    }
558
559    /// Get the value for the given attribute, or `None` if not present
560    pub fn get(&self, attr: &str) -> Option<&PartialValue> {
561        self.attrs.get(attr)
562    }
563
564    /// Get the value for the given tag, or `None` if not present
565    pub fn get_tag(&self, tag: &str) -> Option<&PartialValue> {
566        self.tags.get(tag)
567    }
568
569    /// Is this `Entity` a (direct or indirect) descendant of `e` in the entity hierarchy?
570    pub fn is_descendant_of(&self, e: &EntityUID) -> bool {
571        self.parents.contains(e) || self.indirect_ancestors.contains(e)
572    }
573
574    /// Is this `Entity` a an indirect descendant of `e` in the entity hierarchy?
575    pub fn is_indirect_descendant_of(&self, e: &EntityUID) -> bool {
576        self.indirect_ancestors.contains(e)
577    }
578
579    /// Is this `Entity` a direct decendant (child) of `e` in the entity hierarchy?
580    pub fn is_child_of(&self, e: &EntityUID) -> bool {
581        self.parents.contains(e)
582    }
583
584    /// Iterate over this entity's (direct or indirect) ancestors
585    pub fn ancestors(&self) -> impl Iterator<Item = &EntityUID> {
586        self.parents.iter().chain(self.indirect_ancestors.iter())
587    }
588
589    /// Iterate over this entity's indirect ancestors
590    pub fn indirect_ancestors(&self) -> impl Iterator<Item = &EntityUID> {
591        self.indirect_ancestors.iter()
592    }
593
594    /// Iterate over this entity's direct ancestors (parents)
595    pub fn parents(&self) -> impl Iterator<Item = &EntityUID> {
596        self.parents.iter()
597    }
598
599    /// Get the number of attributes on this entity
600    pub fn attrs_len(&self) -> usize {
601        self.attrs.len()
602    }
603
604    /// Get the number of tags on this entity
605    pub fn tags_len(&self) -> usize {
606        self.tags.len()
607    }
608
609    /// Iterate over this entity's attribute names
610    pub fn keys(&self) -> impl Iterator<Item = &SmolStr> {
611        self.attrs.keys()
612    }
613
614    /// Iterate over this entity's tag names
615    pub fn tag_keys(&self) -> impl Iterator<Item = &SmolStr> {
616        self.tags.keys()
617    }
618
619    /// Iterate over this entity's attributes
620    pub fn attrs(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
621        self.attrs.iter()
622    }
623
624    /// Iterate over this entity's tags
625    pub fn tags(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
626        self.tags.iter()
627    }
628
629    /// Create an `Entity` with the given UID, no attributes, no parents, and no tags.
630    pub fn with_uid(uid: EntityUID) -> Self {
631        Self {
632            uid,
633            attrs: BTreeMap::new(),
634            indirect_ancestors: HashSet::new(),
635            parents: HashSet::new(),
636            tags: BTreeMap::new(),
637        }
638    }
639
640    /// Test if two `Entity` objects are deep/structurally equal.
641    /// That is, not only do they have the same UID, but also the same
642    /// attributes, attribute values, and ancestors/parents.
643    ///
644    /// Does not test that they have the same _direct_ parents, only that they have the same overall ancestor set.
645    pub(crate) fn deep_eq(&self, other: &Self) -> bool {
646        self.uid == other.uid
647            && self.attrs == other.attrs
648            && (self.ancestors().collect::<HashSet<_>>())
649                == (other.ancestors().collect::<HashSet<_>>())
650    }
651
652    /// Mark the given `UID` as an indirect ancestor of this `Entity`
653    ///
654    /// The given `UID` will not be added as an indirecty ancestor if
655    /// it is already a direct ancestor (parent) of this `Entity`
656    /// The caller of this code is responsible for maintaining
657    /// transitive closure of hierarchy.
658    pub fn add_indirect_ancestor(&mut self, uid: EntityUID) {
659        if !self.parents.contains(&uid) {
660            self.indirect_ancestors.insert(uid);
661        }
662    }
663
664    /// Mark the given `UID` as a (direct) parent of this `Entity`, and
665    /// remove the UID from indirect ancestors
666    /// if it was previously added as an indirect ancestor
667    /// The caller of this code is responsible for maintaining
668    /// transitive closure of hierarchy.
669    pub fn add_parent(&mut self, uid: EntityUID) {
670        self.indirect_ancestors.remove(&uid);
671        self.parents.insert(uid);
672    }
673
674    /// Remove the given `UID` as an indirect ancestor of this `Entity`.
675    ///
676    /// No effect if the `UID` is a direct parent.
677    /// The caller of this code is responsible for maintaining
678    /// transitive closure of hierarchy.
679    pub fn remove_indirect_ancestor(&mut self, uid: &EntityUID) {
680        self.indirect_ancestors.remove(uid);
681    }
682
683    /// Remove the given `UID` as a (direct) parent of this `Entity`.
684    ///
685    /// No effect on the `Entity`'s indirect ancestors.
686    /// The caller of this code is responsible for maintaining
687    /// transitive closure of hierarchy.
688    pub fn remove_parent(&mut self, uid: &EntityUID) {
689        self.parents.remove(uid);
690    }
691
692    /// Consume the entity and return the entity's owned Uid, attributes, ancestors, parents, and tags.
693    #[allow(clippy::type_complexity)]
694    pub fn into_inner(
695        self,
696    ) -> (
697        EntityUID,
698        HashMap<SmolStr, PartialValue>,
699        HashSet<EntityUID>,
700        HashSet<EntityUID>,
701        HashMap<SmolStr, PartialValue>,
702    ) {
703        (
704            self.uid,
705            self.attrs.into_iter().collect(),
706            self.indirect_ancestors,
707            self.parents,
708            self.tags.into_iter().collect(),
709        )
710    }
711
712    /// Write the entity to a json document
713    pub fn write_to_json(&self, f: impl std::io::Write) -> Result<(), EntitiesError> {
714        let ejson = EntityJson::from_entity(self)?;
715        serde_json::to_writer_pretty(f, &ejson).map_err(JsonSerializationError::from)?;
716        Ok(())
717    }
718
719    /// write the entity to a json value
720    pub fn to_json_value(&self) -> Result<serde_json::Value, EntitiesError> {
721        let ejson = EntityJson::from_entity(self)?;
722        let v = serde_json::to_value(ejson).map_err(JsonSerializationError::from)?;
723        Ok(v)
724    }
725
726    /// write the entity to a json string
727    pub fn to_json_string(&self) -> Result<String, EntitiesError> {
728        let ejson = EntityJson::from_entity(self)?;
729        let string = serde_json::to_string(&ejson).map_err(JsonSerializationError::from)?;
730        Ok(string)
731    }
732}
733
734/// `Entity`s are equal if their UIDs are equal
735impl PartialEq for Entity {
736    fn eq(&self, other: &Self) -> bool {
737        self.uid() == other.uid()
738    }
739}
740
741impl Eq for Entity {}
742
743impl StaticallyTyped for Entity {
744    fn type_of(&self) -> Type {
745        self.uid.type_of()
746    }
747}
748
749impl TCNode<EntityUID> for Entity {
750    fn get_key(&self) -> EntityUID {
751        self.uid().clone()
752    }
753
754    fn add_edge_to(&mut self, k: EntityUID) {
755        self.add_indirect_ancestor(k);
756    }
757
758    fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
759        Box::new(self.ancestors())
760    }
761
762    fn has_edge_to(&self, e: &EntityUID) -> bool {
763        self.is_descendant_of(e)
764    }
765}
766
767impl TCNode<EntityUID> for Arc<Entity> {
768    fn get_key(&self) -> EntityUID {
769        self.uid().clone()
770    }
771
772    fn add_edge_to(&mut self, k: EntityUID) {
773        // Use Arc::make_mut to get a mutable reference to the inner value
774        Arc::make_mut(self).add_indirect_ancestor(k)
775    }
776
777    fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
778        Box::new(self.ancestors())
779    }
780
781    fn has_edge_to(&self, e: &EntityUID) -> bool {
782        self.is_descendant_of(e)
783    }
784}
785
786impl std::fmt::Display for Entity {
787    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
788        write!(
789            f,
790            "{}:\n  attrs:{}\n  ancestors:{}",
791            self.uid,
792            self.attrs
793                .iter()
794                .map(|(k, v)| format!("{}: {}", k, v))
795                .join("; "),
796            self.ancestors().join(", ")
797        )
798    }
799}
800
801/// Error type for evaluation errors when evaluating an entity attribute or tag.
802/// Contains some extra contextual information and the underlying
803/// `EvaluationError`.
804//
805// This is NOT a publicly exported error type.
806#[derive(Debug, Diagnostic, Error)]
807#[error("failed to evaluate {} `{attr_or_tag}` of `{uid}`: {err}", if *.was_attr { "attribute" } else { "tag" })]
808pub struct EntityAttrEvaluationError {
809    /// UID of the entity where the error was encountered
810    pub uid: EntityUID,
811    /// Attribute or tag of the entity where the error was encountered
812    pub attr_or_tag: SmolStr,
813    /// If `attr_or_tag` was an attribute (`true`) or tag (`false`)
814    pub was_attr: bool,
815    /// Underlying evaluation error
816    #[diagnostic(transparent)]
817    pub err: EvaluationError,
818}
819
820#[cfg(test)]
821mod test {
822    use std::str::FromStr;
823
824    use super::*;
825
826    #[test]
827    fn display() {
828        let e = EntityUID::with_eid("eid");
829        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
830    }
831
832    #[test]
833    fn test_euid_equality() {
834        let e1 = EntityUID::with_eid("foo");
835        let e2 = EntityUID::from_components(
836            Name::parse_unqualified_name("test_entity_type")
837                .expect("should be a valid identifier")
838                .into(),
839            Eid::Eid("foo".into()),
840            None,
841        );
842        let e3 = EntityUID::from_components(
843            Name::parse_unqualified_name("Unspecified")
844                .expect("should be a valid identifier")
845                .into(),
846            Eid::Eid("foo".into()),
847            None,
848        );
849
850        // an EUID is equal to itself
851        assert_eq!(e1, e1);
852        assert_eq!(e2, e2);
853
854        // constructing with `with_euid` or `from_components` is the same
855        assert_eq!(e1, e2);
856
857        // other pairs are not equal
858        assert!(e1 != e3);
859    }
860
861    #[test]
862    fn action_checker() {
863        let euid = EntityUID::from_str("Action::\"view\"").unwrap();
864        assert!(euid.is_action());
865        let euid = EntityUID::from_str("Foo::Action::\"view\"").unwrap();
866        assert!(euid.is_action());
867        let euid = EntityUID::from_str("Foo::\"view\"").unwrap();
868        assert!(!euid.is_action());
869        let euid = EntityUID::from_str("Action::Foo::\"view\"").unwrap();
870        assert!(!euid.is_action());
871    }
872
873    #[test]
874    fn action_type_is_valid_id() {
875        assert!(Id::from_normalized_str(ACTION_ENTITY_TYPE).is_ok());
876    }
877
878    #[cfg(feature = "tolerant-ast")]
879    #[test]
880    fn error_entity() {
881        use cool_asserts::assert_matches;
882
883        let e = EntityUID::Error;
884        assert_matches!(e.eid(), Eid::ErrorEid);
885        assert_matches!(e.entity_type(), EntityType::ErrorEntityType);
886        assert!(!e.is_action());
887        assert_matches!(e.loc(), None);
888
889        let error_eid = Eid::ErrorEid;
890        assert_eq!(error_eid.escaped(), "Eid::Error");
891
892        let error_type = EntityType::ErrorEntityType;
893        assert!(!error_type.is_action());
894        assert_eq!(error_type.qualify_with(None), EntityType::ErrorEntityType);
895        assert_eq!(
896            error_type.qualify_with(Some(&Name(InternalName::from(Id::new_unchecked(
897                "EntityTypeError"
898            ))))),
899            EntityType::ErrorEntityType
900        );
901
902        assert_eq!(
903            error_type.name(),
904            &Name(InternalName::from(Id::new_unchecked("EntityTypeError")))
905        );
906        assert_eq!(error_type.loc(), None)
907    }
908
909    #[test]
910    fn entity_type_deserialization() {
911        let json = r#""some_entity_type""#;
912        let entity_type: EntityType = serde_json::from_str(json).unwrap();
913        assert_eq!(
914            entity_type.name().0.to_string(),
915            "some_entity_type".to_string()
916        )
917    }
918
919    #[test]
920    fn entity_type_serialization() {
921        let entity_type = EntityType::EntityType(Name(InternalName::from(Id::new_unchecked(
922            "some_entity_type",
923        ))));
924        let serialized = serde_json::to_string(&entity_type).unwrap();
925
926        assert_eq!(serialized, r#""some_entity_type""#);
927    }
928}