Skip to main content

gecol_core/extract/
extractor.rs

1use std::path::Path;
2
3use image::{DynamicImage, RgbImage, imageops::FilterType};
4use imgref::Img;
5use kmeans_colors::get_kmeans;
6use mss_saliency::maximum_symmetric_surround_saliency;
7use palette::{FromColor, Hsv, IntoColor, Lab, Srgb};
8
9use crate::{
10    Cache,
11    error::Error,
12    extract::{
13        ExtractionConfig, ExtractStep,
14        scores::{ScoredCluster, ScoredPixel},
15    },
16};
17
18/// A utility for extracting accent color from an image.
19///
20/// This structs acts as a temporary state container during the extraction
21/// process and provides static methods for executing the extraction.
22///
23/// There are multiple extraction variants, mainly with
24/// ([`Extractor::extract`]) and without ([`Extractor::extract_cached`]) using
25/// the [`Cache`]. It is highly recommended using the [`Cache`]. High
26/// resolution images can take a while to just open, so thanks to the [`Cache`]
27/// the result for repeated extraction will be pretty much instant.
28///
29/// You can also use variant with progress reporting
30/// ([`Extractor::extract_with_progress`] and
31/// [`Extractor::extract_cached_with_progress`]), which is useful for having
32/// a loading screen, for example.
33#[derive(Debug, Clone)]
34pub struct Extractor<'a> {
35    config: &'a ExtractionConfig,
36    width: usize,
37    height: usize,
38}
39
40impl<'a> Extractor<'a> {
41    /// Extracts the accent color from the image at the given path.
42    ///
43    /// When no sufficient color is found, it returns `None`.
44    pub fn extract<P>(
45        path: P,
46        config: &'a ExtractionConfig,
47    ) -> Result<Option<(u8, u8, u8)>, Error>
48    where
49        P: AsRef<Path>,
50    {
51        Self::inner_extract(path, config, |_| {})
52    }
53
54    /// Extracts the accent color from the image at the given path and uses
55    /// the cache.
56    ///
57    /// It checks if the cache already contains the color for the given image,
58    /// otherwise it saves the extracted color into the cache.
59    ///
60    /// When no sufficient color is found, it returns `None`.
61    pub fn extract_cached<P>(
62        path: P,
63        config: &'a ExtractionConfig,
64        cache_path: Option<&Path>,
65    ) -> Result<Option<(u8, u8, u8)>, Error>
66    where
67        P: AsRef<Path>,
68    {
69        Self::extract_cached_with_progress(path, config, cache_path, |_| {})
70    }
71
72    /// Extracts the accent color from the image at the given path with the
73    /// progress reporting.
74    ///
75    /// When no sufficient color is found, it returns `None`.
76    pub fn extract_with_progress<P, F>(
77        path: P,
78        config: &'a ExtractionConfig,
79        progress_callback: F,
80    ) -> Result<Option<(u8, u8, u8)>, Error>
81    where
82        P: AsRef<Path>,
83        F: FnMut(ExtractStep),
84    {
85        Self::inner_extract(path, config, progress_callback)
86    }
87
88    /// Extracts the accent color from the image at the given path with the
89    /// progress reporting and uses the cache.
90    ///
91    /// It checks if the cache already contains the color for the given image,
92    /// otherwise it saves the extracted color into the cache.
93    ///
94    /// When no sufficient color is found, it returns `None`.
95    pub fn extract_cached_with_progress<P, F>(
96        path: P,
97        config: &'a ExtractionConfig,
98        cache_path: Option<&Path>,
99        mut progress_callback: F,
100    ) -> Result<Option<(u8, u8, u8)>, Error>
101    where
102        P: AsRef<Path>,
103        F: FnMut(ExtractStep),
104    {
105        progress_callback(ExtractStep::CheckingCache);
106        let cache_file =
107            cache_path.map(|v| v.to_owned()).unwrap_or_else(Cache::file);
108        let mut cache = Cache::load(&cache_file);
109        let key = Cache::key(config, path.as_ref())
110            .unwrap_or("fallback".to_string());
111
112        if let Some(&color) = cache.entries.get(&key) {
113            progress_callback(ExtractStep::FinishedWithCache);
114            return Ok(Some(color));
115        }
116
117        let color = Self::inner_extract(path, config, progress_callback)?;
118        if let Some(col) = color {
119            cache.entries.insert(key, col);
120            _ = cache.save(&cache_file);
121        }
122
123        Ok(color)
124    }
125
126    fn inner_extract<P, F>(
127        path: P,
128        config: &'a ExtractionConfig,
129        mut progress_callback: F,
130    ) -> Result<Option<(u8, u8, u8)>, Error>
131    where
132        P: AsRef<Path>,
133        F: FnMut(ExtractStep),
134    {
135        let mut extractor = Self {
136            config,
137            width: 0,
138            height: 0,
139        };
140
141        progress_callback(ExtractStep::OpeningImage);
142        let img = image::open(path)?;
143        progress_callback(ExtractStep::ResizingImage);
144        let img = extractor.prep_img(img);
145
146        progress_callback(ExtractStep::ExtractingColors);
147        let (sal_map, is_sal_worth) = extractor.gen_saliency(&img);
148        #[cfg(debug_assertions)]
149        extractor.save_saliency(&sal_map);
150
151        let rgb_img = img.to_rgb8();
152
153        let candids =
154            extractor.get_candidates(&rgb_img, &sal_map, is_sal_worth);
155
156        progress_callback(ExtractStep::Clustering);
157        let col = extractor.get_best_col(candids);
158        progress_callback(ExtractStep::Finished);
159        Ok(col)
160    }
161
162    /// Resizes the image only if new dimensions are provided.
163    fn prep_img(&mut self, img: DynamicImage) -> DynamicImage {
164        let tw = self.config.res_w.unwrap_or(img.width());
165        let th = self.config.res_h.unwrap_or(img.height());
166        self.width = tw as usize;
167        self.height = th as usize;
168
169        if tw == img.width() && th == img.height() {
170            return img;
171        }
172        img.resize_exact(tw, th, FilterType::Triangle)
173    }
174
175    /// Generates the normalized u8 saliency map
176    fn gen_saliency(&self, img: &DynamicImage) -> (Vec<u8>, bool) {
177        let luma = img.to_luma8();
178        let luma_img =
179            Img::new(luma.as_raw().as_slice(), self.width, self.height);
180        let sal_map = maximum_symmetric_surround_saliency(luma_img);
181
182        let max_sal = *sal_map.buf().iter().max().unwrap_or(&1);
183        let sal_map: Vec<u8> = sal_map
184            .buf()
185            .iter()
186            .map(|&v| ((v as f32 / max_sal as f32) * 255.) as u8)
187            .collect();
188
189        let total_sal: u32 = sal_map.iter().map(|&p| p as u32).sum();
190        let avg_sal = total_sal as f32 / (self.width * self.height) as f32;
191        let is_sal_worth = avg_sal >= self.config.sal_thresh;
192
193        (sal_map, is_sal_worth)
194    }
195
196    /// Finds the best color based on saliency and HSV.
197    fn get_candidates(
198        &self,
199        rgb_img: &RgbImage,
200        sal_map: &[u8],
201        is_worth: bool,
202    ) -> Vec<ScoredPixel> {
203        let mut candidates = Vec::new();
204
205        for (x, y, pixel) in rgb_img.enumerate_pixels() {
206            let r = pixel[0] as f32 / 255.;
207            let g = pixel[1] as f32 / 255.;
208            let b = pixel[2] as f32 / 255.;
209
210            let srgb = Srgb::new(r, g, b);
211            let hsv = Hsv::from_color(srgb);
212            if hsv.value < self.config.val_thresh
213                || hsv.saturation < self.config.sat_thresh
214            {
215                continue;
216            }
217
218            let mut score = hsv.saturation * hsv.value;
219            if is_worth {
220                let id = y as usize * self.width + x as usize;
221                let sal_val = sal_map[id] as f32 / 255.;
222                score *= 1.0 + sal_val * self.config.sal_bonus;
223            }
224
225            let hue = hsv.hue.into_positive_degrees();
226            let warmth = 1.0 - (hue.min(360. - hue) / 180.);
227            score *= 1.0 + warmth * self.config.warmth_bonus;
228
229            let lab: Lab = srgb.into_color();
230            candidates.push(ScoredPixel::new(lab, pixel, score));
231        }
232        candidates
233    }
234
235    /// Gets best color from the candidate pixels.
236    ///
237    /// It uses k-means clustering in order to find the best color, where
238    /// it picks cluster with the highest average value (which suits the
239    /// cluster size requirement) and picks the pixel with the highest score.
240    fn get_best_col(&self, candids: Vec<ScoredPixel>) -> Option<(u8, u8, u8)> {
241        let clusters = self.get_clusters(candids);
242        let min_size = ((self.width * self.height) as f32 * 0.001) as usize;
243        let max_cnt = clusters.iter().map(|c| c.cnt).max().unwrap_or(1) as f32;
244
245        let mut best = None;
246        let mut max_score = -1.;
247        for cluster in clusters {
248            if cluster.cnt < min_size {
249                continue;
250            }
251
252            let avg_score = cluster.score / cluster.cnt as f32;
253            let mass_score = (cluster.cnt as f32 / max_cnt).sqrt();
254            let mut score = avg_score * mass_score * self.config.dom_bonus;
255
256            let sab = cluster.best_lab.a.powi(2) + cluster.best_lab.b.powi(2);
257            let mut vibr_score = sab / 10000.;
258
259            let r = cluster.best_rgb.0 as f32 / 255.;
260            let g = cluster.best_rgb.1 as f32 / 255.;
261            let b = cluster.best_rgb.2 as f32 / 255.;
262            vibr_score *= r.max(g.max(b));
263
264            score += vibr_score * self.config.vibr_bonus;
265            if score > max_score {
266                max_score = score;
267                best = Some(cluster.best_rgb);
268            }
269        }
270        best
271    }
272
273    /// Gets [`ScoredCluster`]s from using k-means clustring.
274    fn get_clusters(&self, candids: Vec<ScoredPixel>) -> Vec<ScoredCluster> {
275        let labs: Vec<Lab> = candids.iter().map(|c| c.lab).collect();
276        let k = self.config.clusters.min(labs.len());
277
278        let res = get_kmeans(k, 20, 10.0, false, &labs, 0);
279
280        let mut clusters = vec![ScoredCluster::default(); k];
281        for (i, &cid) in res.indices.iter().enumerate() {
282            clusters[cid as usize].push(&candids[i]);
283        }
284        clusters
285    }
286
287    #[cfg(debug_assertions)]
288    fn save_saliency(&self, sal_img: &[u8]) {
289        if let Some(img) = image::GrayImage::from_raw(
290            self.width as u32,
291            self.height as u32,
292            sal_img.to_owned(),
293        ) {
294            img.save("debug_saliency.png")
295                .expect("Failed to save debug image");
296        }
297    }
298}