discrimination_tree/
lib.rs

1#[cfg(test)]
2mod tests;
3mod util;
4
5use crate::util::SortedMap;
6
7fn report_bad_state() -> ! {
8    panic!(
9        "\
10bad state detected - this could mean:
11    1. inserted keys are not proper traversals of well-formed terms, or
12    2. identical symbols have been used with different arities"
13    );
14}
15
16pub trait Symbol: Ord {
17    fn arity(&self) -> usize;
18}
19
20fn find_next_term<S: Symbol>(key: &[Option<S>], mut index: usize) -> usize {
21    let mut remaining = 1;
22    while remaining > 0 {
23        remaining +=
24            key[index].as_ref().map(|s| s.arity()).unwrap_or_default();
25        remaining -= 1;
26        index += 1;
27        if index > key.len() {
28            report_bad_state()
29        }
30    }
31    index
32}
33
34#[derive(Debug, Clone)]
35struct Leaf<T> {
36    data: T,
37}
38
39#[derive(Debug, Clone)]
40struct Branch<S, T> {
41    symbols: SortedMap<S, Node<S, T>>,
42    variable: Option<Box<Node<S, T>>>,
43}
44
45impl<S, T> Default for Branch<S, T> {
46    fn default() -> Self {
47        Self {
48            symbols: SortedMap::default(),
49            variable: None,
50        }
51    }
52}
53
54impl<S: Ord, T> Branch<S, T> {
55    pub(crate) fn get(&self, item: &Option<S>) -> Option<&Node<S, T>> {
56        if let Some(symbol) = item {
57            self.symbols.get(symbol)
58        } else {
59            self.variable.as_deref()
60        }
61    }
62
63    pub(crate) fn get_or_insert_empty_branch(
64        &mut self,
65        item: Option<S>,
66    ) -> (bool, &mut Node<S, T>) {
67        let mut inserted = false;
68        let mut empty_branch = || {
69            inserted = true;
70            Node::Branch(Branch::default())
71        };
72        let node = if let Some(symbol) = item {
73            self.symbols.get_or_insert_with(symbol, empty_branch)
74        } else {
75            self.variable
76                .get_or_insert_with(|| Box::new(empty_branch()))
77        };
78        (inserted, node)
79    }
80}
81
82#[derive(Debug, Clone)]
83enum Node<S, T> {
84    Leaf(Leaf<T>),
85    Branch(Branch<S, T>),
86}
87
88struct Results<'a, S, T> {
89    found: Vec<&'a T>,
90    todo: Vec<(&'a Branch<S, T>, usize)>,
91    skip: Vec<(&'a Branch<S, T>, usize, usize)>,
92    generalise: bool,
93    instantiate: bool,
94}
95
96impl<'a, S: Symbol, T> Results<'a, S, T> {
97    fn add_todo(
98        &mut self,
99        key: &[Option<S>],
100        node: &'a Node<S, T>,
101        index: usize,
102    ) {
103        match node {
104            Node::Leaf(leaf) => {
105                if index != key.len() {
106                    report_bad_state()
107                }
108                self.found.push(&leaf.data);
109            }
110            Node::Branch(branch) => self.todo.push((branch, index)),
111        }
112    }
113
114    fn add_skip(
115        &mut self,
116        key: &[Option<S>],
117        node: &'a Node<S, T>,
118        index: usize,
119        remaining: usize,
120    ) {
121        if remaining == 0 {
122            self.add_todo(key, node, index);
123        } else if let Node::Branch(branch) = node {
124            self.skip.push((branch, index, remaining));
125        } else {
126            report_bad_state()
127        }
128    }
129
130    fn do_skip_symbols(
131        &mut self,
132        key: &[Option<S>],
133        branch: &'a Branch<S, T>,
134        index: usize,
135        remaining: usize,
136    ) {
137        let remaining = remaining - 1;
138        for (symbol, node) in branch.symbols.iter() {
139            let remaining = remaining + symbol.arity();
140            self.add_skip(key, node, index, remaining);
141        }
142    }
143
144    fn do_skip(
145        &mut self,
146        key: &[Option<S>],
147        branch: &'a Branch<S, T>,
148        index: usize,
149        remaining: usize,
150    ) {
151        self.do_skip_symbols(key, branch, index, remaining);
152        if let Some(variable) = branch.variable.as_deref() {
153            self.add_skip(key, variable, index, remaining - 1);
154        }
155    }
156
157    fn do_todo(
158        &mut self,
159        key: &[Option<S>],
160        branch: &'a Branch<S, T>,
161        index: usize,
162    ) {
163        if index >= key.len() {
164            report_bad_state()
165        }
166        let head = &key[index];
167        // exact matches: f = f, * = *
168        if let Some(node) = branch.get(head) {
169            self.add_todo(key, node, index + 1);
170        }
171        // generalisations
172        if self.generalise && head.is_some() {
173            if let Some(node) = branch.variable.as_deref() {
174                self.add_todo(key, node, find_next_term(key, index));
175            }
176        }
177        // instantiations
178        if self.instantiate && head.is_none() {
179            self.do_skip_symbols(key, branch, index + 1, 1);
180        }
181    }
182
183    fn next(&mut self, key: &[Option<S>]) -> Option<&'a T> {
184        loop {
185            if let Some(found) = self.found.pop() {
186                return Some(found);
187            }
188            if let Some((branch, index)) = self.todo.pop() {
189                self.do_todo(key, branch, index);
190            } else if let Some((branch, index, remaining)) = self.skip.pop() {
191                self.do_skip(key, branch, index, remaining);
192            } else {
193                return None;
194            }
195        }
196    }
197}
198
199#[derive(Debug, Clone)]
200pub struct DiscriminationTree<S, T> {
201    root: Branch<S, T>,
202}
203
204impl<S, T> Default for DiscriminationTree<S, T> {
205    fn default() -> Self {
206        Self {
207            root: Branch::default(),
208        }
209    }
210}
211
212impl<S: Symbol, T> DiscriminationTree<S, T> {
213    pub fn get_or_insert_with<
214        I: IntoIterator<Item = Option<S>>,
215        F: FnOnce() -> T,
216    >(
217        &mut self,
218        key: I,
219        insert: F,
220    ) -> &mut T {
221        let mut current = &mut self.root;
222        let mut remaining = 1;
223        for item in key {
224            remaining -= 1;
225            remaining += item.as_ref().map(|s| s.arity()).unwrap_or_default();
226
227            let (inserted, node) = current.get_or_insert_empty_branch(item);
228            if remaining == 0 {
229                if inserted {
230                    *node = Node::Leaf(Leaf { data: insert() });
231                    if let Node::Leaf(leaf) = node {
232                        return &mut leaf.data;
233                    } else {
234                        unreachable!();
235                    }
236                } else if let Node::Leaf(leaf) = node {
237                    return &mut leaf.data;
238                } else {
239                    report_bad_state()
240                }
241            }
242            if let Node::Branch(branch) = node {
243                current = branch;
244            } else {
245                report_bad_state()
246            }
247        }
248        report_bad_state()
249    }
250
251    pub fn query<I: IntoIterator<Item = Option<S>>>(
252        &self,
253        key: I,
254        generalise: bool,
255        instantiate: bool,
256    ) -> impl Iterator<Item = &T> {
257        let key = key.into_iter().collect::<Vec<_>>();
258        let mut results = Results {
259            found: vec![],
260            todo: vec![(&self.root, 0)],
261            skip: vec![],
262            generalise,
263            instantiate,
264        };
265        std::iter::from_fn(move || results.next(&key))
266    }
267}