astc_decode/
lib.rs

1// Copyright 2021 Weiyi Wang
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// The code is ported from an ASTC decoder in C++ with a few bug fixes.
16// Here is the copyright notice from the original code:
17//
18// Copyright 2016 The University of North Carolina at Chapel Hill
19//
20// Licensed under the Apache License, Version 2.0 (the "License");
21// you may not use this file except in compliance with the License.
22// You may obtain a copy of the License at
23//
24//    http://www.apache.org/licenses/LICENSE-2.0
25//
26// Unless required by applicable law or agreed to in writing, software
27// distributed under the License is distributed on an "AS IS" BASIS,
28// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
29// See the License for the specific language governing permissions and
30// limitations under the License.
31//
32// Please send all BUG REPORTS to <pavel@cs.unc.edu>.
33// <http://gamma.cs.unc.edu/FasTC/>
34
35#![allow(clippy::many_single_char_names)]
36
37use std::convert::TryFrom;
38use std::io::*;
39
40struct InputBitStream {
41    data: u128,
42    bits_read: u32,
43}
44
45impl InputBitStream {
46    fn new(data: u128) -> InputBitStream {
47        InputBitStream { data, bits_read: 0 }
48    }
49
50    fn get_bits_read(&self) -> u32 {
51        self.bits_read
52    }
53
54    fn read_bit(&mut self) -> u32 {
55        self.read_bits(1)
56    }
57
58    fn read_bits(&mut self, n_bits: u32) -> u32 {
59        assert!(n_bits <= 32);
60        self.read_bits128(n_bits) as u32
61    }
62
63    fn read_bits128(&mut self, n_bits: u32) -> u128 {
64        self.bits_read += n_bits;
65        assert!(self.bits_read <= 128);
66        let ret = self.data & ((1 << n_bits) - 1);
67        self.data >>= n_bits;
68        ret
69    }
70}
71
72struct Bits(u32);
73
74impl Bits {
75    fn get(&self, pos: u32) -> u32 {
76        (self.0 >> pos) & 1
77    }
78
79    fn range(&self, start: u32, end: u32) -> u32 {
80        let mask = (1 << (end - start + 1)) - 1;
81        (self.0 >> start) & mask
82    }
83}
84
85#[derive(Clone, Copy, PartialEq, Eq)]
86enum IntegerEncodingType {
87    JustBits,
88    Quint,
89    Trit,
90}
91
92#[derive(Clone, Copy, PartialEq, Eq)]
93struct IntegerEncoding {
94    encoding: IntegerEncodingType,
95    num_bits: u32,
96}
97
98impl IntegerEncoding {
99    // Returns the number of bits required to encode n_vals values.
100    fn get_bit_length(&self, n_vals: u32) -> u32 {
101        let mut total_bits = self.num_bits * n_vals;
102        match self.encoding {
103            IntegerEncodingType::JustBits => (),
104            IntegerEncodingType::Trit => total_bits += (n_vals * 8 + 4) / 5,
105            IntegerEncodingType::Quint => total_bits += (n_vals * 7 + 2) / 3,
106        }
107        total_bits
108    }
109}
110
111fn decode_color(a: u32, b: u32, c: u32, d: u32) -> u8 {
112    let t = (d * c + b) ^ a;
113    u8::try_from((a & 0x80) | (t >> 2)).unwrap()
114}
115
116fn decode_weight(a: u32, b: u32, c: u32, d: u32) -> u32 {
117    let t = (d * c + b) ^ a;
118    (a & 0x20) | (t >> 2)
119}
120
121struct Trit {
122    trit_value: u32,
123    bit_value: u32,
124}
125
126impl Trit {
127    fn decode_color(self, bitlen: u32) -> u8 {
128        let bitval = self.bit_value;
129        let a = (bitval & 1) * 0x1FF;
130        let x = bitval >> 1;
131        let b;
132        let c;
133        match bitlen {
134            1 => {
135                c = 204;
136                b = 0;
137            }
138
139            2 => {
140                c = 93;
141                // b = b000b0bb0
142                b = (x << 8) | (x << 4) | (x << 2) | (x << 1);
143            }
144
145            3 => {
146                c = 44;
147                // b = cb000cbcb
148                b = (x << 7) | (x << 2) | x;
149            }
150
151            4 => {
152                c = 22;
153                // b = dcb000dcb
154                b = (x << 6) | x;
155            }
156
157            5 => {
158                c = 11;
159                // b = edcb000ed
160                b = (x << 5) | (x >> 2);
161            }
162
163            6 => {
164                c = 5;
165                // b = fedcb000f
166                b = (x << 4) | (x >> 4);
167            }
168
169            _ => unreachable!("Invalid trit encoding for color values"),
170        }
171        decode_color(a, b, c, self.trit_value)
172    }
173
174    fn decode_weight(self, bitlen: u32) -> u32 {
175        let bitval = self.bit_value;
176        let a = (bitval & 1) * 0x7F;
177        let x = bitval >> 1;
178        let b;
179        let c;
180        match bitlen {
181            0 => {
182                return [0, 32, 63][self.trit_value as usize];
183            }
184
185            1 => {
186                c = 50;
187                b = 0;
188            }
189
190            2 => {
191                c = 23;
192                b = (x << 6) | (x << 2) | x;
193            }
194
195            3 => {
196                c = 11;
197                b = (x << 5) | x;
198            }
199
200            _ => unreachable!("Invalid trit encoding for texel weight"),
201        }
202        decode_weight(a, b, c, self.trit_value)
203    }
204}
205
206struct Quint {
207    quint_value: u32,
208    bit_value: u32,
209}
210
211impl Quint {
212    fn decode_color(self, bitlen: u32) -> u8 {
213        let bitval = self.bit_value;
214        let a = (bitval & 1) * 0x1FF;
215        let x = bitval >> 1;
216        let b;
217        let c;
218        match bitlen {
219            1 => {
220                c = 113;
221                b = 0;
222            }
223
224            2 => {
225                c = 54;
226                // b = b0000bb00
227                b = (x << 8) | (x << 3) | (x << 2);
228            }
229
230            3 => {
231                c = 26;
232                // b = cb0000cbc
233                b = (x << 7) | (x << 1) | (x >> 1);
234            }
235
236            4 => {
237                c = 13;
238                // b = dcb0000dc
239                b = (x << 6) | (x >> 1);
240            }
241
242            5 => {
243                c = 6;
244                // b = edcb0000e
245                b = (x << 5) | (x >> 3);
246            }
247
248            _ => unreachable!("Invalid quint encoding for color values"),
249        }
250        decode_color(a, b, c, self.quint_value)
251    }
252
253    fn decode_weight(self, bitlen: u32) -> u32 {
254        let bitval = self.bit_value;
255        let a = (bitval & 1) * 0x7F;
256        let b;
257        let c;
258        match bitlen {
259            0 => {
260                return [0, 16, 32, 47, 63][self.quint_value as usize];
261            }
262
263            1 => {
264                c = 28;
265                b = 0;
266            }
267
268            2 => {
269                c = 13;
270                let x = bitval >> 1;
271                b = (x << 6) | (x << 1);
272            }
273
274            _ => unreachable!("Invalid quint encoding for texel weight"),
275        }
276        decode_weight(a, b, c, self.quint_value)
277    }
278}
279
280fn decode_trit_block(bits: &mut InputBitStream, bits_per_value: u32) -> impl Iterator<Item = Trit> {
281    // Implement the algorithm in section c.2.12
282    let mut m = [0u32; 5];
283    let mut t = [0u32; 5];
284    let mut tt: u32;
285
286    // Read the trit encoded block according to
287    // table c.2.14
288    m[0] = bits.read_bits(bits_per_value);
289    tt = bits.read_bits(2);
290    m[1] = bits.read_bits(bits_per_value);
291    tt |= bits.read_bits(2) << 2;
292    m[2] = bits.read_bits(bits_per_value);
293    tt |= (bits.read_bit()) << 4;
294    m[3] = bits.read_bits(bits_per_value);
295    tt |= bits.read_bits(2) << 5;
296    m[4] = bits.read_bits(bits_per_value);
297    tt |= (bits.read_bit()) << 7;
298
299    let c: u32;
300
301    let tb = Bits(tt);
302    if tb.range(2, 4) == 7 {
303        c = (tb.range(5, 7) << 2) | tb.range(0, 1);
304        t[3] = 2;
305        t[4] = 2;
306    } else {
307        c = tb.range(0, 4);
308        if tb.range(5, 6) == 3 {
309            t[4] = 2;
310            t[3] = tb.get(7);
311        } else {
312            t[4] = tb.get(7);
313            t[3] = tb.range(5, 6);
314        }
315    }
316
317    let cb = Bits(c);
318    if cb.range(0, 1) == 3 {
319        t[2] = 2;
320        t[1] = cb.get(4);
321        t[0] = (cb.get(3) << 1) | (cb.get(2) & !cb.get(3));
322    } else if cb.range(2, 3) == 3 {
323        t[2] = 2;
324        t[1] = 2;
325        t[0] = cb.range(0, 1);
326    } else {
327        t[2] = cb.get(4);
328        t[1] = cb.range(2, 3);
329        t[0] = (cb.get(1) << 1) | (cb.get(0) & !cb.get(1));
330    }
331
332    IntoIterator::into_iter(m)
333        .zip(t)
334        .map(|(bit_value, trit_value)| Trit {
335            trit_value,
336            bit_value,
337        })
338}
339
340fn decode_quint_block(
341    bits: &mut InputBitStream,
342    bits_per_value: u32,
343) -> impl Iterator<Item = Quint> {
344    // Implement the algorithm in section c.2.12
345    let mut m = [0u32; 3];
346    let mut q = [0u32; 3];
347    let mut qq: u32;
348
349    // Read the trit encoded block according to
350    // table c.2.15
351    m[0] = bits.read_bits(bits_per_value);
352    qq = bits.read_bits(3);
353    m[1] = bits.read_bits(bits_per_value);
354    qq |= bits.read_bits(2) << 3;
355    m[2] = bits.read_bits(bits_per_value);
356    qq |= bits.read_bits(2) << 5;
357
358    let qb = Bits(qq);
359    if qb.range(1, 2) == 3 && qb.range(5, 6) == 0 {
360        q[0] = 4;
361        q[1] = 4;
362        q[2] = (qb.get(0) << 2) | ((qb.get(4) & !qb.get(0)) << 1) | (qb.get(3) & !qb.get(0));
363    } else {
364        let c;
365        if qb.range(1, 2) == 3 {
366            q[2] = 4;
367            c = (qb.range(3, 4) << 3) | ((!qb.range(5, 6) & 3) << 1) | qb.get(0);
368        } else {
369            q[2] = qb.range(5, 6);
370            c = qb.range(0, 4);
371        }
372
373        let cb = Bits(c);
374        if cb.range(0, 2) == 5 {
375            q[1] = 4;
376            q[0] = cb.range(3, 4);
377        } else {
378            q[1] = cb.range(3, 4);
379            q[0] = cb.range(0, 2);
380        }
381    }
382
383    IntoIterator::into_iter(m)
384        .zip(q)
385        .map(|(bit_value, quint_value)| Quint {
386            quint_value,
387            bit_value,
388        })
389}
390
391// Returns IntegerEncoding that can take no more than maxval values
392const fn create_encoding(mut max_val: u32) -> IntegerEncoding {
393    while max_val > 0 {
394        let check = max_val + 1;
395
396        // Is max_val a power of two?
397        if (check & (check - 1)) == 0 {
398            return IntegerEncoding {
399                encoding: IntegerEncodingType::JustBits,
400                num_bits: max_val.count_ones(),
401            };
402        }
403
404        // Is max_val of the type 3*2^n - 1?
405        if (check % 3 == 0) && ((check / 3) & ((check / 3) - 1)) == 0 {
406            return IntegerEncoding {
407                encoding: IntegerEncodingType::Trit,
408                num_bits: (check / 3 - 1).count_ones(),
409            };
410        }
411
412        // Is max_val of the type 5*2^n - 1?
413        if (check % 5 == 0) && ((check / 5) & ((check / 5) - 1)) == 0 {
414            return IntegerEncoding {
415                encoding: IntegerEncodingType::Quint,
416                num_bits: (check / 5 - 1).count_ones(),
417            };
418        }
419
420        // Apparently it can't be represented with a bounded integer sequence...
421        // just iterate.
422        max_val -= 1;
423    }
424    IntegerEncoding {
425        encoding: IntegerEncodingType::JustBits,
426        num_bits: 0,
427    }
428}
429
430static ENCODING_MAP: [IntegerEncoding; 256] = {
431    let mut result = [IntegerEncoding {
432        encoding: IntegerEncodingType::JustBits,
433        num_bits: 0,
434    }; 256];
435    let mut i = 0;
436    while i < 256 {
437        result[i as usize] = create_encoding(i);
438        i += 1;
439    }
440    result
441};
442
443static ENCODING_SEQ: ([IntegerEncoding; 256], usize) = {
444    let mut result = [IntegerEncoding {
445        encoding: IntegerEncodingType::JustBits,
446        num_bits: 0,
447    }; 256];
448    let mut len = 1;
449    result[0] = ENCODING_MAP[0];
450    let mut i = 1;
451    while i < 256 {
452        let encoding = ENCODING_MAP[i];
453        let previous = result[len - 1];
454        // HACK: should be encoding != previous, but Eq in const is not usable right now
455        if encoding.encoding as u32 != previous.encoding as u32
456            || encoding.num_bits != previous.num_bits
457        {
458            result[len] = encoding;
459            len += 1;
460        }
461        i += 1;
462    }
463    (result, len)
464};
465
466struct TexelWeightParams {
467    width: u32,
468    height: u32,
469    is_dual_plane: bool,
470    max_weight: u32,
471    is_error: bool,
472    void_extent_ldr: bool,
473    void_extent_hdr: bool,
474}
475
476impl Default for TexelWeightParams {
477    fn default() -> Self {
478        TexelWeightParams {
479            width: 0,
480            height: 0,
481            is_dual_plane: false,
482            max_weight: 0,
483            is_error: false,
484            void_extent_ldr: false,
485            void_extent_hdr: false,
486        }
487    }
488}
489
490impl TexelWeightParams {
491    fn get_packed_bit_size(&self) -> u32 {
492        ENCODING_MAP[self.max_weight as usize].get_bit_length(self.get_num_weight_values())
493    }
494
495    fn get_num_weight_values(&self) -> u32 {
496        let mut ret = self.width * self.height;
497        if self.is_dual_plane {
498            ret *= 2;
499        }
500        ret
501    }
502}
503
504fn decode_block_info(strm: &mut InputBitStream) -> TexelWeightParams {
505    let mut params = TexelWeightParams::default();
506
507    // Read the entire block mode all at once
508    let mode_bits = strm.read_bits(11);
509
510    // Does this match the void extent block mode?
511    if (mode_bits & 0x01FF) == 0x1FC {
512        if mode_bits & 0x200 != 0 {
513            params.void_extent_hdr = true;
514        } else {
515            params.void_extent_ldr = true;
516        }
517
518        // Next two bits must be one.
519        if (mode_bits & 0x400) == 0 || strm.read_bit() == 0 {
520            params.is_error = true;
521        }
522
523        return params;
524    }
525
526    // First check if the last four bits are zero
527    if (mode_bits & 0xF) == 0 {
528        params.is_error = true;
529        return params;
530    }
531
532    // If the last two bits are zero, then if bits
533    // [6-8] are all ones, this is also reserved.
534    if (mode_bits & 0x3) == 0 && (mode_bits & 0x1C0) == 0x1C0 {
535        params.is_error = true;
536        return params;
537    }
538
539    // Otherwise, there is no error... Figure out the layout
540    // of the block mode. Layout is determined by a number
541    // between 0 and 9 corresponding to table c.2.8 of the
542    // ASTC spec.
543    let layout;
544
545    if (mode_bits & 0x1) != 0 || (mode_bits & 0x2) != 0 {
546        // layout is in [0-4]
547        if (mode_bits & 0x8) != 0 {
548            // layout is in [2-4]
549            if (mode_bits & 0x4) != 0 {
550                // layout is in [3-4]
551                if (mode_bits & 0x100) != 0 {
552                    layout = 4;
553                } else {
554                    layout = 3;
555                }
556            } else {
557                layout = 2;
558            }
559        } else {
560            // layout is in [0-1]
561            if (mode_bits & 0x4) != 0 {
562                layout = 1;
563            } else {
564                layout = 0;
565            }
566        }
567    } else {
568        // layout is in [5-9]
569        if (mode_bits & 0x100) != 0 {
570            // layout is in [7-9]
571            if (mode_bits & 0x80) != 0 {
572                // layout is in [7-8]
573                assert!((mode_bits & 0x40) == 0);
574                if (mode_bits & 0x20) != 0 {
575                    layout = 8;
576                } else {
577                    layout = 7;
578                }
579            } else {
580                layout = 9;
581            }
582        } else {
583            // layout is in [5-6]
584            if (mode_bits & 0x80) != 0 {
585                layout = 6;
586            } else {
587                layout = 5;
588            }
589        }
590    }
591
592    // Determine R
593    let mut r = (mode_bits & 0x10) >> 4;
594    if layout < 5 {
595        r |= (mode_bits & 0x3) << 1;
596    } else {
597        r |= (mode_bits & 0xC) >> 1;
598    }
599    assert!((2..=7).contains(&r));
600
601    // Determine width & height
602    match layout {
603        0 => {
604            let a = (mode_bits >> 5) & 0x3;
605            let b = (mode_bits >> 7) & 0x3;
606            params.width = b + 4;
607            params.height = a + 2;
608        }
609
610        1 => {
611            let a = (mode_bits >> 5) & 0x3;
612            let b = (mode_bits >> 7) & 0x3;
613            params.width = b + 8;
614            params.height = a + 2;
615        }
616
617        2 => {
618            let a = (mode_bits >> 5) & 0x3;
619            let b = (mode_bits >> 7) & 0x3;
620            params.width = a + 2;
621            params.height = b + 8;
622        }
623
624        3 => {
625            let a = (mode_bits >> 5) & 0x3;
626            let b = (mode_bits >> 7) & 0x1;
627            params.width = a + 2;
628            params.height = b + 6;
629        }
630
631        4 => {
632            let a = (mode_bits >> 5) & 0x3;
633            let b = (mode_bits >> 7) & 0x1;
634            params.width = b + 2;
635            params.height = a + 2;
636        }
637
638        5 => {
639            let a = (mode_bits >> 5) & 0x3;
640            params.width = 12;
641            params.height = a + 2;
642        }
643
644        6 => {
645            let a = (mode_bits >> 5) & 0x3;
646            params.width = a + 2;
647            params.height = 12;
648        }
649
650        7 => {
651            params.width = 6;
652            params.height = 10;
653        }
654
655        8 => {
656            params.width = 10;
657            params.height = 6;
658        }
659
660        9 => {
661            let a = (mode_bits >> 5) & 0x3;
662            let b = (mode_bits >> 9) & 0x3;
663            params.width = a + 6;
664            params.height = b + 6;
665        }
666
667        _ => unreachable!("Impossible layout"),
668    }
669
670    // Determine whether or not we're using dual planes
671    // and/or high precision layouts.
672    let dp = (layout != 9) && (mode_bits & 0x400) != 0;
673    let p = (layout != 9) && (mode_bits & 0x200) != 0;
674
675    let max_weights = if p {
676        [9, 11, 15, 19, 23, 31]
677    } else {
678        [1, 2, 3, 4, 5, 7]
679    };
680    params.max_weight = max_weights[(r - 2) as usize];
681
682    params.is_dual_plane = dp;
683
684    params
685}
686
687fn fill_void_extent_ldr<F: FnMut(u32, u32, [u8; 4])>(
688    strm: &mut InputBitStream,
689    writer: &mut F,
690    block_width: u32,
691    block_height: u32,
692) {
693    // Don't actually care about the void extent, just read the bits...
694    for _ in 0..4 {
695        strm.read_bits(13);
696    }
697
698    // Decode the RGBA components and renormalize them to the range [0, 255]
699    let r = strm.read_bits(16) >> 8;
700    let g = strm.read_bits(16) >> 8;
701    let b = strm.read_bits(16) >> 8;
702    let a = strm.read_bits(16) >> 8;
703
704    for j in 0..block_height {
705        for i in 0..block_width {
706            writer(i, j, [r as u8, g as u8, b as u8, a as u8]);
707        }
708    }
709}
710
711fn fill_error<F: FnMut(u32, u32, [u8; 4])>(writer: &mut F, block_width: u32, block_height: u32) {
712    for j in 0..block_height {
713        for i in 0..block_width {
714            writer(i, j, [0xFF, 0, 0xFF, 0xFF]);
715        }
716    }
717}
718
719// Replicates low num_bits such that [(to_bit - 1):(to_bit - num_bits)]
720// is the same as [(num_bits - 1):0] and repeats all the way down.
721fn replicate(val: u32, num_bits: u32, to_bit: u32) -> u32 {
722    if num_bits == 0 {
723        return 0;
724    }
725    if to_bit == 0 {
726        return 0;
727    }
728
729    let mut res = val << (to_bit - num_bits);
730    let mut shift = num_bits;
731    loop {
732        let next = res >> shift;
733        if next == 0 {
734            return res;
735        }
736        res |= next;
737        shift *= 2;
738    }
739}
740
741fn decode_color_values(data: u128, n_values: u32, n_bits_for_color_data: u32) -> [u8; 18] {
742    let mut out = [0; 18];
743    let out_range = &mut out[0..n_values as usize];
744
745    // Based on the number of values and the remaining number of bits,
746    // figure out the encoding...
747    let encoding_i = ENCODING_SEQ.0[0..ENCODING_SEQ.1]
748        .partition_point(|v| v.get_bit_length(n_values) <= n_bits_for_color_data);
749    let encoding = ENCODING_SEQ.0[encoding_i - 1];
750
751    // We now have enough to decode our integer sequence.
752    let mut color_stream = InputBitStream::new(data);
753    match encoding.encoding {
754        IntegerEncodingType::JustBits => {
755            for out in out_range {
756                *out = replicate(
757                    color_stream.read_bits(encoding.num_bits),
758                    encoding.num_bits,
759                    8,
760                ) as u8;
761            }
762        }
763        IntegerEncodingType::Trit => {
764            for (out, result) in out_range
765                .iter_mut()
766                .zip((0..).flat_map(|_| decode_trit_block(&mut color_stream, encoding.num_bits)))
767            {
768                *out = result.decode_color(encoding.num_bits);
769            }
770        }
771        IntegerEncodingType::Quint => {
772            for (out, result) in out_range
773                .iter_mut()
774                .zip((0..).flat_map(|_| decode_quint_block(&mut color_stream, encoding.num_bits)))
775            {
776                *out = result.decode_color(encoding.num_bits);
777            }
778        }
779    }
780
781    out
782}
783
784fn unquantize_texel_weights(
785    weights_stream: &mut InputBitStream,
786    params: &TexelWeightParams,
787    block_width: u32,
788    block_height: u32,
789) -> [[u32; 144]; 2] {
790    // Blocks can be at most 12x12, so we can have as many as 144 weights
791    let mut out = [[0; 144]; 2];
792    let mut unquantized = [0; 96];
793
794    let plane_scale = if params.is_dual_plane { 2 } else { 1 };
795
796    let unquantized_range = &mut unquantized[0..params.get_num_weight_values() as usize];
797    let encoding = ENCODING_MAP[params.max_weight as usize];
798
799    match encoding.encoding {
800        IntegerEncodingType::JustBits => {
801            for out in unquantized_range.iter_mut() {
802                *out = replicate(
803                    weights_stream.read_bits(encoding.num_bits),
804                    encoding.num_bits,
805                    6,
806                )
807            }
808        }
809        IntegerEncodingType::Trit => {
810            for (out, result) in unquantized_range
811                .iter_mut()
812                .zip((0..).flat_map(|_| decode_trit_block(weights_stream, encoding.num_bits)))
813            {
814                *out = result.decode_weight(encoding.num_bits);
815            }
816        }
817        IntegerEncodingType::Quint => {
818            for (out, result) in unquantized_range
819                .iter_mut()
820                .zip((0..).flat_map(|_| decode_quint_block(weights_stream, encoding.num_bits)))
821            {
822                *out = result.decode_weight(encoding.num_bits);
823            }
824        }
825    }
826
827    for weight in unquantized_range {
828        assert!(*weight < 64);
829        if *weight > 32 {
830            *weight += 1
831        }
832    }
833
834    // Do infill if necessary (Section c.2.18) ...
835    let ds = (1024 + (block_width / 2)) / (block_width - 1);
836    let dt = (1024 + (block_height / 2)) / (block_height - 1);
837
838    for plane in 0..plane_scale {
839        for t in 0..block_height {
840            for s in 0..block_width {
841                let cs = ds * s;
842                let ct = dt * t;
843
844                let gs = (cs * (params.width - 1) + 32) >> 6;
845                let gt = (ct * (params.height - 1) + 32) >> 6;
846
847                let js = gs >> 4;
848                let fs = gs & 0xF;
849
850                let jt = gt >> 4;
851                let ft = gt & 0x0F;
852
853                let w11 = (fs * ft + 8) >> 4;
854                let w10 = ft - w11;
855                let w01 = fs - w11;
856                let w00 = 16 + w11 - fs - ft;
857
858                let v0 = js + jt * params.width;
859
860                let mut p00 = 0;
861                let mut p01 = 0;
862                let mut p10 = 0;
863                let mut p11 = 0;
864
865                if v0 < (params.width * params.height) {
866                    p00 = unquantized[plane + plane_scale * (v0 as usize)];
867                }
868
869                if v0 + 1 < (params.width * params.height) {
870                    p01 = unquantized[plane + plane_scale * ((v0 + 1) as usize)];
871                }
872
873                if v0 + params.width < (params.width * params.height) {
874                    p10 = unquantized[plane + plane_scale * ((v0 + params.width) as usize)];
875                }
876
877                if v0 + params.width + 1 < (params.width * params.height) {
878                    p11 = unquantized[plane + plane_scale * ((v0 + params.width + 1) as usize)];
879                }
880
881                out[plane][(t * block_width + s) as usize] =
882                    (p00 * w00 + p01 * w01 + p10 * w10 + p11 * w11 + 8) >> 4;
883            }
884        }
885    }
886    out
887}
888
889// Transfers a bit as described in c.2.14
890fn bit_transfer_signed(a: &mut i32, b: &mut i32) {
891    *b >>= 1;
892    *b |= *a & 0x80;
893    *a >>= 1;
894    *a &= 0x3F;
895    if (*a & 0x20) != 0 {
896        *a -= 0x40;
897    }
898}
899
900// Partition selection functions as specified in
901// c.2.21
902fn hash52(p: u32) -> u32 {
903    let mut p = std::num::Wrapping(p);
904    p ^= p >> 15;
905    p -= p << 17;
906    p += p << 7;
907    p += p << 4;
908    p ^= p >> 5;
909    p += p << 16;
910    p ^= p >> 7;
911    p ^= p >> 3;
912    p ^= p << 6;
913    p ^= p >> 17;
914    p.0
915}
916
917fn select_partition(
918    mut seed: u32,
919    mut x: u32,
920    mut y: u32,
921    mut z: u32,
922    partition_count: usize,
923    small_block: bool,
924) -> usize {
925    if 1 == partition_count {
926        return 0;
927    }
928
929    if small_block {
930        x <<= 1;
931        y <<= 1;
932        z <<= 1;
933    }
934
935    seed += (partition_count as u32 - 1) * 1024;
936
937    let rnum = hash52(seed);
938    let mut seed1 = (rnum & 0xF) as u8;
939    let mut seed2 = ((rnum >> 4) & 0xF) as u8;
940    let mut seed3 = ((rnum >> 8) & 0xF) as u8;
941    let mut seed4 = ((rnum >> 12) & 0xF) as u8;
942    let mut seed5 = ((rnum >> 16) & 0xF) as u8;
943    let mut seed6 = ((rnum >> 20) & 0xF) as u8;
944    let mut seed7 = ((rnum >> 24) & 0xF) as u8;
945    let mut seed8 = ((rnum >> 28) & 0xF) as u8;
946    let mut seed9 = ((rnum >> 18) & 0xF) as u8;
947    let mut seed10 = ((rnum >> 22) & 0xF) as u8;
948    let mut seed11 = ((rnum >> 26) & 0xF) as u8;
949    let mut seed12 = (((rnum >> 30) | (rnum << 2)) & 0xF) as u8;
950
951    seed1 = seed1 * seed1;
952    seed2 = seed2 * seed2;
953    seed3 = seed3 * seed3;
954    seed4 = seed4 * seed4;
955    seed5 = seed5 * seed5;
956    seed6 = seed6 * seed6;
957    seed7 = seed7 * seed7;
958    seed8 = seed8 * seed8;
959    seed9 = seed9 * seed9;
960    seed10 = seed10 * seed10;
961    seed11 = seed11 * seed11;
962    seed12 = seed12 * seed12;
963
964    let sh1: i32;
965    let sh2: i32;
966    let sh3: i32;
967    if seed & 1 != 0 {
968        sh1 = if seed & 2 != 0 { 4 } else { 5 };
969        sh2 = if partition_count == 3 { 6 } else { 5 };
970    } else {
971        sh1 = if partition_count == 3 { 6 } else { 5 };
972        sh2 = if seed & 2 != 0 { 4 } else { 5 };
973    }
974    sh3 = if seed & 0x10 != 0 { sh1 } else { sh2 };
975
976    seed1 >>= sh1;
977    seed2 >>= sh2;
978    seed3 >>= sh1;
979    seed4 >>= sh2;
980    seed5 >>= sh1;
981    seed6 >>= sh2;
982    seed7 >>= sh1;
983    seed8 >>= sh2;
984    seed9 >>= sh3;
985    seed10 >>= sh3;
986    seed11 >>= sh3;
987    seed12 >>= sh3;
988
989    let mut a = seed1 as u32 * x + seed2 as u32 * y + seed11 as u32 * z + (rnum >> 14);
990    let mut b = seed3 as u32 * x + seed4 as u32 * y + seed12 as u32 * z + (rnum >> 10);
991    let mut c = seed5 as u32 * x + seed6 as u32 * y + seed9 as u32 * z + (rnum >> 6);
992    let mut d = seed7 as u32 * x + seed8 as u32 * y + seed10 as u32 * z + (rnum >> 2);
993
994    a &= 0x3F;
995    b &= 0x3F;
996    c &= 0x3F;
997    d &= 0x3F;
998
999    if partition_count < 4 {
1000        d = 0;
1001    }
1002
1003    if partition_count < 3 {
1004        c = 0;
1005    }
1006
1007    if a >= b && a >= c && a >= d {
1008        0
1009    } else if b >= c && b >= d {
1010        1
1011    } else if c >= d {
1012        2
1013    } else {
1014        3
1015    }
1016}
1017
1018fn select_2d_partition(
1019    seed: u32,
1020    x: u32,
1021    y: u32,
1022    partition_count: usize,
1023    small_block: bool,
1024) -> usize {
1025    select_partition(seed, x, y, 0, partition_count, small_block)
1026}
1027
1028fn clamp_color(r: i32, g: i32, b: i32, a: i32) -> [u8; 4] {
1029    [
1030        r.clamp(0, 255) as u8,
1031        g.clamp(0, 255) as u8,
1032        b.clamp(0, 255) as u8,
1033        a.clamp(0, 255) as u8,
1034    ]
1035}
1036
1037// Adds more precision to the blue channel as described
1038// in c.2.14
1039fn blue_contract(r: i32, g: i32, b: i32, a: i32) -> [u8; 4] {
1040    [
1041        ((r + b) >> 1).clamp(0, 255) as u8,
1042        ((g + b) >> 1).clamp(0, 255) as u8,
1043        b.clamp(0, 255) as u8,
1044        a.clamp(0, 255) as u8,
1045    ]
1046}
1047
1048// Section c.2.14
1049fn compute_endpoints(color_values: &mut &[u8], endpoint_mods: u32) -> [[u8; 4]; 2] {
1050    let ep1: [u8; 4];
1051    let ep2: [u8; 4];
1052    macro_rules! read_int_values {
1053        ($N:expr) => {{
1054            let mut v = [0; $N];
1055            for i in 0..$N {
1056                v[i] = color_values[0] as i32;
1057                *color_values = &color_values[1..];
1058            }
1059            v
1060        }};
1061    }
1062
1063    macro_rules! bts {
1064        ($v:ident, $a:expr, $b: expr) => {{
1065            let mut a = $v[$a];
1066            let mut b = $v[$b];
1067            bit_transfer_signed(&mut a, &mut b);
1068            $v[$a] = a;
1069            $v[$b] = b;
1070        }};
1071    }
1072
1073    match endpoint_mods {
1074        0 => {
1075            let v = read_int_values!(2);
1076            ep1 = clamp_color(v[0], v[0], v[0], 0xFF);
1077            ep2 = clamp_color(v[1], v[1], v[1], 0xFF);
1078        }
1079
1080        1 => {
1081            let v = read_int_values!(2);
1082            let l0 = (v[0] >> 2) | (v[1] & 0xC0);
1083            let l1 = std::cmp::min(l0 + (v[1] & 0x3F), 0xFF);
1084            ep1 = clamp_color(l0, l0, l0, 0xFF);
1085            ep2 = clamp_color(l1, l1, l1, 0xFF);
1086        }
1087
1088        4 => {
1089            let v = read_int_values!(4);
1090            ep1 = clamp_color(v[0], v[0], v[0], v[2]);
1091            ep2 = clamp_color(v[1], v[1], v[1], v[3]);
1092        }
1093
1094        5 => {
1095            let mut v = read_int_values!(4);
1096            bts!(v, 1, 0);
1097            bts!(v, 3, 2);
1098            ep1 = clamp_color(v[0], v[0], v[0], v[2]);
1099            ep2 = clamp_color(v[0] + v[1], v[0] + v[1], v[0] + v[1], v[2] + v[3]);
1100        }
1101
1102        6 => {
1103            let v = read_int_values!(4);
1104            ep1 = clamp_color(
1105                (v[0] * v[3]) >> 8,
1106                (v[1] * v[3]) >> 8,
1107                (v[2] * v[3]) >> 8,
1108                0xFF,
1109            );
1110            ep2 = clamp_color(v[0], v[1], v[2], 0xFF);
1111        }
1112
1113        8 => {
1114            let v = read_int_values!(6);
1115            if v[1] + v[3] + v[5] >= v[0] + v[2] + v[4] {
1116                ep1 = clamp_color(v[0], v[2], v[4], 0xFF);
1117                ep2 = clamp_color(v[1], v[3], v[5], 0xFF);
1118            } else {
1119                ep1 = blue_contract(v[1], v[3], v[5], 0xFF);
1120                ep2 = blue_contract(v[0], v[2], v[4], 0xFF);
1121            }
1122        }
1123
1124        9 => {
1125            let mut v = read_int_values!(6);
1126            bts!(v, 1, 0);
1127            bts!(v, 3, 2);
1128            bts!(v, 5, 4);
1129            if v[1] + v[3] + v[5] >= 0 {
1130                ep1 = clamp_color(v[0], v[2], v[4], 0xFF);
1131                ep2 = clamp_color(v[0] + v[1], v[2] + v[3], v[4] + v[5], 0xFF);
1132            } else {
1133                ep1 = blue_contract(v[0] + v[1], v[2] + v[3], v[4] + v[5], 0xFF);
1134                ep2 = blue_contract(v[0], v[2], v[4], 0xFF);
1135            }
1136        }
1137
1138        10 => {
1139            let v = read_int_values!(6);
1140            ep1 = clamp_color(
1141                (v[0] * v[3]) >> 8,
1142                (v[1] * v[3]) >> 8,
1143                (v[2] * v[3]) >> 8,
1144                v[4],
1145            );
1146            ep2 = clamp_color(v[0], v[1], v[2], v[5]);
1147        }
1148
1149        12 => {
1150            let v = read_int_values!(8);
1151            if v[1] + v[3] + v[5] >= v[0] + v[2] + v[4] {
1152                ep1 = clamp_color(v[0], v[2], v[4], v[6]);
1153                ep2 = clamp_color(v[1], v[3], v[5], v[7]);
1154            } else {
1155                ep1 = blue_contract(v[1], v[3], v[5], v[7]);
1156                ep2 = blue_contract(v[0], v[2], v[4], v[6]);
1157            }
1158        }
1159
1160        13 => {
1161            let mut v = read_int_values!(8);
1162            bts!(v, 1, 0);
1163            bts!(v, 3, 2);
1164            bts!(v, 5, 4);
1165            bts!(v, 7, 6);
1166            if v[1] + v[3] + v[5] >= 0 {
1167                ep1 = clamp_color(v[0], v[2], v[4], v[6]);
1168                ep2 = clamp_color(v[0] + v[1], v[2] + v[3], v[4] + v[5], v[6] + v[7]);
1169            } else {
1170                ep1 = blue_contract(v[0] + v[1], v[2] + v[3], v[4] + v[5], v[6] + v[7]);
1171                ep2 = blue_contract(v[0], v[2], v[4], v[6]);
1172            }
1173        }
1174
1175        _ => {
1176            // unsupported HDR modes
1177            ep1 = [0xFF, 0, 0xFF, 0xFF];
1178            ep2 = [0xFF, 0, 0xFF, 0xFF];
1179        }
1180    }
1181    [ep1, ep2]
1182}
1183
1184/// Configuration for ASTC block footprint size
1185///
1186/// This struct provides predefined constants for supported footprint in the specification,
1187/// as well as a constructor for custom footprint. It is recommended to use only the predefined constants.
1188/// The constructor can be used when you want to pass in dimensions defined numerically in other file formats.
1189#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
1190pub struct Footprint {
1191    block_width: u32,
1192    block_height: u32,
1193}
1194
1195impl Footprint {
1196    pub const ASTC_4X4: Footprint = Footprint {
1197        block_width: 4,
1198        block_height: 4,
1199    };
1200    pub const ASTC_5X4: Footprint = Footprint {
1201        block_width: 5,
1202        block_height: 4,
1203    };
1204    pub const ASTC_5X5: Footprint = Footprint {
1205        block_width: 5,
1206        block_height: 5,
1207    };
1208    pub const ASTC_6X5: Footprint = Footprint {
1209        block_width: 6,
1210        block_height: 5,
1211    };
1212    pub const ASTC_6X6: Footprint = Footprint {
1213        block_width: 6,
1214        block_height: 6,
1215    };
1216    pub const ASTC_8X5: Footprint = Footprint {
1217        block_width: 8,
1218        block_height: 5,
1219    };
1220    pub const ASTC_8X6: Footprint = Footprint {
1221        block_width: 8,
1222        block_height: 6,
1223    };
1224    pub const ASTC_10X5: Footprint = Footprint {
1225        block_width: 10,
1226        block_height: 5,
1227    };
1228    pub const ASTC_10X6: Footprint = Footprint {
1229        block_width: 10,
1230        block_height: 6,
1231    };
1232    pub const ASTC_8X8: Footprint = Footprint {
1233        block_width: 8,
1234        block_height: 8,
1235    };
1236    pub const ASTC_10X8: Footprint = Footprint {
1237        block_width: 10,
1238        block_height: 8,
1239    };
1240    pub const ASTC_10X10: Footprint = Footprint {
1241        block_width: 10,
1242        block_height: 10,
1243    };
1244    pub const ASTC_12X10: Footprint = Footprint {
1245        block_width: 12,
1246        block_height: 10,
1247    };
1248    pub const ASTC_12X12: Footprint = Footprint {
1249        block_width: 12,
1250        block_height: 12,
1251    };
1252
1253    /// Constructs custom footprint. Panic if any of the dimensions is zero.
1254    pub fn new(block_width: u32, block_height: u32) -> Footprint {
1255        if block_width == 0 || block_height == 0 {
1256            panic!("Invalid block size")
1257        }
1258        Footprint {
1259            block_width,
1260            block_height,
1261        }
1262    }
1263
1264    /// Returns the block width.
1265    pub fn block_width(&self) -> u32 {
1266        self.block_width
1267    }
1268
1269    /// Returns the block height.
1270    pub fn block_height(&self) -> u32 {
1271        self.block_height
1272    }
1273}
1274
1275/// Decode an ASTC block from a 16-byte buffer.
1276///
1277///  - `input` contains the raw ASTC data for one block.
1278///  - `footprint` specifies the footprint size.
1279///  - `writer` is a function with signature `FnMut(x: u32, y: u32, color: [u8; 4])`.
1280///     It is used to output decoded pixels.
1281///     This function will be called once for each pixel `(x, y)` in the rectangle
1282///     `[0, footprint.width()) * [0, footprint.height())`.
1283///     Each element in `color` represents the R, G, B, and A channel, respectively.
1284///  - returns `true` if the block is successfully decoded; `false` if the block contains illegal encoding.
1285/// ```
1286/// let footprint = astc_decode::Footprint::new(6, 6);
1287/// // Exmaple input. Not necessarily a valid ASTC block
1288/// let input = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
1289/// let mut output = [[0; 4]; 6 * 6];
1290/// astc_decode::astc_decode_block(&input, footprint, |x, y, color| {
1291///     output[(x + y * 6) as usize] = color;
1292/// });
1293/// ```
1294pub fn astc_decode_block<F: FnMut(u32, u32, [u8; 4])>(
1295    input: &[u8; 16],
1296    footprint: Footprint,
1297    mut writer: F,
1298) -> bool {
1299    let block_width = footprint.block_width;
1300    let block_height = footprint.block_height;
1301    let mut strm = InputBitStream::new(u128::from_le_bytes(*input));
1302    let weight_params = decode_block_info(&mut strm);
1303
1304    // Was there an error?
1305    if weight_params.is_error {
1306        fill_error(&mut writer, block_width, block_height);
1307        return false;
1308    }
1309
1310    if weight_params.void_extent_ldr {
1311        fill_void_extent_ldr(&mut strm, &mut writer, block_width, block_height);
1312        return true;
1313    }
1314
1315    if weight_params.void_extent_hdr {
1316        fill_error(&mut writer, block_width, block_height);
1317        return false;
1318    }
1319
1320    if weight_params.width > block_width {
1321        fill_error(&mut writer, block_width, block_height);
1322        return false;
1323    }
1324
1325    if weight_params.height > block_height {
1326        fill_error(&mut writer, block_width, block_height);
1327        return false;
1328    }
1329
1330    if weight_params.get_num_weight_values() > 64 {
1331        fill_error(&mut writer, block_width, block_height);
1332        return false;
1333    }
1334
1335    let n_weight_bits = weight_params.get_packed_bit_size();
1336
1337    if !(24..=96).contains(&n_weight_bits) {
1338        fill_error(&mut writer, block_width, block_height);
1339        return false;
1340    }
1341
1342    // Read num partitions
1343    let n_partitions = (strm.read_bits(2) + 1) as usize;
1344    assert!(n_partitions <= 4);
1345
1346    if n_partitions == 4 && weight_params.is_dual_plane {
1347        fill_error(&mut writer, block_width, block_height);
1348        return false;
1349    }
1350
1351    // Based on the number of partitions, read the color endpoint32 mode for
1352    // each partition.
1353
1354    // Determine partitions, partition index, and color endpoint32 modes
1355    let plane_idx;
1356    let partition_index;
1357    let mut endpoint_mods = [0, 0, 0, 0];
1358    let endpoint_mods = &mut endpoint_mods[0..n_partitions];
1359
1360    // Read extra config data...
1361    let mut base_cem = 0;
1362    if n_partitions == 1 {
1363        endpoint_mods[0] = strm.read_bits(4);
1364        partition_index = 0;
1365    } else {
1366        partition_index = strm.read_bits(10);
1367        base_cem = strm.read_bits(6);
1368    }
1369    let base_mode = base_cem & 3;
1370
1371    let mut non_color_bits = n_weight_bits + strm.get_bits_read();
1372
1373    // Consider extra bits prior to texel data...
1374    let mut extra_cem_bits = 0;
1375    if base_mode != 0 {
1376        match n_partitions {
1377            2 => extra_cem_bits += 2,
1378            3 => extra_cem_bits += 5,
1379            4 => extra_cem_bits += 8,
1380            _ => unreachable!(),
1381        }
1382    }
1383    non_color_bits += extra_cem_bits;
1384
1385    // Do we have a dual plane situation?
1386    let mut plane_selector_bits = 0;
1387    if weight_params.is_dual_plane {
1388        plane_selector_bits = 2;
1389    }
1390    non_color_bits += plane_selector_bits;
1391
1392    if non_color_bits >= 128 {
1393        fill_error(&mut writer, block_width, block_height);
1394        return false;
1395    }
1396
1397    // Read color data...
1398    let color_data_bits = 128 - non_color_bits;
1399    let endpoint_data = strm.read_bits128(color_data_bits);
1400
1401    // Read the plane selection bits
1402    plane_idx = strm.read_bits(plane_selector_bits);
1403
1404    // Read the rest of the cem
1405    if base_mode != 0 {
1406        let extra_cem = strm.read_bits(extra_cem_bits);
1407        let mut cem = (extra_cem << 6) | base_cem;
1408        cem >>= 2;
1409
1410        let mut c = [false; 4];
1411        for c in &mut c[0..n_partitions] {
1412            *c = (cem & 1) != 0;
1413            cem >>= 1;
1414        }
1415
1416        let mut m = [0; 4];
1417        for m in &mut m[0..n_partitions] {
1418            *m = cem & 3;
1419            cem >>= 2;
1420        }
1421
1422        for (i, endpoint_mod) in endpoint_mods.iter_mut().enumerate() {
1423            *endpoint_mod = base_mode;
1424            if !c[i] {
1425                *endpoint_mod -= 1;
1426            }
1427            *endpoint_mod <<= 2;
1428            *endpoint_mod |= m[i];
1429        }
1430    } else if n_partitions > 1 {
1431        let cem = base_cem >> 2;
1432        endpoint_mods[0..n_partitions].fill(cem);
1433    }
1434
1435    // Make sure everything up till here is sane.
1436    for &endpoint_mod in endpoint_mods.iter() {
1437        assert!(endpoint_mod < 16);
1438    }
1439    assert!(strm.get_bits_read() + weight_params.get_packed_bit_size() == 128);
1440
1441    // Figure out how many color values we have
1442    let n_values = endpoint_mods.iter().map(|m| ((m >> 2) + 1) << 1).sum();
1443
1444    if n_values > 18 || (n_values * 13 + 4) / 5 > color_data_bits {
1445        fill_error(&mut writer, block_width, block_height);
1446        return false;
1447    }
1448
1449    // Decode color data
1450    let color_values = decode_color_values(endpoint_data, n_values, color_data_bits);
1451
1452    let mut endpoints = [[[0; 4]; 2]; 4];
1453    let mut color_values_ptr = &color_values[0..n_values as usize];
1454    for i in 0..n_partitions {
1455        endpoints[i] = compute_endpoints(&mut color_values_ptr, endpoint_mods[i]);
1456    }
1457
1458    // Read the texel weight data
1459    let mut texel_weight_data = u128::from_le_bytes(*input).reverse_bits();
1460
1461    // Make sure that higher non-texel bits are set to zero
1462    texel_weight_data &= (1 << weight_params.get_packed_bit_size()) - 1;
1463
1464    // Decode weight data
1465    let mut weight_stream = InputBitStream::new(texel_weight_data);
1466    let weights = unquantize_texel_weights(
1467        &mut weight_stream,
1468        &weight_params,
1469        block_width,
1470        block_height,
1471    );
1472
1473    // Now that we have endpoints and weights, we can interpolate and generate
1474    // the proper decoding...
1475    for j in 0..block_height {
1476        for i in 0..block_width {
1477            let partition = select_2d_partition(
1478                partition_index,
1479                i,
1480                j,
1481                n_partitions,
1482                (block_height * block_width) < 32,
1483            );
1484            assert!(partition < n_partitions);
1485
1486            let mut p = [0; 4];
1487            for (c, p) in p.iter_mut().enumerate() {
1488                let c0 = endpoints[partition][0][c] as u32 * 0x101;
1489                let c1 = endpoints[partition][1][c] as u32 * 0x101;
1490
1491                let mut plane = 0;
1492                if weight_params.is_dual_plane && (plane_idx & 3 == c as u32) {
1493                    plane = 1;
1494                }
1495
1496                let weight = weights[plane][(j * block_width + i) as usize];
1497                let color = (c0 * (64 - weight) + c1 * weight + 32) / 64;
1498                *p = u8::try_from(((color * 255) + 32767) / 65536).unwrap();
1499            }
1500
1501            writer(i, j, p);
1502        }
1503    }
1504
1505    true
1506}
1507
1508/// Decode an ASTC image, assuming linear block layout.
1509///
1510///  - `input` us used to read raw ASTC data.
1511///  - `width` and `height` specify the image dimensions.
1512///  - `footprint` specifies the footprint size.
1513///  - `writer` is a function with signature `FnMut(x: u32, y: u32, color: [u8; 4])`.
1514///     It is used to output decoded pixels.
1515///     This function will be called once for each pixel `(x, y)` in the rectangle
1516///     `[0, width) * [0, height)`.
1517///     Each element in `color` represents the R, G, B, and A channel, respectively.
1518///  - Returns success or IO error encountered when reading `input`.
1519/// ```
1520/// let footprint = astc_decode::Footprint::new(6, 6);
1521/// // Exmaple input. Not necessarily valid ASTC image
1522/// let input = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
1523/// let mut output = [[0; 4]; 5 * 3];
1524/// astc_decode::astc_decode(&input[..], 5, 3, footprint, |x, y, color| {
1525///     output[(x + y * 5) as usize] = color;
1526/// }).unwrap();
1527/// ```
1528pub fn astc_decode<R: Read, F: FnMut(u32, u32, [u8; 4])>(
1529    mut input: R,
1530    width: u32,
1531    height: u32,
1532    footprint: Footprint,
1533    mut writer: F,
1534) -> Result<()> {
1535    let block_width = footprint.block_width;
1536    let block_height = footprint.block_height;
1537
1538    let block_w = (width.checked_add(block_width).unwrap() - 1) / block_width;
1539    let block_h = (height.checked_add(block_height).unwrap() - 1) / block_height;
1540
1541    for by in 0..block_h {
1542        for bx in 0..block_w {
1543            let mut block_buf = [0; 16];
1544            input.read_exact(&mut block_buf)?;
1545            astc_decode_block(&block_buf, footprint, |x, y, v| {
1546                let x = bx * block_width + x;
1547                let y = by * block_height + y;
1548                if x < width && y < height {
1549                    writer(x, y, v)
1550                }
1551            });
1552        }
1553    }
1554
1555    Ok(())
1556}
1557
1558#[cfg(test)]
1559mod tests {
1560    use super::*;
1561    use image::Pixel;
1562
1563    fn dist(a: u8, b: u8) {
1564        assert!((a as i32 - b as i32).abs() <= 1)
1565    }
1566
1567    fn test_case(astc: &[u8], bmp: &[u8], block_width: u32, block_height: u32) {
1568        let bmp = image::load_from_memory(bmp).unwrap().to_rgba8();
1569        let width = bmp.width();
1570        let height = bmp.height();
1571        astc_decode(
1572            &astc[16..],
1573            width,
1574            height,
1575            Footprint::new(block_width, block_height),
1576            |x, y, v| {
1577                let y = height - y - 1;
1578                let p = bmp.get_pixel(x as u32, y as u32).channels();
1579                dist(p[0], v[0]);
1580                dist(p[1], v[1]);
1581                dist(p[2], v[2]);
1582                dist(p[3], v[3]);
1583            },
1584        )
1585        .unwrap();
1586    }
1587
1588    macro_rules! tc {
1589        ($name:literal, $bw:literal, $bh:literal) => {
1590            test_case(
1591                include_bytes!(concat!("test-data/", $name, '_', $bw, 'x', $bh, ".astc")),
1592                include_bytes!(concat!("test-data/", $name, '_', $bw, 'x', $bh, ".bmp")),
1593                $bw,
1594                $bh,
1595            );
1596        };
1597    }
1598
1599    #[test]
1600    fn real_image() {
1601        tc!("atlas_small", 4, 4);
1602        tc!("atlas_small", 5, 5);
1603        tc!("atlas_small", 6, 6);
1604        tc!("atlas_small", 8, 8);
1605        tc!("footprint", 4, 4);
1606        tc!("footprint", 5, 4);
1607        tc!("footprint", 5, 5);
1608        tc!("footprint", 6, 5);
1609        tc!("footprint", 6, 6);
1610        tc!("footprint", 8, 5);
1611        tc!("footprint", 8, 6);
1612        tc!("footprint", 8, 8);
1613        tc!("footprint", 10, 5);
1614        tc!("footprint", 10, 6);
1615        tc!("footprint", 10, 8);
1616        tc!("footprint", 10, 10);
1617        tc!("footprint", 12, 10);
1618        tc!("footprint", 12, 12);
1619        tc!("rgb", 4, 4);
1620        tc!("rgb", 5, 4);
1621        tc!("rgb", 6, 6);
1622        tc!("rgb", 8, 8);
1623        tc!("rgb", 12, 12);
1624    }
1625
1626    fn fuzz_fp(w: u32, h: u32) {
1627        let footprint = Footprint::new(w, h);
1628        for _ in 0..10000 {
1629            let block = rand::random();
1630            astc_decode_block(&block, footprint, |_, _, _| {});
1631        }
1632    }
1633
1634    #[test]
1635    fn fuzzing() {
1636        fuzz_fp(4, 4);
1637        fuzz_fp(5, 4);
1638        fuzz_fp(5, 5);
1639        fuzz_fp(6, 5);
1640        fuzz_fp(6, 6);
1641        fuzz_fp(8, 5);
1642        fuzz_fp(8, 6);
1643        fuzz_fp(8, 8);
1644        fuzz_fp(10, 5);
1645        fuzz_fp(10, 6);
1646        fuzz_fp(10, 8);
1647        fuzz_fp(10, 10);
1648        fuzz_fp(12, 10);
1649        fuzz_fp(12, 12);
1650    }
1651}