Skip to main content

luna_core/
pattern.rs

1//! Lua pattern matching engine — a faithful port of lstrlib.c's matcher.
2//! Pure functions over byte slices (stone candidate: no runtime types).
3
4const MAX_CAPTURES: usize = 32;
5const MAX_DEPTH: u32 = 220;
6
7/// One capture produced by a successful pattern match.
8#[derive(Clone, Copy, PartialEq, Eq, Debug)]
9pub enum Cap {
10    /// captured span [start, end) in source bytes
11    Span(usize, usize),
12    /// position capture `()` — byte offset (0-based; callers add 1)
13    Pos(usize),
14}
15
16/// Error returned by the pattern matcher (malformed pattern, runaway depth,
17/// invalid `%f` frontier, etc.).
18#[derive(Debug)]
19pub struct PatError(
20    /// Human-readable message describing the malformation.
21    pub String,
22);
23
24/// A successful match against a Lua pattern, with the captures it produced.
25pub struct Match {
26    /// whole-match span [start, end)
27    pub start: usize,
28    /// End offset of the whole match (exclusive).
29    pub end: usize,
30    /// Captured spans / positions, in pattern order.
31    pub caps: Vec<Cap>,
32}
33
34struct State<'a> {
35    src: &'a [u8],
36    pat: &'a [u8],
37    caps: Vec<(usize, isize)>, // (start, len); CAP_UNFINISHED / CAP_POSITION
38    depth: u32,
39}
40
41const CAP_UNFINISHED: isize = -1;
42const CAP_POSITION: isize = -2;
43
44/// Split a leading `^` anchor from the pattern body. The caller decides what
45/// the anchor means (find/match scan at most once; gsub/gmatch stop after the
46/// first position).
47pub fn anchor_split(pat: &[u8]) -> (bool, &[u8]) {
48    match pat.first() {
49        Some(b'^') => (true, &pat[1..]),
50        _ => (false, pat),
51    }
52}
53
54/// Try to match `pat_body` (already `^`-stripped) at exactly position `s`,
55/// with no forward scan. Returns the Match (whose `start == s`) or None.
56pub fn match_at(src: &[u8], pat_body: &[u8], s: usize) -> Result<Option<Match>, PatError> {
57    let mut st = State {
58        src,
59        pat: pat_body,
60        caps: Vec::new(),
61        depth: 0,
62    };
63    let Some(e) = do_match(&mut st, s, 0)? else {
64        return Ok(None);
65    };
66    if st.caps.iter().any(|&(_, l)| l == CAP_UNFINISHED) {
67        return Err(PatError("unfinished capture".into()));
68    }
69    let caps = st
70        .caps
71        .iter()
72        .map(|&(cs, cl)| {
73            if cl == CAP_POSITION {
74                Cap::Pos(cs)
75            } else {
76                Cap::Span(cs, cs + cl as usize)
77            }
78        })
79        .collect();
80    Ok(Some(Match {
81        start: s,
82        end: e,
83        caps,
84    }))
85}
86
87/// Scan from `init` for the first match (PUC str_find_aux without the plain
88/// fast path). A leading `^` anchors the search to `init`.
89pub fn find(src: &[u8], pat: &[u8], init: usize) -> Result<Option<Match>, PatError> {
90    if init > src.len() {
91        return Ok(None);
92    }
93    let (anchor, pat_body) = anchor_split(pat);
94    let mut s = init;
95    loop {
96        if let Some(m) = match_at(src, pat_body, s)? {
97            return Ok(Some(m));
98        }
99        if anchor || s >= src.len() {
100            return Ok(None);
101        }
102        s += 1;
103    }
104}
105
106fn class_match(c: u8, cl: u8) -> bool {
107    let res = match cl.to_ascii_lowercase() {
108        b'a' => c.is_ascii_alphabetic(),
109        b'c' => c.is_ascii_control(),
110        b'd' => c.is_ascii_digit(),
111        b'g' => c.is_ascii_graphic(),
112        b'l' => c.is_ascii_lowercase(),
113        b'p' => c.is_ascii_punctuation(),
114        b's' => matches!(c, b' ' | b'\t' | b'\n' | 0x0B | 0x0C | b'\r'),
115        b'u' => c.is_ascii_uppercase(),
116        b'w' => c.is_ascii_alphanumeric(),
117        b'x' => c.is_ascii_hexdigit(),
118        b'z' => c == 0,      // the \0 class (kept by PUC for compatibility)
119        _ => return c == cl, // escaped literal (%%, %., ...)
120    };
121    if cl.is_ascii_uppercase() { !res } else { res }
122}
123
124/// `[set]` matching; `pp` points at the '[' position, `ep` one past ']'.
125fn match_bracket(c: u8, pat: &[u8], pp: usize, ep: usize) -> bool {
126    let mut p = pp + 1;
127    let mut neg = false;
128    if pat.get(p) == Some(&b'^') {
129        neg = true;
130        p += 1;
131    }
132    let mut found = false;
133    while p < ep - 1 {
134        if pat[p] == b'%' && p + 1 < ep - 1 {
135            p += 1;
136            if class_match(c, pat[p]) {
137                found = true;
138            }
139            p += 1;
140        } else if p + 2 < ep - 1 && pat[p + 1] == b'-' {
141            if pat[p] <= c && c <= pat[p + 2] {
142                found = true;
143            }
144            p += 3;
145        } else {
146            if pat[p] == c {
147                found = true;
148            }
149            p += 1;
150        }
151    }
152    found != neg
153}
154
155/// One past the end of the class starting at `p` (PUC classEnd).
156fn class_end(st: &State, p: usize) -> Result<usize, PatError> {
157    let pat = st.pat;
158    match pat.get(p) {
159        None => Err(PatError("malformed pattern (ends with '%')".into())),
160        Some(b'%') => {
161            if p + 1 >= pat.len() {
162                return Err(PatError("malformed pattern (ends with '%')".into()));
163            }
164            Ok(p + 2)
165        }
166        Some(b'[') => {
167            // PUC classEnd: do-while consumes one char before checking ']',
168            // so a ']' right after '[' or '[^' is a literal set member
169            let mut q = p + 1;
170            if pat.get(q) == Some(&b'^') {
171                q += 1;
172            }
173            loop {
174                if q >= pat.len() {
175                    return Err(PatError("malformed pattern (missing ']')".into()));
176                }
177                let c = pat[q];
178                q += 1;
179                if c == b'%' {
180                    if q >= pat.len() {
181                        return Err(PatError("malformed pattern (ends with '%')".into()));
182                    }
183                    q += 1;
184                }
185                if pat.get(q) == Some(&b']') {
186                    return Ok(q + 1);
187                }
188            }
189        }
190        Some(_) => Ok(p + 1),
191    }
192}
193
194fn single_match(st: &State, s: usize, p: usize, ep: usize) -> bool {
195    let Some(&c) = st.src.get(s) else {
196        return false;
197    };
198    match st.pat[p] {
199        b'.' => true,
200        b'%' => class_match(c, st.pat[p + 1]),
201        b'[' => match_bracket(c, st.pat, p, ep),
202        pc => pc == c,
203    }
204}
205
206fn capture_to_close(st: &State) -> Result<usize, PatError> {
207    for i in (0..st.caps.len()).rev() {
208        if st.caps[i].1 == CAP_UNFINISHED {
209            return Ok(i);
210        }
211    }
212    Err(PatError("invalid pattern capture".into()))
213}
214
215fn do_match(st: &mut State, mut s: usize, mut p: usize) -> Result<Option<usize>, PatError> {
216    st.depth += 1;
217    if st.depth > MAX_DEPTH {
218        st.depth -= 1;
219        return Err(PatError("pattern too complex".into()));
220    }
221    let r = do_match_inner(st, &mut s, &mut p);
222    st.depth -= 1;
223    r
224}
225
226fn do_match_inner(st: &mut State, s: &mut usize, p: &mut usize) -> Result<Option<usize>, PatError> {
227    loop {
228        if *p >= st.pat.len() {
229            return Ok(Some(*s));
230        }
231        match st.pat[*p] {
232            b'(' => {
233                // position capture or start capture
234                return if st.pat.get(*p + 1) == Some(&b')') {
235                    if st.caps.len() >= MAX_CAPTURES {
236                        return Err(PatError("too many captures".into()));
237                    }
238                    st.caps.push((*s, CAP_POSITION));
239                    let r = do_match(st, *s, *p + 2)?;
240                    if r.is_none() {
241                        st.caps.pop();
242                    }
243                    Ok(r)
244                } else {
245                    if st.caps.len() >= MAX_CAPTURES {
246                        return Err(PatError("too many captures".into()));
247                    }
248                    st.caps.push((*s, CAP_UNFINISHED));
249                    let r = do_match(st, *s, *p + 1)?;
250                    if r.is_none() {
251                        st.caps.pop();
252                    }
253                    Ok(r)
254                };
255            }
256            b')' => {
257                let i = capture_to_close(st)?;
258                st.caps[i].1 = (*s - st.caps[i].0) as isize;
259                let r = do_match(st, *s, *p + 1)?;
260                if r.is_none() {
261                    st.caps[i].1 = CAP_UNFINISHED;
262                }
263                return Ok(r);
264            }
265            b'$' if *p + 1 == st.pat.len() => {
266                return Ok(if *s == st.src.len() { Some(*s) } else { None });
267            }
268            b'%' => match st.pat.get(*p + 1) {
269                Some(b'b') => {
270                    // balanced match %bxy
271                    let (Some(&x), Some(&y)) = (st.pat.get(*p + 2), st.pat.get(*p + 3)) else {
272                        return Err(PatError(
273                            "malformed pattern (missing arguments to '%b')".into(),
274                        ));
275                    };
276                    if st.src.get(*s) != Some(&x) {
277                        return Ok(None);
278                    }
279                    let mut bal = 1i32;
280                    let mut q = *s + 1;
281                    while q < st.src.len() {
282                        if st.src[q] == y {
283                            bal -= 1;
284                            if bal == 0 {
285                                return do_match(st, q + 1, *p + 4);
286                            }
287                        } else if st.src[q] == x {
288                            bal += 1;
289                        }
290                        q += 1;
291                    }
292                    return Ok(None);
293                }
294                Some(b'f') => {
295                    // frontier %f[set]
296                    if st.pat.get(*p + 2) != Some(&b'[') {
297                        return Err(PatError("missing '[' after '%f' in pattern".into()));
298                    }
299                    let ep = class_end(st, *p + 2)?;
300                    let prev = if *s == 0 { 0u8 } else { st.src[*s - 1] };
301                    let cur = st.src.get(*s).copied().unwrap_or(0);
302                    if !match_bracket(prev, st.pat, *p + 2, ep)
303                        && match_bracket(cur, st.pat, *p + 2, ep)
304                    {
305                        *p = ep;
306                        continue;
307                    }
308                    return Ok(None);
309                }
310                Some(&d @ b'0'..=b'9') => {
311                    // back-reference %1..%9 (%0 is invalid; guard before the
312                    // `d - b'1'` subtraction so it cannot underflow)
313                    if d == b'0' {
314                        return Err(PatError("invalid capture index %0".into()));
315                    }
316                    let idx = (d - b'1') as usize;
317                    if idx >= st.caps.len() || st.caps[idx].1 < 0 {
318                        return Err(PatError(format!("invalid capture index %{}", (d - b'0'))));
319                    }
320                    let (cs, cl) = st.caps[idx];
321                    let cl = cl as usize;
322                    if st.src.len() - *s >= cl && st.src[cs..cs + cl] == st.src[*s..*s + cl] {
323                        *s += cl;
324                        *p += 2;
325                        continue;
326                    }
327                    return Ok(None);
328                }
329                _ => { /* fall through to the default single-match path */ }
330            },
331            _ => {}
332        }
333        // default: single char class, possibly quantified
334        let ep = class_end(st, *p)?;
335        match st.pat.get(ep) {
336            Some(b'?') => {
337                if single_match(st, *s, *p, ep)
338                    && let Some(r) = do_match(st, *s + 1, ep + 1)?
339                {
340                    return Ok(Some(r));
341                }
342                *p = ep + 1;
343                continue;
344            }
345            Some(b'+') => {
346                return if single_match(st, *s, *p, ep) {
347                    max_expand(st, *s + 1, *p, ep)
348                } else {
349                    Ok(None)
350                };
351            }
352            Some(b'*') => {
353                return max_expand(st, *s, *p, ep);
354            }
355            Some(b'-') => {
356                return min_expand(st, *s, *p, ep);
357            }
358            _ => {
359                if single_match(st, *s, *p, ep) {
360                    *s += 1;
361                    *p = ep;
362                    continue;
363                }
364                return Ok(None);
365            }
366        }
367    }
368}
369
370fn max_expand(st: &mut State, s: usize, p: usize, ep: usize) -> Result<Option<usize>, PatError> {
371    let mut i = 0;
372    while single_match(st, s + i, p, ep) {
373        i += 1;
374    }
375    loop {
376        if let Some(r) = do_match(st, s + i, ep + 1)? {
377            return Ok(Some(r));
378        }
379        if i == 0 {
380            return Ok(None);
381        }
382        i -= 1;
383    }
384}
385
386fn min_expand(
387    st: &mut State,
388    mut s: usize,
389    p: usize,
390    ep: usize,
391) -> Result<Option<usize>, PatError> {
392    loop {
393        if let Some(r) = do_match(st, s, ep + 1)? {
394            return Ok(Some(r));
395        }
396        if single_match(st, s, p, ep) {
397            s += 1;
398        } else {
399            return Ok(None);
400        }
401    }
402}
403
404/// Whether the pattern contains any pattern-special character (gsub/find
405/// fast path).
406pub fn has_specials(pat: &[u8]) -> bool {
407    pat.iter().any(|c| {
408        matches!(
409            c,
410            b'^' | b'$' | b'*' | b'+' | b'?' | b'.' | b'(' | b')' | b'[' | b']' | b'%' | b'-'
411        )
412    })
413}
414
415/// Plain substring search (find with plain=true).
416pub fn plain_find(hay: &[u8], needle: &[u8], init: usize) -> Option<usize> {
417    if init > hay.len() {
418        return None;
419    }
420    if needle.is_empty() {
421        return Some(init);
422    }
423    hay[init..]
424        .windows(needle.len())
425        .position(|w| w == needle)
426        .map(|i| i + init)
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    fn m(src: &str, pat: &str) -> Option<(usize, usize)> {
434        find(src.as_bytes(), pat.as_bytes(), 0)
435            .unwrap()
436            .map(|m| (m.start, m.end))
437    }
438
439    #[test]
440    fn basics() {
441        assert_eq!(m("hello", "l+"), Some((2, 4)));
442        assert_eq!(m("hello", "^h"), Some((0, 1)));
443        assert_eq!(m("hello", "^e"), None);
444        assert_eq!(m("hello", "o$"), Some((4, 5)));
445        assert_eq!(m("hello", "%a+"), Some((0, 5)));
446        assert_eq!(m("a1b2", "%d"), Some((1, 2)));
447        assert_eq!(m("abc", "a.c"), Some((0, 3)));
448        assert_eq!(m("", ".*"), Some((0, 0)));
449        assert_eq!(m("abc", "x*"), Some((0, 0)));
450    }
451
452    #[test]
453    fn sets_and_quantifiers() {
454        assert_eq!(m("hello world", "[aeiou]"), Some((1, 2)));
455        assert_eq!(m("hello", "[^aeiou]+"), Some((0, 1)));
456        assert_eq!(m("x123y", "[0-9]+"), Some((1, 4)));
457        assert_eq!(m("aaa", "a-"), Some((0, 0)));
458        assert_eq!(m("<a><b>", "<.->"), Some((0, 3)));
459        assert_eq!(m("<a><b>", "<.*>"), Some((0, 6)));
460        assert_eq!(m("abc", "ab?c"), Some((0, 3)));
461        assert_eq!(m("ac", "ab?c"), Some((0, 2)));
462    }
463
464    #[test]
465    fn captures_and_specials() {
466        let mm = find(b"key=value", b"(%w+)=(%w+)", 0).unwrap().unwrap();
467        assert_eq!(mm.caps.len(), 2);
468        assert_eq!(mm.caps[0], Cap::Span(0, 3));
469        assert_eq!(mm.caps[1], Cap::Span(4, 9));
470        // position capture
471        let mm = find(b"abc", b"a()b", 0).unwrap().unwrap();
472        assert_eq!(mm.caps[0], Cap::Pos(1));
473        // balanced
474        assert_eq!(m("(foo(bar))baz", "%b()"), Some((0, 10)));
475        // frontier
476        assert_eq!(m("THE (quick) fox", "%f[%a]%a+"), Some((0, 3)));
477        // back-reference
478        assert_eq!(m("abcabc", "(abc)%1"), Some((0, 6)));
479        assert_eq!(m("abcabd", "(abc)%1"), None);
480    }
481
482    #[test]
483    fn errors() {
484        assert!(find(b"x", b"%", 0).is_err());
485        assert!(find(b"x", b"[abc", 0).is_err());
486        assert!(find(b"a", b"(a", 0).is_err()); // unfinished capture
487        assert!(find(b"x", b"%1", 0).is_err());
488    }
489}