dcbor_pattern/pattern/meta/
or_pattern.rs1use dcbor::prelude::*;
2
3use crate::pattern::{Matcher, Path, Pattern, vm::Instr};
4
5#[derive(Debug, Clone, PartialEq, Eq)]
7pub struct OrPattern(Vec<Pattern>);
8
9impl OrPattern {
10    pub fn new(patterns: Vec<Pattern>) -> Self { OrPattern(patterns) }
12
13    pub fn patterns(&self) -> &[Pattern] { &self.0 }
15}
16
17impl Matcher for OrPattern {
18    fn paths(&self, haystack: &CBOR) -> Vec<Path> {
19        if self.patterns().iter().any(|pattern| pattern.matches(haystack)) {
20            vec![vec![haystack.clone()]]
21        } else {
22            vec![]
23        }
24    }
25
26    fn paths_with_captures(
27        &self,
28        haystack: &CBOR,
29    ) -> (Vec<Path>, std::collections::HashMap<String, Vec<Path>>) {
30        let mut all_paths = Vec::new();
31        let mut all_captures = std::collections::HashMap::new();
32
33        for pattern in self.patterns() {
35            let (paths, captures) = pattern.paths_with_captures(haystack);
36            all_paths.extend(paths);
37
38            for (name, capture_paths) in captures {
40                all_captures
41                    .entry(name)
42                    .or_insert_with(Vec::new)
43                    .extend(capture_paths);
44            }
45        }
46
47        (all_paths, all_captures)
48    }
49
50    fn compile(
52        &self,
53        code: &mut Vec<Instr>,
54        lits: &mut Vec<Pattern>,
55        captures: &mut Vec<String>,
56    ) {
57        if self.patterns().is_empty() {
58            return;
59        }
60
61        let mut splits = Vec::new();
63
64        for _ in 0..self.patterns().len() - 1 {
66            splits.push(code.len());
67            code.push(Instr::Split { a: 0, b: 0 }); }
69
70        for (i, pattern) in self.patterns().iter().enumerate() {
72            let pattern_start = code.len();
73
74            pattern.compile(code, lits, captures);
76
77            let jump_past_all = code.len();
79            code.push(Instr::Jump(0)); if i < self.patterns().len() - 1 {
83                let next_pattern = code.len();
84                code[splits[i]] =
85                    Instr::Split { a: pattern_start, b: next_pattern };
86            }
87
88            splits.push(jump_past_all);
90        }
91
92        let past_all = code.len();
94        for &jump in &splits[self.patterns().len() - 1..] {
95            code[jump] = Instr::Jump(past_all);
96        }
97    }
98
99    fn collect_capture_names(&self, names: &mut Vec<String>) {
100        for pattern in self.patterns() {
102            pattern.collect_capture_names(names);
103        }
104    }
105
106    fn is_complex(&self) -> bool {
107        self.patterns().len() > 1
110            || self.patterns().iter().any(|p| p.is_complex())
111    }
112}
113
114impl std::fmt::Display for OrPattern {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        write!(
117            f,
118            "{}",
119            self.patterns()
120                .iter()
121                .map(|p| p.to_string())
122                .collect::<Vec<_>>()
123                .join(" | ")
124        )
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn test_or_pattern_display() {
134        let pattern1 = Pattern::number(5);
135        let pattern2 = Pattern::text("hello");
136        let or_pattern = OrPattern::new(vec![pattern1, pattern2]);
137        assert_eq!(or_pattern.to_string(), r#"5 | "hello""#);
138    }
139
140    #[test]
141    fn test_or_pattern_matches_when_any_pattern_matches() {
142        let pattern =
143            OrPattern::new(vec![Pattern::number(5), Pattern::text("hello")]);
144
145        let cbor_5 = CBOR::from(5);
146        assert!(pattern.matches(&cbor_5));
147
148        let cbor_hello = CBOR::from("hello");
149        assert!(pattern.matches(&cbor_hello));
150    }
151
152    #[test]
153    fn test_or_pattern_fails_when_no_pattern_matches() {
154        let pattern =
155            OrPattern::new(vec![Pattern::number(5), Pattern::text("hello")]);
156
157        let cbor_42 = CBOR::from(42); assert!(!pattern.matches(&cbor_42));
159
160        let cbor_world = CBOR::from("world"); assert!(!pattern.matches(&cbor_world));
162    }
163
164    #[test]
165    fn test_or_pattern_empty_returns_false() {
166        let pattern = OrPattern::new(vec![]);
167        let cbor = CBOR::from("any value");
168        assert!(!pattern.matches(&cbor));
169    }
170}