pub mod color;
pub mod weights;
use rand::{distributions::WeightedIndex, prelude::*};
use color::LAB;
use weights::WeightFn;
#[cfg(target_arch = "wasm32")]
use {
wasm_bindgen::{prelude::*, JsCast},
web_sys::{CanvasRenderingContext2d, HtmlCanvasElement},
weights::{Mood, resolve_mood},
color::{RGB, HSL},
serde_derive::Serialize,
};
#[cfg(not(target_arch = "wasm32"))]
use {
std::cmp,
crossbeam_utils::thread,
};
pub type Pixels = Vec<LAB>;
fn recal_means(colors: &Vec<LAB>, weight: WeightFn) -> LAB {
let mut new_color = LAB {
l: 0.0,
a: 0.0,
b: 0.0
};
let mut w_sum = 0.0;
for col in colors.iter() {
let w = weight(col);
w_sum += w;
new_color.l += w * col.l;
new_color.a += w * col.a;
new_color.b += w * col.b;
}
new_color.l /= w_sum;
new_color.a /= w_sum;
new_color.b /= w_sum;
return new_color;
}
#[cfg(target_arch = "wasm32")]
fn find_clusters(pixels: &Pixels, means: &Pixels, k: usize) -> Vec<Pixels> {
let mut clusters: Vec<Pixels> = vec![Vec::new(); k];
for color in pixels.iter() {
clusters[color.nearest(means).0].push(color.clone());
}
return clusters;
}
#[cfg(not(target_arch = "wasm32"))]
fn find_clusters(pixels: &Pixels, means: &Pixels, k: usize) -> Vec<Pixels> {
const NUM_THREADS: usize = 5;
let num_pixels = pixels.len();
let sample_size = num_pixels / NUM_THREADS;
let mut clusters: Vec<Pixels> = vec![Vec::new(); k as usize];
thread::scope(|s| {
let mut threads = vec![];
for i in 0..NUM_THREADS {
let start = i * sample_size;
let end = cmp::min(start + sample_size, num_pixels);
threads.push(
s.spawn(move |_| {
let mut clusters: Vec<Vec<LAB>> = vec![Vec::new(); k];
for pixel_idx in start..end {
let color = &pixels[pixel_idx];
clusters[color.nearest(&means).0].push(color.clone());
}
return clusters;
})
);
}
for t in threads {
let mut mid_clusters = t.join().unwrap();
for (i, cluster) in mid_clusters.iter_mut().enumerate() {
clusters[i].append(cluster);
}
}
}).unwrap();
return clusters;
}
pub fn pigments_pixels(pixels: &Pixels, k: u8, weight: WeightFn, max_iter: Option<u16>) -> Vec<(LAB, f32)> {
const TOLERANCE: f32 = 1e-4;
const MAX_ITER: u16 = 300;
let mut rng = rand::thread_rng();
let k = k as usize;
let i: usize = rng.gen_range(0, pixels.len());
let mut means: Pixels = vec![pixels[i].clone()];
for _ in 0..(k - 1) {
let distances: Vec<f32> = pixels
.iter()
.map(|color| (color.nearest(&means).1).powi(2))
.collect();
let dist = match WeightedIndex::new(&distances) {
Ok(t) => t,
Err(_) => {
let mut palette: Vec<(LAB, f32)> = means.iter().map(|c| (c.clone(), 0.0)).collect();
let len = pixels.len() as f32;
for color in pixels.iter() {
let near = color.nearest(&means).0;
palette[near].1 += 1.0 / len;
}
return palette;
}
};
means.push(pixels[dist.sample(&mut rng)].clone());
}
let mut clusters: Vec<Vec<LAB>>;
let mut iters_left = max_iter.unwrap_or(MAX_ITER);
loop {
clusters = find_clusters(pixels, &means, k);
let mut changed: bool = false;
for i in 0..clusters.len() {
let new_mean = recal_means(&clusters[i], weight);
if means[i].distance(&new_mean) > TOLERANCE {
changed = true;
}
means[i] = new_mean;
}
iters_left -= 1;
if !changed || iters_left <= 0 {
break;
}
}
return clusters
.iter()
.enumerate()
.map(|(i, cluster)| {
(
means[i].clone(),
cluster.len() as f32 / pixels.len() as f32,
)
})
.collect();
}
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen]
pub fn pigments(canvas: HtmlCanvasElement, k: u8, mood: Mood, batch_size: Option<u32>) -> JsValue {
#[derive(Serialize)]
struct PaletteColor {
pub dominance: f32,
pub hex: String,
pub rgb: RGB,
pub hsl: HSL,
}
let ctx = canvas
.get_context("2d")
.unwrap()
.unwrap()
.dyn_into::<CanvasRenderingContext2d>()
.unwrap();
let data = ctx
.get_image_data(0.0, 0.0, canvas.width() as f64, canvas.height() as f64)
.unwrap()
.data();
let mut pixels: Pixels = (0..data.len())
.step_by(4)
.map(|i| LAB::from_rgb(data[i], data[i+1], data[i+2]))
.collect();
let batch = batch_size.unwrap_or(0);
if batch != 0 && batch < canvas.width() * canvas.height() && batch > k.into() {
let mut rng = rand::thread_rng();
pixels = pixels
.choose_multiple(&mut rng, batch as usize)
.cloned()
.collect();
}
let weight: WeightFn = resolve_mood(&mood);
let palettes: Vec<PaletteColor> = pigments_pixels(&pixels, k, weight, None)
.iter()
.map(|(color, dominance)| {
let rgb = RGB::from(color);
PaletteColor {
dominance: *dominance,
hex: rgb.hex(),
rgb: rgb,
hsl: HSL::from(color),
}
})
.collect();
return JsValue::from_serde(&palettes).unwrap();
}