1use std::cmp;
2
3use bytemuck::cast_slice_mut;
4use geo_traits::{CoordTrait, PointTrait};
5
6use crate::error::Result;
7use crate::indices::MutableIndices;
8use crate::kdtree::constants::{KDBUSH_HEADER_SIZE, KDBUSH_MAGIC, KDBUSH_VERSION};
9use crate::kdtree::index::KDTreeMetadata;
10use crate::kdtree::KDTree;
11use crate::r#type::IndexableNum;
12use crate::GeoIndexError;
13
14pub const DEFAULT_KDTREE_NODE_SIZE: u16 = 64;
16
17#[derive(Debug)]
19pub struct KDTreeBuilder<N: IndexableNum> {
20 data: Vec<u8>,
22 metadata: KDTreeMetadata<N>,
23 pos: usize,
24}
25
26impl<N: IndexableNum> KDTreeBuilder<N> {
27 pub fn new(num_items: u32) -> Self {
29 Self::new_with_node_size(num_items, DEFAULT_KDTREE_NODE_SIZE)
30 }
31
32 pub fn new_with_node_size(num_items: u32, node_size: u16) -> Self {
34 let metadata = KDTreeMetadata::new(num_items, node_size);
35 Self::from_metadata(metadata)
36 }
37
38 pub fn from_metadata(metadata: KDTreeMetadata<N>) -> Self {
40 let data_buffer_length = metadata.data_buffer_length();
41 let mut data = vec![0; data_buffer_length];
42
43 data[0] = KDBUSH_MAGIC;
45 data[1] = (KDBUSH_VERSION << 4) + N::TYPE_INDEX;
46 cast_slice_mut(&mut data[2..4])[0] = metadata.node_size();
47 cast_slice_mut(&mut data[4..8])[0] = metadata.num_items();
48
49 Self {
50 data,
51 pos: 0,
52 metadata,
53 }
54 }
55
56 pub fn metadata(&self) -> &KDTreeMetadata<N> {
58 &self.metadata
59 }
60
61 #[inline]
65 pub fn add(&mut self, x: N, y: N) -> u32 {
66 let index = self.pos >> 1;
67 let (coords, mut ids) = split_data_borrow(&mut self.data, self.metadata);
68
69 ids.set(index, index);
70 coords[self.pos] = x;
71 self.pos += 1;
72 coords[self.pos] = y;
73 self.pos += 1;
74
75 index.try_into().unwrap()
76 }
77
78 #[inline]
82 pub fn add_coord(&mut self, coord: &impl CoordTrait<T = N>) -> u32 {
83 self.add(coord.x(), coord.y())
84 }
85
86 #[inline]
94 pub fn add_point(&mut self, point: &impl PointTrait<T = N>) -> Result<u32> {
95 let coord = point.coord().ok_or(GeoIndexError::General(
96 "Unable to add empty point to KDTree".to_string(),
97 ))?;
98 Ok(self.add_coord(&coord))
99 }
100
101 pub fn finish(mut self) -> KDTree<N> {
103 assert_eq!(
104 self.pos >> 1,
105 self.metadata.num_items() as usize,
106 "Added {} items when expected {}.",
107 self.pos >> 1,
108 self.metadata.num_items()
109 );
110
111 let (coords, mut ids) = split_data_borrow::<N>(&mut self.data, self.metadata);
112
113 sort(
115 &mut ids,
116 coords,
117 self.metadata.node_size() as usize,
118 0,
119 self.metadata.num_items() as usize - 1,
120 0,
121 );
122
123 KDTree {
124 buffer: self.data,
125 metadata: self.metadata,
126 }
127 }
128}
129
130fn split_data_borrow<N: IndexableNum>(
132 data: &mut [u8],
133 metadata: KDTreeMetadata<N>,
134) -> (&mut [N], MutableIndices) {
135 let (ids_buf, padded_coords_buf) =
136 data[KDBUSH_HEADER_SIZE..].split_at_mut(metadata.indices_byte_size);
137 let coords_buf = &mut padded_coords_buf[metadata.pad_coords_byte_size..];
138 debug_assert_eq!(coords_buf.len(), metadata.coords_byte_size);
139
140 let ids = if metadata.num_items() < 65536 {
141 MutableIndices::U16(cast_slice_mut(ids_buf))
142 } else {
143 MutableIndices::U32(cast_slice_mut(ids_buf))
144 };
145 let coords = cast_slice_mut(coords_buf);
146
147 (coords, ids)
148}
149
150fn sort<N: IndexableNum>(
151 ids: &mut MutableIndices,
152 coords: &mut [N],
153 node_size: usize,
154 left: usize,
155 right: usize,
156 axis: usize,
157) {
158 if right - left <= node_size {
159 return;
160 }
161
162 let m = (left + right) >> 1;
164
165 select(ids, coords, m, left, right, axis);
168
169 sort(ids, coords, node_size, left, m - 1, 1 - axis);
171 sort(ids, coords, node_size, m + 1, right, 1 - axis);
172}
173
174#[inline]
177fn select<N: IndexableNum>(
178 ids: &mut MutableIndices,
179 coords: &mut [N],
180 k: usize,
181 mut left: usize,
182 mut right: usize,
183 axis: usize,
184) {
185 while right > left {
186 if right - left > 600 {
187 let n = (right - left + 1) as f64;
188 let m = (k - left + 1) as f64;
189 let z = f64::ln(n);
190 let s = 0.5 * f64::exp((2.0 * z) / 3.0);
191 let sd = 0.5
192 * f64::sqrt((z * s * (n - s)) / n)
193 * (if m - n / 2.0 < 0.0 { -1.0 } else { 1.0 });
194 let new_left = cmp::max(left, f64::floor(k as f64 - (m * s) / n + sd) as usize);
195 let new_right = cmp::min(
196 right,
197 f64::floor(k as f64 + ((n - m) * s) / n + sd) as usize,
198 );
199 select(ids, coords, k, new_left, new_right, axis);
200 }
201
202 let t = coords[2 * k + axis];
203 let mut i = left;
204 let mut j = right;
205
206 swap_item(ids, coords, left, k);
207 if coords[2 * right + axis] > t {
208 swap_item(ids, coords, left, right);
209 }
210
211 while i < j {
212 swap_item(ids, coords, i, j);
213 i += 1;
214 j -= 1;
215 while coords[2 * i + axis] < t {
216 i += 1;
217 }
218 while coords[2 * j + axis] > t {
219 j -= 1;
220 }
221 }
222
223 if coords[2 * left + axis] == t {
224 swap_item(ids, coords, left, j);
225 } else {
226 j += 1;
227 swap_item(ids, coords, j, right);
228 }
229
230 if j <= k {
231 left = j + 1;
232 }
233 if k <= j {
234 right = j - 1;
235 }
236 }
237}
238
239#[inline]
240fn swap_item<N: IndexableNum>(ids: &mut MutableIndices, coords: &mut [N], i: usize, j: usize) {
241 ids.swap(i, j);
242 coords.swap(2 * i, 2 * j);
243 coords.swap(2 * i + 1, 2 * j + 1);
244}