Skip to main content

cedar_policy_core/ast/
id.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use serde::{Deserialize, Deserializer, Serialize};
18use smol_str::SmolStr;
19
20use crate::{
21    ast::is_normalized_ident, default_from_normalized_str, parser::err::ParseErrors,
22    FromNormalizedStr,
23};
24
25use super::{InternalName, ReservedNameError};
26
27const RESERVED_ID: &str = "__cedar";
28
29/// Identifiers. Anything in `Id` should be a valid identifier, this means it
30/// does not contain, for instance, spaces or characters like '+'; and also is
31/// not one of the Cedar reserved identifiers (at time of writing,
32/// `true | false | if | then | else | in | is | like | has`).
33//
34// For now, internally, `Id`s are just owned `SmolString`s.
35#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
36pub struct Id(SmolStr);
37
38impl Id {
39    /// Create a new `Id` from a `String`, where it is the caller's
40    /// responsibility to ensure that the string is indeed a valid identifier.
41    ///
42    /// When possible, callers should not use this, and instead use `s.parse()`,
43    /// which checks that `s` is a valid identifier, and returns a parse error
44    /// if not.
45    ///
46    /// This method was created for the `From<cst::Ident> for Id` impl to use.
47    /// Since `parser::parse_ident()` implicitly uses that `From` impl itself,
48    /// if we tried to make that `From` impl go through `.parse()` like everyone
49    /// else, we'd get infinite recursion.  And, we assert that `cst::Ident` is
50    /// always already checked to contain a valid identifier, otherwise it would
51    /// never have been created.
52    pub(crate) fn new_unchecked(s: impl Into<SmolStr>) -> Id {
53        Id(s.into())
54    }
55
56    /// Get the underlying string
57    pub fn into_smolstr(self) -> SmolStr {
58        self.0
59    }
60
61    /// Return if the `Id` is reserved (i.e., `__cedar`)
62    /// Note that it does not test if the `Id` string is a reserved keyword
63    /// as the parser already ensures that it is not
64    pub fn is_reserved(&self) -> bool {
65        self.as_ref() == RESERVED_ID
66    }
67}
68
69impl AsRef<str> for Id {
70    fn as_ref(&self) -> &str {
71        &self.0
72    }
73}
74
75impl std::fmt::Display for Id {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        write!(f, "{}", &self.0)
78    }
79}
80
81// allow `.parse()` on a string to make an `Id`
82impl std::str::FromStr for Id {
83    type Err = ParseErrors;
84
85    fn from_str(s: &str) -> Result<Self, Self::Err> {
86        crate::parser::parse_ident(s)
87    }
88}
89
90impl FromNormalizedStr for Id {
91    fn describe_self() -> &'static str {
92        "Id"
93    }
94
95    // Specialized implementation of `from_normalized_str()` that
96    // uses the optimized implementation `is_normalized_ident()`
97    fn from_normalized_str(s: &str) -> Result<Self, ParseErrors> {
98        if is_normalized_ident(s) {
99            Ok(Self::new_unchecked(s))
100        } else {
101            // Fall back on the default (unoptimized) implementation
102            // to get a nice error message
103            default_from_normalized_str(s, Self::describe_self)
104        }
105    }
106}
107
108/// An `Id` that is not equal to `__cedar`, as specified by RFC 52
109#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
110#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
111#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
112pub struct UnreservedId(#[cfg_attr(feature = "wasm", tsify(type = "string"))] pub(crate) Id);
113
114impl From<UnreservedId> for Id {
115    fn from(value: UnreservedId) -> Self {
116        value.0
117    }
118}
119
120impl TryFrom<Id> for UnreservedId {
121    type Error = ReservedNameError;
122    fn try_from(value: Id) -> Result<Self, Self::Error> {
123        if value.is_reserved() {
124            Err(ReservedNameError(InternalName::unqualified_name(
125                value, None,
126            )))
127        } else {
128            Ok(Self(value))
129        }
130    }
131}
132
133impl AsRef<Id> for UnreservedId {
134    fn as_ref(&self) -> &Id {
135        &self.0
136    }
137}
138
139impl AsRef<str> for UnreservedId {
140    fn as_ref(&self) -> &str {
141        self.0.as_ref()
142    }
143}
144
145impl std::fmt::Display for UnreservedId {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        self.0.fmt(f)
148    }
149}
150
151impl std::str::FromStr for UnreservedId {
152    type Err = ParseErrors;
153    fn from_str(s: &str) -> Result<Self, Self::Err> {
154        Id::from_str(s).and_then(|id| id.try_into().map_err(ParseErrors::singleton))
155    }
156}
157
158impl FromNormalizedStr for UnreservedId {
159    fn describe_self() -> &'static str {
160        "Unreserved Id"
161    }
162}
163
164impl UnreservedId {
165    /// Create an [`UnreservedId`] from an empty string
166    pub(crate) fn empty() -> Self {
167        #[expect(clippy::unwrap_used, reason = "\"\" does not contain `__cedar`")]
168        Id("".into()).try_into().unwrap()
169    }
170
171    /// Get the underlying string
172    pub fn into_smolstr(self) -> SmolStr {
173        self.0.into_smolstr()
174    }
175}
176
177struct IdVisitor;
178
179impl serde::de::Visitor<'_> for IdVisitor {
180    type Value = Id;
181
182    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
183        formatter.write_str("a valid id")
184    }
185
186    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
187    where
188        E: serde::de::Error,
189    {
190        Id::from_normalized_str(value)
191            .map_err(|err| serde::de::Error::custom(format!("invalid id `{value}`: {err}")))
192    }
193}
194
195/// Deserialize an `Id` using `from_normalized_str`.
196/// This deserialization implementation is used in the JSON schema format.
197impl<'de> Deserialize<'de> for Id {
198    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
199    where
200        D: Deserializer<'de>,
201    {
202        deserializer.deserialize_str(IdVisitor)
203    }
204}
205
206/// Deserialize a [`UnreservedId`] using `from_normalized_str`
207/// This deserialization implementation is used in the JSON schema format.
208impl<'de> Deserialize<'de> for UnreservedId {
209    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
210    where
211        D: Deserializer<'de>,
212    {
213        deserializer
214            .deserialize_str(IdVisitor)
215            .and_then(|n| n.try_into().map_err(serde::de::Error::custom))
216    }
217}
218
219#[cfg(feature = "arbitrary")]
220impl<'a> arbitrary::Arbitrary<'a> for Id {
221    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
222        // identifier syntax:
223        // IDENT     := ['_''a'-'z''A'-'Z']['_''a'-'z''A'-'Z''0'-'9']* - RESERVED
224        // BOOL      := 'true' | 'false'
225        // RESERVED  := BOOL | 'if' | 'then' | 'else' | 'in' | 'is' | 'like' | 'has'
226
227        let construct_list = |s: &str| s.chars().collect::<Vec<char>>();
228        let list_concat = |s1: &[char], s2: &[char]| [s1, s2].concat();
229        // the set of the first character of an identifier
230        let head_letters = construct_list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");
231        // the set of the remaining characters of an identifier
232        let tail_letters = list_concat(&construct_list("0123456789"), &head_letters);
233        // identifier character count minus 1
234        let remaining_length = u.int_in_range(0..=16)?;
235        let mut cs = vec![*u.choose(&head_letters)?];
236        cs.extend(
237            (0..remaining_length)
238                .map(|_| u.choose(&tail_letters))
239                .collect::<Result<Vec<&char>, _>>()?,
240        );
241        let mut s: String = cs.into_iter().collect();
242        // Should the parsing fails, the string should be reserved word.
243        // Append a `_` to create a valid Id.
244        if crate::parser::parse_ident(&s).is_err() {
245            s.push('_');
246        }
247        Ok(Self::new_unchecked(s))
248    }
249
250    fn size_hint(depth: usize) -> (usize, Option<usize>) {
251        arbitrary::size_hint::and_all(&[
252            // for arbitrary length
253            <usize as arbitrary::Arbitrary>::size_hint(depth),
254            // for arbitrary choices
255            // we use the size hint of a vector of `u8` to get an underestimate of bytes required by the sequence of choices.
256            <Vec<u8> as arbitrary::Arbitrary>::size_hint(depth),
257        ])
258    }
259}
260
261#[cfg(feature = "arbitrary")]
262impl<'a> arbitrary::Arbitrary<'a> for UnreservedId {
263    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
264        let id: Id = u.arbitrary()?;
265        match UnreservedId::try_from(id.clone()) {
266            Ok(id) => Ok(id),
267            Err(_) => {
268                #[expect(clippy::unwrap_used, reason = "`___cedar` is a valid unreserved id")]
269                let new_id = format!("_{id}").parse().unwrap();
270                Ok(new_id)
271            }
272        }
273    }
274
275    fn size_hint(depth: usize) -> (usize, Option<usize>) {
276        <Id as arbitrary::Arbitrary>::size_hint(depth)
277    }
278}
279
280/// Like `Id`, except this specifically _can_ contain Cedar reserved identifiers.
281/// (It still can't contain, for instance, spaces or characters like '+'.)
282//
283// For now, internally, `AnyId`s are just owned `SmolString`s.
284#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
285#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
286#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
287pub struct AnyId(SmolStr);
288
289impl AnyId {
290    /// Create a new `AnyId` from a `String`, where it is the caller's
291    /// responsibility to ensure that the string is indeed a valid `AnyId`.
292    ///
293    /// When possible, callers should not use this, and instead use `s.parse()`,
294    /// which checks that `s` is a valid `AnyId`, and returns a parse error
295    /// if not.
296    ///
297    /// This method was created for the `From<cst::Ident> for AnyId` impl to use.
298    /// See notes on `Id::new_unchecked()`.
299    pub(crate) fn new_unchecked(s: impl Into<SmolStr>) -> AnyId {
300        AnyId(s.into())
301    }
302
303    /// Get the underlying string
304    pub fn into_smolstr(self) -> SmolStr {
305        self.0
306    }
307}
308
309struct AnyIdVisitor;
310
311impl serde::de::Visitor<'_> for AnyIdVisitor {
312    type Value = AnyId;
313
314    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        formatter.write_str("any id")
316    }
317
318    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
319    where
320        E: serde::de::Error,
321    {
322        AnyId::from_normalized_str(value)
323            .map_err(|err| serde::de::Error::custom(format!("invalid id `{value}`: {err}")))
324    }
325}
326
327/// Deserialize an `AnyId` using `from_normalized_str`.
328/// This deserialization implementation is used in the JSON policy format.
329impl<'de> Deserialize<'de> for AnyId {
330    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
331    where
332        D: Deserializer<'de>,
333    {
334        deserializer.deserialize_str(AnyIdVisitor)
335    }
336}
337
338impl AsRef<str> for AnyId {
339    fn as_ref(&self) -> &str {
340        &self.0
341    }
342}
343
344impl std::fmt::Display for AnyId {
345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346        write!(f, "{}", &self.0)
347    }
348}
349
350// allow `.parse()` on a string to make an `AnyId`
351impl std::str::FromStr for AnyId {
352    type Err = ParseErrors;
353
354    fn from_str(s: &str) -> Result<Self, Self::Err> {
355        crate::parser::parse_anyid(s)
356    }
357}
358
359impl FromNormalizedStr for AnyId {
360    fn describe_self() -> &'static str {
361        "AnyId"
362    }
363}
364
365#[cfg(feature = "arbitrary")]
366impl<'a> arbitrary::Arbitrary<'a> for AnyId {
367    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
368        // AnyId syntax:
369        // ['_''a'-'z''A'-'Z']['_''a'-'z''A'-'Z''0'-'9']*
370
371        let construct_list = |s: &str| s.chars().collect::<Vec<char>>();
372        let list_concat = |s1: &[char], s2: &[char]| [s1, s2].concat();
373        // the set of the first character of an AnyId
374        let head_letters = construct_list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");
375        // the set of the remaining characters of an AnyId
376        let tail_letters = list_concat(&construct_list("0123456789"), &head_letters);
377        // identifier character count minus 1
378        let remaining_length = u.int_in_range(0..=16)?;
379        let mut cs = vec![*u.choose(&head_letters)?];
380        cs.extend(
381            (0..remaining_length)
382                .map(|_| u.choose(&tail_letters))
383                .collect::<Result<Vec<&char>, _>>()?,
384        );
385        let s: String = cs.into_iter().collect();
386        debug_assert!(
387            crate::parser::parse_anyid(&s).is_ok(),
388            "all strings constructed this way should be valid AnyIds, but this one is not: {s:?}"
389        );
390        Ok(Self::new_unchecked(s))
391    }
392
393    fn size_hint(depth: usize) -> (usize, Option<usize>) {
394        arbitrary::size_hint::and_all(&[
395            // for arbitrary length
396            <usize as arbitrary::Arbitrary>::size_hint(depth),
397            // for arbitrary choices
398            // we use the size hint of a vector of `u8` to get an underestimate of bytes required by the sequence of choices.
399            <Vec<u8> as arbitrary::Arbitrary>::size_hint(depth),
400        ])
401    }
402}
403
404#[cfg(test)]
405#[expect(clippy::panic, reason = "unit-test code")]
406mod test {
407    use super::*;
408
409    #[test]
410    fn normalized_id() {
411        Id::from_normalized_str("foo").expect("should be OK");
412        Id::from_normalized_str("foo::bar").expect_err("shouldn't be OK");
413        Id::from_normalized_str(r#"foo::"bar""#).expect_err("shouldn't be OK");
414        Id::from_normalized_str(" foo").expect_err("shouldn't be OK");
415        Id::from_normalized_str("foo ").expect_err("shouldn't be OK");
416        Id::from_normalized_str("foo\n").expect_err("shouldn't be OK");
417        Id::from_normalized_str("foo//comment").expect_err("shouldn't be OK");
418    }
419}