lightning_path/
lib.rs

1use std::collections::{BTreeMap, HashSet};
2
3use crate::CharacterClass::{Ascii, InvalidChars, ValidChars};
4
5#[derive(PartialEq, Eq, Clone, Default, Debug)]
6pub struct CharSet {
7    pub low_mask: u32,
8    pub high_mask: u64,
9    pub non_ascii: HashSet<char>,
10}
11
12impl CharSet {
13    pub fn contains(&self, char: char) -> bool {
14        let val = char as u32 - 1;
15
16        if val > 127 {
17            self.non_ascii.contains(&char)
18        } else if val > 63 {
19            let bit = 1 << (val - 64);
20            self.high_mask & bit != 0
21        } else {
22            let bit = 1 << val;
23            self.low_mask & bit != 0
24        }
25    }
26}
27
28#[derive(Debug)]
29pub struct Params {
30    pub map: BTreeMap<String, String>,
31}
32
33impl PartialEq for Params {
34    fn eq(&self, other: &Self) -> bool {
35        self.map == other.map
36    }
37}
38
39impl Params {
40    pub fn new() -> Params {
41        Params {
42            map: BTreeMap::new(),
43        }
44    }
45
46    pub fn insert(&mut self, key: String, value: String) {
47        self.map.insert(key, value);
48    }
49
50    pub fn find(&self, key: &str) -> Option<&str> {
51        self.map.get(key).map(|s| &s[..])
52    }
53}
54
55#[derive(Clone, Debug)]
56pub struct Thread {
57    pub state: usize,
58    pub captures: Vec<(usize, usize)>,
59    pub capture_begin: Option<usize>,
60}
61
62impl Thread {
63    pub fn new() -> Self {
64        Self {
65            state: 0,
66            captures: Vec::new(),
67            capture_begin: None,
68        }
69    }
70
71    pub fn start_capture(&mut self, start: usize) {
72        self.capture_begin = Some(start);
73    }
74
75    pub fn end_capture(&mut self, end: usize) {
76        self.captures.push((self.capture_begin.unwrap(), end));
77        self.capture_begin = None;
78    }
79
80    pub fn extract<'a>(&self, source: &'a str) -> Vec<&'a str> {
81        self.captures
82            .iter()
83            .map(|&(start, end)| &source[start..end])
84            .collect()
85    }
86}
87
88#[derive(PartialEq, Eq, Clone, Debug)]
89pub enum CharacterClass {
90    Ascii(u64, u64, bool),
91    ValidChars(CharSet),
92    InvalidChars(CharSet),
93}
94
95impl CharacterClass {
96    pub fn any() -> CharacterClass {
97        Ascii(u64::MAX, u64::MAX, false)
98    }
99
100    pub fn valid_char(char: char) -> Self {
101        let val = char as u32 - 1;
102
103        if val > 127 {
104            ValidChars(Self::char_to_set(char))
105        } else if val > 63 {
106            Ascii(1 << (val - 64), 0, false)
107        } else {
108            Ascii(0, 1 << val, false)
109        }
110    }
111
112    pub fn invalid_char(char: char) -> Self {
113        let val = char as u32 - 1;
114
115        if val > 127 {
116            InvalidChars(Self::char_to_set(char))
117        } else if val > 63 {
118            Ascii(u64::MAX ^ (1 << (val - 64)), u64::MAX, true)
119        } else {
120            Ascii(u64::MAX, u64::MAX ^ (1 << val), true)
121        }
122    }
123
124    pub fn char_to_set(char: char) -> CharSet {
125        let mut set = CharSet::default();
126        set.non_ascii.insert(char);
127        set
128    }
129
130    pub fn matches(&self, char: char) -> bool {
131        match *self {
132            ValidChars(ref valid) => valid.contains(char),
133            InvalidChars(ref valid) => !valid.contains(char),
134            Ascii(high, low, unicode) => {
135                let val = char as u32 - 1;
136                if val > 127 {
137                    unicode
138                } else if val > 63 {
139                    high & (1 << (val - 64)) != 0
140                } else {
141                    low & (1 << val) != 0
142                }
143            }
144        }
145    }
146}
147
148#[derive(Debug)]
149pub struct Metadata {
150    pub statics: u32,
151    pub dynamics: u32,
152    pub wildcards: u32,
153    pub param_names: Vec<String>,
154}
155
156impl Metadata {
157    pub fn new() -> Metadata {
158        Metadata {
159            statics: 0,
160            dynamics: 0,
161            wildcards: 0,
162            param_names: Vec::new(),
163        }
164    }
165}
166
167#[derive(Debug)]
168pub struct State<T> {
169    pub index: usize,
170    pub chars: CharacterClass,
171    pub next_states: Vec<usize>,
172    pub acceptance: bool,
173    pub start_capture: bool,
174    pub end_capture: bool,
175    pub metadata: Option<T>,
176}
177
178#[derive(Debug)]
179pub struct NFA<T> {
180    pub states: Vec<State<T>>,
181    pub start_capture: Vec<bool>,
182    pub end_capture: Vec<bool>,
183    pub acceptance: Vec<bool>,
184}
185
186impl<T> NFA<T> {
187    pub fn put(&mut self, index: usize, chars: CharacterClass) -> usize {
188        {
189            // Check if the state already exists
190            // If it does, return just the index of it
191            // So we don't have to create a new state
192            let state = self.get(index);
193
194            for &index in &state.next_states {
195                let state = self.get(index);
196
197                if state.chars == chars {
198                    return index;
199                }
200            }
201        }
202
203        // If the state doesn't exist, we create a new one
204        // And add it to the next states of the current state
205        let state = self.new_state(chars);
206        self.get_mut(index).next_states.push(state);
207
208        state
209    }
210
211    pub fn get_mut(&mut self, index: usize) -> &mut State<T> {
212        &mut self.states[index]
213    }
214
215    pub fn new_state(&mut self, chars: CharacterClass) -> usize {
216        // The index of the new state is the length of the states vector
217        // Example:
218        // [0: 'a', 1: 'b', 2: 'c']
219        // The index is 3, so the new state will be at index 3
220        let index = self.states.len();
221        let state = State::new(index, chars);
222        self.states.push(state);
223
224        self.acceptance.push(false);
225        self.start_capture.push(false);
226        self.end_capture.push(false);
227
228        index
229    }
230
231    pub fn get(&self, index: usize) -> &State<T> {
232        &self.states[index]
233    }
234
235    pub fn acceptance(&mut self, index: usize) {
236        // Set the acceptance of the state at the given index to true
237        self.get_mut(index).acceptance = true;
238        self.acceptance[index] = true;
239    }
240
241    pub fn metadata(&mut self, index: usize, metadata: T) {
242        self.get_mut(index).metadata = Some(metadata);
243    }
244
245    pub fn start_capture(&mut self, index: usize) {
246        self.get_mut(index).start_capture = true;
247        self.start_capture[index] = true;
248    }
249
250    pub fn end_capture(&mut self, index: usize) {
251        self.get_mut(index).end_capture = true;
252        self.end_capture[index] = true;
253    }
254
255    pub fn put_state(&mut self, index: usize, child: usize) {
256        if !self.get(index).next_states.contains(&child) {
257            self.get_mut(index).next_states.push(child);
258        }
259    }
260}
261
262impl<T> State<T> {
263    pub fn new(index: usize, chars: CharacterClass) -> Self {
264        Self {
265            index,
266            chars,
267            next_states: Vec::new(),
268            acceptance: false,
269            start_capture: false,
270            end_capture: false,
271            metadata: None,
272        }
273    }
274}
275
276#[derive(Debug)]
277pub struct Match<'a> {
278    pub state: usize,
279    pub captures: Vec<&'a str>,
280}
281
282impl<'a> Match<'a> {
283    pub fn new(state: usize, captures: Vec<&'a str>) -> Self {
284        Self { state, captures }
285    }
286}
287
288#[derive(Debug)]
289pub struct RouterMatch<T> {
290    pub handler: T,
291    pub params: Params,
292}
293
294impl<T> RouterMatch<T> {
295    pub fn new(handler: T, params: Params) -> Self {
296        Self { handler, params }
297    }
298}
299
300impl<T> NFA<T> {
301    pub fn new() -> NFA<T> {
302        let root = State::new(0, CharacterClass::any());
303
304        NFA {
305            states: vec![root],
306            start_capture: vec![false],
307            end_capture: vec![false],
308            acceptance: vec![false],
309        }
310    }
311
312    pub fn process<'a>(&self, string: &'a str) -> Result<Match<'a>, String> {
313        let mut threads = vec![Thread::new()];
314
315        for (i, char) in string.char_indices() {
316            let next_threads = self.process_char(threads, char, i);
317
318            if next_threads.is_empty() {
319                return Err(format!("No match found for {}", string));
320            }
321
322            threads = next_threads;
323        }
324
325        let mut returned = threads
326            .into_iter()
327            .filter(|thread| self.get(thread.state).acceptance);
328
329        let thread = returned.next();
330
331        match thread {
332            None => Err(format!("No match found for {}", string)),
333            Some(mut thread) => {
334                if thread.capture_begin.is_some() {
335                    thread.end_capture(string.len());
336                }
337
338                let state = self.get(thread.state);
339                Ok(Match::new(state.index, thread.extract(string)))
340            }
341        }
342    }
343
344    pub fn process_char(&self, threads: Vec<Thread>, char: char, pos: usize) -> Vec<Thread> {
345        let mut returned = Vec::with_capacity(threads.len());
346
347        for mut thread in threads {
348            let current_state = self.get(thread.state);
349
350            let mut count = 0;
351            let mut found_state = 0;
352
353            for &index in &current_state.next_states {
354                let state = &self.get(index);
355
356                if state.chars.matches(char) {
357                    count += 1;
358                    found_state = index;
359                }
360            }
361
362            if count == 1 {
363                thread.state = found_state;
364                capture(self, &mut thread, current_state.index, found_state, pos);
365                returned.push(thread);
366                continue;
367            }
368
369            for &index in &current_state.next_states {
370                let state = &self.get(index);
371
372                if state.chars.matches(char) {
373                    let mut thread = fork_thread(&thread, state);
374                    capture(self, &mut thread, current_state.index, index, pos);
375                    returned.push(thread);
376                }
377            }
378        }
379
380        returned
381    }
382}
383
384fn fork_thread<T>(thread: &Thread, state: &State<T>) -> Thread {
385    let mut new_thread = thread.clone();
386    new_thread.state = state.index;
387    new_thread
388}
389
390fn capture<T>(
391    nfa: &NFA<T>,
392    thread: &mut Thread,
393    current_state: usize,
394    next_state: usize,
395    pos: usize,
396) {
397    if thread.capture_begin.is_none() && nfa.start_capture[next_state] {
398        thread.start_capture(pos);
399    }
400
401    if thread.capture_begin.is_some()
402        && nfa.end_capture[current_state]
403        && next_state > current_state
404    {
405        thread.end_capture(pos);
406    }
407}
408
409#[derive(Debug)]
410pub struct Router<T> {
411    pub nfa: NFA<Metadata>,
412    pub handlers: BTreeMap<usize, T>,
413}
414
415fn segments(route: &str) -> Vec<(Option<char>, &str)> {
416    let predicate = |c| c == '.' || c == '/';
417
418    let mut segments = vec![];
419    let mut segment_start = 0;
420
421    while segment_start < route.len() {
422        let segment_end = route[segment_start + 1..]
423            .find(predicate)
424            .map(|i| i + segment_start + 1)
425            .unwrap_or_else(|| route.len());
426        let potential_sep = route.chars().nth(segment_start);
427        let sep_and_segment = match potential_sep {
428            Some(sep) if predicate(sep) => (Some(sep), &route[segment_start + 1..segment_end]),
429            _ => (None, &route[segment_start..segment_end]),
430        };
431
432        segments.push(sep_and_segment);
433        segment_start = segment_end;
434    }
435
436    segments
437}
438
439fn first_byte(s: &str) -> u8 {
440    s.as_bytes()[0]
441}
442
443impl<T> Router<T> {
444    pub fn new() -> Router<T> {
445        Router {
446            nfa: NFA::new(),
447            handlers: BTreeMap::new(),
448        }
449    }
450
451    pub fn recognize(&self, mut path: &str) -> Result<RouterMatch<&T>, String> {
452        if first_byte(path) == b'/' {
453            path = &path[1..];
454        }
455
456        let nfa = &self.nfa;
457        let result = nfa.process(path);
458
459        return result.map(|m| {
460            let mut map = Params::new();
461            let state = &nfa.get(m.state);
462            let metadata = state.metadata.as_ref().unwrap();
463            let param_names = metadata.param_names.clone();
464
465            for (i, capture) in m.captures.iter().enumerate() {
466                if !param_names[i].is_empty() {
467                    map.insert(param_names[i].to_string(), capture.to_string());
468                }
469            }
470
471            let handler = self.handlers.get(&m.state).unwrap();
472            RouterMatch::new(handler, map)
473        });
474    }
475
476    pub fn add(&mut self, mut route: &str, destiny: T) {
477        if route.is_empty() {
478            return;
479        }
480
481        // Remove leading slash if exists
482        if first_byte(route) == b'/' {
483            route = &route[1..];
484        }
485
486        let nfa = &mut self.nfa;
487        let mut state = 0;
488        let mut metadata = Metadata::new();
489
490        for (separator, segment) in segments(route) {
491            // If we have a separator,
492            // we need to add a transition to the current state
493            if let Some(separator) = separator {
494                state = nfa.put(state, CharacterClass::valid_char(separator));
495            }
496
497            if segment.is_empty() {
498                continue;
499            }
500
501            match first_byte(segment) {
502                b':' => {
503                    state = process_star_state(nfa, state);
504                    metadata.dynamics += 1;
505                    metadata.param_names.push(
506                        // Add the param key without ':'
507                        segment[1..].to_string(),
508                    );
509                }
510                b'*' => {
511                    state = process_star_state(nfa, state);
512                    metadata.wildcards += 1;
513                    metadata.param_names.push(
514                        // Add the param key without '*'
515                        segment[1..].to_string(),
516                    );
517                }
518                _ => {
519                    state = process_static_segment(segment, nfa, state);
520                    metadata.statics += 1;
521                }
522            }
523        }
524
525        // Mark the state as an acceptance state
526        nfa.acceptance(state);
527
528        // Add the metadata to the state
529        nfa.metadata(state, metadata);
530
531        // Add the handler to the handlers map
532        self.handlers.insert(state, destiny);
533    }
534}
535
536fn process_star_state<T>(nfa: &mut NFA<T>, mut state: usize) -> usize {
537    state = nfa.put(state, CharacterClass::invalid_char('/'));
538    nfa.put_state(state, state);
539    nfa.start_capture(state);
540    nfa.end_capture(state);
541
542    state
543}
544
545fn process_static_segment<T>(segment: &str, nfa: &mut NFA<T>, mut state: usize) -> usize {
546    // When we are processing a static segment
547    // we just need to add a transition for each character
548    // to our current state
549    for char in segment.chars() {
550        state = nfa.put(state, CharacterClass::valid_char(char));
551    }
552
553    state
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    #[test]
561    fn test_segments() {
562        let route = "users/:id";
563        let expected = vec![(None, "users"), (Some('/'), ":id")];
564        assert_eq!(segments(route), expected);
565
566        let route = "/users/:id/posts";
567        let expected = vec![
568            (Some('/'), "users"),
569            (Some('/'), ":id"),
570            (Some('/'), "posts"),
571        ];
572        assert_eq!(segments(route), expected);
573
574        let route = "/users/:id/posts/:post_id";
575        let expected = vec![
576            (Some('/'), "users"),
577            (Some('/'), ":id"),
578            (Some('/'), "posts"),
579            (Some('/'), ":post_id"),
580        ];
581        assert_eq!(segments(route), expected);
582
583        let route = "/users/:id/posts/:post_id/comments";
584        let expected = vec![
585            (Some('/'), "users"),
586            (Some('/'), ":id"),
587            (Some('/'), "posts"),
588            (Some('/'), ":post_id"),
589            (Some('/'), "comments"),
590        ];
591        assert_eq!(segments(route), expected);
592    }
593
594    #[test]
595    fn test_add_static_routes() {
596        let mut router = Router::new();
597
598        router.add("users", "users");
599
600        let nfa = &router.nfa;
601        let handlers = &router.handlers;
602
603        assert_eq!(
604            nfa.states.len(),
605            6 // One state for each character in "users" + 1 for the root
606        );
607        assert_eq!(handlers.len(), 1);
608
609        let handler = handlers
610            .get(&5) // The last state of the NFA
611            .unwrap();
612
613        assert_eq!(*handler, "users");
614    }
615
616    #[test]
617    fn test_add_routes_with_star_wildcards() {
618        let mut router = Router::new();
619
620        router.add("users/*-profile", "users-wildcard");
621
622        let nfa = router.nfa;
623        let handlers = router.handlers;
624
625        assert_eq!(handlers.len(), 1);
626        assert_eq!(
627            nfa.states.len(),
628            8 // One state for each character in "users" + 1 for the root
629        );
630    }
631
632    #[test]
633    fn test_add_routes_with_colon_wildcards() {
634        let mut router = Router::new();
635
636        router.add("user/:id", "users-wildcard");
637
638        let nfa = router.nfa;
639        let handlers = router.handlers;
640
641        assert_eq!(handlers.len(), 1);
642        assert_eq!(nfa.states.len(), 7);
643
644        let metadata = &nfa.states[6]; // Last state
645        let metadata = metadata.metadata.as_ref().unwrap();
646        let params = &metadata.param_names;
647
648        assert_eq!(*params, vec!["id"]);
649    }
650
651    #[test]
652    fn test_route_recognize_static_route() {
653        let mut router = Router::new();
654
655        router.add("/users", "users");
656
657        let m = router.recognize("/users").unwrap();
658
659        assert_eq!(*m.handler, "users");
660        assert_eq!(m.params, Params::new());
661    }
662
663    #[test]
664    fn test_route_recognize_colon_wildcard() {
665        let mut router = Router::new();
666
667        router.add("/user/:id", "user");
668
669        let m = router.recognize("/user/1").unwrap();
670
671        assert_eq!(*m.handler, "user");
672        assert_eq!(m.params.find("id"), Some("1"));
673    }
674
675    #[test]
676    fn test_route_recognize_colon_wildcard_multiple_params() {
677        let mut router = Router::new();
678
679        router.add("/user/:id/posts/:post_id", "user-post");
680
681        let m = router.recognize("/user/9/posts/10").unwrap();
682
683        assert_eq!(*m.handler, "user-post");
684        assert_eq!(m.params.find("id"), Some("9"));
685        assert_eq!(m.params.find("post_id"), Some("10"));
686    }
687
688    #[test]
689    fn test_route_recognize_colon_wildcard_fail() {
690        let mut router = Router::new();
691
692        router.add("/user/:id", "user");
693
694        let m = router.recognize("/user");
695
696        assert!(m.is_err());
697    }
698
699    #[test]
700    fn test_route_recognize_star_wildcard() {
701        let mut router = Router::new();
702
703        router.add("/fs/*path", "fs");
704
705        let m = router.recognize("/fs/random-file-path").unwrap();
706
707        assert_eq!(*m.handler, "fs");
708        assert_eq!(m.params.find("path"), Some("random-file-path"));
709    }
710}