cel_parser/references.rs
1use std::collections::HashSet;
2
3use crate::ast::{Expr, IdedExpr};
4
5/// A collection of all the references that an expression makes to variables and functions.
6pub struct ExpressionReferences<'expr> {
7    variables: HashSet<&'expr str>,
8    functions: HashSet<&'expr str>,
9}
10
11impl ExpressionReferences<'_> {
12    /// Returns true if the expression references the provided variable name.
13    ///
14    /// # Example
15    /// ```rust
16    /// # use cel_parser::Parser;
17    /// let expression = Parser::new().parse("foo.bar == true").unwrap();
18    /// let references = expression.references();
19    /// assert!(references.has_variable("foo"));
20    /// ```
21    pub fn has_variable(&self, name: impl AsRef<str>) -> bool {
22        self.variables.contains(name.as_ref())
23    }
24
25    /// Returns true if the expression references the provided function name.
26    ///
27    /// # Example
28    /// ```rust
29    /// # use cel_parser::Parser;
30    /// let expression = Parser::new().parse("size(foo) > 0").unwrap();
31    /// let references = expression.references();
32    /// assert!(references.has_function("size"));
33    /// ```
34    pub fn has_function(&self, name: impl AsRef<str>) -> bool {
35        self.functions.contains(name.as_ref())
36    }
37
38    /// Returns a list of all variables referenced in the expression.
39    ///
40    /// # Example
41    /// ```rust
42    /// # use cel_parser::Parser;
43    /// let expression = Parser::new().parse("foo.bar == true").unwrap();
44    /// let references = expression.references();
45    /// assert_eq!(vec!["foo"], references.variables());
46    /// ```
47    pub fn variables(&self) -> Vec<&str> {
48        self.variables.iter().copied().collect()
49    }
50
51    /// Returns a list of all functions referenced in the expression.
52    ///
53    /// # Example
54    /// ```rust
55    /// # use cel_parser::Parser;
56    /// let expression = Parser::new().parse("size(foo) > 0").unwrap();
57    /// let references = expression.references();
58    /// assert!(references.functions().contains(&"_>_"));
59    /// assert!(references.functions().contains(&"size"));
60    /// ```
61    pub fn functions(&self) -> Vec<&str> {
62        self.functions.iter().copied().collect()
63    }
64}
65
66impl IdedExpr {
67    /// Returns a set of all variables and functions referenced in the expression.
68    ///
69    /// # Example
70    /// ```rust
71    /// # use cel_parser::Parser;
72    /// let expression = Parser::new().parse("foo && size(foo) > 0").unwrap();
73    /// let references = expression.references();
74    ///
75    /// assert!(references.has_variable("foo"));
76    /// assert!(references.has_function("size"));
77    /// ```
78    pub fn references(&self) -> ExpressionReferences {
79        let mut variables = HashSet::new();
80        let mut functions = HashSet::new();
81        self._references(&mut variables, &mut functions);
82        ExpressionReferences {
83            variables,
84            functions,
85        }
86    }
87
88    /// Internal recursive function to collect all variable and function references in the expression.
89    fn _references<'expr>(
90        &'expr self,
91        variables: &mut HashSet<&'expr str>,
92        functions: &mut HashSet<&'expr str>,
93    ) {
94        match &self.expr {
95            Expr::Unspecified => {}
96            Expr::Call(call) => {
97                functions.insert(&call.func_name);
98                if let Some(target) = &call.target {
99                    target._references(variables, functions);
100                }
101                for arg in &call.args {
102                    arg._references(variables, functions);
103                }
104            }
105            Expr::Comprehension(comp) => {
106                comp.iter_range._references(variables, functions);
107                comp.accu_init._references(variables, functions);
108                comp.loop_cond._references(variables, functions);
109                comp.loop_step._references(variables, functions);
110                comp.result._references(variables, functions);
111            }
112            Expr::Ident(name) => {
113                // todo! Might want to make this "smarter" (are we in a comprehension?) and better encode these in const
114                if !name.starts_with('@') {
115                    variables.insert(name);
116                }
117            }
118            Expr::List(list) => {
119                for elem in &list.elements {
120                    elem._references(variables, functions);
121                }
122            }
123            Expr::Literal(_) => {}
124            Expr::Map(map) => {
125                for entry in &map.entries {
126                    match &entry.expr {
127                        crate::ast::EntryExpr::StructField(field) => {
128                            field.value._references(variables, functions);
129                        }
130                        crate::ast::EntryExpr::MapEntry(map_entry) => {
131                            map_entry.key._references(variables, functions);
132                            map_entry.value._references(variables, functions);
133                        }
134                    }
135                }
136            }
137            Expr::Select(select) => {
138                select.operand._references(variables, functions);
139            }
140            Expr::Struct(struct_expr) => {
141                for entry in &struct_expr.entries {
142                    match &entry.expr {
143                        crate::ast::EntryExpr::StructField(field) => {
144                            field.value._references(variables, functions);
145                        }
146                        crate::ast::EntryExpr::MapEntry(map_entry) => {
147                            map_entry.key._references(variables, functions);
148                            map_entry.value._references(variables, functions);
149                        }
150                    }
151                }
152            }
153        }
154    }
155}