Skip to main content

geo_coding/
tree.rs

1use alloc::collections::VecDeque;
2use alloc::string::String;
3use alloc::vec::Vec;
4
5#[cfg(feature = "std")]
6mod io;
7#[cfg(feature = "std")]
8mod read;
9#[cfg(feature = "std")]
10mod write;
11
12const EMPTY: u32 = 0;
13
14#[derive(Debug, PartialEq, Eq, Default, Clone, Copy)]
15struct Node<C, V> {
16    location: [C; 2],
17    value: V,
18    lesser_index: u32,
19    greater_index: u32,
20}
21
22/// Two-dimensional tree that maps a location given by `[i64; 2]` to a string.
23pub type NamesTree = Tree2D<i64, String>;
24
25/// Two-dimensional tree that maps a location given by `[C; 2]` to a value `V`.
26///
27/// # References
28///
29/// - <https://en.wikipedia.org/wiki/K-d_tree>
30#[derive(Debug, PartialEq, Eq)]
31pub struct Tree2D<C, V> {
32    nodes: Vec<Node<C, V>>,
33}
34
35impl<C: Ord + Copy + Default, V: Default> Tree2D<C, V> {
36    /// Create a new tree from the given nodes.
37    ///
38    /// The nodes are recursively subdividded into two groups: one that is behind and one that is
39    /// in front of the plane that goes through the median node.
40    /// The plane alternates between _x = 0_ and _y = 0_ for each layer of the tree.
41    ///
42    /// The values are moved from the vector without copying.
43    pub fn from_nodes(mut nodes: Vec<([C; 2], V)>) -> Self {
44        assert!(nodes.len() < u32::MAX as usize);
45        let mut output_nodes = Vec::with_capacity(nodes.len());
46        for _ in 0..nodes.len() {
47            output_nodes.push(Node {
48                location: Default::default(),
49                value: Default::default(),
50                lesser_index: EMPTY,
51                greater_index: EMPTY,
52            });
53        }
54        let mut output_node_index: u32 = 1;
55        let mut next_output_node_index = || {
56            let i = output_node_index;
57            output_node_index += 1;
58            i
59        };
60        let mut queue = VecDeque::new();
61        queue.push_back((0, next_output_node_index(), nodes.as_mut_slice()));
62        while let Some((coord_index, i, nodes)) = queue.pop_front() {
63            let next_coord_index = (coord_index + 1) % 2;
64            let nodes_len = nodes.len();
65            if nodes_len == 0 {
66                break;
67            }
68            if nodes_len == 1 {
69                output_nodes[(i - 1) as usize] = Node {
70                    location: nodes[0].0,
71                    value: core::mem::take(&mut nodes[0].1),
72                    lesser_index: EMPTY,
73                    greater_index: EMPTY,
74                };
75                continue;
76            }
77            let (lesser_nodes, median, greater_nodes) = nodes
78                .select_nth_unstable_by(nodes_len / 2, |a, b| {
79                    a.0[coord_index].cmp(&b.0[coord_index])
80                });
81            let lesser_index = if !lesser_nodes.is_empty() {
82                let i = next_output_node_index();
83                queue.push_back((next_coord_index, i, lesser_nodes));
84                i
85            } else {
86                EMPTY
87            };
88            let greater_index = if !greater_nodes.is_empty() {
89                let i = next_output_node_index();
90                queue.push_back((next_coord_index, i, greater_nodes));
91                i
92            } else {
93                EMPTY
94            };
95            output_nodes[(i - 1) as usize] = Node {
96                location: median.0,
97                value: core::mem::take(&mut median.1),
98                lesser_index,
99                greater_index,
100            };
101        }
102        Self {
103            nodes: output_nodes,
104        }
105    }
106
107    /// Returns up to `max_neighbours` nodes within `max_distance` that are closest to the `location`.
108    ///
109    /// The distance between nodes is computed using `calc_distance`.
110    pub fn find_nearest<D>(
111        &self,
112        location: &[C; 2],
113        mut max_distance: D,
114        max_neighbours: usize,
115        mut calc_distance: impl FnMut(&[C; 2], &[C; 2]) -> D,
116    ) -> Vec<(D, &[C; 2], &V)>
117    where
118        D: Ord + Copy + core::fmt::Display,
119    {
120        let mut neighbours = Vec::new();
121        if max_neighbours == 0 {
122            return neighbours;
123        }
124        // TODO optimize for max_neighbours == 1
125        let Some(root) = self.nodes.first() else {
126            return neighbours;
127        };
128        let mut queue = VecDeque::new();
129        queue.push_back((0, root));
130        while let Some((coord_index, node)) = queue.pop_front() {
131            let d = calc_distance(&node.location, location);
132            let mut lesser = false;
133            let mut greater = false;
134            if d.le(&max_distance) {
135                match neighbours.binary_search_by(|(distance, ..)| distance.cmp(&d)) {
136                    Err(i) if i == max_neighbours => {}
137                    Ok(i) | Err(i) => {
138                        if neighbours.len() == max_neighbours {
139                            neighbours.pop();
140                        }
141                        neighbours.insert(i, (d, &node.location, &node.value));
142                    }
143                }
144                if neighbours.len() == max_neighbours {
145                    // We've already found enough neighbours; now we can limit our search to the
146                    // ones that are closer than the closest one found so far.
147                    max_distance = neighbours[0].0;
148                }
149                lesser = true;
150                greater = true;
151            } else if location[coord_index] < node.location[coord_index] {
152                lesser = true;
153            } else {
154                greater = true;
155            }
156            let next_coord_index = (coord_index + 1) % 2;
157            if lesser && node.lesser_index != EMPTY {
158                queue.push_back((
159                    next_coord_index,
160                    &self.nodes[(node.lesser_index - 1) as usize],
161                ));
162            }
163            if greater && node.greater_index != EMPTY {
164                queue.push_back((
165                    next_coord_index,
166                    &self.nodes[(node.greater_index - 1) as usize],
167                ));
168            }
169        }
170        neighbours
171    }
172
173    /// Returns an iterator over nodes.
174    pub fn iter(&self) -> impl Iterator<Item = (&[C; 2], &V)> {
175        self.nodes.iter().map(|node| (&node.location, &node.value))
176    }
177
178    /// Returns the number of nodes.
179    pub fn len(&self) -> usize {
180        self.nodes.len()
181    }
182
183    /// Returns `true` if the tree is empty.
184    pub fn is_empty(&self) -> bool {
185        self.nodes.is_empty()
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use crate::euclidean_distance_squared;
193    use alloc::vec;
194
195    #[test]
196    fn tree_works() {
197        let tree = Tree2D::from_nodes(vec![
198            ([0_i64, 0], ()), //
199            ([-1, 0], ()),    //
200            ([1, 0], ()),     //
201            ([2, 0], ()),     //
202            ([3, 0], ()),     //
203        ]);
204        let neighbours = tree.find_nearest(&[5, 0], 25_u64, 1, euclidean_distance_squared);
205        assert_eq!(vec![(4, &[3, 0], &())], neighbours);
206    }
207}