#![allow(clippy::many_single_char_names)]
use std::convert::TryFrom;
use std::io::*;
struct InputBitStream {
data: u128,
bits_read: u32,
}
impl InputBitStream {
fn new(data: u128) -> InputBitStream {
InputBitStream { data, bits_read: 0 }
}
fn get_bits_read(&self) -> u32 {
self.bits_read
}
fn read_bit(&mut self) -> u32 {
self.read_bits(1)
}
fn read_bits(&mut self, n_bits: u32) -> u32 {
assert!(n_bits <= 32);
self.read_bits128(n_bits) as u32
}
fn read_bits128(&mut self, n_bits: u32) -> u128 {
self.bits_read += n_bits;
assert!(self.bits_read <= 128);
let ret = self.data & ((1 << n_bits) - 1);
self.data >>= n_bits;
ret
}
}
struct Bits(u32);
impl Bits {
fn get(&self, pos: u32) -> u32 {
(self.0 >> pos) & 1
}
fn range(&self, start: u32, end: u32) -> u32 {
let mask = (1 << (end - start + 1)) - 1;
(self.0 >> start) & mask
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum IntegerEncodingType {
JustBits,
Quint,
Trit,
}
#[derive(Clone, Copy, PartialEq, Eq)]
struct IntegerEncoding {
encoding: IntegerEncodingType,
num_bits: u32,
}
impl IntegerEncoding {
fn get_bit_length(&self, n_vals: u32) -> u32 {
let mut total_bits = self.num_bits * n_vals;
match self.encoding {
IntegerEncodingType::JustBits => (),
IntegerEncodingType::Trit => total_bits += (n_vals * 8 + 4) / 5,
IntegerEncodingType::Quint => total_bits += (n_vals * 7 + 2) / 3,
}
total_bits
}
}
fn decode_color(a: u32, b: u32, c: u32, d: u32) -> u8 {
let t = (d * c + b) ^ a;
u8::try_from((a & 0x80) | (t >> 2)).unwrap()
}
fn decode_weight(a: u32, b: u32, c: u32, d: u32) -> u32 {
let t = (d * c + b) ^ a;
(a & 0x20) | (t >> 2)
}
struct Trit {
trit_value: u32,
bit_value: u32,
}
impl Trit {
fn decode_color(self, bitlen: u32) -> u8 {
let bitval = self.bit_value;
let a = (bitval & 1) * 0x1FF;
let x = bitval >> 1;
let b;
let c;
match bitlen {
1 => {
c = 204;
b = 0;
}
2 => {
c = 93;
b = (x << 8) | (x << 4) | (x << 2) | (x << 1);
}
3 => {
c = 44;
b = (x << 7) | (x << 2) | x;
}
4 => {
c = 22;
b = (x << 6) | x;
}
5 => {
c = 11;
b = (x << 5) | (x >> 2);
}
6 => {
c = 5;
b = (x << 4) | (x >> 4);
}
_ => unreachable!("Invalid trit encoding for color values"),
}
decode_color(a, b, c, self.trit_value)
}
fn decode_weight(self, bitlen: u32) -> u32 {
let bitval = self.bit_value;
let a = (bitval & 1) * 0x7F;
let x = bitval >> 1;
let b;
let c;
match bitlen {
0 => {
return [0, 32, 63][self.trit_value as usize];
}
1 => {
c = 50;
b = 0;
}
2 => {
c = 23;
b = (x << 6) | (x << 2) | x;
}
3 => {
c = 11;
b = (x << 5) | x;
}
_ => unreachable!("Invalid trit encoding for texel weight"),
}
decode_weight(a, b, c, self.trit_value)
}
}
struct Quint {
quint_value: u32,
bit_value: u32,
}
impl Quint {
fn decode_color(self, bitlen: u32) -> u8 {
let bitval = self.bit_value;
let a = (bitval & 1) * 0x1FF;
let x = bitval >> 1;
let b;
let c;
match bitlen {
1 => {
c = 113;
b = 0;
}
2 => {
c = 54;
b = (x << 8) | (x << 3) | (x << 2);
}
3 => {
c = 26;
b = (x << 7) | (x << 1) | (x >> 1);
}
4 => {
c = 13;
b = (x << 6) | (x >> 1);
}
5 => {
c = 6;
b = (x << 5) | (x >> 3);
}
_ => unreachable!("Invalid quint encoding for color values"),
}
decode_color(a, b, c, self.quint_value)
}
fn decode_weight(self, bitlen: u32) -> u32 {
let bitval = self.bit_value;
let a = (bitval & 1) * 0x7F;
let b;
let c;
match bitlen {
0 => {
return [0, 16, 32, 47, 63][self.quint_value as usize];
}
1 => {
c = 28;
b = 0;
}
2 => {
c = 13;
let x = bitval >> 1;
b = (x << 6) | (x << 1);
}
_ => unreachable!("Invalid quint encoding for texel weight"),
}
decode_weight(a, b, c, self.quint_value)
}
}
fn decode_trit_block(bits: &mut InputBitStream, bits_per_value: u32) -> impl Iterator<Item = Trit> {
let mut m = [0u32; 5];
let mut t = [0u32; 5];
let mut tt: u32;
m[0] = bits.read_bits(bits_per_value);
tt = bits.read_bits(2);
m[1] = bits.read_bits(bits_per_value);
tt |= bits.read_bits(2) << 2;
m[2] = bits.read_bits(bits_per_value);
tt |= (bits.read_bit()) << 4;
m[3] = bits.read_bits(bits_per_value);
tt |= bits.read_bits(2) << 5;
m[4] = bits.read_bits(bits_per_value);
tt |= (bits.read_bit()) << 7;
let c: u32;
let tb = Bits(tt);
if tb.range(2, 4) == 7 {
c = (tb.range(5, 7) << 2) | tb.range(0, 1);
t[3] = 2;
t[4] = 2;
} else {
c = tb.range(0, 4);
if tb.range(5, 6) == 3 {
t[4] = 2;
t[3] = tb.get(7);
} else {
t[4] = tb.get(7);
t[3] = tb.range(5, 6);
}
}
let cb = Bits(c);
if cb.range(0, 1) == 3 {
t[2] = 2;
t[1] = cb.get(4);
t[0] = (cb.get(3) << 1) | (cb.get(2) & !cb.get(3));
} else if cb.range(2, 3) == 3 {
t[2] = 2;
t[1] = 2;
t[0] = cb.range(0, 1);
} else {
t[2] = cb.get(4);
t[1] = cb.range(2, 3);
t[0] = (cb.get(1) << 1) | (cb.get(0) & !cb.get(1));
}
IntoIterator::into_iter(m)
.zip(t)
.map(|(bit_value, trit_value)| Trit {
trit_value,
bit_value,
})
}
fn decode_quint_block(
bits: &mut InputBitStream,
bits_per_value: u32,
) -> impl Iterator<Item = Quint> {
let mut m = [0u32; 3];
let mut q = [0u32; 3];
let mut qq: u32;
m[0] = bits.read_bits(bits_per_value);
qq = bits.read_bits(3);
m[1] = bits.read_bits(bits_per_value);
qq |= bits.read_bits(2) << 3;
m[2] = bits.read_bits(bits_per_value);
qq |= bits.read_bits(2) << 5;
let qb = Bits(qq);
if qb.range(1, 2) == 3 && qb.range(5, 6) == 0 {
q[0] = 4;
q[1] = 4;
q[2] = (qb.get(0) << 2) | ((qb.get(4) & !qb.get(0)) << 1) | (qb.get(3) & !qb.get(0));
} else {
let c;
if qb.range(1, 2) == 3 {
q[2] = 4;
c = (qb.range(3, 4) << 3) | ((!qb.range(5, 6) & 3) << 1) | qb.get(0);
} else {
q[2] = qb.range(5, 6);
c = qb.range(0, 4);
}
let cb = Bits(c);
if cb.range(0, 2) == 5 {
q[1] = 4;
q[0] = cb.range(3, 4);
} else {
q[1] = cb.range(3, 4);
q[0] = cb.range(0, 2);
}
}
IntoIterator::into_iter(m)
.zip(q)
.map(|(bit_value, quint_value)| Quint {
quint_value,
bit_value,
})
}
const fn create_encoding(mut max_val: u32) -> IntegerEncoding {
while max_val > 0 {
let check = max_val + 1;
if (check & (check - 1)) == 0 {
return IntegerEncoding {
encoding: IntegerEncodingType::JustBits,
num_bits: max_val.count_ones(),
};
}
if (check % 3 == 0) && ((check / 3) & ((check / 3) - 1)) == 0 {
return IntegerEncoding {
encoding: IntegerEncodingType::Trit,
num_bits: (check / 3 - 1).count_ones(),
};
}
if (check % 5 == 0) && ((check / 5) & ((check / 5) - 1)) == 0 {
return IntegerEncoding {
encoding: IntegerEncodingType::Quint,
num_bits: (check / 5 - 1).count_ones(),
};
}
max_val -= 1;
}
IntegerEncoding {
encoding: IntegerEncodingType::JustBits,
num_bits: 0,
}
}
static ENCODING_MAP: [IntegerEncoding; 256] = {
let mut result = [IntegerEncoding {
encoding: IntegerEncodingType::JustBits,
num_bits: 0,
}; 256];
let mut i = 0;
while i < 256 {
result[i as usize] = create_encoding(i);
i += 1;
}
result
};
static ENCODING_SEQ: ([IntegerEncoding; 256], usize) = {
let mut result = [IntegerEncoding {
encoding: IntegerEncodingType::JustBits,
num_bits: 0,
}; 256];
let mut len = 1;
result[0] = ENCODING_MAP[0];
let mut i = 1;
while i < 256 {
let encoding = ENCODING_MAP[i];
let previous = result[len - 1];
if encoding.encoding as u32 != previous.encoding as u32
|| encoding.num_bits != previous.num_bits
{
result[len] = encoding;
len += 1;
}
i += 1;
}
(result, len)
};
struct TexelWeightParams {
width: u32,
height: u32,
is_dual_plane: bool,
max_weight: u32,
is_error: bool,
void_extent_ldr: bool,
void_extent_hdr: bool,
}
impl Default for TexelWeightParams {
fn default() -> Self {
TexelWeightParams {
width: 0,
height: 0,
is_dual_plane: false,
max_weight: 0,
is_error: false,
void_extent_ldr: false,
void_extent_hdr: false,
}
}
}
impl TexelWeightParams {
fn get_packed_bit_size(&self) -> u32 {
ENCODING_MAP[self.max_weight as usize].get_bit_length(self.get_num_weight_values())
}
fn get_num_weight_values(&self) -> u32 {
let mut ret = self.width * self.height;
if self.is_dual_plane {
ret *= 2;
}
ret
}
}
fn decode_block_info(strm: &mut InputBitStream) -> TexelWeightParams {
let mut params = TexelWeightParams::default();
let mode_bits = strm.read_bits(11);
if (mode_bits & 0x01FF) == 0x1FC {
if mode_bits & 0x200 != 0 {
params.void_extent_hdr = true;
} else {
params.void_extent_ldr = true;
}
if (mode_bits & 0x400) == 0 || strm.read_bit() == 0 {
params.is_error = true;
}
return params;
}
if (mode_bits & 0xF) == 0 {
params.is_error = true;
return params;
}
if (mode_bits & 0x3) == 0 && (mode_bits & 0x1C0) == 0x1C0 {
params.is_error = true;
return params;
}
let layout;
if (mode_bits & 0x1) != 0 || (mode_bits & 0x2) != 0 {
if (mode_bits & 0x8) != 0 {
if (mode_bits & 0x4) != 0 {
if (mode_bits & 0x100) != 0 {
layout = 4;
} else {
layout = 3;
}
} else {
layout = 2;
}
} else {
if (mode_bits & 0x4) != 0 {
layout = 1;
} else {
layout = 0;
}
}
} else {
if (mode_bits & 0x100) != 0 {
if (mode_bits & 0x80) != 0 {
assert!((mode_bits & 0x40) == 0);
if (mode_bits & 0x20) != 0 {
layout = 8;
} else {
layout = 7;
}
} else {
layout = 9;
}
} else {
if (mode_bits & 0x80) != 0 {
layout = 6;
} else {
layout = 5;
}
}
}
let mut r = (mode_bits & 0x10) >> 4;
if layout < 5 {
r |= (mode_bits & 0x3) << 1;
} else {
r |= (mode_bits & 0xC) >> 1;
}
assert!((2..=7).contains(&r));
match layout {
0 => {
let a = (mode_bits >> 5) & 0x3;
let b = (mode_bits >> 7) & 0x3;
params.width = b + 4;
params.height = a + 2;
}
1 => {
let a = (mode_bits >> 5) & 0x3;
let b = (mode_bits >> 7) & 0x3;
params.width = b + 8;
params.height = a + 2;
}
2 => {
let a = (mode_bits >> 5) & 0x3;
let b = (mode_bits >> 7) & 0x3;
params.width = a + 2;
params.height = b + 8;
}
3 => {
let a = (mode_bits >> 5) & 0x3;
let b = (mode_bits >> 7) & 0x1;
params.width = a + 2;
params.height = b + 6;
}
4 => {
let a = (mode_bits >> 5) & 0x3;
let b = (mode_bits >> 7) & 0x1;
params.width = b + 2;
params.height = a + 2;
}
5 => {
let a = (mode_bits >> 5) & 0x3;
params.width = 12;
params.height = a + 2;
}
6 => {
let a = (mode_bits >> 5) & 0x3;
params.width = a + 2;
params.height = 12;
}
7 => {
params.width = 6;
params.height = 10;
}
8 => {
params.width = 10;
params.height = 6;
}
9 => {
let a = (mode_bits >> 5) & 0x3;
let b = (mode_bits >> 9) & 0x3;
params.width = a + 6;
params.height = b + 6;
}
_ => unreachable!("Impossible layout"),
}
let dp = (layout != 9) && (mode_bits & 0x400) != 0;
let p = (layout != 9) && (mode_bits & 0x200) != 0;
let max_weights = if p {
[9, 11, 15, 19, 23, 31]
} else {
[1, 2, 3, 4, 5, 7]
};
params.max_weight = max_weights[(r - 2) as usize];
params.is_dual_plane = dp;
params
}
fn fill_void_extent_ldr<F: FnMut(u32, u32, [u8; 4])>(
strm: &mut InputBitStream,
writer: &mut F,
block_width: u32,
block_height: u32,
) {
for _ in 0..4 {
strm.read_bits(13);
}
let r = strm.read_bits(16) >> 8;
let g = strm.read_bits(16) >> 8;
let b = strm.read_bits(16) >> 8;
let a = strm.read_bits(16) >> 8;
for j in 0..block_height {
for i in 0..block_width {
writer(i, j, [r as u8, g as u8, b as u8, a as u8]);
}
}
}
fn fill_error<F: FnMut(u32, u32, [u8; 4])>(writer: &mut F, block_width: u32, block_height: u32) {
for j in 0..block_height {
for i in 0..block_width {
writer(i, j, [0xFF, 0, 0xFF, 0xFF]);
}
}
}
fn replicate(val: u32, num_bits: u32, to_bit: u32) -> u32 {
if num_bits == 0 {
return 0;
}
if to_bit == 0 {
return 0;
}
let mut res = val << (to_bit - num_bits);
let mut shift = num_bits;
loop {
let next = res >> shift;
if next == 0 {
return res;
}
res |= next;
shift *= 2;
}
}
fn decode_color_values(data: u128, n_values: u32, n_bits_for_color_data: u32) -> [u8; 18] {
let mut out = [0; 18];
let out_range = &mut out[0..n_values as usize];
let encoding_i = ENCODING_SEQ.0[0..ENCODING_SEQ.1]
.partition_point(|v| v.get_bit_length(n_values) <= n_bits_for_color_data);
let encoding = ENCODING_SEQ.0[encoding_i - 1];
let mut color_stream = InputBitStream::new(data);
match encoding.encoding {
IntegerEncodingType::JustBits => {
for out in out_range {
*out = replicate(
color_stream.read_bits(encoding.num_bits),
encoding.num_bits,
8,
) as u8;
}
}
IntegerEncodingType::Trit => {
for (out, result) in out_range
.iter_mut()
.zip((0..).flat_map(|_| decode_trit_block(&mut color_stream, encoding.num_bits)))
{
*out = result.decode_color(encoding.num_bits);
}
}
IntegerEncodingType::Quint => {
for (out, result) in out_range
.iter_mut()
.zip((0..).flat_map(|_| decode_quint_block(&mut color_stream, encoding.num_bits)))
{
*out = result.decode_color(encoding.num_bits);
}
}
}
out
}
fn unquantize_texel_weights(
weights_stream: &mut InputBitStream,
params: &TexelWeightParams,
block_width: u32,
block_height: u32,
) -> [[u32; 144]; 2] {
let mut out = [[0; 144]; 2];
let mut unquantized = [0; 96];
let plane_scale = if params.is_dual_plane { 2 } else { 1 };
let unquantized_range = &mut unquantized[0..params.get_num_weight_values() as usize];
let encoding = ENCODING_MAP[params.max_weight as usize];
match encoding.encoding {
IntegerEncodingType::JustBits => {
for out in unquantized_range.iter_mut() {
*out = replicate(
weights_stream.read_bits(encoding.num_bits),
encoding.num_bits,
6,
)
}
}
IntegerEncodingType::Trit => {
for (out, result) in unquantized_range
.iter_mut()
.zip((0..).flat_map(|_| decode_trit_block(weights_stream, encoding.num_bits)))
{
*out = result.decode_weight(encoding.num_bits);
}
}
IntegerEncodingType::Quint => {
for (out, result) in unquantized_range
.iter_mut()
.zip((0..).flat_map(|_| decode_quint_block(weights_stream, encoding.num_bits)))
{
*out = result.decode_weight(encoding.num_bits);
}
}
}
for weight in unquantized_range {
assert!(*weight < 64);
if *weight > 32 {
*weight += 1
}
}
let ds = (1024 + (block_width / 2)) / (block_width - 1);
let dt = (1024 + (block_height / 2)) / (block_height - 1);
for plane in 0..plane_scale {
for t in 0..block_height {
for s in 0..block_width {
let cs = ds * s;
let ct = dt * t;
let gs = (cs * (params.width - 1) + 32) >> 6;
let gt = (ct * (params.height - 1) + 32) >> 6;
let js = gs >> 4;
let fs = gs & 0xF;
let jt = gt >> 4;
let ft = gt & 0x0F;
let w11 = (fs * ft + 8) >> 4;
let w10 = ft - w11;
let w01 = fs - w11;
let w00 = 16 + w11 - fs - ft;
let v0 = js + jt * params.width;
let mut p00 = 0;
let mut p01 = 0;
let mut p10 = 0;
let mut p11 = 0;
if v0 < (params.width * params.height) {
p00 = unquantized[plane + plane_scale * (v0 as usize)];
}
if v0 + 1 < (params.width * params.height) {
p01 = unquantized[plane + plane_scale * ((v0 + 1) as usize)];
}
if v0 + params.width < (params.width * params.height) {
p10 = unquantized[plane + plane_scale * ((v0 + params.width) as usize)];
}
if v0 + params.width + 1 < (params.width * params.height) {
p11 = unquantized[plane + plane_scale * ((v0 + params.width + 1) as usize)];
}
out[plane][(t * block_width + s) as usize] =
(p00 * w00 + p01 * w01 + p10 * w10 + p11 * w11 + 8) >> 4;
}
}
}
out
}
fn bit_transfer_signed(a: &mut i32, b: &mut i32) {
*b >>= 1;
*b |= *a & 0x80;
*a >>= 1;
*a &= 0x3F;
if (*a & 0x20) != 0 {
*a -= 0x40;
}
}
fn hash52(p: u32) -> u32 {
let mut p = std::num::Wrapping(p);
p ^= p >> 15;
p -= p << 17;
p += p << 7;
p += p << 4;
p ^= p >> 5;
p += p << 16;
p ^= p >> 7;
p ^= p >> 3;
p ^= p << 6;
p ^= p >> 17;
p.0
}
fn select_partition(
mut seed: u32,
mut x: u32,
mut y: u32,
mut z: u32,
partition_count: usize,
small_block: bool,
) -> usize {
if 1 == partition_count {
return 0;
}
if small_block {
x <<= 1;
y <<= 1;
z <<= 1;
}
seed += (partition_count as u32 - 1) * 1024;
let rnum = hash52(seed);
let mut seed1 = (rnum & 0xF) as u8;
let mut seed2 = ((rnum >> 4) & 0xF) as u8;
let mut seed3 = ((rnum >> 8) & 0xF) as u8;
let mut seed4 = ((rnum >> 12) & 0xF) as u8;
let mut seed5 = ((rnum >> 16) & 0xF) as u8;
let mut seed6 = ((rnum >> 20) & 0xF) as u8;
let mut seed7 = ((rnum >> 24) & 0xF) as u8;
let mut seed8 = ((rnum >> 28) & 0xF) as u8;
let mut seed9 = ((rnum >> 18) & 0xF) as u8;
let mut seed10 = ((rnum >> 22) & 0xF) as u8;
let mut seed11 = ((rnum >> 26) & 0xF) as u8;
let mut seed12 = (((rnum >> 30) | (rnum << 2)) & 0xF) as u8;
seed1 = seed1 * seed1;
seed2 = seed2 * seed2;
seed3 = seed3 * seed3;
seed4 = seed4 * seed4;
seed5 = seed5 * seed5;
seed6 = seed6 * seed6;
seed7 = seed7 * seed7;
seed8 = seed8 * seed8;
seed9 = seed9 * seed9;
seed10 = seed10 * seed10;
seed11 = seed11 * seed11;
seed12 = seed12 * seed12;
let sh1: i32;
let sh2: i32;
let sh3: i32;
if seed & 1 != 0 {
sh1 = if seed & 2 != 0 { 4 } else { 5 };
sh2 = if partition_count == 3 { 6 } else { 5 };
} else {
sh1 = if partition_count == 3 { 6 } else { 5 };
sh2 = if seed & 2 != 0 { 4 } else { 5 };
}
sh3 = if seed & 0x10 != 0 { sh1 } else { sh2 };
seed1 >>= sh1;
seed2 >>= sh2;
seed3 >>= sh1;
seed4 >>= sh2;
seed5 >>= sh1;
seed6 >>= sh2;
seed7 >>= sh1;
seed8 >>= sh2;
seed9 >>= sh3;
seed10 >>= sh3;
seed11 >>= sh3;
seed12 >>= sh3;
let mut a = seed1 as u32 * x + seed2 as u32 * y + seed11 as u32 * z + (rnum >> 14);
let mut b = seed3 as u32 * x + seed4 as u32 * y + seed12 as u32 * z + (rnum >> 10);
let mut c = seed5 as u32 * x + seed6 as u32 * y + seed9 as u32 * z + (rnum >> 6);
let mut d = seed7 as u32 * x + seed8 as u32 * y + seed10 as u32 * z + (rnum >> 2);
a &= 0x3F;
b &= 0x3F;
c &= 0x3F;
d &= 0x3F;
if partition_count < 4 {
d = 0;
}
if partition_count < 3 {
c = 0;
}
if a >= b && a >= c && a >= d {
0
} else if b >= c && b >= d {
1
} else if c >= d {
2
} else {
3
}
}
fn select_2d_partition(
seed: u32,
x: u32,
y: u32,
partition_count: usize,
small_block: bool,
) -> usize {
select_partition(seed, x, y, 0, partition_count, small_block)
}
fn clamp_color(r: i32, g: i32, b: i32, a: i32) -> [u8; 4] {
[
r.clamp(0, 255) as u8,
g.clamp(0, 255) as u8,
b.clamp(0, 255) as u8,
a.clamp(0, 255) as u8,
]
}
fn blue_contract(r: i32, g: i32, b: i32, a: i32) -> [u8; 4] {
[
((r + b) >> 1).clamp(0, 255) as u8,
((g + b) >> 1).clamp(0, 255) as u8,
b.clamp(0, 255) as u8,
a.clamp(0, 255) as u8,
]
}
fn compute_endpoints(color_values: &mut &[u8], endpoint_mods: u32) -> [[u8; 4]; 2] {
let ep1: [u8; 4];
let ep2: [u8; 4];
macro_rules! read_int_values {
($N:expr) => {{
let mut v = [0; $N];
for i in 0..$N {
v[i] = color_values[0] as i32;
*color_values = &color_values[1..];
}
v
}};
}
macro_rules! bts {
($v:ident, $a:expr, $b: expr) => {{
let mut a = $v[$a];
let mut b = $v[$b];
bit_transfer_signed(&mut a, &mut b);
$v[$a] = a;
$v[$b] = b;
}};
}
match endpoint_mods {
0 => {
let v = read_int_values!(2);
ep1 = clamp_color(v[0], v[0], v[0], 0xFF);
ep2 = clamp_color(v[1], v[1], v[1], 0xFF);
}
1 => {
let v = read_int_values!(2);
let l0 = (v[0] >> 2) | (v[1] & 0xC0);
let l1 = std::cmp::min(l0 + (v[1] & 0x3F), 0xFF);
ep1 = clamp_color(l0, l0, l0, 0xFF);
ep2 = clamp_color(l1, l1, l1, 0xFF);
}
4 => {
let v = read_int_values!(4);
ep1 = clamp_color(v[0], v[0], v[0], v[2]);
ep2 = clamp_color(v[1], v[1], v[1], v[3]);
}
5 => {
let mut v = read_int_values!(4);
bts!(v, 1, 0);
bts!(v, 3, 2);
ep1 = clamp_color(v[0], v[0], v[0], v[2]);
ep2 = clamp_color(v[0] + v[1], v[0] + v[1], v[0] + v[1], v[2] + v[3]);
}
6 => {
let v = read_int_values!(4);
ep1 = clamp_color(
(v[0] * v[3]) >> 8,
(v[1] * v[3]) >> 8,
(v[2] * v[3]) >> 8,
0xFF,
);
ep2 = clamp_color(v[0], v[1], v[2], 0xFF);
}
8 => {
let v = read_int_values!(6);
if v[1] + v[3] + v[5] >= v[0] + v[2] + v[4] {
ep1 = clamp_color(v[0], v[2], v[4], 0xFF);
ep2 = clamp_color(v[1], v[3], v[5], 0xFF);
} else {
ep1 = blue_contract(v[1], v[3], v[5], 0xFF);
ep2 = blue_contract(v[0], v[2], v[4], 0xFF);
}
}
9 => {
let mut v = read_int_values!(6);
bts!(v, 1, 0);
bts!(v, 3, 2);
bts!(v, 5, 4);
if v[1] + v[3] + v[5] >= 0 {
ep1 = clamp_color(v[0], v[2], v[4], 0xFF);
ep2 = clamp_color(v[0] + v[1], v[2] + v[3], v[4] + v[5], 0xFF);
} else {
ep1 = blue_contract(v[0] + v[1], v[2] + v[3], v[4] + v[5], 0xFF);
ep2 = blue_contract(v[0], v[2], v[4], 0xFF);
}
}
10 => {
let v = read_int_values!(6);
ep1 = clamp_color(
(v[0] * v[3]) >> 8,
(v[1] * v[3]) >> 8,
(v[2] * v[3]) >> 8,
v[4],
);
ep2 = clamp_color(v[0], v[1], v[2], v[5]);
}
12 => {
let v = read_int_values!(8);
if v[1] + v[3] + v[5] >= v[0] + v[2] + v[4] {
ep1 = clamp_color(v[0], v[2], v[4], v[6]);
ep2 = clamp_color(v[1], v[3], v[5], v[7]);
} else {
ep1 = blue_contract(v[1], v[3], v[5], v[7]);
ep2 = blue_contract(v[0], v[2], v[4], v[6]);
}
}
13 => {
let mut v = read_int_values!(8);
bts!(v, 1, 0);
bts!(v, 3, 2);
bts!(v, 5, 4);
bts!(v, 7, 6);
if v[1] + v[3] + v[5] >= 0 {
ep1 = clamp_color(v[0], v[2], v[4], v[6]);
ep2 = clamp_color(v[0] + v[1], v[2] + v[3], v[4] + v[5], v[6] + v[7]);
} else {
ep1 = blue_contract(v[0] + v[1], v[2] + v[3], v[4] + v[5], v[6] + v[7]);
ep2 = blue_contract(v[0], v[2], v[4], v[6]);
}
}
_ => {
ep1 = [0xFF, 0, 0xFF, 0xFF];
ep2 = [0xFF, 0, 0xFF, 0xFF];
}
}
[ep1, ep2]
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct Footprint {
block_width: u32,
block_height: u32,
}
impl Footprint {
pub const ASTC_4X4: Footprint = Footprint {
block_width: 4,
block_height: 4,
};
pub const ASTC_5X4: Footprint = Footprint {
block_width: 5,
block_height: 4,
};
pub const ASTC_5X5: Footprint = Footprint {
block_width: 5,
block_height: 5,
};
pub const ASTC_6X5: Footprint = Footprint {
block_width: 6,
block_height: 5,
};
pub const ASTC_6X6: Footprint = Footprint {
block_width: 6,
block_height: 6,
};
pub const ASTC_8X5: Footprint = Footprint {
block_width: 8,
block_height: 5,
};
pub const ASTC_8X6: Footprint = Footprint {
block_width: 8,
block_height: 6,
};
pub const ASTC_10X5: Footprint = Footprint {
block_width: 10,
block_height: 5,
};
pub const ASTC_10X6: Footprint = Footprint {
block_width: 10,
block_height: 6,
};
pub const ASTC_8X8: Footprint = Footprint {
block_width: 8,
block_height: 8,
};
pub const ASTC_10X8: Footprint = Footprint {
block_width: 10,
block_height: 8,
};
pub const ASTC_10X10: Footprint = Footprint {
block_width: 10,
block_height: 10,
};
pub const ASTC_12X10: Footprint = Footprint {
block_width: 12,
block_height: 10,
};
pub const ASTC_12X12: Footprint = Footprint {
block_width: 12,
block_height: 12,
};
pub fn new(block_width: u32, block_height: u32) -> Footprint {
if block_width == 0 || block_height == 0 {
panic!("Invalid block size")
}
Footprint {
block_width,
block_height,
}
}
pub fn block_width(&self) -> u32 {
self.block_width
}
pub fn block_height(&self) -> u32 {
self.block_height
}
}
pub fn astc_decode_block<F: FnMut(u32, u32, [u8; 4])>(
input: &[u8; 16],
footprint: Footprint,
mut writer: F,
) -> bool {
let block_width = footprint.block_width;
let block_height = footprint.block_height;
let mut strm = InputBitStream::new(u128::from_le_bytes(*input));
let weight_params = decode_block_info(&mut strm);
if weight_params.is_error {
fill_error(&mut writer, block_width, block_height);
return false;
}
if weight_params.void_extent_ldr {
fill_void_extent_ldr(&mut strm, &mut writer, block_width, block_height);
return true;
}
if weight_params.void_extent_hdr {
fill_error(&mut writer, block_width, block_height);
return false;
}
if weight_params.width > block_width {
fill_error(&mut writer, block_width, block_height);
return false;
}
if weight_params.height > block_height {
fill_error(&mut writer, block_width, block_height);
return false;
}
if weight_params.get_num_weight_values() > 64 {
fill_error(&mut writer, block_width, block_height);
return false;
}
let n_weight_bits = weight_params.get_packed_bit_size();
if !(24..=96).contains(&n_weight_bits) {
fill_error(&mut writer, block_width, block_height);
return false;
}
let n_partitions = (strm.read_bits(2) + 1) as usize;
assert!(n_partitions <= 4);
if n_partitions == 4 && weight_params.is_dual_plane {
fill_error(&mut writer, block_width, block_height);
return false;
}
let plane_idx;
let partition_index;
let mut endpoint_mods = [0, 0, 0, 0];
let endpoint_mods = &mut endpoint_mods[0..n_partitions];
let mut base_cem = 0;
if n_partitions == 1 {
endpoint_mods[0] = strm.read_bits(4);
partition_index = 0;
} else {
partition_index = strm.read_bits(10);
base_cem = strm.read_bits(6);
}
let base_mode = base_cem & 3;
let mut non_color_bits = n_weight_bits + strm.get_bits_read();
let mut extra_cem_bits = 0;
if base_mode != 0 {
match n_partitions {
2 => extra_cem_bits += 2,
3 => extra_cem_bits += 5,
4 => extra_cem_bits += 8,
_ => unreachable!(),
}
}
non_color_bits += extra_cem_bits;
let mut plane_selector_bits = 0;
if weight_params.is_dual_plane {
plane_selector_bits = 2;
}
non_color_bits += plane_selector_bits;
if non_color_bits >= 128 {
fill_error(&mut writer, block_width, block_height);
return false;
}
let color_data_bits = 128 - non_color_bits;
let endpoint_data = strm.read_bits128(color_data_bits);
plane_idx = strm.read_bits(plane_selector_bits);
if base_mode != 0 {
let extra_cem = strm.read_bits(extra_cem_bits);
let mut cem = (extra_cem << 6) | base_cem;
cem >>= 2;
let mut c = [false; 4];
for c in &mut c[0..n_partitions] {
*c = (cem & 1) != 0;
cem >>= 1;
}
let mut m = [0; 4];
for m in &mut m[0..n_partitions] {
*m = cem & 3;
cem >>= 2;
}
for (i, endpoint_mod) in endpoint_mods.iter_mut().enumerate() {
*endpoint_mod = base_mode;
if !c[i] {
*endpoint_mod -= 1;
}
*endpoint_mod <<= 2;
*endpoint_mod |= m[i];
}
} else if n_partitions > 1 {
let cem = base_cem >> 2;
endpoint_mods[0..n_partitions].fill(cem);
}
for &endpoint_mod in endpoint_mods.iter() {
assert!(endpoint_mod < 16);
}
assert!(strm.get_bits_read() + weight_params.get_packed_bit_size() == 128);
let n_values = endpoint_mods.iter().map(|m| ((m >> 2) + 1) << 1).sum();
if n_values > 18 || (n_values * 13 + 4) / 5 > color_data_bits {
fill_error(&mut writer, block_width, block_height);
return false;
}
let color_values = decode_color_values(endpoint_data, n_values, color_data_bits);
let mut endpoints = [[[0; 4]; 2]; 4];
let mut color_values_ptr = &color_values[0..n_values as usize];
for i in 0..n_partitions {
endpoints[i] = compute_endpoints(&mut color_values_ptr, endpoint_mods[i]);
}
let mut texel_weight_data = u128::from_le_bytes(*input).reverse_bits();
texel_weight_data &= (1 << weight_params.get_packed_bit_size()) - 1;
let mut weight_stream = InputBitStream::new(texel_weight_data);
let weights = unquantize_texel_weights(
&mut weight_stream,
&weight_params,
block_width,
block_height,
);
for j in 0..block_height {
for i in 0..block_width {
let partition = select_2d_partition(
partition_index,
i,
j,
n_partitions,
(block_height * block_width) < 32,
);
assert!(partition < n_partitions);
let mut p = [0; 4];
for (c, p) in p.iter_mut().enumerate() {
let c0 = endpoints[partition][0][c] as u32 * 0x101;
let c1 = endpoints[partition][1][c] as u32 * 0x101;
let mut plane = 0;
if weight_params.is_dual_plane && (plane_idx & 3 == c as u32) {
plane = 1;
}
let weight = weights[plane][(j * block_width + i) as usize];
let color = (c0 * (64 - weight) + c1 * weight + 32) / 64;
*p = u8::try_from(((color * 255) + 32767) / 65536).unwrap();
}
writer(i, j, p);
}
}
true
}
pub fn astc_decode<R: Read, F: FnMut(u32, u32, [u8; 4])>(
mut input: R,
width: u32,
height: u32,
footprint: Footprint,
mut writer: F,
) -> Result<()> {
let block_width = footprint.block_width;
let block_height = footprint.block_height;
let block_w = (width.checked_add(block_width).unwrap() - 1) / block_width;
let block_h = (height.checked_add(block_height).unwrap() - 1) / block_height;
for by in 0..block_h {
for bx in 0..block_w {
let mut block_buf = [0; 16];
input.read_exact(&mut block_buf)?;
astc_decode_block(&block_buf, footprint, |x, y, v| {
let x = bx * block_width + x;
let y = by * block_height + y;
if x < width && y < height {
writer(x, y, v)
}
});
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use image::Pixel;
fn dist(a: u8, b: u8) {
assert!((a as i32 - b as i32).abs() <= 1)
}
fn test_case(astc: &[u8], bmp: &[u8], block_width: u32, block_height: u32) {
let bmp = image::load_from_memory(bmp).unwrap().to_rgba8();
let width = bmp.width();
let height = bmp.height();
astc_decode(
&astc[16..],
width,
height,
Footprint::new(block_width, block_height),
|x, y, v| {
let y = height - y - 1;
let p = bmp.get_pixel(x as u32, y as u32).channels();
dist(p[0], v[0]);
dist(p[1], v[1]);
dist(p[2], v[2]);
dist(p[3], v[3]);
},
)
.unwrap();
}
macro_rules! tc {
($name:literal, $bw:literal, $bh:literal) => {
test_case(
include_bytes!(concat!("test-data/", $name, '_', $bw, 'x', $bh, ".astc")),
include_bytes!(concat!("test-data/", $name, '_', $bw, 'x', $bh, ".bmp")),
$bw,
$bh,
);
};
}
#[test]
fn real_image() {
tc!("atlas_small", 4, 4);
tc!("atlas_small", 5, 5);
tc!("atlas_small", 6, 6);
tc!("atlas_small", 8, 8);
tc!("footprint", 4, 4);
tc!("footprint", 5, 4);
tc!("footprint", 5, 5);
tc!("footprint", 6, 5);
tc!("footprint", 6, 6);
tc!("footprint", 8, 5);
tc!("footprint", 8, 6);
tc!("footprint", 8, 8);
tc!("footprint", 10, 5);
tc!("footprint", 10, 6);
tc!("footprint", 10, 8);
tc!("footprint", 10, 10);
tc!("footprint", 12, 10);
tc!("footprint", 12, 12);
tc!("rgb", 4, 4);
tc!("rgb", 5, 4);
tc!("rgb", 6, 6);
tc!("rgb", 8, 8);
tc!("rgb", 12, 12);
}
fn fuzz_fp(w: u32, h: u32) {
let footprint = Footprint::new(w, h);
for _ in 0..10000 {
let block = rand::random();
astc_decode_block(&block, footprint, |_, _, _| {});
}
}
#[test]
fn fuzzing() {
fuzz_fp(4, 4);
fuzz_fp(5, 4);
fuzz_fp(5, 5);
fuzz_fp(6, 5);
fuzz_fp(6, 6);
fuzz_fp(8, 5);
fuzz_fp(8, 6);
fuzz_fp(8, 8);
fuzz_fp(10, 5);
fuzz_fp(10, 6);
fuzz_fp(10, 8);
fuzz_fp(10, 10);
fuzz_fp(12, 10);
fuzz_fp(12, 12);
}
}