Skip to main content

cedar_policy_core/pst/
expr.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! Expression types for PST.
18//!
19//! This module defines the expression tree used in Cedar policy conditions
20//! (`when` / `unless` clauses). Expressions are recursive via [`Arc<Expr>`].
21
22use super::err::{
23    error_body::{self},
24    PstConstructionError,
25};
26use crate::ast;
27use crate::expr_builder::ExprBuilder;
28use crate::extensions::Extensions;
29use smol_str::{SmolStr, ToSmolStr};
30use std::collections::{BTreeMap, HashSet};
31use std::fmt::Display;
32use std::str::FromStr;
33use std::sync::Arc;
34
35/// Constants for core Cedar operator names
36mod constants {
37    // The operators that are defined only in syntax
38    pub static NOT_EQ_STR: &str = "!=";
39    pub static GREATER_STR: &str = ">";
40    pub static GREATER_EQ_STR: &str = ">=";
41    pub static AND_STR: &str = "&&";
42    pub static OR_STR: &str = "||";
43}
44
45/// A validated Cedar identifier.
46///
47/// Wraps a [`SmolStr`] that has been checked to be a valid Cedar identifier
48/// (not a reserved keyword, no special characters, etc.).
49///
50/// The only way to create an `Id` is through [`Id::new()`] (which validates
51/// that the input is a valid identifier) or through conversion from other
52/// validated identifier representations.
53/// Accessing the inner string is free via [`as_str()`](Id::as_str) or
54/// [`into_smolstr()`](Id::into_smolstr).
55///
56/// ```
57/// # use cedar_policy_core::pst::Id;
58/// let id = Id::new("userName").expect("valid identifier");
59/// assert_eq!(id.as_str(), "userName");
60///
61/// // Reserved keywords are rejected:
62/// assert!(Id::new("if").is_err());
63/// assert!(Id::new("true").is_err());
64/// ```
65#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
66pub struct Id(SmolStr);
67
68impl Id {
69    /// Create a new `Id`, validating that the string is a legal Cedar identifier.
70    pub fn new(s: impl AsRef<str>) -> Result<Self, PstConstructionError> {
71        let ast_id = ast::Id::from_str(s.as_ref())?;
72        Ok(Self(ast_id.into_smolstr()))
73    }
74
75    /// Get the underlying string as a `&str`. Zero-cost.
76    pub fn as_str(&self) -> &str {
77        &self.0
78    }
79
80    /// Consume the `Id` and return the underlying `SmolStr`. Zero-cost.
81    pub fn into_smolstr(self) -> SmolStr {
82        self.0
83    }
84}
85
86impl AsRef<str> for Id {
87    fn as_ref(&self) -> &str {
88        &self.0
89    }
90}
91
92impl Display for Id {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        write!(f, "{}", &self.0)
95    }
96}
97
98/// Infallible: `ast::Id` is already validated.
99impl From<ast::Id> for Id {
100    fn from(id: ast::Id) -> Self {
101        Id(id.into_smolstr())
102    }
103}
104
105/// Slot identifier for template policies.
106///
107/// In Cedar, template slots are placeholders written as `?principal` or `?resource`
108/// that get filled in when a template is instantiated into a concrete policy.
109///
110/// ```cedar
111/// permit (
112///   principal == ?principal,
113///   action == Action::"view",
114///   resource in ?resource
115/// );
116/// ```
117///
118/// This enum is `#[non_exhaustive]`; match arms must include a wildcard.
119#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)]
120#[non_exhaustive]
121pub enum SlotId {
122    /// `?principal` slot
123    Principal,
124    /// `?resource` slot
125    Resource,
126}
127
128impl Display for SlotId {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        let b: ast::SlotId = (*self).into();
131        write!(f, "{}", b)
132    }
133}
134
135/// A qualified name (e.g., `Namespace::Type`).
136///
137/// Represents entity types, action names, and other identifiers in Cedar.
138/// Names consist of a basename and optional namespace components.
139///
140/// ```cedar
141/// // Unqualified: just a basename
142/// User
143/// Photo
144///
145/// // Qualified: namespace components followed by basename
146/// MyApp::User
147/// AWS::EC2::Instance
148/// ```
149#[derive(Debug, Clone, PartialEq, Eq, Hash)]
150pub struct Name {
151    /// Basename (the final component of the name)
152    pub id: Id,
153    /// Namespace components (empty for unqualified names)
154    pub namespace: Arc<Vec<Id>>,
155}
156
157impl Name {
158    /// Constructs an unqualified name. This is a convenience constructor that validates
159    /// that `id` is a legal Cedar identifier.
160    ///
161    /// If you have an `Id` (which is `AsRef<str>`), you can infallibly construct the name
162    /// yourself.
163    pub fn unqualified(id: impl AsRef<str>) -> Result<Self, PstConstructionError> {
164        Ok(Name {
165            id: Id::new(id)?,
166            namespace: Arc::new(vec![]),
167        })
168    }
169
170    /// Constructs a qualified name. Validates that all components are legal Cedar identifiers.
171    ///
172    /// If you have an `Id` and a namespace in the form of a `Vec<Id>`, you can infallibly
173    /// construct the name yourself.
174    pub fn qualified<I, T>(namespace: I, id: impl AsRef<str>) -> Result<Self, PstConstructionError>
175    where
176        I: IntoIterator<Item = T>,
177        T: AsRef<str>,
178    {
179        let ns: Result<Vec<Id>, _> = namespace.into_iter().map(|s| Id::new(s)).collect();
180        Ok(Name {
181            id: Id::new(id)?,
182            namespace: Arc::new(ns?),
183        })
184    }
185}
186
187impl Display for Name {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        for elem in self.namespace.as_ref() {
190            write!(f, "{elem}::")?;
191        }
192        write!(f, "{}", self.id)?;
193        Ok(())
194    }
195}
196
197/// Entity type name.
198///
199/// Represents the type of an entity in Cedar.
200///
201/// ```cedar
202/// User            // unqualified
203/// MyApp::Photo    // qualified with namespace
204/// ```
205#[derive(Debug, Clone, PartialEq, Eq, Hash)]
206pub struct EntityType(pub Name);
207
208impl EntityType {
209    /// Create an entity type from a name
210    pub fn from_name(name: impl Into<Name>) -> Self {
211        EntityType(name.into())
212    }
213}
214
215impl Display for EntityType {
216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        let ast_et: ast::EntityType = self.clone().into();
218        write!(f, "{}", ast_et)
219    }
220}
221
222/// Entity unique identifier (UID).
223///
224/// Represents a specific entity instance in Cedar, written as `Type::"id"`.
225///
226/// ```cedar
227/// User::"alice"
228/// Photo::"vacation.jpg"
229/// MyApp::Action::"readFile"
230/// ```
231#[derive(Debug, Clone, PartialEq, Eq, Hash)]
232pub struct EntityUID {
233    /// Type of the entity
234    pub ty: EntityType,
235    /// Entity identifier (EID)
236    pub eid: SmolStr,
237}
238
239impl Display for EntityUID {
240    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        write!(f, "{}::\"{}\"", self.ty, self.eid.as_str().escape_default())
242    }
243}
244
245/// Variables available in Cedar policy expressions.
246///
247/// Cedar provides four built-in variables that refer to the authorization request:
248///
249/// ```cedar
250/// principal       // the entity making the request
251/// action          // the action being requested
252/// resource        // the entity the action targets
253/// context         // the request context record
254/// ```
255#[derive(Debug, Clone, PartialEq, Eq, Hash)]
256pub enum Var {
257    /// `principal` — the entity making the request
258    Principal,
259    /// `action` — the action being requested
260    Action,
261    /// `resource` — the entity the action targets
262    Resource,
263    /// `context` — the request context record
264    Context,
265}
266
267/// Unary operators in Cedar expressions.
268///
269/// Includes built-in operators and extension functions that take a single argument.
270///
271/// This enum is `#[non_exhaustive]`; match arms must include a wildcard.
272///
273/// ```cedar
274/// // Built-in operators
275/// !context.is_admin           // Not
276/// -(1)                        // Neg
277/// [].isEmpty()                // IsEmpty
278///
279/// // Extension constructors
280/// decimal("1.23")             // Decimal
281/// ip("10.0.0.1")              // Ip
282/// datetime("2024-01-01")      // Datetime
283/// duration("1h30m")           // Duration
284///
285/// // IP extension methods
286/// ip("10.0.0.1").isIpv4()     // IsIPv4
287/// ip("::1").isIpv6()          // IsIPV6
288/// ip("127.0.0.1").isLoopback()   // IsLoopback
289/// ip("224.0.0.1").isMulticast()  // IsMulticast
290///
291/// // Datetime extension methods
292/// datetime("2024-01-01").toDate()           // ToDate
293/// datetime("2024-01-01T12:00:00Z").toTime() // ToTime
294/// duration("1h30m").toMilliseconds()        // ToMilliseconds
295/// duration("1h30m").toSeconds()             // ToSeconds
296/// duration("1h30m").toMinutes()             // ToMinutes
297/// duration("1h30m").toHours()               // ToHours
298/// duration("30d").toDays()                  // ToDays
299/// ```
300#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
301#[non_exhaustive]
302pub enum UnaryOp {
303    /// `!expr`
304    Not,
305    /// `-(expr)`
306    Neg,
307    /// `expr.isEmpty()`
308    IsEmpty,
309    /// `datetime("...")`
310    Datetime,
311    /// `decimal("...")`
312    Decimal,
313    /// `duration("...")`
314    Duration,
315    /// `ip("...")`
316    Ip,
317    /// `expr.isIpv4()`
318    IsIPv4,
319    /// `expr.isIpv6()`
320    IsIPV6,
321    /// `expr.isLoopback()`
322    IsLoopback,
323    /// `expr.isMulticast()`
324    IsMulticast,
325    /// `expr.toDate()`
326    ToDate,
327    /// `expr.toTime()`
328    ToTime,
329    /// `expr.toMilliseconds()`
330    ToMilliseconds,
331    /// `expr.toSeconds()`
332    ToSeconds,
333    /// `expr.toMinutes()`
334    ToMinutes,
335    /// `expr.toHours()`
336    ToHours,
337    /// `expr.toDays()`
338    ToDays,
339}
340
341impl UnaryOp {
342    pub(crate) fn to_name(self) -> Option<&'static ast::Name> {
343        // We get the names of the extension functions from where they are defined: we don't duplicate
344        // name definitions.
345        use crate::extensions;
346        match self {
347            UnaryOp::IsEmpty | UnaryOp::Neg | UnaryOp::Not => None,
348            UnaryOp::Datetime => Some(&extensions::datetime::constants::DATETIME_CONSTRUCTOR_NAME),
349            UnaryOp::Decimal => Some(&extensions::decimal::constants::DECIMAL_FROM_STR_NAME),
350            UnaryOp::Duration => Some(&extensions::datetime::constants::DURATION_CONSTRUCTOR_NAME),
351            UnaryOp::Ip => Some(&extensions::ipaddr::names::IP_FROM_STR_NAME),
352            UnaryOp::IsIPv4 => Some(&extensions::ipaddr::names::IS_IPV4),
353            UnaryOp::IsIPV6 => Some(&extensions::ipaddr::names::IS_IPV6),
354            UnaryOp::IsLoopback => Some(&extensions::ipaddr::names::IS_LOOPBACK),
355            UnaryOp::IsMulticast => Some(&extensions::ipaddr::names::IS_MULTICAST),
356            UnaryOp::ToDate => Some(&extensions::datetime::constants::TO_DATE_NAME),
357            UnaryOp::ToTime => Some(&extensions::datetime::constants::TO_TIME_NAME),
358            UnaryOp::ToMilliseconds => Some(&extensions::datetime::constants::TO_MILLISECONDS_NAME),
359            UnaryOp::ToSeconds => Some(&extensions::datetime::constants::TO_SECONDS_NAME),
360            UnaryOp::ToMinutes => Some(&extensions::datetime::constants::TO_MINUTES_NAME),
361            UnaryOp::ToHours => Some(&extensions::datetime::constants::TO_HOURS_NAME),
362            UnaryOp::ToDays => Some(&extensions::datetime::constants::TO_DAYS_NAME),
363        }
364    }
365
366    /// Parse a unary operator from a function name
367    pub(crate) fn from_function_name(name: &str) -> Option<Self> {
368        match name {
369            "decimal" => Some(UnaryOp::Decimal),
370            "datetime" => Some(UnaryOp::Datetime),
371            "duration" => Some(UnaryOp::Duration),
372            "ip" => Some(UnaryOp::Ip),
373            "isIpv4" => Some(UnaryOp::IsIPv4),
374            "isIpv6" => Some(UnaryOp::IsIPV6),
375            "isLoopback" => Some(UnaryOp::IsLoopback),
376            "isMulticast" => Some(UnaryOp::IsMulticast),
377            "toDate" => Some(UnaryOp::ToDate),
378            "toTime" => Some(UnaryOp::ToTime),
379            "toMilliseconds" => Some(UnaryOp::ToMilliseconds),
380            "toSeconds" => Some(UnaryOp::ToSeconds),
381            "toMinutes" => Some(UnaryOp::ToMinutes),
382            "toHours" => Some(UnaryOp::ToHours),
383            "toDays" => Some(UnaryOp::ToDays),
384            _ => None,
385        }
386    }
387}
388
389impl Display for UnaryOp {
390    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391        match self {
392            UnaryOp::Not => write!(f, "{}", ast::UnaryOp::Not),
393            UnaryOp::Neg => write!(f, "{}", ast::UnaryOp::Neg),
394            UnaryOp::IsEmpty => write!(f, "{}", ast::UnaryOp::IsEmpty),
395            // Extension functions - use their name
396            _ => match self.to_name() {
397                Some(name) => write!(f, "{}", name),
398                None => write!(f, "<impossible operator>"),
399            },
400        }
401    }
402}
403
404/// Binary operators in Cedar expressions.
405///
406/// Includes built-in operators and extension functions that take two arguments.
407///
408/// This enum is `#[non_exhaustive]`; match arms must include a wildcard.
409///
410/// ```cedar
411/// // Comparison
412/// principal == User::"alice"          // Eq
413/// principal != User::"bob"            // NotEq
414/// context.age < 18                    // Less
415/// context.age <= 21                   // LessEq
416/// context.age > 13                    // Greater
417/// context.age >= 65                   // GreaterEq
418///
419/// // Logical
420/// true && false                       // And
421/// true || false                       // Or
422///
423/// // Arithmetic
424/// context.x + 1                       // Add
425/// context.x - 1                       // Sub
426/// context.x * 2                       // Mul
427///
428/// // Hierarchy / set
429/// principal in Group::"admins"        // In
430/// [1, 2, 3].contains(2)              // Contains
431/// [1, 2].containsAll([1])            // ContainsAll
432/// [1, 2].containsAny([2, 3])         // ContainsAny
433///
434/// // Tags
435/// resource.hasTag("env")              // HasTag
436/// resource.getTag("env")              // GetTag
437///
438/// // IP extension
439/// ip("10.0.0.1").isInRange(ip("10.0.0.0/24"))  // IsInRange
440///
441/// // Datetime extension
442/// datetime("2024-01-01").offset(duration("1d")) // Offset
443/// datetime("2024-01-02").durationSince(datetime("2024-01-01")) // DurationSince
444/// ```
445#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
446#[non_exhaustive]
447pub enum BinaryOp {
448    /// `left == right`
449    Eq,
450    /// `left != right`
451    NotEq,
452    /// `left < right`
453    Less,
454    /// `left <= right`
455    LessEq,
456    /// `left > right`
457    Greater,
458    /// `left >= right`
459    GreaterEq,
460    /// `left && right`
461    And,
462    /// `left || right`
463    Or,
464    /// `left + right`
465    Add,
466    /// `left - right`
467    Sub,
468    /// `left * right`
469    Mul,
470    /// `left in right`
471    In,
472    /// `left.contains(right)`
473    Contains,
474    /// `left.containsAll(right)`
475    ContainsAll,
476    /// `left.containsAny(right)`
477    ContainsAny,
478    /// `left.getTag(right)`
479    GetTag,
480    /// `left.hasTag(right)`
481    HasTag,
482    /// `left.isInRange(right)`
483    IsInRange,
484    /// `left.offset(right)`
485    Offset,
486    /// `left.durationSince(right)`
487    DurationSince,
488    /// `left.lessThan(right)` (decimal less than)
489    DecimalLessThan,
490    /// `left.lessThanOrEqual(right)` (decimal less than or equal)
491    DecimalLessEq,
492    /// `left.greaterThan(right)` (decimal greater than)
493    DecimalGreater,
494    /// `left.greaterThanOrEqual(right)` (decimal greater than or equal)
495    DecimalGreaterEq,
496}
497
498impl BinaryOp {
499    pub(crate) fn to_name(self) -> Option<&'static ast::Name> {
500        use crate::extensions;
501        match self {
502            BinaryOp::IsInRange => Some(&extensions::ipaddr::names::IS_IN_RANGE),
503            BinaryOp::Offset => Some(&extensions::datetime::constants::OFFSET_METHOD_NAME),
504            BinaryOp::DurationSince => Some(&extensions::datetime::constants::DURATION_SINCE_NAME),
505            BinaryOp::DecimalLessThan => Some(&extensions::decimal::constants::LESS_THAN),
506            BinaryOp::DecimalLessEq => Some(&extensions::decimal::constants::LESS_THAN_OR_EQUAL),
507            BinaryOp::DecimalGreater => Some(&extensions::decimal::constants::GREATER_THAN),
508            BinaryOp::DecimalGreaterEq => {
509                Some(&extensions::decimal::constants::GREATER_THAN_OR_EQUAL)
510            }
511            // those are operators, not names
512            BinaryOp::Eq
513            | BinaryOp::NotEq
514            | BinaryOp::And
515            | BinaryOp::Or
516            | BinaryOp::Less
517            | BinaryOp::LessEq
518            | BinaryOp::Greater
519            | BinaryOp::GreaterEq
520            | BinaryOp::Add
521            | BinaryOp::Sub
522            | BinaryOp::Mul
523            | BinaryOp::In
524            | BinaryOp::Contains
525            | BinaryOp::ContainsAll
526            | BinaryOp::ContainsAny
527            | BinaryOp::GetTag
528            | BinaryOp::HasTag => None,
529        }
530    }
531
532    /// Parse a binary operator from a function name
533    pub(crate) fn from_function_name(name: &str) -> Option<Self> {
534        match name {
535            "lessThan" => Some(BinaryOp::DecimalLessThan),
536            "lessThanOrEqual" => Some(BinaryOp::DecimalLessEq),
537            "greaterThan" => Some(BinaryOp::DecimalGreater),
538            "greaterThanOrEqual" => Some(BinaryOp::DecimalGreaterEq),
539            "isInRange" => Some(BinaryOp::IsInRange),
540            "offset" => Some(BinaryOp::Offset),
541            "durationSince" => Some(BinaryOp::DurationSince),
542            _ => None,
543        }
544    }
545}
546
547impl Display for BinaryOp {
548    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
549        match self {
550            BinaryOp::Eq => write!(f, "{}", ast::BinaryOp::Eq),
551            BinaryOp::NotEq => write!(f, "{}", &constants::NOT_EQ_STR),
552            BinaryOp::Less => write!(f, "{}", ast::BinaryOp::Less),
553            BinaryOp::LessEq => write!(f, "{}", ast::BinaryOp::LessEq),
554            BinaryOp::Greater => write!(f, "{}", &constants::GREATER_STR),
555            BinaryOp::GreaterEq => write!(f, "{}", &constants::GREATER_EQ_STR),
556            BinaryOp::And => write!(f, "{}", &constants::AND_STR),
557            BinaryOp::Or => write!(f, "{}", &constants::OR_STR),
558            BinaryOp::Add => write!(f, "{}", ast::BinaryOp::Add),
559            BinaryOp::Sub => write!(f, "{}", ast::BinaryOp::Sub),
560            BinaryOp::Mul => write!(f, "{}", ast::BinaryOp::Mul),
561            BinaryOp::In => write!(f, "{}", ast::BinaryOp::In),
562            BinaryOp::Contains => write!(f, "{}", ast::BinaryOp::Contains),
563            BinaryOp::ContainsAll => write!(f, "{}", ast::BinaryOp::ContainsAll),
564            BinaryOp::ContainsAny => write!(f, "{}", ast::BinaryOp::ContainsAny),
565            BinaryOp::GetTag => write!(f, "{}", ast::BinaryOp::GetTag),
566            BinaryOp::HasTag => write!(f, "{}", ast::BinaryOp::HasTag),
567            // Extension functions - use their name
568            _ => match self.to_name() {
569                Some(name) => write!(f, "{}", name),
570                None => write!(f, "<impossible operator>"),
571            },
572        }
573    }
574}
575
576/// Literal values in Cedar expressions.
577///
578/// This enum is `#[non_exhaustive]`; match arms must include a wildcard.
579///
580/// ```cedar
581/// true                    // Bool
582/// 42                      // Long
583/// "hello"                 // String
584/// User::"alice"           // EntityUID
585/// ```
586#[derive(Debug, Clone, PartialEq, Eq, Hash)]
587#[non_exhaustive]
588pub enum Literal {
589    /// `true` or `false`
590    Bool(bool),
591    /// Integer literal (e.g., `42`, `-1`)
592    Long(i64),
593    /// String literal (e.g., `"hello"`)
594    String(SmolStr),
595    /// Entity UID literal (e.g., `User::"alice"`)
596    EntityUID(EntityUID),
597}
598
599/// Pattern element for `like` expressions.
600///
601/// A pattern is a sequence of literal characters and wildcards used with the `like` operator:
602///
603/// ```cedar
604/// resource.name like "*.jpg"      // Wildcard then Char('.')...
605/// resource.name like "photo_*"    // Char('p')... then Wildcard
606/// ```
607#[derive(Debug, Clone, PartialEq, Eq, Hash)]
608pub enum PatternElem {
609    /// A literal character in the pattern
610    Char(char),
611    /// A wildcard (`*`) matching zero or more characters
612    Wildcard,
613}
614
615/// PST Expression — the core expression type for Cedar policy conditions.
616///
617/// This enum is `#[non_exhaustive]`; match arms must include a wildcard.
618///
619/// Each variant corresponds to a Cedar syntax construct. See individual variant docs
620/// for the Cedar syntax each one represents.
621#[derive(Debug, Clone, PartialEq, Eq)]
622#[non_exhaustive]
623pub enum Expr {
624    /// A literal value: `true`, `42`, `"hello"`, or `User::"alice"`.
625    Literal(Literal),
626    /// A built-in variable: `principal`, `action`, `resource`, or `context`.
627    Var(Var),
628    /// A template slot: `?principal` or `?resource`.
629    Slot(SlotId),
630    /// A unary operation.
631    ///
632    /// ```cedar
633    /// !expr           // UnaryOp::Not
634    /// -(expr)         // UnaryOp::Neg
635    /// expr.isEmpty()  // UnaryOp::IsEmpty
636    /// decimal("1.0")  // UnaryOp::Decimal
637    /// ```
638    UnaryOp {
639        /// The operator
640        op: UnaryOp,
641        /// The operand
642        expr: Arc<Expr>,
643    },
644    /// A binary operation.
645    ///
646    /// ```cedar
647    /// context.age >= 18                   // BinaryOp::GreaterEq
648    /// principal in Group::"admins"        // BinaryOp::In
649    /// [1, 2].contains(1)                 // BinaryOp::Contains
650    /// ```
651    BinaryOp {
652        /// The operator
653        op: BinaryOp,
654        /// Left operand
655        left: Arc<Expr>,
656        /// Right operand
657        right: Arc<Expr>,
658    },
659    /// Attribute access.
660    ///
661    /// ```cedar
662    /// principal.name
663    /// context.request.ip
664    /// ```
665    GetAttr {
666        /// Expression to get attribute from
667        expr: Arc<Expr>,
668        /// Attribute name
669        attr: SmolStr,
670    },
671    /// Attribute existence check. Can check nested attributes.
672    ///
673    /// ```cedar
674    /// principal has name
675    /// principal has "0notACedarIdent"
676    /// principal has address.street
677    /// ```
678    /// If there are more than one attribute, all attributes must be valid Cedar identifiers.
679    HasAttr {
680        /// Expression to check for attribute
681        expr: Arc<Expr>,
682        /// Attribute path (non-empty; multiple elements for nested checks)
683        attrs: nonempty::NonEmpty<SmolStr>,
684    },
685    /// Pattern matching with the `like` operator.
686    ///
687    /// ```cedar
688    /// resource.name like "*.jpg"
689    /// ```
690    Like {
691        /// Expression to match
692        expr: Arc<Expr>,
693        /// Pattern to match against
694        pattern: Vec<PatternElem>,
695    },
696    /// Entity type test, optionally combined with a hierarchy check.
697    ///
698    /// ```cedar
699    /// principal is User
700    /// principal is User in Group::"admins"
701    /// ```
702    Is {
703        /// Expression to test
704        expr: Arc<Expr>,
705        /// Entity type to test for
706        entity_type: EntityType,
707        /// Optional `in` hierarchy parent
708        in_expr: Option<Arc<Expr>>,
709    },
710    /// Conditional expression.
711    ///
712    /// ```cedar
713    /// if context.is_admin then "yes" else "no"
714    /// ```
715    IfThenElse {
716        /// Condition
717        cond: Arc<Expr>,
718        /// Then branch
719        then_expr: Arc<Expr>,
720        /// Else branch
721        else_expr: Arc<Expr>,
722    },
723    /// Set literal.
724    ///
725    /// ```cedar
726    /// [1, 2, 3]
727    /// [User::"alice", User::"bob"]
728    /// ```
729    Set(Vec<Arc<Expr>>),
730    /// Record literal.
731    ///
732    /// ```cedar
733    /// {"key": "value", "count": 42}
734    /// ```
735    Record(BTreeMap<String, Arc<Expr>>),
736    /// An unknown value for partial evaluation (not part of Cedar surface syntax).
737    Unknown {
738        /// Name of the unknown
739        name: SmolStr,
740    },
741    /// A TPE residual error node: indicates that this subexpression would
742    /// produce an evaluation error if reached at runtime.
743    ///
744    /// This is distinct from tolerant-ast parse errors — `ResidualError`
745    /// represents a semantically meaningful result from type-aware partial
746    /// evaluation (e.g., arithmetic overflow).
747    #[cfg(feature = "tpe")]
748    ResidualError,
749}
750
751impl Expr {
752    /// Transform a function call with arguments into a PST expression given the [`ast::Name`] of
753    /// the function. Clones the string representation of the `ast::Name` given.
754    pub(crate) fn from_function_ast_name_and_args(
755        name: &ast::Name,
756        args: Vec<Arc<Expr>>,
757    ) -> Result<Expr, PstConstructionError> {
758        Self::from_function_names_and_args(name.to_smolstr(), name, args)
759    }
760
761    /// Transform a function call with arguments into a PST expression given the [`ast::Name`] of
762    /// the function, and its [SmolStr] name.
763    /// Assumes the two names's representation as strings are equivalent, and does not clone.
764    fn from_function_names_and_args(
765        name: SmolStr,
766        ast_name: &ast::Name,
767        args: Vec<Arc<Expr>>,
768    ) -> Result<Expr, PstConstructionError> {
769        // TPE residual error nodes are represented as `error()` in the AST.
770        // Intercept them here before the extension function lookup fails.
771        #[cfg(feature = "tpe")]
772        if *ast_name == *crate::tpe::residual::ERROR_NAME {
773            return Ok(Expr::ResidualError);
774        }
775
776        let extension = Extensions::all_available().func(ast_name)?;
777
778        let expected = extension.arg_types().len();
779        let got = args.len();
780
781        if expected != got {
782            return Err(error_body::WrongArityError::new(name.into(), expected, got).into());
783        }
784        Ok(match args.len() {
785            1 => {
786                #[expect(clippy::unwrap_used, reason = "length = 1 checked in arm")]
787                let expr = args.into_iter().next().unwrap();
788                // Special case: the unknown function
789                if ast_name.to_string() == "unknown" {
790                    return Ok(Expr::Unknown {
791                        name: format!("{}", expr).into(),
792                    });
793                }
794                let op = UnaryOp::from_function_name(&ast_name.to_string())
795                    .ok_or_else(|| error_body::UnknownFunctionError::new(name.clone()))?;
796                Expr::UnaryOp { op, expr }
797            }
798            2 => {
799                let op = BinaryOp::from_function_name(&ast_name.to_string())
800                    .ok_or_else(|| error_body::UnknownFunctionError::new(name.clone()))?;
801                let mut iter = args.into_iter();
802                Expr::BinaryOp {
803                    op,
804                    #[expect(clippy::unwrap_used, reason = "length = 2 checked in match arm")]
805                    left: iter.next().unwrap(),
806                    #[expect(clippy::unwrap_used, reason = "length = 2 checked in match arm")]
807                    right: iter.next().unwrap(),
808                }
809            }
810            _ => return Err(error_body::UnknownFunctionError::new(name).into()),
811        })
812    }
813
814    // === Expression reduction functions ===
815
816    /// Recursively accumulate a value over this expression tree.
817    ///
818    /// At each node, `f` is called first. If it returns `Some(t)`, that value is returned
819    /// immediately without recursing into children. Otherwise, the results of recursing into
820    /// all child expressions are merged pairwise with `op`. If a node has no children,
821    /// `zero` is returned.
822    pub fn reduce<T: Clone + Sized>(
823        &self,
824        f: &dyn Fn(&Self) -> Option<T>,
825        op: &dyn Fn(T, T) -> T,
826        zero: T,
827    ) -> T {
828        if let Some(t) = f(self) {
829            return t;
830        }
831        let recurse = |e: &Arc<Self>| e.reduce(f, op, zero.clone());
832        match self {
833            Expr::Literal(_) | Expr::Var(_) | Expr::Slot(_) | Expr::Unknown { .. } => zero,
834            #[cfg(feature = "tpe")]
835            Expr::ResidualError => zero,
836            Expr::UnaryOp { expr, .. }
837            | Expr::GetAttr { expr, .. }
838            | Expr::HasAttr { expr, .. }
839            | Expr::Like { expr, .. } => recurse(expr),
840            Expr::BinaryOp { left, right, .. } => op(recurse(left), recurse(right)),
841            Expr::Is { expr, in_expr, .. } => match in_expr {
842                Some(e) => op(recurse(expr), recurse(e)),
843                None => recurse(expr),
844            },
845            Expr::IfThenElse {
846                cond,
847                then_expr,
848                else_expr,
849            } => op(op(recurse(cond), recurse(then_expr)), recurse(else_expr)),
850            Expr::Set(exprs) => {
851                let mut iter = exprs.iter();
852                match iter.next() {
853                    None => zero,
854                    Some(first) => iter.fold(recurse(first), |acc, e| op(acc, recurse(e))),
855                }
856            }
857            Expr::Record(map) => {
858                let mut iter = map.values();
859                match iter.next() {
860                    None => zero,
861                    Some(first) => iter.fold(recurse(first), |acc, e| op(acc, recurse(e))),
862                }
863            }
864        }
865    }
866
867    /// Does this expression contain any slots?
868    pub fn has_slots(&self) -> bool {
869        self.reduce::<bool>(
870            &|e| match e {
871                Expr::Slot(_) => Some(true),
872                _ => None,
873            },
874            &|a, b| a || b,
875            false,
876        )
877    }
878
879    /// Does this expression contain any [`Expr::Unknown`] nodes?
880    pub fn has_unknowns(&self) -> bool {
881        self.reduce::<bool>(
882            &|e| match e {
883                Expr::Unknown { .. } => Some(true),
884                _ => None,
885            },
886            &|a, b| a || b,
887            false,
888        )
889    }
890
891    /// Return the slots used in this expression
892    pub fn slots(&self) -> HashSet<SlotId> {
893        self.reduce::<HashSet<SlotId>>(
894            &|e| match e {
895                Expr::Slot(id) => Some(HashSet::from([*id])),
896                _ => None,
897            },
898            &|a, b| a.union(&b).copied().collect(),
899            HashSet::new(),
900        )
901    }
902
903    /// Does this expression contain any [`Expr::ResidualError`] nodes?
904    ///
905    /// Returns `true` if any subexpression represents a statically-known
906    /// evaluation error (as determined by TPE).
907    #[cfg(feature = "tpe")]
908    pub fn has_error(&self) -> bool {
909        self.reduce::<bool>(
910            &|e| match e {
911                Expr::ResidualError => Some(true),
912                _ => None,
913            },
914            &|a, b| a || b,
915            false,
916        )
917    }
918}
919
920/// Builder to construct a PST [`Expr`] that implements the [`ExprBuilder`] interface. Unlike the
921/// expression building functions, this does not perform any validation on the input and is meant
922/// to be used internally.
923#[derive(Clone, Debug)]
924pub(crate) struct PstBuilder;
925
926impl ExprBuilder for PstBuilder {
927    type Expr = Expr;
928    type Data = ();
929    type BuildError = PstConstructionError;
930
931    #[cfg(feature = "tolerant-ast")]
932    type ErrorType = crate::parser::err::ParseErrors;
933
934    fn with_data(_data: Self::Data) -> Self {
935        Self
936    }
937
938    fn with_maybe_source_loc(self, _: Option<&crate::parser::Loc>) -> Self {
939        // PST doesn't store source locations
940        self
941    }
942
943    fn loc(&self) -> Option<&crate::parser::Loc> {
944        None
945    }
946
947    fn data(&self) -> &Self::Data {
948        &()
949    }
950
951    fn val(self, lit: impl Into<ast::Literal>) -> Expr {
952        Expr::Literal(From::<ast::Literal>::from(lit.into()))
953    }
954
955    fn var(self, var: ast::Var) -> Expr {
956        Expr::Var(var.into())
957    }
958
959    fn unknown(self, u: ast::Unknown) -> Expr {
960        Expr::Unknown { name: u.name }
961    }
962
963    fn slot(self, s: ast::SlotId) -> Expr {
964        Expr::Slot(s.into())
965    }
966
967    fn ite_arc(self, cond: Arc<Expr>, then_expr: Arc<Expr>, else_expr: Arc<Expr>) -> Expr {
968        Expr::IfThenElse {
969            cond,
970            then_expr,
971            else_expr,
972        }
973    }
974
975    fn not(self, e: Expr) -> Expr {
976        Expr::UnaryOp {
977            op: UnaryOp::Not,
978            expr: Arc::new(e),
979        }
980    }
981
982    fn is_eq(self, e1: Expr, e2: Expr) -> Expr {
983        Expr::BinaryOp {
984            op: BinaryOp::Eq,
985            left: Arc::new(e1),
986            right: Arc::new(e2),
987        }
988    }
989
990    fn noteq(self, e1: Expr, e2: Expr) -> Expr {
991        Expr::BinaryOp {
992            op: BinaryOp::NotEq,
993            left: Arc::new(e1),
994            right: Arc::new(e2),
995        }
996    }
997
998    fn and(self, e1: Expr, e2: Expr) -> Expr {
999        Expr::BinaryOp {
1000            op: BinaryOp::And,
1001            left: Arc::new(e1),
1002            right: Arc::new(e2),
1003        }
1004    }
1005
1006    fn or(self, e1: Expr, e2: Expr) -> Expr {
1007        Expr::BinaryOp {
1008            op: BinaryOp::Or,
1009            left: Arc::new(e1),
1010            right: Arc::new(e2),
1011        }
1012    }
1013
1014    fn less(self, e1: Expr, e2: Expr) -> Expr {
1015        Expr::BinaryOp {
1016            op: BinaryOp::Less,
1017            left: Arc::new(e1),
1018            right: Arc::new(e2),
1019        }
1020    }
1021
1022    fn lesseq(self, e1: Expr, e2: Expr) -> Expr {
1023        Expr::BinaryOp {
1024            op: BinaryOp::LessEq,
1025            left: Arc::new(e1),
1026            right: Arc::new(e2),
1027        }
1028    }
1029
1030    fn greater(self, e1: Expr, e2: Expr) -> Expr {
1031        Expr::BinaryOp {
1032            op: BinaryOp::Greater,
1033            left: Arc::new(e1),
1034            right: Arc::new(e2),
1035        }
1036    }
1037
1038    fn greatereq(self, e1: Expr, e2: Expr) -> Expr {
1039        Expr::BinaryOp {
1040            op: BinaryOp::GreaterEq,
1041            left: Arc::new(e1),
1042            right: Arc::new(e2),
1043        }
1044    }
1045
1046    fn add(self, e1: Expr, e2: Expr) -> Expr {
1047        Expr::BinaryOp {
1048            op: BinaryOp::Add,
1049            left: Arc::new(e1),
1050            right: Arc::new(e2),
1051        }
1052    }
1053
1054    fn sub(self, e1: Expr, e2: Expr) -> Expr {
1055        Expr::BinaryOp {
1056            op: BinaryOp::Sub,
1057            left: Arc::new(e1),
1058            right: Arc::new(e2),
1059        }
1060    }
1061
1062    fn mul(self, e1: Expr, e2: Expr) -> Expr {
1063        Expr::BinaryOp {
1064            op: BinaryOp::Mul,
1065            left: Arc::new(e1),
1066            right: Arc::new(e2),
1067        }
1068    }
1069
1070    fn neg(self, e: Expr) -> Expr {
1071        Expr::UnaryOp {
1072            op: UnaryOp::Neg,
1073            expr: Arc::new(e),
1074        }
1075    }
1076
1077    fn is_in_arc(self, left: Arc<Expr>, right: Arc<Expr>) -> Expr {
1078        Expr::BinaryOp {
1079            op: BinaryOp::In,
1080            left,
1081            right,
1082        }
1083    }
1084
1085    fn contains(self, e1: Expr, e2: Expr) -> Expr {
1086        Expr::BinaryOp {
1087            op: BinaryOp::Contains,
1088            left: Arc::new(e1),
1089            right: Arc::new(e2),
1090        }
1091    }
1092
1093    fn contains_all(self, e1: Expr, e2: Expr) -> Expr {
1094        Expr::BinaryOp {
1095            op: BinaryOp::ContainsAll,
1096            left: Arc::new(e1),
1097            right: Arc::new(e2),
1098        }
1099    }
1100
1101    fn contains_any(self, e1: Expr, e2: Expr) -> Expr {
1102        Expr::BinaryOp {
1103            op: BinaryOp::ContainsAny,
1104            left: Arc::new(e1),
1105            right: Arc::new(e2),
1106        }
1107    }
1108
1109    fn is_empty(self, expr: Expr) -> Expr {
1110        Expr::UnaryOp {
1111            op: UnaryOp::IsEmpty,
1112            expr: Arc::new(expr),
1113        }
1114    }
1115
1116    fn get_tag(self, expr: Expr, tag: Expr) -> Expr {
1117        Expr::BinaryOp {
1118            op: BinaryOp::GetTag,
1119            left: Arc::new(expr),
1120            right: Arc::new(tag),
1121        }
1122    }
1123
1124    fn has_tag(self, expr: Expr, tag: Expr) -> Expr {
1125        Expr::BinaryOp {
1126            op: BinaryOp::HasTag,
1127            left: Arc::new(expr),
1128            right: Arc::new(tag),
1129        }
1130    }
1131
1132    fn set(self, exprs: impl IntoIterator<Item = Expr>) -> Expr {
1133        Expr::Set(exprs.into_iter().map(Arc::new).collect())
1134    }
1135
1136    fn record(
1137        self,
1138        pairs: impl IntoIterator<Item = (SmolStr, Expr)>,
1139    ) -> Result<Expr, ast::ExpressionConstructionError> {
1140        let mut map = BTreeMap::new();
1141        for (k, v) in pairs {
1142            if map.insert(k.to_string(), Arc::new(v)).is_some() {
1143                return Err(ast::expression_construction_errors::DuplicateKeyError {
1144                    key: k,
1145                    context: "in record literal",
1146                }
1147                .into());
1148            }
1149        }
1150        Ok(Expr::Record(map))
1151    }
1152
1153    fn call_extension_fn(
1154        self,
1155        fn_name: ast::Name,
1156        args: impl IntoIterator<Item = Expr>,
1157    ) -> Result<Expr, PstConstructionError> {
1158        Expr::from_function_ast_name_and_args(&fn_name, args.into_iter().map(Arc::new).collect())
1159    }
1160
1161    fn get_attr_arc(self, expr: Arc<Expr>, attr: SmolStr) -> Expr {
1162        Expr::GetAttr { expr, attr }
1163    }
1164
1165    fn has_attr_arc(self, expr: Arc<Expr>, attr: SmolStr) -> Expr {
1166        Expr::HasAttr {
1167            expr,
1168            attrs: nonempty::nonempty![attr],
1169        }
1170    }
1171
1172    fn extended_has_attr_arc(self, expr: Arc<Expr>, attrs: nonempty::NonEmpty<SmolStr>) -> Expr {
1173        Expr::HasAttr { expr, attrs }
1174    }
1175
1176    fn like(self, expr: Expr, pattern: ast::Pattern) -> Expr {
1177        Expr::Like {
1178            expr: Arc::new(expr),
1179            pattern: pattern.into(),
1180        }
1181    }
1182
1183    fn is_entity_type_arc(self, expr: Arc<Expr>, entity_type: ast::EntityType) -> Expr {
1184        Expr::Is {
1185            expr,
1186            entity_type: entity_type.into(),
1187            in_expr: None,
1188        }
1189    }
1190
1191    fn is_in_entity_type(self, e1: Expr, entity_type: ast::EntityType, e2: Expr) -> Expr {
1192        Expr::Is {
1193            expr: Arc::new(e1),
1194            entity_type: entity_type.into(),
1195            in_expr: Some(Arc::new(e2)),
1196        }
1197    }
1198
1199    #[cfg(feature = "tolerant-ast")]
1200    fn error(
1201        self,
1202        parse_errors: crate::parser::err::ParseErrors,
1203    ) -> Result<Self::Expr, Self::ErrorType> {
1204        // PST doesn't support error nodes for now, it will propagate parse errors
1205        Err(parse_errors)
1206    }
1207}
1208
1209impl std::fmt::Display for Expr {
1210    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1211        let est: crate::est::Expr = self.clone().into();
1212        write!(f, "{est}")
1213    }
1214}
1215
1216#[expect(
1217    clippy::fallible_impl_from,
1218    reason = "AST records cannot have duplicate keys, so builder.record() cannot fail"
1219)]
1220#[cfg(test)]
1221mod tests {
1222    use cool_asserts::{assert_matches, assertion_failure};
1223
1224    use super::*;
1225    use std::str::FromStr;
1226
1227    // --- Id tests ---
1228
1229    #[test]
1230    fn test_id_valid_identifiers() {
1231        // Simple identifiers
1232        assert!(Id::new("x").is_ok());
1233        assert!(Id::new("userName").is_ok());
1234        assert!(Id::new("_private").is_ok());
1235        assert!(Id::new("a1").is_ok());
1236        assert!(Id::new("ABC").is_ok());
1237    }
1238
1239    #[test]
1240    fn test_id_reserved_keywords_rejected() {
1241        for kw in [
1242            "if", "then", "else", "true", "false", "in", "is", "like", "has",
1243        ] {
1244            assert!(Id::new(kw).is_err(), "keyword `{kw}` should be rejected");
1245        }
1246    }
1247
1248    #[test]
1249    fn test_id_invalid_strings_rejected() {
1250        assert!(Id::new("").is_err());
1251        assert!(Id::new("1abc").is_err()); // starts with digit
1252        assert!(Id::new("a b").is_err()); // space
1253        assert!(Id::new("a+b").is_err()); // special char
1254        assert!(Id::new("::").is_err());
1255    }
1256
1257    #[test]
1258    fn test_id_accessors() {
1259        let id = Id::new("hello").unwrap();
1260        assert_eq!(id.as_str(), "hello");
1261        assert_eq!(id.as_ref(), "hello");
1262        assert_eq!(id.to_string(), "hello");
1263        assert_eq!(id.clone().into_smolstr(), SmolStr::from("hello"));
1264    }
1265
1266    #[test]
1267    fn test_id_equality_and_ordering() {
1268        let a = Id::new("aaa").unwrap();
1269        let b = Id::new("bbb").unwrap();
1270        let a2 = Id::new("aaa").unwrap();
1271        assert_eq!(a, a2);
1272        assert_ne!(a, b);
1273        assert!(a < b);
1274    }
1275
1276    #[test]
1277    fn test_id_from_ast_id() {
1278        let ast_id = crate::ast::Id::from_str("myIdent").unwrap();
1279        let pst_id = Id::from(ast_id);
1280        assert_eq!(pst_id.as_str(), "myIdent");
1281    }
1282
1283    // --- Name tests ---
1284
1285    #[test]
1286    fn test_name_unqualified() {
1287        let name = Name::unqualified("User").unwrap();
1288        assert_eq!(name.id.as_str(), "User");
1289        assert!(name.namespace.is_empty());
1290        assert_eq!(name.to_string(), "User");
1291    }
1292
1293    #[test]
1294    fn test_name_qualified() {
1295        let name = Name::qualified(["MyApp", "Auth"], "User").unwrap();
1296        assert_eq!(name.id.as_str(), "User");
1297        assert_eq!(name.namespace.len(), 2);
1298        assert_eq!(name.namespace[0].as_str(), "MyApp");
1299        assert_eq!(name.namespace[1].as_str(), "Auth");
1300        assert_eq!(name.to_string(), "MyApp::Auth::User");
1301    }
1302
1303    #[test]
1304    fn test_name_rejects_invalid_basename() {
1305        assert!(Name::unqualified("if").is_err());
1306        assert!(Name::unqualified("1bad").is_err());
1307        assert!(Name::qualified(["Good"], "if").is_err());
1308    }
1309
1310    #[test]
1311    fn test_name_rejects_invalid_namespace_component() {
1312        assert!(Name::qualified(["true"], "User").is_err());
1313        assert!(Name::qualified(["ok", "1bad"], "User").is_err());
1314    }
1315
1316    #[test]
1317    fn test_name_roundtrip_through_ast() {
1318        let pst_name = Name::qualified(["NS"], "Foo").unwrap();
1319        let ast_name: crate::ast::Name = pst_name.clone().into();
1320        let back: Name = ast_name.into();
1321        assert_eq!(pst_name, back);
1322    }
1323
1324    // --- EntityType / EntityUID with validated names ---
1325
1326    #[test]
1327    fn test_entity_type_display_with_valid_name() {
1328        let et = EntityType::from_name(Name::unqualified("User").unwrap());
1329        assert_eq!(et.to_string(), "User");
1330        let et = EntityType::from_name(Name::qualified(["App"], "Photo").unwrap());
1331        assert_eq!(et.to_string(), "App::Photo");
1332    }
1333
1334    #[test]
1335    fn test_entity_uid_roundtrip_through_ast() {
1336        let uid = EntityUID {
1337            ty: EntityType::from_name(Name::qualified(["NS"], "Type").unwrap()),
1338            eid: SmolStr::from("eid123"),
1339        };
1340        let ast_uid: crate::ast::EntityUID = uid.clone().into();
1341        let back: EntityUID = ast_uid.into();
1342        assert_eq!(uid, back);
1343    }
1344
1345    #[test]
1346    fn test_has_slots() {
1347        // Leaf with no slot
1348        assert!(!Expr::Literal(Literal::Long(1)).has_slots());
1349        // Var has no slot
1350        assert!(!Expr::Var(Var::Principal).has_slots());
1351        // Slot itself
1352        assert!(Expr::Slot(SlotId::Principal).has_slots());
1353        assert!(Expr::Slot(SlotId::Resource).has_slots());
1354        // Slot nested inside a BinaryOp
1355        let slot = Arc::new(Expr::Slot(SlotId::Principal));
1356        let lit = Arc::new(Expr::Literal(Literal::Long(42)));
1357        let binop = Expr::BinaryOp {
1358            op: BinaryOp::Eq,
1359            left: slot,
1360            right: lit.clone(),
1361        };
1362        assert!(binop.has_slots());
1363        // BinaryOp with no slots
1364        let binop_no_slot = Expr::BinaryOp {
1365            op: BinaryOp::Eq,
1366            left: lit.clone(),
1367            right: lit.clone(),
1368        };
1369        assert!(!binop_no_slot.has_slots());
1370        // Slot nested inside a Set
1371        let set_with_slot = Expr::Set(vec![lit.clone(), Arc::new(Expr::Slot(SlotId::Resource))]);
1372        assert!(set_with_slot.has_slots());
1373        // Empty set
1374        assert!(!Expr::Set(vec![]).has_slots());
1375        // IfThenElse with slot in else branch
1376        let ite = Expr::IfThenElse {
1377            cond: lit.clone(),
1378            then_expr: lit.clone(),
1379            else_expr: Arc::new(Expr::Slot(SlotId::Principal)),
1380        };
1381        assert!(ite.has_slots());
1382    }
1383
1384    #[test]
1385    fn test_from_function_unknown_function() {
1386        let name = ast::Name::parse_unqualified_name("unknownFunc").unwrap();
1387        let args = vec![Arc::new(Expr::Literal(Literal::Long(1)))];
1388
1389        let result = Expr::from_function_ast_name_and_args(&name, args);
1390        assert!(matches!(
1391            result,
1392            Err(PstConstructionError::UnknownFunction(..))
1393        ));
1394    }
1395
1396    #[test]
1397    fn test_from_function_wrong_arity() {
1398        let name = ast::Name::parse_unqualified_name("decimal").unwrap();
1399        let args = vec![
1400            Arc::new(Expr::Literal(Literal::Long(1))),
1401            Arc::new(Expr::Literal(Literal::Long(2))),
1402        ];
1403
1404        let result = Expr::from_function_ast_name_and_args(&name, args);
1405        assert_matches!(result, Err(PstConstructionError::WrongArity(..)));
1406    }
1407
1408    #[test]
1409    fn test_all_extension_functions_are_supported() {
1410        // This test ensures that all extension functions defined in Extensions
1411        // are properly mapped to PST operators (UnaryOp or BinaryOp)
1412        let extensions = Extensions::all_available();
1413
1414        for func in extensions.all_funcs() {
1415            let name = func.name().clone();
1416            let arity = func.arg_types().len();
1417
1418            // Create dummy "0" arguments based on arity, we don't typecheck here
1419            let args: Vec<Arc<Expr>> = (0..arity)
1420                .map(|_| Arc::new(Expr::Literal(Literal::Long(0))))
1421                .collect();
1422
1423            let result = Expr::from_function_ast_name_and_args(&name, args);
1424            assert!(
1425                result.is_ok(),
1426                "Function {} should be supported but got error: {:?}",
1427                name,
1428                result.err()
1429            );
1430            let actual = result.unwrap();
1431            print!("Expression: {}", actual);
1432            match arity {
1433                1 => {
1434                    if &name.to_string() == "unknown" {
1435                        assert!(
1436                            matches!(actual, Expr::Unknown { .. }),
1437                            "Expected unary unknown function to be Unknown expr",
1438                        );
1439                    } else {
1440                        match actual {
1441                            Expr::UnaryOp { op, .. } => {
1442                                let op_name = op.to_name();
1443                                assert!(
1444                                    op_name.is_some(),
1445                                    "UnaryOp from extension {} should have known ast::Name",
1446                                    name
1447                                );
1448                                assert_eq!(
1449                                    UnaryOp::from_function_name(&name.as_ref().to_string()),
1450                                    Some(op)
1451                                );
1452                            }
1453                            _ => {
1454                                assertion_failure!("Unary function  should produce BinaryOp", name:name)
1455                            }
1456                        }
1457                    }
1458                }
1459                2 => match actual {
1460                    Expr::BinaryOp { op, .. } => {
1461                        let op_name = op.to_name();
1462                        assert!(
1463                            op_name.is_some(),
1464                            "BinaryOp from extension {} should have known ast::Name",
1465                            name
1466                        );
1467                        assert_eq!(
1468                            BinaryOp::from_function_name(&name.as_ref().to_string()),
1469                            Some(op)
1470                        );
1471                    }
1472                    _ => assertion_failure!("Binary function  should produce BinaryOp", name:name),
1473                },
1474                _ => (),
1475            }
1476        }
1477    }
1478
1479    #[test]
1480    fn test_expr_construction_error_display() {
1481        let err: PstConstructionError =
1482            error_body::UnknownFunctionError::new("foo".to_smolstr()).into();
1483        assert!(err.to_string().contains("foo"));
1484
1485        let err: PstConstructionError =
1486            error_body::WrongArityError::new("bar".to_string(), 2, 1).into();
1487        assert!(err.to_string().contains("bar"));
1488        assert!(err.to_string().contains("2"));
1489        assert!(err.to_string().contains("1"));
1490    }
1491
1492    #[test]
1493    fn test_builder_additional_methods() {
1494        // Test unknown
1495        let expr = PstBuilder::new().unknown(ast::Unknown::new_untyped("test"));
1496        assert_matches!(expr, Expr::Unknown { .. });
1497
1498        // Test like
1499        let base = PstBuilder::new().val("test");
1500        let pattern = ast::Pattern::from(vec![ast::PatternElem::Char('a')]);
1501        let expr = PstBuilder::new().like(base, pattern);
1502        assert_matches!(expr, Expr::Like { .. });
1503
1504        // Test is_in_entity_type
1505        let base = PstBuilder::new().var(ast::Var::Principal);
1506        let entity_type = EntityType::from_name(ast::Name::parse_unqualified_name("User").unwrap());
1507        let uid = ast::EntityUID::from_components(
1508            ast::EntityType::from(ast::Name::parse_unqualified_name("User").unwrap()),
1509            ast::Eid::new("alice"),
1510            None,
1511        );
1512        let in_expr = PstBuilder::new().val(uid);
1513        let expr = PstBuilder::new().is_in_entity_type(
1514            base,
1515            entity_type.clone().try_into().unwrap(),
1516            in_expr,
1517        );
1518        if let Expr::Is {
1519            entity_type: et,
1520            in_expr: Some(_),
1521            ..
1522        } = expr
1523        {
1524            assert_eq!(et, entity_type);
1525        } else {
1526            panic!("Expected Is with in_expr");
1527        }
1528    }
1529
1530    #[test]
1531    fn test_builder_record_duplicate_keys() {
1532        let pairs = vec![
1533            (SmolStr::new("key"), PstBuilder::new().val(1i64)),
1534            (SmolStr::new("key"), PstBuilder::new().val(2i64)),
1535        ];
1536        let result = PstBuilder::new().record(pairs);
1537        assert!(matches!(
1538            result,
1539            Err(ast::ExpressionConstructionError::DuplicateKey { .. })
1540        ));
1541    }
1542
1543    mod display_tests {
1544        use super::*;
1545        use smol_str::SmolStr;
1546
1547        #[test]
1548        fn invalid_name_rejected_at_construction() {
1549            let name = "!__Cedar!";
1550            assert!(Name::unqualified(name).is_err());
1551        }
1552
1553        // NOTE: These tests verify Display output for expressions constructed via the
1554        // ExprBuilder trait (internal builder). Some operators are desugared during
1555        // construction (e.g., != becomes !(==), > becomes !(<=), && and || may become
1556        // if-then-else in AST but remain as BinaryOp in PST).
1557        //
1558        // Once a public expression builder API is implemented that constructs PST
1559        // directly without desugaring, Display will show all operators in their
1560        // original form (!=, >, >=, &&, ||, etc.).
1561
1562        fn builder() -> PstBuilder {
1563            PstBuilder::new()
1564        }
1565
1566        #[test]
1567        fn test_builder_display() {
1568            let cases = vec![
1569                // Literals
1570                (builder().val(true), "true"),
1571                (builder().val(false), "false"),
1572                (builder().val(42i64), "42"),
1573                (builder().val(-123i64), "(-123)"),
1574                (builder().val("hello"), "\"hello\""),
1575                (
1576                    builder().val(ast::EntityUID::from_components(
1577                        ast::Name::from_str("Photo").unwrap().into(),
1578                        ast::Eid::new("abc123"),
1579                        None,
1580                    )),
1581                    "Photo::\"abc123\"",
1582                ),
1583                // Variables
1584                (builder().var(ast::Var::Principal), "principal"),
1585                (builder().var(ast::Var::Action), "action"),
1586                (builder().var(ast::Var::Resource), "resource"),
1587                (builder().var(ast::Var::Context), "context"),
1588                // Slots
1589                (builder().slot(ast::SlotId::principal()), "?principal"),
1590                (builder().slot(ast::SlotId::resource()), "?resource"),
1591                // Basic unary ops
1592                (builder().not(builder().val(true)), "!true"),
1593                (builder().neg(builder().val(42i64)), "-(42)"),
1594                // Binary ops - comparison
1595                (
1596                    builder().is_eq(builder().val(1i64), builder().val(2i64)),
1597                    "1 == 2",
1598                ),
1599                (
1600                    builder().noteq(builder().val(1i64), builder().val(2i64)),
1601                    "1 != 2",
1602                ),
1603                (
1604                    builder().less(builder().val(1i64), builder().val(2i64)),
1605                    "1 < 2",
1606                ),
1607                (
1608                    builder().lesseq(builder().val(1i64), builder().val(2i64)),
1609                    "1 <= 2",
1610                ),
1611                (
1612                    builder().greater(builder().val(1i64), builder().val(2i64)),
1613                    "1 > 2",
1614                ),
1615                (
1616                    builder().greatereq(builder().val(1i64), builder().val(2i64)),
1617                    "1 >= 2",
1618                ),
1619                // Binary ops - logical
1620                (
1621                    builder().and(builder().val(true), builder().val(false)),
1622                    "true && false",
1623                ),
1624                (
1625                    builder().or(builder().val(true), builder().val(false)),
1626                    "true || false",
1627                ),
1628                // Binary ops - arithmetic
1629                (
1630                    builder().add(builder().val(1i64), builder().val(2i64)),
1631                    "1 + 2",
1632                ),
1633                (
1634                    builder().sub(builder().val(5i64), builder().val(3i64)),
1635                    "5 - 3",
1636                ),
1637                (
1638                    builder().mul(builder().val(2i64), builder().val(3i64)),
1639                    "2 * 3",
1640                ),
1641                // Binary ops - set/hierarchy
1642                (
1643                    builder().is_in(
1644                        builder().var(ast::Var::Principal),
1645                        builder().var(ast::Var::Resource),
1646                    ),
1647                    "principal in resource",
1648                ),
1649                (
1650                    builder().contains(builder().set([builder().val(1i64)]), builder().val(1i64)),
1651                    "[1].contains(1)",
1652                ),
1653                (
1654                    builder().contains_all(
1655                        builder().set([builder().val(1i64)]),
1656                        builder().set([builder().val(1i64)]),
1657                    ),
1658                    "[1].containsAll([1])",
1659                ),
1660                (
1661                    builder().contains_any(
1662                        builder().set([builder().val(1i64)]),
1663                        builder().set([builder().val(1i64)]),
1664                    ),
1665                    "[1].containsAny([1])",
1666                ),
1667                // Attribute access
1668                (
1669                    builder().get_attr(builder().var(ast::Var::Principal), SmolStr::from("name")),
1670                    "principal.name",
1671                ),
1672                (
1673                    builder().has_attr(builder().var(ast::Var::Principal), SmolStr::from("name")),
1674                    "principal has name",
1675                ),
1676                (
1677                    builder().is_entity_type(
1678                        builder().var(ast::Var::Resource),
1679                        ast::Name::from_str("Photo").unwrap().into(),
1680                    ),
1681                    "resource is Photo",
1682                ),
1683                // If-then-else
1684                (
1685                    builder().ite(
1686                        builder().val(true),
1687                        builder().val(1i64),
1688                        builder().val(2i64),
1689                    ),
1690                    "if true then 1 else 2",
1691                ),
1692                // Sets
1693                (builder().set([]), "[]"),
1694                (builder().set([builder().val(1i64)]), "[1]"),
1695                (
1696                    builder().set([
1697                        builder().val(1i64),
1698                        builder().val(2i64),
1699                        builder().val(3i64),
1700                    ]),
1701                    "[1, 2, 3]",
1702                ),
1703                // Records
1704                (builder().record([]).unwrap(), "{}"),
1705                (
1706                    builder()
1707                        .record([(SmolStr::from("a"), builder().val(1i64))])
1708                        .unwrap(),
1709                    "{a: 1}",
1710                ),
1711                (
1712                    builder()
1713                        .record([
1714                            (SmolStr::from("a"), builder().val(1i64)),
1715                            (SmolStr::from("b"), builder().val(2i64)),
1716                        ])
1717                        .unwrap(),
1718                    "{a: 1, b: 2}",
1719                ),
1720                // Tags
1721                (
1722                    builder().has_tag(builder().var(ast::Var::Action), builder().val("tag")),
1723                    "action.hasTag(\"tag\")",
1724                ),
1725                (
1726                    builder().get_tag(builder().var(ast::Var::Action), builder().val("tag")),
1727                    "action.getTag(\"tag\")",
1728                ),
1729                // Like
1730                (
1731                    builder().like(
1732                        builder().val("hello"),
1733                        ast::Pattern::from(vec![
1734                            ast::PatternElem::Char('h'),
1735                            ast::PatternElem::Wildcard,
1736                        ]),
1737                    ),
1738                    "\"hello\" like \"h*\"",
1739                ),
1740                // Function calls
1741                (
1742                    builder()
1743                        .call_extension_fn(
1744                            Name::unqualified("decimal").unwrap().into(),
1745                            vec![builder().val("1.23")],
1746                        )
1747                        .unwrap(),
1748                    "decimal(\"1.23\")",
1749                ),
1750            ];
1751
1752            for (expr, expected) in cases {
1753                assert_eq!(expr.to_string(), expected, "Failed for: {}", expected);
1754            }
1755
1756            let fail_func = builder().call_extension_fn(
1757                Name::unqualified("notAFunc").unwrap().into(),
1758                vec![builder().val("12.3")],
1759            );
1760            assert!(fail_func.is_err());
1761        }
1762
1763        #[test]
1764        fn test_complex_expressions() {
1765            // Nested binary ops
1766            let nested = builder().is_eq(
1767                builder().add(builder().val(1i64), builder().val(2i64)),
1768                builder().val(3i64),
1769            );
1770            assert_eq!(nested.to_string(), "(1 + 2) == 3");
1771
1772            // Complex if-then-else
1773            let complex = builder().ite(
1774                builder().greater(
1775                    builder().get_attr(builder().var(ast::Var::Principal), SmolStr::from("age")),
1776                    builder().val(18i64),
1777                ),
1778                builder().get_attr(builder().var(ast::Var::Principal), SmolStr::from("name")),
1779                builder().val("unknown"),
1780            );
1781            assert_eq!(
1782                complex.to_string(),
1783                "if ((principal.age) > 18) then (principal.name) else \"unknown\""
1784            );
1785
1786            // isEmpty
1787            let is_empty = builder().is_empty(builder().set([]));
1788            assert_eq!(is_empty.to_string(), "[].isEmpty()");
1789        }
1790
1791        #[test]
1792        fn test_unary_op_display_no_impossible_operator() {
1793            // Test that all UnaryOp variants display without showing "<impossible operator>"
1794            let ops = [
1795                UnaryOp::Not,
1796                UnaryOp::Neg,
1797                UnaryOp::IsEmpty,
1798                UnaryOp::Datetime,
1799                UnaryOp::Decimal,
1800                UnaryOp::Duration,
1801                UnaryOp::Ip,
1802                UnaryOp::IsIPv4,
1803                UnaryOp::IsIPV6,
1804                UnaryOp::IsLoopback,
1805                UnaryOp::IsMulticast,
1806                UnaryOp::ToDate,
1807                UnaryOp::ToTime,
1808                UnaryOp::ToMilliseconds,
1809                UnaryOp::ToSeconds,
1810                UnaryOp::ToMinutes,
1811                UnaryOp::ToHours,
1812                UnaryOp::ToDays,
1813            ];
1814
1815            for op in ops {
1816                let display = op.to_string();
1817                assert_ne!(
1818                    display, "<impossible operator>",
1819                    "UnaryOp::{:?} should not display as impossible operator",
1820                    op
1821                );
1822            }
1823        }
1824
1825        #[test]
1826        fn test_binary_op_display_no_impossible_operator() {
1827            // Test that all BinaryOp variants display without showing "<impossible operator>"
1828            let ops = [
1829                BinaryOp::Eq,
1830                BinaryOp::NotEq,
1831                BinaryOp::Less,
1832                BinaryOp::LessEq,
1833                BinaryOp::Greater,
1834                BinaryOp::GreaterEq,
1835                BinaryOp::And,
1836                BinaryOp::Or,
1837                BinaryOp::Add,
1838                BinaryOp::Sub,
1839                BinaryOp::Mul,
1840                BinaryOp::In,
1841                BinaryOp::Contains,
1842                BinaryOp::ContainsAll,
1843                BinaryOp::ContainsAny,
1844                BinaryOp::GetTag,
1845                BinaryOp::HasTag,
1846                BinaryOp::IsInRange,
1847                BinaryOp::Offset,
1848                BinaryOp::DurationSince,
1849            ];
1850
1851            for op in ops {
1852                let display = op.to_string();
1853                assert_ne!(
1854                    display, "<impossible operator>",
1855                    "BinaryOp::{:?} should not display as impossible operator",
1856                    op
1857                );
1858            }
1859        }
1860    }
1861}