1#![allow(clippy::many_single_char_names)]
36
37use std::convert::TryFrom;
38use std::io::*;
39
40struct InputBitStream {
41 data: u128,
42 bits_read: u32,
43}
44
45impl InputBitStream {
46 fn new(data: u128) -> InputBitStream {
47 InputBitStream { data, bits_read: 0 }
48 }
49
50 fn get_bits_read(&self) -> u32 {
51 self.bits_read
52 }
53
54 fn read_bit(&mut self) -> u32 {
55 self.read_bits(1)
56 }
57
58 fn read_bits(&mut self, n_bits: u32) -> u32 {
59 assert!(n_bits <= 32);
60 self.read_bits128(n_bits) as u32
61 }
62
63 fn read_bits128(&mut self, n_bits: u32) -> u128 {
64 self.bits_read += n_bits;
65 assert!(self.bits_read <= 128);
66 let ret = self.data & ((1 << n_bits) - 1);
67 self.data >>= n_bits;
68 ret
69 }
70}
71
72struct Bits(u32);
73
74impl Bits {
75 fn get(&self, pos: u32) -> u32 {
76 (self.0 >> pos) & 1
77 }
78
79 fn range(&self, start: u32, end: u32) -> u32 {
80 let mask = (1 << (end - start + 1)) - 1;
81 (self.0 >> start) & mask
82 }
83}
84
85#[derive(Clone, Copy, PartialEq, Eq)]
86enum IntegerEncodingType {
87 JustBits,
88 Quint,
89 Trit,
90}
91
92#[derive(Clone, Copy, PartialEq, Eq)]
93struct IntegerEncoding {
94 encoding: IntegerEncodingType,
95 num_bits: u32,
96}
97
98impl IntegerEncoding {
99 fn get_bit_length(&self, n_vals: u32) -> u32 {
101 let mut total_bits = self.num_bits * n_vals;
102 match self.encoding {
103 IntegerEncodingType::JustBits => (),
104 IntegerEncodingType::Trit => total_bits += (n_vals * 8 + 4) / 5,
105 IntegerEncodingType::Quint => total_bits += (n_vals * 7 + 2) / 3,
106 }
107 total_bits
108 }
109}
110
111fn decode_color(a: u32, b: u32, c: u32, d: u32) -> u8 {
112 let t = (d * c + b) ^ a;
113 u8::try_from((a & 0x80) | (t >> 2)).unwrap()
114}
115
116fn decode_weight(a: u32, b: u32, c: u32, d: u32) -> u32 {
117 let t = (d * c + b) ^ a;
118 (a & 0x20) | (t >> 2)
119}
120
121struct Trit {
122 trit_value: u32,
123 bit_value: u32,
124}
125
126impl Trit {
127 fn decode_color(self, bitlen: u32) -> u8 {
128 let bitval = self.bit_value;
129 let a = (bitval & 1) * 0x1FF;
130 let x = bitval >> 1;
131 let b;
132 let c;
133 match bitlen {
134 1 => {
135 c = 204;
136 b = 0;
137 }
138
139 2 => {
140 c = 93;
141 b = (x << 8) | (x << 4) | (x << 2) | (x << 1);
143 }
144
145 3 => {
146 c = 44;
147 b = (x << 7) | (x << 2) | x;
149 }
150
151 4 => {
152 c = 22;
153 b = (x << 6) | x;
155 }
156
157 5 => {
158 c = 11;
159 b = (x << 5) | (x >> 2);
161 }
162
163 6 => {
164 c = 5;
165 b = (x << 4) | (x >> 4);
167 }
168
169 _ => unreachable!("Invalid trit encoding for color values"),
170 }
171 decode_color(a, b, c, self.trit_value)
172 }
173
174 fn decode_weight(self, bitlen: u32) -> u32 {
175 let bitval = self.bit_value;
176 let a = (bitval & 1) * 0x7F;
177 let x = bitval >> 1;
178 let b;
179 let c;
180 match bitlen {
181 0 => {
182 return [0, 32, 63][self.trit_value as usize];
183 }
184
185 1 => {
186 c = 50;
187 b = 0;
188 }
189
190 2 => {
191 c = 23;
192 b = (x << 6) | (x << 2) | x;
193 }
194
195 3 => {
196 c = 11;
197 b = (x << 5) | x;
198 }
199
200 _ => unreachable!("Invalid trit encoding for texel weight"),
201 }
202 decode_weight(a, b, c, self.trit_value)
203 }
204}
205
206struct Quint {
207 quint_value: u32,
208 bit_value: u32,
209}
210
211impl Quint {
212 fn decode_color(self, bitlen: u32) -> u8 {
213 let bitval = self.bit_value;
214 let a = (bitval & 1) * 0x1FF;
215 let x = bitval >> 1;
216 let b;
217 let c;
218 match bitlen {
219 1 => {
220 c = 113;
221 b = 0;
222 }
223
224 2 => {
225 c = 54;
226 b = (x << 8) | (x << 3) | (x << 2);
228 }
229
230 3 => {
231 c = 26;
232 b = (x << 7) | (x << 1) | (x >> 1);
234 }
235
236 4 => {
237 c = 13;
238 b = (x << 6) | (x >> 1);
240 }
241
242 5 => {
243 c = 6;
244 b = (x << 5) | (x >> 3);
246 }
247
248 _ => unreachable!("Invalid quint encoding for color values"),
249 }
250 decode_color(a, b, c, self.quint_value)
251 }
252
253 fn decode_weight(self, bitlen: u32) -> u32 {
254 let bitval = self.bit_value;
255 let a = (bitval & 1) * 0x7F;
256 let b;
257 let c;
258 match bitlen {
259 0 => {
260 return [0, 16, 32, 47, 63][self.quint_value as usize];
261 }
262
263 1 => {
264 c = 28;
265 b = 0;
266 }
267
268 2 => {
269 c = 13;
270 let x = bitval >> 1;
271 b = (x << 6) | (x << 1);
272 }
273
274 _ => unreachable!("Invalid quint encoding for texel weight"),
275 }
276 decode_weight(a, b, c, self.quint_value)
277 }
278}
279
280fn decode_trit_block(bits: &mut InputBitStream, bits_per_value: u32) -> impl Iterator<Item = Trit> {
281 let mut m = [0u32; 5];
283 let mut t = [0u32; 5];
284 let mut tt: u32;
285
286 m[0] = bits.read_bits(bits_per_value);
289 tt = bits.read_bits(2);
290 m[1] = bits.read_bits(bits_per_value);
291 tt |= bits.read_bits(2) << 2;
292 m[2] = bits.read_bits(bits_per_value);
293 tt |= (bits.read_bit()) << 4;
294 m[3] = bits.read_bits(bits_per_value);
295 tt |= bits.read_bits(2) << 5;
296 m[4] = bits.read_bits(bits_per_value);
297 tt |= (bits.read_bit()) << 7;
298
299 let c: u32;
300
301 let tb = Bits(tt);
302 if tb.range(2, 4) == 7 {
303 c = (tb.range(5, 7) << 2) | tb.range(0, 1);
304 t[3] = 2;
305 t[4] = 2;
306 } else {
307 c = tb.range(0, 4);
308 if tb.range(5, 6) == 3 {
309 t[4] = 2;
310 t[3] = tb.get(7);
311 } else {
312 t[4] = tb.get(7);
313 t[3] = tb.range(5, 6);
314 }
315 }
316
317 let cb = Bits(c);
318 if cb.range(0, 1) == 3 {
319 t[2] = 2;
320 t[1] = cb.get(4);
321 t[0] = (cb.get(3) << 1) | (cb.get(2) & !cb.get(3));
322 } else if cb.range(2, 3) == 3 {
323 t[2] = 2;
324 t[1] = 2;
325 t[0] = cb.range(0, 1);
326 } else {
327 t[2] = cb.get(4);
328 t[1] = cb.range(2, 3);
329 t[0] = (cb.get(1) << 1) | (cb.get(0) & !cb.get(1));
330 }
331
332 IntoIterator::into_iter(m)
333 .zip(t)
334 .map(|(bit_value, trit_value)| Trit {
335 trit_value,
336 bit_value,
337 })
338}
339
340fn decode_quint_block(
341 bits: &mut InputBitStream,
342 bits_per_value: u32,
343) -> impl Iterator<Item = Quint> {
344 let mut m = [0u32; 3];
346 let mut q = [0u32; 3];
347 let mut qq: u32;
348
349 m[0] = bits.read_bits(bits_per_value);
352 qq = bits.read_bits(3);
353 m[1] = bits.read_bits(bits_per_value);
354 qq |= bits.read_bits(2) << 3;
355 m[2] = bits.read_bits(bits_per_value);
356 qq |= bits.read_bits(2) << 5;
357
358 let qb = Bits(qq);
359 if qb.range(1, 2) == 3 && qb.range(5, 6) == 0 {
360 q[0] = 4;
361 q[1] = 4;
362 q[2] = (qb.get(0) << 2) | ((qb.get(4) & !qb.get(0)) << 1) | (qb.get(3) & !qb.get(0));
363 } else {
364 let c;
365 if qb.range(1, 2) == 3 {
366 q[2] = 4;
367 c = (qb.range(3, 4) << 3) | ((!qb.range(5, 6) & 3) << 1) | qb.get(0);
368 } else {
369 q[2] = qb.range(5, 6);
370 c = qb.range(0, 4);
371 }
372
373 let cb = Bits(c);
374 if cb.range(0, 2) == 5 {
375 q[1] = 4;
376 q[0] = cb.range(3, 4);
377 } else {
378 q[1] = cb.range(3, 4);
379 q[0] = cb.range(0, 2);
380 }
381 }
382
383 IntoIterator::into_iter(m)
384 .zip(q)
385 .map(|(bit_value, quint_value)| Quint {
386 quint_value,
387 bit_value,
388 })
389}
390
391const fn create_encoding(mut max_val: u32) -> IntegerEncoding {
393 while max_val > 0 {
394 let check = max_val + 1;
395
396 if (check & (check - 1)) == 0 {
398 return IntegerEncoding {
399 encoding: IntegerEncodingType::JustBits,
400 num_bits: max_val.count_ones(),
401 };
402 }
403
404 if (check % 3 == 0) && ((check / 3) & ((check / 3) - 1)) == 0 {
406 return IntegerEncoding {
407 encoding: IntegerEncodingType::Trit,
408 num_bits: (check / 3 - 1).count_ones(),
409 };
410 }
411
412 if (check % 5 == 0) && ((check / 5) & ((check / 5) - 1)) == 0 {
414 return IntegerEncoding {
415 encoding: IntegerEncodingType::Quint,
416 num_bits: (check / 5 - 1).count_ones(),
417 };
418 }
419
420 max_val -= 1;
423 }
424 IntegerEncoding {
425 encoding: IntegerEncodingType::JustBits,
426 num_bits: 0,
427 }
428}
429
430static ENCODING_MAP: [IntegerEncoding; 256] = {
431 let mut result = [IntegerEncoding {
432 encoding: IntegerEncodingType::JustBits,
433 num_bits: 0,
434 }; 256];
435 let mut i = 0;
436 while i < 256 {
437 result[i as usize] = create_encoding(i);
438 i += 1;
439 }
440 result
441};
442
443static ENCODING_SEQ: ([IntegerEncoding; 256], usize) = {
444 let mut result = [IntegerEncoding {
445 encoding: IntegerEncodingType::JustBits,
446 num_bits: 0,
447 }; 256];
448 let mut len = 1;
449 result[0] = ENCODING_MAP[0];
450 let mut i = 1;
451 while i < 256 {
452 let encoding = ENCODING_MAP[i];
453 let previous = result[len - 1];
454 if encoding.encoding as u32 != previous.encoding as u32
456 || encoding.num_bits != previous.num_bits
457 {
458 result[len] = encoding;
459 len += 1;
460 }
461 i += 1;
462 }
463 (result, len)
464};
465
466struct TexelWeightParams {
467 width: u32,
468 height: u32,
469 is_dual_plane: bool,
470 max_weight: u32,
471 is_error: bool,
472 void_extent_ldr: bool,
473 void_extent_hdr: bool,
474}
475
476impl Default for TexelWeightParams {
477 fn default() -> Self {
478 TexelWeightParams {
479 width: 0,
480 height: 0,
481 is_dual_plane: false,
482 max_weight: 0,
483 is_error: false,
484 void_extent_ldr: false,
485 void_extent_hdr: false,
486 }
487 }
488}
489
490impl TexelWeightParams {
491 fn get_packed_bit_size(&self) -> u32 {
492 ENCODING_MAP[self.max_weight as usize].get_bit_length(self.get_num_weight_values())
493 }
494
495 fn get_num_weight_values(&self) -> u32 {
496 let mut ret = self.width * self.height;
497 if self.is_dual_plane {
498 ret *= 2;
499 }
500 ret
501 }
502}
503
504fn decode_block_info(strm: &mut InputBitStream) -> TexelWeightParams {
505 let mut params = TexelWeightParams::default();
506
507 let mode_bits = strm.read_bits(11);
509
510 if (mode_bits & 0x01FF) == 0x1FC {
512 if mode_bits & 0x200 != 0 {
513 params.void_extent_hdr = true;
514 } else {
515 params.void_extent_ldr = true;
516 }
517
518 if (mode_bits & 0x400) == 0 || strm.read_bit() == 0 {
520 params.is_error = true;
521 }
522
523 return params;
524 }
525
526 if (mode_bits & 0xF) == 0 {
528 params.is_error = true;
529 return params;
530 }
531
532 if (mode_bits & 0x3) == 0 && (mode_bits & 0x1C0) == 0x1C0 {
535 params.is_error = true;
536 return params;
537 }
538
539 let layout;
544
545 if (mode_bits & 0x1) != 0 || (mode_bits & 0x2) != 0 {
546 if (mode_bits & 0x8) != 0 {
548 if (mode_bits & 0x4) != 0 {
550 if (mode_bits & 0x100) != 0 {
552 layout = 4;
553 } else {
554 layout = 3;
555 }
556 } else {
557 layout = 2;
558 }
559 } else {
560 if (mode_bits & 0x4) != 0 {
562 layout = 1;
563 } else {
564 layout = 0;
565 }
566 }
567 } else {
568 if (mode_bits & 0x100) != 0 {
570 if (mode_bits & 0x80) != 0 {
572 assert!((mode_bits & 0x40) == 0);
574 if (mode_bits & 0x20) != 0 {
575 layout = 8;
576 } else {
577 layout = 7;
578 }
579 } else {
580 layout = 9;
581 }
582 } else {
583 if (mode_bits & 0x80) != 0 {
585 layout = 6;
586 } else {
587 layout = 5;
588 }
589 }
590 }
591
592 let mut r = (mode_bits & 0x10) >> 4;
594 if layout < 5 {
595 r |= (mode_bits & 0x3) << 1;
596 } else {
597 r |= (mode_bits & 0xC) >> 1;
598 }
599 assert!((2..=7).contains(&r));
600
601 match layout {
603 0 => {
604 let a = (mode_bits >> 5) & 0x3;
605 let b = (mode_bits >> 7) & 0x3;
606 params.width = b + 4;
607 params.height = a + 2;
608 }
609
610 1 => {
611 let a = (mode_bits >> 5) & 0x3;
612 let b = (mode_bits >> 7) & 0x3;
613 params.width = b + 8;
614 params.height = a + 2;
615 }
616
617 2 => {
618 let a = (mode_bits >> 5) & 0x3;
619 let b = (mode_bits >> 7) & 0x3;
620 params.width = a + 2;
621 params.height = b + 8;
622 }
623
624 3 => {
625 let a = (mode_bits >> 5) & 0x3;
626 let b = (mode_bits >> 7) & 0x1;
627 params.width = a + 2;
628 params.height = b + 6;
629 }
630
631 4 => {
632 let a = (mode_bits >> 5) & 0x3;
633 let b = (mode_bits >> 7) & 0x1;
634 params.width = b + 2;
635 params.height = a + 2;
636 }
637
638 5 => {
639 let a = (mode_bits >> 5) & 0x3;
640 params.width = 12;
641 params.height = a + 2;
642 }
643
644 6 => {
645 let a = (mode_bits >> 5) & 0x3;
646 params.width = a + 2;
647 params.height = 12;
648 }
649
650 7 => {
651 params.width = 6;
652 params.height = 10;
653 }
654
655 8 => {
656 params.width = 10;
657 params.height = 6;
658 }
659
660 9 => {
661 let a = (mode_bits >> 5) & 0x3;
662 let b = (mode_bits >> 9) & 0x3;
663 params.width = a + 6;
664 params.height = b + 6;
665 }
666
667 _ => unreachable!("Impossible layout"),
668 }
669
670 let dp = (layout != 9) && (mode_bits & 0x400) != 0;
673 let p = (layout != 9) && (mode_bits & 0x200) != 0;
674
675 let max_weights = if p {
676 [9, 11, 15, 19, 23, 31]
677 } else {
678 [1, 2, 3, 4, 5, 7]
679 };
680 params.max_weight = max_weights[(r - 2) as usize];
681
682 params.is_dual_plane = dp;
683
684 params
685}
686
687fn fill_void_extent_ldr<F: FnMut(u32, u32, [u8; 4])>(
688 strm: &mut InputBitStream,
689 writer: &mut F,
690 block_width: u32,
691 block_height: u32,
692) {
693 for _ in 0..4 {
695 strm.read_bits(13);
696 }
697
698 let r = strm.read_bits(16) >> 8;
700 let g = strm.read_bits(16) >> 8;
701 let b = strm.read_bits(16) >> 8;
702 let a = strm.read_bits(16) >> 8;
703
704 for j in 0..block_height {
705 for i in 0..block_width {
706 writer(i, j, [r as u8, g as u8, b as u8, a as u8]);
707 }
708 }
709}
710
711fn fill_error<F: FnMut(u32, u32, [u8; 4])>(writer: &mut F, block_width: u32, block_height: u32) {
712 for j in 0..block_height {
713 for i in 0..block_width {
714 writer(i, j, [0xFF, 0, 0xFF, 0xFF]);
715 }
716 }
717}
718
719fn replicate(val: u32, num_bits: u32, to_bit: u32) -> u32 {
722 if num_bits == 0 {
723 return 0;
724 }
725 if to_bit == 0 {
726 return 0;
727 }
728
729 let mut res = val << (to_bit - num_bits);
730 let mut shift = num_bits;
731 loop {
732 let next = res >> shift;
733 if next == 0 {
734 return res;
735 }
736 res |= next;
737 shift *= 2;
738 }
739}
740
741fn decode_color_values(data: u128, n_values: u32, n_bits_for_color_data: u32) -> [u8; 18] {
742 let mut out = [0; 18];
743 let out_range = &mut out[0..n_values as usize];
744
745 let encoding_i = ENCODING_SEQ.0[0..ENCODING_SEQ.1]
748 .partition_point(|v| v.get_bit_length(n_values) <= n_bits_for_color_data);
749 let encoding = ENCODING_SEQ.0[encoding_i - 1];
750
751 let mut color_stream = InputBitStream::new(data);
753 match encoding.encoding {
754 IntegerEncodingType::JustBits => {
755 for out in out_range {
756 *out = replicate(
757 color_stream.read_bits(encoding.num_bits),
758 encoding.num_bits,
759 8,
760 ) as u8;
761 }
762 }
763 IntegerEncodingType::Trit => {
764 for (out, result) in out_range
765 .iter_mut()
766 .zip((0..).flat_map(|_| decode_trit_block(&mut color_stream, encoding.num_bits)))
767 {
768 *out = result.decode_color(encoding.num_bits);
769 }
770 }
771 IntegerEncodingType::Quint => {
772 for (out, result) in out_range
773 .iter_mut()
774 .zip((0..).flat_map(|_| decode_quint_block(&mut color_stream, encoding.num_bits)))
775 {
776 *out = result.decode_color(encoding.num_bits);
777 }
778 }
779 }
780
781 out
782}
783
784fn unquantize_texel_weights(
785 weights_stream: &mut InputBitStream,
786 params: &TexelWeightParams,
787 block_width: u32,
788 block_height: u32,
789) -> [[u32; 144]; 2] {
790 let mut out = [[0; 144]; 2];
792 let mut unquantized = [0; 96];
793
794 let plane_scale = if params.is_dual_plane { 2 } else { 1 };
795
796 let unquantized_range = &mut unquantized[0..params.get_num_weight_values() as usize];
797 let encoding = ENCODING_MAP[params.max_weight as usize];
798
799 match encoding.encoding {
800 IntegerEncodingType::JustBits => {
801 for out in unquantized_range.iter_mut() {
802 *out = replicate(
803 weights_stream.read_bits(encoding.num_bits),
804 encoding.num_bits,
805 6,
806 )
807 }
808 }
809 IntegerEncodingType::Trit => {
810 for (out, result) in unquantized_range
811 .iter_mut()
812 .zip((0..).flat_map(|_| decode_trit_block(weights_stream, encoding.num_bits)))
813 {
814 *out = result.decode_weight(encoding.num_bits);
815 }
816 }
817 IntegerEncodingType::Quint => {
818 for (out, result) in unquantized_range
819 .iter_mut()
820 .zip((0..).flat_map(|_| decode_quint_block(weights_stream, encoding.num_bits)))
821 {
822 *out = result.decode_weight(encoding.num_bits);
823 }
824 }
825 }
826
827 for weight in unquantized_range {
828 assert!(*weight < 64);
829 if *weight > 32 {
830 *weight += 1
831 }
832 }
833
834 let ds = (1024 + (block_width / 2)) / (block_width - 1);
836 let dt = (1024 + (block_height / 2)) / (block_height - 1);
837
838 for plane in 0..plane_scale {
839 for t in 0..block_height {
840 for s in 0..block_width {
841 let cs = ds * s;
842 let ct = dt * t;
843
844 let gs = (cs * (params.width - 1) + 32) >> 6;
845 let gt = (ct * (params.height - 1) + 32) >> 6;
846
847 let js = gs >> 4;
848 let fs = gs & 0xF;
849
850 let jt = gt >> 4;
851 let ft = gt & 0x0F;
852
853 let w11 = (fs * ft + 8) >> 4;
854 let w10 = ft - w11;
855 let w01 = fs - w11;
856 let w00 = 16 + w11 - fs - ft;
857
858 let v0 = js + jt * params.width;
859
860 let mut p00 = 0;
861 let mut p01 = 0;
862 let mut p10 = 0;
863 let mut p11 = 0;
864
865 if v0 < (params.width * params.height) {
866 p00 = unquantized[plane + plane_scale * (v0 as usize)];
867 }
868
869 if v0 + 1 < (params.width * params.height) {
870 p01 = unquantized[plane + plane_scale * ((v0 + 1) as usize)];
871 }
872
873 if v0 + params.width < (params.width * params.height) {
874 p10 = unquantized[plane + plane_scale * ((v0 + params.width) as usize)];
875 }
876
877 if v0 + params.width + 1 < (params.width * params.height) {
878 p11 = unquantized[plane + plane_scale * ((v0 + params.width + 1) as usize)];
879 }
880
881 out[plane][(t * block_width + s) as usize] =
882 (p00 * w00 + p01 * w01 + p10 * w10 + p11 * w11 + 8) >> 4;
883 }
884 }
885 }
886 out
887}
888
889fn bit_transfer_signed(a: &mut i32, b: &mut i32) {
891 *b >>= 1;
892 *b |= *a & 0x80;
893 *a >>= 1;
894 *a &= 0x3F;
895 if (*a & 0x20) != 0 {
896 *a -= 0x40;
897 }
898}
899
900fn hash52(p: u32) -> u32 {
903 let mut p = std::num::Wrapping(p);
904 p ^= p >> 15;
905 p -= p << 17;
906 p += p << 7;
907 p += p << 4;
908 p ^= p >> 5;
909 p += p << 16;
910 p ^= p >> 7;
911 p ^= p >> 3;
912 p ^= p << 6;
913 p ^= p >> 17;
914 p.0
915}
916
917fn select_partition(
918 mut seed: u32,
919 mut x: u32,
920 mut y: u32,
921 mut z: u32,
922 partition_count: usize,
923 small_block: bool,
924) -> usize {
925 if 1 == partition_count {
926 return 0;
927 }
928
929 if small_block {
930 x <<= 1;
931 y <<= 1;
932 z <<= 1;
933 }
934
935 seed += (partition_count as u32 - 1) * 1024;
936
937 let rnum = hash52(seed);
938 let mut seed1 = (rnum & 0xF) as u8;
939 let mut seed2 = ((rnum >> 4) & 0xF) as u8;
940 let mut seed3 = ((rnum >> 8) & 0xF) as u8;
941 let mut seed4 = ((rnum >> 12) & 0xF) as u8;
942 let mut seed5 = ((rnum >> 16) & 0xF) as u8;
943 let mut seed6 = ((rnum >> 20) & 0xF) as u8;
944 let mut seed7 = ((rnum >> 24) & 0xF) as u8;
945 let mut seed8 = ((rnum >> 28) & 0xF) as u8;
946 let mut seed9 = ((rnum >> 18) & 0xF) as u8;
947 let mut seed10 = ((rnum >> 22) & 0xF) as u8;
948 let mut seed11 = ((rnum >> 26) & 0xF) as u8;
949 let mut seed12 = (((rnum >> 30) | (rnum << 2)) & 0xF) as u8;
950
951 seed1 = seed1 * seed1;
952 seed2 = seed2 * seed2;
953 seed3 = seed3 * seed3;
954 seed4 = seed4 * seed4;
955 seed5 = seed5 * seed5;
956 seed6 = seed6 * seed6;
957 seed7 = seed7 * seed7;
958 seed8 = seed8 * seed8;
959 seed9 = seed9 * seed9;
960 seed10 = seed10 * seed10;
961 seed11 = seed11 * seed11;
962 seed12 = seed12 * seed12;
963
964 let sh1: i32;
965 let sh2: i32;
966 let sh3: i32;
967 if seed & 1 != 0 {
968 sh1 = if seed & 2 != 0 { 4 } else { 5 };
969 sh2 = if partition_count == 3 { 6 } else { 5 };
970 } else {
971 sh1 = if partition_count == 3 { 6 } else { 5 };
972 sh2 = if seed & 2 != 0 { 4 } else { 5 };
973 }
974 sh3 = if seed & 0x10 != 0 { sh1 } else { sh2 };
975
976 seed1 >>= sh1;
977 seed2 >>= sh2;
978 seed3 >>= sh1;
979 seed4 >>= sh2;
980 seed5 >>= sh1;
981 seed6 >>= sh2;
982 seed7 >>= sh1;
983 seed8 >>= sh2;
984 seed9 >>= sh3;
985 seed10 >>= sh3;
986 seed11 >>= sh3;
987 seed12 >>= sh3;
988
989 let mut a = seed1 as u32 * x + seed2 as u32 * y + seed11 as u32 * z + (rnum >> 14);
990 let mut b = seed3 as u32 * x + seed4 as u32 * y + seed12 as u32 * z + (rnum >> 10);
991 let mut c = seed5 as u32 * x + seed6 as u32 * y + seed9 as u32 * z + (rnum >> 6);
992 let mut d = seed7 as u32 * x + seed8 as u32 * y + seed10 as u32 * z + (rnum >> 2);
993
994 a &= 0x3F;
995 b &= 0x3F;
996 c &= 0x3F;
997 d &= 0x3F;
998
999 if partition_count < 4 {
1000 d = 0;
1001 }
1002
1003 if partition_count < 3 {
1004 c = 0;
1005 }
1006
1007 if a >= b && a >= c && a >= d {
1008 0
1009 } else if b >= c && b >= d {
1010 1
1011 } else if c >= d {
1012 2
1013 } else {
1014 3
1015 }
1016}
1017
1018fn select_2d_partition(
1019 seed: u32,
1020 x: u32,
1021 y: u32,
1022 partition_count: usize,
1023 small_block: bool,
1024) -> usize {
1025 select_partition(seed, x, y, 0, partition_count, small_block)
1026}
1027
1028fn clamp_color(r: i32, g: i32, b: i32, a: i32) -> [u8; 4] {
1029 [
1030 r.clamp(0, 255) as u8,
1031 g.clamp(0, 255) as u8,
1032 b.clamp(0, 255) as u8,
1033 a.clamp(0, 255) as u8,
1034 ]
1035}
1036
1037fn blue_contract(r: i32, g: i32, b: i32, a: i32) -> [u8; 4] {
1040 [
1041 ((r + b) >> 1).clamp(0, 255) as u8,
1042 ((g + b) >> 1).clamp(0, 255) as u8,
1043 b.clamp(0, 255) as u8,
1044 a.clamp(0, 255) as u8,
1045 ]
1046}
1047
1048fn compute_endpoints(color_values: &mut &[u8], endpoint_mods: u32) -> [[u8; 4]; 2] {
1050 let ep1: [u8; 4];
1051 let ep2: [u8; 4];
1052 macro_rules! read_int_values {
1053 ($N:expr) => {{
1054 let mut v = [0; $N];
1055 for i in 0..$N {
1056 v[i] = color_values[0] as i32;
1057 *color_values = &color_values[1..];
1058 }
1059 v
1060 }};
1061 }
1062
1063 macro_rules! bts {
1064 ($v:ident, $a:expr, $b: expr) => {{
1065 let mut a = $v[$a];
1066 let mut b = $v[$b];
1067 bit_transfer_signed(&mut a, &mut b);
1068 $v[$a] = a;
1069 $v[$b] = b;
1070 }};
1071 }
1072
1073 match endpoint_mods {
1074 0 => {
1075 let v = read_int_values!(2);
1076 ep1 = clamp_color(v[0], v[0], v[0], 0xFF);
1077 ep2 = clamp_color(v[1], v[1], v[1], 0xFF);
1078 }
1079
1080 1 => {
1081 let v = read_int_values!(2);
1082 let l0 = (v[0] >> 2) | (v[1] & 0xC0);
1083 let l1 = std::cmp::min(l0 + (v[1] & 0x3F), 0xFF);
1084 ep1 = clamp_color(l0, l0, l0, 0xFF);
1085 ep2 = clamp_color(l1, l1, l1, 0xFF);
1086 }
1087
1088 4 => {
1089 let v = read_int_values!(4);
1090 ep1 = clamp_color(v[0], v[0], v[0], v[2]);
1091 ep2 = clamp_color(v[1], v[1], v[1], v[3]);
1092 }
1093
1094 5 => {
1095 let mut v = read_int_values!(4);
1096 bts!(v, 1, 0);
1097 bts!(v, 3, 2);
1098 ep1 = clamp_color(v[0], v[0], v[0], v[2]);
1099 ep2 = clamp_color(v[0] + v[1], v[0] + v[1], v[0] + v[1], v[2] + v[3]);
1100 }
1101
1102 6 => {
1103 let v = read_int_values!(4);
1104 ep1 = clamp_color(
1105 (v[0] * v[3]) >> 8,
1106 (v[1] * v[3]) >> 8,
1107 (v[2] * v[3]) >> 8,
1108 0xFF,
1109 );
1110 ep2 = clamp_color(v[0], v[1], v[2], 0xFF);
1111 }
1112
1113 8 => {
1114 let v = read_int_values!(6);
1115 if v[1] + v[3] + v[5] >= v[0] + v[2] + v[4] {
1116 ep1 = clamp_color(v[0], v[2], v[4], 0xFF);
1117 ep2 = clamp_color(v[1], v[3], v[5], 0xFF);
1118 } else {
1119 ep1 = blue_contract(v[1], v[3], v[5], 0xFF);
1120 ep2 = blue_contract(v[0], v[2], v[4], 0xFF);
1121 }
1122 }
1123
1124 9 => {
1125 let mut v = read_int_values!(6);
1126 bts!(v, 1, 0);
1127 bts!(v, 3, 2);
1128 bts!(v, 5, 4);
1129 if v[1] + v[3] + v[5] >= 0 {
1130 ep1 = clamp_color(v[0], v[2], v[4], 0xFF);
1131 ep2 = clamp_color(v[0] + v[1], v[2] + v[3], v[4] + v[5], 0xFF);
1132 } else {
1133 ep1 = blue_contract(v[0] + v[1], v[2] + v[3], v[4] + v[5], 0xFF);
1134 ep2 = blue_contract(v[0], v[2], v[4], 0xFF);
1135 }
1136 }
1137
1138 10 => {
1139 let v = read_int_values!(6);
1140 ep1 = clamp_color(
1141 (v[0] * v[3]) >> 8,
1142 (v[1] * v[3]) >> 8,
1143 (v[2] * v[3]) >> 8,
1144 v[4],
1145 );
1146 ep2 = clamp_color(v[0], v[1], v[2], v[5]);
1147 }
1148
1149 12 => {
1150 let v = read_int_values!(8);
1151 if v[1] + v[3] + v[5] >= v[0] + v[2] + v[4] {
1152 ep1 = clamp_color(v[0], v[2], v[4], v[6]);
1153 ep2 = clamp_color(v[1], v[3], v[5], v[7]);
1154 } else {
1155 ep1 = blue_contract(v[1], v[3], v[5], v[7]);
1156 ep2 = blue_contract(v[0], v[2], v[4], v[6]);
1157 }
1158 }
1159
1160 13 => {
1161 let mut v = read_int_values!(8);
1162 bts!(v, 1, 0);
1163 bts!(v, 3, 2);
1164 bts!(v, 5, 4);
1165 bts!(v, 7, 6);
1166 if v[1] + v[3] + v[5] >= 0 {
1167 ep1 = clamp_color(v[0], v[2], v[4], v[6]);
1168 ep2 = clamp_color(v[0] + v[1], v[2] + v[3], v[4] + v[5], v[6] + v[7]);
1169 } else {
1170 ep1 = blue_contract(v[0] + v[1], v[2] + v[3], v[4] + v[5], v[6] + v[7]);
1171 ep2 = blue_contract(v[0], v[2], v[4], v[6]);
1172 }
1173 }
1174
1175 _ => {
1176 ep1 = [0xFF, 0, 0xFF, 0xFF];
1178 ep2 = [0xFF, 0, 0xFF, 0xFF];
1179 }
1180 }
1181 [ep1, ep2]
1182}
1183
1184#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
1190pub struct Footprint {
1191 block_width: u32,
1192 block_height: u32,
1193}
1194
1195impl Footprint {
1196 pub const ASTC_4X4: Footprint = Footprint {
1197 block_width: 4,
1198 block_height: 4,
1199 };
1200 pub const ASTC_5X4: Footprint = Footprint {
1201 block_width: 5,
1202 block_height: 4,
1203 };
1204 pub const ASTC_5X5: Footprint = Footprint {
1205 block_width: 5,
1206 block_height: 5,
1207 };
1208 pub const ASTC_6X5: Footprint = Footprint {
1209 block_width: 6,
1210 block_height: 5,
1211 };
1212 pub const ASTC_6X6: Footprint = Footprint {
1213 block_width: 6,
1214 block_height: 6,
1215 };
1216 pub const ASTC_8X5: Footprint = Footprint {
1217 block_width: 8,
1218 block_height: 5,
1219 };
1220 pub const ASTC_8X6: Footprint = Footprint {
1221 block_width: 8,
1222 block_height: 6,
1223 };
1224 pub const ASTC_10X5: Footprint = Footprint {
1225 block_width: 10,
1226 block_height: 5,
1227 };
1228 pub const ASTC_10X6: Footprint = Footprint {
1229 block_width: 10,
1230 block_height: 6,
1231 };
1232 pub const ASTC_8X8: Footprint = Footprint {
1233 block_width: 8,
1234 block_height: 8,
1235 };
1236 pub const ASTC_10X8: Footprint = Footprint {
1237 block_width: 10,
1238 block_height: 8,
1239 };
1240 pub const ASTC_10X10: Footprint = Footprint {
1241 block_width: 10,
1242 block_height: 10,
1243 };
1244 pub const ASTC_12X10: Footprint = Footprint {
1245 block_width: 12,
1246 block_height: 10,
1247 };
1248 pub const ASTC_12X12: Footprint = Footprint {
1249 block_width: 12,
1250 block_height: 12,
1251 };
1252
1253 pub fn new(block_width: u32, block_height: u32) -> Footprint {
1255 if block_width == 0 || block_height == 0 {
1256 panic!("Invalid block size")
1257 }
1258 Footprint {
1259 block_width,
1260 block_height,
1261 }
1262 }
1263
1264 pub fn block_width(&self) -> u32 {
1266 self.block_width
1267 }
1268
1269 pub fn block_height(&self) -> u32 {
1271 self.block_height
1272 }
1273}
1274
1275pub fn astc_decode_block<F: FnMut(u32, u32, [u8; 4])>(
1295 input: &[u8; 16],
1296 footprint: Footprint,
1297 mut writer: F,
1298) -> bool {
1299 let block_width = footprint.block_width;
1300 let block_height = footprint.block_height;
1301 let mut strm = InputBitStream::new(u128::from_le_bytes(*input));
1302 let weight_params = decode_block_info(&mut strm);
1303
1304 if weight_params.is_error {
1306 fill_error(&mut writer, block_width, block_height);
1307 return false;
1308 }
1309
1310 if weight_params.void_extent_ldr {
1311 fill_void_extent_ldr(&mut strm, &mut writer, block_width, block_height);
1312 return true;
1313 }
1314
1315 if weight_params.void_extent_hdr {
1316 fill_error(&mut writer, block_width, block_height);
1317 return false;
1318 }
1319
1320 if weight_params.width > block_width {
1321 fill_error(&mut writer, block_width, block_height);
1322 return false;
1323 }
1324
1325 if weight_params.height > block_height {
1326 fill_error(&mut writer, block_width, block_height);
1327 return false;
1328 }
1329
1330 if weight_params.get_num_weight_values() > 64 {
1331 fill_error(&mut writer, block_width, block_height);
1332 return false;
1333 }
1334
1335 let n_weight_bits = weight_params.get_packed_bit_size();
1336
1337 if !(24..=96).contains(&n_weight_bits) {
1338 fill_error(&mut writer, block_width, block_height);
1339 return false;
1340 }
1341
1342 let n_partitions = (strm.read_bits(2) + 1) as usize;
1344 assert!(n_partitions <= 4);
1345
1346 if n_partitions == 4 && weight_params.is_dual_plane {
1347 fill_error(&mut writer, block_width, block_height);
1348 return false;
1349 }
1350
1351 let plane_idx;
1356 let partition_index;
1357 let mut endpoint_mods = [0, 0, 0, 0];
1358 let endpoint_mods = &mut endpoint_mods[0..n_partitions];
1359
1360 let mut base_cem = 0;
1362 if n_partitions == 1 {
1363 endpoint_mods[0] = strm.read_bits(4);
1364 partition_index = 0;
1365 } else {
1366 partition_index = strm.read_bits(10);
1367 base_cem = strm.read_bits(6);
1368 }
1369 let base_mode = base_cem & 3;
1370
1371 let mut non_color_bits = n_weight_bits + strm.get_bits_read();
1372
1373 let mut extra_cem_bits = 0;
1375 if base_mode != 0 {
1376 match n_partitions {
1377 2 => extra_cem_bits += 2,
1378 3 => extra_cem_bits += 5,
1379 4 => extra_cem_bits += 8,
1380 _ => unreachable!(),
1381 }
1382 }
1383 non_color_bits += extra_cem_bits;
1384
1385 let mut plane_selector_bits = 0;
1387 if weight_params.is_dual_plane {
1388 plane_selector_bits = 2;
1389 }
1390 non_color_bits += plane_selector_bits;
1391
1392 if non_color_bits >= 128 {
1393 fill_error(&mut writer, block_width, block_height);
1394 return false;
1395 }
1396
1397 let color_data_bits = 128 - non_color_bits;
1399 let endpoint_data = strm.read_bits128(color_data_bits);
1400
1401 plane_idx = strm.read_bits(plane_selector_bits);
1403
1404 if base_mode != 0 {
1406 let extra_cem = strm.read_bits(extra_cem_bits);
1407 let mut cem = (extra_cem << 6) | base_cem;
1408 cem >>= 2;
1409
1410 let mut c = [false; 4];
1411 for c in &mut c[0..n_partitions] {
1412 *c = (cem & 1) != 0;
1413 cem >>= 1;
1414 }
1415
1416 let mut m = [0; 4];
1417 for m in &mut m[0..n_partitions] {
1418 *m = cem & 3;
1419 cem >>= 2;
1420 }
1421
1422 for (i, endpoint_mod) in endpoint_mods.iter_mut().enumerate() {
1423 *endpoint_mod = base_mode;
1424 if !c[i] {
1425 *endpoint_mod -= 1;
1426 }
1427 *endpoint_mod <<= 2;
1428 *endpoint_mod |= m[i];
1429 }
1430 } else if n_partitions > 1 {
1431 let cem = base_cem >> 2;
1432 endpoint_mods[0..n_partitions].fill(cem);
1433 }
1434
1435 for &endpoint_mod in endpoint_mods.iter() {
1437 assert!(endpoint_mod < 16);
1438 }
1439 assert!(strm.get_bits_read() + weight_params.get_packed_bit_size() == 128);
1440
1441 let n_values = endpoint_mods.iter().map(|m| ((m >> 2) + 1) << 1).sum();
1443
1444 if n_values > 18 || (n_values * 13 + 4) / 5 > color_data_bits {
1445 fill_error(&mut writer, block_width, block_height);
1446 return false;
1447 }
1448
1449 let color_values = decode_color_values(endpoint_data, n_values, color_data_bits);
1451
1452 let mut endpoints = [[[0; 4]; 2]; 4];
1453 let mut color_values_ptr = &color_values[0..n_values as usize];
1454 for i in 0..n_partitions {
1455 endpoints[i] = compute_endpoints(&mut color_values_ptr, endpoint_mods[i]);
1456 }
1457
1458 let mut texel_weight_data = u128::from_le_bytes(*input).reverse_bits();
1460
1461 texel_weight_data &= (1 << weight_params.get_packed_bit_size()) - 1;
1463
1464 let mut weight_stream = InputBitStream::new(texel_weight_data);
1466 let weights = unquantize_texel_weights(
1467 &mut weight_stream,
1468 &weight_params,
1469 block_width,
1470 block_height,
1471 );
1472
1473 for j in 0..block_height {
1476 for i in 0..block_width {
1477 let partition = select_2d_partition(
1478 partition_index,
1479 i,
1480 j,
1481 n_partitions,
1482 (block_height * block_width) < 32,
1483 );
1484 assert!(partition < n_partitions);
1485
1486 let mut p = [0; 4];
1487 for (c, p) in p.iter_mut().enumerate() {
1488 let c0 = endpoints[partition][0][c] as u32 * 0x101;
1489 let c1 = endpoints[partition][1][c] as u32 * 0x101;
1490
1491 let mut plane = 0;
1492 if weight_params.is_dual_plane && (plane_idx & 3 == c as u32) {
1493 plane = 1;
1494 }
1495
1496 let weight = weights[plane][(j * block_width + i) as usize];
1497 let color = (c0 * (64 - weight) + c1 * weight + 32) / 64;
1498 *p = u8::try_from(((color * 255) + 32767) / 65536).unwrap();
1499 }
1500
1501 writer(i, j, p);
1502 }
1503 }
1504
1505 true
1506}
1507
1508pub fn astc_decode<R: Read, F: FnMut(u32, u32, [u8; 4])>(
1529 mut input: R,
1530 width: u32,
1531 height: u32,
1532 footprint: Footprint,
1533 mut writer: F,
1534) -> Result<()> {
1535 let block_width = footprint.block_width;
1536 let block_height = footprint.block_height;
1537
1538 let block_w = (width.checked_add(block_width).unwrap() - 1) / block_width;
1539 let block_h = (height.checked_add(block_height).unwrap() - 1) / block_height;
1540
1541 for by in 0..block_h {
1542 for bx in 0..block_w {
1543 let mut block_buf = [0; 16];
1544 input.read_exact(&mut block_buf)?;
1545 astc_decode_block(&block_buf, footprint, |x, y, v| {
1546 let x = bx * block_width + x;
1547 let y = by * block_height + y;
1548 if x < width && y < height {
1549 writer(x, y, v)
1550 }
1551 });
1552 }
1553 }
1554
1555 Ok(())
1556}
1557
1558#[cfg(test)]
1559mod tests {
1560 use super::*;
1561 use image::Pixel;
1562
1563 fn dist(a: u8, b: u8) {
1564 assert!((a as i32 - b as i32).abs() <= 1)
1565 }
1566
1567 fn test_case(astc: &[u8], bmp: &[u8], block_width: u32, block_height: u32) {
1568 let bmp = image::load_from_memory(bmp).unwrap().to_rgba8();
1569 let width = bmp.width();
1570 let height = bmp.height();
1571 astc_decode(
1572 &astc[16..],
1573 width,
1574 height,
1575 Footprint::new(block_width, block_height),
1576 |x, y, v| {
1577 let y = height - y - 1;
1578 let p = bmp.get_pixel(x as u32, y as u32).channels();
1579 dist(p[0], v[0]);
1580 dist(p[1], v[1]);
1581 dist(p[2], v[2]);
1582 dist(p[3], v[3]);
1583 },
1584 )
1585 .unwrap();
1586 }
1587
1588 macro_rules! tc {
1589 ($name:literal, $bw:literal, $bh:literal) => {
1590 test_case(
1591 include_bytes!(concat!("test-data/", $name, '_', $bw, 'x', $bh, ".astc")),
1592 include_bytes!(concat!("test-data/", $name, '_', $bw, 'x', $bh, ".bmp")),
1593 $bw,
1594 $bh,
1595 );
1596 };
1597 }
1598
1599 #[test]
1600 fn real_image() {
1601 tc!("atlas_small", 4, 4);
1602 tc!("atlas_small", 5, 5);
1603 tc!("atlas_small", 6, 6);
1604 tc!("atlas_small", 8, 8);
1605 tc!("footprint", 4, 4);
1606 tc!("footprint", 5, 4);
1607 tc!("footprint", 5, 5);
1608 tc!("footprint", 6, 5);
1609 tc!("footprint", 6, 6);
1610 tc!("footprint", 8, 5);
1611 tc!("footprint", 8, 6);
1612 tc!("footprint", 8, 8);
1613 tc!("footprint", 10, 5);
1614 tc!("footprint", 10, 6);
1615 tc!("footprint", 10, 8);
1616 tc!("footprint", 10, 10);
1617 tc!("footprint", 12, 10);
1618 tc!("footprint", 12, 12);
1619 tc!("rgb", 4, 4);
1620 tc!("rgb", 5, 4);
1621 tc!("rgb", 6, 6);
1622 tc!("rgb", 8, 8);
1623 tc!("rgb", 12, 12);
1624 }
1625
1626 fn fuzz_fp(w: u32, h: u32) {
1627 let footprint = Footprint::new(w, h);
1628 for _ in 0..10000 {
1629 let block = rand::random();
1630 astc_decode_block(&block, footprint, |_, _, _| {});
1631 }
1632 }
1633
1634 #[test]
1635 fn fuzzing() {
1636 fuzz_fp(4, 4);
1637 fuzz_fp(5, 4);
1638 fuzz_fp(5, 5);
1639 fuzz_fp(6, 5);
1640 fuzz_fp(6, 6);
1641 fuzz_fp(8, 5);
1642 fuzz_fp(8, 6);
1643 fuzz_fp(8, 8);
1644 fuzz_fp(10, 5);
1645 fuzz_fp(10, 6);
1646 fuzz_fp(10, 8);
1647 fuzz_fp(10, 10);
1648 fuzz_fp(12, 10);
1649 fuzz_fp(12, 12);
1650 }
1651}