cedar_policy_core/ast/
id.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use serde::{Deserialize, Deserializer, Serialize};
18use smol_str::SmolStr;
19
20use crate::{parser::err::ParseErrors, FromNormalizedStr};
21
22use super::{InternalName, ReservedNameError};
23
24const RESERVED_ID: &str = "__cedar";
25
26/// Identifiers. Anything in `Id` should be a valid identifier, this means it
27/// does not contain, for instance, spaces or characters like '+'; and also is
28/// not one of the Cedar reserved identifiers (at time of writing,
29/// `true | false | if | then | else | in | is | like | has`).
30//
31// For now, internally, `Id`s are just owned `SmolString`s.
32#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
33pub struct Id(SmolStr);
34
35impl Id {
36    /// Create a new `Id` from a `String`, where it is the caller's
37    /// responsibility to ensure that the string is indeed a valid identifier.
38    ///
39    /// When possible, callers should not use this, and instead use `s.parse()`,
40    /// which checks that `s` is a valid identifier, and returns a parse error
41    /// if not.
42    ///
43    /// This method was created for the `From<cst::Ident> for Id` impl to use.
44    /// Since `parser::parse_ident()` implicitly uses that `From` impl itself,
45    /// if we tried to make that `From` impl go through `.parse()` like everyone
46    /// else, we'd get infinite recursion.  And, we assert that `cst::Ident` is
47    /// always already checked to contain a valid identifier, otherwise it would
48    /// never have been created.
49    pub(crate) fn new_unchecked(s: impl Into<SmolStr>) -> Id {
50        Id(s.into())
51    }
52
53    /// Get the underlying string
54    pub fn into_smolstr(self) -> SmolStr {
55        self.0
56    }
57
58    /// Return if the `Id` is reserved (i.e., `__cedar`)
59    /// Note that it does not test if the `Id` string is a reserved keyword
60    /// as the parser already ensures that it is not
61    pub fn is_reserved(&self) -> bool {
62        self.as_ref() == RESERVED_ID
63    }
64}
65
66impl AsRef<str> for Id {
67    fn as_ref(&self) -> &str {
68        &self.0
69    }
70}
71
72impl std::fmt::Display for Id {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        write!(f, "{}", &self.0)
75    }
76}
77
78// allow `.parse()` on a string to make an `Id`
79impl std::str::FromStr for Id {
80    type Err = ParseErrors;
81
82    fn from_str(s: &str) -> Result<Self, Self::Err> {
83        crate::parser::parse_ident(s)
84    }
85}
86
87impl FromNormalizedStr for Id {
88    fn describe_self() -> &'static str {
89        "Id"
90    }
91}
92
93/// An `Id` that is not equal to `__cedar`, as specified by RFC 52
94#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
95#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
96#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
97pub struct UnreservedId(#[cfg_attr(feature = "wasm", tsify(type = "string"))] pub(crate) Id);
98
99impl From<UnreservedId> for Id {
100    fn from(value: UnreservedId) -> Self {
101        value.0
102    }
103}
104
105impl TryFrom<Id> for UnreservedId {
106    type Error = ReservedNameError;
107    fn try_from(value: Id) -> Result<Self, Self::Error> {
108        if value.is_reserved() {
109            Err(ReservedNameError(InternalName::unqualified_name(
110                value, None,
111            )))
112        } else {
113            Ok(Self(value))
114        }
115    }
116}
117
118impl AsRef<Id> for UnreservedId {
119    fn as_ref(&self) -> &Id {
120        &self.0
121    }
122}
123
124impl AsRef<str> for UnreservedId {
125    fn as_ref(&self) -> &str {
126        self.0.as_ref()
127    }
128}
129
130impl std::fmt::Display for UnreservedId {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        self.0.fmt(f)
133    }
134}
135
136impl std::str::FromStr for UnreservedId {
137    type Err = ParseErrors;
138    fn from_str(s: &str) -> Result<Self, Self::Err> {
139        Id::from_str(s).and_then(|id| id.try_into().map_err(ParseErrors::singleton))
140    }
141}
142
143impl FromNormalizedStr for UnreservedId {
144    fn describe_self() -> &'static str {
145        "Unreserved Id"
146    }
147}
148
149impl UnreservedId {
150    /// Create an [`UnreservedId`] from an empty string
151    pub(crate) fn empty() -> Self {
152        // PANIC SAFETY: "" does not contain `__cedar`
153        #[allow(clippy::unwrap_used)]
154        Id("".into()).try_into().unwrap()
155    }
156}
157
158struct IdVisitor;
159
160impl serde::de::Visitor<'_> for IdVisitor {
161    type Value = Id;
162
163    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        formatter.write_str("a valid id")
165    }
166
167    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
168    where
169        E: serde::de::Error,
170    {
171        Id::from_normalized_str(value)
172            .map_err(|err| serde::de::Error::custom(format!("invalid id `{value}`: {err}")))
173    }
174}
175
176/// Deserialize an `Id` using `from_normalized_str`.
177/// This deserialization implementation is used in the JSON schema format.
178impl<'de> Deserialize<'de> for Id {
179    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
180    where
181        D: Deserializer<'de>,
182    {
183        deserializer.deserialize_str(IdVisitor)
184    }
185}
186
187/// Deserialize a [`UnreservedId`] using `from_normalized_str`
188/// This deserialization implementation is used in the JSON schema format.
189impl<'de> Deserialize<'de> for UnreservedId {
190    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
191    where
192        D: Deserializer<'de>,
193    {
194        deserializer
195            .deserialize_str(IdVisitor)
196            .and_then(|n| n.try_into().map_err(serde::de::Error::custom))
197    }
198}
199
200#[cfg(feature = "arbitrary")]
201impl<'a> arbitrary::Arbitrary<'a> for Id {
202    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
203        // identifier syntax:
204        // IDENT     := ['_''a'-'z''A'-'Z']['_''a'-'z''A'-'Z''0'-'9']* - RESERVED
205        // BOOL      := 'true' | 'false'
206        // RESERVED  := BOOL | 'if' | 'then' | 'else' | 'in' | 'is' | 'like' | 'has'
207
208        let construct_list = |s: &str| s.chars().collect::<Vec<char>>();
209        let list_concat = |s1: &[char], s2: &[char]| [s1, s2].concat();
210        // the set of the first character of an identifier
211        let head_letters = construct_list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");
212        // the set of the remaining characters of an identifier
213        let tail_letters = list_concat(&construct_list("0123456789"), &head_letters);
214        // identifier character count minus 1
215        let remaining_length = u.int_in_range(0..=16)?;
216        let mut cs = vec![*u.choose(&head_letters)?];
217        cs.extend(
218            (0..remaining_length)
219                .map(|_| u.choose(&tail_letters))
220                .collect::<Result<Vec<&char>, _>>()?,
221        );
222        let mut s: String = cs.into_iter().collect();
223        // Should the parsing fails, the string should be reserved word.
224        // Append a `_` to create a valid Id.
225        if crate::parser::parse_ident(&s).is_err() {
226            s.push('_');
227        }
228        Ok(Self::new_unchecked(s))
229    }
230
231    fn size_hint(depth: usize) -> (usize, Option<usize>) {
232        arbitrary::size_hint::and_all(&[
233            // for arbitrary length
234            <usize as arbitrary::Arbitrary>::size_hint(depth),
235            // for arbitrary choices
236            // we use the size hint of a vector of `u8` to get an underestimate of bytes required by the sequence of choices.
237            <Vec<u8> as arbitrary::Arbitrary>::size_hint(depth),
238        ])
239    }
240}
241
242#[cfg(feature = "arbitrary")]
243impl<'a> arbitrary::Arbitrary<'a> for UnreservedId {
244    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
245        let id: Id = u.arbitrary()?;
246        match UnreservedId::try_from(id.clone()) {
247            Ok(id) => Ok(id),
248            Err(_) => {
249                // PANIC SAFETY: `___cedar` is a valid unreserved id
250                #[allow(clippy::unwrap_used)]
251                let new_id = format!("_{id}").parse().unwrap();
252                Ok(new_id)
253            }
254        }
255    }
256
257    fn size_hint(depth: usize) -> (usize, Option<usize>) {
258        <Id as arbitrary::Arbitrary>::size_hint(depth)
259    }
260}
261
262/// Like `Id`, except this specifically _can_ contain Cedar reserved identifiers.
263/// (It still can't contain, for instance, spaces or characters like '+'.)
264//
265// For now, internally, `AnyId`s are just owned `SmolString`s.
266#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
267#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
268#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
269pub struct AnyId(SmolStr);
270
271impl AnyId {
272    /// Create a new `AnyId` from a `String`, where it is the caller's
273    /// responsibility to ensure that the string is indeed a valid `AnyId`.
274    ///
275    /// When possible, callers should not use this, and instead use `s.parse()`,
276    /// which checks that `s` is a valid `AnyId`, and returns a parse error
277    /// if not.
278    ///
279    /// This method was created for the `From<cst::Ident> for AnyId` impl to use.
280    /// See notes on `Id::new_unchecked()`.
281    pub(crate) fn new_unchecked(s: impl Into<SmolStr>) -> AnyId {
282        AnyId(s.into())
283    }
284
285    /// Get the underlying string
286    pub fn into_smolstr(self) -> SmolStr {
287        self.0
288    }
289}
290
291struct AnyIdVisitor;
292
293impl serde::de::Visitor<'_> for AnyIdVisitor {
294    type Value = AnyId;
295
296    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297        formatter.write_str("any id")
298    }
299
300    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
301    where
302        E: serde::de::Error,
303    {
304        AnyId::from_normalized_str(value)
305            .map_err(|err| serde::de::Error::custom(format!("invalid id `{value}`: {err}")))
306    }
307}
308
309/// Deserialize an `AnyId` using `from_normalized_str`.
310/// This deserialization implementation is used in the JSON policy format.
311impl<'de> Deserialize<'de> for AnyId {
312    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
313    where
314        D: Deserializer<'de>,
315    {
316        deserializer.deserialize_str(AnyIdVisitor)
317    }
318}
319
320impl AsRef<str> for AnyId {
321    fn as_ref(&self) -> &str {
322        &self.0
323    }
324}
325
326impl std::fmt::Display for AnyId {
327    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328        write!(f, "{}", &self.0)
329    }
330}
331
332// allow `.parse()` on a string to make an `AnyId`
333impl std::str::FromStr for AnyId {
334    type Err = ParseErrors;
335
336    fn from_str(s: &str) -> Result<Self, Self::Err> {
337        crate::parser::parse_anyid(s)
338    }
339}
340
341impl FromNormalizedStr for AnyId {
342    fn describe_self() -> &'static str {
343        "AnyId"
344    }
345}
346
347#[cfg(feature = "arbitrary")]
348impl<'a> arbitrary::Arbitrary<'a> for AnyId {
349    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
350        // AnyId syntax:
351        // ['_''a'-'z''A'-'Z']['_''a'-'z''A'-'Z''0'-'9']*
352
353        let construct_list = |s: &str| s.chars().collect::<Vec<char>>();
354        let list_concat = |s1: &[char], s2: &[char]| [s1, s2].concat();
355        // the set of the first character of an AnyId
356        let head_letters = construct_list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");
357        // the set of the remaining characters of an AnyId
358        let tail_letters = list_concat(&construct_list("0123456789"), &head_letters);
359        // identifier character count minus 1
360        let remaining_length = u.int_in_range(0..=16)?;
361        let mut cs = vec![*u.choose(&head_letters)?];
362        cs.extend(
363            (0..remaining_length)
364                .map(|_| u.choose(&tail_letters))
365                .collect::<Result<Vec<&char>, _>>()?,
366        );
367        let s: String = cs.into_iter().collect();
368        debug_assert!(
369            crate::parser::parse_anyid(&s).is_ok(),
370            "all strings constructed this way should be valid AnyIds, but this one is not: {s:?}"
371        );
372        Ok(Self::new_unchecked(s))
373    }
374
375    fn size_hint(depth: usize) -> (usize, Option<usize>) {
376        arbitrary::size_hint::and_all(&[
377            // for arbitrary length
378            <usize as arbitrary::Arbitrary>::size_hint(depth),
379            // for arbitrary choices
380            // we use the size hint of a vector of `u8` to get an underestimate of bytes required by the sequence of choices.
381            <Vec<u8> as arbitrary::Arbitrary>::size_hint(depth),
382        ])
383    }
384}
385
386// PANIC SAFETY: unit-test code
387#[allow(clippy::panic)]
388#[cfg(test)]
389mod test {
390    use super::*;
391
392    #[test]
393    fn normalized_id() {
394        Id::from_normalized_str("foo").expect("should be OK");
395        Id::from_normalized_str("foo::bar").expect_err("shouldn't be OK");
396        Id::from_normalized_str(r#"foo::"bar""#).expect_err("shouldn't be OK");
397        Id::from_normalized_str(" foo").expect_err("shouldn't be OK");
398        Id::from_normalized_str("foo ").expect_err("shouldn't be OK");
399        Id::from_normalized_str("foo\n").expect_err("shouldn't be OK");
400        Id::from_normalized_str("foo//comment").expect_err("shouldn't be OK");
401    }
402}