#[derive(Debug, Copy, Clone, PartialEq)]
pub struct Point2D {
pub x: f64,
pub y: f64,
}
impl Point2D {
#[inline]
pub fn new(x: f64, y: f64) -> Self {
Self { x, y }
}
#[inline]
pub fn distance_squared(&self, other: &Point2D) -> f64 {
let dx = self.x - other.x;
let dy = self.y - other.y;
dx * dx + dy * dy
}
}
struct KdNode<T> {
point: Point2D,
data: T,
left: Option<usize>,
right: Option<usize>,
}
pub struct KdTree<T> {
nodes: Vec<KdNode<T>>,
root: Option<usize>,
}
impl<T> KdTree<T> {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
root: None,
}
}
pub fn nearest_neighbor(&self, query: &Point2D) -> Option<&T> {
self.nearest_neighbor_with_distance(query)
.map(|(data, _)| data)
}
pub fn nearest_neighbor_with_distance(&self, query: &Point2D) -> Option<(&T, f64)> {
let root = self.root?;
let mut best: Option<(usize, f64)> = None;
self.search_nearest(root, query, 0, &mut best);
best.map(|(idx, dist)| (&self.nodes[idx].data, dist))
}
fn search_nearest(
&self,
node_idx: usize,
query: &Point2D,
depth: usize,
best: &mut Option<(usize, f64)>,
) {
let node = &self.nodes[node_idx];
let dist = query.distance_squared(&node.point);
match best {
Some((_, best_dist)) if dist < *best_dist => {
*best = Some((node_idx, dist));
}
None => {
*best = Some((node_idx, dist));
}
_ => {}
}
let axis = depth % 2;
let query_val = if axis == 0 { query.x } else { query.y };
let node_val = if axis == 0 {
node.point.x
} else {
node.point.y
};
let diff = query_val - node_val;
let (first, second) = if diff <= 0.0 {
(node.left, node.right)
} else {
(node.right, node.left)
};
if let Some(child) = first {
self.search_nearest(child, query, depth + 1, best);
}
let plane_dist = diff * diff;
if let Some((_, best_dist)) = best {
if plane_dist < *best_dist {
if let Some(child) = second {
self.search_nearest(child, query, depth + 1, best);
}
}
}
}
}
impl<T> Default for KdTree<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> KdTree<T> {
pub fn from_items(items: Vec<(Point2D, T)>) -> Self {
if items.is_empty() {
return Self::new();
}
let mut indexed: Vec<(usize, Point2D)> = items
.iter()
.enumerate()
.map(|(i, (p, _))| (i, *p))
.collect();
let mut tree = Self {
nodes: Vec::with_capacity(items.len()),
root: None,
};
let mut items: Vec<Option<(Point2D, T)>> = items.into_iter().map(Some).collect();
tree.root = tree.build_from_indexed(&mut indexed, &mut items, 0);
tree
}
fn build_from_indexed(
&mut self,
indexed: &mut [(usize, Point2D)],
items: &mut [Option<(Point2D, T)>],
depth: usize,
) -> Option<usize> {
if indexed.is_empty() {
return None;
}
let axis = depth % 2;
indexed.sort_by(|a, b| {
let va = if axis == 0 { a.1.x } else { a.1.y };
let vb = if axis == 0 { b.1.x } else { b.1.y };
va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
});
let mid = indexed.len() / 2;
let (left_slice, rest) = indexed.split_at_mut(mid);
let (mid_item, right_slice) = rest.split_first_mut().expect("not empty");
let orig_idx = mid_item.0;
let (point, data) = items[orig_idx].take().expect("item already taken");
let left = self.build_from_indexed(left_slice, items, depth + 1);
let right = self.build_from_indexed(right_slice, items, depth + 1);
let node_idx = self.nodes.len();
self.nodes.push(KdNode {
point,
data,
left,
right,
});
Some(node_idx)
}
}
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct Segment {
pub from: Point2D,
pub to: Point2D,
}
impl Segment {
#[inline]
pub fn new(from: Point2D, to: Point2D) -> Self {
Self { from, to }
}
#[inline]
pub fn centroid(&self) -> Point2D {
Point2D {
x: (self.from.x + self.to.x) / 2.0,
y: (self.from.y + self.to.y) / 2.0,
}
}
pub fn project_point(&self, point: &Point2D) -> (Point2D, f64) {
let dx = self.to.x - self.from.x;
let dy = self.to.y - self.from.y;
let len_sq = dx * dx + dy * dy;
if len_sq < f64::EPSILON {
return (self.from, 0.0);
}
let t = ((point.x - self.from.x) * dx + (point.y - self.from.y) * dy) / len_sq;
let t = t.clamp(0.0, 1.0);
let proj = Point2D {
x: self.from.x + t * dx,
y: self.from.y + t * dy,
};
(proj, t)
}
}
#[derive(Debug, Copy, Clone)]
struct BoundingBox2D {
min: Point2D,
max: Point2D,
}
impl BoundingBox2D {
fn from_segment(segment: &Segment) -> Self {
Self {
min: Point2D::new(
segment.from.x.min(segment.to.x),
segment.from.y.min(segment.to.y),
),
max: Point2D::new(
segment.from.x.max(segment.to.x),
segment.from.y.max(segment.to.y),
),
}
}
fn union(self, other: Self) -> Self {
Self {
min: Point2D::new(self.min.x.min(other.min.x), self.min.y.min(other.min.y)),
max: Point2D::new(self.max.x.max(other.max.x), self.max.y.max(other.max.y)),
}
}
fn distance_squared_to_point(&self, point: &Point2D) -> f64 {
let dx = if point.x < self.min.x {
self.min.x - point.x
} else if point.x > self.max.x {
point.x - self.max.x
} else {
0.0
};
let dy = if point.y < self.min.y {
self.min.y - point.y
} else if point.y > self.max.y {
point.y - self.max.y
} else {
0.0
};
dx * dx + dy * dy
}
}
struct SegmentNode<T> {
segment: Segment,
data: T,
bounds: BoundingBox2D,
left: Option<usize>,
right: Option<usize>,
}
pub struct SegmentIndex<T> {
nodes: Vec<SegmentNode<T>>,
root: Option<usize>,
}
impl<T> SegmentIndex<T> {
pub fn bulk_load(segments: Vec<(Segment, T)>) -> Self {
let mut indexed: Vec<(usize, Point2D)> = segments
.iter()
.enumerate()
.map(|(i, (segment, _))| (i, segment.centroid()))
.collect();
let mut items: Vec<Option<(Segment, T)>> = segments.into_iter().map(Some).collect();
let mut nodes = Vec::with_capacity(items.len());
let root = Self::build_nodes(&mut indexed, &mut items, 0, &mut nodes);
Self { nodes, root }
}
pub fn nearest_segment(&self, query: &Point2D) -> Option<(&Segment, &T, Point2D, f64)> {
let root = self.root?;
let mut best: Option<(usize, Point2D, f64)> = None;
self.search_nearest(root, query, 0, &mut best);
let (idx, projection, dist) = best?;
let node = &self.nodes[idx];
Some((&node.segment, &node.data, projection, dist))
}
fn build_nodes(
indexed: &mut [(usize, Point2D)],
items: &mut [Option<(Segment, T)>],
depth: usize,
nodes: &mut Vec<SegmentNode<T>>,
) -> Option<usize> {
if indexed.is_empty() {
return None;
}
let axis = depth % 2;
indexed.sort_by(|a, b| {
let va = if axis == 0 { a.1.x } else { a.1.y };
let vb = if axis == 0 { b.1.x } else { b.1.y };
va.partial_cmp(&vb).unwrap_or(std::cmp::Ordering::Equal)
});
let mid = indexed.len() / 2;
let (left_slice, rest) = indexed.split_at_mut(mid);
let (mid_item, right_slice) = rest.split_first_mut().expect("not empty");
let left = Self::build_nodes(left_slice, items, depth + 1, nodes);
let right = Self::build_nodes(right_slice, items, depth + 1, nodes);
let (segment, data) = items[mid_item.0].take().expect("item already taken");
let bounds = BoundingBox2D::from_segment(&segment);
let idx = nodes.len();
nodes.push(SegmentNode {
segment,
data,
bounds,
left,
right,
});
let mut bounds = BoundingBox2D::from_segment(&nodes[idx].segment);
if let Some(left_idx) = left {
bounds = bounds.union(nodes[left_idx].bounds);
}
if let Some(right_idx) = right {
bounds = bounds.union(nodes[right_idx].bounds);
}
nodes[idx].bounds = bounds;
Some(idx)
}
fn search_nearest(
&self,
node_idx: usize,
query: &Point2D,
depth: usize,
best: &mut Option<(usize, Point2D, f64)>,
) {
let node = &self.nodes[node_idx];
let (projection, _) = node.segment.project_point(query);
let dist = query.distance_squared(&projection);
match best {
Some((_, _, best_dist)) if dist < *best_dist => {
*best = Some((node_idx, projection, dist));
}
None => {
*best = Some((node_idx, projection, dist));
}
_ => {}
}
let axis = depth % 2;
let centroid = node.segment.centroid();
let query_val = if axis == 0 { query.x } else { query.y };
let node_val = if axis == 0 { centroid.x } else { centroid.y };
let (first, second) = if query_val <= node_val {
(node.left, node.right)
} else {
(node.right, node.left)
};
if let Some(child) = first {
self.search_child(child, query, depth + 1, best);
}
if let Some(child) = second {
self.search_child(child, query, depth + 1, best);
}
}
fn search_child(
&self,
child_idx: usize,
query: &Point2D,
depth: usize,
best: &mut Option<(usize, Point2D, f64)>,
) {
if let Some((_, _, best_dist)) = best {
let child_dist = self.nodes[child_idx]
.bounds
.distance_squared_to_point(query);
if child_dist > *best_dist {
return;
}
}
self.search_nearest(child_idx, query, depth, best);
}
}
#[cfg(test)]
#[path = "spatial_tests.rs"]
mod tests;