blake2s_const/
guts.rs

1use crate::*;
2use arrayref::array_ref;
3use core::cmp;
4
5#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
6pub const MAX_DEGREE: usize = 8;
7
8#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
9pub const MAX_DEGREE: usize = 1;
10
11// Variants other than Portable are unreachable in no_std, unless CPU features
12// are explicitly enabled for the build with e.g. RUSTFLAGS="-C target-feature=avx2".
13// This might change in the future if is_x86_feature_detected moves into libcore.
14#[allow(dead_code)]
15#[derive(Clone, Copy, Debug, Eq, PartialEq)]
16pub enum Platform {
17    Portable,
18    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
19    SSE41,
20    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
21    AVX2,
22}
23
24#[derive(Clone, Copy, Debug)]
25pub struct Implementation(pub Platform);
26
27impl Implementation {
28    pub fn detect() -> Self {
29        // Try the different implementations in order of how fast/modern they
30        // are. Currently on non-x86, everything just uses portable.
31        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
32        {
33            if let Some(avx2_impl) = Self::avx2_if_supported() {
34                return avx2_impl;
35            }
36        }
37        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
38        {
39            if let Some(sse41_impl) = Self::sse41_if_supported() {
40                return sse41_impl;
41            }
42        }
43        Self::portable()
44    }
45
46    pub const fn portable() -> Self {
47        Implementation(Platform::Portable)
48    }
49
50    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
51    #[allow(unreachable_code)]
52    pub const fn sse41_if_supported() -> Option<Self> {
53        // Check whether SSE4.1 support is assumed by the build.
54        #[cfg(target_feature = "sse4.1")]
55        {
56            return Some(Implementation(Platform::SSE41));
57        }
58
59        None
60    }
61
62    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
63    #[allow(unreachable_code)]
64    pub const fn avx2_if_supported() -> Option<Self> {
65        // Check whether AVX2 support is assumed by the build.
66        #[cfg(target_feature = "avx2")]
67        {
68            return Some(Implementation(Platform::AVX2));
69        }
70
71        None
72    }
73
74    pub fn degree(&self) -> usize {
75        match self.0 {
76            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
77            Platform::AVX2 => avx2::DEGREE,
78            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
79            Platform::SSE41 => sse41::DEGREE,
80            Platform::Portable => 1,
81        }
82    }
83
84    pub fn compress1_loop(
85        &self,
86        input: &[u8],
87        words: &mut [Word; 8],
88        count: Count,
89        last_node: LastNode,
90        finalize: Finalize,
91        stride: Stride,
92    ) {
93        match self.0 {
94            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
95            Platform::AVX2 | Platform::SSE41 => unsafe {
96                sse41::compress1_loop(input, words, count, last_node, finalize, stride);
97            },
98            Platform::Portable => {
99                portable::compress1_loop(input, words, count, last_node, finalize, stride);
100            }
101        }
102    }
103
104    pub fn compress4_loop(&self, jobs: &mut [Job; 4], finalize: Finalize, stride: Stride) {
105        match self.0 {
106            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
107            Platform::AVX2 | Platform::SSE41 => unsafe {
108                sse41::compress4_loop(jobs, finalize, stride)
109            },
110            _ => panic!("unsupported"),
111        }
112    }
113
114    pub fn compress8_loop(&self, jobs: &mut [Job; 8], finalize: Finalize, stride: Stride) {
115        match self.0 {
116            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
117            Platform::AVX2 => unsafe { avx2::compress8_loop(jobs, finalize, stride) },
118            _ => panic!("unsupported"),
119        }
120    }
121}
122
123pub struct Job<'a, 'b> {
124    pub input: &'a [u8],
125    pub words: &'b mut [Word; 8],
126    pub count: Count,
127    pub last_node: LastNode,
128}
129
130impl<'a, 'b> core::fmt::Debug for Job<'a, 'b> {
131    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
132        // NB: Don't print the words. Leaking them would allow length extension.
133        write!(
134            f,
135            "Job {{ input_len: {}, count: {}, last_node: {} }}",
136            self.input.len(),
137            self.count,
138            self.last_node.yes(),
139        )
140    }
141}
142
143// Finalize could just be a bool, but this is easier to read at callsites.
144#[derive(Clone, Copy, Debug)]
145pub enum Finalize {
146    Yes,
147    No,
148}
149
150impl Finalize {
151    pub fn yes(&self) -> bool {
152        match self {
153            Finalize::Yes => true,
154            Finalize::No => false,
155        }
156    }
157}
158
159// Like Finalize, this is easier to read at callsites.
160#[derive(Clone, Copy, Debug)]
161pub enum LastNode {
162    Yes,
163    No,
164}
165
166impl LastNode {
167    pub fn yes(&self) -> bool {
168        match self {
169            LastNode::Yes => true,
170            LastNode::No => false,
171        }
172    }
173}
174
175#[derive(Clone, Copy, Debug)]
176pub enum Stride {
177    Serial,   // BLAKE2b/BLAKE2s
178    Parallel, // BLAKE2bp/BLAKE2sp
179}
180
181impl Stride {
182    pub fn padded_blockbytes(&self) -> usize {
183        match self {
184            Stride::Serial => BLOCKBYTES,
185            Stride::Parallel => blake2sp::DEGREE * BLOCKBYTES,
186        }
187    }
188}
189
190pub(crate) fn count_low(count: Count) -> Word {
191    count as Word
192}
193
194pub(crate) fn count_high(count: Count) -> Word {
195    (count >> 8 * size_of::<Word>()) as Word
196}
197
198pub(crate) fn assemble_count(low: Word, high: Word) -> Count {
199    low as Count + ((high as Count) << 8 * size_of::<Word>())
200}
201
202pub(crate) fn flag_word(flag: bool) -> Word {
203    if flag {
204        !0
205    } else {
206        0
207    }
208}
209
210// Pull a array reference at the given offset straight from the input, if
211// there's a full block of input available. If there's only a partial block,
212// copy it into the provided buffer, and return an array reference that. Along
213// with the array, return the number of bytes of real input, and whether the
214// input can be finalized (i.e. whether there aren't any more bytes after this
215// block). Note that this is written so that the optimizer can elide bounds
216// checks, see: https://godbolt.org/z/0hH2bC
217pub fn final_block<'a>(
218    input: &'a [u8],
219    offset: usize,
220    buffer: &'a mut [u8; BLOCKBYTES],
221    stride: Stride,
222) -> (&'a [u8; BLOCKBYTES], usize, bool) {
223    let capped_offset = cmp::min(offset, input.len());
224    let offset_slice = &input[capped_offset..];
225    if offset_slice.len() >= BLOCKBYTES {
226        let block = array_ref!(offset_slice, 0, BLOCKBYTES);
227        let should_finalize = offset_slice.len() <= stride.padded_blockbytes();
228        (block, BLOCKBYTES, should_finalize)
229    } else {
230        // Copy the final block to the front of the block buffer. The rest of
231        // the buffer is assumed to be initialized to zero.
232        buffer[..offset_slice.len()].copy_from_slice(offset_slice);
233        (buffer, offset_slice.len(), true)
234    }
235}
236
237pub fn input_debug_asserts(input: &[u8], finalize: Finalize) {
238    // If we're not finalizing, the input must not be empty, and it must be an
239    // even multiple of the block size.
240    if !finalize.yes() {
241        debug_assert!(!input.is_empty());
242        debug_assert_eq!(0, input.len() % BLOCKBYTES);
243    }
244}
245
246#[cfg(test)]
247mod test {
248    use super::*;
249    use arrayvec::ArrayVec;
250    use core::mem::size_of;
251
252    #[test]
253    fn test_detection() {
254        assert_eq!(Platform::Portable, Implementation::portable().0);
255
256        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
257        #[cfg(feature = "std")]
258        {
259            if is_x86_feature_detected!("avx2") {
260                assert_eq!(Platform::AVX2, Implementation::detect().0);
261                assert_eq!(
262                    Platform::AVX2,
263                    Implementation::avx2_if_supported().unwrap().0
264                );
265                assert_eq!(
266                    Platform::SSE41,
267                    Implementation::sse41_if_supported().unwrap().0
268                );
269            } else if is_x86_feature_detected!("sse4.1") {
270                assert_eq!(Platform::SSE41, Implementation::detect().0);
271                assert!(Implementation::avx2_if_supported().is_none());
272                assert_eq!(
273                    Platform::SSE41,
274                    Implementation::sse41_if_supported().unwrap().0
275                );
276            } else {
277                assert_eq!(Platform::Portable, Implementation::detect().0);
278                assert!(Implementation::avx2_if_supported().is_none());
279                assert!(Implementation::sse41_if_supported().is_none());
280            }
281        }
282    }
283
284    fn exercise_cases<F>(mut f: F)
285    where
286        F: FnMut(Stride, usize, LastNode, Finalize, Count),
287    {
288        // Chose counts to hit the relevant overflow cases.
289        let counts = &[
290            (0 as Count),
291            ((1 as Count) << (8 * size_of::<Word>())) - BLOCKBYTES as Count,
292            (0 as Count).wrapping_sub(BLOCKBYTES as Count),
293        ];
294        for &stride in &[Stride::Serial, Stride::Parallel] {
295            let lengths = [
296                0,
297                1,
298                BLOCKBYTES - 1,
299                BLOCKBYTES,
300                BLOCKBYTES + 1,
301                2 * BLOCKBYTES - 1,
302                2 * BLOCKBYTES,
303                2 * BLOCKBYTES + 1,
304                stride.padded_blockbytes() - 1,
305                stride.padded_blockbytes(),
306                stride.padded_blockbytes() + 1,
307                2 * stride.padded_blockbytes() - 1,
308                2 * stride.padded_blockbytes(),
309                2 * stride.padded_blockbytes() + 1,
310            ];
311            for &length in &lengths {
312                for &last_node in &[LastNode::No, LastNode::Yes] {
313                    for &finalize in &[Finalize::No, Finalize::Yes] {
314                        if !finalize.yes() && (length == 0 || length % BLOCKBYTES != 0) {
315                            // Skip these cases, they're invalid.
316                            continue;
317                        }
318                        for &count in counts {
319                            // eprintln!("\ncase -----");
320                            // dbg!(stride);
321                            // dbg!(length);
322                            // dbg!(last_node);
323                            // dbg!(finalize);
324                            // dbg!(count);
325
326                            f(stride, length, last_node, finalize, count);
327                        }
328                    }
329                }
330            }
331        }
332    }
333
334    fn initial_test_words(input_index: usize) -> [Word; 8] {
335        crate::Params::new()
336            .node_offset(input_index as u64)
337            .to_words()
338    }
339
340    // Use the portable implementation, one block at a time, to compute the
341    // final state words expected for a given test case.
342    fn reference_compression(
343        input: &[u8],
344        stride: Stride,
345        last_node: LastNode,
346        finalize: Finalize,
347        mut count: Count,
348        input_index: usize,
349    ) -> [Word; 8] {
350        let mut words = initial_test_words(input_index);
351        let mut offset = 0;
352        while offset == 0 || offset < input.len() {
353            let block_size = cmp::min(BLOCKBYTES, input.len() - offset);
354            let maybe_finalize = if offset + stride.padded_blockbytes() < input.len() {
355                Finalize::No
356            } else {
357                finalize
358            };
359            portable::compress1_loop(
360                &input[offset..][..block_size],
361                &mut words,
362                count,
363                last_node,
364                maybe_finalize,
365                Stride::Serial,
366            );
367            offset += stride.padded_blockbytes();
368            count = count.wrapping_add(BLOCKBYTES as Count);
369        }
370        words
371    }
372
373    // For various loop lengths and finalization parameters, make sure that the
374    // implementation gives the same answer as the portable implementation does
375    // when invoked one block at a time. (So even the portable implementation
376    // itself is being tested here, to make sure its loop is correct.) Note
377    // that this doesn't include any fixed test vectors; those are taken from
378    // the blake2-kat.json file (copied from upstream) and tested elsewhere.
379    fn exercise_compress1_loop(implementation: Implementation) {
380        let mut input = [0; 100 * BLOCKBYTES];
381        paint_test_input(&mut input);
382
383        exercise_cases(|stride, length, last_node, finalize, count| {
384            let reference_words =
385                reference_compression(&input[..length], stride, last_node, finalize, count, 0);
386
387            let mut test_words = initial_test_words(0);
388            implementation.compress1_loop(
389                &input[..length],
390                &mut test_words,
391                count,
392                last_node,
393                finalize,
394                stride,
395            );
396            assert_eq!(reference_words, test_words);
397        });
398    }
399
400    #[test]
401    fn test_compress1_loop_portable() {
402        exercise_compress1_loop(Implementation::portable());
403    }
404
405    #[test]
406    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
407    fn test_compress1_loop_sse41() {
408        if let Some(imp) = Implementation::sse41_if_supported() {
409            exercise_compress1_loop(imp);
410        }
411    }
412
413    #[test]
414    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
415    fn test_compress1_loop_avx2() {
416        // Currently this just falls back to SSE4.1, but we test it anyway.
417        if let Some(imp) = Implementation::avx2_if_supported() {
418            exercise_compress1_loop(imp);
419        }
420    }
421
422    // I use ArrayVec everywhere in here becuase currently these tests pass
423    // under no_std. I might decide that's not worth maintaining at some point,
424    // since really all we care about with no_std is that the library builds,
425    // but for now it's here. Everything is keyed off of this N constant so
426    // that it's easy to copy the code to exercise_compress4_loop.
427    fn exercise_compress4_loop(implementation: Implementation) {
428        const N: usize = 4;
429
430        let mut input_buffer = [0; 100 * BLOCKBYTES];
431        paint_test_input(&mut input_buffer);
432        let mut inputs = ArrayVec::<[_; N]>::new();
433        for i in 0..N {
434            inputs.push(&input_buffer[i..]);
435        }
436
437        exercise_cases(|stride, length, last_node, finalize, count| {
438            let mut reference_words = ArrayVec::<[_; N]>::new();
439            for i in 0..N {
440                let words = reference_compression(
441                    &inputs[i][..length],
442                    stride,
443                    last_node,
444                    finalize,
445                    count.wrapping_add((i * BLOCKBYTES) as Count),
446                    i,
447                );
448                reference_words.push(words);
449            }
450
451            let mut test_words = ArrayVec::<[_; N]>::new();
452            for i in 0..N {
453                test_words.push(initial_test_words(i));
454            }
455            let mut jobs = ArrayVec::<[_; N]>::new();
456            for (i, words) in test_words.iter_mut().enumerate() {
457                jobs.push(Job {
458                    input: &inputs[i][..length],
459                    words,
460                    count: count.wrapping_add((i * BLOCKBYTES) as Count),
461                    last_node,
462                });
463            }
464            let mut jobs = jobs.into_inner().expect("full");
465            implementation.compress4_loop(&mut jobs, finalize, stride);
466
467            for i in 0..N {
468                assert_eq!(reference_words[i], test_words[i], "words {} unequal", i);
469            }
470        });
471    }
472
473    #[test]
474    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
475    fn test_compress4_loop_sse41() {
476        if let Some(imp) = Implementation::sse41_if_supported() {
477            exercise_compress4_loop(imp);
478        }
479    }
480
481    #[test]
482    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
483    fn test_compress4_loop_avx2() {
484        // Currently this just falls back to SSE4.1, but we test it anyway.
485        if let Some(imp) = Implementation::avx2_if_supported() {
486            exercise_compress4_loop(imp);
487        }
488    }
489
490    // Copied from exercise_compress2_loop, with a different value of N and an
491    // interior call to compress4_loop.
492    fn exercise_compress8_loop(implementation: Implementation) {
493        const N: usize = 8;
494
495        let mut input_buffer = [0; 100 * BLOCKBYTES];
496        paint_test_input(&mut input_buffer);
497        let mut inputs = ArrayVec::<[_; N]>::new();
498        for i in 0..N {
499            inputs.push(&input_buffer[i..]);
500        }
501
502        exercise_cases(|stride, length, last_node, finalize, count| {
503            let mut reference_words = ArrayVec::<[_; N]>::new();
504            for i in 0..N {
505                let words = reference_compression(
506                    &inputs[i][..length],
507                    stride,
508                    last_node,
509                    finalize,
510                    count.wrapping_add((i * BLOCKBYTES) as Count),
511                    i,
512                );
513                reference_words.push(words);
514            }
515
516            let mut test_words = ArrayVec::<[_; N]>::new();
517            for i in 0..N {
518                test_words.push(initial_test_words(i));
519            }
520            let mut jobs = ArrayVec::<[_; N]>::new();
521            for (i, words) in test_words.iter_mut().enumerate() {
522                jobs.push(Job {
523                    input: &inputs[i][..length],
524                    words,
525                    count: count.wrapping_add((i * BLOCKBYTES) as Count),
526                    last_node,
527                });
528            }
529            let mut jobs = jobs.into_inner().expect("full");
530            implementation.compress8_loop(&mut jobs, finalize, stride);
531
532            for i in 0..N {
533                assert_eq!(reference_words[i], test_words[i], "words {} unequal", i);
534            }
535        });
536    }
537
538    #[test]
539    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
540    fn test_compress8_loop_avx2() {
541        if let Some(imp) = Implementation::avx2_if_supported() {
542            exercise_compress8_loop(imp);
543        }
544    }
545
546    #[test]
547    fn sanity_check_count_size() {
548        assert_eq!(size_of::<Count>(), 2 * size_of::<Word>());
549    }
550}