1#[cfg(test)]
2mod tests;
3mod util;
4
5use crate::util::SortedMap;
6
7fn report_bad_state() -> ! {
8 panic!(
9 "\
10bad state detected - this could mean:
11 1. inserted keys are not proper traversals of well-formed terms, or
12 2. identical symbols have been used with different arities"
13 );
14}
15
16pub trait Symbol: Ord {
17 fn arity(&self) -> usize;
18}
19
20fn find_next_term<S: Symbol>(key: &[Option<S>], mut index: usize) -> usize {
21 let mut remaining = 1;
22 while remaining > 0 {
23 remaining +=
24 key[index].as_ref().map(|s| s.arity()).unwrap_or_default();
25 remaining -= 1;
26 index += 1;
27 if index > key.len() {
28 report_bad_state()
29 }
30 }
31 index
32}
33
34#[derive(Debug, Clone)]
35struct Leaf<T> {
36 data: T,
37}
38
39#[derive(Debug, Clone)]
40struct Branch<S, T> {
41 symbols: SortedMap<S, Node<S, T>>,
42 variable: Option<Box<Node<S, T>>>,
43}
44
45impl<S, T> Default for Branch<S, T> {
46 fn default() -> Self {
47 Self {
48 symbols: SortedMap::default(),
49 variable: None,
50 }
51 }
52}
53
54impl<S: Ord, T> Branch<S, T> {
55 pub(crate) fn get(&self, item: &Option<S>) -> Option<&Node<S, T>> {
56 if let Some(symbol) = item {
57 self.symbols.get(symbol)
58 } else {
59 self.variable.as_deref()
60 }
61 }
62
63 pub(crate) fn get_or_insert_empty_branch(
64 &mut self,
65 item: Option<S>,
66 ) -> (bool, &mut Node<S, T>) {
67 let mut inserted = false;
68 let mut empty_branch = || {
69 inserted = true;
70 Node::Branch(Branch::default())
71 };
72 let node = if let Some(symbol) = item {
73 self.symbols.get_or_insert_with(symbol, empty_branch)
74 } else {
75 self.variable
76 .get_or_insert_with(|| Box::new(empty_branch()))
77 };
78 (inserted, node)
79 }
80}
81
82#[derive(Debug, Clone)]
83enum Node<S, T> {
84 Leaf(Leaf<T>),
85 Branch(Branch<S, T>),
86}
87
88struct Results<'a, S, T> {
89 found: Vec<&'a T>,
90 todo: Vec<(&'a Branch<S, T>, usize)>,
91 skip: Vec<(&'a Branch<S, T>, usize, usize)>,
92 generalise: bool,
93 instantiate: bool,
94}
95
96impl<'a, S: Symbol, T> Results<'a, S, T> {
97 fn add_todo(
98 &mut self,
99 key: &[Option<S>],
100 node: &'a Node<S, T>,
101 index: usize,
102 ) {
103 match node {
104 Node::Leaf(leaf) => {
105 if index != key.len() {
106 report_bad_state()
107 }
108 self.found.push(&leaf.data);
109 }
110 Node::Branch(branch) => self.todo.push((branch, index)),
111 }
112 }
113
114 fn add_skip(
115 &mut self,
116 key: &[Option<S>],
117 node: &'a Node<S, T>,
118 index: usize,
119 remaining: usize,
120 ) {
121 if remaining == 0 {
122 self.add_todo(key, node, index);
123 } else if let Node::Branch(branch) = node {
124 self.skip.push((branch, index, remaining));
125 } else {
126 report_bad_state()
127 }
128 }
129
130 fn do_skip_symbols(
131 &mut self,
132 key: &[Option<S>],
133 branch: &'a Branch<S, T>,
134 index: usize,
135 remaining: usize,
136 ) {
137 let remaining = remaining - 1;
138 for (symbol, node) in branch.symbols.iter() {
139 let remaining = remaining + symbol.arity();
140 self.add_skip(key, node, index, remaining);
141 }
142 }
143
144 fn do_skip(
145 &mut self,
146 key: &[Option<S>],
147 branch: &'a Branch<S, T>,
148 index: usize,
149 remaining: usize,
150 ) {
151 self.do_skip_symbols(key, branch, index, remaining);
152 if let Some(variable) = branch.variable.as_deref() {
153 self.add_skip(key, variable, index, remaining - 1);
154 }
155 }
156
157 fn do_todo(
158 &mut self,
159 key: &[Option<S>],
160 branch: &'a Branch<S, T>,
161 index: usize,
162 ) {
163 if index >= key.len() {
164 report_bad_state()
165 }
166 let head = &key[index];
167 if let Some(node) = branch.get(head) {
169 self.add_todo(key, node, index + 1);
170 }
171 if self.generalise && head.is_some() {
173 if let Some(node) = branch.variable.as_deref() {
174 self.add_todo(key, node, find_next_term(key, index));
175 }
176 }
177 if self.instantiate && head.is_none() {
179 self.do_skip_symbols(key, branch, index + 1, 1);
180 }
181 }
182
183 fn next(&mut self, key: &[Option<S>]) -> Option<&'a T> {
184 loop {
185 if let Some(found) = self.found.pop() {
186 return Some(found);
187 }
188 if let Some((branch, index)) = self.todo.pop() {
189 self.do_todo(key, branch, index);
190 } else if let Some((branch, index, remaining)) = self.skip.pop() {
191 self.do_skip(key, branch, index, remaining);
192 } else {
193 return None;
194 }
195 }
196 }
197}
198
199#[derive(Debug, Clone)]
200pub struct DiscriminationTree<S, T> {
201 root: Branch<S, T>,
202}
203
204impl<S, T> Default for DiscriminationTree<S, T> {
205 fn default() -> Self {
206 Self {
207 root: Branch::default(),
208 }
209 }
210}
211
212impl<S: Symbol, T> DiscriminationTree<S, T> {
213 pub fn get_or_insert_with<
214 I: IntoIterator<Item = Option<S>>,
215 F: FnOnce() -> T,
216 >(
217 &mut self,
218 key: I,
219 insert: F,
220 ) -> &mut T {
221 let mut current = &mut self.root;
222 let mut remaining = 1;
223 for item in key {
224 remaining -= 1;
225 remaining += item.as_ref().map(|s| s.arity()).unwrap_or_default();
226
227 let (inserted, node) = current.get_or_insert_empty_branch(item);
228 if remaining == 0 {
229 if inserted {
230 *node = Node::Leaf(Leaf { data: insert() });
231 if let Node::Leaf(leaf) = node {
232 return &mut leaf.data;
233 } else {
234 unreachable!();
235 }
236 } else if let Node::Leaf(leaf) = node {
237 return &mut leaf.data;
238 } else {
239 report_bad_state()
240 }
241 }
242 if let Node::Branch(branch) = node {
243 current = branch;
244 } else {
245 report_bad_state()
246 }
247 }
248 report_bad_state()
249 }
250
251 pub fn query<I: IntoIterator<Item = Option<S>>>(
252 &self,
253 key: I,
254 generalise: bool,
255 instantiate: bool,
256 ) -> impl Iterator<Item = &T> {
257 let key = key.into_iter().collect::<Vec<_>>();
258 let mut results = Results {
259 found: vec![],
260 todo: vec![(&self.root, 0)],
261 skip: vec![],
262 generalise,
263 instantiate,
264 };
265 std::iter::from_fn(move || results.next(&key))
266 }
267}