1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
use std::collections::BTreeMap;

/// Trie or prefix tree
///```
/// use librualg::trie::Trie;
///
/// let mut trie = Trie::new();
/// trie.insert("abab");
/// trie.insert("abcc");
/// trie.insert("ddvbn");
///
/// assert_eq!(trie.contains("abab"), true);
/// assert_eq!(trie.contains("ababa"), false);
/// assert_eq!(trie.contains("abcc"), true);
/// assert_eq!(trie.contains("abc"), false);
/// ```

pub struct Trie {
    children: BTreeMap<u8, Trie>,
    leaf: bool,
}

impl Default for Trie {
    fn default() -> Self {
        Trie{ children: BTreeMap::new(), leaf: false}
    }
}

impl Trie {
    pub fn new() -> Self {
        Trie::default()
    }

    pub fn insert(&mut self, s: &str) {
        let mut node = self;
        for ch in s.as_bytes() {
            if node.children.get(ch).is_none() {
                node.children.insert(*ch, Trie{ children: BTreeMap::new(), leaf: false});
            }
            node = node.children.get_mut(ch).unwrap();
        }
        node.leaf = true;
    }

    pub fn contains(&self, p: &str) -> bool {
        let mut node = self;
        for ch in p.as_bytes() {
            if node.children.get(ch).is_none() {
                return false;
            }
            node = node.children.get(ch).unwrap();
        }
        node.leaf
    }

    pub fn remove(&mut self, p: &str) {
        if self.contains(p) {
            let mut node = self;
            for ch in p.as_bytes() {
                if node.children.get(ch).unwrap().children.is_empty() {
                    node.children.remove(ch);
                    return;
                }
                node = node.children.get_mut(ch).unwrap();
            }
            node.leaf = false;
        }
    }
}

#[test]
fn test_trie() {
    let mut trie = Trie::new();
    trie.insert("abab");
    trie.insert("abc");
    trie.insert("abccc");
    trie.insert("ddvbn");

    assert_eq!(trie.contains("abab"), true);
    assert_eq!(trie.contains("ababa"), false);
    assert_eq!(trie.contains("abccc"), true);
    assert_eq!(trie.contains("abcc"), false);
    assert_eq!(trie.contains("abc"), true);

    trie.remove("ab");
    trie.remove("abc");
    assert_eq!(trie.contains("abc"), false);

    trie = Trie::new();
    trie.insert("abc");
    trie.insert("abccc");

    assert_eq!(trie.contains("abccc"), true);
    assert_eq!(trie.contains("abc"), true);

    trie.remove("abccc");
    assert_eq!(trie.contains("abccc"), false);
    assert_eq!(trie.contains("abc"), true);
}