use std::{cmp::Reverse, marker::PhantomData};
use crate::{
algorithm::Algorithm,
color::{rgb_to_xyz, xyz_to_lab, Color, Lab, D65},
error::Error,
image::ImageData,
math::{
clustering::{Cluster, ClusteringAlgorithm, DBSCAN},
denormalize,
normalize,
sampling::{DiversitySampling, SamplingAlgorithm, SamplingError, WeightedFarthestSampling},
DistanceMetric,
FloatNumber,
Point,
},
theme::Theme,
Swatch,
};
#[derive(Debug, Clone, PartialEq)]
pub struct Palette<T>
where
T: FloatNumber,
{
swatches: Vec<Swatch<T>>,
}
impl<T> Palette<T>
where
T: FloatNumber,
{
#[must_use]
pub fn new(swatches: Vec<Swatch<T>>) -> Self {
Self { swatches }
}
#[must_use]
pub fn len(&self) -> usize {
self.swatches.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.swatches.is_empty()
}
#[must_use]
pub fn swatches(&self) -> &[Swatch<T>] {
&self.swatches
}
pub fn find_swatches(&self, num_swatches: usize) -> Result<Vec<Swatch<T>>, Error> {
self.find_swatches_internal(
num_swatches,
|swatch| swatch.ratio(),
|scores| {
DiversitySampling::new(T::from_f64(0.6), scores, DistanceMetric::SquaredEuclidean)
},
)
}
pub fn find_swatches_with_theme(
&self,
num_swatches: usize,
theme: Theme,
) -> Result<Vec<Swatch<T>>, Error> {
self.find_swatches_internal(
num_swatches,
|swatch| theme.score(swatch),
|scores| WeightedFarthestSampling::new(scores, DistanceMetric::SquaredEuclidean),
)
}
pub fn find_swatches_internal<S, F1, F2>(
&self,
num_swatches: usize,
score_fn: F1,
sampling_factory: F2,
) -> Result<Vec<Swatch<T>>, Error>
where
S: SamplingAlgorithm<T>,
F1: Fn(&Swatch<T>) -> T,
F2: FnOnce(Vec<T>) -> Result<S, SamplingError>,
{
if self.swatches.is_empty() {
return Ok(vec![]);
}
let num_swatches = num_swatches.min(self.swatches.len());
let (colors, scores): (Vec<Point<T, 3>>, Vec<T>) = self
.swatches
.iter()
.map(|swatch| {
let color = swatch.color();
([color.l, color.a, color.b], score_fn(swatch))
})
.unzip();
let sampler =
sampling_factory(scores).map_err(|cause| Error::SwatchSelectionError { cause })?;
let sampled = sampler
.sample(&colors, num_swatches)
.map_err(|cause| Error::SwatchSelectionError { cause })?;
let mut found: Vec<_> = sampled.iter().map(|&index| self.swatches[index]).collect();
found.sort_by_key(|swatch| Reverse(swatch.population()));
Ok(found)
}
#[must_use]
pub fn builder() -> PaletteBuilder<T> {
PaletteBuilder::new()
}
pub fn extract(image_data: &ImageData) -> Result<Self, Error> {
Self::builder().build(image_data)
}
#[deprecated(
since = "0.8.0",
note = "Use `Palette::extract` or `Palette::builder` instead."
)]
pub fn extract_with_algorithm(
image_data: &ImageData,
algorithm: Algorithm,
) -> Result<Self, Error> {
Self::builder().algorithm(algorithm).build(image_data)
}
}
const PIXEL_SIZE: usize = 4;
pub type Pixel = [u8; PIXEL_SIZE];
type DynFilterFn = Box<dyn Fn(&Pixel) -> bool + Send + Sync + 'static>;
pub struct PaletteBuilder<T>
where
T: FloatNumber,
{
algorithm: Algorithm,
filters: Vec<DynFilterFn>,
max_swatches: usize,
_marker: PhantomData<T>,
}
impl<T> PaletteBuilder<T>
where
T: FloatNumber,
{
const DEFAULT_MAX_SWATCHES: usize = 256;
#[must_use]
fn new() -> PaletteBuilder<T> {
PaletteBuilder {
algorithm: Algorithm::default(),
filters: vec![
Box::new(|pixel: &Pixel| pixel[3] == 0), ],
max_swatches: Self::DEFAULT_MAX_SWATCHES,
_marker: PhantomData,
}
}
pub fn algorithm(mut self, algorithm: Algorithm) -> Self {
self.algorithm = algorithm;
self
}
#[must_use]
pub fn filter<F>(mut self, filter: F) -> Self
where
F: Fn(&Pixel) -> bool + Send + Sync + 'static,
{
self.filters.push(Box::new(filter));
self
}
#[must_use]
pub fn max_swatches(mut self, max_swatches: usize) -> Self {
self.max_swatches = max_swatches;
self
}
pub fn build(self, image_data: &ImageData) -> Result<Palette<T>, Error> {
let pixels = image_data.data();
if pixels.is_empty() {
return Err(Error::EmptyImageData);
}
let width = image_data.width();
let height = image_data.height();
let points = to_feature_points::<T>(width as usize, height as usize, pixels, &self.filters);
if points.is_empty() {
return Err(Error::EmptyImageData);
}
let pixel_clusters = self.algorithm.cluster(width, height, &points)?;
let color_clusters = cluster_to_color_groups(&pixel_clusters)?;
let mut swatches = convert_swatches_from_clusters(
T::from_u32(width),
T::from_u32(height),
&color_clusters,
&pixel_clusters,
);
swatches.sort_by_key(|swatch| Reverse(swatch.population()));
let palette = Palette::new(swatches.into_iter().take(self.max_swatches).collect());
Ok(palette)
}
}
fn to_feature_points<T>(
width: usize,
height: usize,
data: &[u8],
filters: &[DynFilterFn],
) -> Vec<Point<T, 5>>
where
T: FloatNumber,
{
let width_f = T::from_usize(width);
let height_f = T::from_usize(height);
data.chunks_exact(PIXEL_SIZE)
.enumerate()
.filter_map(|(index, chunk)| {
let r = chunk[0];
let g = chunk[1];
let b = chunk[2];
let a = chunk[3];
if filters.iter().any(|filter| filter(&[r, g, b, a])) {
return None;
}
let (x, y, z) = rgb_to_xyz::<T>(r, g, b);
let (l, a, b) = xyz_to_lab::<T, D65>(x, y, z);
let x = T::from_usize(index % width);
let y = T::from_usize(index / width);
Some([
normalize(l, Lab::<T>::min_l(), Lab::<T>::max_l()),
normalize(a, Lab::<T>::min_a(), Lab::<T>::max_a()),
normalize(b, Lab::<T>::min_b(), Lab::<T>::max_b()),
normalize(x, T::zero(), width_f),
normalize(y, T::zero(), height_f),
])
})
.collect()
}
const COLOR_GROUP_DBSCAN_MIN_POINTS: usize = 1;
const COLOR_GROUP_DBSCAN_EPSILON: f64 = 2.5;
fn cluster_to_color_groups<T>(pixel_clusters: &[Cluster<T, 5>]) -> Result<Vec<Cluster<T, 3>>, Error>
where
T: FloatNumber,
{
let colors: Vec<_> = pixel_clusters
.iter()
.map(|cluster| -> Point<T, 3> {
let centroid = cluster.centroid();
[
denormalize(centroid[0], Lab::<T>::min_l(), Lab::<T>::max_l()),
denormalize(centroid[1], Lab::<T>::min_a(), Lab::<T>::max_a()),
denormalize(centroid[2], Lab::<T>::min_b(), Lab::<T>::max_b()),
]
})
.collect();
let dbscan = DBSCAN::new(
COLOR_GROUP_DBSCAN_MIN_POINTS,
T::from_f64(COLOR_GROUP_DBSCAN_EPSILON),
DistanceMetric::Euclidean,
)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})?;
dbscan
.fit(&colors)
.map_err(|e| Error::PaletteExtractionError {
details: e.to_string(),
})
}
#[must_use]
fn convert_swatches_from_clusters<T>(
width: T,
height: T,
color_clusters: &[Cluster<T, 3>],
pixel_clusters: &[Cluster<T, 5>],
) -> Vec<Swatch<T>>
where
T: FloatNumber,
{
color_clusters
.iter()
.filter_map(|color_cluster| {
let mut best_color = [T::zero(); 3];
let mut best_position = None;
let mut best_population = 0;
let mut total_population = 0;
for &member in color_cluster.members() {
let Some(pixel_cluster) = pixel_clusters.get(member) else {
continue;
};
if pixel_cluster.is_empty() {
continue;
}
let fraction = T::from_usize(pixel_cluster.len())
/ T::from_usize(pixel_cluster.len() + best_population);
let centroid = pixel_cluster.centroid();
best_color.iter_mut().enumerate().for_each(|(i, color)| {
*color += fraction * (centroid[i] - *color);
});
if fraction >= T::from_f32(0.5) || best_population == 0 {
best_position = Some((
denormalize(centroid[3], T::zero(), width).trunc_to_u32(),
denormalize(centroid[4], T::zero(), height).trunc_to_u32(),
));
best_population = pixel_cluster.len();
}
total_population += pixel_cluster.len();
}
if let Some(position) = best_position {
let l = denormalize(best_color[0], Lab::<T>::min_l(), Lab::<T>::max_l());
let a = denormalize(best_color[1], Lab::<T>::min_a(), Lab::<T>::max_a());
let b = denormalize(best_color[2], Lab::<T>::min_b(), Lab::<T>::max_b());
Some(Swatch::new(
Color::new(l, a, b),
position,
total_population,
T::from_usize(total_population) / (width * height),
))
} else {
None
}
})
.collect()
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use rstest::rstest;
use super::*;
#[must_use]
fn sample_swatches<T>() -> Vec<Swatch<T>>
where
T: FloatNumber,
{
vec![
Swatch::new(
Color::from_str("#FFFFFF").unwrap(),
(159, 106),
61228,
T::from_f64(0.9214),
),
Swatch::new(
Color::from_str("#EE334E").unwrap(),
(238, 89),
1080,
T::from_f64(0.0163),
),
Swatch::new(
Color::from_str("#0081C8").unwrap(),
(82, 88),
1064,
T::from_f64(0.0160),
),
Swatch::new(
Color::from_str("#00A651").unwrap(),
(197, 123),
1037,
T::from_f64(0.0156),
),
Swatch::new(
Color::from_str("#000000").unwrap(),
(157, 95),
1036,
T::from_f64(0.0156),
),
Swatch::new(
Color::from_str("#FCB131").unwrap(),
(119, 123),
1005,
T::from_f64(0.0151),
),
]
}
#[must_use]
fn empty_swatches<T>() -> Vec<Swatch<T>>
where
T: FloatNumber,
{
vec![]
}
#[test]
fn test_new() {
let swatches = vec![
Swatch::<f32>::new(Color::from_str("#FFFFFF").unwrap(), (5, 10), 256, 0.5714),
Swatch::<f32>::new(Color::from_str("#C8102E").unwrap(), (15, 20), 128, 0.2857),
Swatch::<f32>::new(Color::from_str("#012169").unwrap(), (30, 30), 64, 0.1429),
];
let actual = Palette::new(swatches.clone());
assert!(!actual.is_empty());
assert_eq!(actual.len(), 3);
assert_eq!(actual.swatches, swatches);
}
#[test]
fn test_new_empty() {
let swatches = vec![];
let actual: Palette<f32> = Palette::new(swatches.clone());
assert!(actual.is_empty());
assert_eq!(actual.len(), 0);
}
#[cfg(feature = "image")]
#[rstest]
#[case::kmeans("kmeans")]
#[case::dbscan("dbscan")]
#[case::dbscanpp("dbscan++")]
fn test_builder_with_algorithm(#[case] name: &str) {
let image_data = ImageData::load("../../gfx/olympic_logo.png").unwrap();
let algorithm = Algorithm::from_str(name).unwrap();
let actual: Palette<f32> = Palette::builder()
.algorithm(algorithm)
.build(&image_data)
.unwrap();
assert!(!actual.is_empty());
assert!(actual.len() >= 5);
}
#[cfg(feature = "image")]
#[test]
fn test_builder_with_filter() {
let image_data = ImageData::load("../../gfx/olympic_logo.png").unwrap();
let actual: Palette<f32> = Palette::builder()
.filter(|pixel| pixel[3] == 0x00)
.filter(|pixel| pixel[0] == 0xff || pixel[1] == 0xff || pixel[2] == 0xff)
.build(&image_data)
.unwrap();
assert!(!actual.is_empty());
assert_eq!(actual.len(), 5);
assert_eq!(actual.swatches[0].color().to_hex_string(), "#0081C8");
assert_eq!(actual.swatches[1].color().to_hex_string(), "#EE334E");
assert_eq!(actual.swatches[2].color().to_hex_string(), "#000000");
assert_eq!(actual.swatches[3].color().to_hex_string(), "#01A651");
assert_eq!(actual.swatches[4].color().to_hex_string(), "#FCB131");
}
#[cfg(feature = "image")]
#[test]
fn test_builder_max_swatches() {
let image_data = ImageData::load("../../gfx/olympic_logo.png").unwrap();
let actual: Palette<f32> = Palette::builder()
.max_swatches(3)
.build(&image_data)
.unwrap();
assert!(!actual.is_empty());
assert_eq!(actual.len(), 3);
}
#[test]
fn test_builder_empty_image_data() {
let data: Vec<u8> = Vec::new();
let image_data = ImageData::new(0, 0, &data).unwrap();
let result: Result<Palette<f64>, _> = Palette::builder().build(&image_data);
assert!(result.is_err());
assert_eq!(
result.unwrap_err().to_string(),
"Image data is empty: no pixels to process"
);
}
#[test]
fn test_extract_transparent_image() {
let data: Vec<u8> = vec![0; 4 * 10 * 10]; let image_data = ImageData::new(10, 10, &data).unwrap();
let result: Result<Palette<f64>, _> = Palette::builder().build(&image_data);
assert!(result.is_err());
assert_eq!(
result.unwrap_err().to_string(),
"Image data is empty: no pixels to process"
);
}
#[test]
#[cfg(feature = "image")]
fn test_extract() {
let image_data = ImageData::load("../../gfx/olympic_logo.png").unwrap();
let actual: Palette<f64> = Palette::extract(&image_data).unwrap();
assert!(!actual.is_empty());
assert!(actual.len() >= 3);
}
#[warn(deprecated)]
#[test]
#[cfg(feature = "image")]
fn test_extract_with_algorithm() {
let image_data = ImageData::load("../../gfx/olympic_logo.png").unwrap();
let actual: Palette<f64> =
Palette::extract_with_algorithm(&image_data, Algorithm::DBSCANpp).unwrap();
assert!(!actual.is_empty());
assert_eq!(actual.len(), 5);
}
#[test]
fn test_find_swatches() {
let swatches = sample_swatches::<f32>();
let palette = Palette::new(swatches.clone());
let actual = palette.find_swatches(4);
assert!(actual.is_ok());
let actual = actual.unwrap();
assert_eq!(actual.len(), 4);
assert_eq!(actual[0].color().to_hex_string(), "#FFFFFF");
assert_eq!(actual[1].color().to_hex_string(), "#EE334E");
assert_eq!(actual[2].color().to_hex_string(), "#00A651");
assert_eq!(actual[3].color().to_hex_string(), "#000000");
}
#[test]
fn test_find_swatches_empty() {
let swatches = empty_swatches::<f32>();
let palette = Palette::new(swatches.clone());
let actual = palette.find_swatches(10);
assert!(actual.is_ok());
assert!(actual.unwrap().is_empty(), "Expected empty swatches");
}
#[rstest]
#[case::colorful(Theme::Colorful, vec ! ["#EE334E", "#00A651"])]
#[case::vivid(Theme::Vivid, vec ! ["#EE334E", "#00A651"])]
#[case::muted(Theme::Muted, vec ! ["#0081C8", "#000000"])]
#[case::light(Theme::Light, vec ! ["#0081C8", "#00A651"])]
#[case::dark(Theme::Dark, vec ! ["#0081C8", "#000000"])]
fn test_find_swatches_with_theme(#[case] theme: Theme, #[case] expected: Vec<&str>) {
let swatches = sample_swatches::<f32>();
let palette = Palette::new(swatches.clone());
let actual = palette.find_swatches_with_theme(2, theme).unwrap();
actual
.iter()
.for_each(|swatch| println!("{:?}", swatch.color().to_hex_string()));
assert_eq!(actual.len(), 2);
assert_eq!(actual[0].color().to_hex_string(), expected[0]);
assert_eq!(actual[1].color().to_hex_string(), expected[1]);
}
}