use crate::color::{mean_color, ColorPalette};
use crate::error::{DominantColorError, Result};
use crate::Config;
pub fn extract(pixels: &[[u8; 3]], config: &Config) -> Result<ColorPalette> {
let k = config.max_colors.min(pixels.len());
if pixels.is_empty() {
return Err(DominantColorError::EmptyImage);
}
let mut buckets: Vec<Vec<[u8; 3]>> = vec![pixels.to_vec()];
while buckets.len() < k {
let split_idx = buckets
.iter()
.enumerate()
.max_by_key(|(_, b)| b.len())
.map(|(i, _)| i)
.ok_or_else(|| DominantColorError::internal("桶列表意外为空"))?;
if buckets[split_idx].len() < 2 {
break; }
let bucket = buckets.remove(split_idx);
let (left, right) = split_bucket(bucket);
buckets.push(left);
buckets.push(right);
}
let total = pixels.len() as f32;
let palette: ColorPalette = buckets
.iter()
.filter(|b| !b.is_empty())
.filter_map(|b| mean_color(b, b.len() as f32 / total))
.collect();
if palette.is_empty() {
return Err(DominantColorError::internal("中位切分后所有桶均为空"));
}
Ok(palette)
}
const R: usize = 0;
const G: usize = 1;
const B: usize = 2;
fn longest_axis(pixels: &[[u8; 3]]) -> usize {
let (mut r_min, mut r_max) = (u8::MAX, u8::MIN);
let (mut g_min, mut g_max) = (u8::MAX, u8::MIN);
let (mut b_min, mut b_max) = (u8::MAX, u8::MIN);
for p in pixels {
r_min = r_min.min(p[R]);
r_max = r_max.max(p[R]);
g_min = g_min.min(p[G]);
g_max = g_max.max(p[G]);
b_min = b_min.min(p[B]);
b_max = b_max.max(p[B]);
}
let ranges = [
r_max.saturating_sub(r_min),
g_max.saturating_sub(g_min),
b_max.saturating_sub(b_min),
];
ranges
.iter()
.enumerate()
.max_by_key(|(_, &v)| v)
.map(|(i, _)| i)
.unwrap_or(R)
}
fn split_bucket(mut pixels: Vec<[u8; 3]>) -> (Vec<[u8; 3]>, Vec<[u8; 3]>) {
let axis = longest_axis(&pixels);
pixels.sort_unstable_by_key(|p| p[axis]);
let mid = pixels.len() / 2;
let right = pixels.split_off(mid);
(pixels, right)
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg(k: usize) -> Config {
Config::default().max_colors(k).sample_size(None)
}
#[test]
fn test_extract_empty() {
let result = extract(&[], &cfg(4));
assert_eq!(result, Err(DominantColorError::EmptyImage));
}
#[test]
fn test_single_pixel() {
let pixels = vec![[128u8, 64, 32]];
let palette = extract(&pixels, &cfg(3)).unwrap();
assert_eq!(palette.len(), 1);
assert!((palette[0].percentage - 1.0).abs() < 1e-5);
}
#[test]
fn test_two_distinct_colors() {
let mut pixels = vec![[0u8, 0, 255]; 50];
pixels.extend(vec![[255u8, 0, 0]; 50]);
let palette = extract(&pixels, &cfg(2)).unwrap();
assert_eq!(palette.len(), 2);
for c in &palette {
assert!((c.percentage - 0.5).abs() < 0.05, "各色应占约 50%");
}
}
#[test]
fn test_percentages_sum_to_one() {
let pixels: Vec<[u8; 3]> = (0..200u8).map(|i| [i, i.wrapping_add(10), 100]).collect();
let palette = extract(&pixels, &cfg(8)).unwrap();
let total: f32 = palette.iter().map(|c| c.percentage).sum();
assert!((total - 1.0).abs() < 1e-5, "占比之和 = {total}");
}
#[test]
fn test_k_exceeds_unique_pixels() {
let pixels = vec![[255u8, 0, 0]; 3];
let palette = extract(&pixels, &cfg(10)).unwrap();
assert!(palette.len() <= 3, "结果颜色数不应超过唯一像素数");
}
#[test]
fn test_longest_axis_red() {
let pixels = vec![[0u8, 5, 5], [255, 5, 5]];
assert_eq!(longest_axis(&pixels), R);
}
#[test]
fn test_longest_axis_green() {
let pixels = vec![[5u8, 0, 5], [5, 200, 5]];
assert_eq!(longest_axis(&pixels), G);
}
#[test]
fn test_longest_axis_blue() {
let pixels = vec![[5u8, 5, 10], [5, 5, 250]];
assert_eq!(longest_axis(&pixels), B);
}
#[test]
fn test_split_bucket_even() {
let pixels = vec![[0u8, 0, 0], [100, 0, 0], [200, 0, 0], [255, 0, 0]];
let (left, right) = split_bucket(pixels);
assert_eq!(left.len(), 2);
assert_eq!(right.len(), 2);
}
#[test]
fn test_deterministic() {
let pixels: Vec<[u8; 3]> = (0..100u8).map(|i| [i, 255 - i, i / 2]).collect();
let p1 = extract(&pixels, &cfg(5)).unwrap();
let p2 = extract(&pixels, &cfg(5)).unwrap();
assert_eq!(p1.len(), p2.len());
for (a, b) in p1.iter().zip(p2.iter()) {
assert_eq!((a.r, a.g, a.b), (b.r, b.g, b.b), "相同输入结果应一致");
}
}
#[test]
fn test_gradient_image_color_count() {
let pixels: Vec<[u8; 3]> = (0..=255u8).map(|i| [i, 0, 0]).collect();
let palette = extract(&pixels, &cfg(8)).unwrap();
assert_eq!(palette.len(), 8, "渐变图应被切分为 8 个桶");
}
}