1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
use crate::graph::{ImageEdge, ImageGraph};
use crate::segmentation::{Distance, NodeMerging};
use opencv::core::{Scalar, Vec3b, CV_32SC1};
use opencv::prelude::*;

/// Implementation of graph based image segmentation as described in the
/// paper by Felzenswalb and Huttenlocher.
#[derive(Debug)]
pub struct Segmentation<D, M>
where
    D: Distance,
    M: NodeMerging,
{
    /// Image height.
    height: usize,
    /// Image width.
    width: usize,
    /// The constructed and segmented image graph.
    graph: ImageGraph,
    /// The underlying distance to use.
    distance: D,
    /// The magic part of graph segmentation.
    magic: M,
    /// The minimum size of the segments, in pixels.
    #[allow(dead_code)]
    segment_size: usize,
}

impl<D, M> Segmentation<D, M>
where
    D: Distance,
    M: NodeMerging,
{
    pub fn new(distance: D, magic: M, segment_size: usize) -> Self {
        Self {
            distance,
            magic,
            height: 0,
            width: 0,
            segment_size,
            graph: ImageGraph::default(),
        }
    }

    /// Build the graph based on the image, i.e. compute the weights
    /// between pixels using the underlying distance.
    ///
    /// # Arguments
    ///
    /// * `image` - The image to oversegment.
    ///
    /// # Returns
    ///
    /// A tuple consisting of
    /// - The matrix in `CV_32SC1` format containing the labels for each pixel.
    /// - The number of segments / components.
    pub fn segment_image(&mut self, image: &Mat) -> (Mat, usize) {
        self.build_graph(&image);
        self.oversegment_graph();
        self.enforce_minimum_segment_size(10);
        let segmentation = self.derive_labels();
        let num_nodes = self.graph.num_components();
        (segmentation, num_nodes)
    }

    /// Build the graph based on the image, i.e. compute the weights
    /// between pixels using the underlying distance.
    ///
    /// # Arguments
    ///
    /// * `image` - The image to oversegment.
    fn build_graph(&mut self, image: &Mat) {
        assert_eq!(image.empty(), false, "image must not be empty");
        self.height = image.rows() as usize;
        self.width = image.cols() as usize;
        self.graph = self.init_graph_nodes(&image);
        self.init_graph_edges();
    }

    /// Initializes the graph nodes from the image.
    fn init_graph_nodes(&mut self, image: &Mat) -> ImageGraph {
        debug_assert_ne!(self.height, 0);
        debug_assert_ne!(self.width, 0);
        let width = self.width;
        let height = self.height;
        let node_count = height * width;
        let graph = ImageGraph::new_with_nodes(node_count);

        for i in 0..height {
            for j in 0..width {
                let node_index = width * i + j;
                let mut node = graph.node_at(node_index).borrow_mut();

                let bgr = image.at_2d::<Vec3b>(i as i32, j as i32).unwrap().0;
                node.b = bgr[0];
                node.g = bgr[1];
                node.r = bgr[2];

                // Initialize label
                node.label = node_index;
                node.id = node_index;
                node.n = 1;
            }
        }

        graph
    }

    /// Initializes the edges between the nodes in the prepared graph.
    fn init_graph_edges(&mut self) {
        debug_assert_ne!(self.height, 0);
        debug_assert_ne!(self.width, 0);
        let height = self.height;
        let width = self.width;
        let graph = &mut self.graph;
        let distance = &self.distance;

        let mut edges = Vec::new();

        for i in 0..height {
            for j in 0..width {
                let node_index = width * i + j;
                let node = graph.node_at(node_index);

                // Test right neighbor.
                if j < width - 1 {
                    let other_index = width * i + (j + 1);
                    let other = graph.node_at(other_index);

                    let weight = distance.distance(&node.borrow(), &other.borrow());
                    let edge = ImageEdge::new(node_index, other_index, weight);

                    edges.push(edge);
                }

                // Test bottom neighbor.
                if i < height - 1 {
                    let other_index = width * (i + 1) + j;
                    let other = graph.node_at(other_index);

                    let weight = distance.distance(&node.borrow(), &other.borrow());
                    let edge = ImageEdge::new(node_index, other_index, weight);

                    edges.push(edge);
                }
            }
        }

        graph.clear_edges();
        graph.add_edges(edges.into_iter());
    }

    /// Oversegment the given graph.
    fn oversegment_graph(&mut self) {
        let graph = &mut self.graph;
        assert_ne!(graph.num_edges(), 0, "number of edges must be nonzero");

        graph.sort_edges();

        for e in 0..graph.num_edges() {
            debug_assert_eq!(e % graph.num_edges(), e);
            let edge_cell = graph.edge_at(e);

            // SAFETY: The edge is only borrow immutably here, and none of the
            //         node lookup methods on the graph operate on edges.
            //         Since each edge is only processed once, we can safely borrow "unsafely".
            let edge = unsafe { edge_cell.try_borrow_unguarded().unwrap() };

            let s_n_idx = graph.find_node_component_at(edge.n);
            let s_m_idx = graph.find_node_component_at(edge.m);

            if s_n_idx == s_m_idx {
                continue;
            }

            let mut s_n = graph.node_at(s_n_idx).borrow_mut();
            let mut s_m = graph.node_at(s_m_idx).borrow_mut();

            // Are the nodes in different components?
            debug_assert_ne!(s_m.id, s_n.id);
            let should_merge = self.magic.should_merge(&s_n, &s_m, &edge);
            if should_merge {
                graph.merge(&mut s_n, &mut s_m, &edge);
            }
        }
    }

    /// Enforces the given minimum segment size.
    ///
    /// # Arguments
    ///
    /// * `segment_size` - Minimum segment size in pixels.
    fn enforce_minimum_segment_size(&mut self, segment_size: usize) {
        let graph = &mut self.graph;
        assert_ne!(graph.num_nodes(), 0, "number of nodes must be nonzero");

        for e in 0..graph.num_edges() {
            let edge = graph.edge_at(e).borrow();

            let s_n_idx = graph.find_node_component_at(edge.n);
            let s_m_idx = graph.find_node_component_at(edge.m);

            if s_n_idx == s_m_idx {
                continue;
            }

            let mut s_n = graph.node_at(s_n_idx).borrow_mut();
            let mut s_m = graph.node_at(s_m_idx).borrow_mut();

            // Neighboring segments must have different labels.
            debug_assert_ne!(s_m.label, s_n.label);

            let segment_too_small = s_n.n < segment_size || s_m.n < segment_size;
            if segment_too_small {
                graph.merge(&mut s_n, &mut s_m, &edge);
            }
        }
    }

    /// Derive labels from the produced oversegmentation.
    ///
    /// # Returns
    ///
    /// Labels as an integer matrix.
    fn derive_labels(&self) -> Mat {
        let mut labels = Mat::new_rows_cols_with_default(
            self.height as i32,
            self.width as i32,
            CV_32SC1,
            Scalar::from(0f64),
        )
        .unwrap();

        for i in 0..self.height {
            for j in 0..self.width {
                let n = self.width * i + j;

                let index = self.graph.find_node_component_at(n);
                let id = self.graph.node_id_at(index) as i32;

                *(labels.at_2d_mut(i as i32, j as i32).unwrap()) = id;
            }
        }

        labels
    }
}