use crate::Rgb;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum GamutCompressionMethod {
SoftClip,
Desaturate,
PreserveLightness,
RollOff,
}
pub struct GamutCompression {
method: GamutCompressionMethod,
threshold: f64,
strength: f64,
}
impl GamutCompression {
#[must_use]
pub fn new(method: GamutCompressionMethod, threshold: f64, strength: f64) -> Self {
Self {
method,
threshold,
strength,
}
}
#[must_use]
pub fn soft_clip() -> Self {
Self::new(GamutCompressionMethod::SoftClip, 0.8, 0.5)
}
#[must_use]
pub fn compress(&self, rgb: &Rgb) -> Rgb {
match self.method {
GamutCompressionMethod::SoftClip => self.soft_clip_compress(rgb),
GamutCompressionMethod::Desaturate => self.desaturate_compress(rgb),
GamutCompressionMethod::PreserveLightness => self.preserve_lightness_compress(rgb),
GamutCompressionMethod::RollOff => self.rolloff_compress(rgb),
}
}
fn soft_clip_compress(&self, rgb: &Rgb) -> Rgb {
[
self.soft_clip_channel(rgb[0]),
self.soft_clip_channel(rgb[1]),
self.soft_clip_channel(rgb[2]),
]
}
fn soft_clip_channel(&self, value: f64) -> f64 {
if value <= self.threshold {
value
} else {
let excess = value - self.threshold;
let range = 1.0 - self.threshold;
let t = excess / (excess + range * self.strength);
self.threshold + t * range
}
}
fn sigmoid(&self, x: f64) -> f64 {
1.0 / (1.0 + (-10.0 * (x - 0.5)).exp())
}
fn desaturate_compress(&self, rgb: &Rgb) -> Rgb {
let max_val = rgb[0].max(rgb[1]).max(rgb[2]);
if max_val <= 1.0 {
return *rgb;
}
let lightness = (rgb[0] + rgb[1] + rgb[2]) / 3.0;
let scale = (1.0 - self.strength) + self.strength * (1.0 / max_val);
[
(lightness + (rgb[0] - lightness) * scale).clamp(0.0, 1.0),
(lightness + (rgb[1] - lightness) * scale).clamp(0.0, 1.0),
(lightness + (rgb[2] - lightness) * scale).clamp(0.0, 1.0),
]
}
fn preserve_lightness_compress(&self, rgb: &Rgb) -> Rgb {
self.desaturate_compress(rgb)
}
fn rolloff_compress(&self, rgb: &Rgb) -> Rgb {
[
self.rolloff_channel(rgb[0]),
self.rolloff_channel(rgb[1]),
self.rolloff_channel(rgb[2]),
]
}
fn rolloff_channel(&self, value: f64) -> f64 {
if value <= self.threshold {
value
} else if value >= 1.0 {
1.0
} else {
let t = (value - self.threshold) / (1.0 - self.threshold);
let compressed = t * t * (3.0 - 2.0 * t); self.threshold + compressed * (1.0 - self.threshold)
}
}
#[must_use]
pub fn compress_image(&self, image_data: &[u8]) -> Vec<u8> {
let mut output = Vec::with_capacity(image_data.len());
for chunk in image_data.chunks_exact(3) {
let r = f64::from(chunk[0]) / 255.0;
let g = f64::from(chunk[1]) / 255.0;
let b = f64::from(chunk[2]) / 255.0;
let compressed = self.compress(&[r, g, b]);
output.push((compressed[0] * 255.0).round() as u8);
output.push((compressed[1] * 255.0).round() as u8);
output.push((compressed[2] * 255.0).round() as u8);
}
output
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gamut_compression_new() {
let comp = GamutCompression::new(GamutCompressionMethod::SoftClip, 0.8, 0.5);
assert_eq!(comp.method, GamutCompressionMethod::SoftClip);
assert!((comp.threshold - 0.8).abs() < 1e-10);
assert!((comp.strength - 0.5).abs() < 1e-10);
}
#[test]
fn test_soft_clip() {
let comp = GamutCompression::soft_clip();
let rgb = [0.5, 0.6, 0.7];
let result = comp.compress(&rgb);
assert!((result[0] - 0.5).abs() < 0.1);
assert!((result[1] - 0.6).abs() < 0.1);
assert!((result[2] - 0.7).abs() < 0.1);
}
#[test]
fn test_soft_clip_over_threshold() {
let comp = GamutCompression::soft_clip();
let rgb = [1.5, 1.2, 1.0];
let result = comp.compress(&rgb);
assert!(result[0] <= 1.0);
assert!(result[1] <= 1.0);
assert!(result[2] <= 1.0);
}
#[test]
fn test_desaturate() {
let comp = GamutCompression::new(GamutCompressionMethod::Desaturate, 0.8, 0.8);
let rgb = [1.5, 0.8, 0.6];
let result = comp.compress(&rgb);
assert!(result[0] <= 1.0);
assert!(result[1] <= 1.0);
assert!(result[2] <= 1.0);
}
#[test]
fn test_rolloff() {
let comp = GamutCompression::new(GamutCompressionMethod::RollOff, 0.8, 0.5);
let rgb = [0.9, 0.85, 0.7];
let result = comp.compress(&rgb);
assert!(result[0] >= 0.8);
assert!(result[0] <= 1.0);
}
#[test]
fn test_compress_image() {
let comp = GamutCompression::soft_clip();
let image = vec![128, 128, 128, 255, 0, 0];
let output = comp.compress_image(&image);
assert_eq!(output.len(), image.len());
}
#[test]
fn test_sigmoid() {
let comp = GamutCompression::soft_clip();
let result = comp.sigmoid(0.5);
assert!((result - 0.5).abs() < 0.1);
let result = comp.sigmoid(0.0);
assert!(result < 0.1);
let result = comp.sigmoid(1.0);
assert!(result > 0.9);
}
}