Skip to main content

nodedb_spatial/rtree/
search.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! R-tree range search and nearest-neighbor queries.
4
5use nodedb_types::BoundingBox;
6use std::collections::BinaryHeap;
7
8use super::node::{EntryId, Node, NodeKind, RTreeEntry};
9
10/// Result from nearest-neighbor search.
11#[derive(Debug, Clone)]
12pub struct NnResult {
13    pub entry_id: EntryId,
14    pub bbox: BoundingBox,
15    /// Minimum distance in degrees (approximate).
16    pub distance: f64,
17}
18
19/// Recursive range search.
20pub(crate) fn search_node<'a>(
21    nodes: &'a [Node],
22    node_idx: usize,
23    query: &BoundingBox,
24    results: &mut Vec<&'a RTreeEntry>,
25) {
26    let node = &nodes[node_idx];
27    if !node.bbox.intersects(query) {
28        return;
29    }
30    match &node.kind {
31        NodeKind::Leaf { entries } => {
32            for entry in entries {
33                if entry.bbox.intersects(query) {
34                    results.push(entry);
35                }
36            }
37        }
38        NodeKind::Internal { children } => {
39            for child in children {
40                if child.bbox.intersects(query) {
41                    search_node(nodes, child.node_idx, query, results);
42                }
43            }
44        }
45    }
46}
47
48/// Nearest-neighbor search via priority queue (min-heap).
49pub(crate) fn nearest(
50    nodes: &[Node],
51    root: usize,
52    query_lng: f64,
53    query_lat: f64,
54    k: usize,
55    is_empty: bool,
56) -> Vec<NnResult> {
57    if k == 0 || is_empty {
58        return Vec::new();
59    }
60
61    let mut heap: BinaryHeap<HeapItem> = BinaryHeap::new();
62    let mut results: Vec<NnResult> = Vec::with_capacity(k);
63
64    heap.push(HeapItem {
65        dist: min_dist_point_bbox(query_lng, query_lat, &nodes[root].bbox),
66        node_idx: root,
67    });
68
69    while let Some(item) = heap.pop() {
70        if results.len() >= k && item.dist > results[k - 1].distance {
71            continue;
72        }
73        let node = &nodes[item.node_idx];
74        match &node.kind {
75            NodeKind::Internal { children } => {
76                for child in children {
77                    let d = min_dist_point_bbox(query_lng, query_lat, &child.bbox);
78                    if results.len() < k || d <= results[results.len() - 1].distance {
79                        heap.push(HeapItem {
80                            dist: d,
81                            node_idx: child.node_idx,
82                        });
83                    }
84                }
85            }
86            NodeKind::Leaf { entries } => {
87                for entry in entries {
88                    let d = min_dist_point_bbox(query_lng, query_lat, &entry.bbox);
89                    if results.len() < k || d < results[results.len() - 1].distance {
90                        let nn = NnResult {
91                            entry_id: entry.id,
92                            bbox: entry.bbox,
93                            distance: d,
94                        };
95                        insert_sorted(&mut results, nn, k);
96                    }
97                }
98            }
99        }
100    }
101
102    results
103}
104
105/// Min-heap item for NN traversal.
106#[derive(Debug)]
107struct HeapItem {
108    dist: f64,
109    node_idx: usize,
110}
111
112impl PartialEq for HeapItem {
113    fn eq(&self, other: &Self) -> bool {
114        self.dist == other.dist
115    }
116}
117impl Eq for HeapItem {}
118
119impl PartialOrd for HeapItem {
120    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
121        Some(self.cmp(other))
122    }
123}
124
125impl Ord for HeapItem {
126    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
127        // Reverse for min-heap (BinaryHeap is max-heap).
128        other
129            .dist
130            .partial_cmp(&self.dist)
131            .unwrap_or(std::cmp::Ordering::Equal)
132    }
133}
134
135/// Minimum distance from point to bbox (in degrees, approximate).
136fn min_dist_point_bbox(lng: f64, lat: f64, bbox: &BoundingBox) -> f64 {
137    let dlat = if lat < bbox.min_lat {
138        bbox.min_lat - lat
139    } else if lat > bbox.max_lat {
140        lat - bbox.max_lat
141    } else {
142        0.0
143    };
144
145    let dlng = if bbox.crosses_antimeridian() {
146        if lng >= bbox.min_lng || lng <= bbox.max_lng {
147            0.0
148        } else {
149            (bbox.min_lng - lng).min(lng - bbox.max_lng).max(0.0)
150        }
151    } else if lng < bbox.min_lng {
152        bbox.min_lng - lng
153    } else if lng > bbox.max_lng {
154        lng - bbox.max_lng
155    } else {
156        0.0
157    };
158
159    (dlat * dlat + dlng * dlng).sqrt()
160}
161
162fn insert_sorted(results: &mut Vec<NnResult>, item: NnResult, k: usize) {
163    let pos = results
164        .binary_search_by(|r| {
165            r.distance
166                .partial_cmp(&item.distance)
167                .unwrap_or(std::cmp::Ordering::Equal)
168        })
169        .unwrap_or_else(|pos| pos);
170    results.insert(pos, item);
171    if results.len() > k {
172        results.truncate(k);
173    }
174}