cedar_policy_core/ast/
name.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use super::id::Id;
18use itertools::Itertools;
19use serde::{Deserialize, Deserializer, Serialize, Serializer};
20use smol_str::ToSmolStr;
21use std::sync::Arc;
22
23use crate::parser::err::ParseErrors;
24use crate::parser::Loc;
25use crate::FromNormalizedStr;
26
27use super::PrincipalOrResource;
28
29/// This is the `Name` type used to name types, functions, etc.
30/// The name can include namespaces.
31/// Clone is O(1).
32#[derive(Debug, Clone)]
33pub struct Name {
34    /// Basename
35    pub(crate) id: Id,
36    /// Namespaces
37    pub(crate) path: Arc<Vec<Id>>,
38    /// Location of the name in source
39    pub(crate) loc: Option<Loc>,
40}
41
42/// `PartialEq` implementation ignores the `loc`.
43impl PartialEq for Name {
44    fn eq(&self, other: &Self) -> bool {
45        self.id == other.id && self.path == other.path
46    }
47}
48impl Eq for Name {}
49
50impl std::hash::Hash for Name {
51    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
52        // hash the ty and eid, in line with the `PartialEq` impl which compares
53        // the ty and eid.
54        self.id.hash(state);
55        self.path.hash(state);
56    }
57}
58
59impl PartialOrd for Name {
60    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
61        Some(self.cmp(other))
62    }
63}
64impl Ord for Name {
65    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
66        self.id.cmp(&other.id).then(self.path.cmp(&other.path))
67    }
68}
69
70/// A shortcut for `Name::unqualified_name`
71impl From<Id> for Name {
72    fn from(value: Id) -> Self {
73        Self::unqualified_name(value)
74    }
75}
76
77/// Convert a `Name` to an `Id`
78/// The error type is the unit type because the reason the conversion fails
79/// is obvious
80impl TryFrom<Name> for Id {
81    type Error = ();
82    fn try_from(value: Name) -> Result<Self, Self::Error> {
83        if value.is_unqualified() {
84            Ok(value.id)
85        } else {
86            Err(())
87        }
88    }
89}
90
91impl Name {
92    /// A full constructor for `Name`
93    pub fn new(basename: Id, path: impl IntoIterator<Item = Id>, loc: Option<Loc>) -> Self {
94        Self {
95            id: basename,
96            path: Arc::new(path.into_iter().collect()),
97            loc,
98        }
99    }
100
101    /// Create a `Name` with no path (no namespaces).
102    pub fn unqualified_name(id: Id) -> Self {
103        Self {
104            id,
105            path: Arc::new(vec![]),
106            loc: None,
107        }
108    }
109
110    /// Create a `Name` with no path (no namespaces).
111    /// Returns an error if `s` is not a valid identifier.
112    pub fn parse_unqualified_name(s: &str) -> Result<Self, ParseErrors> {
113        Ok(Self {
114            id: s.parse()?,
115            path: Arc::new(vec![]),
116            loc: None,
117        })
118    }
119
120    /// Given a type basename and a namespace (as a `Name` itself),
121    /// return a `Name` representing the type's fully qualified name
122    pub fn type_in_namespace(basename: Id, namespace: Name, loc: Option<Loc>) -> Name {
123        let mut path = Arc::unwrap_or_clone(namespace.path);
124        path.push(namespace.id);
125        Name::new(basename, path, loc)
126    }
127
128    /// Get the source location
129    pub fn loc(&self) -> Option<&Loc> {
130        self.loc.as_ref()
131    }
132
133    /// Get the basename of the `Name` (ie, with namespaces stripped).
134    pub fn basename(&self) -> &Id {
135        &self.id
136    }
137
138    /// Get the namespace of the `Name`, as components
139    pub fn namespace_components(&self) -> impl Iterator<Item = &Id> {
140        self.path.iter()
141    }
142
143    /// Get the full namespace of the `Name`, as a single string.
144    ///
145    /// Examples:
146    /// - `foo::bar` --> the namespace is `"foo"`
147    /// - `bar` --> the namespace is `""`
148    /// - `foo::bar::baz` --> the namespace is `"foo::bar"`
149    pub fn namespace(&self) -> String {
150        self.path.iter().join("::")
151    }
152
153    /// Prefix the name with a optional namespace
154    /// When the name is not an `Id`, it doesn't make sense to prefix any
155    /// namespace and hence this method returns a copy of `self`
156    /// When the name is an `Id`, prefix it with the optional namespace
157    /// e.g., prefix `A::B`` with `Some(C)` or `None` produces `A::B`
158    /// prefix `A` with `Some(B::C)` yields `B::C::A`
159    pub fn prefix_namespace_if_unqualified(&self, namespace: Option<Name>) -> Name {
160        if self.is_unqualified() {
161            // Ideally, we want to implement `IntoIterator` for `Name`
162            match namespace {
163                Some(namespace) => Self::new(
164                    self.basename().clone(),
165                    namespace
166                        .namespace_components()
167                        .chain(std::iter::once(namespace.basename()))
168                        .cloned(),
169                    self.loc().cloned(),
170                ),
171                None => self.clone(),
172            }
173        } else {
174            self.clone()
175        }
176    }
177
178    /// Test if a `Name` is an `Id`
179    pub fn is_unqualified(&self) -> bool {
180        self.path.is_empty()
181    }
182}
183
184impl std::fmt::Display for Name {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        for elem in self.path.as_ref() {
187            write!(f, "{}::", elem)?;
188        }
189        write!(f, "{}", self.id)?;
190        Ok(())
191    }
192}
193
194/// Serialize a `Name` using its `Display` implementation
195/// This serialization implementation is used in the JSON schema format.
196impl Serialize for Name {
197    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
198    where
199        S: Serializer,
200    {
201        self.to_smolstr().serialize(serializer)
202    }
203}
204
205// allow `.parse()` on a string to make a `Name`
206impl std::str::FromStr for Name {
207    type Err = ParseErrors;
208
209    fn from_str(s: &str) -> Result<Self, Self::Err> {
210        crate::parser::parse_name(s)
211    }
212}
213
214impl FromNormalizedStr for Name {
215    fn describe_self() -> &'static str {
216        "Name"
217    }
218}
219
220struct NameVisitor;
221
222impl<'de> serde::de::Visitor<'de> for NameVisitor {
223    type Value = Name;
224
225    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        formatter.write_str("a name consisting of an optional namespace and id")
227    }
228
229    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
230    where
231        E: serde::de::Error,
232    {
233        Name::from_normalized_str(value)
234            .map_err(|err| serde::de::Error::custom(format!("invalid name `{value}`: {err}")))
235    }
236}
237
238/// Deserialize a `Name` using `from_normalized_str`
239/// This deserialization implementation is used in the JSON schema format.
240impl<'de> Deserialize<'de> for Name {
241    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
242    where
243        D: Deserializer<'de>,
244    {
245        deserializer.deserialize_str(NameVisitor)
246    }
247}
248
249#[cfg(feature = "arbitrary")]
250impl<'a> arbitrary::Arbitrary<'a> for Name {
251    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
252        Ok(Self {
253            id: u.arbitrary()?,
254            path: u.arbitrary()?,
255            loc: None,
256        })
257    }
258}
259
260/// Identifier for a slot
261/// Clone is O(1).
262// This simply wraps a separate enum -- currently `ValidSlotId` -- in case we
263// want to generalize later
264#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
265#[serde(transparent)]
266pub struct SlotId(pub(crate) ValidSlotId);
267
268impl SlotId {
269    /// Get the slot for `principal`
270    pub fn principal() -> Self {
271        Self(ValidSlotId::Principal)
272    }
273
274    /// Get the slot for `resource`
275    pub fn resource() -> Self {
276        Self(ValidSlotId::Resource)
277    }
278
279    /// Check if a slot represents a principal
280    pub fn is_principal(&self) -> bool {
281        matches!(self, Self(ValidSlotId::Principal))
282    }
283
284    /// Check if a slot represents a resource
285    pub fn is_resource(&self) -> bool {
286        matches!(self, Self(ValidSlotId::Resource))
287    }
288}
289
290impl From<PrincipalOrResource> for SlotId {
291    fn from(v: PrincipalOrResource) -> Self {
292        match v {
293            PrincipalOrResource::Principal => SlotId::principal(),
294            PrincipalOrResource::Resource => SlotId::resource(),
295        }
296    }
297}
298
299impl std::fmt::Display for SlotId {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        write!(f, "{}", self.0)
302    }
303}
304
305/// Two possible variants for Slots
306#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
307pub(crate) enum ValidSlotId {
308    #[serde(rename = "?principal")]
309    Principal,
310    #[serde(rename = "?resource")]
311    Resource,
312}
313
314impl std::fmt::Display for ValidSlotId {
315    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316        let s = match self {
317            ValidSlotId::Principal => "principal",
318            ValidSlotId::Resource => "resource",
319        };
320        write!(f, "?{s}")
321    }
322}
323
324/// [`SlotId`] plus a source location
325#[derive(Debug, Clone)]
326pub struct Slot {
327    /// [`SlotId`]
328    pub id: SlotId,
329    /// Source location, if available
330    pub loc: Option<Loc>,
331}
332
333/// `PartialEq` implementation ignores the `loc`. Slots are equal if their ids
334/// are equal.
335impl PartialEq for Slot {
336    fn eq(&self, other: &Self) -> bool {
337        self.id == other.id
338    }
339}
340impl Eq for Slot {}
341
342impl std::hash::Hash for Slot {
343    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
344        // hash only the id, in line with the `PartialEq` impl which compares
345        // only the id
346        self.id.hash(state);
347    }
348}
349
350#[cfg(test)]
351mod vars_test {
352    use super::*;
353    // Make sure the vars always parse correctly
354    #[test]
355    fn vars_correct() {
356        SlotId::principal();
357        SlotId::resource();
358    }
359
360    #[test]
361    fn display() {
362        assert_eq!(format!("{}", SlotId::principal()), "?principal")
363    }
364}
365
366#[cfg(test)]
367mod test {
368    use super::*;
369
370    #[test]
371    fn normalized_name() {
372        Name::from_normalized_str("foo").expect("should be OK");
373        Name::from_normalized_str("foo::bar").expect("should be OK");
374        Name::from_normalized_str(r#"foo::"bar""#).expect_err("shouldn't be OK");
375        Name::from_normalized_str(" foo").expect_err("shouldn't be OK");
376        Name::from_normalized_str("foo ").expect_err("shouldn't be OK");
377        Name::from_normalized_str("foo\n").expect_err("shouldn't be OK");
378        Name::from_normalized_str("foo//comment").expect_err("shouldn't be OK");
379    }
380
381    #[test]
382    fn prefix_namespace() {
383        assert_eq!(
384            "foo::bar::baz",
385            Name::from_normalized_str("baz")
386                .unwrap()
387                .prefix_namespace_if_unqualified(Some("foo::bar".parse().unwrap()))
388                .to_smolstr()
389        );
390        assert_eq!(
391            "C::D",
392            Name::from_normalized_str("C::D")
393                .unwrap()
394                .prefix_namespace_if_unqualified(Some("A::B".parse().unwrap()))
395                .to_smolstr()
396        );
397        assert_eq!(
398            "A::B::C::D",
399            Name::from_normalized_str("D")
400                .unwrap()
401                .prefix_namespace_if_unqualified(Some("A::B::C".parse().unwrap()))
402                .to_smolstr()
403        );
404        assert_eq!(
405            "B::C::D",
406            Name::from_normalized_str("B::C::D")
407                .unwrap()
408                .prefix_namespace_if_unqualified(Some("A".parse().unwrap()))
409                .to_smolstr()
410        );
411    }
412}