1use std::path::Path;
2
3use image::{DynamicImage, RgbImage, imageops::FilterType};
4use imgref::Img;
5use kmeans_colors::get_kmeans;
6use mss_saliency::maximum_symmetric_surround_saliency;
7use palette::{FromColor, Hsv, IntoColor, Lab, Srgb};
8
9use crate::{
10 Cache,
11 error::Error,
12 extract::{
13 ExtractionConfig, ExtractStep,
14 scores::{ScoredCluster, ScoredPixel},
15 },
16};
17
18#[derive(Debug, Clone)]
34pub struct Extractor<'a> {
35 config: &'a ExtractionConfig,
36 width: usize,
37 height: usize,
38}
39
40impl<'a> Extractor<'a> {
41 pub fn extract<P>(
45 path: P,
46 config: &'a ExtractionConfig,
47 ) -> Result<Option<(u8, u8, u8)>, Error>
48 where
49 P: AsRef<Path>,
50 {
51 Self::inner_extract(path, config, |_| {})
52 }
53
54 pub fn extract_cached<P>(
62 path: P,
63 config: &'a ExtractionConfig,
64 cache_path: Option<&Path>,
65 ) -> Result<Option<(u8, u8, u8)>, Error>
66 where
67 P: AsRef<Path>,
68 {
69 Self::extract_cached_with_progress(path, config, cache_path, |_| {})
70 }
71
72 pub fn extract_with_progress<P, F>(
77 path: P,
78 config: &'a ExtractionConfig,
79 progress_callback: F,
80 ) -> Result<Option<(u8, u8, u8)>, Error>
81 where
82 P: AsRef<Path>,
83 F: FnMut(ExtractStep),
84 {
85 Self::inner_extract(path, config, progress_callback)
86 }
87
88 pub fn extract_cached_with_progress<P, F>(
96 path: P,
97 config: &'a ExtractionConfig,
98 cache_path: Option<&Path>,
99 mut progress_callback: F,
100 ) -> Result<Option<(u8, u8, u8)>, Error>
101 where
102 P: AsRef<Path>,
103 F: FnMut(ExtractStep),
104 {
105 progress_callback(ExtractStep::CheckingCache);
106 let cache_file =
107 cache_path.map(|v| v.to_owned()).unwrap_or_else(Cache::file);
108 let mut cache = Cache::load(&cache_file);
109 let key = Cache::key(config, path.as_ref())
110 .unwrap_or("fallback".to_string());
111
112 if let Some(&color) = cache.entries.get(&key) {
113 progress_callback(ExtractStep::FinishedWithCache);
114 return Ok(Some(color));
115 }
116
117 let color = Self::inner_extract(path, config, progress_callback)?;
118 if let Some(col) = color {
119 cache.entries.insert(key, col);
120 _ = cache.save(&cache_file);
121 }
122
123 Ok(color)
124 }
125
126 fn inner_extract<P, F>(
127 path: P,
128 config: &'a ExtractionConfig,
129 mut progress_callback: F,
130 ) -> Result<Option<(u8, u8, u8)>, Error>
131 where
132 P: AsRef<Path>,
133 F: FnMut(ExtractStep),
134 {
135 let mut extractor = Self {
136 config,
137 width: 0,
138 height: 0,
139 };
140
141 progress_callback(ExtractStep::OpeningImage);
142 let img = image::open(path)?;
143 progress_callback(ExtractStep::ResizingImage);
144 let img = extractor.prep_img(img);
145
146 progress_callback(ExtractStep::ExtractingColors);
147 let (sal_map, is_sal_worth) = extractor.gen_saliency(&img);
148 #[cfg(debug_assertions)]
149 extractor.save_saliency(&sal_map);
150
151 let rgb_img = img.to_rgb8();
152
153 let candids =
154 extractor.get_candidates(&rgb_img, &sal_map, is_sal_worth);
155
156 progress_callback(ExtractStep::Clustering);
157 let col = extractor.get_best_col(candids);
158 progress_callback(ExtractStep::Finished);
159 Ok(col)
160 }
161
162 fn prep_img(&mut self, img: DynamicImage) -> DynamicImage {
164 let tw = self.config.res_w.unwrap_or(img.width());
165 let th = self.config.res_h.unwrap_or(img.height());
166 self.width = tw as usize;
167 self.height = th as usize;
168
169 if tw == img.width() && th == img.height() {
170 return img;
171 }
172 img.resize_exact(tw, th, FilterType::Triangle)
173 }
174
175 fn gen_saliency(&self, img: &DynamicImage) -> (Vec<u8>, bool) {
177 let luma = img.to_luma8();
178 let luma_img =
179 Img::new(luma.as_raw().as_slice(), self.width, self.height);
180 let sal_map = maximum_symmetric_surround_saliency(luma_img);
181
182 let max_sal = *sal_map.buf().iter().max().unwrap_or(&1);
183 let sal_map: Vec<u8> = sal_map
184 .buf()
185 .iter()
186 .map(|&v| ((v as f32 / max_sal as f32) * 255.) as u8)
187 .collect();
188
189 let total_sal: u32 = sal_map.iter().map(|&p| p as u32).sum();
190 let avg_sal = total_sal as f32 / (self.width * self.height) as f32;
191 let is_sal_worth = avg_sal >= self.config.sal_thresh;
192
193 (sal_map, is_sal_worth)
194 }
195
196 fn get_candidates(
198 &self,
199 rgb_img: &RgbImage,
200 sal_map: &[u8],
201 is_worth: bool,
202 ) -> Vec<ScoredPixel> {
203 let mut candidates = Vec::new();
204
205 for (x, y, pixel) in rgb_img.enumerate_pixels() {
206 let r = pixel[0] as f32 / 255.;
207 let g = pixel[1] as f32 / 255.;
208 let b = pixel[2] as f32 / 255.;
209
210 let srgb = Srgb::new(r, g, b);
211 let hsv = Hsv::from_color(srgb);
212 if hsv.value < self.config.val_thresh
213 || hsv.saturation < self.config.sat_thresh
214 {
215 continue;
216 }
217
218 let mut score = hsv.saturation * hsv.value;
219 if is_worth {
220 let id = y as usize * self.width + x as usize;
221 let sal_val = sal_map[id] as f32 / 255.;
222 score *= 1.0 + sal_val * self.config.sal_bonus;
223 }
224
225 let hue = hsv.hue.into_positive_degrees();
226 let warmth = 1.0 - (hue.min(360. - hue) / 180.);
227 score *= 1.0 + warmth * self.config.warmth_bonus;
228
229 let lab: Lab = srgb.into_color();
230 candidates.push(ScoredPixel::new(lab, pixel, score));
231 }
232 candidates
233 }
234
235 fn get_best_col(&self, candids: Vec<ScoredPixel>) -> Option<(u8, u8, u8)> {
241 let clusters = self.get_clusters(candids);
242 let min_size = ((self.width * self.height) as f32 * 0.001) as usize;
243 let max_cnt = clusters.iter().map(|c| c.cnt).max().unwrap_or(1) as f32;
244
245 let mut best = None;
246 let mut max_score = -1.;
247 for cluster in clusters {
248 if cluster.cnt < min_size {
249 continue;
250 }
251
252 let avg_score = cluster.score / cluster.cnt as f32;
253 let mass_score = (cluster.cnt as f32 / max_cnt).sqrt();
254 let mut score = avg_score * mass_score * self.config.dom_bonus;
255
256 let sab = cluster.best_lab.a.powi(2) + cluster.best_lab.b.powi(2);
257 let mut vibr_score = sab / 10000.;
258
259 let r = cluster.best_rgb.0 as f32 / 255.;
260 let g = cluster.best_rgb.1 as f32 / 255.;
261 let b = cluster.best_rgb.2 as f32 / 255.;
262 vibr_score *= r.max(g.max(b));
263
264 score += vibr_score * self.config.vibr_bonus;
265 if score > max_score {
266 max_score = score;
267 best = Some(cluster.best_rgb);
268 }
269 }
270 best
271 }
272
273 fn get_clusters(&self, candids: Vec<ScoredPixel>) -> Vec<ScoredCluster> {
275 let labs: Vec<Lab> = candids.iter().map(|c| c.lab).collect();
276 let k = self.config.clusters.min(labs.len());
277
278 let res = get_kmeans(k, 20, 10.0, false, &labs, 0);
279
280 let mut clusters = vec![ScoredCluster::default(); k];
281 for (i, &cid) in res.indices.iter().enumerate() {
282 clusters[cid as usize].push(&candids[i]);
283 }
284 clusters
285 }
286
287 #[cfg(debug_assertions)]
288 fn save_saliency(&self, sal_img: &[u8]) {
289 if let Some(img) = image::GrayImage::from_raw(
290 self.width as u32,
291 self.height as u32,
292 sal_img.to_owned(),
293 ) {
294 img.save("debug_saliency.png")
295 .expect("Failed to save debug image");
296 }
297 }
298}