prominence/
lib.rs

1// Copyright 2022 Spanfile
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! A library to extract prominent colors from an image.
16//!
17//! This library is a reimplementation of the Palette library in Android Jetpack. Android Jetpack is
18//! Copyright 2018 The Android Open Source Project. Android Jetpack is licensed under the Apache
19//! License, Version 2.0.
20//!
21//! [Original source.](https://github.com/androidx/androidx/tree/f4eca2c46040cab36ebf7f34e68bdd973110e4a5/palette/palette/src/main/java/androidx/palette/graphics)
22//!
23//! [Android Jetpack license.](https://github.com/androidx/androidx/blob/7b7922489f9a7572f4462558691bf5550dd65c26/LICENSE.txt)
24
25mod color_cut_quantizer;
26mod filter;
27mod swatch;
28mod target;
29
30/// The default amount of colors to calculate at maximum while quantizing an image.
31pub const DEFAULT_CALCULATE_NUMBER_COLORS: usize = 16;
32/// The default area to resize the given image to before quantizing;
33pub const DEFAULT_RESIZE_IMAGE_AREA: u32 = 112 * 112;
34
35use std::collections::{HashMap, HashSet};
36
37pub use image;
38use image::{math::Rect, GenericImageView, ImageBuffer};
39
40use crate::color_cut_quantizer::ColorCutQuantizer;
41pub use crate::{
42    filter::{DefaultFilter, Filter},
43    swatch::Swatch,
44    target::Target,
45};
46
47/// A color palette derived from an image.
48#[derive(Debug)]
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50pub struct Palette {
51    swatches: Vec<Swatch>,
52    targets: Vec<Target>,
53    selected_swatches: HashMap<u64, Option<Swatch>>,
54}
55
56/// A builder for a new [Palette].
57pub struct PaletteBuilder<P>
58where
59    P: image::Pixel<Subpixel = u8> + 'static + std::cmp::Eq + std::hash::Hash,
60{
61    image: ImageBuffer<P, Vec<<P as image::Pixel>::Subpixel>>,
62    targets: Vec<Target>,
63    maximum_color_count: usize,
64    resize_area: Option<u32>,
65    region: Option<Rect>,
66    filters: Vec<Box<dyn Filter>>,
67}
68
69impl Palette {
70    /// Return a new [`PaletteBuilder`] from a given image buffer.
71    pub fn from_image<P>(
72        image: ImageBuffer<P, Vec<<P as image::Pixel>::Subpixel>>,
73    ) -> PaletteBuilder<P>
74    where
75        P: image::Pixel<Subpixel = u8> + 'static + std::cmp::Eq + std::hash::Hash,
76    {
77        PaletteBuilder::from_image(image)
78    }
79
80    /// Returns the swatches in this palette.
81    pub fn swatches(&self) -> &[Swatch] {
82        &self.swatches
83    }
84
85    /// Returns the targets in this palette.
86    pub fn targets(&self) -> &[Target] {
87        &self.targets
88    }
89
90    /// Returns the swatch corresponding to the preset light vibrant target, if it exists.
91    pub fn light_vibrant_swatch(&self) -> Option<Swatch> {
92        self.get_swatch_for_target(Target::light_vibrant())
93    }
94
95    /// Returns the swatch corresponding to the preset vibrant target, if it exists.
96    pub fn vibrant_swatch(&self) -> Option<Swatch> {
97        self.get_swatch_for_target(Target::vibrant())
98    }
99
100    /// Returns the swatch corresponding to the preset dark vibrant target, if it exists.
101    pub fn dark_vibrant_swatch(&self) -> Option<Swatch> {
102        self.get_swatch_for_target(Target::dark_vibrant())
103    }
104
105    /// Returns the swatch corresponding to the preset light muted target, if it exists.
106    pub fn light_muted_swatch(&self) -> Option<Swatch> {
107        self.get_swatch_for_target(Target::light_muted())
108    }
109
110    /// Returns the swatch corresponding to the preset muted target, if it exists.
111    pub fn muted_swatch(&self) -> Option<Swatch> {
112        self.get_swatch_for_target(Target::muted())
113    }
114
115    /// Returns the swatch corresponding to the preset dark muted target, if it exists.
116    pub fn dark_muted_swatch(&self) -> Option<Swatch> {
117        self.get_swatch_for_target(Target::dark_muted())
118    }
119
120    /// Returns the color corresponding to the preset light vibrant target, if it exists.
121    pub fn light_vibrant_color(&self) -> Option<(u8, u8, u8)> {
122        self.get_swatch_for_target(Target::light_vibrant())
123            .map(Swatch::rgb)
124    }
125
126    /// Returns the color corresponding to the preset vibrant target, if it exists.
127    pub fn vibrant_color(&self) -> Option<(u8, u8, u8)> {
128        self.get_swatch_for_target(Target::vibrant())
129            .map(Swatch::rgb)
130    }
131
132    /// Returns the color corresponding to the preset dark vibrant target, if it exists.
133    pub fn dark_vibrant_color(&self) -> Option<(u8, u8, u8)> {
134        self.get_swatch_for_target(Target::dark_vibrant())
135            .map(Swatch::rgb)
136    }
137
138    /// Returns the color corresponding to the preset light muted target, if it exists.
139    pub fn light_muted_color(&self) -> Option<(u8, u8, u8)> {
140        self.get_swatch_for_target(Target::light_muted())
141            .map(Swatch::rgb)
142    }
143
144    /// Returns the color corresponding to the preset muted target, if it exists.
145    pub fn muted_color(&self) -> Option<(u8, u8, u8)> {
146        self.get_swatch_for_target(Target::muted()).map(Swatch::rgb)
147    }
148
149    /// Returns the color corresponding to the preset dark vibrant target, if it exists.
150    pub fn dark_muted_color(&self) -> Option<(u8, u8, u8)> {
151        self.get_swatch_for_target(Target::dark_muted())
152            .map(Swatch::rgb)
153    }
154
155    /// Returns the swatch corresponding to a given target, if it exists.
156    pub fn get_swatch_for_target(&self, target: Target) -> Option<Swatch> {
157        self.selected_swatches.get(&target.id()).copied().flatten()
158    }
159
160    /// Returns the most prominent color in the palette, which is the swatch with the largest
161    /// population.
162    pub fn most_prominent_color(&self) -> Option<(u8, u8, u8)> {
163        self.swatches
164            .iter()
165            .max_by_key(|swatch| swatch.population())
166            .map(|swatch| swatch.rgb())
167    }
168}
169
170impl<P> PaletteBuilder<P>
171where
172    P: image::Pixel<Subpixel = u8> + 'static + std::cmp::Eq + std::hash::Hash,
173{
174    /// Returns a new [`PaletteBuilder`] from a given image buffer.
175    pub fn from_image(image: ImageBuffer<P, Vec<<P as image::Pixel>::Subpixel>>) -> Self {
176        Self {
177            image,
178            targets: Target::default_targets().to_vec(),
179            maximum_color_count: DEFAULT_CALCULATE_NUMBER_COLORS,
180            resize_area: Some(DEFAULT_RESIZE_IMAGE_AREA),
181            region: None,
182            filters: vec![Box::new(DefaultFilter)],
183        }
184    }
185
186    pub fn from_swatches() -> Self {
187        unimplemented!()
188    }
189
190    /// Set the desired area to shrink the image to before quantizing. Set to `None` to disable
191    /// shrinking.
192    ///
193    /// By default the image will be shrunk to an area of 112 by 112 pixels, as defined in the
194    /// [`DEFAULT_RESIZE_IMAGE_AREA`] constant. The image will not be grown if it is already smaller
195    /// than the desired area.
196    pub fn resize_image_area(self, resize_area: Option<u32>) -> Self {
197        Self {
198            resize_area,
199            ..self
200        }
201    }
202
203    /// Set a custom region to focus the palette generation on.
204    ///
205    /// The region is based on the original image. If the image is shrunk before quantizing (see
206    /// [`PaletteBuilder::resize_image_area`]), the given region will be scaled accordingly to still
207    /// cover a similar area in the shrunk image. By default, the entire image is used to
208    /// generate the palette.
209    pub fn region(self, x: u32, y: u32, width: u32, height: u32) -> Self {
210        Self {
211            region: Some(Rect {
212                x,
213                y,
214                width,
215                height,
216            }),
217            ..self
218        }
219    }
220
221    /// Add a custom target to the palette.
222    ///
223    /// By default, a set of preset targets are included in every palette. See
224    /// [`Target::default_targets()`].
225    pub fn add_target(mut self, target: Target) -> Self {
226        if !self.targets.contains(&target) {
227            self.targets.push(target);
228        }
229
230        self
231    }
232
233    /// Add a custom filter to the palette. Multiple filters may be added. Filters will be evaluated
234    /// in order of insertion.
235    ///
236    /// A filter is used to reject certain colors from being included in the palette generation. A
237    /// [`DefaultFilter`] is included in every builder by default. It can be removed from the
238    /// builder with [`PaletteBuilder::clear_filters`].
239    pub fn add_filter<F>(mut self, filter: F) -> Self
240    where
241        F: Filter + 'static,
242    {
243        self.filters.push(Box::new(filter));
244        self
245    }
246
247    /// Clears the set region.
248    pub fn clear_region(self) -> Self {
249        Self {
250            region: None,
251            ..self
252        }
253    }
254
255    /// Removes all targets in the builder, including the presets.
256    pub fn clear_targets(self) -> Self {
257        Self {
258            targets: Vec::new(),
259            ..self
260        }
261    }
262
263    /// Removes all filters in the builder, including the default filter.
264    pub fn clear_filters(self) -> Self {
265        Self {
266            filters: Vec::new(),
267            ..self
268        }
269    }
270
271    /// Consume the builder and generate a new [`Palette`].
272    pub fn generate(mut self) -> Palette {
273        // scale down the image if requested
274        if self.scale_image_down() {
275            if let Some(mut region) = self.region {
276                // scale down the region to match the new scaled image
277                let scale = self.image.width() as f32 / self.image.height() as f32;
278
279                region.x = (region.x as f32 * scale).floor() as u32;
280                region.y = (region.y as f32 * scale).floor() as u32;
281                region.width = ((region.width as f32 * scale) as u32 + region.x)
282                    .min(self.image.width() - region.x);
283                region.height = ((region.height as f32 * scale) as u32 + region.y)
284                    .min(self.image.height() - region.y);
285
286                self.region = Some(region);
287            }
288        }
289
290        // get pixels in the requested region, or in the entire image
291        let pixels = if let Some(region) = self.region {
292            self.image
293                .view(region.x, region.y, region.width, region.height)
294                .pixels()
295                .map(|(_, _, p)| p)
296                .collect()
297        } else {
298            self.image.pixels().copied().collect()
299        };
300
301        // quantize pixels, get swatches
302        let quantizer = ColorCutQuantizer::new(pixels, self.maximum_color_count, self.filters);
303        let swatches = quantizer.get_quantized_colors();
304
305        // try to pick swatches for each target
306        let mut used_colors = HashSet::new();
307        let selected_swatches = self
308            .targets
309            .iter_mut()
310            .map(|target| {
311                target.normalize_weights();
312                (
313                    target.id(),
314                    generate_scored_target(&swatches, *target, &mut used_colors),
315                )
316            })
317            .collect();
318
319        Palette {
320            swatches,
321            targets: self.targets,
322            selected_swatches,
323        }
324    }
325
326    fn scale_image_down(&mut self) -> bool
327    where
328        <P as image::Pixel>::Subpixel: 'static,
329    {
330        let (width, height) = self.image.dimensions();
331        let area = width * height;
332
333        let scale_ratio = match self.resize_area {
334            Some(resize_area) if resize_area > 0 && area > resize_area => {
335                (resize_area as f32 / area as f32).sqrt()
336            }
337            _ => 0.0,
338        };
339
340        if scale_ratio > 0.0 {
341            self.image = image::imageops::resize(
342                &self.image,
343                (width as f32 * scale_ratio).ceil() as u32,
344                (height as f32 * scale_ratio).ceil() as u32,
345                image::imageops::FilterType::Nearest,
346            );
347
348            true
349        } else {
350            false
351        }
352    }
353}
354
355fn generate_scored_target(
356    swatches: &[Swatch],
357    target: Target,
358    used_colors: &mut HashSet<(u8, u8, u8)>,
359) -> Option<Swatch> {
360    if target.is_exclusive() {
361        if let Some(max_scored_swatch) =
362            get_max_scored_swatch_for_target(swatches, target, used_colors)
363        {
364            used_colors.insert(max_scored_swatch.rgb());
365            return Some(max_scored_swatch);
366        }
367    }
368
369    None
370}
371
372fn get_max_scored_swatch_for_target(
373    swatches: &[Swatch],
374    target: Target,
375    used_colors: &HashSet<(u8, u8, u8)>,
376) -> Option<Swatch> {
377    let dominant_swatch = swatches
378        .iter()
379        .copied()
380        .max_by_key(|swatch| swatch.population());
381
382    swatches
383        .iter()
384        .copied()
385        .filter(|swatch| should_be_scored_for_target(*swatch, target, used_colors))
386        .max_by(|lhs, rhs| {
387            generate_score(*lhs, dominant_swatch, target)
388                .partial_cmp(&generate_score(*rhs, dominant_swatch, target))
389                .unwrap()
390        })
391}
392
393fn should_be_scored_for_target(
394    swatch: Swatch,
395    target: Target,
396    used_colors: &HashSet<(u8, u8, u8)>,
397) -> bool {
398    let (_, s, l) = swatch.hsl();
399
400    (target.minimum_saturation()..=target.maximum_saturation()).contains(&s)
401        && (target.minimum_lightness()..=target.maximum_lightness()).contains(&l)
402        && !used_colors.contains(&swatch.rgb())
403}
404
405fn generate_score(swatch: Swatch, dominant_swatch: Option<Swatch>, target: Target) -> f32 {
406    let (_, saturation, lightness) = swatch.hsl();
407
408    let max_population = if let Some(dominant_swatch) = dominant_swatch {
409        dominant_swatch.population() as f32
410    } else {
411        1.0
412    };
413
414    // calculate scores for saturation and luminance based on how close to the target values they
415    // are, weighted by the target
416    let saturation_score =
417        target.saturation_weight() * (1.0 - (saturation - target.target_saturation()).abs());
418    let lightness_score =
419        target.lightness_weight() * (1.0 - (lightness - target.target_lightness()).abs());
420
421    // calculate score for the population based on how large it is compared to the dominant swatch,
422    // weighted by the target
423    let population_score =
424        target.population_weight() * (swatch.population() as f32 / max_population);
425
426    saturation_score + lightness_score + population_score
427}
428
429// thank you SO. https://stackoverflow.com/a/39147465
430fn rgb_to_hsl((r, g, b): (u8, u8, u8)) -> (f32, f32, f32) {
431    let r = r as f32 / 255.0;
432    let g = g as f32 / 255.0;
433    let b = b as f32 / 255.0;
434
435    let max = r.max(g).max(b);
436    let min = r.min(g).min(b);
437    let c = max - min;
438
439    let l = (max + min) / 2.0;
440    let (h, s) = if c == 0.0 {
441        (0.0, 0.0)
442    } else {
443        let s = c / (1.0 - (2.0 * l - 1.0).abs());
444
445        let (segment, shift) = if max == r {
446            ((g - b) / c, if (g - b) / c < 0.0 { 6.0 } else { 0.0 })
447        } else if max == g {
448            ((b - r) / c, 2.0)
449        } else {
450            ((r - g) / c, 4.0)
451        };
452
453        (segment + shift, s)
454    };
455
456    (h * 60.0, s, l)
457}