pinpointer/
labeling.rs

1use geo::{BooleanOps, Contains, CoordsIter, Intersects, MultiPolygon, Point, Rect};
2use plotters::{
3    prelude::{BitMapBackend, ChartBuilder, IntoDrawingArea},
4    series::LineSeries,
5    style::{BLACK, RED, WHITE},
6};
7use std::{collections::HashMap, hash::Hash, path::Path};
8
9/// A struct representing a labeled partition tree.
10///
11/// This structure is used for performing fast point-in-polygon queries by recursively checking 
12/// bounding boxes before performing the final point-in-polygon check.
13#[derive(serde::Serialize, serde::Deserialize)]
14pub struct LabeledPartitionTree<T: Eq + Hash> {
15    children: Box<Vec<LabeledPartitionTree<T>>>,
16    polygons: HashMap<T, MultiPolygon>,
17    bbox: Rect,
18}
19
20impl<T: Clone + Eq + Hash> LabeledPartitionTree<T> {
21    /// Constructs a labeled partition tree from a set of labeled polygons.
22    ///
23    /// # Arguments
24    /// * `selected` - The labels of the polygons to be included in the tree.
25    /// * `polygons` - A map of labels to their corresponding polygons.
26    /// * `bbox` - The bounding box for the current partition.
27    /// * `max_depth` - The maximum depth of the tree. Deeper trees tend to result in faster queries, 
28    ///                 but take much longer to construct.
29    /// * `depth` - The current depth during recursion.
30    pub fn from_labeled_polygons(
31        selected: &Vec<T>,
32        polygons: &HashMap<T, MultiPolygon>,
33        bbox: Rect,
34        max_depth: usize,
35        depth: usize,
36    ) -> LabeledPartitionTree<T> {
37        let (children, inner_polygons) = if depth == max_depth {
38            (
39                Box::new(vec![]),
40                selected
41                    .iter()
42                    .map(|label| {
43                        (
44                            label.clone(),
45                            polygons
46                                .get(label)
47                                .unwrap()
48                                .intersection(&MultiPolygon::from(bbox)), // TODO this intersection is slow
49                        )
50                    })
51                    .collect(),
52            )
53        } else if selected.len() == 0 {
54            (Box::new(vec![]), HashMap::new())
55        } else if selected.len() == 1 && polygons.get(&selected[0]).unwrap().contains(&bbox) {
56            // TODO the check for this is slow
57            (
58                Box::new(vec![]),
59                vec![(selected[0].clone(), MultiPolygon::from(bbox))]
60                    .into_iter()
61                    .collect(),
62            )
63        } else {
64            // TODO check if a different branching factor can speed things up
65            let [ab, cd] = bbox.split_x();
66            let [a, b] = ab.split_y();
67            let [c, d] = cd.split_y();
68            let bboxes = vec![a, b, c, d];
69
70            let bbox_selected_polygons: Vec<Vec<T>> = bboxes
71                .iter()
72                .map(|bbox| {
73                    // TODO it might be possible to speed up this intersection check
74                    selected
75                        .iter()
76                        .filter(|&label| bbox.intersects(polygons.get(label).unwrap()))
77                        .cloned()
78                        .collect()
79                })
80                .collect();
81
82            (
83                Box::new(
84                    bbox_selected_polygons
85                        .iter()
86                        .zip(bboxes)
87                        .map(|(selected, bbox)| {
88                            LabeledPartitionTree::from_labeled_polygons(
89                                selected,
90                                polygons,
91                                bbox,
92                                max_depth,
93                                depth + 1,
94                            )
95                        })
96                        .collect(),
97                ),
98                HashMap::new(),
99            )
100        };
101
102        LabeledPartitionTree {
103            children,
104            bbox,
105            polygons: inner_polygons,
106        }
107    }
108
109    /// Returns the label of the partition that contains the given point.
110    ///
111    /// This method recursively searches for the leaf node that contains the point and returns its label.
112    /// If no leaf node contains the point, `None` is returned.
113    /// 
114    /// # Arguments
115    /// * `point` - The point to check.
116    pub fn label(&self, point: &Point) -> Option<T> {
117        if self.children.is_empty() {
118            self.polygons.iter().find_map(|(label, polygon)| {
119                if polygon.contains(point) {
120                    Some(label.clone())
121                } else {
122                    None
123                }
124            })
125        } else {
126            self.children
127                .iter()
128                .filter(|child| child.bbox.contains(point))
129                .find_map(|child| child.label(point))
130        }
131    }
132
133    /// Returns the number of leaf nodes in the tree.
134    pub fn size(&self) -> usize {
135        if self.children.is_empty() {
136            1
137        } else {
138            self.children.iter().map(|child| child.size()).sum()
139        }
140    }
141
142    /// Plots the labeled partition tree and saves the image to the specified path.
143    ///
144    /// # Arguments
145    /// * `out_path` - The path where the resulting image will be saved.
146    pub fn plot(&self, out_path: &Path) -> Result<(), Box<dyn std::error::Error>> {
147        let root = BitMapBackend::new(out_path, (4000, 3000)).into_drawing_area();
148        root.fill(&WHITE)?;
149        let mut chart = ChartBuilder::on(&root)
150            .margin(5)
151            .x_label_area_size(30)
152            .y_label_area_size(30)
153            .build_cartesian_2d(-180f32..180f32, -90f32..90f32)?;
154
155        chart.configure_mesh().draw()?;
156
157        let bboxes = self.bboxes();
158        bboxes.iter().for_each(|bbox| {
159            chart
160                .draw_series(LineSeries::new(
161                    bbox.coords_iter()
162                        .map(|coord| (coord.x as f32, coord.y as f32)),
163                    &RED,
164                ))
165                .unwrap();
166        });
167
168        chart
169            .configure_series_labels()
170            .background_style(&WHITE)
171            .border_style(&BLACK)
172            .draw()?;
173
174        root.present()?;
175        Ok(())
176    }
177
178    /// Returns a vector of bounding boxes for all leaf nodes in the labeled partition tree.
179    fn bboxes(&self) -> Vec<Rect> {
180        if self.children.is_empty() {
181            vec![self.bbox]
182        } else {
183            self.children
184                .iter()
185                .map(|child| child.bboxes())
186                .flatten()
187                .collect()
188        }
189    }
190}
191