use crate::color::{ColorError, ColorResult};
use crate::core::{Pix, PixColormap, PixelDepth, pixel};
const MAX_ALLOWED_ITERATIONS: u32 = 20;
const DIST_EXPAND_FACTOR: f32 = 1.3;
#[derive(Debug, Clone)]
pub struct ColorSegmentOptions {
pub max_dist: u32,
pub max_colors: u32,
pub sel_size: u32,
pub final_colors: u32,
}
impl Default for ColorSegmentOptions {
fn default() -> Self {
Self {
max_dist: 75,
max_colors: 10,
sel_size: 4,
final_colors: 5,
}
}
}
impl ColorSegmentOptions {
pub fn for_colors(final_colors: u32) -> Self {
let (max_colors, max_dist) = match final_colors {
1..=3 => (6, 100),
4 => (8, 90),
5 => (10, 75),
_ => (final_colors * 2, 60),
};
Self {
max_dist,
max_colors,
sel_size: 4,
final_colors,
}
}
}
pub fn color_segment(pix: &Pix, options: &ColorSegmentOptions) -> ColorResult<Pix> {
if pix.depth() != PixelDepth::Bit32 {
return Err(ColorError::UnsupportedDepth {
expected: "32 bpp",
actual: pix.depth().bits(),
});
}
if options.max_colors == 0 || options.max_colors > 256 {
return Err(ColorError::InvalidParameters(
"max_colors must be between 1 and 256".into(),
));
}
if options.final_colors == 0 || options.final_colors > options.max_colors {
return Err(ColorError::InvalidParameters(
"final_colors must be between 1 and max_colors".into(),
));
}
let pix_clustered = color_segment_cluster(pix, options.max_dist, options.max_colors)?;
let pix_refined = Pix::new(pix.width(), pix.height(), PixelDepth::Bit8)?;
let mut refined_mut = pix_refined.try_into_mut().unwrap();
let colormap = pix_clustered
.colormap()
.ok_or_else(|| ColorError::InvalidParameters("clustered result has no colormap".into()))?;
refined_mut.set_colormap(Some(colormap.clone()))?;
let counts = assign_to_nearest_color(&mut refined_mut, pix, &pix_clustered, None)?;
let refined_pix: Pix = refined_mut.into();
let final_pix = color_segment_remove_colors(&refined_pix, pix, options.final_colors, &counts)?;
Ok(final_pix)
}
pub fn color_segment_simple(pix: &Pix, final_colors: u32) -> ColorResult<Pix> {
let options = ColorSegmentOptions::for_colors(final_colors);
color_segment(pix, &options)
}
pub fn color_segment_cluster(pix: &Pix, max_dist: u32, max_colors: u32) -> ColorResult<Pix> {
if pix.depth() != PixelDepth::Bit32 {
return Err(ColorError::UnsupportedDepth {
expected: "32 bpp",
actual: pix.depth().bits(),
});
}
if max_colors == 0 || max_colors > 256 {
return Err(ColorError::InvalidParameters(
"max_colors must be between 1 and 256".into(),
));
}
let mut current_dist = max_dist;
for _ in 0..MAX_ALLOWED_ITERATIONS {
match cluster_try(pix, current_dist, max_colors) {
Ok(result) => return Ok(result),
Err(ClusterError::TooManyColors) => {
current_dist = (current_dist as f32 * DIST_EXPAND_FACTOR) as u32;
}
Err(ClusterError::Other(e)) => return Err(e),
}
}
Err(ColorError::QuantizationError(format!(
"failed to cluster after {} iterations (final dist={})",
MAX_ALLOWED_ITERATIONS, current_dist
)))
}
pub fn assign_to_nearest_color(
dest: &mut crate::core::PixMut,
src: &Pix,
reference: &Pix,
mask: Option<&Pix>,
) -> ColorResult<Vec<u32>> {
let colormap = reference
.colormap()
.ok_or_else(|| ColorError::InvalidParameters("reference image has no colormap".into()))?;
let w = src.width();
let h = src.height();
if dest.width() != w || dest.height() != h {
return Err(ColorError::InvalidParameters(
"source and dest dimensions must match".into(),
));
}
if let Some(m) = mask
&& (m.width() != w || m.height() != h)
{
return Err(ColorError::InvalidParameters(
"mask dimensions must match source".into(),
));
}
let mut counts = vec![0u32; colormap.len()];
for y in 0..h {
for x in 0..w {
if let Some(m) = mask {
let mask_val = m.get_pixel_unchecked(x, y);
if mask_val == 0 {
continue;
}
}
let pixel = src.get_pixel_unchecked(x, y);
let (r, g, b) = pixel::extract_rgb(pixel);
let idx = colormap.find_nearest(r, g, b).unwrap_or(0);
dest.set_pixel_unchecked(x, y, idx as u32);
counts[idx] += 1;
}
}
Ok(counts)
}
enum ClusterError {
TooManyColors,
Other(ColorError),
}
fn cluster_try(pix: &Pix, max_dist: u32, max_colors: u32) -> Result<Pix, ClusterError> {
let w = pix.width();
let h = pix.height();
let max_dist_sq = (max_dist as i64) * (max_dist as i64);
let pix_out = Pix::new(w, h, PixelDepth::Bit8).map_err(|e| ClusterError::Other(e.into()))?;
let mut out_mut = pix_out.try_into_mut().unwrap();
let mut rmap: Vec<u8> = Vec::with_capacity(max_colors as usize);
let mut gmap: Vec<u8> = Vec::with_capacity(max_colors as usize);
let mut bmap: Vec<u8> = Vec::with_capacity(max_colors as usize);
let mut rsum: Vec<u64> = Vec::with_capacity(max_colors as usize);
let mut gsum: Vec<u64> = Vec::with_capacity(max_colors as usize);
let mut bsum: Vec<u64> = Vec::with_capacity(max_colors as usize);
let mut counts: Vec<u64> = Vec::with_capacity(max_colors as usize);
for y in 0..h {
for x in 0..w {
let pixel = pix.get_pixel_unchecked(x, y);
let (r, g, b) = pixel::extract_rgb(pixel);
let mut found = false;
let ncolors = rmap.len();
for k in 0..ncolors {
let dr = r as i64 - rmap[k] as i64;
let dg = g as i64 - gmap[k] as i64;
let db = b as i64 - bmap[k] as i64;
let dist_sq = dr * dr + dg * dg + db * db;
if dist_sq <= max_dist_sq {
out_mut.set_pixel_unchecked(x, y, k as u32);
rsum[k] += r as u64;
gsum[k] += g as u64;
bsum[k] += b as u64;
counts[k] += 1;
found = true;
break;
}
}
if !found {
if ncolors >= max_colors as usize {
return Err(ClusterError::TooManyColors);
}
let idx = ncolors;
rmap.push(r);
gmap.push(g);
bmap.push(b);
rsum.push(r as u64);
gsum.push(g as u64);
bsum.push(b as u64);
counts.push(1);
out_mut.set_pixel_unchecked(x, y, idx as u32);
}
}
}
let mut colormap = PixColormap::new(8).map_err(|e| ClusterError::Other(e.into()))?;
for k in 0..rmap.len() {
let count = counts[k];
if count > 0 {
let avg_r = (rsum[k] / count) as u8;
let avg_g = (gsum[k] / count) as u8;
let avg_b = (bsum[k] / count) as u8;
colormap
.add_rgb(avg_r, avg_g, avg_b)
.map_err(|e| ClusterError::Other(e.into()))?;
}
}
out_mut
.set_colormap(Some(colormap))
.map_err(|e| ClusterError::Other(e.into()))?;
Ok(out_mut.into())
}
fn color_segment_remove_colors(
pix_dest: &Pix,
_pix_src: &Pix,
final_colors: u32,
counts: &[u32],
) -> ColorResult<Pix> {
let colormap = pix_dest
.colormap()
.ok_or_else(|| ColorError::InvalidParameters("dest image has no colormap".into()))?;
let ncolors = colormap.len();
if ncolors <= final_colors as usize {
return Ok(pix_dest.clone());
}
let mut indices: Vec<usize> = (0..ncolors).collect();
indices.sort_by(|a, b| counts[*b].cmp(&counts[*a]));
let mut index_map: Vec<u8> = vec![0; ncolors];
let mut new_colormap = PixColormap::new(8)?;
for (new_idx, &old_idx) in indices[..final_colors as usize].iter().enumerate() {
let (r, g, b) = colormap.get_rgb(old_idx).unwrap();
new_colormap.add_rgb(r, g, b)?;
index_map[old_idx] = new_idx as u8;
}
for &old_idx in &indices[final_colors as usize..] {
let (r, g, b) = colormap.get_rgb(old_idx).unwrap();
let mut min_dist = i64::MAX;
let mut best_new_idx = 0;
for (new_idx, &kept_idx) in indices[..final_colors as usize].iter().enumerate() {
let (kr, kg, kb) = colormap.get_rgb(kept_idx).unwrap();
let dr = r as i64 - kr as i64;
let dg = g as i64 - kg as i64;
let db = b as i64 - kb as i64;
let dist = dr * dr + dg * dg + db * db;
if dist < min_dist {
min_dist = dist;
best_new_idx = new_idx;
}
}
index_map[old_idx] = best_new_idx as u8;
}
let pix_out = Pix::new(pix_dest.width(), pix_dest.height(), PixelDepth::Bit8)?;
let mut out_mut = pix_out.try_into_mut().unwrap();
out_mut.set_colormap(Some(new_colormap))?;
for y in 0..pix_dest.height() {
for x in 0..pix_dest.width() {
let old_idx = pix_dest.get_pixel_unchecked(x, y) as usize;
let new_idx = if old_idx < index_map.len() {
index_map[old_idx]
} else {
0
};
out_mut.set_pixel_unchecked(x, y, new_idx as u32);
}
}
Ok(out_mut.into())
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_image() -> Pix {
let pix = Pix::new(60, 60, PixelDepth::Bit32).unwrap();
let mut pix_mut = pix.try_into_mut().unwrap();
for y in 0..60 {
for x in 0..60 {
let pixel = if x < 20 {
pixel::compose_rgb(255, 0, 0) } else if x < 40 {
pixel::compose_rgb(0, 255, 0) } else {
pixel::compose_rgb(0, 0, 255) };
pix_mut.set_pixel_unchecked(x, y, pixel);
}
}
pix_mut.into()
}
fn create_gradient_image() -> Pix {
let pix = Pix::new(64, 64, PixelDepth::Bit32).unwrap();
let mut pix_mut = pix.try_into_mut().unwrap();
for y in 0..64 {
for x in 0..64 {
let r = (x * 4) as u8;
let g = (y * 4) as u8;
let b = 128;
let pixel = pixel::compose_rgb(r, g, b);
pix_mut.set_pixel_unchecked(x, y, pixel);
}
}
pix_mut.into()
}
#[test]
fn test_color_segment_simple_colors() {
let pix = create_test_image();
let result = color_segment_simple(&pix, 3).unwrap();
assert_eq!(result.depth(), PixelDepth::Bit8);
assert!(result.colormap().is_some());
let cmap = result.colormap().unwrap();
assert!(cmap.len() <= 3);
}
#[test]
fn test_color_segment_gradient() {
let pix = create_gradient_image();
let result = color_segment_simple(&pix, 5).unwrap();
assert_eq!(result.depth(), PixelDepth::Bit8);
assert!(result.colormap().is_some());
let cmap = result.colormap().unwrap();
assert!(cmap.len() <= 5);
}
#[test]
fn test_cluster_phase_only() {
let pix = create_test_image();
let result = color_segment_cluster(&pix, 100, 10).unwrap();
assert_eq!(result.depth(), PixelDepth::Bit8);
assert!(result.colormap().is_some());
let cmap = result.colormap().unwrap();
assert!(cmap.len() <= 10);
}
#[test]
fn test_wrong_depth() {
let pix = Pix::new(10, 10, PixelDepth::Bit8).unwrap();
let result = color_segment_simple(&pix, 5);
assert!(result.is_err());
let result = color_segment_cluster(&pix, 75, 10);
assert!(result.is_err());
}
#[test]
fn test_invalid_params() {
let pix = create_test_image();
let result = color_segment_cluster(&pix, 75, 0);
assert!(result.is_err());
let result = color_segment_cluster(&pix, 75, 257);
assert!(result.is_err());
}
#[test]
fn test_options_for_colors() {
let opts = ColorSegmentOptions::for_colors(3);
assert_eq!(opts.final_colors, 3);
assert_eq!(opts.max_colors, 6);
assert_eq!(opts.max_dist, 100);
let opts = ColorSegmentOptions::for_colors(5);
assert_eq!(opts.final_colors, 5);
assert_eq!(opts.max_colors, 10);
assert_eq!(opts.max_dist, 75);
}
#[test]
fn test_default_options() {
let opts = ColorSegmentOptions::default();
assert_eq!(opts.final_colors, 5);
assert_eq!(opts.max_colors, 10);
assert_eq!(opts.max_dist, 75);
assert_eq!(opts.sel_size, 4);
}
}