use crate::{
BoundedIndex, BoundedSlice, ColorComponents, ImageRef, IndexedImage, LengthOutOfRange,
MAX_PIXELS, PaletteBuf,
color_map::{NearestNeighborColorMap, simd_argmin_min_distance},
};
use alloc::{vec, vec::Vec};
use core::{array, num::NonZeroU32};
use num_traits::AsPrimitive;
use ordered_float::OrderedFloat;
use palette::cast::{self, AsArrays as _};
use rand::{SeedableRng as _, distr::Uniform, prelude::Distribution as _};
use rand_xoshiro::Xoroshiro128PlusPlus;
#[must_use]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct KmeansOptions {
sampling_factor: OrderedFloat<f32>,
max_samples: u32,
batch_size: u32,
seed: u64,
}
impl KmeansOptions {
#[inline]
pub const fn new() -> Self {
Self {
sampling_factor: OrderedFloat(1.0),
max_samples: 512 * 512,
batch_size: 4096,
seed: 0,
}
}
#[inline]
pub const fn sampling_factor(self, sampling_factor: f32) -> Self {
Self {
sampling_factor: OrderedFloat(sampling_factor),
..self
}
}
#[inline]
pub const fn max_samples(self, max_samples: u32) -> Self {
Self { max_samples, ..self }
}
#[inline]
pub const fn batch_size(self, batch_size: u32) -> Self {
Self { batch_size, ..self }
}
#[inline]
pub const fn seed(self, seed: u64) -> Self {
Self { seed, ..self }
}
#[inline]
pub const fn get_sampling_factor(&self) -> f32 {
self.sampling_factor.0
}
#[inline]
pub const fn get_max_samples(&self) -> u32 {
self.max_samples
}
#[inline]
pub const fn get_batch_size(&self) -> u32 {
self.batch_size
}
#[inline]
pub const fn get_seed(&self) -> u64 {
self.seed
}
#[inline]
fn num_samples(&self, len: u32) -> Option<NonZeroU32> {
if len == 0 || self.batch_size == 0 {
None
} else {
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
let samples = (f64::from(len) * f64::from(self.sampling_factor.0)) as u32;
NonZeroU32::new(samples.min(self.max_samples))
}
}
}
impl Default for KmeansOptions {
#[inline]
fn default() -> Self {
Self::new()
}
}
struct State<Color, Component, const N: usize> {
nearest: NearestNeighborColorMap<Color, Component, N>,
counts: PaletteBuf<u32>,
}
impl<Color, Component, const N: usize> State<Color, Component, N>
where
Color: ColorComponents<Component, N>,
Component: Copy + Into<f32> + 'static,
f32: AsPrimitive<Component>,
{
#[inline]
fn new(centroids: PaletteBuf<Color>) -> Self {
let counts = PaletteBuf::new_unchecked(vec![0; centroids.len()]);
let nearest = NearestNeighborColorMap::new(centroids);
State { nearest, counts }
}
#[inline]
fn add_sample_to(&mut self, chunk: u8, lane: u8, color: [f32; N]) {
let Self { nearest, counts, .. } = self;
let i = chunk * 8 + lane;
let count = counts[usize::from(i)] + 1;
#[allow(clippy::cast_possible_truncation)]
let rate = {
#[cfg(feature = "std")]
{
(1.0 / f64::from(count).sqrt()) as f32
}
#[cfg(not(feature = "std"))]
{
(1.0 / libm::sqrt(count.into())) as f32
}
};
for (c, x) in nearest.data[usize::from(chunk)].iter_mut().zip(color) {
let c = &mut c.as_mut_array()[usize::from(lane)];
*c += rate * (x - *c);
}
counts[usize::from(i)] = count;
}
#[inline]
fn add_sample(&mut self, color: [Component; N]) {
let color = color.map(Into::into);
let (chunk, lane) = simd_argmin_min_distance(&self.nearest.data, color).0;
self.add_sample_to(chunk, lane, color);
}
fn online_kmeans(
&mut self,
num_pixels: u32,
index_to_color: impl Fn(u32) -> Color,
samples: NonZeroU32,
options: KmeansOptions,
) {
const BATCH_SIZE: u32 = 256;
let samples = samples.get();
#[allow(clippy::expect_used)]
let distribution = Uniform::new(0, num_pixels).expect("num_pixels != 0");
let rng = &mut Xoroshiro128PlusPlus::seed_from_u64(options.seed);
let mut batch = Vec::with_capacity(BATCH_SIZE as usize);
let mut add_samples = |state: &mut State<Color, Component, N>, n: u32| {
batch.extend((0..n).map(|_| index_to_color(distribution.sample(rng))));
for &color in batch.as_arrays() {
state.add_sample(color);
}
batch.clear();
};
for _ in 0..(samples / BATCH_SIZE) {
add_samples(self, BATCH_SIZE);
}
add_samples(self, samples % BATCH_SIZE);
}
}
#[must_use]
pub struct Kmeans<Color, Component, const N: usize> {
result: Result<State<Color, Component, N>, PaletteBuf<Color>>,
}
impl<Color, Component, const N: usize> Kmeans<Color, Component, N>
where
Color: ColorComponents<Component, N>,
Component: Copy + Into<f32> + 'static,
f32: AsPrimitive<Component>,
{
fn run<T>(
len: u32,
index_to_color: T,
centroids: PaletteBuf<Color>,
options: KmeansOptions,
f: impl FnOnce(&mut State<Color, Component, N>, u32, T, NonZeroU32, KmeansOptions),
) -> Self {
let result = if let Some(samples) = options.num_samples(len) {
let mut state = State::new(centroids);
f(&mut state, len, index_to_color, samples, options);
Ok(state)
} else {
Err(centroids)
};
Self { result }
}
pub(crate) fn run_slice_bounded(
colors: &BoundedSlice<Color>,
centroids: PaletteBuf<Color>,
options: KmeansOptions,
) -> Self {
Self::run(
colors.length(),
|i| colors[i as usize],
centroids,
options,
State::online_kmeans,
)
}
pub fn run_slice(
colors: &[Color],
centroids: PaletteBuf<Color>,
options: KmeansOptions,
) -> Result<Self, LengthOutOfRange> {
LengthOutOfRange::check_u32(colors, 0, MAX_PIXELS).map(|len| {
Self::run(
len,
|i| colors[i as usize],
centroids,
options,
State::online_kmeans,
)
})
}
pub fn run_image(
image: ImageRef<'_, Color>,
centroids: PaletteBuf<Color>,
options: KmeansOptions,
) -> Self {
let pixels = image.as_slice();
Self::run(
image.num_pixels(),
|i| pixels[i as usize],
centroids,
options,
State::online_kmeans,
)
}
pub fn run_indexed_image<Index>(
image: &IndexedImage<Color, Index>,
centroids: PaletteBuf<Color>,
options: KmeansOptions,
) -> Self
where
Index: BoundedIndex + Into<u32>,
Index::Length: Into<u32>,
{
let palette = image.palette();
let indices = image.indices();
Self::run(
image.num_pixels(),
|i| palette[indices[i as usize].as_()],
centroids,
options,
State::online_kmeans,
)
}
fn finalize<T>(
self,
f: impl FnOnce(Result<State<Color, Component, N>, PaletteBuf<Color>>) -> T,
) -> T {
let Self { mut result } = self;
if let Ok(State { nearest, .. }) = &mut result {
for (palette, data) in nearest.palette.chunks_mut(8).zip(&nearest.data) {
let colors = array::from_fn::<Color, 8, _>(|i| {
cast::from_array(data.map(|x| x.as_array()[i].as_()))
});
palette.copy_from_slice(&colors[..palette.len()]);
}
}
f(result)
}
#[must_use]
pub fn into_palette(self) -> PaletteBuf<Color> {
self.finalize(|result| match result {
Ok(data) => data.nearest.into_palette(),
Err(palette) => palette,
})
}
#[must_use]
pub fn into_palette_and_counts(self) -> (PaletteBuf<Color>, PaletteBuf<u32>) {
self.finalize(|result| match result {
Ok(State { nearest, counts, .. }) => (nearest.into_palette(), counts),
Err(palette) => {
let counts = PaletteBuf::new_unchecked(vec![0; palette.len()]);
(palette, counts)
}
})
}
#[must_use]
pub fn into_color_map_and_counts(
self,
) -> (
NearestNeighborColorMap<Color, Component, N>,
PaletteBuf<u32>,
) {
self.finalize(|result| match result {
Ok(State { nearest, counts, .. }) => (nearest, counts),
Err(palette) => {
let counts = PaletteBuf::new_unchecked(vec![0; palette.len()]);
let color_map = NearestNeighborColorMap::new(palette);
(color_map, counts)
}
})
}
#[must_use]
pub fn into_color_map(self) -> NearestNeighborColorMap<Color, Component, N> {
self.finalize(|result| match result {
Ok(State { nearest, .. }) => nearest,
Err(palette) => NearestNeighborColorMap::new(palette),
})
}
}
#[cfg(feature = "threads")]
mod parallel {
use super::{Kmeans, KmeansOptions, State};
use crate::{
BoundedIndex, ColorComponents, ImageRef, IndexedImage, LengthOutOfRange, MAX_PIXELS,
PaletteBuf,
color_map::{NearestNeighborParallelColorMap, simd_argmin_min_distance},
};
use alloc::vec;
use core::num::NonZeroU32;
use num_traits::AsPrimitive;
use palette::cast::{self, AsArrays as _};
use rand::{SeedableRng as _, distr::Uniform, prelude::Distribution as _};
use rand_xoshiro::Xoroshiro128PlusPlus;
use rayon::prelude::*;
impl<Color, Component, const N: usize> State<Color, Component, N>
where
Color: ColorComponents<Component, N>,
Component: Copy + Into<f32> + 'static + Send + Sync,
f32: AsPrimitive<Component>,
{
fn minibatch_kmeans(
&mut self,
_num_pixels: u32,
colors: &[Color],
samples: NonZeroU32,
options: KmeansOptions,
) {
#[repr(align(64))]
struct Align64<T>(T);
let samples = samples.get();
let KmeansOptions { batch_size, seed, .. } = options;
let threads = rayon::current_num_threads();
let chunk_size = (batch_size as usize).div_ceil(threads);
#[allow(clippy::expect_used)]
let distribution = Uniform::new(0, colors.len()).expect("num_pixels != 0");
let mut rng = (0..threads)
.scan(Xoroshiro128PlusPlus::seed_from_u64(seed), |rng, _| {
rng.jump();
Some(Align64(rng.clone()))
})
.collect::<Vec<_>>();
let mut batch = vec![[0.0.as_(); N]; batch_size as usize];
let mut assignments = vec![(0, 0); batch_size as usize];
let colors = colors.as_arrays();
let mut run = |state: &mut State<Color, Component, N>,
batch: &mut [[Component; N]],
assignments: &mut [(u8, u8)],
chunk_size| {
batch
.par_chunks_mut(chunk_size)
.zip(assignments.par_chunks_mut(chunk_size))
.zip(&mut rng)
.for_each(|((batch, assignments), Align64(rng))| {
for color in &mut *batch {
*color = colors[distribution.sample(rng)];
}
for (color, center) in batch.iter().zip(assignments) {
*center = simd_argmin_min_distance(
&state.nearest.data,
color.map(Into::into),
)
.0;
}
});
for (color, &(chunk, lane)) in batch.iter().zip(&*assignments) {
state.add_sample_to(chunk, lane, color.map(Into::into));
}
};
for _ in 0..(samples / batch_size) {
run(self, &mut batch, &mut assignments, chunk_size);
}
let remainder = (samples % batch_size) as usize;
if remainder != 0 {
run(
self,
&mut batch[..remainder],
&mut assignments[..remainder],
remainder.div_ceil(threads),
);
}
}
fn minibatch_kmeans_indexed<Index: BoundedIndex>(
&mut self,
_num_pixels: u32,
image: &IndexedImage<Color, Index>,
samples: NonZeroU32,
options: KmeansOptions,
) {
#[repr(align(64))]
struct Align64<T>(T);
let samples = samples.get();
let KmeansOptions { batch_size, seed, .. } = options;
let threads = rayon::current_num_threads();
let chunk_size = (batch_size as usize).div_ceil(threads);
#[allow(clippy::expect_used)]
let distribution = Uniform::new(0, image.num_pixels()).expect("num_pixels != 0");
let mut rng = (0..threads)
.scan(Xoroshiro128PlusPlus::seed_from_u64(seed), |rng, _| {
rng.jump();
Some(Align64(rng.clone()))
})
.collect::<Vec<_>>();
let mut batch = vec![[0.0.as_(); N]; batch_size as usize];
let mut assignments = vec![(0, 0); batch_size as usize];
let colors = image.palette();
let indices = image.indices();
let mut run = |state: &mut State<Color, Component, N>,
batch: &mut [[Component; N]],
assignments: &mut [(u8, u8)],
chunk_size| {
batch
.par_chunks_mut(chunk_size)
.zip(assignments.par_chunks_mut(chunk_size))
.zip(&mut rng)
.for_each(|((batch, assignments), Align64(rng))| {
for color in &mut *batch {
let index = indices[distribution.sample(rng) as usize];
*color = cast::into_array(colors[index.as_()]);
}
for (color, center) in batch.iter().zip(assignments) {
*center = simd_argmin_min_distance(
&state.nearest.data,
color.map(Into::into),
)
.0;
}
});
for (color, &(chunk, lane)) in batch.iter().zip(&*assignments) {
state.add_sample_to(chunk, lane, color.map(Into::into));
}
};
for _ in 0..(samples / batch_size) {
run(self, &mut batch, &mut assignments, chunk_size);
}
let remainder = (samples % batch_size) as usize;
if remainder != 0 {
run(
self,
&mut batch[..remainder],
&mut assignments[..remainder],
remainder.div_ceil(threads),
);
}
}
}
impl<Color, Component, const N: usize> Kmeans<Color, Component, N>
where
Color: ColorComponents<Component, N>,
Component: Copy + Into<f32> + 'static + Send + Sync,
f32: AsPrimitive<Component>,
{
#[allow(clippy::cast_possible_truncation)]
pub(crate) fn run_slice_par_unchecked(
colors: &[Color],
centroids: PaletteBuf<Color>,
options: KmeansOptions,
) -> Self {
Self::run(
colors.len() as u32,
colors,
centroids,
options,
State::minibatch_kmeans,
)
}
#[inline]
pub fn run_slice_par(
colors: &[Color],
centroids: PaletteBuf<Color>,
options: KmeansOptions,
) -> Result<Self, LengthOutOfRange> {
LengthOutOfRange::check_u32(colors, 0, MAX_PIXELS)
.map(|len| Self::run(len, colors, centroids, options, State::minibatch_kmeans))
}
#[inline]
pub fn run_image_par(
image: ImageRef<'_, Color>,
centroids: PaletteBuf<Color>,
options: KmeansOptions,
) -> Self {
Self::run_slice_par_unchecked(image.as_slice(), centroids, options)
}
pub fn run_indexed_image_par<Index>(
image: &IndexedImage<Color, Index>,
centroids: PaletteBuf<Color>,
options: KmeansOptions,
) -> Self
where
Index: BoundedIndex + Into<u32>,
Index::Length: Into<u32>,
{
Self::run(
image.num_pixels(),
image,
centroids,
options,
State::minibatch_kmeans_indexed,
)
}
#[must_use]
#[inline]
pub fn into_parallel_color_map_and_counts(
self,
) -> (
NearestNeighborParallelColorMap<Color, Component, N>,
PaletteBuf<u32>,
) {
let (color_map, counts) = self.into_color_map_and_counts();
(color_map.into(), counts)
}
#[must_use]
#[inline]
pub fn into_parallel_color_map(
self,
) -> NearestNeighborParallelColorMap<Color, Component, N> {
self.into_color_map().into()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::*;
use palette::Srgb;
fn test_palette() -> PaletteBuf<Srgb<u8>> {
let mut centroids = test_data_256();
centroids.truncate(249u8.try_into().unwrap()); centroids
}
#[test]
fn no_samples_gives_initial_centroids() {
let colors = test_data_1024();
let centroids = test_palette();
let options = KmeansOptions::new().max_samples(0);
let actual = Kmeans::run_slice(&colors, centroids.clone(), options)
.unwrap()
.into_palette_and_counts();
let expected = (
centroids.clone(),
PaletteBuf::new_unchecked(vec![0; centroids.len()]),
);
assert_eq!(actual, expected);
#[cfg(feature = "threads")]
{
let actual = Kmeans::run_slice_par(&colors, centroids.clone(), options.batch_size(64))
.unwrap()
.into_palette_and_counts();
assert_eq!(actual, expected);
}
}
#[cfg(feature = "threads")]
#[test]
fn zero_batch_size_gives_initial_centroids() {
let colors = test_data_1024();
let centroids = test_palette();
let options = KmeansOptions::new().max_samples(0);
let expected = (
centroids.clone(),
PaletteBuf::new_unchecked(vec![0; centroids.len()]),
);
let actual = Kmeans::run_slice_par(&colors, centroids.clone(), options.batch_size(64))
.unwrap()
.into_palette_and_counts();
assert_eq!(actual, expected);
}
#[test]
fn empty_input_gives_initial_centroids() {
let centroids = test_palette();
let options = KmeansOptions::new().max_samples(0);
let actual = Kmeans::run_slice(&[], centroids.clone(), options)
.unwrap()
.into_palette_and_counts();
let expected = (
centroids.clone(),
PaletteBuf::new_unchecked(vec![0; centroids.len()]),
);
assert_eq!(actual, expected);
let actual = Kmeans::run_image(ImageRef::default(), centroids.clone(), options)
.into_palette_and_counts();
assert_eq!(actual, expected);
let actual = Kmeans::run_indexed_image(
&IndexedImage::<_, u8>::default(),
centroids.clone(),
options,
)
.into_palette_and_counts();
assert_eq!(actual, expected);
#[cfg(feature = "threads")]
{
let actual = Kmeans::run_slice_par(&[], centroids.clone(), options.batch_size(64))
.unwrap()
.into_palette_and_counts();
assert_eq!(actual, expected);
let actual = Kmeans::run_image_par(ImageRef::default(), centroids.clone(), options)
.into_palette_and_counts();
assert_eq!(actual, expected);
let actual = Kmeans::run_indexed_image_par(
&IndexedImage::<_, u8>::default(),
centroids.clone(),
options,
)
.into_palette_and_counts();
assert_eq!(actual, expected);
}
}
#[test]
fn exact_match_image_unaffected() {
let centroids = test_palette();
let indices = {
#[allow(clippy::cast_possible_truncation)]
let indices = (0..centroids.len()).map(|i| i as u8).collect::<Vec<_>>();
let mut indices = [indices.as_slice(); 4].concat();
indices.rotate_right(7);
indices
};
let colors = indices.iter().map(|&i| centroids[i]).collect::<Vec<_>>();
let samples = 505;
let options = KmeansOptions::new().max_samples(samples);
let (palette, counts) = Kmeans::run_slice(&colors, centroids.clone(), options)
.unwrap()
.into_palette_and_counts();
assert_eq!(palette, centroids);
assert_eq!(counts.len(), centroids.len());
assert_eq!(counts.into_iter().sum::<u32>(), samples);
#[cfg(feature = "threads")]
{
let (palette, counts) =
Kmeans::run_slice_par(&colors, centroids.clone(), options.batch_size(64))
.unwrap()
.into_palette_and_counts();
assert_eq!(palette, centroids);
assert_eq!(counts.len(), centroids.len());
assert_eq!(counts.into_iter().sum::<u32>(), samples);
}
}
}