cedar_policy_core/ast/
id.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use serde::{Deserialize, Deserializer, Serialize};
18use smol_str::SmolStr;
19
20use crate::{parser::err::ParseErrors, FromNormalizedStr};
21
22use super::{InternalName, ReservedNameError};
23
24const RESERVED_ID: &str = "__cedar";
25
26/// Identifiers. Anything in `Id` should be a valid identifier, this means it
27/// does not contain, for instance, spaces or characters like '+'; and also is
28/// not one of the Cedar reserved identifiers (at time of writing,
29/// `true | false | if | then | else | in | is | like | has`).
30//
31// For now, internally, `Id`s are just owned `SmolString`s.
32#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
33pub struct Id(SmolStr);
34
35impl Id {
36    /// Create a new `Id` from a `String`, where it is the caller's
37    /// responsibility to ensure that the string is indeed a valid identifier.
38    ///
39    /// When possible, callers should not use this, and instead use `s.parse()`,
40    /// which checks that `s` is a valid identifier, and returns a parse error
41    /// if not.
42    ///
43    /// This method was created for the `From<cst::Ident> for Id` impl to use.
44    /// Since `parser::parse_ident()` implicitly uses that `From` impl itself,
45    /// if we tried to make that `From` impl go through `.parse()` like everyone
46    /// else, we'd get infinite recursion.  And, we assert that `cst::Ident` is
47    /// always already checked to contain a valid identifier, otherwise it would
48    /// never have been created.
49    pub(crate) fn new_unchecked(s: impl Into<SmolStr>) -> Id {
50        Id(s.into())
51    }
52
53    /// Get the underlying string
54    pub fn into_smolstr(self) -> SmolStr {
55        self.0
56    }
57
58    /// Return if the `Id` is reserved (i.e., `__cedar`)
59    /// Note that it does not test if the `Id` string is a reserved keyword
60    /// as the parser already ensures that it is not
61    pub fn is_reserved(&self) -> bool {
62        self.as_ref() == RESERVED_ID
63    }
64}
65
66impl AsRef<str> for Id {
67    fn as_ref(&self) -> &str {
68        &self.0
69    }
70}
71
72impl std::fmt::Display for Id {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        write!(f, "{}", &self.0)
75    }
76}
77
78// allow `.parse()` on a string to make an `Id`
79impl std::str::FromStr for Id {
80    type Err = ParseErrors;
81
82    fn from_str(s: &str) -> Result<Self, Self::Err> {
83        crate::parser::parse_ident(s)
84    }
85}
86
87impl FromNormalizedStr for Id {
88    fn describe_self() -> &'static str {
89        "Id"
90    }
91}
92
93/// An `Id` that is not equal to `__cedar`, as specified by RFC 52
94#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
95#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
96#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
97pub struct UnreservedId(#[cfg_attr(feature = "wasm", tsify(type = "string"))] pub(crate) Id);
98
99impl From<UnreservedId> for Id {
100    fn from(value: UnreservedId) -> Self {
101        value.0
102    }
103}
104
105impl TryFrom<Id> for UnreservedId {
106    type Error = ReservedNameError;
107    fn try_from(value: Id) -> Result<Self, Self::Error> {
108        if value.is_reserved() {
109            Err(ReservedNameError(InternalName::unqualified_name(
110                value, None,
111            )))
112        } else {
113            Ok(Self(value))
114        }
115    }
116}
117
118impl AsRef<Id> for UnreservedId {
119    fn as_ref(&self) -> &Id {
120        &self.0
121    }
122}
123
124impl AsRef<str> for UnreservedId {
125    fn as_ref(&self) -> &str {
126        self.0.as_ref()
127    }
128}
129
130impl std::fmt::Display for UnreservedId {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        self.0.fmt(f)
133    }
134}
135
136impl std::str::FromStr for UnreservedId {
137    type Err = ParseErrors;
138    fn from_str(s: &str) -> Result<Self, Self::Err> {
139        Id::from_str(s).and_then(|id| id.try_into().map_err(ParseErrors::singleton))
140    }
141}
142
143impl FromNormalizedStr for UnreservedId {
144    fn describe_self() -> &'static str {
145        "Unreserved Id"
146    }
147}
148
149impl UnreservedId {
150    /// Create an [`UnreservedId`] from an empty string
151    pub(crate) fn empty() -> Self {
152        // PANIC SAFETY: "" does not contain `__cedar`
153        #[allow(clippy::unwrap_used)]
154        Id("".into()).try_into().unwrap()
155    }
156
157    /// Get the underlying string
158    pub fn into_smolstr(self) -> SmolStr {
159        self.0.into_smolstr()
160    }
161}
162
163struct IdVisitor;
164
165impl serde::de::Visitor<'_> for IdVisitor {
166    type Value = Id;
167
168    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        formatter.write_str("a valid id")
170    }
171
172    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
173    where
174        E: serde::de::Error,
175    {
176        Id::from_normalized_str(value)
177            .map_err(|err| serde::de::Error::custom(format!("invalid id `{value}`: {err}")))
178    }
179}
180
181/// Deserialize an `Id` using `from_normalized_str`.
182/// This deserialization implementation is used in the JSON schema format.
183impl<'de> Deserialize<'de> for Id {
184    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
185    where
186        D: Deserializer<'de>,
187    {
188        deserializer.deserialize_str(IdVisitor)
189    }
190}
191
192/// Deserialize a [`UnreservedId`] using `from_normalized_str`
193/// This deserialization implementation is used in the JSON schema format.
194impl<'de> Deserialize<'de> for UnreservedId {
195    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
196    where
197        D: Deserializer<'de>,
198    {
199        deserializer
200            .deserialize_str(IdVisitor)
201            .and_then(|n| n.try_into().map_err(serde::de::Error::custom))
202    }
203}
204
205#[cfg(feature = "arbitrary")]
206impl<'a> arbitrary::Arbitrary<'a> for Id {
207    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
208        // identifier syntax:
209        // IDENT     := ['_''a'-'z''A'-'Z']['_''a'-'z''A'-'Z''0'-'9']* - RESERVED
210        // BOOL      := 'true' | 'false'
211        // RESERVED  := BOOL | 'if' | 'then' | 'else' | 'in' | 'is' | 'like' | 'has'
212
213        let construct_list = |s: &str| s.chars().collect::<Vec<char>>();
214        let list_concat = |s1: &[char], s2: &[char]| [s1, s2].concat();
215        // the set of the first character of an identifier
216        let head_letters = construct_list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");
217        // the set of the remaining characters of an identifier
218        let tail_letters = list_concat(&construct_list("0123456789"), &head_letters);
219        // identifier character count minus 1
220        let remaining_length = u.int_in_range(0..=16)?;
221        let mut cs = vec![*u.choose(&head_letters)?];
222        cs.extend(
223            (0..remaining_length)
224                .map(|_| u.choose(&tail_letters))
225                .collect::<Result<Vec<&char>, _>>()?,
226        );
227        let mut s: String = cs.into_iter().collect();
228        // Should the parsing fails, the string should be reserved word.
229        // Append a `_` to create a valid Id.
230        if crate::parser::parse_ident(&s).is_err() {
231            s.push('_');
232        }
233        Ok(Self::new_unchecked(s))
234    }
235
236    fn size_hint(depth: usize) -> (usize, Option<usize>) {
237        arbitrary::size_hint::and_all(&[
238            // for arbitrary length
239            <usize as arbitrary::Arbitrary>::size_hint(depth),
240            // for arbitrary choices
241            // we use the size hint of a vector of `u8` to get an underestimate of bytes required by the sequence of choices.
242            <Vec<u8> as arbitrary::Arbitrary>::size_hint(depth),
243        ])
244    }
245}
246
247#[cfg(feature = "arbitrary")]
248impl<'a> arbitrary::Arbitrary<'a> for UnreservedId {
249    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
250        let id: Id = u.arbitrary()?;
251        match UnreservedId::try_from(id.clone()) {
252            Ok(id) => Ok(id),
253            Err(_) => {
254                // PANIC SAFETY: `___cedar` is a valid unreserved id
255                #[allow(clippy::unwrap_used)]
256                let new_id = format!("_{id}").parse().unwrap();
257                Ok(new_id)
258            }
259        }
260    }
261
262    fn size_hint(depth: usize) -> (usize, Option<usize>) {
263        <Id as arbitrary::Arbitrary>::size_hint(depth)
264    }
265}
266
267/// Like `Id`, except this specifically _can_ contain Cedar reserved identifiers.
268/// (It still can't contain, for instance, spaces or characters like '+'.)
269//
270// For now, internally, `AnyId`s are just owned `SmolString`s.
271#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
272#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
273#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
274pub struct AnyId(SmolStr);
275
276impl AnyId {
277    /// Create a new `AnyId` from a `String`, where it is the caller's
278    /// responsibility to ensure that the string is indeed a valid `AnyId`.
279    ///
280    /// When possible, callers should not use this, and instead use `s.parse()`,
281    /// which checks that `s` is a valid `AnyId`, and returns a parse error
282    /// if not.
283    ///
284    /// This method was created for the `From<cst::Ident> for AnyId` impl to use.
285    /// See notes on `Id::new_unchecked()`.
286    pub(crate) fn new_unchecked(s: impl Into<SmolStr>) -> AnyId {
287        AnyId(s.into())
288    }
289
290    /// Get the underlying string
291    pub fn into_smolstr(self) -> SmolStr {
292        self.0
293    }
294}
295
296struct AnyIdVisitor;
297
298impl serde::de::Visitor<'_> for AnyIdVisitor {
299    type Value = AnyId;
300
301    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        formatter.write_str("any id")
303    }
304
305    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
306    where
307        E: serde::de::Error,
308    {
309        AnyId::from_normalized_str(value)
310            .map_err(|err| serde::de::Error::custom(format!("invalid id `{value}`: {err}")))
311    }
312}
313
314/// Deserialize an `AnyId` using `from_normalized_str`.
315/// This deserialization implementation is used in the JSON policy format.
316impl<'de> Deserialize<'de> for AnyId {
317    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
318    where
319        D: Deserializer<'de>,
320    {
321        deserializer.deserialize_str(AnyIdVisitor)
322    }
323}
324
325impl AsRef<str> for AnyId {
326    fn as_ref(&self) -> &str {
327        &self.0
328    }
329}
330
331impl std::fmt::Display for AnyId {
332    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
333        write!(f, "{}", &self.0)
334    }
335}
336
337// allow `.parse()` on a string to make an `AnyId`
338impl std::str::FromStr for AnyId {
339    type Err = ParseErrors;
340
341    fn from_str(s: &str) -> Result<Self, Self::Err> {
342        crate::parser::parse_anyid(s)
343    }
344}
345
346impl FromNormalizedStr for AnyId {
347    fn describe_self() -> &'static str {
348        "AnyId"
349    }
350}
351
352#[cfg(feature = "arbitrary")]
353impl<'a> arbitrary::Arbitrary<'a> for AnyId {
354    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
355        // AnyId syntax:
356        // ['_''a'-'z''A'-'Z']['_''a'-'z''A'-'Z''0'-'9']*
357
358        let construct_list = |s: &str| s.chars().collect::<Vec<char>>();
359        let list_concat = |s1: &[char], s2: &[char]| [s1, s2].concat();
360        // the set of the first character of an AnyId
361        let head_letters = construct_list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");
362        // the set of the remaining characters of an AnyId
363        let tail_letters = list_concat(&construct_list("0123456789"), &head_letters);
364        // identifier character count minus 1
365        let remaining_length = u.int_in_range(0..=16)?;
366        let mut cs = vec![*u.choose(&head_letters)?];
367        cs.extend(
368            (0..remaining_length)
369                .map(|_| u.choose(&tail_letters))
370                .collect::<Result<Vec<&char>, _>>()?,
371        );
372        let s: String = cs.into_iter().collect();
373        debug_assert!(
374            crate::parser::parse_anyid(&s).is_ok(),
375            "all strings constructed this way should be valid AnyIds, but this one is not: {s:?}"
376        );
377        Ok(Self::new_unchecked(s))
378    }
379
380    fn size_hint(depth: usize) -> (usize, Option<usize>) {
381        arbitrary::size_hint::and_all(&[
382            // for arbitrary length
383            <usize as arbitrary::Arbitrary>::size_hint(depth),
384            // for arbitrary choices
385            // we use the size hint of a vector of `u8` to get an underestimate of bytes required by the sequence of choices.
386            <Vec<u8> as arbitrary::Arbitrary>::size_hint(depth),
387        ])
388    }
389}
390
391// PANIC SAFETY: unit-test code
392#[allow(clippy::panic)]
393#[cfg(test)]
394mod test {
395    use super::*;
396
397    #[test]
398    fn normalized_id() {
399        Id::from_normalized_str("foo").expect("should be OK");
400        Id::from_normalized_str("foo::bar").expect_err("shouldn't be OK");
401        Id::from_normalized_str(r#"foo::"bar""#).expect_err("shouldn't be OK");
402        Id::from_normalized_str(" foo").expect_err("shouldn't be OK");
403        Id::from_normalized_str("foo ").expect_err("shouldn't be OK");
404        Id::from_normalized_str("foo\n").expect_err("shouldn't be OK");
405        Id::from_normalized_str("foo//comment").expect_err("shouldn't be OK");
406    }
407}