cedar_policy_core/ast/
id.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use serde::{Deserialize, Deserializer, Serialize};
18use smol_str::SmolStr;
19
20use crate::{
21    ast::is_normalized_ident, default_from_normalized_str, parser::err::ParseErrors,
22    FromNormalizedStr,
23};
24
25use super::{InternalName, ReservedNameError};
26
27const RESERVED_ID: &str = "__cedar";
28
29/// Identifiers. Anything in `Id` should be a valid identifier, this means it
30/// does not contain, for instance, spaces or characters like '+'; and also is
31/// not one of the Cedar reserved identifiers (at time of writing,
32/// `true | false | if | then | else | in | is | like | has`).
33//
34// For now, internally, `Id`s are just owned `SmolString`s.
35#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
36pub struct Id(SmolStr);
37
38impl Id {
39    /// Create a new `Id` from a `String`, where it is the caller's
40    /// responsibility to ensure that the string is indeed a valid identifier.
41    ///
42    /// When possible, callers should not use this, and instead use `s.parse()`,
43    /// which checks that `s` is a valid identifier, and returns a parse error
44    /// if not.
45    ///
46    /// This method was created for the `From<cst::Ident> for Id` impl to use.
47    /// Since `parser::parse_ident()` implicitly uses that `From` impl itself,
48    /// if we tried to make that `From` impl go through `.parse()` like everyone
49    /// else, we'd get infinite recursion.  And, we assert that `cst::Ident` is
50    /// always already checked to contain a valid identifier, otherwise it would
51    /// never have been created.
52    pub(crate) fn new_unchecked(s: impl Into<SmolStr>) -> Id {
53        Id(s.into())
54    }
55
56    /// Get the underlying string
57    pub fn into_smolstr(self) -> SmolStr {
58        self.0
59    }
60
61    /// Return if the `Id` is reserved (i.e., `__cedar`)
62    /// Note that it does not test if the `Id` string is a reserved keyword
63    /// as the parser already ensures that it is not
64    pub fn is_reserved(&self) -> bool {
65        self.as_ref() == RESERVED_ID
66    }
67}
68
69impl AsRef<str> for Id {
70    fn as_ref(&self) -> &str {
71        &self.0
72    }
73}
74
75impl std::fmt::Display for Id {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        write!(f, "{}", &self.0)
78    }
79}
80
81// allow `.parse()` on a string to make an `Id`
82impl std::str::FromStr for Id {
83    type Err = ParseErrors;
84
85    fn from_str(s: &str) -> Result<Self, Self::Err> {
86        crate::parser::parse_ident(s)
87    }
88}
89
90impl FromNormalizedStr for Id {
91    fn describe_self() -> &'static str {
92        "Id"
93    }
94
95    // Specialized implementation of `from_normalized_str()` that
96    // uses the optimized implementation `is_normalized_ident()`
97    fn from_normalized_str(s: &str) -> Result<Self, ParseErrors> {
98        if is_normalized_ident(s) {
99            Ok(Self::new_unchecked(s))
100        } else {
101            // Fall back on the default (unoptimized) implementation
102            // to get a nice error message
103            default_from_normalized_str(s, Self::describe_self)
104        }
105    }
106}
107
108/// An `Id` that is not equal to `__cedar`, as specified by RFC 52
109#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
110#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
111#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
112pub struct UnreservedId(#[cfg_attr(feature = "wasm", tsify(type = "string"))] pub(crate) Id);
113
114impl From<UnreservedId> for Id {
115    fn from(value: UnreservedId) -> Self {
116        value.0
117    }
118}
119
120impl TryFrom<Id> for UnreservedId {
121    type Error = ReservedNameError;
122    fn try_from(value: Id) -> Result<Self, Self::Error> {
123        if value.is_reserved() {
124            Err(ReservedNameError(InternalName::unqualified_name(
125                value, None,
126            )))
127        } else {
128            Ok(Self(value))
129        }
130    }
131}
132
133impl AsRef<Id> for UnreservedId {
134    fn as_ref(&self) -> &Id {
135        &self.0
136    }
137}
138
139impl AsRef<str> for UnreservedId {
140    fn as_ref(&self) -> &str {
141        self.0.as_ref()
142    }
143}
144
145impl std::fmt::Display for UnreservedId {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        self.0.fmt(f)
148    }
149}
150
151impl std::str::FromStr for UnreservedId {
152    type Err = ParseErrors;
153    fn from_str(s: &str) -> Result<Self, Self::Err> {
154        Id::from_str(s).and_then(|id| id.try_into().map_err(ParseErrors::singleton))
155    }
156}
157
158impl FromNormalizedStr for UnreservedId {
159    fn describe_self() -> &'static str {
160        "Unreserved Id"
161    }
162}
163
164impl UnreservedId {
165    /// Create an [`UnreservedId`] from an empty string
166    pub(crate) fn empty() -> Self {
167        // PANIC SAFETY: "" does not contain `__cedar`
168        #[allow(clippy::unwrap_used)]
169        Id("".into()).try_into().unwrap()
170    }
171
172    /// Get the underlying string
173    pub fn into_smolstr(self) -> SmolStr {
174        self.0.into_smolstr()
175    }
176}
177
178struct IdVisitor;
179
180impl serde::de::Visitor<'_> for IdVisitor {
181    type Value = Id;
182
183    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184        formatter.write_str("a valid id")
185    }
186
187    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
188    where
189        E: serde::de::Error,
190    {
191        Id::from_normalized_str(value)
192            .map_err(|err| serde::de::Error::custom(format!("invalid id `{value}`: {err}")))
193    }
194}
195
196/// Deserialize an `Id` using `from_normalized_str`.
197/// This deserialization implementation is used in the JSON schema format.
198impl<'de> Deserialize<'de> for Id {
199    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
200    where
201        D: Deserializer<'de>,
202    {
203        deserializer.deserialize_str(IdVisitor)
204    }
205}
206
207/// Deserialize a [`UnreservedId`] using `from_normalized_str`
208/// This deserialization implementation is used in the JSON schema format.
209impl<'de> Deserialize<'de> for UnreservedId {
210    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
211    where
212        D: Deserializer<'de>,
213    {
214        deserializer
215            .deserialize_str(IdVisitor)
216            .and_then(|n| n.try_into().map_err(serde::de::Error::custom))
217    }
218}
219
220#[cfg(feature = "arbitrary")]
221impl<'a> arbitrary::Arbitrary<'a> for Id {
222    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
223        // identifier syntax:
224        // IDENT     := ['_''a'-'z''A'-'Z']['_''a'-'z''A'-'Z''0'-'9']* - RESERVED
225        // BOOL      := 'true' | 'false'
226        // RESERVED  := BOOL | 'if' | 'then' | 'else' | 'in' | 'is' | 'like' | 'has'
227
228        let construct_list = |s: &str| s.chars().collect::<Vec<char>>();
229        let list_concat = |s1: &[char], s2: &[char]| [s1, s2].concat();
230        // the set of the first character of an identifier
231        let head_letters = construct_list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");
232        // the set of the remaining characters of an identifier
233        let tail_letters = list_concat(&construct_list("0123456789"), &head_letters);
234        // identifier character count minus 1
235        let remaining_length = u.int_in_range(0..=16)?;
236        let mut cs = vec![*u.choose(&head_letters)?];
237        cs.extend(
238            (0..remaining_length)
239                .map(|_| u.choose(&tail_letters))
240                .collect::<Result<Vec<&char>, _>>()?,
241        );
242        let mut s: String = cs.into_iter().collect();
243        // Should the parsing fails, the string should be reserved word.
244        // Append a `_` to create a valid Id.
245        if crate::parser::parse_ident(&s).is_err() {
246            s.push('_');
247        }
248        Ok(Self::new_unchecked(s))
249    }
250
251    fn size_hint(depth: usize) -> (usize, Option<usize>) {
252        arbitrary::size_hint::and_all(&[
253            // for arbitrary length
254            <usize as arbitrary::Arbitrary>::size_hint(depth),
255            // for arbitrary choices
256            // we use the size hint of a vector of `u8` to get an underestimate of bytes required by the sequence of choices.
257            <Vec<u8> as arbitrary::Arbitrary>::size_hint(depth),
258        ])
259    }
260}
261
262#[cfg(feature = "arbitrary")]
263impl<'a> arbitrary::Arbitrary<'a> for UnreservedId {
264    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
265        let id: Id = u.arbitrary()?;
266        match UnreservedId::try_from(id.clone()) {
267            Ok(id) => Ok(id),
268            Err(_) => {
269                // PANIC SAFETY: `___cedar` is a valid unreserved id
270                #[allow(clippy::unwrap_used)]
271                let new_id = format!("_{id}").parse().unwrap();
272                Ok(new_id)
273            }
274        }
275    }
276
277    fn size_hint(depth: usize) -> (usize, Option<usize>) {
278        <Id as arbitrary::Arbitrary>::size_hint(depth)
279    }
280}
281
282/// Like `Id`, except this specifically _can_ contain Cedar reserved identifiers.
283/// (It still can't contain, for instance, spaces or characters like '+'.)
284//
285// For now, internally, `AnyId`s are just owned `SmolString`s.
286#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
287#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
288#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
289pub struct AnyId(SmolStr);
290
291impl AnyId {
292    /// Create a new `AnyId` from a `String`, where it is the caller's
293    /// responsibility to ensure that the string is indeed a valid `AnyId`.
294    ///
295    /// When possible, callers should not use this, and instead use `s.parse()`,
296    /// which checks that `s` is a valid `AnyId`, and returns a parse error
297    /// if not.
298    ///
299    /// This method was created for the `From<cst::Ident> for AnyId` impl to use.
300    /// See notes on `Id::new_unchecked()`.
301    pub(crate) fn new_unchecked(s: impl Into<SmolStr>) -> AnyId {
302        AnyId(s.into())
303    }
304
305    /// Get the underlying string
306    pub fn into_smolstr(self) -> SmolStr {
307        self.0
308    }
309}
310
311struct AnyIdVisitor;
312
313impl serde::de::Visitor<'_> for AnyIdVisitor {
314    type Value = AnyId;
315
316    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
317        formatter.write_str("any id")
318    }
319
320    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
321    where
322        E: serde::de::Error,
323    {
324        AnyId::from_normalized_str(value)
325            .map_err(|err| serde::de::Error::custom(format!("invalid id `{value}`: {err}")))
326    }
327}
328
329/// Deserialize an `AnyId` using `from_normalized_str`.
330/// This deserialization implementation is used in the JSON policy format.
331impl<'de> Deserialize<'de> for AnyId {
332    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
333    where
334        D: Deserializer<'de>,
335    {
336        deserializer.deserialize_str(AnyIdVisitor)
337    }
338}
339
340impl AsRef<str> for AnyId {
341    fn as_ref(&self) -> &str {
342        &self.0
343    }
344}
345
346impl std::fmt::Display for AnyId {
347    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348        write!(f, "{}", &self.0)
349    }
350}
351
352// allow `.parse()` on a string to make an `AnyId`
353impl std::str::FromStr for AnyId {
354    type Err = ParseErrors;
355
356    fn from_str(s: &str) -> Result<Self, Self::Err> {
357        crate::parser::parse_anyid(s)
358    }
359}
360
361impl FromNormalizedStr for AnyId {
362    fn describe_self() -> &'static str {
363        "AnyId"
364    }
365}
366
367#[cfg(feature = "arbitrary")]
368impl<'a> arbitrary::Arbitrary<'a> for AnyId {
369    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
370        // AnyId syntax:
371        // ['_''a'-'z''A'-'Z']['_''a'-'z''A'-'Z''0'-'9']*
372
373        let construct_list = |s: &str| s.chars().collect::<Vec<char>>();
374        let list_concat = |s1: &[char], s2: &[char]| [s1, s2].concat();
375        // the set of the first character of an AnyId
376        let head_letters = construct_list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");
377        // the set of the remaining characters of an AnyId
378        let tail_letters = list_concat(&construct_list("0123456789"), &head_letters);
379        // identifier character count minus 1
380        let remaining_length = u.int_in_range(0..=16)?;
381        let mut cs = vec![*u.choose(&head_letters)?];
382        cs.extend(
383            (0..remaining_length)
384                .map(|_| u.choose(&tail_letters))
385                .collect::<Result<Vec<&char>, _>>()?,
386        );
387        let s: String = cs.into_iter().collect();
388        debug_assert!(
389            crate::parser::parse_anyid(&s).is_ok(),
390            "all strings constructed this way should be valid AnyIds, but this one is not: {s:?}"
391        );
392        Ok(Self::new_unchecked(s))
393    }
394
395    fn size_hint(depth: usize) -> (usize, Option<usize>) {
396        arbitrary::size_hint::and_all(&[
397            // for arbitrary length
398            <usize as arbitrary::Arbitrary>::size_hint(depth),
399            // for arbitrary choices
400            // we use the size hint of a vector of `u8` to get an underestimate of bytes required by the sequence of choices.
401            <Vec<u8> as arbitrary::Arbitrary>::size_hint(depth),
402        ])
403    }
404}
405
406// PANIC SAFETY: unit-test code
407#[allow(clippy::panic)]
408#[cfg(test)]
409mod test {
410    use super::*;
411
412    #[test]
413    fn normalized_id() {
414        Id::from_normalized_str("foo").expect("should be OK");
415        Id::from_normalized_str("foo::bar").expect_err("shouldn't be OK");
416        Id::from_normalized_str(r#"foo::"bar""#).expect_err("shouldn't be OK");
417        Id::from_normalized_str(" foo").expect_err("shouldn't be OK");
418        Id::from_normalized_str("foo ").expect_err("shouldn't be OK");
419        Id::from_normalized_str("foo\n").expect_err("shouldn't be OK");
420        Id::from_normalized_str("foo//comment").expect_err("shouldn't be OK");
421    }
422}