use std::{cmp::Reverse, marker::PhantomData};
use crate::{
algorithm::Algorithm,
color::{Color, Lab},
error::Error,
image::{
filter::{AlphaFilter, CompositeFilter, Filter},
segmentation::LabelImage,
ImageData,
},
math::{
clustering::{ClusteringAlgorithm, DBSCAN},
denormalize,
sampling::{DiversitySampling, SamplingAlgorithm, SamplingError, WeightedFarthestSampling},
DistanceMetric,
FloatNumber,
Point,
},
theme::Theme,
Swatch,
};
#[derive(Debug, Clone, PartialEq)]
pub struct Palette<T>
where
T: FloatNumber,
{
swatches: Vec<Swatch<T>>,
}
impl<T> Palette<T>
where
T: FloatNumber,
{
#[must_use]
pub fn new(swatches: Vec<Swatch<T>>) -> Self {
Self { swatches }
}
#[must_use]
pub fn len(&self) -> usize {
self.swatches.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.swatches.is_empty()
}
#[must_use]
pub fn swatches(&self) -> &[Swatch<T>] {
&self.swatches
}
pub fn find_swatches(&self, num_swatches: usize) -> Result<Vec<Swatch<T>>, Error> {
self.find_swatches_internal(
num_swatches,
|swatch| swatch.ratio(),
|scores| {
DiversitySampling::new(T::from_f64(0.6), scores, DistanceMetric::SquaredEuclidean)
},
)
}
pub fn find_swatches_with_theme(
&self,
num_swatches: usize,
theme: Theme,
) -> Result<Vec<Swatch<T>>, Error> {
self.find_swatches_internal(
num_swatches,
|swatch| theme.score(swatch),
|scores| WeightedFarthestSampling::new(scores, DistanceMetric::SquaredEuclidean),
)
}
pub fn find_swatches_internal<S, F1, F2>(
&self,
num_swatches: usize,
score_fn: F1,
sampling_factory: F2,
) -> Result<Vec<Swatch<T>>, Error>
where
S: SamplingAlgorithm<T>,
F1: Fn(&Swatch<T>) -> T,
F2: FnOnce(Vec<T>) -> Result<S, SamplingError>,
{
if self.swatches.is_empty() {
return Ok(vec![]);
}
let num_swatches = num_swatches.min(self.swatches.len());
let (colors, scores): (Vec<Point<T, 3>>, Vec<T>) = self
.swatches
.iter()
.map(|swatch| {
let color = swatch.color();
([color.l, color.a, color.b], score_fn(swatch))
})
.unzip();
let sampler =
sampling_factory(scores).map_err(|cause| Error::SwatchSelectionError { cause })?;
let sampled = sampler
.sample(&colors, num_swatches)
.map_err(|cause| Error::SwatchSelectionError { cause })?;
let mut found: Vec<_> = sampled.iter().map(|&index| self.swatches[index]).collect();
found.sort_by_key(|swatch| Reverse(swatch.population()));
Ok(found)
}
#[must_use]
pub fn builder() -> PaletteBuilder<T, AlphaFilter> {
PaletteBuilder::new()
}
pub fn extract(image_data: &ImageData) -> Result<Self, Error> {
Self::builder().build(image_data)
}
#[deprecated(
since = "0.8.0",
note = "Use `Palette::extract` or `Palette::builder` instead."
)]
pub fn extract_with_algorithm(
image_data: &ImageData,
algorithm: Algorithm,
) -> Result<Self, Error> {
Self::builder().algorithm(algorithm).build(image_data)
}
}
pub struct PaletteBuilder<T, F>
where
T: FloatNumber,
F: Filter,
{
algorithm: Algorithm,
filter: F,
max_swatches: usize,
_marker: PhantomData<T>,
}
impl<T> PaletteBuilder<T, AlphaFilter>
where
T: FloatNumber,
{
const DEFAULT_MAX_SWATCHES: usize = 256;
#[must_use]
fn new() -> Self {
PaletteBuilder {
algorithm: Algorithm::default(),
filter: AlphaFilter::default(),
max_swatches: Self::DEFAULT_MAX_SWATCHES,
_marker: PhantomData,
}
}
}
impl<T, F> PaletteBuilder<T, F>
where
T: FloatNumber,
F: Filter,
{
pub fn algorithm(mut self, algorithm: Algorithm) -> Self {
self.algorithm = algorithm;
self
}
#[must_use]
pub fn filter<F2>(self, filter: F2) -> PaletteBuilder<T, CompositeFilter<F, F2>>
where
F2: Filter,
{
PaletteBuilder {
algorithm: self.algorithm,
filter: self.filter.composite(filter),
max_swatches: self.max_swatches,
_marker: PhantomData,
}
}
#[must_use]
pub fn max_swatches(mut self, max_swatches: usize) -> Self {
self.max_swatches = max_swatches;
self
}
pub fn build(self, image_data: &ImageData) -> Result<Palette<T>, Error> {
let label_image = self.algorithm.segment(image_data, &self.filter)?;
let mut swatches = to_swatches(&label_image)?;
swatches.sort_by_key(|swatch| Reverse(swatch.population()));
let palette = Palette::new(swatches.into_iter().take(self.max_swatches).collect());
Ok(palette)
}
}
const COLOR_GROUP_DBSCAN_MIN_POINTS: usize = 1;
const COLOR_GROUP_DBSCAN_EPSILON: f64 = 2.5;
fn to_swatches<T>(label_image: &LabelImage<T>) -> Result<Vec<Swatch<T>>, Error>
where
T: FloatNumber,
{
let (segments, colors): (Vec<_>, Vec<_>) = label_image
.segments()
.map(|segment| {
let center_pixel = segment.center();
(
segment,
[
Lab::<T>::denormalize_l(center_pixel[0]),
Lab::<T>::denormalize_a(center_pixel[1]),
Lab::<T>::denormalize_b(center_pixel[2]),
],
)
})
.unzip();
let dbscan = DBSCAN::new(
COLOR_GROUP_DBSCAN_MIN_POINTS,
T::from_f64(COLOR_GROUP_DBSCAN_EPSILON),
DistanceMetric::Euclidean,
)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})?;
let clusters = dbscan
.fit(&colors)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})?;
let width = T::from_usize(label_image.width());
let height = T::from_usize(label_image.height());
let area = width * height;
let swatches: Vec<_> = clusters
.iter()
.filter_map(|cluster| {
let mut best_color = [T::zero(); 3];
let mut best_position = None;
let mut best_population = 0;
let mut total_population = 0;
for &member in cluster.members() {
let Some(segment) = segments.get(member) else {
continue;
};
if segment.is_empty() {
continue;
}
let fraction =
T::from_usize(segment.len()) / T::from_usize(segment.len() + best_population);
let center_pixel = segment.center();
best_color.iter_mut().enumerate().for_each(|(i, color)| {
*color += fraction * (center_pixel[i] - *color);
});
if fraction >= T::from_f64(0.5) || best_population == 0 {
best_position = Some((
denormalize(center_pixel[3], T::zero(), width).trunc_to_u32(),
denormalize(center_pixel[4], T::zero(), height).trunc_to_u32(),
));
best_population = segment.len();
}
total_population += segment.len();
}
if let Some(position) = best_position {
let l = Lab::<T>::denormalize_l(best_color[0]);
let a = Lab::<T>::denormalize_a(best_color[1]);
let b = Lab::<T>::denormalize_b(best_color[2]);
Some(Swatch::new(
Color::new(l, a, b),
position,
total_population,
T::from_usize(total_population) / area,
))
} else {
None
}
})
.collect();
Ok(swatches)
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use rstest::rstest;
use super::*;
use crate::{assert_color_eq, Rgba};
#[must_use]
fn sample_swatches<T>() -> Vec<Swatch<T>>
where
T: FloatNumber,
{
vec![
Swatch::new(
Color::from_str("#FFFFFF").unwrap(),
(159, 106),
61228,
T::from_f64(0.9214),
),
Swatch::new(
Color::from_str("#EE334E").unwrap(),
(238, 89),
1080,
T::from_f64(0.0163),
),
Swatch::new(
Color::from_str("#0081C8").unwrap(),
(82, 88),
1064,
T::from_f64(0.0160),
),
Swatch::new(
Color::from_str("#00A651").unwrap(),
(197, 123),
1037,
T::from_f64(0.0156),
),
Swatch::new(
Color::from_str("#000000").unwrap(),
(157, 95),
1036,
T::from_f64(0.0156),
),
Swatch::new(
Color::from_str("#FCB131").unwrap(),
(119, 123),
1005,
T::from_f64(0.0151),
),
]
}
#[must_use]
fn collect_sorted_swatches<T>(palette: &Palette<T>) -> Vec<Swatch<T>>
where
T: FloatNumber,
{
let mut swatches = palette.swatches().iter().copied().collect::<Vec<_>>();
swatches.sort_by(|a, b| {
a.population()
.cmp(&b.population())
.reverse()
.then(a.color().to_rgb_int().cmp(&b.color().to_rgb_int()))
});
swatches
}
#[test]
fn test_new() {
let swatches = vec![
Swatch::<f64>::new(Color::from_str("#FFFFFF").unwrap(), (5, 10), 256, 0.5714),
Swatch::<f64>::new(Color::from_str("#C8102E").unwrap(), (15, 20), 128, 0.2857),
Swatch::<f64>::new(Color::from_str("#012169").unwrap(), (30, 30), 64, 0.1429),
];
let actual = Palette::new(swatches.clone());
assert!(!actual.is_empty());
assert_eq!(actual.len(), 3);
assert_eq!(actual.swatches, swatches);
}
#[test]
fn test_new_empty() {
let swatches = vec![];
let actual: Palette<f64> = Palette::new(swatches.clone());
assert!(actual.is_empty());
assert_eq!(actual.len(), 0);
}
#[cfg(feature = "image")]
#[rstest]
#[case::dbscan("dbscan")]
#[case::dbscanpp("dbscan++")]
#[case::kmeans("kmeans")]
#[case::slic("slic")]
#[case::snic("snic")]
fn test_builder_with_algorithm(#[case] name: &str) {
let image_data = ImageData::load("../../gfx/olympic_logo.png").unwrap();
let algorithm = Algorithm::from_str(name).unwrap();
let actual: Palette<f64> = Palette::builder()
.algorithm(algorithm)
.build(&image_data)
.expect("Failed to extract palette with algorithm");
assert!(!actual.is_empty());
assert!(actual.len() >= 5);
}
#[cfg(feature = "image")]
#[test]
fn test_builder_with_filter() {
let image_data = ImageData::load("../../gfx/flags/np.png").unwrap();
let actual: Palette<f64> = Palette::builder()
.filter(|rgba: &Rgba| rgba[3] != 0)
.build(&image_data)
.expect("Failed to extract palette with filter");
assert!(!actual.is_empty());
assert!(actual.len() >= 3);
let swatches = collect_sorted_swatches(&actual);
assert_color_eq!(
swatches[0].color(),
Color::<f64>::from_str("#DC143C").expect("Invalid color format")
);
assert_color_eq!(
swatches[1].color(),
Color::<f64>::from_str("#003893").expect("Invalid color format")
);
assert_color_eq!(
swatches[2].color(),
Color::<f64>::from_str("#FFFFFF").expect("Invalid color format")
);
}
#[cfg(feature = "image")]
#[test]
fn test_builder_max_swatches() {
let image_data = ImageData::load("../../gfx/olympic_logo.png").unwrap();
let actual: Palette<f64> = Palette::builder()
.max_swatches(3)
.build(&image_data)
.expect("Failed to extract palette with max swatches");
assert!(!actual.is_empty());
assert_eq!(actual.len(), 3);
let swatches = collect_sorted_swatches(&actual);
assert_color_eq!(
swatches[0].color(),
Color::<f64>::from_str("#FFFFFF").expect("Invalid color format")
);
assert_color_eq!(
swatches[1].color(),
Color::<f64>::from_str("#0081C8").expect("Invalid color format")
);
assert_color_eq!(
swatches[2].color(),
Color::<f64>::from_str("#EE334E").expect("Invalid color format")
);
}
#[test]
fn test_builder_empty_image_data() {
let data: Vec<u8> = Vec::new();
let image_data = ImageData::new(0, 0, &data).unwrap();
let actual = Palette::<f64>::builder().build(&image_data);
assert!(actual.is_ok());
let palette = actual.expect("Failed to extract palette from empty image data");
assert!(palette.is_empty());
assert_eq!(palette.len(), 0);
}
#[test]
fn test_extract_transparent_image() {
let data: Vec<u8> = vec![0; 4 * 10 * 10]; let image_data = ImageData::new(10, 10, &data).unwrap();
let actual: Result<Palette<f64>, _> = Palette::builder().build(&image_data);
assert!(actual.is_ok());
let palette = actual.expect("Failed to extract palette from transparent image");
assert!(palette.is_empty());
assert_eq!(palette.len(), 0);
}
#[test]
#[cfg(feature = "image")]
fn test_extract() {
let image_data = ImageData::load("../../gfx/olympic_logo.png").unwrap();
let actual: Palette<f64> =
Palette::extract(&image_data).expect("Failed to extract palette");
assert!(!actual.is_empty());
assert_eq!(actual.len(), 6);
let swatches = collect_sorted_swatches(&actual);
assert_color_eq!(
swatches[0].color(),
Color::<f64>::from_str("#FFFFFF").expect("Invalid color format")
);
assert_color_eq!(
swatches[1].color(),
Color::<f64>::from_str("#0081C8").expect("Invalid color format")
);
assert_color_eq!(
swatches[2].color(),
Color::<f64>::from_str("#EE334E").expect("Invalid color format")
);
assert_color_eq!(
swatches[3].color(),
Color::<f64>::from_str("#000000").expect("Invalid color format")
);
assert_color_eq!(
swatches[4].color(),
Color::<f64>::from_str("#00A651").expect("Invalid color format")
);
assert_color_eq!(
swatches[5].color(),
Color::<f64>::from_str("#FCB131").expect("Invalid color format")
);
}
#[test]
#[cfg(feature = "image")]
fn test_extract_with_algorithm() {
let image_data = ImageData::load("../../gfx/olympic_logo.png").unwrap();
let actual: Palette<f64> =
Palette::extract_with_algorithm(&image_data, Algorithm::DBSCANpp).unwrap();
assert!(!actual.is_empty());
assert_eq!(actual.len(), 6);
}
#[test]
fn test_find_swatches() {
let swatches = sample_swatches::<f64>();
let palette = Palette::new(swatches.clone());
let actual = palette.find_swatches(4);
assert!(actual.is_ok());
let swatches = actual.expect("Failed to find swatches");
assert_eq!(swatches.len(), 4);
assert_eq!(swatches[0].color().to_hex_string(), "#FFFFFF");
assert_eq!(swatches[1].color().to_hex_string(), "#EE334E");
assert_eq!(swatches[2].color().to_hex_string(), "#00A651");
assert_eq!(swatches[3].color().to_hex_string(), "#000000");
}
#[test]
fn test_find_swatches_empty() {
let swatches: Vec<Swatch<f64>> = Vec::new();
let palette = Palette::new(swatches.clone());
let actual = palette.find_swatches(10);
assert!(actual.is_ok());
let swatches = actual.expect("Failed to find swatches in empty palette");
assert!(swatches.is_empty());
}
#[rstest]
#[case::colorful(Theme::Colorful, vec ! ["#EE334E", "#00A651"])]
#[case::vivid(Theme::Vivid, vec ! ["#EE334E", "#00A651"])]
#[case::muted(Theme::Muted, vec ! ["#0081C8", "#000000"])]
#[case::light(Theme::Light, vec ! ["#0081C8", "#00A651"])]
#[case::dark(Theme::Dark, vec ! ["#0081C8", "#000000"])]
fn test_find_swatches_with_theme(#[case] theme: Theme, #[case] expected: Vec<&str>) {
let swatches = sample_swatches::<f64>();
let palette = Palette::new(swatches.clone());
let actual = palette
.find_swatches_with_theme(2, theme)
.expect("Failed to find swatches with theme");
actual
.iter()
.for_each(|swatch| println!("{:?}", swatch.color().to_hex_string()));
assert_eq!(actual.len(), 2);
assert_eq!(actual[0].color().to_hex_string(), expected[0]);
assert_eq!(actual[1].color().to_hex_string(), expected[1]);
}
}