blake2s_simd/
guts.rs

1use crate::*;
2use arrayref::array_ref;
3use core::cmp;
4
5#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
6pub const MAX_DEGREE: usize = 8;
7
8#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
9pub const MAX_DEGREE: usize = 1;
10
11// Variants other than Portable are unreachable in no_std, unless CPU features
12// are explicitly enabled for the build with e.g. RUSTFLAGS="-C target-feature=avx2".
13// This might change in the future if is_x86_feature_detected moves into libcore.
14#[allow(dead_code)]
15#[derive(Clone, Copy, Debug, Eq, PartialEq)]
16enum Platform {
17    Portable,
18    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
19    SSE41,
20    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
21    AVX2,
22}
23
24#[derive(Clone, Copy, Debug)]
25pub struct Implementation(Platform);
26
27impl Implementation {
28    pub fn detect() -> Self {
29        // Try the different implementations in order of how fast/modern they
30        // are. Currently on non-x86, everything just uses portable.
31        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
32        {
33            if let Some(avx2_impl) = Self::avx2_if_supported() {
34                return avx2_impl;
35            }
36        }
37        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
38        {
39            if let Some(sse41_impl) = Self::sse41_if_supported() {
40                return sse41_impl;
41            }
42        }
43        Self::portable()
44    }
45
46    pub fn portable() -> Self {
47        Implementation(Platform::Portable)
48    }
49
50    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
51    #[allow(unreachable_code)]
52    pub fn sse41_if_supported() -> Option<Self> {
53        // Check whether SSE4.1 support is assumed by the build.
54        #[cfg(target_feature = "sse4.1")]
55        {
56            return Some(Implementation(Platform::SSE41));
57        }
58        // Otherwise dynamically check for support if we can.
59        #[cfg(feature = "std")]
60        {
61            if is_x86_feature_detected!("sse4.1") {
62                return Some(Implementation(Platform::SSE41));
63            }
64        }
65        None
66    }
67
68    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
69    #[allow(unreachable_code)]
70    pub fn avx2_if_supported() -> Option<Self> {
71        // Check whether AVX2 support is assumed by the build.
72        #[cfg(target_feature = "avx2")]
73        {
74            return Some(Implementation(Platform::AVX2));
75        }
76        // Otherwise dynamically check for support if we can.
77        #[cfg(feature = "std")]
78        {
79            if is_x86_feature_detected!("avx2") {
80                return Some(Implementation(Platform::AVX2));
81            }
82        }
83        None
84    }
85
86    pub fn degree(&self) -> usize {
87        match self.0 {
88            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
89            Platform::AVX2 => avx2::DEGREE,
90            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
91            Platform::SSE41 => sse41::DEGREE,
92            Platform::Portable => 1,
93        }
94    }
95
96    pub fn compress1_loop(
97        &self,
98        input: &[u8],
99        words: &mut [Word; 8],
100        count: Count,
101        last_node: LastNode,
102        finalize: Finalize,
103        stride: Stride,
104    ) {
105        match self.0 {
106            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
107            Platform::AVX2 | Platform::SSE41 => unsafe {
108                sse41::compress1_loop(input, words, count, last_node, finalize, stride);
109            },
110            Platform::Portable => {
111                portable::compress1_loop(input, words, count, last_node, finalize, stride);
112            }
113        }
114    }
115
116    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
117    pub fn compress4_loop(&self, jobs: &mut [Job; 4], finalize: Finalize, stride: Stride) {
118        match self.0 {
119            Platform::AVX2 | Platform::SSE41 => unsafe {
120                sse41::compress4_loop(jobs, finalize, stride)
121            },
122            _ => panic!("unsupported"),
123        }
124    }
125
126    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
127    pub fn compress8_loop(&self, jobs: &mut [Job; 8], finalize: Finalize, stride: Stride) {
128        match self.0 {
129            Platform::AVX2 => unsafe { avx2::compress8_loop(jobs, finalize, stride) },
130            _ => panic!("unsupported"),
131        }
132    }
133}
134
135pub struct Job<'a, 'b> {
136    pub input: &'a [u8],
137    pub words: &'b mut [Word; 8],
138    pub count: Count,
139    pub last_node: LastNode,
140}
141
142impl<'a, 'b> core::fmt::Debug for Job<'a, 'b> {
143    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
144        // NB: Don't print the words. Leaking them would allow length extension.
145        write!(
146            f,
147            "Job {{ input_len: {}, count: {}, last_node: {} }}",
148            self.input.len(),
149            self.count,
150            self.last_node.yes(),
151        )
152    }
153}
154
155// Finalize could just be a bool, but this is easier to read at callsites.
156#[derive(Clone, Copy, Debug)]
157pub enum Finalize {
158    Yes,
159    No,
160}
161
162impl Finalize {
163    pub fn yes(&self) -> bool {
164        match self {
165            Finalize::Yes => true,
166            Finalize::No => false,
167        }
168    }
169}
170
171// Like Finalize, this is easier to read at callsites.
172#[derive(Clone, Copy, Debug)]
173pub enum LastNode {
174    Yes,
175    No,
176}
177
178impl LastNode {
179    pub fn yes(&self) -> bool {
180        match self {
181            LastNode::Yes => true,
182            LastNode::No => false,
183        }
184    }
185}
186
187#[derive(Clone, Copy, Debug)]
188pub enum Stride {
189    Serial,   // BLAKE2b/BLAKE2s
190    Parallel, // BLAKE2bp/BLAKE2sp
191}
192
193impl Stride {
194    pub fn padded_blockbytes(&self) -> usize {
195        match self {
196            Stride::Serial => BLOCKBYTES,
197            Stride::Parallel => blake2sp::DEGREE * BLOCKBYTES,
198        }
199    }
200}
201
202pub(crate) fn count_low(count: Count) -> Word {
203    count as Word
204}
205
206pub(crate) fn count_high(count: Count) -> Word {
207    (count >> 8 * size_of::<Word>()) as Word
208}
209
210#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
211pub(crate) fn assemble_count(low: Word, high: Word) -> Count {
212    low as Count + ((high as Count) << 8 * size_of::<Word>())
213}
214
215pub(crate) fn flag_word(flag: bool) -> Word {
216    if flag {
217        !0
218    } else {
219        0
220    }
221}
222
223// Pull a array reference at the given offset straight from the input, if
224// there's a full block of input available. If there's only a partial block,
225// copy it into the provided buffer, and return an array reference that. Along
226// with the array, return the number of bytes of real input, and whether the
227// input can be finalized (i.e. whether there aren't any more bytes after this
228// block). Note that this is written so that the optimizer can elide bounds
229// checks, see: https://godbolt.org/z/0hH2bC
230pub fn final_block<'a>(
231    input: &'a [u8],
232    offset: usize,
233    buffer: &'a mut [u8; BLOCKBYTES],
234    stride: Stride,
235) -> (&'a [u8; BLOCKBYTES], usize, bool) {
236    let capped_offset = cmp::min(offset, input.len());
237    let offset_slice = &input[capped_offset..];
238    if offset_slice.len() >= BLOCKBYTES {
239        let block = array_ref!(offset_slice, 0, BLOCKBYTES);
240        let should_finalize = offset_slice.len() <= stride.padded_blockbytes();
241        (block, BLOCKBYTES, should_finalize)
242    } else {
243        // Copy the final block to the front of the block buffer. The rest of
244        // the buffer is assumed to be initialized to zero.
245        buffer[..offset_slice.len()].copy_from_slice(offset_slice);
246        (buffer, offset_slice.len(), true)
247    }
248}
249
250pub fn input_debug_asserts(input: &[u8], finalize: Finalize) {
251    // If we're not finalizing, the input must not be empty, and it must be an
252    // even multiple of the block size.
253    if !finalize.yes() {
254        debug_assert!(!input.is_empty());
255        debug_assert_eq!(0, input.len() % BLOCKBYTES);
256    }
257}
258
259#[cfg(test)]
260mod test {
261    use super::*;
262    use core::mem::size_of;
263
264    #[test]
265    fn test_detection() {
266        assert_eq!(Platform::Portable, Implementation::portable().0);
267
268        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
269        #[cfg(feature = "std")]
270        {
271            if is_x86_feature_detected!("avx2") {
272                assert_eq!(Platform::AVX2, Implementation::detect().0);
273                assert_eq!(
274                    Platform::AVX2,
275                    Implementation::avx2_if_supported().unwrap().0
276                );
277                assert_eq!(
278                    Platform::SSE41,
279                    Implementation::sse41_if_supported().unwrap().0
280                );
281            } else if is_x86_feature_detected!("sse4.1") {
282                assert_eq!(Platform::SSE41, Implementation::detect().0);
283                assert!(Implementation::avx2_if_supported().is_none());
284                assert_eq!(
285                    Platform::SSE41,
286                    Implementation::sse41_if_supported().unwrap().0
287                );
288            } else {
289                assert_eq!(Platform::Portable, Implementation::detect().0);
290                assert!(Implementation::avx2_if_supported().is_none());
291                assert!(Implementation::sse41_if_supported().is_none());
292            }
293        }
294    }
295
296    fn exercise_cases<F>(mut f: F)
297    where
298        F: FnMut(Stride, usize, LastNode, Finalize, Count),
299    {
300        // Chose counts to hit the relevant overflow cases.
301        let counts = &[
302            (0 as Count),
303            ((1 as Count) << (8 * size_of::<Word>())) - BLOCKBYTES as Count,
304            (0 as Count).wrapping_sub(BLOCKBYTES as Count),
305        ];
306        for &stride in &[Stride::Serial, Stride::Parallel] {
307            let lengths = [
308                0,
309                1,
310                BLOCKBYTES - 1,
311                BLOCKBYTES,
312                BLOCKBYTES + 1,
313                2 * BLOCKBYTES - 1,
314                2 * BLOCKBYTES,
315                2 * BLOCKBYTES + 1,
316                stride.padded_blockbytes() - 1,
317                stride.padded_blockbytes(),
318                stride.padded_blockbytes() + 1,
319                2 * stride.padded_blockbytes() - 1,
320                2 * stride.padded_blockbytes(),
321                2 * stride.padded_blockbytes() + 1,
322            ];
323            for &length in &lengths {
324                for &last_node in &[LastNode::No, LastNode::Yes] {
325                    for &finalize in &[Finalize::No, Finalize::Yes] {
326                        if !finalize.yes() && (length == 0 || length % BLOCKBYTES != 0) {
327                            // Skip these cases, they're invalid.
328                            continue;
329                        }
330                        for &count in counts {
331                            // eprintln!("\ncase -----");
332                            // dbg!(stride);
333                            // dbg!(length);
334                            // dbg!(last_node);
335                            // dbg!(finalize);
336                            // dbg!(count);
337
338                            f(stride, length, last_node, finalize, count);
339                        }
340                    }
341                }
342            }
343        }
344    }
345
346    fn initial_test_words(input_index: usize) -> [Word; 8] {
347        crate::Params::new()
348            .node_offset(input_index as u64)
349            .to_words()
350    }
351
352    // Use the portable implementation, one block at a time, to compute the
353    // final state words expected for a given test case.
354    fn reference_compression(
355        input: &[u8],
356        stride: Stride,
357        last_node: LastNode,
358        finalize: Finalize,
359        mut count: Count,
360        input_index: usize,
361    ) -> [Word; 8] {
362        let mut words = initial_test_words(input_index);
363        let mut offset = 0;
364        while offset == 0 || offset < input.len() {
365            let block_size = cmp::min(BLOCKBYTES, input.len() - offset);
366            let maybe_finalize = if offset + stride.padded_blockbytes() < input.len() {
367                Finalize::No
368            } else {
369                finalize
370            };
371            portable::compress1_loop(
372                &input[offset..][..block_size],
373                &mut words,
374                count,
375                last_node,
376                maybe_finalize,
377                Stride::Serial,
378            );
379            offset += stride.padded_blockbytes();
380            count = count.wrapping_add(BLOCKBYTES as Count);
381        }
382        words
383    }
384
385    // For various loop lengths and finalization parameters, make sure that the
386    // implementation gives the same answer as the portable implementation does
387    // when invoked one block at a time. (So even the portable implementation
388    // itself is being tested here, to make sure its loop is correct.) Note
389    // that this doesn't include any fixed test vectors; those are taken from
390    // the blake2-kat.json file (copied from upstream) and tested elsewhere.
391    fn exercise_compress1_loop(implementation: Implementation) {
392        let mut input = [0; 100 * BLOCKBYTES];
393        paint_test_input(&mut input);
394
395        exercise_cases(|stride, length, last_node, finalize, count| {
396            let reference_words =
397                reference_compression(&input[..length], stride, last_node, finalize, count, 0);
398
399            let mut test_words = initial_test_words(0);
400            implementation.compress1_loop(
401                &input[..length],
402                &mut test_words,
403                count,
404                last_node,
405                finalize,
406                stride,
407            );
408            assert_eq!(reference_words, test_words);
409        });
410    }
411
412    #[test]
413    fn test_compress1_loop_portable() {
414        exercise_compress1_loop(Implementation::portable());
415    }
416
417    #[test]
418    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
419    fn test_compress1_loop_sse41() {
420        if let Some(imp) = Implementation::sse41_if_supported() {
421            exercise_compress1_loop(imp);
422        }
423    }
424
425    #[test]
426    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
427    fn test_compress1_loop_avx2() {
428        // Currently this just falls back to SSE4.1, but we test it anyway.
429        if let Some(imp) = Implementation::avx2_if_supported() {
430            exercise_compress1_loop(imp);
431        }
432    }
433
434    // I use ArrayVec everywhere in here becuase currently these tests pass
435    // under no_std. I might decide that's not worth maintaining at some point,
436    // since really all we care about with no_std is that the library builds,
437    // but for now it's here. Everything is keyed off of this N constant so
438    // that it's easy to copy the code to exercise_compress4_loop.
439    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
440    fn exercise_compress4_loop(implementation: Implementation) {
441        const N: usize = 4;
442
443        let mut input_buffer = [0; 100 * BLOCKBYTES];
444        paint_test_input(&mut input_buffer);
445        let mut inputs = arrayvec::ArrayVec::<_, N>::new();
446        for i in 0..N {
447            inputs.push(&input_buffer[i..]);
448        }
449
450        exercise_cases(|stride, length, last_node, finalize, count| {
451            let mut reference_words = arrayvec::ArrayVec::<_, N>::new();
452            for i in 0..N {
453                let words = reference_compression(
454                    &inputs[i][..length],
455                    stride,
456                    last_node,
457                    finalize,
458                    count.wrapping_add((i * BLOCKBYTES) as Count),
459                    i,
460                );
461                reference_words.push(words);
462            }
463
464            let mut test_words = arrayvec::ArrayVec::<_, N>::new();
465            for i in 0..N {
466                test_words.push(initial_test_words(i));
467            }
468            let mut jobs = arrayvec::ArrayVec::<_, N>::new();
469            for (i, words) in test_words.iter_mut().enumerate() {
470                jobs.push(Job {
471                    input: &inputs[i][..length],
472                    words,
473                    count: count.wrapping_add((i * BLOCKBYTES) as Count),
474                    last_node,
475                });
476            }
477            let mut jobs = jobs.into_inner().expect("full");
478            implementation.compress4_loop(&mut jobs, finalize, stride);
479
480            for i in 0..N {
481                assert_eq!(reference_words[i], test_words[i], "words {} unequal", i);
482            }
483        });
484    }
485
486    #[test]
487    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
488    fn test_compress4_loop_sse41() {
489        if let Some(imp) = Implementation::sse41_if_supported() {
490            exercise_compress4_loop(imp);
491        }
492    }
493
494    #[test]
495    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
496    fn test_compress4_loop_avx2() {
497        // Currently this just falls back to SSE4.1, but we test it anyway.
498        if let Some(imp) = Implementation::avx2_if_supported() {
499            exercise_compress4_loop(imp);
500        }
501    }
502
503    // Copied from exercise_compress2_loop, with a different value of N and an
504    // interior call to compress4_loop.
505    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
506    fn exercise_compress8_loop(implementation: Implementation) {
507        const N: usize = 8;
508
509        let mut input_buffer = [0; 100 * BLOCKBYTES];
510        paint_test_input(&mut input_buffer);
511        let mut inputs = arrayvec::ArrayVec::<_, N>::new();
512        for i in 0..N {
513            inputs.push(&input_buffer[i..]);
514        }
515
516        exercise_cases(|stride, length, last_node, finalize, count| {
517            let mut reference_words = arrayvec::ArrayVec::<_, N>::new();
518            for i in 0..N {
519                let words = reference_compression(
520                    &inputs[i][..length],
521                    stride,
522                    last_node,
523                    finalize,
524                    count.wrapping_add((i * BLOCKBYTES) as Count),
525                    i,
526                );
527                reference_words.push(words);
528            }
529
530            let mut test_words = arrayvec::ArrayVec::<_, N>::new();
531            for i in 0..N {
532                test_words.push(initial_test_words(i));
533            }
534            let mut jobs = arrayvec::ArrayVec::<_, N>::new();
535            for (i, words) in test_words.iter_mut().enumerate() {
536                jobs.push(Job {
537                    input: &inputs[i][..length],
538                    words,
539                    count: count.wrapping_add((i * BLOCKBYTES) as Count),
540                    last_node,
541                });
542            }
543            let mut jobs = jobs.into_inner().expect("full");
544            implementation.compress8_loop(&mut jobs, finalize, stride);
545
546            for i in 0..N {
547                assert_eq!(reference_words[i], test_words[i], "words {} unequal", i);
548            }
549        });
550    }
551
552    #[test]
553    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
554    fn test_compress8_loop_avx2() {
555        if let Some(imp) = Implementation::avx2_if_supported() {
556            exercise_compress8_loop(imp);
557        }
558    }
559
560    #[test]
561    fn sanity_check_count_size() {
562        assert_eq!(size_of::<Count>(), 2 * size_of::<Word>());
563    }
564}