auto_palette/
algorithm.rs

1use std::{
2    fmt::{Display, Formatter},
3    str::FromStr,
4};
5
6use crate::{
7    error::Error,
8    image::{
9        segmentation::{
10            DbscanSegmentation,
11            FastDbscanSegmentation,
12            KmeansSegmentation,
13            LabelImage,
14            Segmentation,
15            SlicSegmentation,
16            SnicSegmentation,
17        },
18        Pixel,
19    },
20    math::{DistanceMetric, FloatNumber},
21    Filter,
22    ImageData,
23};
24
25/// The clustering algorithm to use for color palette extraction.
26#[allow(clippy::upper_case_acronyms)]
27#[derive(Debug, Default, Clone, PartialEq)]
28pub enum Algorithm {
29    /// K-means clustering algorithm.
30    KMeans,
31
32    /// DBSCAN clustering algorithm.
33    #[default]
34    DBSCAN,
35
36    /// DBSCAN++ clustering algorithm.
37    DBSCANpp,
38
39    /// SLIC (Simple Linear Iterative Clustering) algorithm.
40    SLIC,
41
42    /// SNIC (Simple Non-Iterative Clustering) algorithm.
43    SNIC,
44}
45
46impl Algorithm {
47    /// The number of segments to use for segmentation.
48    const SEGMENTS: usize = 128;
49
50    /// The maximum number of iterations for the K-means algorithm.
51    const KMEANS_MAX_ITER: usize = 50;
52
53    /// The tolerance for convergence conditions in the K-means algorithm.
54    const KMEANS_TOLERANCE: f64 = 1e-3;
55
56    /// The minimum number of points for the DBSCAN algorithm.
57    const DBSCAN_MIN_POINTS: usize = 10;
58
59    /// The epsilon value for the DBSCAN algorithm.
60    const DBSCAN_EPSILON: f64 = 0.03;
61
62    /// The probability for the Fast DBSCAN (DBSCAN++) algorithm.
63    const FASTDBSCAN_PROBABILITY: f64 = 0.1;
64
65    /// The minimum number of points for the Fast DBSCAN (DBSCAN++) algorithm.
66    const FASTDBSCAN_MIN_POINTS: usize = 10;
67
68    /// The epsilon value for the Fast DBSCAN (DBSCAN++) algorithm.
69    const FASTDBSCAN_EPSILON: f64 = 0.04;
70
71    /// The compactness value for the SLIC algorithm.
72    const SLIC_COMPACTNESS: f64 = 0.0225; // 0.15^2
73
74    /// The maximum number of iterations for the SLIC algorithm.
75    const SLIC_MAX_ITER: usize = 10;
76
77    /// The tolerance for convergence conditions in the SLIC algorithm.
78    const SLIC_TOLERANCE: f64 = 1e-3;
79
80    /// Clusters the given pixels using the algorithm.
81    ///
82    /// # Arguments
83    /// * `width` - The width of the image.
84    /// * `height` - The height of the image.
85    /// * `pixels` - The pixels to cluster.
86    ///
87    /// # Returns
88    /// The clusters found by the algorithm.
89    pub(crate) fn segment<T, F>(
90        &self,
91        image_data: &ImageData,
92        filter: &F,
93    ) -> Result<LabelImage<T>, Error>
94    where
95        T: FloatNumber,
96        F: Filter,
97    {
98        match self {
99            Self::KMeans => segment_internal(image_data, filter, || {
100                KmeansSegmentation::builder()
101                    .segments(Self::SEGMENTS)
102                    .max_iter(Self::KMEANS_MAX_ITER)
103                    .tolerance(T::from_f64(Self::KMEANS_TOLERANCE))
104                    .metric(DistanceMetric::SquaredEuclidean)
105                    .build()
106            }),
107            Self::DBSCAN => segment_internal(image_data, filter, || {
108                DbscanSegmentation::builder()
109                    .segments(Self::SEGMENTS)
110                    .min_pixels(Self::DBSCAN_MIN_POINTS)
111                    .epsilon(T::from_f64(Self::DBSCAN_EPSILON.powi(2))) // Squared epsilon for squared euclidean distance
112                    .metric(DistanceMetric::SquaredEuclidean)
113                    .build()
114            }),
115            Self::DBSCANpp => segment_internal(image_data, filter, || {
116                FastDbscanSegmentation::builder()
117                    .min_pixels(Self::FASTDBSCAN_MIN_POINTS)
118                    .probability(T::from_f64(Self::FASTDBSCAN_PROBABILITY))
119                    .epsilon(T::from_f64(Self::FASTDBSCAN_EPSILON).powi(2))
120                    .metric(DistanceMetric::SquaredEuclidean)
121                    .build()
122            }),
123            Self::SLIC => segment_internal(image_data, filter, || {
124                SlicSegmentation::builder()
125                    .segments(Self::SEGMENTS)
126                    .max_iter(Self::SLIC_MAX_ITER)
127                    .compactness(T::from_f64(Self::SLIC_COMPACTNESS))
128                    .tolerance(T::from_f64(Self::SLIC_TOLERANCE))
129                    .metric(DistanceMetric::SquaredEuclidean)
130                    .build()
131            }),
132            Self::SNIC => segment_internal(image_data, filter, || {
133                SnicSegmentation::<T>::builder()
134                    .segments(Self::SEGMENTS)
135                    .metric(DistanceMetric::Euclidean)
136                    .build()
137            }),
138        }
139    }
140}
141
142impl FromStr for Algorithm {
143    type Err = Error;
144
145    fn from_str(s: &str) -> Result<Self, Self::Err> {
146        match s.to_lowercase().as_str() {
147            "kmeans" => Ok(Self::KMeans),
148            "dbscan" => Ok(Self::DBSCAN),
149            "dbscan++" => Ok(Self::DBSCANpp),
150            "slic" => Ok(Self::SLIC),
151            "snic" => Ok(Self::SNIC),
152            _ => Err(Error::UnsupportedAlgorithm {
153                name: s.to_string(),
154            }),
155        }
156    }
157}
158
159impl Display for Algorithm {
160    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
161        match self {
162            Self::KMeans => write!(f, "kmeans"),
163            Self::DBSCAN => write!(f, "dbscan"),
164            Self::DBSCANpp => write!(f, "dbscan++"),
165            Self::SLIC => write!(f, "slic"),
166            Self::SNIC => write!(f, "snic"),
167        }
168    }
169}
170
171/// Segments the image data using the specified filter and algorithm.
172///
173/// # Type Parameters
174/// * `T` - The floating point type.
175/// * `F` - The filter function.
176/// * `B` - The builder function.
177/// * `S` - The segmentation algorithm.
178/// * `E` - The error type for the segmentation algorithm.
179///
180/// # Arguments
181/// * `image_data` - The image data to segment.
182/// * `filter` - The filter to apply to the image data.
183/// * `builder` - The builder function to create the segmentation algorithm.
184///
185/// # Returns
186/// A vector of segments found by the segmentation algorithm.
187fn segment_internal<T, F, B, S, E>(
188    image_data: &ImageData,
189    filter: &F,
190    builder: B,
191) -> Result<LabelImage<T>, Error>
192where
193    T: FloatNumber,
194    F: Filter,
195    B: FnOnce() -> Result<S, E>,
196    S: Segmentation<T, Err = E>,
197    E: Display,
198{
199    let segmentation = builder().map_err(|e| Error::PaletteExtractionError {
200        details: e.to_string(),
201    })?;
202
203    let width = image_data.width() as usize;
204    let height = image_data.height() as usize;
205    let (pixels, mask) = collect_pixels_and_mask(image_data, filter);
206    segmentation
207        .segment_with_mask(width, height, &pixels, &mask)
208        .map_err(|e| Error::PaletteExtractionError {
209            details: e.to_string(),
210        })
211}
212
213/// Collects the pixels and mask from the image data.
214///
215/// # Type Parameters
216/// * `T` - The floating point type.
217/// * `F` - The filter type.
218///
219/// # Arguments
220/// * `image_data` - The image data to collect pixels from.
221/// * `filter` - The filter to apply to the pixels.
222///
223/// # Returns
224/// A tuple containing a vector of pixels and a vector of masks.
225#[must_use]
226fn collect_pixels_and_mask<T, F>(image_data: &ImageData, filter: &F) -> (Vec<Pixel<T>>, Vec<bool>)
227where
228    T: FloatNumber,
229    F: Filter,
230{
231    let width = image_data.width() as usize;
232    let height = image_data.height() as usize;
233    let (pixels, mask) = image_data.pixels_with_filter(filter).fold(
234        (
235            Vec::with_capacity(width * height),
236            Vec::with_capacity(width * height),
237        ),
238        |(mut pixels, mut mask), (p, m)| {
239            pixels.push(p);
240            mask.push(m);
241            (pixels, mask)
242        },
243    );
244    (pixels, mask)
245}
246
247#[cfg(test)]
248mod tests {
249    use rstest::rstest;
250
251    use super::*;
252    use crate::Rgba;
253
254    #[rstest]
255    #[case::kmeans("kmeans", Algorithm::KMeans)]
256    #[case::dbscan("dbscan", Algorithm::DBSCAN)]
257    #[case::dbscanpp("dbscan++", Algorithm::DBSCANpp)]
258    #[case::kmeans_upper("KMEANS", Algorithm::KMeans)]
259    #[case::dbscan_upper("DBSCAN", Algorithm::DBSCAN)]
260    #[case::dbscanpp_upper("DBSCAN++", Algorithm::DBSCANpp)]
261    #[case::kmeans_capitalized("Kmeans", Algorithm::KMeans)]
262    #[case::dbscan_capitalized("Dbscan", Algorithm::DBSCAN)]
263    #[case::dbscanpp_capitalized("Dbscan++", Algorithm::DBSCANpp)]
264    fn test_from_str(#[case] input: &str, #[case] expected: Algorithm) {
265        // Act
266        let actual = Algorithm::from_str(input).unwrap();
267
268        // Assert
269        assert_eq!(actual, expected);
270    }
271
272    #[rstest]
273    #[case::empty("")]
274    #[case::invalid("unknown")]
275    fn test_from_str_error(#[case] input: &str) {
276        // Act
277        let actual = Algorithm::from_str(input);
278
279        // Assert
280        assert!(actual.is_err());
281        assert_eq!(
282            actual.unwrap_err().to_string(),
283            format!("Unsupported algorithm specified: '{}'", input)
284        );
285    }
286
287    #[rstest]
288    #[case::kmeans(Algorithm::KMeans, "kmeans")]
289    #[case::dbscan(Algorithm::DBSCAN, "dbscan")]
290    #[case::dbscanpp(Algorithm::DBSCANpp, "dbscan++")]
291    fn test_fmt(#[case] algorithm: Algorithm, #[case] expected: &str) {
292        // Act
293        let actual = format!("{}", algorithm);
294
295        // Assert
296        assert_eq!(actual, expected);
297    }
298
299    #[rstest]
300    #[case::kmeans(Algorithm::KMeans)]
301    #[case::dbscan(Algorithm::DBSCAN)]
302    #[case::dbscanpp(Algorithm::DBSCANpp)]
303    #[case::slic(Algorithm::SLIC)]
304    #[case::snic(Algorithm::SNIC)]
305    fn test_segment_empty(#[case] algorithm: Algorithm) {
306        // Arrange
307        let pixels: Vec<_> = Vec::new();
308        let image_data = ImageData::new(0, 0, &pixels).expect("Failed to create empty image data");
309
310        // Act
311        let actual = algorithm.segment(&image_data, &|rgba: &Rgba| rgba[0] != 0);
312
313        // Assert
314        assert!(actual.is_ok());
315
316        let label_image: LabelImage<f64> = actual.unwrap();
317        assert_eq!(label_image.width(), 0);
318        assert_eq!(label_image.height(), 0);
319    }
320}