1use super::{point::Point, rect::Rect, sphere::Sphere};
2use crate::{measure::distance::Metric, SRTree};
3use ordered_float::{Float, OrderedFloat};
4
5impl<T, M> SRTree<T, M>
6where
7 T: Float + Send + Sync,
8 M: Metric<T>,
9{
10 pub fn reshape(&mut self, node_index: usize) {
11 let centroid = Point::new(self.calculate_mean(node_index), node_index);
12 let node = &self.nodes[node_index];
13
14 let mut max_distance = T::zero();
15 let mut low = centroid.coords.clone();
16 let mut high = centroid.coords.clone();
17 if node.is_leaf() {
18 let mut points = Vec::with_capacity(node.points().len());
19 for point_index in node.points() {
20 let point = &self.points[*point_index];
21 for i in 0..low.len() {
22 low[i] = low[i].min(point.coords[i]);
23 high[i] = high[i].max(point.coords[i]);
24 }
25 let distance_to_point = self.distance(¢roid, point);
26 max_distance = max_distance.max(distance_to_point);
27 points.push((distance_to_point, *point_index));
28 }
29
30 for (distance, point_index) in &points {
31 self.points[*point_index].radius = *distance;
32 self.points[*point_index].parent_index = node_index;
33 }
34
35 points.sort_by_key(|(distance, _)| -OrderedFloat(*distance));
36 let points: Vec<usize> = points.into_iter().map(|(_, index)| index).collect();
37 self.nodes[node_index].set_points(points);
38 } else {
39 node.children().iter().for_each(|child_index| {
40 let child = &self.nodes[*child_index];
41 for i in 0..self.params.dimension {
42 low[i] = low[i].min(child.rect.low[i]);
43 high[i] = high[i].max(child.rect.high[i]);
44 }
45 let distance = self.point_to_node_max_distance(¢roid, child);
46 max_distance = max_distance.max(distance);
47 });
48 }
49
50 let node = &mut self.nodes[node_index];
51 node.rect = Rect::new(low, high);
52 node.sphere = Sphere::new(centroid, max_distance);
53 }
54}
55
56#[cfg(test)]
57mod tests {
58 use num_traits::Float;
59
60 use crate::SRTree;
61
62 #[test]
63 pub fn test_reshape() {
64 let pts = vec![
65 vec![1.0, 1.0],
66 vec![2.0, 2.0],
67 vec![3.0, 3.0],
68 vec![4.0, 4.0],
69 vec![5.0, 5.0],
70 ];
71 let tree = SRTree::euclidean(&pts).unwrap();
72 assert_eq!(tree.nodes[0].rect.low, vec![1., 1.]);
73 assert_eq!(tree.nodes[0].rect.high, vec![5., 5.]);
74 assert_eq!(tree.nodes[0].sphere.center.coords, vec![3., 3.]);
75 assert_eq!(tree.nodes[0].sphere.radius, (4. + 4.).sqrt());
76 }
77}