srtree/shape/
reshape.rs

1use super::{point::Point, rect::Rect, sphere::Sphere};
2use crate::{measure::distance::Metric, SRTree};
3use ordered_float::{Float, OrderedFloat};
4
5impl<T, M> SRTree<T, M>
6where
7    T: Float + Send + Sync,
8    M: Metric<T>,
9{
10    pub fn reshape(&mut self, node_index: usize) {
11        let centroid = Point::new(self.calculate_mean(node_index), node_index);
12        let node = &self.nodes[node_index];
13
14        let mut max_distance = T::zero();
15        let mut low = centroid.coords.clone();
16        let mut high = centroid.coords.clone();
17        if node.is_leaf() {
18            let mut points = Vec::with_capacity(node.points().len());
19            for point_index in node.points() {
20                let point = &self.points[*point_index];
21                for i in 0..low.len() {
22                    low[i] = low[i].min(point.coords[i]);
23                    high[i] = high[i].max(point.coords[i]);
24                }
25                let distance_to_point = self.distance(&centroid, point);
26                max_distance = max_distance.max(distance_to_point);
27                points.push((distance_to_point, *point_index));
28            }
29
30            for (distance, point_index) in &points {
31                self.points[*point_index].radius = *distance;
32                self.points[*point_index].parent_index = node_index;
33            }
34
35            points.sort_by_key(|(distance, _)| -OrderedFloat(*distance));
36            let points: Vec<usize> = points.into_iter().map(|(_, index)| index).collect();
37            self.nodes[node_index].set_points(points);
38        } else {
39            node.children().iter().for_each(|child_index| {
40                let child = &self.nodes[*child_index];
41                for i in 0..self.params.dimension {
42                    low[i] = low[i].min(child.rect.low[i]);
43                    high[i] = high[i].max(child.rect.high[i]);
44                }
45                let distance = self.point_to_node_max_distance(&centroid, child);
46                max_distance = max_distance.max(distance);
47            });
48        }
49
50        let node = &mut self.nodes[node_index];
51        node.rect = Rect::new(low, high);
52        node.sphere = Sphere::new(centroid, max_distance);
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use num_traits::Float;
59
60    use crate::SRTree;
61
62    #[test]
63    pub fn test_reshape() {
64        let pts = vec![
65            vec![1.0, 1.0],
66            vec![2.0, 2.0],
67            vec![3.0, 3.0],
68            vec![4.0, 4.0],
69            vec![5.0, 5.0],
70        ];
71        let tree = SRTree::euclidean(&pts).unwrap();
72        assert_eq!(tree.nodes[0].rect.low, vec![1., 1.]);
73        assert_eq!(tree.nodes[0].rect.high, vec![5., 5.]);
74        assert_eq!(tree.nodes[0].sphere.center.coords, vec![3., 3.]);
75        assert_eq!(tree.nodes[0].sphere.radius, (4. + 4.).sqrt());
76    }
77}