cedar_policy_core/ast/
entity.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::ast::*;
18use crate::entities::{err::EntitiesError, json::err::JsonSerializationError, EntityJson};
19use crate::evaluator::{EvaluationError, RestrictedEvaluator};
20use crate::extensions::Extensions;
21use crate::parser::err::ParseErrors;
22use crate::parser::Loc;
23use crate::transitive_closure::TCNode;
24use crate::FromNormalizedStr;
25use educe::Educe;
26use itertools::Itertools;
27use miette::Diagnostic;
28use serde::{de::Deserializer, ser::Serializer, Deserialize, Serialize};
29use smol_str::SmolStr;
30use std::collections::{BTreeMap, HashMap, HashSet};
31use std::str::FromStr;
32use std::sync::Arc;
33use thiserror::Error;
34
35#[cfg(feature = "tolerant-ast")]
36static ERROR_NAME: std::sync::LazyLock<Name> =
37    std::sync::LazyLock::new(|| Name(InternalName::from(Id::new_unchecked("EntityTypeError"))));
38
39#[cfg(feature = "tolerant-ast")]
40static EID_ERROR_STR: &str = "Eid::Error";
41
42#[cfg(feature = "tolerant-ast")]
43static ENTITY_TYPE_ERROR_STR: &str = "EntityType::Error";
44
45#[cfg(feature = "tolerant-ast")]
46static ENTITY_UID_ERROR_STR: &str = "EntityUID::Error";
47
48/// The entity type that Actions must have
49pub static ACTION_ENTITY_TYPE: &str = "Action";
50
51#[derive(PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
52#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
53/// Entity type - can be an error type when 'tolerant-ast' feature is enabled
54pub enum EntityType {
55    /// Entity type names are just [`Name`]s, but we have some operations on them specific to entity types.
56    EntityType(Name),
57    #[cfg(feature = "tolerant-ast")]
58    /// Represents an error node of an entity that failed to parse
59    ErrorEntityType,
60}
61
62impl<'de> Deserialize<'de> for EntityType {
63    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
64    where
65        D: Deserializer<'de>,
66    {
67        let name = Name::deserialize(deserializer)?;
68        Ok(EntityType::EntityType(name))
69    }
70}
71
72impl Serialize for EntityType {
73    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
74    where
75        S: Serializer,
76    {
77        match self {
78            EntityType::EntityType(name) => name.serialize(serializer),
79            #[cfg(feature = "tolerant-ast")]
80            EntityType::ErrorEntityType => serializer.serialize_str(ENTITY_TYPE_ERROR_STR),
81        }
82    }
83}
84
85impl EntityType {
86    /// Is this an Action entity type?
87    /// Returns true when an entity type is an action entity type. This compares the
88    /// base name for the type, so this will return true for any entity type named
89    /// `Action` regardless of namespaces.
90    pub fn is_action(&self) -> bool {
91        match self {
92            EntityType::EntityType(name) => {
93                name.as_ref().basename() == &Id::new_unchecked(ACTION_ENTITY_TYPE)
94            }
95            #[cfg(feature = "tolerant-ast")]
96            EntityType::ErrorEntityType => false,
97        }
98    }
99
100    /// The name of this entity type
101    pub fn name(&self) -> &Name {
102        match self {
103            EntityType::EntityType(name) => name,
104            #[cfg(feature = "tolerant-ast")]
105            EntityType::ErrorEntityType => &ERROR_NAME,
106        }
107    }
108
109    /// The source location of this entity type
110    pub fn loc(&self) -> Option<&Loc> {
111        match self {
112            EntityType::EntityType(name) => name.as_ref().loc(),
113            #[cfg(feature = "tolerant-ast")]
114            EntityType::ErrorEntityType => None,
115        }
116    }
117
118    /// Create a clone of this EntityType with given loc
119    pub fn with_loc(&self, loc: Option<&Loc>) -> Self {
120        match self {
121            EntityType::EntityType(name) => EntityType::EntityType(Name(InternalName {
122                id: name.0.id.clone(),
123                path: name.0.path.clone(),
124                loc: loc.cloned(),
125            })),
126            #[cfg(feature = "tolerant-ast")]
127            EntityType::ErrorEntityType => self.clone(),
128        }
129    }
130
131    /// Calls [`Name::qualify_with_name`] on the underlying [`Name`]
132    pub fn qualify_with(&self, namespace: Option<&Name>) -> Self {
133        match self {
134            EntityType::EntityType(name) => Self::EntityType(name.qualify_with_name(namespace)),
135            #[cfg(feature = "tolerant-ast")]
136            EntityType::ErrorEntityType => Self::ErrorEntityType,
137        }
138    }
139
140    /// Wraps [`Name::from_normalized_str`]
141    pub fn from_normalized_str(src: &str) -> Result<Self, ParseErrors> {
142        Name::from_normalized_str(src).map(Into::into)
143    }
144}
145
146impl From<Name> for EntityType {
147    fn from(n: Name) -> Self {
148        Self::EntityType(n)
149    }
150}
151
152impl From<EntityType> for Name {
153    fn from(ty: EntityType) -> Name {
154        match ty {
155            EntityType::EntityType(name) => name,
156            #[cfg(feature = "tolerant-ast")]
157            EntityType::ErrorEntityType => ERROR_NAME.clone(),
158        }
159    }
160}
161
162impl AsRef<Name> for EntityType {
163    fn as_ref(&self) -> &Name {
164        match self {
165            EntityType::EntityType(name) => name,
166            #[cfg(feature = "tolerant-ast")]
167            EntityType::ErrorEntityType => &ERROR_NAME,
168        }
169    }
170}
171
172impl FromStr for EntityType {
173    type Err = ParseErrors;
174
175    fn from_str(s: &str) -> Result<Self, Self::Err> {
176        s.parse().map(Self::EntityType)
177    }
178}
179
180impl std::fmt::Display for EntityType {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        match self {
183            EntityType::EntityType(name) => write!(f, "{name}"),
184            #[cfg(feature = "tolerant-ast")]
185            EntityType::ErrorEntityType => write!(f, "{ENTITY_TYPE_ERROR_STR}"),
186        }
187    }
188}
189
190/// Unique ID for an entity. These represent entities in the AST.
191#[derive(Educe, Serialize, Deserialize, Debug, Clone)]
192#[serde(rename = "EntityUID")]
193#[educe(PartialEq, Eq, Hash, PartialOrd, Ord)]
194pub struct EntityUIDImpl {
195    /// Typename of the entity
196    ty: EntityType,
197    /// EID of the entity
198    eid: Eid,
199    /// Location of the entity in policy source
200    #[serde(skip)]
201    #[educe(PartialEq(ignore))]
202    #[educe(Hash(ignore))]
203    #[educe(PartialOrd(ignore))]
204    loc: Option<Loc>,
205}
206
207impl EntityUIDImpl {
208    /// The source location of this entity
209    pub fn loc(&self) -> Option<Loc> {
210        self.loc.clone()
211    }
212}
213
214/// Unique ID for an entity. These represent entities in the AST.
215#[derive(Educe, Debug, Clone)]
216#[educe(PartialEq, Eq, Hash, PartialOrd, Ord)]
217pub enum EntityUID {
218    /// Unique ID for an entity. These represent entities in the AST
219    EntityUID(EntityUIDImpl),
220    #[cfg(feature = "tolerant-ast")]
221    /// Represents the ID of an error that failed to parse
222    Error,
223}
224
225impl<'de> Deserialize<'de> for EntityUID {
226    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
227    where
228        D: Deserializer<'de>,
229    {
230        let uid_impl = EntityUIDImpl::deserialize(deserializer)?;
231        Ok(EntityUID::EntityUID(uid_impl))
232    }
233}
234
235impl Serialize for EntityUID {
236    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
237    where
238        S: Serializer,
239    {
240        match self {
241            EntityUID::EntityUID(uid_impl) => uid_impl.serialize(serializer),
242            #[cfg(feature = "tolerant-ast")]
243            EntityUID::Error => serializer.serialize_str(ENTITY_UID_ERROR_STR),
244        }
245    }
246}
247
248impl StaticallyTyped for EntityUID {
249    fn type_of(&self) -> Type {
250        match self {
251            EntityUID::EntityUID(entity_uid) => Type::Entity {
252                ty: entity_uid.ty.clone(),
253            },
254            #[cfg(feature = "tolerant-ast")]
255            EntityUID::Error => Type::Entity {
256                ty: EntityType::ErrorEntityType,
257            },
258        }
259    }
260}
261
262#[cfg(test)]
263impl EntityUID {
264    /// Create an `EntityUID` with the given string as its EID.
265    /// Useful for testing.
266    pub(crate) fn with_eid(eid: &str) -> Self {
267        Self::EntityUID(EntityUIDImpl {
268            ty: Self::test_entity_type(),
269            eid: Eid::Eid(eid.into()),
270            loc: None,
271        })
272    }
273
274    /// The type of entities created with the above `with_eid()`.
275    pub(crate) fn test_entity_type() -> EntityType {
276        let name = Name::parse_unqualified_name("test_entity_type")
277            .expect("test_entity_type should be a valid identifier");
278        EntityType::EntityType(name)
279    }
280}
281
282impl EntityUID {
283    /// Create an `EntityUID` with the given (unqualified) typename, and the given string as its EID.
284    pub fn with_eid_and_type(typename: &str, eid: &str) -> Result<Self, ParseErrors> {
285        Ok(Self::EntityUID(EntityUIDImpl {
286            ty: EntityType::EntityType(Name::parse_unqualified_name(typename)?),
287            eid: Eid::Eid(eid.into()),
288            loc: None,
289        }))
290    }
291
292    /// Split into the `EntityType` representing the entity type, and the `Eid`
293    /// representing its name
294    pub fn components(self) -> (EntityType, Eid) {
295        match self {
296            EntityUID::EntityUID(entity_uid) => (entity_uid.ty, entity_uid.eid),
297            #[cfg(feature = "tolerant-ast")]
298            EntityUID::Error => (EntityType::ErrorEntityType, Eid::ErrorEid),
299        }
300    }
301
302    /// Get the source location for this `EntityUID`.
303    pub fn loc(&self) -> Option<&Loc> {
304        match self {
305            EntityUID::EntityUID(entity_uid) => entity_uid.loc.as_ref(),
306            #[cfg(feature = "tolerant-ast")]
307            EntityUID::Error => None,
308        }
309    }
310
311    /// Create an [`EntityUID`] with the given typename and [`Eid`]
312    pub fn from_components(ty: EntityType, eid: Eid, loc: Option<Loc>) -> Self {
313        Self::EntityUID(EntityUIDImpl { ty, eid, loc })
314    }
315
316    /// Get the type component.
317    pub fn entity_type(&self) -> &EntityType {
318        match self {
319            EntityUID::EntityUID(entity_uid) => &entity_uid.ty,
320            #[cfg(feature = "tolerant-ast")]
321            EntityUID::Error => &EntityType::ErrorEntityType,
322        }
323    }
324
325    /// Get the Eid component.
326    pub fn eid(&self) -> &Eid {
327        match self {
328            EntityUID::EntityUID(entity_uid) => &entity_uid.eid,
329            #[cfg(feature = "tolerant-ast")]
330            EntityUID::Error => &Eid::ErrorEid,
331        }
332    }
333
334    /// Does this EntityUID refer to an action entity?
335    pub fn is_action(&self) -> bool {
336        self.entity_type().is_action()
337    }
338}
339
340impl std::fmt::Display for EntityUID {
341    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342        write!(f, "{}::\"{}\"", self.entity_type(), self.eid().escaped())
343    }
344}
345
346// allow `.parse()` on a string to make an `EntityUID`
347impl std::str::FromStr for EntityUID {
348    type Err = ParseErrors;
349
350    fn from_str(s: &str) -> Result<Self, Self::Err> {
351        crate::parser::parse_euid(s)
352    }
353}
354
355impl FromNormalizedStr for EntityUID {
356    fn describe_self() -> &'static str {
357        "Entity UID"
358    }
359}
360
361#[cfg(feature = "arbitrary")]
362impl<'a> arbitrary::Arbitrary<'a> for EntityUID {
363    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
364        Ok(Self::EntityUID(EntityUIDImpl {
365            ty: u.arbitrary()?,
366            eid: u.arbitrary()?,
367            loc: None,
368        }))
369    }
370}
371
372/// The `Eid` type represents the id of an `Entity`, without the typename.
373/// Together with the typename it comprises an `EntityUID`.
374/// For example, in `User::"alice"`, the `Eid` is `alice`.
375///
376/// `Eid` does not implement `Display`, partly because it is unclear whether
377/// `Display` should produce an escaped representation or an unescaped representation
378/// (see [#884](https://github.com/cedar-policy/cedar/issues/884)).
379/// To get an escaped representation, use `.escaped()`.
380/// To get an unescaped representation, use `.as_ref()`.
381#[derive(PartialEq, Eq, Debug, Clone, Hash, PartialOrd, Ord)]
382pub enum Eid {
383    /// Actual Eid
384    Eid(SmolStr),
385    #[cfg(feature = "tolerant-ast")]
386    /// Represents an Eid of an entity that failed to parse
387    ErrorEid,
388}
389
390impl<'de> Deserialize<'de> for Eid {
391    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
392    where
393        D: Deserializer<'de>,
394    {
395        let value = String::deserialize(deserializer)?;
396        Ok(Eid::Eid(SmolStr::from(value)))
397    }
398}
399
400impl Serialize for Eid {
401    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
402    where
403        S: Serializer,
404    {
405        match self {
406            Eid::Eid(s) => s.serialize(serializer),
407            #[cfg(feature = "tolerant-ast")]
408            Eid::ErrorEid => serializer.serialize_str(EID_ERROR_STR),
409        }
410    }
411}
412
413impl Eid {
414    /// Construct an Eid
415    pub fn new(eid: impl Into<SmolStr>) -> Self {
416        Eid::Eid(eid.into())
417    }
418
419    /// Get the contents of the `Eid` as an escaped string
420    pub fn escaped(&self) -> SmolStr {
421        match self {
422            Eid::Eid(smol_str) => smol_str.escape_debug().collect(),
423            #[cfg(feature = "tolerant-ast")]
424            Eid::ErrorEid => SmolStr::new_static(EID_ERROR_STR),
425        }
426    }
427
428    /// Get the underlying smolstr for this `Eid`
429    pub fn into_smolstr(self) -> SmolStr {
430        match self {
431            Eid::Eid(smol_str) => smol_str,
432            #[cfg(feature = "tolerant-ast")]
433            Eid::ErrorEid => SmolStr::new_static(EID_ERROR_STR),
434        }
435    }
436}
437
438impl AsRef<str> for Eid {
439    fn as_ref(&self) -> &str {
440        match self {
441            Eid::Eid(smol_str) => smol_str,
442            #[cfg(feature = "tolerant-ast")]
443            Eid::ErrorEid => EID_ERROR_STR,
444        }
445    }
446}
447
448#[cfg(feature = "arbitrary")]
449impl<'a> arbitrary::Arbitrary<'a> for Eid {
450    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
451        let x: String = u.arbitrary()?;
452        Ok(Self::Eid(x.into()))
453    }
454}
455
456/// Entity datatype
457#[derive(Debug, Clone)]
458pub struct Entity {
459    /// UID
460    uid: EntityUID,
461
462    /// Internal `BTreeMap` of attributes.
463    ///
464    /// We use a `BTreeMap` so that the keys have a deterministic order.
465    attrs: BTreeMap<SmolStr, PartialValue>,
466
467    /// Set of indirect ancestors of this `Entity` as UIDs
468    indirect_ancestors: HashSet<EntityUID>,
469
470    /// Set of direct ancestors (i.e., parents) as UIDs
471    ///
472    /// indirect_ancestors and parents should be disjoint
473    /// even if a parent is also an indirect parent through
474    /// a different parent
475    parents: HashSet<EntityUID>,
476
477    /// Tags on this entity (RFC 82)
478    ///
479    /// Like for `attrs`, we use a `BTreeMap` so that the tags have a
480    /// deterministic order.
481    tags: BTreeMap<SmolStr, PartialValue>,
482}
483
484impl std::hash::Hash for Entity {
485    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
486        self.uid.hash(state);
487    }
488}
489
490impl Entity {
491    /// Create a new `Entity` with this UID, attributes, ancestors, and tags
492    ///
493    /// # Errors
494    /// - Will error if any of the [`RestrictedExpr]`s in `attrs` or `tags` error when evaluated
495    pub fn new(
496        uid: EntityUID,
497        attrs: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
498        indirect_ancestors: HashSet<EntityUID>,
499        parents: HashSet<EntityUID>,
500        tags: impl IntoIterator<Item = (SmolStr, RestrictedExpr)>,
501        extensions: &Extensions<'_>,
502    ) -> Result<Self, EntityAttrEvaluationError> {
503        let evaluator = RestrictedEvaluator::new(extensions);
504        let evaluate_kvs = |(k, v): (SmolStr, RestrictedExpr), was_attr: bool| {
505            let attr_val = evaluator
506                .partial_interpret(v.as_borrowed())
507                .map_err(|err| EntityAttrEvaluationError {
508                    uid: uid.clone(),
509                    attr_or_tag: k.clone(),
510                    was_attr,
511                    err,
512                })?;
513            Ok((k, attr_val))
514        };
515        let evaluated_attrs = attrs
516            .into_iter()
517            .map(|kv| evaluate_kvs(kv, true))
518            .collect::<Result<_, EntityAttrEvaluationError>>()?;
519        let evaluated_tags = tags
520            .into_iter()
521            .map(|kv| evaluate_kvs(kv, false))
522            .collect::<Result<_, EntityAttrEvaluationError>>()?;
523        Ok(Entity {
524            uid,
525            attrs: evaluated_attrs,
526            indirect_ancestors,
527            parents,
528            tags: evaluated_tags,
529        })
530    }
531
532    /// Create a new [`Entity`] with this UID, attributes, ancestors, and tags
533    ///
534    /// Unlike in `Entity::new()`, in this constructor, attributes and tags are
535    /// expressed as `PartialValue`.
536    pub fn new_with_attr_partial_value(
537        uid: EntityUID,
538        attrs: impl IntoIterator<Item = (SmolStr, PartialValue)>,
539        indirect_ancestors: HashSet<EntityUID>,
540        parents: HashSet<EntityUID>,
541        tags: impl IntoIterator<Item = (SmolStr, PartialValue)>,
542    ) -> Self {
543        Self {
544            uid,
545            attrs: attrs.into_iter().collect(),
546            indirect_ancestors,
547            parents,
548            tags: tags.into_iter().collect(),
549        }
550    }
551
552    /// Get the UID of this entity
553    pub fn uid(&self) -> &EntityUID {
554        &self.uid
555    }
556
557    /// Get the value for the given attribute, or `None` if not present
558    pub fn get(&self, attr: &str) -> Option<&PartialValue> {
559        self.attrs.get(attr)
560    }
561
562    /// Get the value for the given tag, or `None` if not present
563    pub fn get_tag(&self, tag: &str) -> Option<&PartialValue> {
564        self.tags.get(tag)
565    }
566
567    /// Is this `Entity` a (direct or indirect) descendant of `e` in the entity hierarchy?
568    pub fn is_descendant_of(&self, e: &EntityUID) -> bool {
569        self.parents.contains(e) || self.indirect_ancestors.contains(e)
570    }
571
572    /// Is this `Entity` a an indirect descendant of `e` in the entity hierarchy?
573    pub fn is_indirect_descendant_of(&self, e: &EntityUID) -> bool {
574        self.indirect_ancestors.contains(e)
575    }
576
577    /// Is this `Entity` a direct decendant (child) of `e` in the entity hierarchy?
578    pub fn is_child_of(&self, e: &EntityUID) -> bool {
579        self.parents.contains(e)
580    }
581
582    /// Iterate over this entity's (direct or indirect) ancestors
583    pub fn ancestors(&self) -> impl Iterator<Item = &EntityUID> {
584        self.parents.iter().chain(self.indirect_ancestors.iter())
585    }
586
587    /// Iterate over this entity's indirect ancestors
588    pub fn indirect_ancestors(&self) -> impl Iterator<Item = &EntityUID> {
589        self.indirect_ancestors.iter()
590    }
591
592    /// Iterate over this entity's direct ancestors (parents)
593    pub fn parents(&self) -> impl Iterator<Item = &EntityUID> {
594        self.parents.iter()
595    }
596
597    /// Get the number of attributes on this entity
598    pub fn attrs_len(&self) -> usize {
599        self.attrs.len()
600    }
601
602    /// Get the number of tags on this entity
603    pub fn tags_len(&self) -> usize {
604        self.tags.len()
605    }
606
607    /// Iterate over this entity's attribute names
608    pub fn keys(&self) -> impl Iterator<Item = &SmolStr> {
609        self.attrs.keys()
610    }
611
612    /// Iterate over this entity's tag names
613    pub fn tag_keys(&self) -> impl Iterator<Item = &SmolStr> {
614        self.tags.keys()
615    }
616
617    /// Iterate over this entity's attributes
618    pub fn attrs(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
619        self.attrs.iter()
620    }
621
622    /// Iterate over this entity's tags
623    pub fn tags(&self) -> impl Iterator<Item = (&SmolStr, &PartialValue)> {
624        self.tags.iter()
625    }
626
627    /// Create an `Entity` with the given UID, no attributes, no parents, and no tags.
628    pub fn with_uid(uid: EntityUID) -> Self {
629        Self {
630            uid,
631            attrs: BTreeMap::new(),
632            indirect_ancestors: HashSet::new(),
633            parents: HashSet::new(),
634            tags: BTreeMap::new(),
635        }
636    }
637
638    /// Test if two `Entity` objects are deep/structurally equal.
639    /// That is, not only do they have the same UID, but also the same
640    /// attributes, attribute values, and ancestors/parents.
641    ///
642    /// Does not test that they have the same _direct_ parents, only that they have the same overall ancestor set.
643    pub fn deep_eq(&self, other: &Self) -> bool {
644        self.uid == other.uid
645            && self.attrs == other.attrs
646            && self.tags == other.tags
647            && (self.ancestors().collect::<HashSet<_>>())
648                == (other.ancestors().collect::<HashSet<_>>())
649    }
650
651    /// Mark the given `UID` as an indirect ancestor of this `Entity`
652    ///
653    /// The given `UID` will not be added as an indirecty ancestor if
654    /// it is already a direct ancestor (parent) of this `Entity`
655    /// The caller of this code is responsible for maintaining
656    /// transitive closure of hierarchy.
657    pub fn add_indirect_ancestor(&mut self, uid: EntityUID) {
658        if !self.parents.contains(&uid) {
659            self.indirect_ancestors.insert(uid);
660        }
661    }
662
663    /// Mark the given `UID` as a (direct) parent of this `Entity`, and
664    /// remove the UID from indirect ancestors
665    /// if it was previously added as an indirect ancestor
666    /// The caller of this code is responsible for maintaining
667    /// transitive closure of hierarchy.
668    pub fn add_parent(&mut self, uid: EntityUID) {
669        self.indirect_ancestors.remove(&uid);
670        self.parents.insert(uid);
671    }
672
673    /// Remove the given `UID` as an indirect ancestor of this `Entity`.
674    ///
675    /// No effect if the `UID` is a direct parent.
676    /// The caller of this code is responsible for maintaining
677    /// transitive closure of hierarchy.
678    pub fn remove_indirect_ancestor(&mut self, uid: &EntityUID) {
679        self.indirect_ancestors.remove(uid);
680    }
681
682    /// Remove the given `UID` as a (direct) parent of this `Entity`.
683    ///
684    /// No effect on the `Entity`'s indirect ancestors.
685    /// The caller of this code is responsible for maintaining
686    /// transitive closure of hierarchy.
687    pub fn remove_parent(&mut self, uid: &EntityUID) {
688        self.parents.remove(uid);
689    }
690
691    /// Remove all indirect ancestors of this `Entity`.
692    ///
693    /// The caller of this code is responsible for maintaining
694    /// transitive closure of hierarchy.
695    pub fn remove_all_indirect_ancestors(&mut self) {
696        self.indirect_ancestors.clear();
697    }
698
699    /// Consume the entity and return the entity's owned Uid, attributes, ancestors, parents, and tags.
700    #[allow(clippy::type_complexity)]
701    pub fn into_inner(
702        self,
703    ) -> (
704        EntityUID,
705        HashMap<SmolStr, PartialValue>,
706        HashSet<EntityUID>,
707        HashSet<EntityUID>,
708        HashMap<SmolStr, PartialValue>,
709    ) {
710        (
711            self.uid,
712            self.attrs.into_iter().collect(),
713            self.indirect_ancestors,
714            self.parents,
715            self.tags.into_iter().collect(),
716        )
717    }
718
719    /// Write the entity to a json document
720    pub fn write_to_json(&self, f: impl std::io::Write) -> Result<(), EntitiesError> {
721        let ejson = EntityJson::from_entity(self)?;
722        serde_json::to_writer_pretty(f, &ejson).map_err(JsonSerializationError::from)?;
723        Ok(())
724    }
725
726    /// write the entity to a json value
727    pub fn to_json_value(&self) -> Result<serde_json::Value, EntitiesError> {
728        let ejson = EntityJson::from_entity(self)?;
729        let v = serde_json::to_value(ejson).map_err(JsonSerializationError::from)?;
730        Ok(v)
731    }
732
733    /// write the entity to a json string
734    pub fn to_json_string(&self) -> Result<String, EntitiesError> {
735        let ejson = EntityJson::from_entity(self)?;
736        let string = serde_json::to_string(&ejson).map_err(JsonSerializationError::from)?;
737        Ok(string)
738    }
739}
740
741/// `Entity`s are equal if their UIDs are equal
742impl PartialEq for Entity {
743    fn eq(&self, other: &Self) -> bool {
744        self.uid() == other.uid()
745    }
746}
747
748impl Eq for Entity {}
749
750impl StaticallyTyped for Entity {
751    fn type_of(&self) -> Type {
752        self.uid.type_of()
753    }
754}
755
756impl TCNode<EntityUID> for Entity {
757    fn get_key(&self) -> EntityUID {
758        self.uid().clone()
759    }
760
761    fn add_edge_to(&mut self, k: EntityUID) {
762        self.add_indirect_ancestor(k);
763    }
764
765    fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
766        Box::new(self.ancestors())
767    }
768
769    fn has_edge_to(&self, e: &EntityUID) -> bool {
770        self.is_descendant_of(e)
771    }
772
773    fn reset_edges(&mut self) {
774        self.remove_all_indirect_ancestors()
775    }
776
777    fn direct_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
778        Box::new(self.parents())
779    }
780}
781
782impl TCNode<EntityUID> for Arc<Entity> {
783    fn get_key(&self) -> EntityUID {
784        self.uid().clone()
785    }
786
787    fn add_edge_to(&mut self, k: EntityUID) {
788        // Use Arc::make_mut to get a mutable reference to the inner value
789        Arc::make_mut(self).add_indirect_ancestor(k)
790    }
791
792    fn out_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
793        Box::new(self.ancestors())
794    }
795
796    fn has_edge_to(&self, e: &EntityUID) -> bool {
797        self.is_descendant_of(e)
798    }
799
800    fn reset_edges(&mut self) {
801        // Use Arc::make_mut to get a mutable reference to the inner value
802        Arc::make_mut(self).remove_all_indirect_ancestors()
803    }
804
805    fn direct_edges(&self) -> Box<dyn Iterator<Item = &EntityUID> + '_> {
806        Box::new(self.parents())
807    }
808}
809
810impl std::fmt::Display for Entity {
811    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
812        write!(
813            f,
814            "{}:\n  attrs:{}\n  ancestors:{}",
815            self.uid,
816            self.attrs
817                .iter()
818                .map(|(k, v)| format!("{k}: {v}"))
819                .join("; "),
820            self.ancestors().join(", ")
821        )
822    }
823}
824
825/// Error type for evaluation errors when evaluating an entity attribute or tag.
826/// Contains some extra contextual information and the underlying
827/// `EvaluationError`.
828//
829// This is NOT a publicly exported error type.
830#[derive(Debug, Diagnostic, Error)]
831#[error("failed to evaluate {} `{attr_or_tag}` of `{uid}`: {err}", if *.was_attr { "attribute" } else { "tag" })]
832pub struct EntityAttrEvaluationError {
833    /// UID of the entity where the error was encountered
834    pub uid: EntityUID,
835    /// Attribute or tag of the entity where the error was encountered
836    pub attr_or_tag: SmolStr,
837    /// If `attr_or_tag` was an attribute (`true`) or tag (`false`)
838    pub was_attr: bool,
839    /// Underlying evaluation error
840    #[diagnostic(transparent)]
841    pub err: EvaluationError,
842}
843
844#[cfg(test)]
845mod test {
846    use std::str::FromStr;
847
848    use super::*;
849
850    #[test]
851    fn display() {
852        let e = EntityUID::with_eid("eid");
853        assert_eq!(format!("{e}"), "test_entity_type::\"eid\"");
854    }
855
856    #[test]
857    fn test_euid_equality() {
858        let e1 = EntityUID::with_eid("foo");
859        let e2 = EntityUID::from_components(
860            Name::parse_unqualified_name("test_entity_type")
861                .expect("should be a valid identifier")
862                .into(),
863            Eid::Eid("foo".into()),
864            None,
865        );
866        let e3 = EntityUID::from_components(
867            Name::parse_unqualified_name("Unspecified")
868                .expect("should be a valid identifier")
869                .into(),
870            Eid::Eid("foo".into()),
871            None,
872        );
873
874        // an EUID is equal to itself
875        assert_eq!(e1, e1);
876        assert_eq!(e2, e2);
877
878        // constructing with `with_euid` or `from_components` is the same
879        assert_eq!(e1, e2);
880
881        // other pairs are not equal
882        assert!(e1 != e3);
883    }
884
885    #[test]
886    fn action_checker() {
887        let euid = EntityUID::from_str("Action::\"view\"").unwrap();
888        assert!(euid.is_action());
889        let euid = EntityUID::from_str("Foo::Action::\"view\"").unwrap();
890        assert!(euid.is_action());
891        let euid = EntityUID::from_str("Foo::\"view\"").unwrap();
892        assert!(!euid.is_action());
893        let euid = EntityUID::from_str("Action::Foo::\"view\"").unwrap();
894        assert!(!euid.is_action());
895    }
896
897    #[test]
898    fn action_type_is_valid_id() {
899        assert!(Id::from_normalized_str(ACTION_ENTITY_TYPE).is_ok());
900    }
901
902    #[cfg(feature = "tolerant-ast")]
903    #[test]
904    fn error_entity() {
905        use cool_asserts::assert_matches;
906
907        let e = EntityUID::Error;
908        assert_matches!(e.eid(), Eid::ErrorEid);
909        assert_matches!(e.entity_type(), EntityType::ErrorEntityType);
910        assert!(!e.is_action());
911        assert_matches!(e.loc(), None);
912
913        let error_eid = Eid::ErrorEid;
914        assert_eq!(error_eid.escaped(), "Eid::Error");
915
916        let error_type = EntityType::ErrorEntityType;
917        assert!(!error_type.is_action());
918        assert_eq!(error_type.qualify_with(None), EntityType::ErrorEntityType);
919        assert_eq!(
920            error_type.qualify_with(Some(&Name(InternalName::from(Id::new_unchecked(
921                "EntityTypeError"
922            ))))),
923            EntityType::ErrorEntityType
924        );
925
926        assert_eq!(
927            error_type.name(),
928            &Name(InternalName::from(Id::new_unchecked("EntityTypeError")))
929        );
930        assert_eq!(error_type.loc(), None)
931    }
932
933    #[test]
934    fn entity_type_deserialization() {
935        let json = r#""some_entity_type""#;
936        let entity_type: EntityType = serde_json::from_str(json).unwrap();
937        assert_eq!(
938            entity_type.name().0.to_string(),
939            "some_entity_type".to_string()
940        )
941    }
942
943    #[test]
944    fn entity_type_serialization() {
945        let entity_type = EntityType::EntityType(Name(InternalName::from(Id::new_unchecked(
946            "some_entity_type",
947        ))));
948        let serialized = serde_json::to_string(&entity_type).unwrap();
949
950        assert_eq!(serialized, r#""some_entity_type""#);
951    }
952}