1mod color_cut_quantizer;
26mod filter;
27mod swatch;
28mod target;
29
30pub const DEFAULT_CALCULATE_NUMBER_COLORS: usize = 16;
32pub const DEFAULT_RESIZE_IMAGE_AREA: u32 = 112 * 112;
34
35use std::collections::{HashMap, HashSet};
36
37pub use image;
38use image::{math::Rect, GenericImageView, ImageBuffer};
39
40use crate::color_cut_quantizer::ColorCutQuantizer;
41pub use crate::{
42 filter::{DefaultFilter, Filter},
43 swatch::Swatch,
44 target::Target,
45};
46
47#[derive(Debug)]
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50pub struct Palette {
51 swatches: Vec<Swatch>,
52 targets: Vec<Target>,
53 selected_swatches: HashMap<u64, Option<Swatch>>,
54}
55
56pub struct PaletteBuilder<P>
58where
59 P: image::Pixel<Subpixel = u8> + 'static + std::cmp::Eq + std::hash::Hash,
60{
61 image: ImageBuffer<P, Vec<<P as image::Pixel>::Subpixel>>,
62 targets: Vec<Target>,
63 maximum_color_count: usize,
64 resize_area: Option<u32>,
65 region: Option<Rect>,
66 filters: Vec<Box<dyn Filter>>,
67}
68
69impl Palette {
70 pub fn from_image<P>(
72 image: ImageBuffer<P, Vec<<P as image::Pixel>::Subpixel>>,
73 ) -> PaletteBuilder<P>
74 where
75 P: image::Pixel<Subpixel = u8> + 'static + std::cmp::Eq + std::hash::Hash,
76 {
77 PaletteBuilder::from_image(image)
78 }
79
80 pub fn swatches(&self) -> &[Swatch] {
82 &self.swatches
83 }
84
85 pub fn targets(&self) -> &[Target] {
87 &self.targets
88 }
89
90 pub fn light_vibrant_swatch(&self) -> Option<Swatch> {
92 self.get_swatch_for_target(Target::light_vibrant())
93 }
94
95 pub fn vibrant_swatch(&self) -> Option<Swatch> {
97 self.get_swatch_for_target(Target::vibrant())
98 }
99
100 pub fn dark_vibrant_swatch(&self) -> Option<Swatch> {
102 self.get_swatch_for_target(Target::dark_vibrant())
103 }
104
105 pub fn light_muted_swatch(&self) -> Option<Swatch> {
107 self.get_swatch_for_target(Target::light_muted())
108 }
109
110 pub fn muted_swatch(&self) -> Option<Swatch> {
112 self.get_swatch_for_target(Target::muted())
113 }
114
115 pub fn dark_muted_swatch(&self) -> Option<Swatch> {
117 self.get_swatch_for_target(Target::dark_muted())
118 }
119
120 pub fn light_vibrant_color(&self) -> Option<(u8, u8, u8)> {
122 self.get_swatch_for_target(Target::light_vibrant())
123 .map(Swatch::rgb)
124 }
125
126 pub fn vibrant_color(&self) -> Option<(u8, u8, u8)> {
128 self.get_swatch_for_target(Target::vibrant())
129 .map(Swatch::rgb)
130 }
131
132 pub fn dark_vibrant_color(&self) -> Option<(u8, u8, u8)> {
134 self.get_swatch_for_target(Target::dark_vibrant())
135 .map(Swatch::rgb)
136 }
137
138 pub fn light_muted_color(&self) -> Option<(u8, u8, u8)> {
140 self.get_swatch_for_target(Target::light_muted())
141 .map(Swatch::rgb)
142 }
143
144 pub fn muted_color(&self) -> Option<(u8, u8, u8)> {
146 self.get_swatch_for_target(Target::muted()).map(Swatch::rgb)
147 }
148
149 pub fn dark_muted_color(&self) -> Option<(u8, u8, u8)> {
151 self.get_swatch_for_target(Target::dark_muted())
152 .map(Swatch::rgb)
153 }
154
155 pub fn get_swatch_for_target(&self, target: Target) -> Option<Swatch> {
157 self.selected_swatches.get(&target.id()).copied().flatten()
158 }
159
160 pub fn most_prominent_color(&self) -> Option<(u8, u8, u8)> {
163 self.swatches
164 .iter()
165 .max_by_key(|swatch| swatch.population())
166 .map(|swatch| swatch.rgb())
167 }
168}
169
170impl<P> PaletteBuilder<P>
171where
172 P: image::Pixel<Subpixel = u8> + 'static + std::cmp::Eq + std::hash::Hash,
173{
174 pub fn from_image(image: ImageBuffer<P, Vec<<P as image::Pixel>::Subpixel>>) -> Self {
176 Self {
177 image,
178 targets: Target::default_targets().to_vec(),
179 maximum_color_count: DEFAULT_CALCULATE_NUMBER_COLORS,
180 resize_area: Some(DEFAULT_RESIZE_IMAGE_AREA),
181 region: None,
182 filters: vec![Box::new(DefaultFilter)],
183 }
184 }
185
186 pub fn from_swatches() -> Self {
187 unimplemented!()
188 }
189
190 pub fn resize_image_area(self, resize_area: Option<u32>) -> Self {
197 Self {
198 resize_area,
199 ..self
200 }
201 }
202
203 pub fn region(self, x: u32, y: u32, width: u32, height: u32) -> Self {
210 Self {
211 region: Some(Rect {
212 x,
213 y,
214 width,
215 height,
216 }),
217 ..self
218 }
219 }
220
221 pub fn add_target(mut self, target: Target) -> Self {
226 if !self.targets.contains(&target) {
227 self.targets.push(target);
228 }
229
230 self
231 }
232
233 pub fn add_filter<F>(mut self, filter: F) -> Self
240 where
241 F: Filter + 'static,
242 {
243 self.filters.push(Box::new(filter));
244 self
245 }
246
247 pub fn clear_region(self) -> Self {
249 Self {
250 region: None,
251 ..self
252 }
253 }
254
255 pub fn clear_targets(self) -> Self {
257 Self {
258 targets: Vec::new(),
259 ..self
260 }
261 }
262
263 pub fn clear_filters(self) -> Self {
265 Self {
266 filters: Vec::new(),
267 ..self
268 }
269 }
270
271 pub fn generate(mut self) -> Palette {
273 if self.scale_image_down() {
275 if let Some(mut region) = self.region {
276 let scale = self.image.width() as f32 / self.image.height() as f32;
278
279 region.x = (region.x as f32 * scale).floor() as u32;
280 region.y = (region.y as f32 * scale).floor() as u32;
281 region.width = ((region.width as f32 * scale) as u32 + region.x)
282 .min(self.image.width() - region.x);
283 region.height = ((region.height as f32 * scale) as u32 + region.y)
284 .min(self.image.height() - region.y);
285
286 self.region = Some(region);
287 }
288 }
289
290 let pixels = if let Some(region) = self.region {
292 self.image
293 .view(region.x, region.y, region.width, region.height)
294 .pixels()
295 .map(|(_, _, p)| p)
296 .collect()
297 } else {
298 self.image.pixels().copied().collect()
299 };
300
301 let quantizer = ColorCutQuantizer::new(pixels, self.maximum_color_count, self.filters);
303 let swatches = quantizer.get_quantized_colors();
304
305 let mut used_colors = HashSet::new();
307 let selected_swatches = self
308 .targets
309 .iter_mut()
310 .map(|target| {
311 target.normalize_weights();
312 (
313 target.id(),
314 generate_scored_target(&swatches, *target, &mut used_colors),
315 )
316 })
317 .collect();
318
319 Palette {
320 swatches,
321 targets: self.targets,
322 selected_swatches,
323 }
324 }
325
326 fn scale_image_down(&mut self) -> bool
327 where
328 <P as image::Pixel>::Subpixel: 'static,
329 {
330 let (width, height) = self.image.dimensions();
331 let area = width * height;
332
333 let scale_ratio = match self.resize_area {
334 Some(resize_area) if resize_area > 0 && area > resize_area => {
335 (resize_area as f32 / area as f32).sqrt()
336 }
337 _ => 0.0,
338 };
339
340 if scale_ratio > 0.0 {
341 self.image = image::imageops::resize(
342 &self.image,
343 (width as f32 * scale_ratio).ceil() as u32,
344 (height as f32 * scale_ratio).ceil() as u32,
345 image::imageops::FilterType::Nearest,
346 );
347
348 true
349 } else {
350 false
351 }
352 }
353}
354
355fn generate_scored_target(
356 swatches: &[Swatch],
357 target: Target,
358 used_colors: &mut HashSet<(u8, u8, u8)>,
359) -> Option<Swatch> {
360 if target.is_exclusive() {
361 if let Some(max_scored_swatch) =
362 get_max_scored_swatch_for_target(swatches, target, used_colors)
363 {
364 used_colors.insert(max_scored_swatch.rgb());
365 return Some(max_scored_swatch);
366 }
367 }
368
369 None
370}
371
372fn get_max_scored_swatch_for_target(
373 swatches: &[Swatch],
374 target: Target,
375 used_colors: &HashSet<(u8, u8, u8)>,
376) -> Option<Swatch> {
377 let dominant_swatch = swatches
378 .iter()
379 .copied()
380 .max_by_key(|swatch| swatch.population());
381
382 swatches
383 .iter()
384 .copied()
385 .filter(|swatch| should_be_scored_for_target(*swatch, target, used_colors))
386 .max_by(|lhs, rhs| {
387 generate_score(*lhs, dominant_swatch, target)
388 .partial_cmp(&generate_score(*rhs, dominant_swatch, target))
389 .unwrap()
390 })
391}
392
393fn should_be_scored_for_target(
394 swatch: Swatch,
395 target: Target,
396 used_colors: &HashSet<(u8, u8, u8)>,
397) -> bool {
398 let (_, s, l) = swatch.hsl();
399
400 (target.minimum_saturation()..=target.maximum_saturation()).contains(&s)
401 && (target.minimum_lightness()..=target.maximum_lightness()).contains(&l)
402 && !used_colors.contains(&swatch.rgb())
403}
404
405fn generate_score(swatch: Swatch, dominant_swatch: Option<Swatch>, target: Target) -> f32 {
406 let (_, saturation, lightness) = swatch.hsl();
407
408 let max_population = if let Some(dominant_swatch) = dominant_swatch {
409 dominant_swatch.population() as f32
410 } else {
411 1.0
412 };
413
414 let saturation_score =
417 target.saturation_weight() * (1.0 - (saturation - target.target_saturation()).abs());
418 let lightness_score =
419 target.lightness_weight() * (1.0 - (lightness - target.target_lightness()).abs());
420
421 let population_score =
424 target.population_weight() * (swatch.population() as f32 / max_population);
425
426 saturation_score + lightness_score + population_score
427}
428
429fn rgb_to_hsl((r, g, b): (u8, u8, u8)) -> (f32, f32, f32) {
431 let r = r as f32 / 255.0;
432 let g = g as f32 / 255.0;
433 let b = b as f32 / 255.0;
434
435 let max = r.max(g).max(b);
436 let min = r.min(g).min(b);
437 let c = max - min;
438
439 let l = (max + min) / 2.0;
440 let (h, s) = if c == 0.0 {
441 (0.0, 0.0)
442 } else {
443 let s = c / (1.0 - (2.0 * l - 1.0).abs());
444
445 let (segment, shift) = if max == r {
446 ((g - b) / c, if (g - b) / c < 0.0 { 6.0 } else { 0.0 })
447 } else if max == g {
448 ((b - r) / c, 2.0)
449 } else {
450 ((r - g) / c, 4.0)
451 };
452
453 (segment + shift, s)
454 };
455
456 (h * 60.0, s, l)
457}