Skip to main content

cedar_policy_core/ast/
name.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::id::Id;
18use itertools::Itertools;
19use serde::{Deserialize, Deserializer, Serialize, Serializer};
20use smol_str::ToSmolStr;
21use std::sync::Arc;
22
23use crate::parser::err::ParseErrors;
24use crate::parser::Loc;
25use crate::FromNormalizedStr;
26
27use super::PrincipalOrResource;
28
29/// This is the `Name` type used to name types, functions, etc.
30/// The name can include namespaces.
31/// Clone is O(1).
32#[derive(Debug, Clone)]
33pub struct Name {
34    /// Basename
35    pub(crate) id: Id,
36    /// Namespaces
37    pub(crate) path: Arc<Vec<Id>>,
38    /// Location of the name in source
39    pub(crate) loc: Option<Loc>,
40}
41
42/// `PartialEq` implementation ignores the `loc`.
43impl PartialEq for Name {
44    fn eq(&self, other: &Self) -> bool {
45        self.id == other.id && self.path == other.path
46    }
47}
48impl Eq for Name {}
49
50impl std::hash::Hash for Name {
51    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
52        // hash the ty and eid, in line with the `PartialEq` impl which compares
53        // the ty and eid.
54        self.id.hash(state);
55        self.path.hash(state);
56    }
57}
58
59impl PartialOrd for Name {
60    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
61        Some(self.cmp(other))
62    }
63}
64impl Ord for Name {
65    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
66        self.id.cmp(&other.id).then(self.path.cmp(&other.path))
67    }
68}
69
70/// A shortcut for `Name::unqualified_name`
71impl From<Id> for Name {
72    fn from(value: Id) -> Self {
73        Self::unqualified_name(value)
74    }
75}
76
77/// Convert a `Name` to an `Id`
78/// The error type is the unit type because the reason the conversion fails
79/// is obvious
80impl TryFrom<Name> for Id {
81    type Error = ();
82    fn try_from(value: Name) -> Result<Self, Self::Error> {
83        if value.is_unqualified() {
84            Ok(value.id)
85        } else {
86            Err(())
87        }
88    }
89}
90
91impl Name {
92    /// A full constructor for `Name`
93    pub fn new(basename: Id, path: impl IntoIterator<Item = Id>, loc: Option<Loc>) -> Self {
94        Self {
95            id: basename,
96            path: Arc::new(path.into_iter().collect()),
97            loc,
98        }
99    }
100
101    /// Create a `Name` with no path (no namespaces).
102    pub fn unqualified_name(id: Id) -> Self {
103        Self {
104            id,
105            path: Arc::new(vec![]),
106            loc: None,
107        }
108    }
109
110    /// Create a `Name` with no path (no namespaces).
111    /// Returns an error if `s` is not a valid identifier.
112    pub fn parse_unqualified_name(s: &str) -> Result<Self, ParseErrors> {
113        Ok(Self {
114            id: s.parse()?,
115            path: Arc::new(vec![]),
116            loc: None,
117        })
118    }
119
120    /// Given a type basename and a namespace (as a `Name` itself),
121    /// return a `Name` representing the type's fully qualified name
122    pub fn type_in_namespace(basename: Id, namespace: Name, loc: Option<Loc>) -> Name {
123        let mut path = Arc::unwrap_or_clone(namespace.path);
124        path.push(namespace.id);
125        Name::new(basename, path, loc)
126    }
127
128    /// Get the source location
129    pub fn loc(&self) -> Option<&Loc> {
130        self.loc.as_ref()
131    }
132
133    /// Get the basename of the `Name` (ie, with namespaces stripped).
134    pub fn basename(&self) -> &Id {
135        &self.id
136    }
137
138    /// Get the namespace of the `Name`, as components
139    pub fn namespace_components(&self) -> impl Iterator<Item = &Id> {
140        self.path.iter()
141    }
142
143    /// Get the full namespace of the `Name`, as a single string.
144    ///
145    /// Examples:
146    /// - `foo::bar` --> the namespace is `"foo"`
147    /// - `bar` --> the namespace is `""`
148    /// - `foo::bar::baz` --> the namespace is `"foo::bar"`
149    pub fn namespace(&self) -> String {
150        self.path.iter().join("::")
151    }
152
153    /// Prefix the name with a optional namespace
154    /// When the name is not an `Id`, it doesn't make sense to prefix any
155    /// namespace and hence this method returns a copy of `self`
156    /// When the name is an `Id`, prefix it with the optional namespace
157    /// e.g., prefix `A::B`` with `Some(C)` or `None` produces `A::B`
158    /// prefix `A` with `Some(B::C)` yields `B::C::A`
159    pub fn prefix_namespace_if_unqualified(&self, namespace: Option<Name>) -> Name {
160        if self.is_unqualified() {
161            // Ideally, we want to implement `IntoIterator` for `Name`
162            match namespace {
163                Some(namespace) => Self::new(
164                    self.basename().clone(),
165                    namespace
166                        .namespace_components()
167                        .chain(std::iter::once(namespace.basename()))
168                        .cloned(),
169                    self.loc().cloned(),
170                ),
171                None => self.clone(),
172            }
173        } else {
174            self.clone()
175        }
176    }
177
178    /// Test if a `Name` is an `Id`
179    pub fn is_unqualified(&self) -> bool {
180        self.path.is_empty()
181    }
182}
183
184impl std::fmt::Display for Name {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        for elem in self.path.as_ref() {
187            write!(f, "{}::", elem)?;
188        }
189        write!(f, "{}", self.id)?;
190        Ok(())
191    }
192}
193
194/// Serialize a `Name` using its `Display` implementation
195/// This serialization implementation is used in the JSON schema format.
196impl Serialize for Name {
197    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
198    where
199        S: Serializer,
200    {
201        self.to_smolstr().serialize(serializer)
202    }
203}
204
205// allow `.parse()` on a string to make a `Name`
206impl std::str::FromStr for Name {
207    type Err = ParseErrors;
208
209    fn from_str(s: &str) -> Result<Self, Self::Err> {
210        crate::parser::parse_name(s)
211    }
212}
213
214impl FromNormalizedStr for Name {
215    fn describe_self() -> &'static str {
216        "Name"
217    }
218}
219
220struct NameVisitor;
221
222impl<'de> serde::de::Visitor<'de> for NameVisitor {
223    type Value = Name;
224
225    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        formatter.write_str("a name consisting of an optional namespace and id")
227    }
228
229    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
230    where
231        E: serde::de::Error,
232    {
233        Name::from_normalized_str(value)
234            .map_err(|err| serde::de::Error::custom(format!("invalid name `{value}`: {err}")))
235    }
236}
237
238/// Deserialize a `Name` using `from_normalized_str`
239/// This deserialization implementation is used in the JSON schema format.
240impl<'de> Deserialize<'de> for Name {
241    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
242    where
243        D: Deserializer<'de>,
244    {
245        deserializer.deserialize_str(NameVisitor)
246    }
247}
248
249#[cfg(feature = "arbitrary")]
250impl<'a> arbitrary::Arbitrary<'a> for Name {
251    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
252        // Computing hash of long id strings can be expensive Hence we limit the
253        // size of `path` such that DRT does not report slow units
254        let path_size = u.int_in_range(0..=8)?;
255        let path: Vec<Id> = (0..path_size)
256            .map(|_| u.arbitrary())
257            .collect::<Result<Vec<_>, _>>()?;
258        Ok(Self {
259            id: u.arbitrary()?,
260            path: Arc::new(path),
261            loc: None,
262        })
263    }
264}
265
266/// Identifier for a slot
267/// Clone is O(1).
268// This simply wraps a separate enum -- currently `ValidSlotId` -- in case we
269// want to generalize later
270#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
271#[serde(transparent)]
272pub struct SlotId(pub(crate) ValidSlotId);
273
274impl SlotId {
275    /// Get the slot for `principal`
276    pub fn principal() -> Self {
277        Self(ValidSlotId::Principal)
278    }
279
280    /// Get the slot for `resource`
281    pub fn resource() -> Self {
282        Self(ValidSlotId::Resource)
283    }
284
285    /// Check if a slot represents a principal
286    pub fn is_principal(&self) -> bool {
287        matches!(self, Self(ValidSlotId::Principal))
288    }
289
290    /// Check if a slot represents a resource
291    pub fn is_resource(&self) -> bool {
292        matches!(self, Self(ValidSlotId::Resource))
293    }
294}
295
296impl From<PrincipalOrResource> for SlotId {
297    fn from(v: PrincipalOrResource) -> Self {
298        match v {
299            PrincipalOrResource::Principal => SlotId::principal(),
300            PrincipalOrResource::Resource => SlotId::resource(),
301        }
302    }
303}
304
305impl std::fmt::Display for SlotId {
306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        write!(f, "{}", self.0)
308    }
309}
310
311/// Two possible variants for Slots
312#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
313pub(crate) enum ValidSlotId {
314    #[serde(rename = "?principal")]
315    Principal,
316    #[serde(rename = "?resource")]
317    Resource,
318}
319
320impl std::fmt::Display for ValidSlotId {
321    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322        let s = match self {
323            ValidSlotId::Principal => "principal",
324            ValidSlotId::Resource => "resource",
325        };
326        write!(f, "?{s}")
327    }
328}
329
330/// [`SlotId`] plus a source location
331#[derive(Debug, Clone)]
332pub struct Slot {
333    /// [`SlotId`]
334    pub id: SlotId,
335    /// Source location, if available
336    pub loc: Option<Loc>,
337}
338
339/// `PartialEq` implementation ignores the `loc`. Slots are equal if their ids
340/// are equal.
341impl PartialEq for Slot {
342    fn eq(&self, other: &Self) -> bool {
343        self.id == other.id
344    }
345}
346impl Eq for Slot {}
347
348impl std::hash::Hash for Slot {
349    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
350        // hash only the id, in line with the `PartialEq` impl which compares
351        // only the id
352        self.id.hash(state);
353    }
354}
355
356#[cfg(test)]
357mod vars_test {
358    use super::*;
359    // Make sure the vars always parse correctly
360    #[test]
361    fn vars_correct() {
362        SlotId::principal();
363        SlotId::resource();
364    }
365
366    #[test]
367    fn display() {
368        assert_eq!(format!("{}", SlotId::principal()), "?principal")
369    }
370}
371
372#[cfg(test)]
373mod test {
374    use super::*;
375
376    #[test]
377    fn normalized_name() {
378        Name::from_normalized_str("foo").expect("should be OK");
379        Name::from_normalized_str("foo::bar").expect("should be OK");
380        Name::from_normalized_str(r#"foo::"bar""#).expect_err("shouldn't be OK");
381        Name::from_normalized_str(" foo").expect_err("shouldn't be OK");
382        Name::from_normalized_str("foo ").expect_err("shouldn't be OK");
383        Name::from_normalized_str("foo\n").expect_err("shouldn't be OK");
384        Name::from_normalized_str("foo//comment").expect_err("shouldn't be OK");
385    }
386
387    #[test]
388    fn prefix_namespace() {
389        assert_eq!(
390            "foo::bar::baz",
391            Name::from_normalized_str("baz")
392                .unwrap()
393                .prefix_namespace_if_unqualified(Some("foo::bar".parse().unwrap()))
394                .to_smolstr()
395        );
396        assert_eq!(
397            "C::D",
398            Name::from_normalized_str("C::D")
399                .unwrap()
400                .prefix_namespace_if_unqualified(Some("A::B".parse().unwrap()))
401                .to_smolstr()
402        );
403        assert_eq!(
404            "A::B::C::D",
405            Name::from_normalized_str("D")
406                .unwrap()
407                .prefix_namespace_if_unqualified(Some("A::B::C".parse().unwrap()))
408                .to_smolstr()
409        );
410        assert_eq!(
411            "B::C::D",
412            Name::from_normalized_str("B::C::D")
413                .unwrap()
414                .prefix_namespace_if_unqualified(Some("A".parse().unwrap()))
415                .to_smolstr()
416        );
417    }
418}