geo_index/kdtree/
trait.rs

1use geo_traits::{CoordTrait, RectTrait};
2use tinyvec::TinyVec;
3
4use crate::indices::Indices;
5use crate::kdtree::{KDTree, KDTreeMetadata, KDTreeRef, Node};
6use crate::r#type::IndexableNum;
7
8/// A trait for searching and accessing data out of a KDTree.
9pub trait KDTreeIndex<N: IndexableNum>: Sized {
10    /// The underlying raw coordinate buffer of this tree
11    fn coords(&self) -> &[N];
12
13    /// The underlying raw indices buffer of this tree
14    fn indices(&self) -> Indices;
15
16    /// Access the metadata describing this KDTree
17    fn metadata(&self) -> &KDTreeMetadata<N>;
18
19    /// The number of items in this KDTree
20    fn num_items(&self) -> u32 {
21        self.metadata().num_items()
22    }
23
24    /// The node size of this KDTree
25    fn node_size(&self) -> u16 {
26        self.metadata().node_size()
27    }
28
29    /// Search the index for items within a given bounding box.
30    ///
31    /// - min_x: bbox
32    /// - min_y: bbox
33    /// - max_x: bbox
34    /// - max_y: bbox
35    ///
36    /// Returns indices of found items
37    fn range(&self, min_x: N, min_y: N, max_x: N, max_y: N) -> Vec<u32> {
38        let indices = self.indices();
39        let coords = self.coords();
40        let node_size = self.node_size();
41
42        // Use TinyVec to avoid heap allocations
43        let mut stack: TinyVec<[usize; 33]> = TinyVec::new();
44        stack.push(0);
45        stack.push(indices.len() - 1);
46        stack.push(0);
47
48        let mut result: Vec<u32> = vec![];
49
50        // recursively search for items in range in the kd-sorted arrays
51        while !stack.is_empty() {
52            let axis = stack.pop().unwrap_or(0);
53            let right = stack.pop().unwrap_or(0);
54            let left = stack.pop().unwrap_or(0);
55
56            // if we reached "tree node", search linearly
57            if right - left <= node_size as usize {
58                for i in left..right + 1 {
59                    let x = coords[2 * i];
60                    let y = coords[2 * i + 1];
61                    if x >= min_x && x <= max_x && y >= min_y && y <= max_y {
62                        result.push(indices.get(i).try_into().unwrap());
63                    }
64                }
65                continue;
66            }
67
68            // otherwise find the middle index
69            let m = (left + right) >> 1;
70
71            // include the middle item if it's in range
72            let x = coords[2 * m];
73            let y = coords[2 * m + 1];
74            if x >= min_x && x <= max_x && y >= min_y && y <= max_y {
75                result.push(indices.get(m).try_into().unwrap());
76            }
77
78            // queue search in halves that intersect the query
79            let lte = if axis == 0 { min_x <= x } else { min_y <= y };
80            if lte {
81                // Note: these are pushed in backwards order to what gets popped
82                stack.push(left);
83                stack.push(m - 1);
84                stack.push(1 - axis);
85            }
86
87            let gte = if axis == 0 { max_x >= x } else { max_y >= y };
88            if gte {
89                // Note: these are pushed in backwards order to what gets popped
90                stack.push(m + 1);
91                stack.push(right);
92                stack.push(1 - axis);
93            }
94        }
95
96        result
97    }
98
99    /// Search the index for items within a given bounding box.
100    ///
101    /// Returns indices of found items
102    fn range_rect(&self, rect: &impl RectTrait<T = N>) -> Vec<u32> {
103        self.range(
104            rect.min().x(),
105            rect.min().y(),
106            rect.max().x(),
107            rect.max().y(),
108        )
109    }
110
111    /// Search the index for items within a given radius.
112    ///
113    /// - qx: x value of query point
114    /// - qy: y value of query point
115    /// - r: radius
116    ///
117    /// Returns indices of found items
118    fn within(&self, qx: N, qy: N, r: N) -> Vec<u32> {
119        let indices = self.indices();
120        let coords = self.coords();
121        let node_size = self.node_size();
122
123        // Use TinyVec to avoid heap allocations
124        let mut stack: TinyVec<[usize; 33]> = TinyVec::new();
125        stack.push(0);
126        stack.push(indices.len() - 1);
127        stack.push(0);
128
129        let mut result: Vec<u32> = vec![];
130        let r2 = r * r;
131
132        // recursively search for items within radius in the kd-sorted arrays
133        while !stack.is_empty() {
134            let axis = stack.pop().unwrap_or(0);
135            let right = stack.pop().unwrap_or(0);
136            let left = stack.pop().unwrap_or(0);
137
138            // if we reached "tree node", search linearly
139            if right - left <= node_size as usize {
140                for i in left..right + 1 {
141                    if sq_dist(coords[2 * i], coords[2 * i + 1], qx, qy) <= r2 {
142                        result.push(indices.get(i).try_into().unwrap());
143                    }
144                }
145                continue;
146            }
147
148            // otherwise find the middle index
149            let m = (left + right) >> 1;
150
151            // include the middle item if it's in range
152            let x = coords[2 * m];
153            let y = coords[2 * m + 1];
154            if sq_dist(x, y, qx, qy) <= r2 {
155                result.push(indices.get(m).try_into().unwrap());
156            }
157
158            // queue search in halves that intersect the query
159            let lte = if axis == 0 { qx - r <= x } else { qy - r <= y };
160            if lte {
161                stack.push(left);
162                stack.push(m - 1);
163                stack.push(1 - axis);
164            }
165
166            let gte = if axis == 0 { qx + r >= x } else { qy + r >= y };
167            if gte {
168                stack.push(m + 1);
169                stack.push(right);
170                stack.push(1 - axis);
171            }
172        }
173        result
174    }
175
176    /// Search the index for items within a given radius.
177    ///
178    /// - coord: coordinate of query point
179    /// - r: radius
180    ///
181    /// Returns indices of found items
182    fn within_coord(&self, coord: &impl CoordTrait<T = N>, r: N) -> Vec<u32> {
183        self.within(coord.x(), coord.y(), r)
184    }
185
186    /// Access the root node of the KDTree for manual traversal.
187    fn root(&self) -> Node<'_, N, Self> {
188        Node::from_root(self)
189    }
190}
191
192impl<N: IndexableNum> KDTreeIndex<N> for KDTree<N> {
193    fn coords(&self) -> &[N] {
194        self.metadata.coords_slice(&self.buffer)
195    }
196
197    fn indices(&self) -> Indices {
198        self.metadata.indices_slice(&self.buffer)
199    }
200
201    fn metadata(&self) -> &KDTreeMetadata<N> {
202        &self.metadata
203    }
204}
205
206impl<N: IndexableNum> KDTreeIndex<N> for KDTreeRef<'_, N> {
207    fn coords(&self) -> &[N] {
208        self.coords
209    }
210
211    fn indices(&self) -> Indices {
212        self.indices
213    }
214
215    fn metadata(&self) -> &KDTreeMetadata<N> {
216        &self.metadata
217    }
218}
219
220#[inline]
221pub(crate) fn sq_dist<N: IndexableNum>(ax: N, ay: N, bx: N, by: N) -> N {
222    let dx = ax - bx;
223    let dy = ay - by;
224    dx * dx + dy * dy
225}