graph_based_image_segmentation/segmentation/
segmentation.rs

1use crate::graph::{ImageEdge, ImageGraph, ImageNode, ImageNodeColor};
2use crate::segmentation::{Distance, NodeMerging};
3use opencv::core::{Scalar, Vec3b, CV_32SC1};
4use opencv::prelude::*;
5
6/// Implementation of graph based image segmentation as described in the
7/// paper by Felzenswalb and Huttenlocher.
8#[derive(Debug)]
9pub struct Segmentation<D, M>
10where
11    D: Distance,
12    M: NodeMerging,
13{
14    /// Image height.
15    height: usize,
16    /// Image width.
17    width: usize,
18    /// The constructed and segmented image graph.
19    graph: ImageGraph,
20    /// The underlying distance to use.
21    distance: D,
22    /// The magic part of graph segmentation.
23    magic: M,
24    /// The minimum size of the segments, in pixels.
25    #[allow(dead_code)]
26    segment_size: usize,
27}
28
29impl<D, M> Segmentation<D, M>
30where
31    D: Distance,
32    M: NodeMerging,
33{
34    pub fn new(distance: D, magic: M, segment_size: usize) -> Self {
35        Self {
36            distance,
37            magic,
38            height: 0,
39            width: 0,
40            segment_size,
41            graph: ImageGraph::default(),
42        }
43    }
44
45    /// Build the graph based on the image, i.e. compute the weights
46    /// between pixels using the underlying distance.
47    ///
48    /// # Arguments
49    ///
50    /// * `image` - The image to oversegment.
51    ///
52    /// # Returns
53    ///
54    /// A tuple consisting of
55    /// - The matrix in `CV_32SC1` format containing the labels for each pixel.
56    /// - The number of segments / components.
57    pub fn segment_image(&mut self, image: &Mat) -> (Mat, usize) {
58        self.build_graph(&image);
59        self.oversegment_graph();
60        self.enforce_minimum_segment_size(10);
61        let segmentation = self.derive_labels();
62        let num_nodes = self.graph.num_components();
63        (segmentation, num_nodes)
64    }
65
66    /// Build the graph based on the image, i.e. compute the weights
67    /// between pixels using the underlying distance.
68    ///
69    /// # Arguments
70    ///
71    /// * `image` - The image to oversegment.
72    fn build_graph(&mut self, image: &Mat) {
73        assert_eq!(image.empty(), false, "image must not be empty");
74        self.height = image.rows() as usize;
75        self.width = image.cols() as usize;
76        self.graph = self.init_graph_nodes(&image);
77        self.init_graph_edges();
78    }
79
80    /// Initializes the graph nodes from the image.
81    fn init_graph_nodes(&mut self, image: &Mat) -> ImageGraph {
82        debug_assert_ne!(self.height, 0);
83        debug_assert_ne!(self.width, 0);
84        let width = self.width;
85        let height = self.height;
86        let node_count = height * width;
87        let graph = ImageGraph::new_with_nodes(node_count);
88
89        for i in 0..height {
90            for j in 0..width {
91                let node_index = width * i + j;
92                let node = graph.node_at(node_index);
93                let node_color = graph.node_color_at(node_index);
94
95                let bgr = image.at_2d::<Vec3b>(i as i32, j as i32).unwrap().0;
96                node_color.set(ImageNodeColor {
97                    b: bgr[0],
98                    g: bgr[1],
99                    r: bgr[2],
100                });
101
102                // Initialize label
103                node.set(ImageNode {
104                    label: node_index,
105                    id: node_index,
106                    n: 1,
107                    ..Default::default()
108                });
109            }
110        }
111
112        graph
113    }
114
115    /// Initializes the edges between the nodes in the prepared graph.
116    fn init_graph_edges(&mut self) {
117        debug_assert_ne!(self.height, 0);
118        debug_assert_ne!(self.width, 0);
119        let height = self.height;
120        let width = self.width;
121        let graph = &mut self.graph;
122        let distance = &self.distance;
123
124        let mut edges = Vec::new();
125
126        for i in 0..(height - 1) {
127            for j in 0..(width - 1) {
128                let node_index = width * i + j;
129                let node = graph.node_color_at(node_index).get();
130
131                // Test right neighbor.
132                let other_index = width * i + (j + 1);
133                let other = graph.node_color_at(other_index).get();
134                let weight = distance.distance(&node, &other);
135                let edge = ImageEdge::new(node_index, other_index, weight);
136                edges.push(edge);
137
138                // Test bottom neighbor.
139                let other_index = width * (i + 1) + j;
140                let other = graph.node_color_at(other_index).get();
141                let weight = distance.distance(&node, &other);
142                let edge = ImageEdge::new(node_index, other_index, weight);
143                edges.push(edge);
144            }
145        }
146
147        graph.clear_edges();
148        graph.add_edges(edges.into_iter());
149    }
150
151    /// Oversegment the given graph.
152    fn oversegment_graph(&mut self) {
153        let graph = &mut self.graph;
154        assert_ne!(graph.num_edges(), 0, "number of edges must be nonzero");
155
156        graph.sort_edges();
157
158        for e in 0..graph.num_edges() {
159            debug_assert_eq!(e % graph.num_edges(), e);
160            let edge = graph.edge_at(e).get();
161
162            let s_n_idx = graph.find_node_component_at(edge.n);
163            let s_m_idx = graph.find_node_component_at(edge.m);
164
165            if s_n_idx == s_m_idx {
166                continue;
167            }
168
169            let mut s_n = graph.node_at(s_n_idx);
170            let mut s_m = graph.node_at(s_m_idx);
171
172            // Are the nodes in different components?
173            let should_merge = self.magic.should_merge(&s_n, &s_m, &edge);
174            if should_merge {
175                graph.merge(&mut s_n, &mut s_m, &edge);
176            }
177        }
178    }
179
180    /// Enforces the given minimum segment size.
181    ///
182    /// # Arguments
183    ///
184    /// * `segment_size` - Minimum segment size in pixels.
185    fn enforce_minimum_segment_size(&mut self, segment_size: usize) {
186        let graph = &mut self.graph;
187        assert_ne!(graph.num_nodes(), 0, "number of nodes must be nonzero");
188
189        for e in 0..graph.num_edges() {
190            let edge = graph.edge_at(e).get();
191
192            let s_n_idx = graph.find_node_component_at(edge.n);
193            let s_m_idx = graph.find_node_component_at(edge.m);
194
195            if s_n_idx == s_m_idx {
196                continue;
197            }
198
199            let mut s_n = graph.node_at(s_n_idx);
200            let mut s_m = graph.node_at(s_m_idx);
201
202            let lhs = s_n.get();
203            let rhs = s_m.get();
204
205            // Neighboring segments must have different labels.
206            debug_assert_ne!(lhs.label, rhs.label);
207
208            let segment_too_small = lhs.n < segment_size || rhs.n < segment_size;
209            if segment_too_small {
210                graph.merge(&mut s_n, &mut s_m, &edge);
211            }
212        }
213    }
214
215    /// Derive labels from the produced oversegmentation.
216    ///
217    /// # Returns
218    ///
219    /// Labels as an integer matrix.
220    fn derive_labels(&self) -> Mat {
221        let mut labels = Mat::new_rows_cols_with_default(
222            self.height as i32,
223            self.width as i32,
224            CV_32SC1,
225            Scalar::from(0f64),
226        )
227        .unwrap();
228
229        for i in 0..self.height {
230            for j in 0..self.width {
231                let n = self.width * i + j;
232
233                let index = self.graph.find_node_component_at(n);
234                let id = self.graph.node_id_at(index) as i32;
235
236                *(labels.at_2d_mut(i as i32, j as i32).unwrap()) = id;
237            }
238        }
239
240        labels
241    }
242}