use crate::color::{Color, ColorPalette};
use crate::error::{DominantColorError, Result};
use crate::Config;
const MAX_DEPTH: usize = 7;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
struct NodeId(usize);
#[derive(Clone)]
struct Node {
r_sum: u64,
g_sum: u64,
b_sum: u64,
pixel_count: u64,
is_leaf: bool,
in_reducible: bool,
parent: Option<NodeId>,
children: [Option<NodeId>; 8],
}
impl Node {
fn new(parent: Option<NodeId>) -> Self {
Self {
r_sum: 0,
g_sum: 0,
b_sum: 0,
pixel_count: 0,
is_leaf: false,
in_reducible: false,
parent,
children: [None; 8],
}
}
}
struct Octree {
nodes: Vec<Node>,
root: NodeId,
leaf_count: usize,
reducible: [Vec<NodeId>; MAX_DEPTH],
}
impl Octree {
fn new() -> Self {
let mut nodes = Vec::with_capacity(8192);
nodes.push(Node::new(None));
Self {
nodes,
root: NodeId(0),
leaf_count: 0,
reducible: std::array::from_fn(|_| Vec::new()),
}
}
fn alloc_child(&mut self, parent: NodeId) -> NodeId {
let id = NodeId(self.nodes.len());
self.nodes.push(Node::new(Some(parent)));
id
}
fn register(&mut self, id: NodeId, depth: usize) {
if !self.nodes[id.0].in_reducible {
self.nodes[id.0].in_reducible = true;
self.reducible[depth].push(id);
}
}
fn insert(&mut self, pixel: [u8; 3]) {
let mut id = self.root;
for depth in 0..=MAX_DEPTH {
if depth == MAX_DEPTH {
let node = &mut self.nodes[id.0];
node.r_sum += pixel[0] as u64;
node.g_sum += pixel[1] as u64;
node.b_sum += pixel[2] as u64;
node.pixel_count += 1;
if !node.is_leaf {
node.is_leaf = true;
self.leaf_count += 1;
}
break;
}
let idx = octant_index(pixel, depth);
if self.nodes[id.0].children[idx].is_none() {
let child = self.alloc_child(id);
self.nodes[id.0].children[idx] = Some(child);
if depth + 1 == MAX_DEPTH {
self.register(id, depth);
}
}
id = self.nodes[id.0].children[idx].unwrap();
}
}
fn reduce(&mut self, k: usize) -> bool {
let depth = match self.reducible.iter().rposition(|v| !v.is_empty()) {
Some(d) => d,
None => return false,
};
let node_id = self.reducible[depth].pop().unwrap();
self.nodes[node_id.0].in_reducible = false;
let already_leaf = self.nodes[node_id.0].is_leaf;
let budget = self.leaf_count.saturating_sub(k);
let mut leaf_indices: Vec<usize> = (0..8)
.filter(|&i| {
self.nodes[node_id.0].children[i].map_or(false, |cid| self.nodes[cid.0].is_leaf)
})
.collect();
if leaf_indices.is_empty() {
return true; }
leaf_indices.sort_by_key(|&i| {
self.nodes[node_id.0].children[i].map_or(0, |cid| self.nodes[cid.0].pixel_count)
});
let n = leaf_indices.len();
let max_merge = if already_leaf {
budget.max(1) } else {
budget + 1
};
let merge_count = n.min(max_merge);
let (mut r_acc, mut g_acc, mut b_acc, mut pc_acc) = (0u64, 0u64, 0u64, 0u64);
for &i in &leaf_indices[..merge_count] {
let cid = self.nodes[node_id.0].children[i].unwrap();
r_acc += self.nodes[cid.0].r_sum;
g_acc += self.nodes[cid.0].g_sum;
b_acc += self.nodes[cid.0].b_sum;
pc_acc += self.nodes[cid.0].pixel_count;
}
for &i in &leaf_indices[..merge_count] {
self.nodes[node_id.0].children[i] = None;
}
self.nodes[node_id.0].r_sum += r_acc;
self.nodes[node_id.0].g_sum += g_acc;
self.nodes[node_id.0].b_sum += b_acc;
self.nodes[node_id.0].pixel_count += pc_acc;
self.leaf_count -= merge_count;
if !already_leaf {
self.nodes[node_id.0].is_leaf = true;
self.leaf_count += 1;
}
if depth > 0 {
if let Some(pid) = self.nodes[node_id.0].parent {
if !self.nodes[pid.0].is_leaf {
self.register(pid, depth - 1);
}
}
}
if merge_count < n {
self.register(node_id, depth);
}
true
}
fn collect_leaves(&self, palette: &mut ColorPalette, total: f32) {
self.collect_recursive(self.root, palette, total);
}
fn collect_recursive(&self, id: NodeId, palette: &mut ColorPalette, total: f32) {
let node = &self.nodes[id.0];
if node.is_leaf && node.pixel_count > 0 {
let n = node.pixel_count as f64;
palette.push(Color::new(
(node.r_sum as f64 / n).round() as u8,
(node.g_sum as f64 / n).round() as u8,
(node.b_sum as f64 / n).round() as u8,
node.pixel_count as f32 / total,
));
}
for child_id in node.children.iter().filter_map(|&c| c) {
self.collect_recursive(child_id, palette, total);
}
}
}
pub fn extract(pixels: &[[u8; 3]], config: &Config) -> Result<ColorPalette> {
if pixels.is_empty() {
return Err(DominantColorError::EmptyImage);
}
let k = config.max_colors;
let mut octree = Octree::new();
for &pixel in pixels {
octree.insert(pixel);
}
while octree.leaf_count > k {
if !octree.reduce(k) {
break;
}
}
let total = pixels.len() as f32;
let mut palette = ColorPalette::new();
octree.collect_leaves(&mut palette, total);
if palette.is_empty() {
return Err(DominantColorError::internal("八叉树未产生任何叶节点"));
}
Ok(palette)
}
#[inline]
fn octant_index(pixel: [u8; 3], depth: usize) -> usize {
let shift = MAX_DEPTH - depth;
let r = ((pixel[0] >> shift) & 1) as usize;
let g = ((pixel[1] >> shift) & 1) as usize;
let b = ((pixel[2] >> shift) & 1) as usize;
(r << 2) | (g << 1) | b
}
#[cfg(test)]
mod tests {
use super::*;
fn cfg(k: usize) -> Config {
Config::default().max_colors(k).sample_size(None)
}
#[test]
fn test_empty_returns_error() {
assert_eq!(extract(&[], &cfg(4)), Err(DominantColorError::EmptyImage));
}
#[test]
fn test_single_pixel() {
let pixels = vec![[200u8, 100, 50]];
let palette = extract(&pixels, &cfg(4)).unwrap();
assert_eq!(palette.len(), 1);
assert!((palette[0].percentage - 1.0).abs() < 1e-5);
}
#[test]
fn test_two_colors_separated() {
let mut pixels = vec![[255u8, 0, 0]; 60];
pixels.extend(vec![[0u8, 0, 255]; 40]);
let palette = extract(&pixels, &cfg(2)).unwrap();
assert_eq!(palette.len(), 2);
assert!(palette.iter().any(|c| c.r > 200 && c.b < 50), "缺少红色");
assert!(palette.iter().any(|c| c.b > 200 && c.r < 50), "缺少蓝色");
}
#[test]
fn test_percentages_sum_to_one() {
let pixels: Vec<[u8; 3]> = (0..128u8)
.map(|i| [i, i.wrapping_mul(2), 255 - i])
.collect();
let palette = extract(&pixels, &cfg(6)).unwrap();
let total: f32 = palette.iter().map(|c| c.percentage).sum();
assert!((total - 1.0).abs() < 1e-4, "占比之和 = {total}");
}
#[test]
fn test_leaf_count_respects_k() {
let pixels: Vec<[u8; 3]> = (0..=255u8).map(|i| [i, i, i]).collect();
let k = 5;
let palette = extract(&pixels, &cfg(k)).unwrap();
assert!(palette.len() <= k, "期望 ≤{k},实际 {}", palette.len());
}
#[test]
fn test_exactly_k_distinct_colors() {
let mut pixels = Vec::new();
for (i, &color) in [
[255u8, 0, 0],
[0, 255, 0],
[0, 0, 255],
[255, 255, 0],
[255, 0, 255],
[0, 255, 255],
[128, 0, 0],
[0, 128, 0],
]
.iter()
.enumerate()
{
pixels.extend(vec![color; 200 + i * 50]);
}
let palette = extract(&pixels, &cfg(8)).unwrap();
assert!(
palette.len() <= 8,
"颜色数量 {} 超过了预期的上限 8",
palette.len()
);
assert!(palette.len() > 0, "调色板不应为空");
}
#[test]
fn test_no_data_loss() {
let pixels: Vec<[u8; 3]> = (0..=255u8)
.flat_map(|i| vec![[i, 255 - i, i / 2]; 3])
.collect();
let palette = extract(&pixels, &cfg(8)).unwrap();
let total: f32 = palette.iter().map(|c| c.percentage).sum();
assert!(
(total - 1.0).abs() < 1e-4,
"占比之和 = {total},疑似数据丢失"
);
}
#[test]
fn test_octant_index_range() {
for r in [0u8, 127, 255] {
for g in [0u8, 127, 255] {
for b in [0u8, 127, 255] {
for depth in 0..=MAX_DEPTH {
assert!(octant_index([r, g, b], depth) < 8);
}
}
}
}
}
#[test]
fn test_deterministic() {
let pixels: Vec<[u8; 3]> = (0..200u8).map(|i| [i, 200 - i, i / 2]).collect();
let p1 = extract(&pixels, &cfg(6)).unwrap();
let p2 = extract(&pixels, &cfg(6)).unwrap();
assert_eq!(p1.len(), p2.len());
for (a, b) in p1.iter().zip(p2.iter()) {
assert_eq!((a.r, a.g, a.b), (b.r, b.g, b.b));
}
}
#[test]
fn test_k1_returns_one_color() {
let pixels: Vec<[u8; 3]> = (0..50u8).map(|i| [i * 5, i, 100]).collect();
let palette = extract(&pixels, &cfg(1)).unwrap();
assert_eq!(palette.len(), 1);
assert!((palette[0].percentage - 1.0).abs() < 1e-4);
}
#[test]
fn test_gradient() {
let pixels: Vec<[u8; 3]> = (0..=255u8).map(|i| [i, 0, 255 - i]).collect();
let palette = extract(&pixels, &cfg(8)).unwrap();
assert!(palette.len() > 0 && palette.len() <= 8);
assert!(palette.iter().any(|c| c.r > 180 && c.b < 80), "缺偏红色");
assert!(palette.iter().any(|c| c.b > 180 && c.r < 80), "缺偏蓝色");
}
}