1use std::{
2 fmt::{Display, Formatter},
3 str::FromStr,
4};
5
6use crate::{
7 error::Error,
8 image::{
9 segmentation::{
10 DbscanSegmentation,
11 FastDbscanSegmentation,
12 KmeansSegmentation,
13 LabelImage,
14 Segmentation,
15 SlicSegmentation,
16 SnicSegmentation,
17 },
18 Pixel,
19 },
20 math::{DistanceMetric, FloatNumber},
21 Filter,
22 ImageData,
23};
24
25#[allow(clippy::upper_case_acronyms)]
27#[derive(Debug, Default, Clone, PartialEq)]
28pub enum Algorithm {
29 KMeans,
31
32 #[default]
34 DBSCAN,
35
36 DBSCANpp,
38
39 SLIC,
41
42 SNIC,
44}
45
46impl Algorithm {
47 const SEGMENTS: usize = 128;
49
50 const KMEANS_MAX_ITER: usize = 50;
52
53 const KMEANS_TOLERANCE: f64 = 1e-3;
55
56 const DBSCAN_MIN_POINTS: usize = 10;
58
59 const DBSCAN_EPSILON: f64 = 0.03;
61
62 const FASTDBSCAN_PROBABILITY: f64 = 0.1;
64
65 const FASTDBSCAN_MIN_POINTS: usize = 10;
67
68 const FASTDBSCAN_EPSILON: f64 = 0.04;
70
71 const SLIC_COMPACTNESS: f64 = 0.0225; const SLIC_MAX_ITER: usize = 10;
76
77 const SLIC_TOLERANCE: f64 = 1e-3;
79
80 pub(crate) fn segment<T, F>(
90 &self,
91 image_data: &ImageData,
92 filter: &F,
93 ) -> Result<LabelImage<T>, Error>
94 where
95 T: FloatNumber,
96 F: Filter,
97 {
98 match self {
99 Self::KMeans => segment_internal(image_data, filter, || {
100 KmeansSegmentation::builder()
101 .segments(Self::SEGMENTS)
102 .max_iter(Self::KMEANS_MAX_ITER)
103 .tolerance(T::from_f64(Self::KMEANS_TOLERANCE))
104 .metric(DistanceMetric::SquaredEuclidean)
105 .build()
106 }),
107 Self::DBSCAN => segment_internal(image_data, filter, || {
108 DbscanSegmentation::builder()
109 .segments(Self::SEGMENTS)
110 .min_pixels(Self::DBSCAN_MIN_POINTS)
111 .epsilon(T::from_f64(Self::DBSCAN_EPSILON.powi(2))) .metric(DistanceMetric::SquaredEuclidean)
113 .build()
114 }),
115 Self::DBSCANpp => segment_internal(image_data, filter, || {
116 FastDbscanSegmentation::builder()
117 .min_pixels(Self::FASTDBSCAN_MIN_POINTS)
118 .probability(T::from_f64(Self::FASTDBSCAN_PROBABILITY))
119 .epsilon(T::from_f64(Self::FASTDBSCAN_EPSILON).powi(2))
120 .metric(DistanceMetric::SquaredEuclidean)
121 .build()
122 }),
123 Self::SLIC => segment_internal(image_data, filter, || {
124 SlicSegmentation::builder()
125 .segments(Self::SEGMENTS)
126 .max_iter(Self::SLIC_MAX_ITER)
127 .compactness(T::from_f64(Self::SLIC_COMPACTNESS))
128 .tolerance(T::from_f64(Self::SLIC_TOLERANCE))
129 .metric(DistanceMetric::SquaredEuclidean)
130 .build()
131 }),
132 Self::SNIC => segment_internal(image_data, filter, || {
133 SnicSegmentation::<T>::builder()
134 .segments(Self::SEGMENTS)
135 .metric(DistanceMetric::Euclidean)
136 .build()
137 }),
138 }
139 }
140}
141
142impl FromStr for Algorithm {
143 type Err = Error;
144
145 fn from_str(s: &str) -> Result<Self, Self::Err> {
146 match s.to_lowercase().as_str() {
147 "kmeans" => Ok(Self::KMeans),
148 "dbscan" => Ok(Self::DBSCAN),
149 "dbscan++" => Ok(Self::DBSCANpp),
150 "slic" => Ok(Self::SLIC),
151 "snic" => Ok(Self::SNIC),
152 _ => Err(Error::UnsupportedAlgorithm {
153 name: s.to_string(),
154 }),
155 }
156 }
157}
158
159impl Display for Algorithm {
160 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
161 match self {
162 Self::KMeans => write!(f, "kmeans"),
163 Self::DBSCAN => write!(f, "dbscan"),
164 Self::DBSCANpp => write!(f, "dbscan++"),
165 Self::SLIC => write!(f, "slic"),
166 Self::SNIC => write!(f, "snic"),
167 }
168 }
169}
170
171fn segment_internal<T, F, B, S, E>(
188 image_data: &ImageData,
189 filter: &F,
190 builder: B,
191) -> Result<LabelImage<T>, Error>
192where
193 T: FloatNumber,
194 F: Filter,
195 B: FnOnce() -> Result<S, E>,
196 S: Segmentation<T, Err = E>,
197 E: Display,
198{
199 let segmentation = builder().map_err(|e| Error::PaletteExtractionError {
200 details: e.to_string(),
201 })?;
202
203 let width = image_data.width() as usize;
204 let height = image_data.height() as usize;
205 let (pixels, mask) = collect_pixels_and_mask(image_data, filter);
206 segmentation
207 .segment_with_mask(width, height, &pixels, &mask)
208 .map_err(|e| Error::PaletteExtractionError {
209 details: e.to_string(),
210 })
211}
212
213#[must_use]
226fn collect_pixels_and_mask<T, F>(image_data: &ImageData, filter: &F) -> (Vec<Pixel<T>>, Vec<bool>)
227where
228 T: FloatNumber,
229 F: Filter,
230{
231 let width = image_data.width() as usize;
232 let height = image_data.height() as usize;
233 let (pixels, mask) = image_data.pixels_with_filter(filter).fold(
234 (
235 Vec::with_capacity(width * height),
236 Vec::with_capacity(width * height),
237 ),
238 |(mut pixels, mut mask), (p, m)| {
239 pixels.push(p);
240 mask.push(m);
241 (pixels, mask)
242 },
243 );
244 (pixels, mask)
245}
246
247#[cfg(test)]
248mod tests {
249 use rstest::rstest;
250
251 use super::*;
252 use crate::Rgba;
253
254 #[rstest]
255 #[case::kmeans("kmeans", Algorithm::KMeans)]
256 #[case::dbscan("dbscan", Algorithm::DBSCAN)]
257 #[case::dbscanpp("dbscan++", Algorithm::DBSCANpp)]
258 #[case::kmeans_upper("KMEANS", Algorithm::KMeans)]
259 #[case::dbscan_upper("DBSCAN", Algorithm::DBSCAN)]
260 #[case::dbscanpp_upper("DBSCAN++", Algorithm::DBSCANpp)]
261 #[case::kmeans_capitalized("Kmeans", Algorithm::KMeans)]
262 #[case::dbscan_capitalized("Dbscan", Algorithm::DBSCAN)]
263 #[case::dbscanpp_capitalized("Dbscan++", Algorithm::DBSCANpp)]
264 fn test_from_str(#[case] input: &str, #[case] expected: Algorithm) {
265 let actual = Algorithm::from_str(input).unwrap();
267
268 assert_eq!(actual, expected);
270 }
271
272 #[rstest]
273 #[case::empty("")]
274 #[case::invalid("unknown")]
275 fn test_from_str_error(#[case] input: &str) {
276 let actual = Algorithm::from_str(input);
278
279 assert!(actual.is_err());
281 assert_eq!(
282 actual.unwrap_err().to_string(),
283 format!("Unsupported algorithm specified: '{}'", input)
284 );
285 }
286
287 #[rstest]
288 #[case::kmeans(Algorithm::KMeans, "kmeans")]
289 #[case::dbscan(Algorithm::DBSCAN, "dbscan")]
290 #[case::dbscanpp(Algorithm::DBSCANpp, "dbscan++")]
291 fn test_fmt(#[case] algorithm: Algorithm, #[case] expected: &str) {
292 let actual = format!("{}", algorithm);
294
295 assert_eq!(actual, expected);
297 }
298
299 #[rstest]
300 #[case::kmeans(Algorithm::KMeans)]
301 #[case::dbscan(Algorithm::DBSCAN)]
302 #[case::dbscanpp(Algorithm::DBSCANpp)]
303 #[case::slic(Algorithm::SLIC)]
304 #[case::snic(Algorithm::SNIC)]
305 fn test_segment_empty(#[case] algorithm: Algorithm) {
306 let pixels: Vec<_> = Vec::new();
308 let image_data = ImageData::new(0, 0, &pixels).expect("Failed to create empty image data");
309
310 let actual = algorithm.segment(&image_data, &|rgba: &Rgba| rgba[0] != 0);
312
313 assert!(actual.is_ok());
315
316 let label_image: LabelImage<f64> = actual.unwrap();
317 assert_eq!(label_image.width(), 0);
318 assert_eq!(label_image.height(), 0);
319 }
320}