graph_based_image_segmentation/segmentation/
segmentation.rs1use crate::graph::{ImageEdge, ImageGraph, ImageNode, ImageNodeColor};
2use crate::segmentation::{Distance, NodeMerging};
3use opencv::core::{Scalar, Vec3b, CV_32SC1};
4use opencv::prelude::*;
5
6#[derive(Debug)]
9pub struct Segmentation<D, M>
10where
11 D: Distance,
12 M: NodeMerging,
13{
14 height: usize,
16 width: usize,
18 graph: ImageGraph,
20 distance: D,
22 magic: M,
24 #[allow(dead_code)]
26 segment_size: usize,
27}
28
29impl<D, M> Segmentation<D, M>
30where
31 D: Distance,
32 M: NodeMerging,
33{
34 pub fn new(distance: D, magic: M, segment_size: usize) -> Self {
35 Self {
36 distance,
37 magic,
38 height: 0,
39 width: 0,
40 segment_size,
41 graph: ImageGraph::default(),
42 }
43 }
44
45 pub fn segment_image(&mut self, image: &Mat) -> (Mat, usize) {
58 self.build_graph(&image);
59 self.oversegment_graph();
60 self.enforce_minimum_segment_size(10);
61 let segmentation = self.derive_labels();
62 let num_nodes = self.graph.num_components();
63 (segmentation, num_nodes)
64 }
65
66 fn build_graph(&mut self, image: &Mat) {
73 assert_eq!(image.empty(), false, "image must not be empty");
74 self.height = image.rows() as usize;
75 self.width = image.cols() as usize;
76 self.graph = self.init_graph_nodes(&image);
77 self.init_graph_edges();
78 }
79
80 fn init_graph_nodes(&mut self, image: &Mat) -> ImageGraph {
82 debug_assert_ne!(self.height, 0);
83 debug_assert_ne!(self.width, 0);
84 let width = self.width;
85 let height = self.height;
86 let node_count = height * width;
87 let graph = ImageGraph::new_with_nodes(node_count);
88
89 for i in 0..height {
90 for j in 0..width {
91 let node_index = width * i + j;
92 let node = graph.node_at(node_index);
93 let node_color = graph.node_color_at(node_index);
94
95 let bgr = image.at_2d::<Vec3b>(i as i32, j as i32).unwrap().0;
96 node_color.set(ImageNodeColor {
97 b: bgr[0],
98 g: bgr[1],
99 r: bgr[2],
100 });
101
102 node.set(ImageNode {
104 label: node_index,
105 id: node_index,
106 n: 1,
107 ..Default::default()
108 });
109 }
110 }
111
112 graph
113 }
114
115 fn init_graph_edges(&mut self) {
117 debug_assert_ne!(self.height, 0);
118 debug_assert_ne!(self.width, 0);
119 let height = self.height;
120 let width = self.width;
121 let graph = &mut self.graph;
122 let distance = &self.distance;
123
124 let mut edges = Vec::new();
125
126 for i in 0..(height - 1) {
127 for j in 0..(width - 1) {
128 let node_index = width * i + j;
129 let node = graph.node_color_at(node_index).get();
130
131 let other_index = width * i + (j + 1);
133 let other = graph.node_color_at(other_index).get();
134 let weight = distance.distance(&node, &other);
135 let edge = ImageEdge::new(node_index, other_index, weight);
136 edges.push(edge);
137
138 let other_index = width * (i + 1) + j;
140 let other = graph.node_color_at(other_index).get();
141 let weight = distance.distance(&node, &other);
142 let edge = ImageEdge::new(node_index, other_index, weight);
143 edges.push(edge);
144 }
145 }
146
147 graph.clear_edges();
148 graph.add_edges(edges.into_iter());
149 }
150
151 fn oversegment_graph(&mut self) {
153 let graph = &mut self.graph;
154 assert_ne!(graph.num_edges(), 0, "number of edges must be nonzero");
155
156 graph.sort_edges();
157
158 for e in 0..graph.num_edges() {
159 debug_assert_eq!(e % graph.num_edges(), e);
160 let edge = graph.edge_at(e).get();
161
162 let s_n_idx = graph.find_node_component_at(edge.n);
163 let s_m_idx = graph.find_node_component_at(edge.m);
164
165 if s_n_idx == s_m_idx {
166 continue;
167 }
168
169 let mut s_n = graph.node_at(s_n_idx);
170 let mut s_m = graph.node_at(s_m_idx);
171
172 let should_merge = self.magic.should_merge(&s_n, &s_m, &edge);
174 if should_merge {
175 graph.merge(&mut s_n, &mut s_m, &edge);
176 }
177 }
178 }
179
180 fn enforce_minimum_segment_size(&mut self, segment_size: usize) {
186 let graph = &mut self.graph;
187 assert_ne!(graph.num_nodes(), 0, "number of nodes must be nonzero");
188
189 for e in 0..graph.num_edges() {
190 let edge = graph.edge_at(e).get();
191
192 let s_n_idx = graph.find_node_component_at(edge.n);
193 let s_m_idx = graph.find_node_component_at(edge.m);
194
195 if s_n_idx == s_m_idx {
196 continue;
197 }
198
199 let mut s_n = graph.node_at(s_n_idx);
200 let mut s_m = graph.node_at(s_m_idx);
201
202 let lhs = s_n.get();
203 let rhs = s_m.get();
204
205 debug_assert_ne!(lhs.label, rhs.label);
207
208 let segment_too_small = lhs.n < segment_size || rhs.n < segment_size;
209 if segment_too_small {
210 graph.merge(&mut s_n, &mut s_m, &edge);
211 }
212 }
213 }
214
215 fn derive_labels(&self) -> Mat {
221 let mut labels = Mat::new_rows_cols_with_default(
222 self.height as i32,
223 self.width as i32,
224 CV_32SC1,
225 Scalar::from(0f64),
226 )
227 .unwrap();
228
229 for i in 0..self.height {
230 for j in 0..self.width {
231 let n = self.width * i + j;
232
233 let index = self.graph.find_node_component_at(n);
234 let id = self.graph.node_id_at(index) as i32;
235
236 *(labels.at_2d_mut(i as i32, j as i32).unwrap()) = id;
237 }
238 }
239
240 labels
241 }
242}