pinpointer/labeling.rs
1use geo::{BooleanOps, Contains, CoordsIter, Intersects, MultiPolygon, Point, Rect};
2use plotters::{
3 prelude::{BitMapBackend, ChartBuilder, IntoDrawingArea},
4 series::LineSeries,
5 style::{BLACK, RED, WHITE},
6};
7use std::{collections::HashMap, hash::Hash, path::Path};
8
9/// A struct representing a labeled partition tree.
10///
11/// This structure is used for performing fast point-in-polygon queries by recursively checking
12/// bounding boxes before performing the final point-in-polygon check.
13#[derive(serde::Serialize, serde::Deserialize)]
14pub struct LabeledPartitionTree<T: Eq + Hash> {
15 children: Box<Vec<LabeledPartitionTree<T>>>,
16 polygons: HashMap<T, MultiPolygon>,
17 bbox: Rect,
18}
19
20impl<T: Clone + Eq + Hash> LabeledPartitionTree<T> {
21 /// Constructs a labeled partition tree from a set of labeled polygons.
22 ///
23 /// # Arguments
24 /// * `selected` - The labels of the polygons to be included in the tree.
25 /// * `polygons` - A map of labels to their corresponding polygons.
26 /// * `bbox` - The bounding box for the current partition.
27 /// * `max_depth` - The maximum depth of the tree. Deeper trees tend to result in faster queries,
28 /// but take much longer to construct.
29 /// * `depth` - The current depth during recursion.
30 pub fn from_labeled_polygons(
31 selected: &Vec<T>,
32 polygons: &HashMap<T, MultiPolygon>,
33 bbox: Rect,
34 max_depth: usize,
35 depth: usize,
36 ) -> LabeledPartitionTree<T> {
37 let (children, inner_polygons) = if depth == max_depth {
38 (
39 Box::new(vec![]),
40 selected
41 .iter()
42 .map(|label| {
43 (
44 label.clone(),
45 polygons
46 .get(label)
47 .unwrap()
48 .intersection(&MultiPolygon::from(bbox)), // TODO this intersection is slow
49 )
50 })
51 .collect(),
52 )
53 } else if selected.len() == 0 {
54 (Box::new(vec![]), HashMap::new())
55 } else if selected.len() == 1 && polygons.get(&selected[0]).unwrap().contains(&bbox) {
56 // TODO the check for this is slow
57 (
58 Box::new(vec![]),
59 vec![(selected[0].clone(), MultiPolygon::from(bbox))]
60 .into_iter()
61 .collect(),
62 )
63 } else {
64 // TODO check if a different branching factor can speed things up
65 let [ab, cd] = bbox.split_x();
66 let [a, b] = ab.split_y();
67 let [c, d] = cd.split_y();
68 let bboxes = vec![a, b, c, d];
69
70 let bbox_selected_polygons: Vec<Vec<T>> = bboxes
71 .iter()
72 .map(|bbox| {
73 // TODO it might be possible to speed up this intersection check
74 selected
75 .iter()
76 .filter(|&label| bbox.intersects(polygons.get(label).unwrap()))
77 .cloned()
78 .collect()
79 })
80 .collect();
81
82 (
83 Box::new(
84 bbox_selected_polygons
85 .iter()
86 .zip(bboxes)
87 .map(|(selected, bbox)| {
88 LabeledPartitionTree::from_labeled_polygons(
89 selected,
90 polygons,
91 bbox,
92 max_depth,
93 depth + 1,
94 )
95 })
96 .collect(),
97 ),
98 HashMap::new(),
99 )
100 };
101
102 LabeledPartitionTree {
103 children,
104 bbox,
105 polygons: inner_polygons,
106 }
107 }
108
109 /// Returns the label of the partition that contains the given point.
110 ///
111 /// This method recursively searches for the leaf node that contains the point and returns its label.
112 /// If no leaf node contains the point, `None` is returned.
113 ///
114 /// # Arguments
115 /// * `point` - The point to check.
116 pub fn label(&self, point: &Point) -> Option<T> {
117 if self.children.is_empty() {
118 self.polygons.iter().find_map(|(label, polygon)| {
119 if polygon.contains(point) {
120 Some(label.clone())
121 } else {
122 None
123 }
124 })
125 } else {
126 self.children
127 .iter()
128 .filter(|child| child.bbox.contains(point))
129 .find_map(|child| child.label(point))
130 }
131 }
132
133 /// Returns the number of leaf nodes in the tree.
134 pub fn size(&self) -> usize {
135 if self.children.is_empty() {
136 1
137 } else {
138 self.children.iter().map(|child| child.size()).sum()
139 }
140 }
141
142 /// Plots the labeled partition tree and saves the image to the specified path.
143 ///
144 /// # Arguments
145 /// * `out_path` - The path where the resulting image will be saved.
146 pub fn plot(&self, out_path: &Path) -> Result<(), Box<dyn std::error::Error>> {
147 let root = BitMapBackend::new(out_path, (4000, 3000)).into_drawing_area();
148 root.fill(&WHITE)?;
149 let mut chart = ChartBuilder::on(&root)
150 .margin(5)
151 .x_label_area_size(30)
152 .y_label_area_size(30)
153 .build_cartesian_2d(-180f32..180f32, -90f32..90f32)?;
154
155 chart.configure_mesh().draw()?;
156
157 let bboxes = self.bboxes();
158 bboxes.iter().for_each(|bbox| {
159 chart
160 .draw_series(LineSeries::new(
161 bbox.coords_iter()
162 .map(|coord| (coord.x as f32, coord.y as f32)),
163 &RED,
164 ))
165 .unwrap();
166 });
167
168 chart
169 .configure_series_labels()
170 .background_style(&WHITE)
171 .border_style(&BLACK)
172 .draw()?;
173
174 root.present()?;
175 Ok(())
176 }
177
178 /// Returns a vector of bounding boxes for all leaf nodes in the labeled partition tree.
179 fn bboxes(&self) -> Vec<Rect> {
180 if self.children.is_empty() {
181 vec![self.bbox]
182 } else {
183 self.children
184 .iter()
185 .map(|child| child.bboxes())
186 .flatten()
187 .collect()
188 }
189 }
190}
191