cedar_policy_core/ast/
pattern.rs

1/*
2 * Copyright Cedar Contributors
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      https://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use std::sync::Arc;
18
19/// Represent an element in a pattern literal (the RHS of the like operation)
20#[derive(Hash, Debug, Clone, Copy, PartialEq, Eq)]
21#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
22pub enum PatternElem {
23    /// A character literal
24    Char(char),
25    /// The wildcard `*`
26    Wildcard,
27}
28
29/// Represent a pattern literal (the RHS of the like operator)
30/// Also provides an implementation of the Display trait as well as a wildcard matching method.
31#[derive(Debug, Clone, Hash, Eq, PartialEq)]
32pub struct Pattern {
33    /// A vector of pattern elements
34    elems: Arc<Vec<PatternElem>>,
35}
36
37impl Pattern {
38    /// Explicitly create a pattern literal out of a shared vector of pattern elements
39    fn new(elems: Arc<Vec<PatternElem>>) -> Self {
40        Self { elems }
41    }
42
43    /// Getter to the wrapped vector
44    pub fn get_elems(&self) -> &[PatternElem] {
45        &self.elems
46    }
47
48    /// Iterate over pattern elements
49    pub fn iter(&self) -> impl Iterator<Item = &PatternElem> {
50        self.elems.iter()
51    }
52
53    /// Length of elems vector
54    pub fn len(&self) -> usize {
55        self.elems.len()
56    }
57
58    /// Is this an empty pattern
59    pub fn is_empty(&self) -> bool {
60        self.elems.is_empty()
61    }
62}
63
64impl From<Arc<Vec<PatternElem>>> for Pattern {
65    fn from(value: Arc<Vec<PatternElem>>) -> Self {
66        Self::new(value)
67    }
68}
69
70impl From<Vec<PatternElem>> for Pattern {
71    fn from(value: Vec<PatternElem>) -> Self {
72        Self::new(Arc::new(value))
73    }
74}
75
76impl FromIterator<PatternElem> for Pattern {
77    fn from_iter<T: IntoIterator<Item = PatternElem>>(iter: T) -> Self {
78        Self::new(Arc::new(iter.into_iter().collect()))
79    }
80}
81
82impl std::fmt::Display for Pattern {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        for pc in self.elems.as_ref() {
85            match pc {
86                PatternElem::Char('*') => write!(f, r#"\*"#)?,
87                PatternElem::Char(c) => write!(f, "{}", c.escape_debug())?,
88                PatternElem::Wildcard => write!(f, r#"*"#)?,
89            }
90        }
91        Ok(())
92    }
93}
94
95impl PatternElem {
96    fn match_char(self, text_char: char) -> bool {
97        match self {
98            PatternElem::Char(c) => text_char == c,
99            PatternElem::Wildcard => true,
100        }
101    }
102    fn is_wildcard(self) -> bool {
103        matches!(self, PatternElem::Wildcard)
104    }
105}
106
107impl Pattern {
108    /// Find if the argument text matches the pattern
109    pub fn wildcard_match(&self, text: &str) -> bool {
110        let pattern = self.get_elems();
111        if pattern.is_empty() {
112            return text.is_empty();
113        }
114
115        // Copying the strings into vectors requires extra space, but has two benefits:
116        // 1. It makes accessing elements more efficient. The alternative (i.e.,
117        //    chars().nth()) needs to re-scan the string for each invocation. Note
118        //    that a simple iterator will not work here since we move both forward
119        //    and backward through the string.
120        // 2. It provides an unambiguous length. In general for a string s,
121        //    s.len() is not the same as s.chars().count(). The length of these
122        //    created vectors will match .chars().count()
123        let text: Vec<char> = text.chars().collect();
124
125        let mut i: usize = 0; // index into text
126        let mut j: usize = 0; // index into pattern
127        let mut star_idx: usize = 0; // index in pattern (j) of the most recent *
128        let mut tmp_idx: usize = 0; // index in text (i) of the most recent *
129        let mut contains_star: bool = false; // does the pattern contain *?
130
131        let text_len = text.len();
132        let pattern_len = pattern.len();
133
134        while i < text_len && (!contains_star || star_idx != pattern_len - 1) {
135            // PANIC SAFETY `j` is checked to be less than length
136            #[allow(clippy::indexing_slicing)]
137            if j < pattern_len && pattern[j].is_wildcard() {
138                contains_star = true;
139                star_idx = j;
140                tmp_idx = i;
141                j += 1;
142            } else if j < pattern_len && pattern[j].match_char(text[i]) {
143                i += 1;
144                j += 1;
145            } else if contains_star {
146                j = star_idx + 1;
147                i = tmp_idx + 1;
148                tmp_idx = i;
149            } else {
150                return false;
151            }
152        }
153
154        // PANIC SAFETY `j` is checked to be less than length
155        #[allow(clippy::indexing_slicing)]
156        while j < pattern_len && pattern[j].is_wildcard() {
157            j += 1;
158        }
159
160        j == pattern_len
161    }
162}
163
164#[cfg(test)]
165mod test {
166    use super::*;
167
168    impl std::ops::Add for Pattern {
169        type Output = Pattern;
170        fn add(self, rhs: Self) -> Self::Output {
171            let elems = [self.get_elems(), rhs.get_elems()].concat();
172            Pattern::from(elems)
173        }
174    }
175
176    // Map a string into a pattern literal with `PatternElem::Char`
177    fn string_map(text: &str) -> Pattern {
178        text.chars().map(PatternElem::Char).collect()
179    }
180
181    // Create a star pattern literal
182    fn star() -> Pattern {
183        Pattern::from(vec![PatternElem::Wildcard])
184    }
185
186    // Create an empty pattern literal
187    fn empty() -> Pattern {
188        Pattern::from(vec![])
189    }
190
191    #[test]
192    fn test_wildcard_match_basic() {
193        // Patterns that match "foo bar"
194        assert!((string_map("foo") + star()).wildcard_match("foo bar"));
195        assert!((star() + string_map("bar")).wildcard_match("foo bar"));
196        assert!((star() + string_map("o b") + star()).wildcard_match("foo bar"));
197        assert!((string_map("f") + star() + string_map(" bar")).wildcard_match("foo bar"));
198        assert!((string_map("f") + star() + star() + string_map("r")).wildcard_match("foo bar"));
199        assert!((star() + string_map("f") + star() + star() + star()).wildcard_match("foo bar"));
200
201        // Patterns that do not match "foo bar"
202        assert!(!(star() + string_map("foo")).wildcard_match("foo bar"));
203        assert!(!(string_map("bar") + star()).wildcard_match("foo bar"));
204        assert!(!(star() + string_map("bo") + star()).wildcard_match("foo bar"));
205        assert!(!(string_map("f") + star() + string_map("br")).wildcard_match("foo bar"));
206        assert!(!(star() + string_map("x") + star() + star() + star()).wildcard_match("foo bar"));
207        assert!(!empty().wildcard_match("foo bar"));
208
209        // Patterns that match ""
210        assert!(empty().wildcard_match(""));
211        assert!(star().wildcard_match(""));
212
213        // Patterns that do not match ""
214        assert!(!string_map("foo bar").wildcard_match(""));
215
216        // Patterns that match "*"
217        assert!(string_map("*").wildcard_match("*"));
218        assert!(star().wildcard_match("*"));
219
220        // Patterns that do not match "*"
221        assert!(!string_map("\u{0000}").wildcard_match("*"));
222        assert!(!string_map(r"\u{0000}").wildcard_match("*"));
223    }
224
225    #[test]
226    fn test_wildcard_match_unicode() {
227        // Patterns that match "y̆"
228        assert!((string_map("y") + star()).wildcard_match("y̆"));
229        assert!(string_map("y̆").wildcard_match("y̆"));
230
231        // Patterns that do not match "y̆"
232        assert!(!(star() + string_map("p") + star()).wildcard_match("y̆"));
233
234        // Patterns that match "ḛ̶͑͝x̶͔͛a̵̰̯͛m̴͉̋́p̷̠͂l̵͇̍̔ȩ̶̣͝"
235        assert!((star() + string_map("p") + star()).wildcard_match("ḛ̶͑͝x̶͔͛a̵̰̯͛m̴͉̋́p̷̠͂l̵͇̍̔ȩ̶̣͝"));
236        assert!((star() + string_map("a̵̰̯͛m̴͉̋́") + star()).wildcard_match("ḛ̶͑͝x̶͔͛a̵̰̯͛m̴͉̋́p̷̠͂l̵͇̍̔ȩ̶̣͝"));
237
238        // Patterns that do not match "ḛ̶͑͝x̶͔͛a̵̰̯͛m̴͉̋́p̷̠͂l̵͇̍̔ȩ̶̣͝"
239        assert!(!(string_map("y") + star()).wildcard_match("ḛ̶͑͝x̶͔͛a̵̰̯͛m̴͉̋́p̷̠͂l̵͇̍̔ȩ̶̣͝"));
240    }
241}