geo_index/kdtree/
builder.rs

1use std::cmp;
2
3use bytemuck::cast_slice_mut;
4use geo_traits::{CoordTrait, PointTrait};
5
6use crate::error::Result;
7use crate::indices::MutableIndices;
8use crate::kdtree::constants::{KDBUSH_HEADER_SIZE, KDBUSH_MAGIC, KDBUSH_VERSION};
9use crate::kdtree::index::KDTreeMetadata;
10use crate::kdtree::KDTree;
11use crate::r#type::IndexableNum;
12use crate::GeoIndexError;
13
14/// Default node size in [`KDTreeBuilder::new`].
15pub const DEFAULT_KDTREE_NODE_SIZE: u16 = 64;
16
17/// A builder to create an [`KDTree`].
18#[derive(Debug)]
19pub struct KDTreeBuilder<N: IndexableNum> {
20    /// data buffer
21    data: Vec<u8>,
22    metadata: KDTreeMetadata<N>,
23    pos: usize,
24}
25
26impl<N: IndexableNum> KDTreeBuilder<N> {
27    /// Create a new builder with the provided number of items and the default node size.
28    pub fn new(num_items: u32) -> Self {
29        Self::new_with_node_size(num_items, DEFAULT_KDTREE_NODE_SIZE)
30    }
31
32    /// Create a new builder with the provided number of items and node size.
33    pub fn new_with_node_size(num_items: u32, node_size: u16) -> Self {
34        let metadata = KDTreeMetadata::new(num_items, node_size);
35        Self::from_metadata(metadata)
36    }
37
38    /// Create a new builder with the provided metadata
39    pub fn from_metadata(metadata: KDTreeMetadata<N>) -> Self {
40        let data_buffer_length = metadata.data_buffer_length();
41        let mut data = vec![0; data_buffer_length];
42
43        // Set data header;
44        data[0] = KDBUSH_MAGIC;
45        data[1] = (KDBUSH_VERSION << 4) + N::TYPE_INDEX;
46        cast_slice_mut(&mut data[2..4])[0] = metadata.node_size();
47        cast_slice_mut(&mut data[4..8])[0] = metadata.num_items();
48
49        Self {
50            data,
51            pos: 0,
52            metadata,
53        }
54    }
55
56    /// Access the underlying [KDTreeMetadata] of this instance.
57    pub fn metadata(&self) -> &KDTreeMetadata<N> {
58        &self.metadata
59    }
60
61    /// Add a point to the KDTree.
62    ///
63    /// This returns a positional index that provides a lookup back into the original data.
64    #[inline]
65    pub fn add(&mut self, x: N, y: N) -> u32 {
66        let index = self.pos >> 1;
67        let (coords, mut ids) = split_data_borrow(&mut self.data, self.metadata);
68
69        ids.set(index, index);
70        coords[self.pos] = x;
71        self.pos += 1;
72        coords[self.pos] = y;
73        self.pos += 1;
74
75        index.try_into().unwrap()
76    }
77
78    /// Add a coord to the KDTree.
79    ///
80    /// This returns a positional index that provides a lookup back into the original data.
81    #[inline]
82    pub fn add_coord(&mut self, coord: &impl CoordTrait<T = N>) -> u32 {
83        self.add(coord.x(), coord.y())
84    }
85
86    /// Add a point to the KDTree.
87    ///
88    /// This returns a positional index that provides a lookup back into the original data.
89    ///
90    /// ## Errors
91    ///
92    /// - If the point is empty.
93    #[inline]
94    pub fn add_point(&mut self, point: &impl PointTrait<T = N>) -> Result<u32> {
95        let coord = point.coord().ok_or(GeoIndexError::General(
96            "Unable to add empty point to KDTree".to_string(),
97        ))?;
98        Ok(self.add_coord(&coord))
99    }
100
101    /// Consume this builder, perfoming the k-d sort and generating a KDTree ready for queries.
102    pub fn finish(mut self) -> KDTree<N> {
103        assert_eq!(
104            self.pos >> 1,
105            self.metadata.num_items() as usize,
106            "Added {} items when expected {}.",
107            self.pos >> 1,
108            self.metadata.num_items()
109        );
110
111        let (coords, mut ids) = split_data_borrow::<N>(&mut self.data, self.metadata);
112
113        // kd-sort both arrays for efficient search
114        sort(
115            &mut ids,
116            coords,
117            self.metadata.node_size() as usize,
118            0,
119            self.metadata.num_items() as usize - 1,
120            0,
121        );
122
123        KDTree {
124            buffer: self.data,
125            metadata: self.metadata,
126        }
127    }
128}
129
130/// Mutable borrow of coords and ids
131fn split_data_borrow<N: IndexableNum>(
132    data: &mut [u8],
133    metadata: KDTreeMetadata<N>,
134) -> (&mut [N], MutableIndices) {
135    let (ids_buf, padded_coords_buf) =
136        data[KDBUSH_HEADER_SIZE..].split_at_mut(metadata.indices_byte_size);
137    let coords_buf = &mut padded_coords_buf[metadata.pad_coords_byte_size..];
138    debug_assert_eq!(coords_buf.len(), metadata.coords_byte_size);
139
140    let ids = if metadata.num_items() < 65536 {
141        MutableIndices::U16(cast_slice_mut(ids_buf))
142    } else {
143        MutableIndices::U32(cast_slice_mut(ids_buf))
144    };
145    let coords = cast_slice_mut(coords_buf);
146
147    (coords, ids)
148}
149
150fn sort<N: IndexableNum>(
151    ids: &mut MutableIndices,
152    coords: &mut [N],
153    node_size: usize,
154    left: usize,
155    right: usize,
156    axis: usize,
157) {
158    if right - left <= node_size {
159        return;
160    }
161
162    // middle index
163    let m = (left + right) >> 1;
164
165    // sort ids and coords around the middle index so that the halves lie either left/right or
166    // top/bottom correspondingly (taking turns)
167    select(ids, coords, m, left, right, axis);
168
169    // recursively kd-sort first half and second half on the opposite axis
170    sort(ids, coords, node_size, left, m - 1, 1 - axis);
171    sort(ids, coords, node_size, m + 1, right, 1 - axis);
172}
173
174/// Custom Floyd-Rivest selection algorithm: sort ids and coords so that [left..k-1] items are
175/// smaller than k-th item (on either x or y axis)
176#[inline]
177fn select<N: IndexableNum>(
178    ids: &mut MutableIndices,
179    coords: &mut [N],
180    k: usize,
181    mut left: usize,
182    mut right: usize,
183    axis: usize,
184) {
185    while right > left {
186        if right - left > 600 {
187            let n = (right - left + 1) as f64;
188            let m = (k - left + 1) as f64;
189            let z = f64::ln(n);
190            let s = 0.5 * f64::exp((2.0 * z) / 3.0);
191            let sd = 0.5
192                * f64::sqrt((z * s * (n - s)) / n)
193                * (if m - n / 2.0 < 0.0 { -1.0 } else { 1.0 });
194            let new_left = cmp::max(left, f64::floor(k as f64 - (m * s) / n + sd) as usize);
195            let new_right = cmp::min(
196                right,
197                f64::floor(k as f64 + ((n - m) * s) / n + sd) as usize,
198            );
199            select(ids, coords, k, new_left, new_right, axis);
200        }
201
202        let t = coords[2 * k + axis];
203        let mut i = left;
204        let mut j = right;
205
206        swap_item(ids, coords, left, k);
207        if coords[2 * right + axis] > t {
208            swap_item(ids, coords, left, right);
209        }
210
211        while i < j {
212            swap_item(ids, coords, i, j);
213            i += 1;
214            j -= 1;
215            while coords[2 * i + axis] < t {
216                i += 1;
217            }
218            while coords[2 * j + axis] > t {
219                j -= 1;
220            }
221        }
222
223        if coords[2 * left + axis] == t {
224            swap_item(ids, coords, left, j);
225        } else {
226            j += 1;
227            swap_item(ids, coords, j, right);
228        }
229
230        if j <= k {
231            left = j + 1;
232        }
233        if k <= j {
234            right = j - 1;
235        }
236    }
237}
238
239#[inline]
240fn swap_item<N: IndexableNum>(ids: &mut MutableIndices, coords: &mut [N], i: usize, j: usize) {
241    ids.swap(i, j);
242    coords.swap(2 * i, 2 * j);
243    coords.swap(2 * i + 1, 2 * j + 1);
244}