cedar_policy_core/ast/
id.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use serde::{Deserialize, Deserializer, Serialize};
18use smol_str::SmolStr;
19
20use crate::{parser::err::ParseErrors, FromNormalizedStr};
21
22/// Identifiers. Anything in `Id` should be a valid identifier, this means it
23/// does not contain, for instance, spaces or characters like '+'; and also is
24/// not one of the Cedar reserved identifiers (at time of writing,
25/// `true | false | if | then | else | in | is | like | has`).
26//
27// For now, internally, `Id`s are just owned `SmolString`s.
28#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
29pub struct Id(SmolStr);
30
31impl Id {
32    /// Create a new `Id` from a `String`, where it is the caller's
33    /// responsibility to ensure that the string is indeed a valid identifier.
34    ///
35    /// When possible, callers should not use this, and instead use `s.parse()`,
36    /// which checks that `s` is a valid identifier, and returns a parse error
37    /// if not.
38    ///
39    /// This method was created for the `From<cst::Ident> for Id` impl to use.
40    /// Since `parser::parse_ident()` implicitly uses that `From` impl itself,
41    /// if we tried to make that `From` impl go through `.parse()` like everyone
42    /// else, we'd get infinite recursion.  And, we assert that `cst::Ident` is
43    /// always already checked to contain a valid identifier, otherwise it would
44    /// never have been created.
45    pub(crate) fn new_unchecked(s: impl Into<SmolStr>) -> Id {
46        Id(s.into())
47    }
48
49    /// Get the underlying string
50    pub fn into_smolstr(self) -> SmolStr {
51        self.0
52    }
53}
54
55impl AsRef<str> for Id {
56    fn as_ref(&self) -> &str {
57        &self.0
58    }
59}
60
61impl std::fmt::Display for Id {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        write!(f, "{}", &self.0)
64    }
65}
66
67// allow `.parse()` on a string to make an `Id`
68impl std::str::FromStr for Id {
69    type Err = ParseErrors;
70
71    fn from_str(s: &str) -> Result<Self, Self::Err> {
72        crate::parser::parse_ident(s)
73    }
74}
75
76impl FromNormalizedStr for Id {
77    fn describe_self() -> &'static str {
78        "Id"
79    }
80}
81
82struct IdVisitor;
83
84impl<'de> serde::de::Visitor<'de> for IdVisitor {
85    type Value = Id;
86
87    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        formatter.write_str("a valid id")
89    }
90
91    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
92    where
93        E: serde::de::Error,
94    {
95        Id::from_normalized_str(value)
96            .map_err(|err| serde::de::Error::custom(format!("invalid id `{value}`: {err}")))
97    }
98}
99
100/// Deserialize an `Id` using `from_normalized_str`.
101/// This deserialization implementation is used in the JSON schema format.
102impl<'de> Deserialize<'de> for Id {
103    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
104    where
105        D: Deserializer<'de>,
106    {
107        deserializer.deserialize_str(IdVisitor)
108    }
109}
110
111#[cfg(feature = "arbitrary")]
112impl<'a> arbitrary::Arbitrary<'a> for Id {
113    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
114        // identifier syntax:
115        // IDENT     := ['_''a'-'z''A'-'Z']['_''a'-'z''A'-'Z''0'-'9']* - RESERVED
116        // BOOL      := 'true' | 'false'
117        // RESERVED  := BOOL | 'if' | 'then' | 'else' | 'in' | 'is' | 'like' | 'has'
118
119        let construct_list = |s: &str| s.chars().collect::<Vec<char>>();
120        let list_concat = |s1: &[char], s2: &[char]| [s1, s2].concat();
121        // the set of the first character of an identifier
122        let head_letters = construct_list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");
123        // the set of the remaining characters of an identifier
124        let tail_letters = list_concat(&construct_list("0123456789"), &head_letters);
125        // identifier character count minus 1
126        let remaining_length = u.int_in_range(0..=16)?;
127        let mut cs = vec![*u.choose(&head_letters)?];
128        cs.extend(
129            (0..remaining_length)
130                .map(|_| u.choose(&tail_letters))
131                .collect::<Result<Vec<&char>, _>>()?,
132        );
133        let mut s: String = cs.into_iter().collect();
134        // Should the parsing fails, the string should be reserved word.
135        // Append a `_` to create a valid Id.
136        if crate::parser::parse_ident(&s).is_err() {
137            s.push('_');
138        }
139        Ok(Self::new_unchecked(s))
140    }
141
142    fn size_hint(depth: usize) -> (usize, Option<usize>) {
143        arbitrary::size_hint::and_all(&[
144            // for arbitrary length
145            <usize as arbitrary::Arbitrary>::size_hint(depth),
146            // for arbitrary choices
147            // we use the size hint of a vector of `u8` to get an underestimate of bytes required by the sequence of choices.
148            <Vec<u8> as arbitrary::Arbitrary>::size_hint(depth),
149        ])
150    }
151}
152
153/// Like `Id`, except this specifically _can_ contain Cedar reserved identifiers.
154/// (It still can't contain, for instance, spaces or characters like '+'.)
155//
156// For now, internally, `AnyId`s are just owned `SmolString`s.
157#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
158pub struct AnyId(SmolStr);
159
160impl AnyId {
161    /// Create a new `AnyId` from a `String`, where it is the caller's
162    /// responsibility to ensure that the string is indeed a valid `AnyId`.
163    ///
164    /// When possible, callers should not use this, and instead use `s.parse()`,
165    /// which checks that `s` is a valid `AnyId`, and returns a parse error
166    /// if not.
167    ///
168    /// This method was created for the `From<cst::Ident> for AnyId` impl to use.
169    /// See notes on `Id::new_unchecked()`.
170    pub(crate) fn new_unchecked(s: impl Into<SmolStr>) -> AnyId {
171        AnyId(s.into())
172    }
173
174    /// Get the underlying string
175    pub fn into_smolstr(self) -> SmolStr {
176        self.0
177    }
178}
179
180struct AnyIdVisitor;
181
182impl<'de> serde::de::Visitor<'de> for AnyIdVisitor {
183    type Value = AnyId;
184
185    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        formatter.write_str("any id")
187    }
188
189    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
190    where
191        E: serde::de::Error,
192    {
193        AnyId::from_normalized_str(value)
194            .map_err(|err| serde::de::Error::custom(format!("invalid id `{value}`: {err}")))
195    }
196}
197
198/// Deserialize an `AnyId` using `from_normalized_str`.
199/// This deserialization implementation is used in the JSON policy format.
200impl<'de> Deserialize<'de> for AnyId {
201    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
202    where
203        D: Deserializer<'de>,
204    {
205        deserializer.deserialize_str(AnyIdVisitor)
206    }
207}
208
209impl AsRef<str> for AnyId {
210    fn as_ref(&self) -> &str {
211        &self.0
212    }
213}
214
215impl std::fmt::Display for AnyId {
216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        write!(f, "{}", &self.0)
218    }
219}
220
221// allow `.parse()` on a string to make an `AnyId`
222impl std::str::FromStr for AnyId {
223    type Err = ParseErrors;
224
225    fn from_str(s: &str) -> Result<Self, Self::Err> {
226        crate::parser::parse_anyid(s)
227    }
228}
229
230impl FromNormalizedStr for AnyId {
231    fn describe_self() -> &'static str {
232        "AnyId"
233    }
234}
235
236#[cfg(feature = "arbitrary")]
237impl<'a> arbitrary::Arbitrary<'a> for AnyId {
238    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
239        // AnyId syntax:
240        // ['_''a'-'z''A'-'Z']['_''a'-'z''A'-'Z''0'-'9']*
241
242        let construct_list = |s: &str| s.chars().collect::<Vec<char>>();
243        let list_concat = |s1: &[char], s2: &[char]| [s1, s2].concat();
244        // the set of the first character of an AnyId
245        let head_letters = construct_list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");
246        // the set of the remaining characters of an AnyId
247        let tail_letters = list_concat(&construct_list("0123456789"), &head_letters);
248        // identifier character count minus 1
249        let remaining_length = u.int_in_range(0..=16)?;
250        let mut cs = vec![*u.choose(&head_letters)?];
251        cs.extend(
252            (0..remaining_length)
253                .map(|_| u.choose(&tail_letters))
254                .collect::<Result<Vec<&char>, _>>()?,
255        );
256        let s: String = cs.into_iter().collect();
257        debug_assert!(
258            crate::parser::parse_anyid(&s).is_ok(),
259            "all strings constructed this way should be valid AnyIds, but this one is not: {s:?}"
260        );
261        Ok(Self::new_unchecked(s))
262    }
263
264    fn size_hint(depth: usize) -> (usize, Option<usize>) {
265        arbitrary::size_hint::and_all(&[
266            // for arbitrary length
267            <usize as arbitrary::Arbitrary>::size_hint(depth),
268            // for arbitrary choices
269            // we use the size hint of a vector of `u8` to get an underestimate of bytes required by the sequence of choices.
270            <Vec<u8> as arbitrary::Arbitrary>::size_hint(depth),
271        ])
272    }
273}
274
275// PANIC SAFETY: unit-test code
276#[allow(clippy::panic)]
277#[cfg(test)]
278mod test {
279    use super::*;
280
281    #[test]
282    fn normalized_id() {
283        Id::from_normalized_str("foo").expect("should be OK");
284        Id::from_normalized_str("foo::bar").expect_err("shouldn't be OK");
285        Id::from_normalized_str(r#"foo::"bar""#).expect_err("shouldn't be OK");
286        Id::from_normalized_str(" foo").expect_err("shouldn't be OK");
287        Id::from_normalized_str("foo ").expect_err("shouldn't be OK");
288        Id::from_normalized_str("foo\n").expect_err("shouldn't be OK");
289        Id::from_normalized_str("foo//comment").expect_err("shouldn't be OK");
290    }
291}