b64_ct/decode/
avx2.rs

1/* Copyright (c) Fortanix, Inc.
2 *
3 * This Source Code Form is subject to the terms of the Mozilla Public
4 * License, v. 2.0. If a copy of the MPL was not distributed with this
5 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
6
7#[cfg(target_arch = "x86")]
8use core::arch::x86::*;
9#[cfg(target_arch = "x86_64")]
10use core::arch::x86_64::*;
11
12use crate::avx2::*;
13
14/// # Safety
15/// The caller should ensure the requisite CPU features are enabled.
16#[target_feature(enable = "avx2,bmi1,sse4.2,popcnt")]
17unsafe fn decode_avx2(input: __m256i) -> (__m256i, u32, u32) {
18    // Step 0. Split input bytes into nibbles.
19    let higher_nibble = _mm256_and_si256(_mm256_srli_epi16(input, 4), _mm256_set1_epi8(0x0f));
20    let lower_nibble = _mm256_and_si256(input, _mm256_set1_epi8(0x0f));
21
22    // Step 1. Find invalid characters. Steps 2 & 3 will compute invalid 6-bit
23    // values for invalid characters. The result of the computation should only
24    // be used if no invalid characters are found.
25
26    // This table contains 128 bits, one bit for each of the lower 128 ASCII
27    // characters. A set bit indicates that the character is in the base64
28    // character set (the character is valid) or the character is considered
29    // ASCII whitespace. This table is indexed by ASCII low nibble.
30    #[rustfmt::skip]
31    let row_lut = dup_mm_setr_epu8([
32        0b1010_1100, 0b1111_1000, 0b1111_1000, 0b1111_1000, 
33        0b1111_1000, 0b1111_1000, 0b1111_1000, 0b1111_1000, 
34        0b1111_1000, 0b1111_1001, 0b1111_0001, 0b0101_0100, 
35        0b0101_0001, 0b0101_0101, 0b0101_0000, 0b0111_0100,
36    ]);
37
38    // This table contains column offsets (within a byte) for the table above.
39    // This table is indexed by ASCII high nibble.
40    #[rustfmt::skip]
41    let column_lut = dup_mm_setr_epu8([
42        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
43           0,    0,    0,    0,    0,    0,    0,    0,
44    ]);
45
46    // Lookup table row
47    let row = _mm256_shuffle_epi8(row_lut, lower_nibble);
48    // Lookup column offset
49    let column = _mm256_shuffle_epi8(column_lut, higher_nibble);
50    // Lookup valid characters
51    let valid = _mm256_and_si256(row, column);
52    // Compute invalid character mask
53    let non_match = _mm256_cmpeq_epi8(valid, _mm256_setzero_si256());
54    // Transform mask to u32
55    let invalid_mask = _mm256_movemask_epi8(non_match);
56
57    // Step 2. Numbers & letters: compute 6-bit value for the 3 different
58    // ranges by simply adjusting the ASCII value.
59
60    // This table contains the offsets for the alphanumerical ASCII ranges.
61    // This table is indexed by ASCII high nibble.
62    #[rustfmt::skip]
63    let shift_lut = dup_mm_setr_epi8([
64        0, 0, 0,
65        // '0' through '9'
66        4,
67        // 'A' through 'Z'
68        -65, -65,
69        // 'a' through 'z'
70        -71, -71,
71        0, 0, 0, 0, 0, 0, 0, 0,
72    ]);
73
74    // Get offset
75    let shift = _mm256_shuffle_epi8(shift_lut, higher_nibble);
76    // Compute 6-bit value
77    let shifted = _mm256_add_epi8(input, shift);
78
79    // Step 3. Special characters: lookup 6-bit value by looking it up in a
80    // table.
81
82    // This table specifies the ASCII ranges that contain valid special
83    // characters. This table is indexed by ASCII high nibble.
84    #[rustfmt::skip]
85    let spcrange_lut = dup_mm_setr_epu8([
86        0, 0, 0xff, 0, 0, 0xff, 0, 0,
87        0, 0,    0, 0, 0,    0, 0, 0,
88    ]);
89
90    // This table specifies the (inverted) 6-bit values for the special
91    // characters. The values in this table act as both a value and a blend
92    // mask. This table is indexed by the difference between ASCII low and high
93    // nibble.
94    #[rustfmt::skip]
95    let spcchar_lut = dup_mm_setr_epu8([
96        0,   0,   0,   0, 0,   0, 0, 0,
97        // '+', '_', '-',    '/'
98        0, !62, !63, !62, 0, !63, 0, 0,
99    ]);
100
101    // Check if character is in the range for special characters
102    let sel_range = _mm256_shuffle_epi8(spcrange_lut, higher_nibble);
103    // Compute difference between ASCII low and high nibble
104    let lo_sub_hi = _mm256_sub_epi8(lower_nibble, higher_nibble);
105    // Lookup special character 6-bit value
106    let specials = _mm256_shuffle_epi8(spcchar_lut, lo_sub_hi);
107    // Combine blend masks from range and value
108    let sel_spec = _mm256_and_si256(sel_range, specials);
109
110    // Combine results of step 1 and step 2
111    let result = _mm256_blendv_epi8(shifted, _mm256_not_si256(specials), sel_spec);
112
113    // Step 4. Compute mask for valid non-whitespace bytes. The mask will be
114    // used to copy only relevant bytes into the output.
115
116    // This table specifies the character ranges which should be decoded. The
117    // format is a range table for the PCMPESTRM instruction.
118    #[rustfmt::skip]
119    let valid_nonws_set = _mm_setr_epi8(
120        b'A' as _, b'Z' as _,
121        b'a' as _, b'z' as _,
122        b'0' as _, b'9' as _,
123        b'+' as _, b'+' as _,
124        b'/' as _, b'/' as _,
125        b'-' as _, b'-' as _,
126        b'_' as _, b'_' as _,
127        0, 0,
128    );
129
130    // Split input into 128-bit values
131    let lane0 = _mm256_extracti128_si256(input, 0);
132    let lane1 = _mm256_extracti128_si256(input, 1);
133    // Compute bitmask for each 128-bit value
134    const CMP_FLAGS: i32 = _SIDD_UBYTE_OPS | _SIDD_CMP_RANGES | _SIDD_BIT_MASK;
135    let mask0 = _mm_cmpestrm(valid_nonws_set, 14, lane0, 16, CMP_FLAGS);
136    let mask1 = _mm_cmpestrm(valid_nonws_set, 14, lane1, 16, CMP_FLAGS);
137
138    // Combine bitmasks into integer value
139    let first = _mm_extract_epi16(mask0, 0) as u16;
140    let second = _mm_extract_epi16(mask1, 0) as u16;
141    let valid_mask = first as u32 + ((second as u32) << 16);
142
143    (result, invalid_mask as _, valid_mask as _)
144}
145
146/// # Safety
147/// The caller should ensure the requisite CPU features are enabled.
148#[target_feature(enable = "avx2,bmi1,sse4.2,popcnt")]
149unsafe fn decode_block(block: &mut <Avx2 as super::Decoder>::Block) -> super::BlockResult {
150    let input = array_as_m256i(*block);
151
152    let (unpacked, invalid_mask, mut valid_mask) = decode_avx2(input);
153
154    let unpacked = m256i_as_array(unpacked);
155
156    let first_invalid = match invalid_mask.trailing_zeros() {
157        32 => None,
158        v => Some(v as _),
159    };
160    let out_length = valid_mask.count_ones() as _;
161
162    let mut out_iter = block.iter_mut();
163    // TODO: Optimize loop (https://github.com/fortanix/b64-ct/issues/2)
164    for &val in unpacked.iter() {
165        if (valid_mask & 1) == 1 {
166            *out_iter.next().unwrap() = val;
167        }
168        valid_mask >>= 1;
169    }
170
171    super::BlockResult {
172        out_length,
173        first_invalid,
174    }
175}
176
177/// # Safety
178/// The caller should ensure the requisite CPU features are enabled.
179#[target_feature(enable = "avx2,bmi1,sse4.2,popcnt")]
180unsafe fn pack_block(input: &<Avx2 as super::Packer>::Input, output: &mut [u8]) {
181    assert_eq!(output.len(), <Avx2 as super::Packer>::OUT_BUF_LEN);
182
183    let unpacked = array_as_m256i(*input);
184
185    // Pack 32× 6-bit values into 16× 12-bit values
186    let packed1 = _mm256_maddubs_epi16(unpacked, _mm256_set1_epi16(0x0140));
187    // Pack 16× 12-bit values into 8× 3-byte values
188    let packed2 = _mm256_madd_epi16(packed1, _mm256_set1_epi32(0x00011000));
189    // Pack 8× 3-byte values into 2× 12-byte values
190    #[rustfmt::skip]
191    let packed3 = _mm256_shuffle_epi8(packed2, dup_mm_setr_epu8([
192           2,  1,  0,
193           6,  5,  4,
194          10,  9,  8,
195          14, 13, 12,
196          0xff, 0xff, 0xff, 0xff,
197    ]));
198
199    _mm_storeu_si128(
200        output.as_mut_ptr() as _,
201        _mm256_extracti128_si256(packed3, 0),
202    );
203    _mm_storeu_si128(
204        output.as_mut_ptr().offset(12) as _,
205        _mm256_extracti128_si256(packed3, 1),
206    );
207}
208
209#[derive(Copy, Clone)]
210pub(super) struct Avx2 {
211    _private: (),
212}
213
214impl Avx2 {
215    /// # Safety
216    /// The caller should ensure the requisite CPU features are enabled.
217    #[target_feature(enable = "avx2,bmi1,sse4.2,popcnt")]
218    pub(super) unsafe fn new() -> Avx2 {
219        Avx2 { _private: () }
220    }
221}
222
223impl super::Decoder for Avx2 {
224    type Block = [u8; 32];
225
226    #[inline]
227    fn decode_block(self, block: &mut Self::Block) -> super::BlockResult {
228        // safe: `self` was given as a witness that the features are available
229        unsafe { decode_block(block) }
230    }
231
232    #[inline(always)]
233    fn zero_block() -> Self::Block {
234        [b' '; 32]
235    }
236}
237
238impl super::Packer for Avx2 {
239    type Input = [u8; 32];
240    const OUT_BUF_LEN: usize = 28;
241
242    fn pack_block(self, input: &Self::Input, output: &mut [u8]) {
243        // safe: `self` was given as a witness that the features are available
244        unsafe { pack_block(input, output) }
245    }
246}