Skip to main content

cedar_policy_core/ast/
id.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use serde::{Deserialize, Deserializer, Serialize};
18use smol_str::SmolStr;
19
20use crate::{
21    ast::is_normalized_ident, default_from_normalized_str, parser::err::ParseErrors,
22    FromNormalizedStr,
23};
24
25use super::{InternalName, ReservedNameError};
26
27const RESERVED_ID: &str = "__cedar";
28
29/// Identifiers. Anything in `Id` should be a valid identifier, this means it
30/// does not contain, for instance, spaces or characters like '+'; and also is
31/// not one of the Cedar reserved identifiers (at time of writing,
32/// `true | false | if | then | else | in | is | like | has`).
33//
34// For now, internally, `Id`s are just owned `SmolString`s.
35#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
36pub struct Id(SmolStr);
37
38impl Id {
39    /// Create a new `Id` from a `String`, where it is the caller's
40    /// responsibility to ensure that the string is indeed a valid identifier.
41    ///
42    /// When possible, callers should not use this, and instead use `s.parse()`,
43    /// which checks that `s` is a valid identifier, and returns a parse error
44    /// if not.
45    ///
46    /// This method was created for the `From<cst::Ident> for Id` impl to use.
47    /// Since `parser::parse_ident()` implicitly uses that `From` impl itself,
48    /// if we tried to make that `From` impl go through `.parse()` like everyone
49    /// else, we'd get infinite recursion.  And, we assert that `cst::Ident` is
50    /// always already checked to contain a valid identifier, otherwise it would
51    /// never have been created.
52    pub(crate) fn new_unchecked(s: impl Into<SmolStr>) -> Id {
53        Id(s.into())
54    }
55
56    /// Create a new `Id` from a static string
57    ///
58    /// This never allocates.
59    pub(crate) const fn new_unchecked_const(s: &'static str) -> Self {
60        Id(SmolStr::new_static(s))
61    }
62
63    /// Get the underlying string
64    pub fn into_smolstr(self) -> SmolStr {
65        self.0
66    }
67
68    /// Return if the `Id` is reserved (i.e., `__cedar`)
69    /// Note that it does not test if the `Id` string is a reserved keyword
70    /// as the parser already ensures that it is not
71    pub fn is_reserved(&self) -> bool {
72        self.as_ref() == RESERVED_ID
73    }
74
75    /// Return if the `Id` is heap-allocated
76    pub const fn is_heap_allocated(&self) -> bool {
77        self.0.is_heap_allocated()
78    }
79}
80
81impl AsRef<str> for Id {
82    fn as_ref(&self) -> &str {
83        &self.0
84    }
85}
86
87impl std::fmt::Display for Id {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        write!(f, "{}", &self.0)
90    }
91}
92
93// allow `.parse()` on a string to make an `Id`
94impl std::str::FromStr for Id {
95    type Err = ParseErrors;
96
97    fn from_str(s: &str) -> Result<Self, Self::Err> {
98        crate::parser::parse_ident(s)
99    }
100}
101
102impl FromNormalizedStr for Id {
103    fn describe_self() -> &'static str {
104        "Id"
105    }
106
107    // Specialized implementation of `from_normalized_str()` that
108    // uses the optimized implementation `is_normalized_ident()`
109    fn from_normalized_str(s: &str) -> Result<Self, ParseErrors> {
110        if is_normalized_ident(s) {
111            Ok(Self::new_unchecked(s))
112        } else {
113            // Fall back on the default (unoptimized) implementation
114            // to get a nice error message
115            default_from_normalized_str(s, Self::describe_self)
116        }
117    }
118}
119
120/// An `Id` that is not equal to `__cedar`, as specified by RFC 52
121#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
122#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
123#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
124pub struct UnreservedId(#[cfg_attr(feature = "wasm", tsify(type = "string"))] pub(crate) Id);
125
126impl From<UnreservedId> for Id {
127    fn from(value: UnreservedId) -> Self {
128        value.0
129    }
130}
131
132impl TryFrom<Id> for UnreservedId {
133    type Error = ReservedNameError;
134    fn try_from(value: Id) -> Result<Self, Self::Error> {
135        if value.is_reserved() {
136            Err(ReservedNameError(InternalName::unqualified_name(
137                value, None,
138            )))
139        } else {
140            Ok(Self(value))
141        }
142    }
143}
144
145impl AsRef<Id> for UnreservedId {
146    fn as_ref(&self) -> &Id {
147        &self.0
148    }
149}
150
151impl AsRef<str> for UnreservedId {
152    fn as_ref(&self) -> &str {
153        self.0.as_ref()
154    }
155}
156
157impl std::fmt::Display for UnreservedId {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        self.0.fmt(f)
160    }
161}
162
163impl std::str::FromStr for UnreservedId {
164    type Err = ParseErrors;
165    fn from_str(s: &str) -> Result<Self, Self::Err> {
166        Id::from_str(s).and_then(|id| id.try_into().map_err(ParseErrors::singleton))
167    }
168}
169
170impl FromNormalizedStr for UnreservedId {
171    fn describe_self() -> &'static str {
172        "Unreserved Id"
173    }
174}
175
176impl UnreservedId {
177    /// Create an [`UnreservedId`] from an empty string
178    pub(crate) fn empty() -> Self {
179        #[expect(clippy::unwrap_used, reason = "\"\" does not contain `__cedar`")]
180        Id("".into()).try_into().unwrap()
181    }
182
183    /// Get the underlying string
184    pub fn into_smolstr(self) -> SmolStr {
185        self.0.into_smolstr()
186    }
187}
188
189struct IdVisitor;
190
191impl serde::de::Visitor<'_> for IdVisitor {
192    type Value = Id;
193
194    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195        formatter.write_str("a valid id")
196    }
197
198    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
199    where
200        E: serde::de::Error,
201    {
202        Id::from_normalized_str(value)
203            .map_err(|err| serde::de::Error::custom(format!("invalid id `{value}`: {err}")))
204    }
205}
206
207/// Deserialize an `Id` using `from_normalized_str`.
208/// This deserialization implementation is used in the JSON schema format.
209impl<'de> Deserialize<'de> for Id {
210    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
211    where
212        D: Deserializer<'de>,
213    {
214        deserializer.deserialize_str(IdVisitor)
215    }
216}
217
218/// Deserialize a [`UnreservedId`] using `from_normalized_str`
219/// This deserialization implementation is used in the JSON schema format.
220impl<'de> Deserialize<'de> for UnreservedId {
221    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
222    where
223        D: Deserializer<'de>,
224    {
225        deserializer
226            .deserialize_str(IdVisitor)
227            .and_then(|n| n.try_into().map_err(serde::de::Error::custom))
228    }
229}
230
231#[cfg(feature = "arbitrary")]
232impl<'a> arbitrary::Arbitrary<'a> for Id {
233    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
234        // identifier syntax:
235        // IDENT     := ['_''a'-'z''A'-'Z']['_''a'-'z''A'-'Z''0'-'9']* - RESERVED
236        // BOOL      := 'true' | 'false'
237        // RESERVED  := BOOL | 'if' | 'then' | 'else' | 'in' | 'is' | 'like' | 'has'
238
239        let construct_list = |s: &str| s.chars().collect::<Vec<char>>();
240        let list_concat = |s1: &[char], s2: &[char]| [s1, s2].concat();
241        // the set of the first character of an identifier
242        let head_letters = construct_list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");
243        // the set of the remaining characters of an identifier
244        let tail_letters = list_concat(&construct_list("0123456789"), &head_letters);
245        // identifier character count minus 1
246        let remaining_length = u.int_in_range(0..=16)?;
247        let mut cs = vec![*u.choose(&head_letters)?];
248        cs.extend(
249            (0..remaining_length)
250                .map(|_| u.choose(&tail_letters))
251                .collect::<Result<Vec<&char>, _>>()?,
252        );
253        let mut s: String = cs.into_iter().collect();
254        // Should the parsing fails, the string should be reserved word.
255        // Append a `_` to create a valid Id.
256        if crate::parser::parse_ident(&s).is_err() {
257            s.push('_');
258        }
259        Ok(Self::new_unchecked(s))
260    }
261
262    fn size_hint(depth: usize) -> (usize, Option<usize>) {
263        arbitrary::size_hint::and_all(&[
264            // for arbitrary length
265            <usize as arbitrary::Arbitrary>::size_hint(depth),
266            // for arbitrary choices
267            // we use the size hint of a vector of `u8` to get an underestimate of bytes required by the sequence of choices.
268            <Vec<u8> as arbitrary::Arbitrary>::size_hint(depth),
269        ])
270    }
271}
272
273#[cfg(feature = "arbitrary")]
274impl<'a> arbitrary::Arbitrary<'a> for UnreservedId {
275    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
276        let id: Id = u.arbitrary()?;
277        match UnreservedId::try_from(id.clone()) {
278            Ok(id) => Ok(id),
279            Err(_) => {
280                #[expect(clippy::unwrap_used, reason = "`___cedar` is a valid unreserved id")]
281                let new_id = format!("_{id}").parse().unwrap();
282                Ok(new_id)
283            }
284        }
285    }
286
287    fn size_hint(depth: usize) -> (usize, Option<usize>) {
288        <Id as arbitrary::Arbitrary>::size_hint(depth)
289    }
290}
291
292/// Like `Id`, except this specifically _can_ contain Cedar reserved identifiers.
293/// (It still can't contain, for instance, spaces or characters like '+'.)
294//
295// For now, internally, `AnyId`s are just owned `SmolString`s.
296#[derive(Serialize, Debug, PartialEq, Eq, Clone, Hash, PartialOrd, Ord)]
297#[cfg_attr(feature = "wasm", derive(tsify::Tsify))]
298#[cfg_attr(feature = "wasm", tsify(into_wasm_abi, from_wasm_abi))]
299pub struct AnyId(SmolStr);
300
301impl AnyId {
302    /// Create a new `AnyId` from a `String`, where it is the caller's
303    /// responsibility to ensure that the string is indeed a valid `AnyId`.
304    ///
305    /// When possible, callers should not use this, and instead use `s.parse()`,
306    /// which checks that `s` is a valid `AnyId`, and returns a parse error
307    /// if not.
308    ///
309    /// This method was created for the `From<cst::Ident> for AnyId` impl to use.
310    /// See notes on `Id::new_unchecked()`.
311    pub(crate) fn new_unchecked(s: impl Into<SmolStr>) -> AnyId {
312        AnyId(s.into())
313    }
314
315    /// Get the underlying string
316    pub fn into_smolstr(self) -> SmolStr {
317        self.0
318    }
319}
320
321struct AnyIdVisitor;
322
323impl serde::de::Visitor<'_> for AnyIdVisitor {
324    type Value = AnyId;
325
326    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
327        formatter.write_str("any id")
328    }
329
330    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
331    where
332        E: serde::de::Error,
333    {
334        AnyId::from_normalized_str(value)
335            .map_err(|err| serde::de::Error::custom(format!("invalid id `{value}`: {err}")))
336    }
337}
338
339/// Deserialize an `AnyId` using `from_normalized_str`.
340/// This deserialization implementation is used in the JSON policy format.
341impl<'de> Deserialize<'de> for AnyId {
342    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
343    where
344        D: Deserializer<'de>,
345    {
346        deserializer.deserialize_str(AnyIdVisitor)
347    }
348}
349
350impl AsRef<str> for AnyId {
351    fn as_ref(&self) -> &str {
352        &self.0
353    }
354}
355
356impl std::fmt::Display for AnyId {
357    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
358        write!(f, "{}", &self.0)
359    }
360}
361
362// allow `.parse()` on a string to make an `AnyId`
363impl std::str::FromStr for AnyId {
364    type Err = ParseErrors;
365
366    fn from_str(s: &str) -> Result<Self, Self::Err> {
367        crate::parser::parse_anyid(s)
368    }
369}
370
371impl FromNormalizedStr for AnyId {
372    fn describe_self() -> &'static str {
373        "AnyId"
374    }
375}
376
377#[cfg(feature = "arbitrary")]
378impl<'a> arbitrary::Arbitrary<'a> for AnyId {
379    fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
380        // AnyId syntax:
381        // ['_''a'-'z''A'-'Z']['_''a'-'z''A'-'Z''0'-'9']*
382
383        let construct_list = |s: &str| s.chars().collect::<Vec<char>>();
384        let list_concat = |s1: &[char], s2: &[char]| [s1, s2].concat();
385        // the set of the first character of an AnyId
386        let head_letters = construct_list("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");
387        // the set of the remaining characters of an AnyId
388        let tail_letters = list_concat(&construct_list("0123456789"), &head_letters);
389        // identifier character count minus 1
390        let remaining_length = u.int_in_range(0..=16)?;
391        let mut cs = vec![*u.choose(&head_letters)?];
392        cs.extend(
393            (0..remaining_length)
394                .map(|_| u.choose(&tail_letters))
395                .collect::<Result<Vec<&char>, _>>()?,
396        );
397        let s: String = cs.into_iter().collect();
398        debug_assert!(
399            crate::parser::parse_anyid(&s).is_ok(),
400            "all strings constructed this way should be valid AnyIds, but this one is not: {s:?}"
401        );
402        Ok(Self::new_unchecked(s))
403    }
404
405    fn size_hint(depth: usize) -> (usize, Option<usize>) {
406        arbitrary::size_hint::and_all(&[
407            // for arbitrary length
408            <usize as arbitrary::Arbitrary>::size_hint(depth),
409            // for arbitrary choices
410            // we use the size hint of a vector of `u8` to get an underestimate of bytes required by the sequence of choices.
411            <Vec<u8> as arbitrary::Arbitrary>::size_hint(depth),
412        ])
413    }
414}
415
416#[cfg(test)]
417#[expect(clippy::panic, reason = "unit-test code")]
418mod test {
419    use super::*;
420
421    #[test]
422    fn normalized_id() {
423        Id::from_normalized_str("foo").expect("should be OK");
424        Id::from_normalized_str("foo::bar").expect_err("shouldn't be OK");
425        Id::from_normalized_str(r#"foo::"bar""#).expect_err("shouldn't be OK");
426        Id::from_normalized_str(" foo").expect_err("shouldn't be OK");
427        Id::from_normalized_str("foo ").expect_err("shouldn't be OK");
428        Id::from_normalized_str("foo\n").expect_err("shouldn't be OK");
429        Id::from_normalized_str("foo//comment").expect_err("shouldn't be OK");
430    }
431}