dcbor_pattern/pattern/meta/
or_pattern.rs

1use dcbor::prelude::*;
2
3use crate::pattern::{Matcher, Path, Pattern, vm::Instr};
4
5/// A pattern that matches if any contained pattern matches.
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct OrPattern(Vec<Pattern>);
8
9impl OrPattern {
10    /// Creates a new `OrPattern` with the given patterns.
11    pub fn new(patterns: Vec<Pattern>) -> Self { OrPattern(patterns) }
12
13    /// Returns the patterns contained in this OR pattern.
14    pub fn patterns(&self) -> &[Pattern] { &self.0 }
15}
16
17impl Matcher for OrPattern {
18    fn paths(&self, haystack: &CBOR) -> Vec<Path> {
19        if self
20            .patterns()
21            .iter()
22            .any(|pattern| pattern.matches(haystack))
23        {
24            vec![vec![haystack.clone()]]
25        } else {
26            vec![]
27        }
28    }
29
30    fn paths_with_captures(
31        &self,
32        haystack: &CBOR,
33    ) -> (Vec<Path>, std::collections::HashMap<String, Vec<Path>>) {
34        let mut all_paths = Vec::new();
35        let mut all_captures = std::collections::HashMap::new();
36
37        // Try each pattern in the OR group
38        for pattern in self.patterns() {
39            let (paths, captures) = pattern.paths_with_captures(haystack);
40            all_paths.extend(paths);
41
42            // Merge captures
43            for (name, capture_paths) in captures {
44                all_captures
45                    .entry(name)
46                    .or_insert_with(Vec::new)
47                    .extend(capture_paths);
48            }
49        }
50
51        (all_paths, all_captures)
52    }
53
54    /// Compile into byte-code (OR = any can match).
55    fn compile(
56        &self,
57        code: &mut Vec<Instr>,
58        lits: &mut Vec<Pattern>,
59        captures: &mut Vec<String>,
60    ) {
61        if self.patterns().is_empty() {
62            return;
63        }
64
65        // For N patterns: Split(p1, Split(p2, ... Split(pN-1, pN)))
66        let mut splits = Vec::new();
67
68        // Generate splits for all but the last pattern
69        for _ in 0..self.patterns().len() - 1 {
70            splits.push(code.len());
71            code.push(Instr::Split { a: 0, b: 0 }); // Placeholder
72        }
73
74        // Now fill in the actual split targets
75        for (i, pattern) in self.patterns().iter().enumerate() {
76            let pattern_start = code.len();
77
78            // Compile this pattern
79            pattern.compile(code, lits, captures);
80
81            // This pattern will jump to the end if it matches
82            let jump_past_all = code.len();
83            code.push(Instr::Jump(0)); // Placeholder
84
85            // If there's a next pattern, update the split to point here
86            if i < self.patterns().len() - 1 {
87                let next_pattern = code.len();
88                code[splits[i]] =
89                    Instr::Split { a: pattern_start, b: next_pattern };
90            }
91
92            // Will patch this jump once we know where "past all" is
93            splits.push(jump_past_all);
94        }
95
96        // Now patch all the jumps to point past all the patterns
97        let past_all = code.len();
98        for &jump in &splits[self.patterns().len() - 1..] {
99            code[jump] = Instr::Jump(past_all);
100        }
101    }
102
103    fn collect_capture_names(&self, names: &mut Vec<String>) {
104        // Collect captures from all patterns
105        for pattern in self.patterns() {
106            pattern.collect_capture_names(names);
107        }
108    }
109
110    fn is_complex(&self) -> bool {
111        // The pattern is complex if it contains more than one pattern, or if
112        // the one pattern is complex itself.
113        self.patterns().len() > 1
114            || self.patterns().iter().any(|p| p.is_complex())
115    }
116}
117
118impl std::fmt::Display for OrPattern {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        write!(
121            f,
122            "{}",
123            self.patterns()
124                .iter()
125                .map(|p| p.to_string())
126                .collect::<Vec<_>>()
127                .join(" | ")
128        )
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn test_or_pattern_display() {
138        let pattern1 = Pattern::number(5);
139        let pattern2 = Pattern::text("hello");
140        let or_pattern = OrPattern::new(vec![pattern1, pattern2]);
141        assert_eq!(or_pattern.to_string(), r#"5 | "hello""#);
142    }
143
144    #[test]
145    fn test_or_pattern_matches_when_any_pattern_matches() {
146        let pattern =
147            OrPattern::new(vec![Pattern::number(5), Pattern::text("hello")]);
148
149        let cbor_5 = CBOR::from(5);
150        assert!(pattern.matches(&cbor_5));
151
152        let cbor_hello = CBOR::from("hello");
153        assert!(pattern.matches(&cbor_hello));
154    }
155
156    #[test]
157    fn test_or_pattern_fails_when_no_pattern_matches() {
158        let pattern =
159            OrPattern::new(vec![Pattern::number(5), Pattern::text("hello")]);
160
161        let cbor_42 = CBOR::from(42); // Not 5
162        assert!(!pattern.matches(&cbor_42));
163
164        let cbor_world = CBOR::from("world"); // Not "hello"
165        assert!(!pattern.matches(&cbor_world));
166    }
167
168    #[test]
169    fn test_or_pattern_empty_returns_false() {
170        let pattern = OrPattern::new(vec![]);
171        let cbor = CBOR::from("any value");
172        assert!(!pattern.matches(&cbor));
173    }
174}