use dcbor::prelude::*;
use crate::pattern::{Matcher, Path, Pattern, vm::Instr};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OrPattern(Vec<Pattern>);
impl OrPattern {
pub fn new(patterns: Vec<Pattern>) -> Self { OrPattern(patterns) }
pub fn patterns(&self) -> &[Pattern] { &self.0 }
}
impl Matcher for OrPattern {
fn paths(&self, haystack: &CBOR) -> Vec<Path> {
if self
.patterns()
.iter()
.any(|pattern| pattern.matches(haystack))
{
vec![vec![haystack.clone()]]
} else {
vec![]
}
}
fn paths_with_captures(
&self,
haystack: &CBOR,
) -> (Vec<Path>, std::collections::HashMap<String, Vec<Path>>) {
let mut all_paths = Vec::new();
let mut all_captures = std::collections::HashMap::new();
for pattern in self.patterns() {
let (paths, captures) = pattern.paths_with_captures(haystack);
all_paths.extend(paths);
for (name, capture_paths) in captures {
all_captures
.entry(name)
.or_insert_with(Vec::new)
.extend(capture_paths);
}
}
(all_paths, all_captures)
}
fn compile(
&self,
code: &mut Vec<Instr>,
lits: &mut Vec<Pattern>,
captures: &mut Vec<String>,
) {
if self.patterns().is_empty() {
return;
}
let mut splits = Vec::new();
for _ in 0..self.patterns().len() - 1 {
splits.push(code.len());
code.push(Instr::Split { a: 0, b: 0 }); }
for (i, pattern) in self.patterns().iter().enumerate() {
let pattern_start = code.len();
pattern.compile(code, lits, captures);
let jump_past_all = code.len();
code.push(Instr::Jump(0));
if i < self.patterns().len() - 1 {
let next_pattern = code.len();
code[splits[i]] =
Instr::Split { a: pattern_start, b: next_pattern };
}
splits.push(jump_past_all);
}
let past_all = code.len();
for &jump in &splits[self.patterns().len() - 1..] {
code[jump] = Instr::Jump(past_all);
}
}
fn collect_capture_names(&self, names: &mut Vec<String>) {
for pattern in self.patterns() {
pattern.collect_capture_names(names);
}
}
fn is_complex(&self) -> bool {
self.patterns().len() > 1
|| self.patterns().iter().any(|p| p.is_complex())
}
}
impl std::fmt::Display for OrPattern {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
self.patterns()
.iter()
.map(|p| p.to_string())
.collect::<Vec<_>>()
.join(" | ")
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_or_pattern_display() {
let pattern1 = Pattern::number(5);
let pattern2 = Pattern::text("hello");
let or_pattern = OrPattern::new(vec![pattern1, pattern2]);
assert_eq!(or_pattern.to_string(), r#"5 | "hello""#);
}
#[test]
fn test_or_pattern_matches_when_any_pattern_matches() {
let pattern =
OrPattern::new(vec![Pattern::number(5), Pattern::text("hello")]);
let cbor_5 = CBOR::from(5);
assert!(pattern.matches(&cbor_5));
let cbor_hello = CBOR::from("hello");
assert!(pattern.matches(&cbor_hello));
}
#[test]
fn test_or_pattern_fails_when_no_pattern_matches() {
let pattern =
OrPattern::new(vec![Pattern::number(5), Pattern::text("hello")]);
let cbor_42 = CBOR::from(42); assert!(!pattern.matches(&cbor_42));
let cbor_world = CBOR::from("world"); assert!(!pattern.matches(&cbor_world));
}
#[test]
fn test_or_pattern_empty_returns_false() {
let pattern = OrPattern::new(vec![]);
let cbor = CBOR::from("any value");
assert!(!pattern.matches(&cbor));
}
}