use glam::{Vec3A, Vec4};
#[derive(Debug, Clone, Copy, PartialEq, Default)]
#[repr(transparent)]
pub(crate) struct ColorSpace(pub Vec3A);
impl std::ops::Add for ColorSpace {
type Output = Self;
fn add(self, rhs: Self) -> Self {
ColorSpace(self.0 + rhs.0)
}
}
impl std::ops::AddAssign for ColorSpace {
fn add_assign(&mut self, rhs: Self) {
self.0 += rhs.0;
}
}
impl std::ops::Sub for ColorSpace {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
ColorSpace(self.0 - rhs.0)
}
}
impl std::ops::Mul<f32> for ColorSpace {
type Output = Self;
fn mul(self, rhs: f32) -> Self {
ColorSpace(self.0 * rhs)
}
}
impl From<ColorSpace> for Vec3A {
#[inline(always)]
fn from(value: ColorSpace) -> Self {
value.0
}
}
pub(crate) fn block_dither<T>(block: impl Block4x4<T>, mut get_closest: impl FnMut(usize, T) -> T)
where
T: Copy
+ Default
+ std::ops::Add<Output = T>
+ std::ops::AddAssign
+ std::ops::Sub<Output = T>
+ std::ops::Mul<f32, Output = T>,
{
let mut error_map: [T; 16] = Default::default();
for y in 0..4 {
for x in 0..4 {
let pixel_index = y * 4 + x;
let pixel = block.get_pixel_at(pixel_index) + error_map[pixel_index];
let closest = get_closest(pixel_index, pixel);
let error = pixel - closest;
let mut weight_right = 7. / 16.;
let mut weight_next_left = 3. / 16.;
let mut weight_next_middle = 5. / 16.;
let mut weight_next_right = 1. / 16.;
if x == 0 {
weight_next_left = 0.0;
weight_next_middle = 6. / 16.;
weight_next_right = 2. / 16.;
}
if x == 3 {
weight_right = 0.0;
weight_next_left = 5. / 16.;
weight_next_middle = 7. / 16.;
weight_next_right = 0.0;
}
if y == 3 {
weight_right = 8. / 16.;
weight_next_left = 0.0;
weight_next_middle = 0.0;
weight_next_right = 0.0;
}
if x < 3 {
error_map[pixel_index + 1] += error * weight_right;
}
if y < 3 {
if x > 0 {
error_map[pixel_index + 4 - 1] += error * weight_next_left;
}
error_map[pixel_index + 4] += error * weight_next_middle;
if x < 3 {
error_map[pixel_index + 4 + 1] += error * weight_next_right;
}
}
}
}
}
pub(crate) trait Block4x4<T> {
fn get_pixel_at(&self, index: usize) -> T;
}
impl<T: Copy> Block4x4<T> for &[T; 16] {
#[inline(always)]
fn get_pixel_at(&self, index: usize) -> T {
self[index]
}
}
impl<T: Copy> Block4x4<T> for T {
#[inline(always)]
fn get_pixel_at(&self, _index: usize) -> T {
*self
}
}
#[derive(Debug, Clone)]
pub(crate) struct RefinementOptions {
pub step_initial: f32,
pub step_decay: f32,
pub step_min: f32,
pub max_iter: u32,
}
pub(crate) fn refine_endpoints<T: RefinementSteps, E: PartialOrd>(
min: T,
max: T,
options: RefinementOptions,
mut compute_error: impl FnMut((T, T)) -> E,
) -> (T, T) {
let mut step = options.step_initial;
let mut best = (min, max);
let mut iters = 0;
if !(step > options.step_min && iters < options.max_iter) {
return best;
}
let mut error = compute_error((min, max));
while step > options.step_min && iters < options.max_iter {
RefinementSteps::for_each_endpoint(best, step, |current| {
let new_error = compute_error(current);
if new_error < error {
error = new_error;
best = current;
}
});
step *= options.step_decay;
iters += 1;
}
best
}
pub(crate) trait RefinementSteps
where
Self: Copy + Sized,
{
fn for_each_endpoint(start: (Self, Self), step: f32, f: impl FnMut((Self, Self)));
}
impl RefinementSteps for f32 {
fn for_each_endpoint((min, max): (f32, f32), step: f32, mut f: impl FnMut((f32, f32))) {
for (delta_min, delta_max) in [(step, 0.0), (0.0, step), (-step, 0.0), (0.0, -step)] {
let new_min = (min + delta_min).clamp(0.0, 1.0);
let new_max = (max + delta_max).clamp(0.0, 1.0);
if new_min < new_max {
f((new_min, new_max));
}
}
}
}
impl RefinementSteps for Vec3A {
fn for_each_endpoint((min, max): (Vec3A, Vec3A), step: f32, mut f: impl FnMut((Vec3A, Vec3A))) {
let main_dir_1 = (min - max).try_normalize().unwrap_or(Vec3A::X);
let (main_dir_2, main_dir_3) = main_dir_1.any_orthonormal_pair();
let directions = [
main_dir_1 * step,
main_dir_2 * step,
main_dir_3 * step,
main_dir_1 * -step,
main_dir_2 * -step,
main_dir_3 * -step,
];
for &dir in &directions {
let new_min = (min + dir).clamp(Vec3A::ZERO, Vec3A::ONE);
f((new_min, max));
}
for &dir in &directions {
let new_max = (max + dir).clamp(Vec3A::ZERO, Vec3A::ONE);
f((min, new_max));
}
}
}
impl RefinementSteps for ColorSpace {
fn for_each_endpoint(
(min, max): (ColorSpace, ColorSpace),
step: f32,
mut f: impl FnMut((ColorSpace, ColorSpace)),
) {
Vec3A::for_each_endpoint((min.0, max.0), step, move |(min, max)| {
f((ColorSpace(min), ColorSpace(max)));
});
}
}
pub(crate) fn line3_fit_endpoints<C: Copy + Into<Vec3A>>(
colors: &[C],
nudge_factor: f32,
) -> (Vec3A, Vec3A) {
debug_assert!(!colors.is_empty());
let line = ColorLine3::new(colors);
let mut min_t = f32::INFINITY;
let mut max_t = f32::NEG_INFINITY;
for &color in colors.iter() {
let color: Vec3A = color.into();
let t = line.project(color);
min_t = min_t.min(t);
max_t = max_t.max(t);
}
let mid_t = (min_t + max_t) * 0.5;
min_t = mid_t + (min_t - mid_t) * nudge_factor;
max_t = mid_t + (max_t - mid_t) * nudge_factor;
(line.at(min_t), line.at(max_t))
}
pub(crate) fn line4_fit_endpoints<C: Copy + Into<Vec4>>(
colors: &[C],
nudge_factor: f32,
) -> (Vec4, Vec4) {
debug_assert!(!colors.is_empty());
let line = ColorLine4::new(colors);
let mut min_t = f32::INFINITY;
let mut max_t = f32::NEG_INFINITY;
for &color in colors.iter() {
let color: Vec4 = color.into();
let t = line.project(color);
min_t = min_t.min(t);
max_t = max_t.max(t);
}
let mid_t = (min_t + max_t) * 0.5;
min_t = mid_t + (min_t - mid_t) * nudge_factor;
max_t = mid_t + (max_t - mid_t) * nudge_factor;
(line.at(min_t), line.at(max_t))
}
pub(crate) struct ColorLine3 {
pub centroid: Vec3A,
d: Vec3A,
}
impl ColorLine3 {
pub fn new<C: Copy + Into<Vec3A>>(colors: &[C]) -> Self {
fn mean<C: Copy + Into<Vec3A>>(colors: &[C]) -> Vec3A {
let mut mean = Vec3A::ZERO;
for &color in colors {
let color: Vec3A = color.into();
mean += color;
}
mean * (1. / colors.len() as f32)
}
fn covariance_matrix<C: Copy + Into<Vec3A>>(colors: &[C], centroid: Vec3A) -> [Vec3A; 3] {
let mut cov = [Vec3A::ZERO; 3];
for &p in colors {
let p: Vec3A = p.into();
let d = p - centroid;
cov[0] += d * d.x;
cov[1] += d * d.y;
cov[2] += d * d.z;
}
let n_r = 1.0 / colors.len() as f32;
cov[0] *= n_r;
cov[1] *= n_r;
cov[2] *= n_r;
cov
}
fn largest_eigenvector(matrix: [Vec3A; 3]) -> Vec3A {
let mut v = Vec3A::ONE;
for _ in 0..2 {
let r = matrix[0].dot(v);
let g = matrix[1].dot(v);
let b = matrix[2].dot(v);
v = Vec3A::new(r, g, b).normalize_or_zero();
}
v
}
debug_assert!(!colors.is_empty());
let centroid = mean(colors);
let covariance = covariance_matrix(colors, centroid);
let eigenvector = largest_eigenvector(covariance);
Self {
centroid,
d: eigenvector,
}
}
pub fn at(&self, t: f32) -> Vec3A {
self.centroid + self.d * t
}
pub fn project(&self, color: Vec3A) -> f32 {
let diff = color - self.centroid;
diff.dot(self.d)
}
pub fn sum_dist_sq(&self, colors: &[Vec3A]) -> f32 {
let mut sum = Vec3A::ZERO;
for &color in colors {
let diff = color - self.centroid;
let t = self.d.dot(diff);
let dist = diff - self.d * t;
sum += dist * dist;
}
sum.x + sum.y + sum.z
}
}
pub(crate) struct ColorLine4 {
pub centroid: Vec4,
d: Vec4,
}
impl ColorLine4 {
pub fn new<C: Copy + Into<Vec4>>(colors: &[C]) -> Self {
fn mean<C: Copy + Into<Vec4>>(colors: &[C]) -> Vec4 {
let mut mean = Vec4::ZERO;
for &color in colors {
let color: Vec4 = color.into();
mean += color;
}
mean * (1. / colors.len() as f32)
}
fn covariance_matrix<C: Copy + Into<Vec4>>(colors: &[C], centroid: Vec4) -> [Vec4; 4] {
let mut cov = [Vec4::ZERO; 4];
for &p in colors {
let p: Vec4 = p.into();
let d = p - centroid;
cov[0] += d * d.x;
cov[1] += d * d.y;
cov[2] += d * d.z;
cov[3] += d * d.w;
}
let n_r = 1.0 / colors.len() as f32;
cov[0] *= n_r;
cov[1] *= n_r;
cov[2] *= n_r;
cov[3] *= n_r;
cov
}
fn largest_eigenvector(matrix: [Vec4; 4]) -> Vec4 {
let mut v = Vec4::ONE;
for _ in 0..2 {
let r = matrix[0].dot(v);
let g = matrix[1].dot(v);
let b = matrix[2].dot(v);
let a = matrix[3].dot(v);
v = Vec4::new(r, g, b, a).normalize_or_zero();
}
v
}
debug_assert!(!colors.is_empty());
let centroid = mean(colors);
let covariance = covariance_matrix(colors, centroid);
let eigenvector = largest_eigenvector(covariance);
Self {
centroid,
d: eigenvector,
}
}
pub fn at(&self, t: f32) -> Vec4 {
self.centroid + self.d * t
}
pub fn project(&self, color: Vec4) -> f32 {
let diff = color - self.centroid;
diff.dot(self.d)
}
pub fn sum_dist_sq(&self, colors: &[Vec4]) -> f32 {
let mut sum = Vec4::ZERO;
for &color in colors {
let diff = color - self.centroid;
let t = self.d.dot(diff);
let dist = diff - self.d * t;
sum += dist * dist;
}
(sum.x + sum.y) + (sum.z + sum.w)
}
}
struct LeastSquaresWeightMatrix {
pub e01: f32,
pub e11: f32,
pub e00_01: f32,
pub e10_11: f32,
}
impl LeastSquaresWeightMatrix {
fn mean(color_count: usize) -> Self {
let r = 1.0 / (color_count as f32);
Self {
e01: r,
e11: r,
e00_01: 0.0,
e10_11: 0.0,
}
}
fn from_d(a: f32, b: f32, c: f32) -> Option<Self> {
let d_det = a * c - b * b;
if d_det.abs() < f32::EPSILON {
return None;
}
let d_det_rep = 1.0 / d_det;
let (e00, e01, e11) = (c * d_det_rep, -b * d_det_rep, a * d_det_rep);
Some(Self {
e01,
e11,
e00_01: e00 - e01,
e10_11: e01 - e11, })
}
}
pub(crate) fn least_squares_weights<
R: Copy + Default + std::ops::Mul<f32, Output = R> + std::ops::AddAssign,
C: Copy + Into<R>,
>(
colors: &[C],
weights: &[f32],
) -> (R, R) {
assert_eq!(weights.len(), colors.len());
let (mut a, mut b, mut c) = (0.0f32, 0.0f32, 0.0f32);
for &w in weights {
let w_inv = 1.0 - w;
a += w * w;
b += w * w_inv;
c += w_inv * w_inv;
}
let LeastSquaresWeightMatrix {
e01,
e11,
e00_01,
e10_11,
} = LeastSquaresWeightMatrix::from_d(a, b, c)
.unwrap_or(LeastSquaresWeightMatrix::mean(weights.len()));
let (mut x0, mut x1) = (R::default(), R::default());
for (&color, &w) in colors.iter().zip(weights) {
let color: R = color.into();
let g0 = e01 + (e00_01) * w;
let g1 = e11 + (e10_11) * w;
x0 += color * g0;
x1 += color * g1;
}
(x0, x1)
}
pub(crate) fn least_squares_weights_f32_vec4(
colors: &[Vec4; 4],
weights: &[Vec4; 4],
) -> (f32, f32) {
let [w0, w1, w2, w3] = *weights;
let [w0_, w1_, w2_, w3_] = [1.0 - w0, 1.0 - w1, 1.0 - w2, 1.0 - w3];
let a = w0 * w0 + w1 * w1 + w2 * w2 + w3 * w3;
let b = w0 * w0_ + w1 * w1_ + w2 * w2_ + w3 * w3_;
let c = w0_ * w0_ + w1_ * w1_ + w2_ * w2_ + w3_ * w3_;
let a = (a.x + a.y) + (a.z + a.w);
let b = (b.x + b.y) + (b.z + b.w);
let c = (c.x + c.y) + (c.z + c.w);
let LeastSquaresWeightMatrix {
e01,
e11,
e00_01,
e10_11,
} = LeastSquaresWeightMatrix::from_d(a, b, c).unwrap_or(LeastSquaresWeightMatrix::mean(16));
let [c0, c1, c2, c3] = *colors;
let x0 = (c0 * (e01 + e00_01 * w0))
+ (c1 * (e01 + e00_01 * w1))
+ (c2 * (e01 + e00_01 * w2))
+ (c3 * (e01 + e00_01 * w3));
let x1 = (c0 * (e11 + e10_11 * w0))
+ (c1 * (e11 + e10_11 * w1))
+ (c2 * (e11 + e10_11 * w2))
+ (c3 * (e11 + e10_11 * w3));
let x0 = (x0.x + x0.y) + (x0.z + x0.w);
let x1 = (x1.x + x1.y) + (x1.z + x1.w);
(x0, x1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_least_square_weights_all_the_same() {
let colors: &[f32] = &[1.0, 2.0, 3.0, 4.0];
let weights: &[f32] = &[0.5, 0.5, 0.5, 0.5];
let (min, max): (f32, f32) = least_squares_weights(colors, weights);
assert!((min - max).abs() < 1e-6);
assert!((max - 2.5).abs() < 1e-6);
}
}
pub(crate) trait Quantized: Copy + Sized + WithChannels<E = u8> {
type V: VectorType + WithChannels<E = f32>;
fn round(v: Self::V) -> Self;
fn floor(v: Self::V) -> Self;
fn ceil(v: Self::V) -> Self;
fn to_vec(self) -> Self::V;
}
pub(crate) trait VectorType: Copy + Sized + std::ops::Sub<Output = Self> {}
impl VectorType for f32 {}
impl VectorType for glam::Vec3A {}
impl VectorType for glam::Vec4 {}
pub(crate) trait WithChannels {
type E;
const CHANNELS: usize;
fn get(&self, channel: usize) -> Self::E;
fn set(&mut self, channel: usize, value: Self::E);
}
impl WithChannels for f32 {
type E = f32;
const CHANNELS: usize = 1;
#[inline(always)]
fn get(&self, channel: usize) -> Self::E {
debug_assert!(channel == 0);
*self
}
#[inline(always)]
fn set(&mut self, channel: usize, value: Self::E) {
debug_assert!(channel == 0);
*self = value;
}
}
impl WithChannels for glam::Vec3A {
type E = f32;
const CHANNELS: usize = 3;
#[inline(always)]
fn get(&self, channel: usize) -> Self::E {
self[channel]
}
#[inline(always)]
fn set(&mut self, channel: usize, value: Self::E) {
self[channel] = value;
}
}
impl WithChannels for glam::Vec4 {
type E = f32;
const CHANNELS: usize = 4;
#[inline(always)]
fn get(&self, channel: usize) -> Self::E {
self[channel]
}
#[inline(always)]
fn set(&mut self, channel: usize, value: Self::E) {
self[channel] = value;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Quantization {
Round,
ChannelWise,
ChannelWiseOptimized,
}
impl Quantization {
pub fn round<Q: Quantized>(c0: Q::V, c1: Q::V) -> (Q, Q) {
(Q::round(c0), Q::round(c1))
}
pub fn wide<Q: Quantized>(c0: Q::V, c1: Q::V) -> (Q, Q) {
let c0_floor = Q::floor(c0);
let c0_ceil = Q::ceil(c0);
let c1_floor = Q::floor(c1);
let c1_ceil = Q::ceil(c1);
let mut q0 = c0_floor;
let mut q1 = c1_ceil;
for c in 0..Q::V::CHANNELS {
if c0.get(c) > c1.get(c) {
q0.set(c, c0_ceil.get(c));
q1.set(c, c1_floor.get(c));
}
}
(q0, q1)
}
pub fn pick_best<Q: Quantized, E: PartialOrd>(
self,
c0: Q::V,
c1: Q::V,
mut error_metric: impl FnMut(Q, Q) -> E,
) -> (Q, Q) {
let mut best = Self::round(c0, c1);
if self == Quantization::Round {
return best;
}
let mut best_error = error_metric(best.0, best.1);
let get_range = match self {
Quantization::ChannelWiseOptimized => Self::optimized_range::<Q>,
_ => Self::full_range::<Q>,
};
let (c0_min, c0_max) = get_range(c0);
let (c1_min, c1_max) = get_range(c1);
for c in 0..Q::CHANNELS {
let skip0 = best.0.get(c);
let skip1 = best.1.get(c);
for channel0 in c0_min.get(c)..=c0_max.get(c) {
for channel1 in c1_min.get(c)..=c1_max.get(c) {
if channel0 == skip0 && channel1 == skip1 {
continue;
}
let (mut c0, mut c1) = best;
c0.set(c, channel0);
c1.set(c, channel1);
let error = error_metric(c0, c1);
if error < best_error {
best = (c0, c1);
best_error = error;
}
}
}
}
best
}
fn full_range<Q: Quantized>(c: Q::V) -> (Q, Q) {
(Q::floor(c), Q::ceil(c))
}
const CULL_THRESHOLD: f32 = 0.25;
fn optimized_range<Q: Quantized>(c: Q::V) -> (Q, Q) {
let mut floor = Q::floor(c);
let mut ceil = Q::ceil(c);
let v_floor = floor.to_vec();
let v_ceil = ceil.to_vec();
let floor_dist = c - v_floor;
let dist = v_ceil - v_floor;
const FLOOR_THRESHOLD: f32 = Quantization::CULL_THRESHOLD;
const CEIL_THRESHOLD: f32 = 1.0 - Quantization::CULL_THRESHOLD;
for c in 0..Q::V::CHANNELS {
let floor_dist: f32 = floor_dist.get(c);
let dist: f32 = dist.get(c);
if floor_dist < FLOOR_THRESHOLD * dist {
ceil.set(c, floor.get(c));
} else if floor_dist > CEIL_THRESHOLD * dist {
floor.set(c, ceil.get(c));
}
}
(floor, ceil)
}
}