use glam::{UVec4, Vec4};
use crate::{n8, s8, util::unlikely_branch};
use super::bcn_util::{self, Block4x4};
#[derive(Debug, Clone, Copy)]
pub(crate) struct Bc4Options {
pub dither: bool,
pub snorm: bool,
pub brute_force: bool,
pub use_inter4: bool,
pub use_inter4_heuristic: bool,
pub quantization: Bc4Quantization,
pub fast_iter: bool,
pub max_refine_iter: u8,
pub size_variations: bool,
}
#[derive(Debug, Clone, Copy)]
pub(crate) enum Bc4Quantization {
Round,
MediumQuality,
HighQuality,
}
struct Block {
b: [Vec4; 4],
}
impl Block {
fn from_raw(block: [f32; 16]) -> Self {
let b0 = Vec4::new(block[0], block[1], block[2], block[3]);
let b1 = Vec4::new(block[4], block[5], block[6], block[7]);
let b2 = Vec4::new(block[8], block[9], block[10], block[11]);
let b3 = Vec4::new(block[12], block[13], block[14], block[15]);
Self {
b: [
b0.clamp(Vec4::ZERO, Vec4::ONE),
b1.clamp(Vec4::ZERO, Vec4::ONE),
b2.clamp(Vec4::ZERO, Vec4::ONE),
b3.clamp(Vec4::ZERO, Vec4::ONE),
],
}
}
fn min_max(&self) -> (f32, f32) {
let [b0, b1, b2, b3] = self.b;
let min = b0.min(b1).min(b2).min(b3);
let max = b0.max(b1).max(b2).max(b3);
(min.min_element(), max.max_element())
}
fn min_max_with_threshold(&self, threshold: f32) -> (f32, f32) {
let low = Vec4::splat(threshold);
let high = Vec4::splat(1.0 - threshold);
let mut min = Vec4::ONE;
let mut max = Vec4::ZERO;
for b in self.b {
min = min.min(Vec4::select(b.cmpge(low), b, Vec4::ONE));
max = max.max(Vec4::select(b.cmple(high), b, Vec4::ZERO));
}
(min.min_element(), max.max_element())
}
}
impl Block4x4<f32> for &Block {
#[inline(always)]
fn get_pixel_at(&self, index: usize) -> f32 {
debug_assert!(index < 16);
let vec_index = (index / 4) % 4;
let component_index = index % 4;
self.b[vec_index][component_index]
}
}
const BC4_MIN_VALUE: f32 = 1. / (255. * 7.);
const BC4_EPSILON: f32 = 1. / (65536.);
pub(crate) fn compress_bc4_block(block: [f32; 16], options: Bc4Options) -> [u8; 8] {
let block = Block::from_raw(block);
let (min, max) = block.min_max();
let diff = max - min;
if options.brute_force && !options.dither && !options.snorm {
return reference_brute_force(block);
}
if diff < BC4_EPSILON {
let value = (min + max) * 0.5;
return single_color(value, options);
}
const INTER6_THRESHOLD: f32 = 1. / 7.;
let heuristic = 0. < min - diff * INTER6_THRESHOLD && max + diff * INTER6_THRESHOLD < 1.;
if !options.use_inter4 || options.use_inter4_heuristic && heuristic {
return compress_inter6(&block, min, max, options).0;
}
let (inter6, error6) = compress_inter6(&block, min, max, options);
let (inter4, error4) = compress_inter4(&block, options);
if error6 < error4 {
inter6
} else {
inter4
}
}
fn reference_brute_force(block: Block) -> [u8; 8] {
let (block_min, block_max) = block.min_max();
let min_max = (block_max * 255. + 2.) as u8;
let max_min = (block_min * 255. - 1.) as u8;
let mut best = [0_u8; 8];
let mut best_error = f32::INFINITY;
for min in 0..min_max {
for max in max_min.max(min + 1)..=255 {
let endpoints6 = EndPoints::new_inter6_unorm(max, min);
let palette6 = Inter6Palette::from_endpoints(&endpoints6);
let error6 = palette6.block_closest_error_sq(&block);
if error6 < best_error {
best = endpoints6.with_indexes(palette6.block_closest(&block).0);
best_error = error6;
}
let endpoints4 = endpoints6.inter6_to_inter4();
let palette4 = Inter4Palette::from_endpoints(&endpoints4);
let error4 = palette4.block_closest_error_sq(&block);
if error4 < best_error {
best = endpoints4.with_indexes(palette4.block_closest(&block).0);
best_error = error4;
}
}
}
best
}
fn single_color(value: f32, options: Bc4Options) -> [u8; 8] {
let closest = EndPoints::new_closest(value, options.snorm);
if (closest.c0_f - value).abs() < BC4_EPSILON {
return closest.with_indexes(IndexList::new_all(0));
}
let endpoints6 = EndPoints::new_inter6(value, value, options.snorm);
let palette6 = Inter6Palette::from_endpoints(&endpoints6);
if options.dither {
let (indexes, _) = palette6.block_dither(value);
endpoints6.with_indexes(indexes)
} else {
let endpoints4 = EndPoints::new_inter4(value, value, options.snorm);
let palette4 = Inter4Palette::from_endpoints(&endpoints4);
let (index_value4, _, error4) = palette4.closest(value);
let (index_value6, _, error6) = palette6.closest(value);
if error4 < error6 {
endpoints4.with_indexes(IndexList::new_all(index_value4))
} else {
endpoints6.with_indexes(IndexList::new_all(index_value6))
}
}
}
fn refine_endpoints(
mut min: f32,
mut max: f32,
mut compute_error: impl Copy + FnMut((f32, f32)) -> f32,
options: Bc4Options,
) -> (f32, f32) {
if options.max_refine_iter > 0 {
(min, max) = bcn_util::refine_endpoints(
min,
max,
if options.fast_iter {
bcn_util::RefinementOptions {
step_initial: 0.1 * (max - min),
step_decay: 0.5,
step_min: 1. / 255.,
max_iter: options.max_refine_iter as u32,
}
} else {
bcn_util::RefinementOptions {
step_initial: 0.15 * (max - min),
step_decay: 0.5,
step_min: 1. / 255. / 2.,
max_iter: options.max_refine_iter as u32,
}
},
compute_error,
);
}
const QUANT_STEP: f32 = 1. / 254. + 0.0001;
match options.quantization {
Bc4Quantization::Round => (min, max),
Bc4Quantization::MediumQuality => {
let base_quantized = EndPoints::quantize((min, max), options.snorm);
let min_diff = min - base_quantized.0;
let max_diff = max - base_quantized.1;
let other = if min_diff.abs() > max_diff.abs() {
let step = if min_diff > 0.0 {
QUANT_STEP
} else {
-QUANT_STEP
};
(min + step, max)
} else {
let step = if max_diff > 0.0 {
QUANT_STEP
} else {
-QUANT_STEP
};
(min, max + step)
};
let other_quantized = EndPoints::quantize(other, options.snorm);
let base_error = compute_error(base_quantized);
let other_error = compute_error(other_quantized);
if other_error < base_error {
other_quantized
} else {
base_quantized
}
}
Bc4Quantization::HighQuality => {
let mut best = EndPoints::quantize((min, max), options.snorm);
let mut error = compute_error(best);
for pair in [
(min + QUANT_STEP, max),
(min, max - QUANT_STEP),
(min + QUANT_STEP, max - QUANT_STEP),
] {
let q = EndPoints::quantize(pair, options.snorm);
let new_error = compute_error(q);
if new_error < error {
error = new_error;
best = q;
}
}
best
}
}
}
fn refinement_error_metric<P: Palette>(
block: &Block,
_options: Bc4Options,
) -> impl Copy + Fn((f32, f32)) -> f32 + '_ {
move |(min, max)| {
let palette = P::new(min, max);
palette.block_closest_error_sq(block)
}
}
fn compress_inter6(
block: &Block,
mut min: f32,
mut max: f32,
options: Bc4Options,
) -> ([u8; 8], f32) {
let mean = {
let [b0, b1, b2, b3] = block.b;
let b = b0 + b1 + b2 + b3;
((b.x + b.y) + (b.z + b.w)) * (1.0 / 16.0)
};
let nudge = 0.95;
min = mean + (min - mean) * nudge;
max = mean + (max - mean) * nudge;
let mut best = compress_inter6_impl(block, min, max, options);
if options.size_variations {
let dist = max - min;
let min_5 = min - dist * 0.25;
let max_5 = max + dist * 0.25;
let min_movement = min_5;
let max_movement = 1.0 - max_5;
let pair = if min_movement > max_movement {
(min_5.max(0.0), max)
} else {
(min, max_5.min(1.0))
};
let p5 = compress_inter6_impl(block, pair.0, pair.1, options);
if p5.1 < best.1 {
best = p5;
}
}
best
}
fn compress_inter6_impl(
block: &Block,
mut min: f32,
mut max: f32,
options: Bc4Options,
) -> ([u8; 8], f32) {
for _ in 0..2 {
let weights = Inter6Palette::new(min, max).block_closest_weights(block);
(min, max) = bcn_util::least_squares_weights_f32_vec4(&block.b, &weights);
min = min.clamp(0.0, 1.0);
max = max.clamp(0.0, 1.0);
}
(min, max) = refine_endpoints(
min,
max,
refinement_error_metric::<Inter6Palette>(block, options),
options,
);
let endpoints = EndPoints::new_inter6(min, max, options.snorm);
let palette = Inter6Palette::from_endpoints(&endpoints);
let (indexes, error) = if options.dither {
palette.block_dither(block)
} else {
palette.block_closest(block)
};
(endpoints.with_indexes(indexes), error)
}
fn compress_inter4(block: &Block, options: Bc4Options) -> ([u8; 8], f32) {
let (mut min, mut max) = block.min_max_with_threshold(BC4_MIN_VALUE);
(min, max) = refine_endpoints(
min,
max,
refinement_error_metric::<Inter4Palette>(block, options),
options,
);
let endpoints = EndPoints::new_inter4(min, max, options.snorm);
let palette = Inter4Palette::from_endpoints(&endpoints);
let (indexes, error) = if options.dither {
palette.block_dither(block)
} else {
palette.block_closest(block)
};
(endpoints.with_indexes(indexes), error)
}
struct EndPoints {
c0: u8,
c1: u8,
c0_f: f32,
c1_f: f32,
}
impl EndPoints {
fn new_closest(value: f32, snorm: bool) -> Self {
if snorm {
let closest_s8_norm = (254.0 * value + 0.5) as u8;
let c0 = s8::from_norm(closest_s8_norm);
let c1 = s8::from_norm(0);
let c0_f = s8::uf32(c0);
let c1_f = s8::uf32(c1);
debug_assert!(c1_f == 0.0);
Self { c0, c1, c0_f, c1_f }
} else {
let closest = (255.0 * value + 0.5) as u8;
let c0 = closest;
let c1 = 0;
let c0_f = n8::f32(c0);
let c1_f = 0.0;
Self { c0, c1, c0_f, c1_f }
}
}
fn quantize((e0, e1): (f32, f32), snorm: bool) -> (f32, f32) {
let min = e0.min(e1);
let max = e0.max(e1);
if snorm {
unlikely_branch();
let min = min.clamp(0.0, 1.0);
let max = max.clamp(0.0, 1.0);
let mut min_s8_norm = (254.0 * min + 0.5) as u8;
let mut max_s8_norm = (254.0 * max + 0.5) as u8;
if min_s8_norm == max_s8_norm {
unlikely_branch();
min_s8_norm = (254.0 * min) as u8;
max_s8_norm = 254 - (254.0 * (1.0 - max)) as u8;
if min_s8_norm == max_s8_norm {
if min_s8_norm == 0 {
max_s8_norm = 1;
} else {
min_s8_norm -= 1;
}
}
}
debug_assert!(min_s8_norm < max_s8_norm);
let c0 = s8::from_norm(max_s8_norm);
let c1 = s8::from_norm(min_s8_norm);
debug_assert!(c0 != c1);
let c0_f = s8::uf32(c0);
let c1_f = s8::uf32(c1);
(c0_f, c1_f)
} else {
let mut min_u8 = (255.0 * min + 0.5) as u8;
let mut max_u8 = (255.0 * max + 0.5) as u8;
if min_u8 == max_u8 {
unlikely_branch();
min_u8 = (255.0 * min) as u8;
max_u8 = 255 - (255.0 * (1.0 - max)) as u8;
if min_u8 == max_u8 {
if min_u8 == 0 {
max_u8 = 1;
} else {
min_u8 -= 1;
}
}
}
debug_assert!(min_u8 < max_u8);
(n8::f32(max_u8), n8::f32(min_u8))
}
}
fn new_inter6(e0: f32, e1: f32, snorm: bool) -> Self {
let min = e0.min(e1);
let max = e0.max(e1);
if snorm {
let min = min.clamp(0.0, 1.0);
let max = max.clamp(0.0, 1.0);
let mut min_s8_norm = (254.0 * min + 0.5) as u8;
let mut max_s8_norm = (254.0 * max + 0.5) as u8;
if min_s8_norm == max_s8_norm {
unlikely_branch();
min_s8_norm = (254.0 * min) as u8;
max_s8_norm = 254 - (254.0 * (1.0 - max)) as u8;
if min_s8_norm == max_s8_norm {
if min_s8_norm == 0 {
max_s8_norm = 1;
} else {
min_s8_norm -= 1;
}
}
}
debug_assert!(min_s8_norm < max_s8_norm);
let mut c0 = s8::from_norm(max_s8_norm);
let mut c1 = s8::from_norm(min_s8_norm);
debug_assert!(c0 != c1);
if c0 as i8 <= c1 as i8 {
std::mem::swap(&mut c0, &mut c1);
}
let c0_f = s8::uf32(c0);
let c1_f = s8::uf32(c1);
Self { c0, c1, c0_f, c1_f }
} else {
let mut min_u8 = (255.0 * min + 0.5) as u8;
let mut max_u8 = (255.0 * max + 0.5) as u8;
if min_u8 == max_u8 {
unlikely_branch();
min_u8 = (255.0 * min) as u8;
max_u8 = 255 - (255.0 * (1.0 - max)) as u8;
if min_u8 == max_u8 {
if min_u8 == 0 {
max_u8 = 1;
} else {
min_u8 -= 1;
}
}
}
debug_assert!(min_u8 < max_u8);
let c0 = max_u8;
let c1 = min_u8;
let c0_f = n8::f32(c0);
let c1_f = n8::f32(c1);
Self { c0, c1, c0_f, c1_f }
}
}
fn new_inter4(e0: f32, e1: f32, snorm: bool) -> Self {
Self::new_inter6(e0, e1, snorm).inter6_to_inter4()
}
fn new_inter6_unorm(c0: u8, c1: u8) -> Self {
debug_assert!(c0 > c1);
let c0_f = n8::f32(c0);
let c1_f = n8::f32(c1);
Self { c0, c1, c0_f, c1_f }
}
fn inter6_to_inter4(&self) -> Self {
Self {
c0: self.c1,
c1: self.c0,
c0_f: self.c1_f,
c1_f: self.c0_f,
}
}
fn with_indexes(&self, indexes: IndexList) -> [u8; 8] {
let index_bytes = indexes.data.to_le_bytes();
[
self.c0,
self.c1,
index_bytes[0],
index_bytes[1],
index_bytes[2],
index_bytes[3],
index_bytes[4],
index_bytes[5],
]
}
}
struct IndexList {
data: u64,
}
impl IndexList {
fn new_empty() -> Self {
Self { data: 0 }
}
fn new_all(value: u8) -> Self {
debug_assert!(value < 8);
const MASK: u64 = {
let mut mask: u64 = 0;
let mut i = 0;
while i < 16 {
mask |= 1 << (i * 3);
i += 1;
}
mask
};
Self {
data: (value as u64) * MASK,
}
}
fn get(&self, index: usize) -> u8 {
debug_assert!(index < 16);
((self.data >> (index * 3)) & 0b111) as u8
}
fn set(&mut self, index: usize, value: u8) {
debug_assert!(index < 16);
debug_assert!(value < 8);
debug_assert!(self.get(index) == 0, "Cannot set an index twice.");
self.data |= (value as u64) << (index * 3);
}
}
trait Palette {
fn new(c0: f32, c1: f32) -> Self
where
Self: Sized;
fn from_endpoints(endpoints: &EndPoints) -> Self
where
Self: Sized,
{
Self::new(endpoints.c0_f, endpoints.c1_f)
}
fn closest(&self, pixel: f32) -> (u8, f32, f32);
fn closest_4(&self, pixels: Vec4) -> (UVec4, Vec4) {
let (i0, _, e0) = self.closest(pixels.x);
let (i1, _, e1) = self.closest(pixels.y);
let (i2, _, e2) = self.closest(pixels.z);
let (i3, _, e3) = self.closest(pixels.w);
(
UVec4::new(i0 as u32, i1 as u32, i2 as u32, i3 as u32),
Vec4::new(e0, e1, e2, e3),
)
}
fn block_closest(&self, block: &Block) -> (IndexList, f32) {
let mut total_error = 0.0;
let mut index_list = IndexList::new_empty();
for (pixel_index, pixels) in block.b.iter().enumerate() {
let (index_value, error) = self.closest_4(*pixels);
for i in 0..4 {
index_list.set(pixel_index * 4 + i, index_value[i] as u8);
}
total_error += error.dot(error);
}
(index_list, total_error)
}
fn block_closest_error_sq(&self, block: &Block) -> f32;
fn block_dither(&self, block: impl Block4x4<f32>) -> (IndexList, f32) {
let mut index_list = IndexList::new_empty();
let mut total_error = 0.0;
bcn_util::block_dither(block, |pixel_index, pixel| {
let (index_value, closest, error) = self.closest(pixel);
index_list.set(pixel_index, index_value);
total_error += error * error;
closest
});
(index_list, total_error)
}
}
struct Inter6Palette {
c0: f32,
c1: f32,
factor1: f32,
factor2: f32,
add1: f32,
}
impl Inter6Palette {
const INDEX_MAP: [u8; 8] = [1, 7, 6, 5, 4, 3, 2, 0];
fn closest_weights4(&self, pixels: Vec4) -> Vec4 {
let blend = (pixels * self.factor1 + self.add1).min(Vec4::splat(7.0));
1.0 - blend.trunc() * (1.0 / 7.0)
}
fn block_closest_weights(&self, block: &Block) -> [Vec4; 4] {
block.b.map(|pixels| self.closest_weights4(pixels))
}
}
impl Palette for Inter6Palette {
fn new(c0: f32, c1: f32) -> Self {
debug_assert!(c0 != c1);
let factor1 = 7.0 / (c0 - c1);
let factor2 = (1.0 / 7.0) * (c0 - c1);
let add1 = 0.5 - c1 * factor1;
Self {
c0,
c1,
factor1,
factor2,
add1,
}
}
fn closest(&self, pixel: f32) -> (u8, f32, f32) {
let blend = pixel * self.factor1 + self.add1;
let blend7 = ((blend) as u8).min(7);
let index_value = Self::INDEX_MAP[blend7 as usize];
let closest = blend7 as f32 * self.factor2 + self.c1;
let error = pixel - closest;
(index_value, closest, error.abs())
}
fn block_closest_error_sq(&self, block: &Block) -> f32 {
let [b0, b1, b2, b3] = block.b;
let c0 = Vec4::splat(self.c0);
let cd = Vec4::splat(self.c1 - self.c0);
let mut e0 = (c0 - b0).abs();
let mut e1 = (c0 - b1).abs();
let mut e2 = (c0 - b2).abs();
let mut e3 = (c0 - b3).abs();
const FACTORS: [f32; 7] = [1. / 7., 2. / 7., 3. / 7., 4. / 7., 5. / 7., 6. / 7., 1.];
for f in FACTORS {
let c = c0 + cd * f;
e0 = e0.min((c - b0).abs());
e1 = e1.min((c - b1).abs());
e2 = e2.min((c - b2).abs());
e3 = e3.min((c - b3).abs());
}
let e = e0 * e0 + e1 * e1 + e2 * e2 + e3 * e3;
e.x + e.y + e.z + e.w
}
}
struct Inter4Palette {
colors: [f32; 8],
}
impl Palette for Inter4Palette {
fn new(c0: f32, c1: f32) -> Self {
Self {
colors: [
c0,
c1,
c0 * 0.8 + c1 * 0.2,
c0 * 0.6 + c1 * 0.4,
c0 * 0.4 + c1 * 0.6,
c0 * 0.2 + c1 * 0.8,
0.0,
1.0,
],
}
}
fn closest(&self, pixel: f32) -> (u8, f32, f32) {
let (mut index_value, mut min_error) = if pixel >= 0.5 {
(7_u8, 1.0 - pixel)
} else {
(6_u8, pixel)
};
#[allow(clippy::needless_range_loop)]
for i in 0..6 {
let error = (pixel - self.colors[i]).abs();
if error < min_error {
min_error = error;
index_value = i as u8;
}
}
(index_value, self.colors[index_value as usize], min_error)
}
fn block_closest_error_sq(&self, block: &Block) -> f32 {
let [b0, b1, b2, b3] = block.b;
let mut e0 = b0.min(1.0 - b0);
let mut e1 = b1.min(1.0 - b1);
let mut e2 = b2.min(1.0 - b2);
let mut e3 = b3.min(1.0 - b3);
for i in 0..6 {
let c = Vec4::splat(self.colors[i]);
e0 = e0.min((c - b0).abs());
e1 = e1.min((c - b1).abs());
e2 = e2.min((c - b2).abs());
e3 = e3.min((c - b3).abs());
}
let e = e0 * e0 + e1 * e1 + e2 * e2 + e3 * e3;
e.x + e.y + e.z + e.w
}
}