pathrouter/
nfa.rs

1use std::collections::BTreeMap;
2
3const CHAR_PATH_SEP: char = '/';
4const CHAR_PARAM: char = ':';
5const CHAR_WILDCARD: char = '*';
6
7#[derive(Debug, Clone)]
8struct Entry {
9    pat: Pattern,
10    index: usize,
11}
12
13impl Entry {
14    fn new(pat: Pattern, index: usize) -> Self {
15        Entry { pat, index }
16    }
17}
18
19#[derive(Debug, Clone)]
20struct Transitions {
21    static_entries: BTreeMap<String, usize>,
22    dynamic_entries: Vec<Entry>,
23}
24
25impl Transitions {
26    fn new() -> Self {
27        Transitions {
28            static_entries: BTreeMap::new(),
29            dynamic_entries: Vec::new(),
30        }
31    }
32
33    fn get(&self, pat: &Pattern) -> Option<usize> {
34        match pat {
35            Pattern::Static(p) => self.static_entries.get(p).cloned(),
36            _ => {
37                for entry in &self.dynamic_entries {
38                    if &entry.pat == pat {
39                        return Some(entry.index);
40                    }
41                }
42                None
43            }
44        }
45    }
46
47    fn push(&mut self, pat: Pattern, index: usize) {
48        match pat {
49            Pattern::Static(p) => {
50                self.static_entries.insert(p, index);
51            }
52            p => {
53                self.dynamic_entries.push(Entry::new(p, index));
54            }
55        }
56    }
57
58    fn entries(&self) -> Vec<Entry> {
59        let mut ret = Vec::new();
60
61        for (k, v) in self.static_entries.iter() {
62            ret.push(Entry::new(Pattern::Static(k.to_owned()), *v))
63        }
64
65        ret.extend_from_slice(&self.dynamic_entries);
66
67        ret
68    }
69
70    fn capture<'a: 'b, 'b>(&'b self, seg: &'a str, path: &'a str) -> Vec<(Capture, usize)> {
71        let mut captures = Vec::new();
72
73        if let Some(index) = self.static_entries.get(seg) {
74            captures.push((Capture::Static, *index));
75        }
76
77        for entry in &self.dynamic_entries {
78            match &entry.pat {
79                Pattern::Param(name) => {
80                    captures.push((Capture::Param(name, seg), entry.index));
81                }
82                Pattern::Wildcard(name) => {
83                    captures.push((Capture::Wildcard(name, path), entry.index));
84                }
85                _ => unreachable!(),
86            }
87        }
88
89        captures
90    }
91
92    fn capture_static(&self, seg: &str) -> Option<usize> {
93        self.static_entries.get(seg).copied()
94    }
95}
96
97#[derive(Debug, Clone)]
98struct State {
99    index: usize,
100    transitions: Transitions,
101}
102
103impl State {
104    fn new(index: usize) -> Self {
105        State {
106            index,
107            transitions: Transitions::new(),
108        }
109    }
110}
111
112#[derive(Debug, Clone)]
113enum Pattern {
114    Static(String),
115    Param(String),
116    Wildcard(String),
117}
118
119impl PartialEq for Pattern {
120    fn eq(&self, other: &Self) -> bool {
121        match (self, other) {
122            (Self::Static(l0), Self::Static(r0)) => l0 == r0,
123            (Self::Param(_l0), Self::Param(_r0)) => true,
124            (Self::Wildcard(_l0), Self::Wildcard(_r0)) => true,
125            _ => false,
126        }
127    }
128}
129
130impl Pattern {
131    fn from_str(pat: impl AsRef<str>) -> Self {
132        let pat = pat.as_ref();
133        match pat.chars().next() {
134            Some(CHAR_PARAM) => Pattern::Param(pat[1..].to_owned()),
135            Some(CHAR_WILDCARD) => Pattern::Wildcard(pat[1..].to_owned()),
136            _ => Pattern::Static(pat.to_owned()),
137        }
138    }
139}
140
141#[derive(Debug, Clone)]
142pub struct Nfa {
143    states: Vec<State>,
144    acceptances: Vec<bool>,
145}
146
147impl Nfa {
148    pub fn new() -> Self {
149        let mut this = Nfa {
150            states: Vec::new(),
151            acceptances: Vec::new(),
152        };
153
154        this.new_state();
155
156        this
157    }
158
159    fn new_state(&mut self) -> usize {
160        let new_index = self.states.len();
161
162        let new_state = State::new(new_index);
163
164        self.states.push(new_state);
165        self.acceptances.push(false);
166
167        new_index
168    }
169
170    pub(crate) fn start_state(&self) -> usize {
171        self.states.first().expect("first state not exist").index
172    }
173
174    fn get_state(&self, index: usize) -> &State {
175        &self.states[index]
176    }
177
178    fn get_state_mut(&mut self, index: usize) -> &mut State {
179        self.states.get_mut(index).expect("state not exist")
180    }
181
182    fn get_acceptance(&self, state: usize) -> bool {
183        self.acceptances[state]
184    }
185
186    pub fn locate(&mut self, path: &str) -> usize {
187        let path = path.trim_start_matches(CHAR_PATH_SEP);
188        let segs = path.split(CHAR_PATH_SEP);
189
190        let mut index = self.start_state();
191
192        for seg in segs {
193            let pat = Pattern::from_str(seg);
194
195            let next = self.get_state(index).transitions.get(&pat);
196
197            match next {
198                Some(s) => {
199                    index = s;
200                }
201                None => {
202                    let new_state = self.new_state();
203                    self.get_state_mut(index).transitions.push(pat, new_state);
204
205                    index = new_state;
206                }
207            }
208        }
209
210        index
211    }
212
213    pub fn accept(&mut self, state: usize) {
214        if state != self.start_state() {
215            self.acceptances[state] = true;
216        }
217    }
218
219    pub fn insert(&mut self, path: &str) -> usize {
220        let state = self.locate(path);
221        self.accept(state);
222        state
223    }
224
225    pub fn search<'a: 'b, 'b>(&'a self, path: &'b str) -> Option<Match<'b>> {
226        let mut path = path.trim_start_matches(CHAR_PATH_SEP);
227
228        // try fast path, only match static transition
229        if let Some(ret) = self.fast_path_search(path) {
230            return Some(ret);
231        }
232
233        let mut roads = vec![Road::new(self.start_state(), Vec::new())];
234        while let Some((seg, reminder)) = path.split_once(CHAR_PATH_SEP) {
235            roads = self.process_seg(roads, seg, path);
236            path = reminder;
237        }
238
239        roads = self.process_seg(roads, path, path);
240
241        let roads = roads
242            .into_iter()
243            .filter(|road| self.get_acceptance(road.state));
244
245        // detect longest path
246        let found = roads.fold(None, |prev, curr| match prev {
247            Some(item) => {
248                if item < curr {
249                    Some(curr)
250                } else {
251                    Some(item)
252                }
253            }
254            None => Some(curr),
255        });
256
257        found.map(|found| {
258            let mut params = Vec::new();
259            for capture in found.captures {
260                match capture {
261                    Capture::Param(n, v) => {
262                        params.push((n, v));
263                    }
264                    Capture::Wildcard(n, v) => {
265                        params.push((n, v));
266                    }
267                    Capture::Static => {}
268                }
269            }
270
271            Match::new(found.state, params)
272        })
273    }
274
275    fn fast_path_search(&self, path: &str) -> Option<Match> {
276        let mut road = Road::new(self.start_state(), Vec::new());
277        for seg in path.split(CHAR_PATH_SEP) {
278            match self.process_static_seg(seg, road) {
279                Some(r) => {
280                    road = r;
281                }
282                None => {
283                    return None;
284                }
285            }
286        }
287
288        Some(Match::new(road.state, Vec::new()))
289    }
290
291    fn process_static_seg<'a: 'b, 'b>(&'a self, seg: &str, mut road: Road<'b>) -> Option<Road<'b>> {
292        self.get_state(road.state)
293            .transitions
294            .capture_static(seg)
295            .map(|next| {
296                road.state = next;
297                road
298            })
299    }
300
301    fn process_seg<'a: 'b, 'b>(
302        &'a self,
303        roads: Vec<Road<'a>>,
304        seg: &'b str,
305        path: &'b str,
306    ) -> Vec<Road<'b>> {
307        let mut returned = Vec::with_capacity(roads.len());
308
309        for r in roads {
310            // while into wildcard, skip it
311            if r.wildcard {
312                returned.push(r);
313                continue;
314            }
315
316            let Road {
317                state, captures, ..
318            } = r;
319
320            for (capture, next) in self.get_state(state).transitions.capture(seg, path) {
321                let mut new_captures = captures.clone();
322                match capture {
323                    Capture::Wildcard(_name, _param) => {
324                        new_captures.push(capture);
325                        let mut road = Road::new(next, new_captures);
326                        road.set_wildcard(true);
327                        returned.push(road);
328                    }
329                    _ => {
330                        new_captures.push(capture);
331                        returned.push(Road::new(next, new_captures));
332                    }
333                }
334            }
335        }
336
337        returned
338    }
339
340    pub(crate) fn merge(&mut self, left: usize, other: &Self, right: usize) -> Vec<(usize, usize)> {
341        let mut returned = Vec::new();
342
343        for Entry { pat, index: old } in other.get_state(right).transitions.entries() {
344            let new_state = self.new_state();
345            if other.get_acceptance(old) {
346                self.accept(new_state);
347            }
348            self.get_state_mut(left)
349                .transitions
350                .push(pat.clone(), new_state);
351
352            returned.push((new_state, old));
353
354            returned.extend(self.merge(new_state, other, old));
355        }
356
357        returned
358    }
359}
360
361#[derive(Debug)]
362pub struct Match<'a> {
363    pub state: usize,
364    pub params: Vec<(&'a str, &'a str)>,
365}
366
367impl<'a> Match<'a> {
368    fn new(state: usize, params: Vec<(&'a str, &'a str)>) -> Self {
369        Match { state, params }
370    }
371}
372
373#[derive(Debug, Clone, PartialEq, Eq)]
374enum Capture<'a> {
375    Static,
376    Param(&'a str, &'a str),
377    Wildcard(&'a str, &'a str),
378}
379
380#[derive(Debug, PartialEq)]
381struct Road<'a> {
382    state: usize,
383    captures: Vec<Capture<'a>>,
384    wildcard: bool,
385}
386
387impl<'a> Road<'a> {
388    fn new(state: usize, captures: Vec<Capture<'a>>) -> Self {
389        Road {
390            state,
391            captures,
392            wildcard: false,
393        }
394    }
395
396    fn set_wildcard(&mut self, wildcard: bool) {
397        self.wildcard = wildcard;
398    }
399}
400
401impl<'a> PartialOrd for Road<'a> {
402    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
403        if self.captures.len() == other.captures.len() {
404            for (a, b) in self.captures.iter().zip(other.captures.iter()) {
405                match (a, b) {
406                    (Capture::Static, Capture::Param(_, _))
407                    | (Capture::Static, Capture::Wildcard(_, _)) => {
408                        return Some(std::cmp::Ordering::Greater)
409                    }
410                    (Capture::Param(_, _), Capture::Static)
411                    | (Capture::Wildcard(_, _), Capture::Static) => {
412                        return Some(std::cmp::Ordering::Less)
413                    }
414                    (Capture::Param(_, _), Capture::Wildcard(_, _)) => {
415                        return Some(std::cmp::Ordering::Greater)
416                    }
417                    (Capture::Wildcard(_, _), Capture::Param(_, _)) => {
418                        return Some(std::cmp::Ordering::Less)
419                    }
420                    _ => continue,
421                }
422            }
423            None
424        } else {
425            self.captures.len().partial_cmp(&other.captures.len())
426        }
427    }
428}
429
430#[cfg(test)]
431mod test {
432    use super::*;
433
434    #[test]
435    fn test_nfa() {
436        let mut nfa = Nfa::new();
437
438        nfa.insert("/api/v1/post/tom/daily");
439        nfa.insert("/api/v2/post/tom/daily");
440        nfa.insert("/api/v1/post/:user/daily");
441        nfa.insert("/api/v1/post/*any");
442
443        println!("-> {:?}", nfa);
444
445        let ret = nfa.search("/api/v1/post/tom/daily");
446
447        println!("ret => {:?}", ret);
448    }
449
450    #[test]
451    fn test_nfa2() {
452        let mut nfa = Nfa::new();
453
454        nfa.insert("/posts/:post_id/comments/100");
455        nfa.insert("/posts/100/comments/10");
456
457        println!("-> {:?}", nfa);
458
459        let ret = nfa.search("/posts/100/comments/100");
460
461        println!("ret => {:?}", ret);
462    }
463
464    #[test]
465    fn test_nfa_merge() {
466        let mut nfa = Nfa::new();
467
468        nfa.insert("/a/b/c");
469        nfa.insert("/a/b/d");
470        nfa.insert("/a/b/e");
471
472        let mut other = Nfa::new();
473
474        other.insert("/h/i/j");
475        other.insert("/h/i/k");
476        other.insert("/h/i/l");
477
478        let sub = nfa.locate("/a");
479
480        nfa.merge(sub, &other, other.start_state());
481
482        println!("-> {:?}", nfa);
483
484        let ret = nfa.search("/a/h/i/k");
485
486        println!("ret => {:?}", ret);
487    }
488}