1use crate::*;
2use arrayref::array_ref;
3use core::cmp;
4
5#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
6pub const MAX_DEGREE: usize = 8;
7
8#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
9pub const MAX_DEGREE: usize = 1;
10
11#[allow(dead_code)]
15#[derive(Clone, Copy, Debug, Eq, PartialEq)]
16pub enum Platform {
17 Portable,
18 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
19 SSE41,
20 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
21 AVX2,
22}
23
24#[derive(Clone, Copy, Debug)]
25pub struct Implementation(pub Platform);
26
27impl Implementation {
28 pub fn detect() -> Self {
29 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
32 {
33 if let Some(avx2_impl) = Self::avx2_if_supported() {
34 return avx2_impl;
35 }
36 }
37 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
38 {
39 if let Some(sse41_impl) = Self::sse41_if_supported() {
40 return sse41_impl;
41 }
42 }
43 Self::portable()
44 }
45
46 pub const fn portable() -> Self {
47 Implementation(Platform::Portable)
48 }
49
50 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
51 #[allow(unreachable_code)]
52 pub const fn sse41_if_supported() -> Option<Self> {
53 #[cfg(target_feature = "sse4.1")]
55 {
56 return Some(Implementation(Platform::SSE41));
57 }
58
59 None
60 }
61
62 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
63 #[allow(unreachable_code)]
64 pub const fn avx2_if_supported() -> Option<Self> {
65 #[cfg(target_feature = "avx2")]
67 {
68 return Some(Implementation(Platform::AVX2));
69 }
70
71 None
72 }
73
74 pub fn degree(&self) -> usize {
75 match self.0 {
76 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
77 Platform::AVX2 => avx2::DEGREE,
78 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
79 Platform::SSE41 => sse41::DEGREE,
80 Platform::Portable => 1,
81 }
82 }
83
84 pub fn compress1_loop(
85 &self,
86 input: &[u8],
87 words: &mut [Word; 8],
88 count: Count,
89 last_node: LastNode,
90 finalize: Finalize,
91 stride: Stride,
92 ) {
93 match self.0 {
94 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
95 Platform::AVX2 | Platform::SSE41 => unsafe {
96 sse41::compress1_loop(input, words, count, last_node, finalize, stride);
97 },
98 Platform::Portable => {
99 portable::compress1_loop(input, words, count, last_node, finalize, stride);
100 }
101 }
102 }
103
104 pub fn compress4_loop(&self, jobs: &mut [Job; 4], finalize: Finalize, stride: Stride) {
105 match self.0 {
106 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
107 Platform::AVX2 | Platform::SSE41 => unsafe {
108 sse41::compress4_loop(jobs, finalize, stride)
109 },
110 _ => panic!("unsupported"),
111 }
112 }
113
114 pub fn compress8_loop(&self, jobs: &mut [Job; 8], finalize: Finalize, stride: Stride) {
115 match self.0 {
116 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
117 Platform::AVX2 => unsafe { avx2::compress8_loop(jobs, finalize, stride) },
118 _ => panic!("unsupported"),
119 }
120 }
121}
122
123pub struct Job<'a, 'b> {
124 pub input: &'a [u8],
125 pub words: &'b mut [Word; 8],
126 pub count: Count,
127 pub last_node: LastNode,
128}
129
130impl<'a, 'b> core::fmt::Debug for Job<'a, 'b> {
131 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
132 write!(
134 f,
135 "Job {{ input_len: {}, count: {}, last_node: {} }}",
136 self.input.len(),
137 self.count,
138 self.last_node.yes(),
139 )
140 }
141}
142
143#[derive(Clone, Copy, Debug)]
145pub enum Finalize {
146 Yes,
147 No,
148}
149
150impl Finalize {
151 pub fn yes(&self) -> bool {
152 match self {
153 Finalize::Yes => true,
154 Finalize::No => false,
155 }
156 }
157}
158
159#[derive(Clone, Copy, Debug)]
161pub enum LastNode {
162 Yes,
163 No,
164}
165
166impl LastNode {
167 pub fn yes(&self) -> bool {
168 match self {
169 LastNode::Yes => true,
170 LastNode::No => false,
171 }
172 }
173}
174
175#[derive(Clone, Copy, Debug)]
176pub enum Stride {
177 Serial, Parallel, }
180
181impl Stride {
182 pub fn padded_blockbytes(&self) -> usize {
183 match self {
184 Stride::Serial => BLOCKBYTES,
185 Stride::Parallel => blake2sp::DEGREE * BLOCKBYTES,
186 }
187 }
188}
189
190pub(crate) fn count_low(count: Count) -> Word {
191 count as Word
192}
193
194pub(crate) fn count_high(count: Count) -> Word {
195 (count >> 8 * size_of::<Word>()) as Word
196}
197
198pub(crate) fn assemble_count(low: Word, high: Word) -> Count {
199 low as Count + ((high as Count) << 8 * size_of::<Word>())
200}
201
202pub(crate) fn flag_word(flag: bool) -> Word {
203 if flag {
204 !0
205 } else {
206 0
207 }
208}
209
210pub fn final_block<'a>(
218 input: &'a [u8],
219 offset: usize,
220 buffer: &'a mut [u8; BLOCKBYTES],
221 stride: Stride,
222) -> (&'a [u8; BLOCKBYTES], usize, bool) {
223 let capped_offset = cmp::min(offset, input.len());
224 let offset_slice = &input[capped_offset..];
225 if offset_slice.len() >= BLOCKBYTES {
226 let block = array_ref!(offset_slice, 0, BLOCKBYTES);
227 let should_finalize = offset_slice.len() <= stride.padded_blockbytes();
228 (block, BLOCKBYTES, should_finalize)
229 } else {
230 buffer[..offset_slice.len()].copy_from_slice(offset_slice);
233 (buffer, offset_slice.len(), true)
234 }
235}
236
237pub fn input_debug_asserts(input: &[u8], finalize: Finalize) {
238 if !finalize.yes() {
241 debug_assert!(!input.is_empty());
242 debug_assert_eq!(0, input.len() % BLOCKBYTES);
243 }
244}
245
246#[cfg(test)]
247mod test {
248 use super::*;
249 use arrayvec::ArrayVec;
250 use core::mem::size_of;
251
252 #[test]
253 fn test_detection() {
254 assert_eq!(Platform::Portable, Implementation::portable().0);
255
256 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
257 #[cfg(feature = "std")]
258 {
259 if is_x86_feature_detected!("avx2") {
260 assert_eq!(Platform::AVX2, Implementation::detect().0);
261 assert_eq!(
262 Platform::AVX2,
263 Implementation::avx2_if_supported().unwrap().0
264 );
265 assert_eq!(
266 Platform::SSE41,
267 Implementation::sse41_if_supported().unwrap().0
268 );
269 } else if is_x86_feature_detected!("sse4.1") {
270 assert_eq!(Platform::SSE41, Implementation::detect().0);
271 assert!(Implementation::avx2_if_supported().is_none());
272 assert_eq!(
273 Platform::SSE41,
274 Implementation::sse41_if_supported().unwrap().0
275 );
276 } else {
277 assert_eq!(Platform::Portable, Implementation::detect().0);
278 assert!(Implementation::avx2_if_supported().is_none());
279 assert!(Implementation::sse41_if_supported().is_none());
280 }
281 }
282 }
283
284 fn exercise_cases<F>(mut f: F)
285 where
286 F: FnMut(Stride, usize, LastNode, Finalize, Count),
287 {
288 let counts = &[
290 (0 as Count),
291 ((1 as Count) << (8 * size_of::<Word>())) - BLOCKBYTES as Count,
292 (0 as Count).wrapping_sub(BLOCKBYTES as Count),
293 ];
294 for &stride in &[Stride::Serial, Stride::Parallel] {
295 let lengths = [
296 0,
297 1,
298 BLOCKBYTES - 1,
299 BLOCKBYTES,
300 BLOCKBYTES + 1,
301 2 * BLOCKBYTES - 1,
302 2 * BLOCKBYTES,
303 2 * BLOCKBYTES + 1,
304 stride.padded_blockbytes() - 1,
305 stride.padded_blockbytes(),
306 stride.padded_blockbytes() + 1,
307 2 * stride.padded_blockbytes() - 1,
308 2 * stride.padded_blockbytes(),
309 2 * stride.padded_blockbytes() + 1,
310 ];
311 for &length in &lengths {
312 for &last_node in &[LastNode::No, LastNode::Yes] {
313 for &finalize in &[Finalize::No, Finalize::Yes] {
314 if !finalize.yes() && (length == 0 || length % BLOCKBYTES != 0) {
315 continue;
317 }
318 for &count in counts {
319 f(stride, length, last_node, finalize, count);
327 }
328 }
329 }
330 }
331 }
332 }
333
334 fn initial_test_words(input_index: usize) -> [Word; 8] {
335 crate::Params::new()
336 .node_offset(input_index as u64)
337 .to_words()
338 }
339
340 fn reference_compression(
343 input: &[u8],
344 stride: Stride,
345 last_node: LastNode,
346 finalize: Finalize,
347 mut count: Count,
348 input_index: usize,
349 ) -> [Word; 8] {
350 let mut words = initial_test_words(input_index);
351 let mut offset = 0;
352 while offset == 0 || offset < input.len() {
353 let block_size = cmp::min(BLOCKBYTES, input.len() - offset);
354 let maybe_finalize = if offset + stride.padded_blockbytes() < input.len() {
355 Finalize::No
356 } else {
357 finalize
358 };
359 portable::compress1_loop(
360 &input[offset..][..block_size],
361 &mut words,
362 count,
363 last_node,
364 maybe_finalize,
365 Stride::Serial,
366 );
367 offset += stride.padded_blockbytes();
368 count = count.wrapping_add(BLOCKBYTES as Count);
369 }
370 words
371 }
372
373 fn exercise_compress1_loop(implementation: Implementation) {
380 let mut input = [0; 100 * BLOCKBYTES];
381 paint_test_input(&mut input);
382
383 exercise_cases(|stride, length, last_node, finalize, count| {
384 let reference_words =
385 reference_compression(&input[..length], stride, last_node, finalize, count, 0);
386
387 let mut test_words = initial_test_words(0);
388 implementation.compress1_loop(
389 &input[..length],
390 &mut test_words,
391 count,
392 last_node,
393 finalize,
394 stride,
395 );
396 assert_eq!(reference_words, test_words);
397 });
398 }
399
400 #[test]
401 fn test_compress1_loop_portable() {
402 exercise_compress1_loop(Implementation::portable());
403 }
404
405 #[test]
406 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
407 fn test_compress1_loop_sse41() {
408 if let Some(imp) = Implementation::sse41_if_supported() {
409 exercise_compress1_loop(imp);
410 }
411 }
412
413 #[test]
414 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
415 fn test_compress1_loop_avx2() {
416 if let Some(imp) = Implementation::avx2_if_supported() {
418 exercise_compress1_loop(imp);
419 }
420 }
421
422 fn exercise_compress4_loop(implementation: Implementation) {
428 const N: usize = 4;
429
430 let mut input_buffer = [0; 100 * BLOCKBYTES];
431 paint_test_input(&mut input_buffer);
432 let mut inputs = ArrayVec::<[_; N]>::new();
433 for i in 0..N {
434 inputs.push(&input_buffer[i..]);
435 }
436
437 exercise_cases(|stride, length, last_node, finalize, count| {
438 let mut reference_words = ArrayVec::<[_; N]>::new();
439 for i in 0..N {
440 let words = reference_compression(
441 &inputs[i][..length],
442 stride,
443 last_node,
444 finalize,
445 count.wrapping_add((i * BLOCKBYTES) as Count),
446 i,
447 );
448 reference_words.push(words);
449 }
450
451 let mut test_words = ArrayVec::<[_; N]>::new();
452 for i in 0..N {
453 test_words.push(initial_test_words(i));
454 }
455 let mut jobs = ArrayVec::<[_; N]>::new();
456 for (i, words) in test_words.iter_mut().enumerate() {
457 jobs.push(Job {
458 input: &inputs[i][..length],
459 words,
460 count: count.wrapping_add((i * BLOCKBYTES) as Count),
461 last_node,
462 });
463 }
464 let mut jobs = jobs.into_inner().expect("full");
465 implementation.compress4_loop(&mut jobs, finalize, stride);
466
467 for i in 0..N {
468 assert_eq!(reference_words[i], test_words[i], "words {} unequal", i);
469 }
470 });
471 }
472
473 #[test]
474 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
475 fn test_compress4_loop_sse41() {
476 if let Some(imp) = Implementation::sse41_if_supported() {
477 exercise_compress4_loop(imp);
478 }
479 }
480
481 #[test]
482 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
483 fn test_compress4_loop_avx2() {
484 if let Some(imp) = Implementation::avx2_if_supported() {
486 exercise_compress4_loop(imp);
487 }
488 }
489
490 fn exercise_compress8_loop(implementation: Implementation) {
493 const N: usize = 8;
494
495 let mut input_buffer = [0; 100 * BLOCKBYTES];
496 paint_test_input(&mut input_buffer);
497 let mut inputs = ArrayVec::<[_; N]>::new();
498 for i in 0..N {
499 inputs.push(&input_buffer[i..]);
500 }
501
502 exercise_cases(|stride, length, last_node, finalize, count| {
503 let mut reference_words = ArrayVec::<[_; N]>::new();
504 for i in 0..N {
505 let words = reference_compression(
506 &inputs[i][..length],
507 stride,
508 last_node,
509 finalize,
510 count.wrapping_add((i * BLOCKBYTES) as Count),
511 i,
512 );
513 reference_words.push(words);
514 }
515
516 let mut test_words = ArrayVec::<[_; N]>::new();
517 for i in 0..N {
518 test_words.push(initial_test_words(i));
519 }
520 let mut jobs = ArrayVec::<[_; N]>::new();
521 for (i, words) in test_words.iter_mut().enumerate() {
522 jobs.push(Job {
523 input: &inputs[i][..length],
524 words,
525 count: count.wrapping_add((i * BLOCKBYTES) as Count),
526 last_node,
527 });
528 }
529 let mut jobs = jobs.into_inner().expect("full");
530 implementation.compress8_loop(&mut jobs, finalize, stride);
531
532 for i in 0..N {
533 assert_eq!(reference_words[i], test_words[i], "words {} unequal", i);
534 }
535 });
536 }
537
538 #[test]
539 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
540 fn test_compress8_loop_avx2() {
541 if let Some(imp) = Implementation::avx2_if_supported() {
542 exercise_compress8_loop(imp);
543 }
544 }
545
546 #[test]
547 fn sanity_check_count_size() {
548 assert_eq!(size_of::<Count>(), 2 * size_of::<Word>());
549 }
550}