dcbor_pattern/pattern/meta/
or_pattern.rs

1use dcbor::prelude::*;
2
3use crate::pattern::{Matcher, Path, Pattern, vm::Instr};
4
5/// A pattern that matches if any contained pattern matches.
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct OrPattern(Vec<Pattern>);
8
9impl OrPattern {
10    /// Creates a new `OrPattern` with the given patterns.
11    pub fn new(patterns: Vec<Pattern>) -> Self { OrPattern(patterns) }
12
13    /// Returns the patterns contained in this OR pattern.
14    pub fn patterns(&self) -> &[Pattern] { &self.0 }
15}
16
17impl Matcher for OrPattern {
18    fn paths(&self, haystack: &CBOR) -> Vec<Path> {
19        if self.patterns().iter().any(|pattern| pattern.matches(haystack)) {
20            vec![vec![haystack.clone()]]
21        } else {
22            vec![]
23        }
24    }
25
26    fn paths_with_captures(
27        &self,
28        haystack: &CBOR,
29    ) -> (Vec<Path>, std::collections::HashMap<String, Vec<Path>>) {
30        let mut all_paths = Vec::new();
31        let mut all_captures = std::collections::HashMap::new();
32
33        // Try each pattern in the OR group
34        for pattern in self.patterns() {
35            let (paths, captures) = pattern.paths_with_captures(haystack);
36            all_paths.extend(paths);
37
38            // Merge captures
39            for (name, capture_paths) in captures {
40                all_captures
41                    .entry(name)
42                    .or_insert_with(Vec::new)
43                    .extend(capture_paths);
44            }
45        }
46
47        (all_paths, all_captures)
48    }
49
50    /// Compile into byte-code (OR = any can match).
51    fn compile(
52        &self,
53        code: &mut Vec<Instr>,
54        lits: &mut Vec<Pattern>,
55        captures: &mut Vec<String>,
56    ) {
57        if self.patterns().is_empty() {
58            return;
59        }
60
61        // For N patterns: Split(p1, Split(p2, ... Split(pN-1, pN)))
62        let mut splits = Vec::new();
63
64        // Generate splits for all but the last pattern
65        for _ in 0..self.patterns().len() - 1 {
66            splits.push(code.len());
67            code.push(Instr::Split { a: 0, b: 0 }); // Placeholder
68        }
69
70        // Now fill in the actual split targets
71        for (i, pattern) in self.patterns().iter().enumerate() {
72            let pattern_start = code.len();
73
74            // Compile this pattern
75            pattern.compile(code, lits, captures);
76
77            // This pattern will jump to the end if it matches
78            let jump_past_all = code.len();
79            code.push(Instr::Jump(0)); // Placeholder
80
81            // If there's a next pattern, update the split to point here
82            if i < self.patterns().len() - 1 {
83                let next_pattern = code.len();
84                code[splits[i]] =
85                    Instr::Split { a: pattern_start, b: next_pattern };
86            }
87
88            // Will patch this jump once we know where "past all" is
89            splits.push(jump_past_all);
90        }
91
92        // Now patch all the jumps to point past all the patterns
93        let past_all = code.len();
94        for &jump in &splits[self.patterns().len() - 1..] {
95            code[jump] = Instr::Jump(past_all);
96        }
97    }
98
99    fn collect_capture_names(&self, names: &mut Vec<String>) {
100        // Collect captures from all patterns
101        for pattern in self.patterns() {
102            pattern.collect_capture_names(names);
103        }
104    }
105
106    fn is_complex(&self) -> bool {
107        // The pattern is complex if it contains more than one pattern, or if
108        // the one pattern is complex itself.
109        self.patterns().len() > 1
110            || self.patterns().iter().any(|p| p.is_complex())
111    }
112}
113
114impl std::fmt::Display for OrPattern {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        write!(
117            f,
118            "{}",
119            self.patterns()
120                .iter()
121                .map(|p| p.to_string())
122                .collect::<Vec<_>>()
123                .join(" | ")
124        )
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn test_or_pattern_display() {
134        let pattern1 = Pattern::number(5);
135        let pattern2 = Pattern::text("hello");
136        let or_pattern = OrPattern::new(vec![pattern1, pattern2]);
137        assert_eq!(or_pattern.to_string(), r#"5 | "hello""#);
138    }
139
140    #[test]
141    fn test_or_pattern_matches_when_any_pattern_matches() {
142        let pattern =
143            OrPattern::new(vec![Pattern::number(5), Pattern::text("hello")]);
144
145        let cbor_5 = CBOR::from(5);
146        assert!(pattern.matches(&cbor_5));
147
148        let cbor_hello = CBOR::from("hello");
149        assert!(pattern.matches(&cbor_hello));
150    }
151
152    #[test]
153    fn test_or_pattern_fails_when_no_pattern_matches() {
154        let pattern =
155            OrPattern::new(vec![Pattern::number(5), Pattern::text("hello")]);
156
157        let cbor_42 = CBOR::from(42); // Not 5
158        assert!(!pattern.matches(&cbor_42));
159
160        let cbor_world = CBOR::from("world"); // Not "hello"
161        assert!(!pattern.matches(&cbor_world));
162    }
163
164    #[test]
165    fn test_or_pattern_empty_returns_false() {
166        let pattern = OrPattern::new(vec![]);
167        let cbor = CBOR::from("any value");
168        assert!(!pattern.matches(&cbor));
169    }
170}