use std::path::Path;
use anyhow::{Context, Result};
use image::{DynamicImage, GenericImageView, imageops::FilterType};
use rand::Rng;
use super::scheme::{ColorScheme, Rgb};
#[derive(Debug, Clone)]
pub struct ExtractionOptions {
pub color_count: usize,
pub prefers_dark: Option<bool>,
pub contrast_ratio: f32,
pub background_intensity: f32,
}
impl Default for ExtractionOptions {
fn default() -> Self {
Self {
color_count: 16,
prefers_dark: None,
contrast_ratio: 3.0,
background_intensity: 0.6,
}
}
}
pub struct ColorExtractor {
max_dimension: u32,
sample_step: u32,
max_iterations: usize,
}
impl Default for ColorExtractor {
fn default() -> Self {
Self::new()
}
}
impl ColorExtractor {
pub fn new() -> Self {
Self {
max_dimension: 200,
sample_step: 4,
max_iterations: 20,
}
}
pub fn extract<P: AsRef<Path>>(&self, image_path: P, options: &ExtractionOptions) -> Result<ColorScheme> {
let path = image_path.as_ref();
let img = image::open(path).context("Failed to open image")?;
self.extract_from_image(&img, path.to_string_lossy().to_string(), options)
}
pub fn extract_from_image(&self, image: &DynamicImage, wallpaper_path: String, options: &ExtractionOptions) -> Result<ColorScheme> {
let resized = self.resize_image(image);
let pixels = self.sample_pixels(&resized);
if pixels.is_empty() {
anyhow::bail!("No valid pixels found in image");
}
let mut centroids = self.kmeans(&pixels, options.color_count);
centroids.sort_by(|a, b| a.luminance().partial_cmp(&b.luminance()).unwrap());
Ok(self.generate_scheme(wallpaper_path, centroids, options))
}
fn resize_image(&self, image: &DynamicImage) -> DynamicImage {
let (width, height) = image.dimensions();
if width <= self.max_dimension && height <= self.max_dimension {
return image.clone();
}
let scale = self.max_dimension as f32 / width.max(height) as f32;
let new_width = (width as f32 * scale) as u32;
let new_height = (height as f32 * scale) as u32;
image.resize(new_width, new_height, FilterType::Triangle)
}
fn sample_pixels(&self, image: &DynamicImage) -> Vec<Rgb> {
let (width, height) = image.dimensions();
let rgba = image.to_rgba8();
let mut pixels = Vec::with_capacity((width * height / 16) as usize);
for y in (0..height).step_by(self.sample_step as usize) {
for x in (0..width).step_by(self.sample_step as usize) {
let pixel = rgba.get_pixel(x, y);
let [r, g, b, a] = pixel.0;
if a < 200 {
continue;
}
let rf = r as f32 / 255.0;
let gf = g as f32 / 255.0;
let bf = b as f32 / 255.0;
let brightness = (rf + gf + bf) / 3.0;
if brightness > 0.08 && brightness < 0.92 {
pixels.push(Rgb::new(rf, gf, bf));
}
}
}
if pixels.len() < 100 {
pixels.clear();
for y in (0..height).step_by(self.sample_step as usize) {
for x in (0..width).step_by(self.sample_step as usize) {
let pixel = rgba.get_pixel(x, y);
let [r, g, b, _] = pixel.0;
pixels.push(Rgb::from_u8(r, g, b));
}
}
}
pixels
}
fn kmeans(&self, pixels: &[Rgb], k: usize) -> Vec<Rgb> {
if pixels.len() <= k {
return pixels.to_vec();
}
let mut centroids = self.kmeans_plus_plus_init(pixels, k);
let mut assignments = vec![0usize; pixels.len()];
for _ in 0..self.max_iterations {
let mut changed = false;
for (i, pixel) in pixels.iter().enumerate() {
let mut min_dist = f32::MAX;
let mut min_idx = 0;
for (j, centroid) in centroids.iter().enumerate() {
let dist = pixel.distance_squared(centroid);
if dist < min_dist {
min_dist = dist;
min_idx = j;
}
}
if assignments[i] != min_idx {
assignments[i] = min_idx;
changed = true;
}
}
if !changed {
break;
}
let mut sums = vec![(0.0f32, 0.0f32, 0.0f32); k];
let mut counts = vec![0usize; k];
for (i, pixel) in pixels.iter().enumerate() {
let c = assignments[i];
sums[c].0 += pixel.r;
sums[c].1 += pixel.g;
sums[c].2 += pixel.b;
counts[c] += 1;
}
for (c, centroid) in centroids.iter_mut().enumerate() {
if counts[c] > 0 {
let count = counts[c] as f32;
*centroid = Rgb::new(sums[c].0 / count, sums[c].1 / count, sums[c].2 / count);
}
}
}
centroids
}
fn kmeans_plus_plus_init(&self, pixels: &[Rgb], k: usize) -> Vec<Rgb> {
let mut rng = rand::thread_rng();
let mut centroids = Vec::with_capacity(k);
let first_idx = rng.r#gen_range(0..pixels.len());
centroids.push(pixels[first_idx]);
let mut min_distances = vec![f32::MAX; pixels.len()];
for _ in 1..k {
let mut total_dist = 0.0f32;
for (i, pixel) in pixels.iter().enumerate() {
let dist = pixel.distance_squared(centroids.last().unwrap());
if dist < min_distances[i] {
min_distances[i] = dist;
}
total_dist += min_distances[i];
}
let threshold = rng.r#gen::<f32>() * total_dist;
let mut cumulative = 0.0f32;
let mut selected_idx = 0;
for (i, &dist) in min_distances.iter().enumerate() {
cumulative += dist;
if cumulative >= threshold {
selected_idx = i;
break;
}
}
centroids.push(pixels[selected_idx]);
}
centroids
}
fn generate_scheme(&self, wallpaper: String, dominant_colors: Vec<Rgb>, options: &ExtractionOptions) -> ColorScheme {
let is_dark = options.prefers_dark.unwrap_or_else(|| {
let avg_luminance: f32 = dominant_colors.iter().map(|c| c.luminance()).sum::<f32>() / dominant_colors.len() as f32;
avg_luminance < 0.5
});
let (background, foreground) = if is_dark {
let bg = dominant_colors
.first()
.map(|c| c.darkened(options.background_intensity))
.unwrap_or(Rgb::new(0.1, 0.1, 0.1));
(bg, Rgb::new(0.9, 0.9, 0.9))
} else {
let bg = dominant_colors
.last()
.map(|c| c.lightened(options.background_intensity))
.unwrap_or(Rgb::new(0.95, 0.95, 0.95));
(bg, Rgb::new(0.1, 0.1, 0.1))
};
let mut colors = Vec::with_capacity(16);
colors.push(background);
let selected = self.select_terminal_colors(&dominant_colors, 6, is_dark, options.contrast_ratio);
colors.extend(selected.iter().cloned());
colors.push(foreground);
colors.push(background.lightened(0.15));
for color in &selected {
if is_dark {
colors.push(color.saturated(1.2).lightened(0.15));
} else {
colors.push(color.saturated(1.1));
}
}
colors.push(foreground);
let cursor = dominant_colors.iter().find(|c| c.saturation() > 0.3).cloned().unwrap_or(foreground);
ColorScheme::new(wallpaper, is_dark, background, foreground, cursor, colors)
}
fn select_terminal_colors(&self, colors: &[Rgb], count: usize, is_dark: bool, contrast_ratio: f32) -> Vec<Rgb> {
let mut saturated: Vec<Rgb> = colors.iter().filter(|c| c.saturation() > 0.2).cloned().collect();
if saturated.len() < count {
saturated = colors.to_vec();
}
saturated.sort_by(|a, b| a.hue().partial_cmp(&b.hue()).unwrap());
let normalized = (contrast_ratio - 1.5) / 3.0;
let dark_threshold = 0.15 + normalized * 0.30;
let dark_adjustment = 0.10 + normalized * 0.25;
let light_threshold = 0.85 - normalized * 0.30;
let light_adjustment = 0.20 + normalized * 0.30;
let mut selected = Vec::with_capacity(count);
let step = (saturated.len() / count).max(1);
for i in (0..saturated.len().min(count * step)).step_by(step) {
let mut color = saturated[i];
if color.saturation() < 0.4 {
color = color.saturated(1.5);
}
if is_dark && color.luminance() < dark_threshold {
color = color.lightened(dark_adjustment);
} else if !is_dark && color.luminance() > light_threshold {
color = color.darkened(light_adjustment);
}
selected.push(color);
}
let defaults = [
Rgb::new(0.8, 0.2, 0.2), Rgb::new(0.2, 0.8, 0.2), Rgb::new(0.8, 0.8, 0.2), Rgb::new(0.2, 0.4, 0.8), Rgb::new(0.8, 0.2, 0.8), Rgb::new(0.2, 0.8, 0.8), ];
while selected.len() < count {
selected.push(defaults[selected.len() % defaults.len()]);
}
selected.truncate(count);
selected
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extraction_options_default() {
let opts = ExtractionOptions::default();
assert_eq!(opts.color_count, 16);
assert_eq!(opts.prefers_dark, None);
assert!((opts.contrast_ratio - 3.0).abs() < 0.001);
}
#[test]
fn test_kmeans_simple() {
let extractor = ColorExtractor::new();
let pixels = vec![
Rgb::new(1.0, 0.0, 0.0),
Rgb::new(1.0, 0.1, 0.0),
Rgb::new(0.0, 1.0, 0.0),
Rgb::new(0.0, 1.0, 0.1),
Rgb::new(0.0, 0.0, 1.0),
Rgb::new(0.1, 0.0, 1.0),
];
let centroids = extractor.kmeans(&pixels, 3);
assert_eq!(centroids.len(), 3);
}
}