ib_matcher/syntax/regex/hir/
fold.rs

1use std::iter;
2
3use regex_syntax::{
4    hir::{Hir, HirKind},
5    Error,
6};
7
8pub fn parse_and_fold_literal(
9    pattern: &str,
10) -> Result<(Hir, Vec<Box<[u8]>>), Error> {
11    let (mut hirs, literals) =
12        fold_literal(iter::once(regex_syntax::parse(pattern)?));
13    Ok((hirs.pop().unwrap(), literals))
14}
15
16pub fn parse_and_fold_literal_utf8(
17    pattern: &str,
18) -> Result<(Hir, Vec<String>), Error> {
19    let (mut hirs, literals) =
20        fold_literal_utf8(iter::once(regex_syntax::parse(pattern)?));
21    Ok((hirs.pop().unwrap(), literals))
22}
23
24/// Fold the first 256 literals into single byte literals.
25pub fn fold_literal(
26    hirs: impl Iterator<Item = Hir>,
27) -> (Vec<Hir>, Vec<Box<[u8]>>) {
28    fold_literal_common(hirs, Ok)
29}
30
31/// Fold the first 256 UTF-8 literals into single byte literals.
32pub fn fold_literal_utf8(
33    hirs: impl Iterator<Item = Hir>,
34) -> (Vec<Hir>, Vec<String>) {
35    fold_literal_common(hirs, |b| String::from_utf8(b.to_vec()).map_err(|_| b))
36}
37
38fn fold_literal_common<T>(
39    hirs: impl Iterator<Item = Hir>,
40    try_into: impl Fn(Box<[u8]>) -> Result<T, Box<[u8]>>,
41) -> (Vec<Hir>, Vec<T>) {
42    fn fold_literal<T>(
43        hir: Hir,
44        literals: &mut Vec<T>,
45        f: &impl Fn(Box<[u8]>) -> Result<T, Box<[u8]>>,
46    ) -> Hir {
47        match hir.kind() {
48            HirKind::Empty | HirKind::Class(_) | HirKind::Look(_) => hir,
49            HirKind::Literal(_) => {
50                let i = literals.len();
51                if i > u8::MAX as usize {
52                    // Too many literals
53                    return hir;
54                }
55
56                let literal = match hir.into_kind() {
57                    HirKind::Literal(literal) => literal,
58                    _ => unreachable!(),
59                };
60                match f(literal.0) {
61                    Ok(literal) => {
62                        literals.push(literal);
63                        // maximum_len is only used by meta
64                        // minimum_len is also used by c_at_least(), but only to test > 0
65                        // utf8 is not used
66                        Hir::literal([i as u8])
67                    }
68                    Err(literal) => Hir::literal(literal),
69                }
70            }
71            HirKind::Repetition(_) => {
72                let mut repetition = match hir.into_kind() {
73                    HirKind::Repetition(repetition) => repetition,
74                    _ => unreachable!(),
75                };
76                repetition.sub =
77                    fold_literal(*repetition.sub, literals, f).into();
78                Hir::repetition(repetition)
79            }
80            HirKind::Capture(_) => {
81                let mut capture = match hir.into_kind() {
82                    HirKind::Capture(capture) => capture,
83                    _ => unreachable!(),
84                };
85                capture.sub = fold_literal(*capture.sub, literals, f).into();
86                Hir::capture(capture)
87            }
88            HirKind::Concat(_) => {
89                let subs = match hir.into_kind() {
90                    HirKind::Concat(subs) => subs,
91                    _ => unreachable!(),
92                }
93                .into_iter()
94                .map(|sub| fold_literal(sub, literals, f))
95                .collect();
96                Hir::concat(subs)
97            }
98            HirKind::Alternation(_) => {
99                // let all_literal = subs
100                //     .iter()
101                //     .all(|sub| matches!(sub.kind(), HirKind::Literal(_)));
102                let all_literal = hir.properties().is_alternation_literal();
103                let it = match hir.into_kind() {
104                    HirKind::Alternation(subs) => subs,
105                    _ => unreachable!(),
106                }
107                .into_iter()
108                .map(|sub| fold_literal(sub, literals, f));
109                let subs = if all_literal {
110                    // Bypass Hir::alternation() and c_alt_slice()
111                    it.chain(iter::once(Hir::fail())).collect()
112                } else {
113                    it.collect()
114                };
115                Hir::alternation(subs)
116            }
117        }
118    }
119    let mut literals = Vec::new();
120    (
121        hirs.map(|hir| fold_literal(hir, &mut literals, &try_into)).collect(),
122        literals,
123    )
124}
125
126#[cfg(test)]
127mod tests {
128    use regex_syntax::{hir::Hir, parse};
129
130    use super::*;
131
132    #[test]
133    fn fold_literal_test() {
134        let (hir, literals) = parse_and_fold_literal_utf8("abc").unwrap();
135        assert_eq!(hir, Hir::literal(*b"\x00"));
136        assert_eq!(literals, vec!["abc".to_string()]);
137
138        let (hir, literals) = parse_and_fold_literal_utf8("abc.*def").unwrap();
139        assert_eq!(
140            hir,
141            Hir::concat(vec![
142                Hir::literal(*b"\x00"),
143                parse(".*").unwrap(),
144                Hir::literal(*b"\x01")
145            ])
146        );
147        assert_eq!(literals, vec!["abc".to_string(), "def".to_string()]);
148    }
149}