Skip to main content

pow_buster/solver/
avx512.rs

1use sha2::digest::generic_array::GenericArray;
2
3use crate::{
4    Align16, Align64, SWAP_DWORD_BYTE_ORDER, decompose_blocks_mut,
5    message::{
6        BinaryMessage, CerberusMessage, DecimalMessage, DoubleBlockMessage, GoAwayMessage,
7        SingleBlockMessage,
8    },
9};
10use core::arch::x86_64::*;
11
12cpufeatures::new!(avx512f, "avx512f");
13
14#[derive(Debug, Copy, Clone)]
15/// Required features for AVX-512 solver.
16pub struct RequiredFeatures;
17
18impl Default for RequiredFeatures {
19    fn default() -> Self {
20        Self
21    }
22}
23
24impl crate::solver::CpuIDToken for RequiredFeatures {
25    fn get() -> bool {
26        avx512f::get()
27    }
28}
29
30static LANE_ID_MSB_STR: Align16<[u8; 5 * 16]> =
31    Align16(*b"11111111112222222222333333333344444444445555555555666666666677777777778888888888");
32
33static LANE_ID_MSB_STR_0: Align16<[u8; 6 * 16]> =
34    Align16(*b"000000000011111111112222222222333333333344444444445555555555666666666677777777778888888888999999");
35
36static LANE_ID_LSB_STR: Align16<[u8; 5 * 16]> =
37    Align16(*b"01234567890123456789012345678901234567890123456789012345678901234567890123456789");
38
39static LANE_ID_LSB_STR_0: Align16<[u8; 6 * 16]> =
40    Align16(*b"012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345");
41
42static LANE_ID_STR_COMBINED_LE_HI: Align64<[u32; 1000 / 16 * 16]> = {
43    let mut out = [0; 1000 / 16 * 16];
44    let mut i = 0;
45    while i < 1000 / 16 * 16 {
46        let mut copy = i;
47        let mut ds = [0; 4];
48        let mut j = 0;
49        while j < 3 {
50            ds[j] = (copy % 10) as u8 + b'0';
51            copy /= 10;
52            j += 1;
53        }
54        out[i] = u32::from_be_bytes(ds);
55        i += 1;
56    }
57    Align64(out)
58};
59
60#[expect(dead_code)]
61mod static_asserts {
62    use super::*;
63
64    const ASSERT_LANE_ID_STR_COMBINED_LE_HI_0: [(); 1] =
65        [(); (LANE_ID_STR_COMBINED_LE_HI.0[0] == u32::from_be_bytes(*b"000\x00")) as usize];
66
67    const ASSERT_LANE_ID_STR_COMBINED_LE_HI_1: [(); 1] =
68        [(); (LANE_ID_STR_COMBINED_LE_HI.0[1] == u32::from_be_bytes(*b"100\x00")) as usize];
69
70    const ASSERT_LANE_ID_STR_COMBINED_LE_HI_123: [(); 1] =
71        [(); (LANE_ID_STR_COMBINED_LE_HI.0[123] == u32::from_be_bytes(*b"321\x00")) as usize];
72}
73
74#[cfg(feature = "compare-64bit")]
75const INDEX_REMAP_PUNPCKLDQ: [usize; 16] = [0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15];
76
77#[inline(always)]
78fn load_lane_id_epi32<const N: usize>(src: &Align16<[u8; N]>, set_idx: usize) -> __m512i {
79    debug_assert!(set_idx * 16 < N);
80    unsafe { _mm512_cvtepi8_epi32(_mm_load_si128(src.as_ptr().add(set_idx * 16).cast())) }
81}
82
83/// AVX-512 decimal nonce single block solver.
84///
85///
86/// Current implementation: 16 way SIMD with 1-round hotstart granularity.
87pub struct SingleBlockSolver {
88    message: SingleBlockMessage,
89
90    attempted_nonces: u64,
91
92    limit: u64,
93}
94
95impl From<super::safe::SingleBlockSolver> for SingleBlockSolver {
96    fn from(solver: super::safe::SingleBlockSolver) -> Self {
97        Self {
98            message: solver.message,
99            attempted_nonces: solver.attempted_nonces,
100            limit: solver.limit,
101        }
102    }
103}
104
105impl From<SingleBlockMessage> for SingleBlockSolver {
106    fn from(message: SingleBlockMessage) -> Self {
107        Self {
108            message,
109            attempted_nonces: 0,
110            limit: u64::MAX,
111        }
112    }
113}
114
115const MUTATION_TYPE_UNALIGNED: u8 = 0;
116const MUTATION_TYPE_ALIGNED: u8 = 1;
117const MUTATION_TYPE_OCTAL: u8 = 2;
118const MUTATION_TYPE_ALIGNED_OCTAL: u8 = MUTATION_TYPE_ALIGNED | MUTATION_TYPE_OCTAL;
119const MUTATION_TYPE_UNALIGNED_OCTAL: u8 = MUTATION_TYPE_UNALIGNED | MUTATION_TYPE_OCTAL;
120
121impl SingleBlockSolver {
122    #[inline(never)]
123    #[target_feature(enable = "avx512f")]
124    fn solve_impl<
125        const DIGIT_WORD_IDX0: usize,
126        const DIGIT_WORD_IDX1_INCREMENT: bool,
127        const TYPE: u8,
128        const MUTATION_TYPE: u8,
129    >(
130        &mut self,
131        target: u64,
132        mask: u64,
133    ) -> Option<u64> {
134        let mut partial_state = self.message.prefix_state;
135        crate::sha256::ingest_message_prefix::<DIGIT_WORD_IDX0>(
136            &mut partial_state,
137            core::array::from_fn(|i| self.message.message[i]),
138        );
139
140        // zero out the nonce portion to prevent incorrect results if solvers are reused
141        for (ix, i) in (self.message.digit_index..).take(9).enumerate() {
142            let message = decompose_blocks_mut(&mut self.message.message);
143            message[SWAP_DWORD_BYTE_ORDER[i]] =
144                if ix >= 2 && MUTATION_TYPE & MUTATION_TYPE_OCTAL != 0 {
145                    b'1'
146                } else {
147                    b'0'
148                };
149        }
150
151        let lane_id_0_byte_idx = self.message.digit_index % 4;
152        let lane_id_1_byte_idx = (self.message.digit_index + 1) % 4;
153
154        for prefix_set_index in 0..(if MUTATION_TYPE & MUTATION_TYPE_OCTAL != 0 {
155            6
156        } else {
157            5
158        }) {
159            if self.attempted_nonces >= self.limit {
160                return None;
161            }
162
163            let mut inner_key_buf = if MUTATION_TYPE & MUTATION_TYPE_OCTAL != 0 {
164                Align16(*b"1111\x80111")
165            } else {
166                Align16(*b"0000\x80000")
167            };
168
169            unsafe {
170                let (lane_id_0_or_value, lane_id_1_or_value) =
171                    if MUTATION_TYPE & MUTATION_TYPE_OCTAL != 0 {
172                        let lane_id_0_or_value = _mm512_sll_epi32(
173                            load_lane_id_epi32(&LANE_ID_MSB_STR_0, prefix_set_index),
174                            _mm_cvtsi64x_si128(((3 - lane_id_0_byte_idx) * 8) as _),
175                        );
176                        let lane_id_1_or_value = _mm512_sll_epi32(
177                            load_lane_id_epi32(&LANE_ID_LSB_STR_0, prefix_set_index),
178                            _mm_cvtsi64x_si128(((3 - lane_id_1_byte_idx) * 8) as _),
179                        );
180
181                        (lane_id_0_or_value, lane_id_1_or_value)
182                    } else {
183                        let lane_id_0_or_value = _mm512_sll_epi32(
184                            load_lane_id_epi32(&LANE_ID_MSB_STR, prefix_set_index),
185                            _mm_cvtsi64x_si128(((3 - lane_id_0_byte_idx) * 8) as _),
186                        );
187                        let lane_id_1_or_value = _mm512_sll_epi32(
188                            load_lane_id_epi32(&LANE_ID_LSB_STR, prefix_set_index),
189                            _mm_cvtsi64x_si128(((3 - lane_id_1_byte_idx) * 8) as _),
190                        );
191
192                        (lane_id_0_or_value, lane_id_1_or_value)
193                    };
194
195                let lane_id_0_or_value_v = if !DIGIT_WORD_IDX1_INCREMENT {
196                    _mm512_or_epi32(lane_id_0_or_value, lane_id_1_or_value)
197                } else {
198                    lane_id_0_or_value
199                };
200
201                let mut inner_iteration_end = if MUTATION_TYPE & MUTATION_TYPE_OCTAL != 0 {
202                    0o10_000_000
203                } else {
204                    10_000_000
205                };
206
207                // clamp it to the number of remaining iterations
208                inner_iteration_end = self
209                    .limit
210                    .saturating_sub(self.attempted_nonces)
211                    .div_ceil(16)
212                    .min(inner_iteration_end as u64) as u32;
213
214                // soft pipeline this to compute the new message after the hash
215                // LLVM seems to handle cases where high register pressure work happens first better
216                // so this prevents some needless register spills
217                // doesn't seem to affect performance on my Zen4 but dirty so avoid
218                // on the last iteration simd_itoa(10_000_000) is unit-tested to convert to 0000\x80000
219                // so no fixup is needed-saves a branch on LLVM codegen
220                for next_inner_key in 1..=inner_iteration_end {
221                    macro_rules! fetch_msg {
222                        ($idx:expr) => {
223                            if $idx == DIGIT_WORD_IDX0 {
224                                _mm512_or_epi32(
225                                    _mm512_set1_epi32(self.message.message[$idx] as _),
226                                    lane_id_0_or_value_v,
227                                )
228                            } else if DIGIT_WORD_IDX1_INCREMENT && $idx == DIGIT_WORD_IDX0 + 1 {
229                                _mm512_or_epi32(
230                                    _mm512_set1_epi32(self.message.message[$idx] as _),
231                                    lane_id_1_or_value,
232                                )
233                            } else if (MUTATION_TYPE_ALIGNED & MUTATION_TYPE != 0)
234                                && $idx == DIGIT_WORD_IDX0 + 1
235                            {
236                                _mm512_set1_epi32(
237                                    (inner_key_buf.as_ptr().cast::<u32>().read()) as _,
238                                )
239                            } else if (MUTATION_TYPE_ALIGNED & MUTATION_TYPE != 0)
240                                && $idx == DIGIT_WORD_IDX0 + 2
241                            {
242                                _mm512_set1_epi32(
243                                    (inner_key_buf.as_ptr().add(4).cast::<u32>().read()) as _,
244                                )
245                            } else {
246                                _mm512_set1_epi32(self.message.message[$idx] as _)
247                            }
248                        };
249                    }
250                    let mut blocks = [
251                        fetch_msg!(0),
252                        fetch_msg!(1),
253                        fetch_msg!(2),
254                        fetch_msg!(3),
255                        fetch_msg!(4),
256                        fetch_msg!(5),
257                        fetch_msg!(6),
258                        fetch_msg!(7),
259                        fetch_msg!(8),
260                        fetch_msg!(9),
261                        fetch_msg!(10),
262                        fetch_msg!(11),
263                        fetch_msg!(12),
264                        fetch_msg!(13),
265                        fetch_msg!(14),
266                        fetch_msg!(15),
267                    ];
268
269                    let mut state =
270                        core::array::from_fn(|i| _mm512_set1_epi32(partial_state[i] as _));
271
272                    // do 16-way SHA-256 without feedback so as not to force the compiler to save 8 registers
273                    // we already have them in scalar form, this allows more registers to be reused in the next iteration
274                    crate::sha256::avx512::multiway_arx::<DIGIT_WORD_IDX0>(&mut state, &mut blocks);
275
276                    state[0] = _mm512_add_epi32(
277                        state[0],
278                        _mm512_set1_epi32(self.message.prefix_state[0] as _),
279                    );
280
281                    #[cfg(feature = "compare-64bit")]
282                    {
283                        state[1] = _mm512_add_epi32(
284                            state[1],
285                            _mm512_set1_epi32(self.message.prefix_state[1] as _),
286                        );
287                    }
288
289                    #[cfg(feature = "compare-64bit")]
290                    let result_ab_lo = _mm512_unpacklo_epi32(state[1], state[0]);
291                    #[cfg(feature = "compare-64bit")]
292                    let result_ab_hi = _mm512_unpackhi_epi32(state[1], state[0]);
293
294                    // A 64-bit compare solution is provided for completeness but almost never needed for realistic challenges.
295
296                    #[cfg(not(feature = "compare-64bit"))]
297                    let cmp_fn = |x: __m512i, y: __m512i| {
298                        if TYPE == crate::solver::SOLVE_TYPE_GT {
299                            _mm512_cmpgt_epu32_mask(x, y)
300                        } else if TYPE == crate::solver::SOLVE_TYPE_LT {
301                            _mm512_cmplt_epu32_mask(x, y)
302                        } else {
303                            _mm512_cmpeq_epu32_mask(
304                                _mm512_and_si512(x, _mm512_set1_epi32((mask >> 32) as _)),
305                                y,
306                            )
307                        }
308                    };
309
310                    #[cfg(feature = "compare-64bit")]
311                    let cmp64_fn = |x: __m512i, y: __m512i| {
312                        if TYPE == crate::solver::SOLVE_TYPE_GT {
313                            _mm512_cmpgt_epu64_mask(x, y)
314                        } else if TYPE == crate::solver::SOLVE_TYPE_LT {
315                            _mm512_cmplt_epu64_mask(x, y)
316                        } else {
317                            _mm512_cmpeq_epu64_mask(
318                                _mm512_and_si512(x, _mm512_set1_epi64(mask as _)),
319                                y,
320                            )
321                        }
322                    };
323
324                    #[cfg(not(feature = "compare-64bit"))]
325                    let met_target = cmp_fn(state[0], _mm512_set1_epi32((target >> 32) as _));
326
327                    #[cfg(feature = "compare-64bit")]
328                    let (met_target_high, met_target_lo) = {
329                        let ab_met_target_lo =
330                            cmp64_fn(result_ab_lo, _mm512_set1_epi64(target as _)) as u16;
331
332                        let ab_met_target_high =
333                            cmp64_fn(result_ab_hi, _mm512_set1_epi64(target as _)) as u16;
334
335                        (ab_met_target_high, ab_met_target_lo)
336                    };
337                    #[cfg(feature = "compare-64bit")]
338                    let met_target_test = met_target_high != 0 || met_target_lo != 0;
339                    #[cfg(not(feature = "compare-64bit"))]
340                    let met_target_test = met_target != 0;
341
342                    if met_target_test {
343                        crate::unlikely();
344
345                        #[cfg(not(feature = "compare-64bit"))]
346                        let success_lane_idx = met_target.trailing_zeros() as usize;
347
348                        // remap the indices according to unpacking order
349                        #[cfg(feature = "compare-64bit")]
350                        let success_lane_idx = INDEX_REMAP_PUNPCKLDQ
351                            [(met_target_high << 8 | met_target_lo).trailing_zeros() as usize];
352
353                        let mut nonce_prefix = 16 * prefix_set_index + success_lane_idx;
354                        if MUTATION_TYPE & MUTATION_TYPE_OCTAL == 0 {
355                            nonce_prefix += 10;
356                        }
357
358                        if MUTATION_TYPE & MUTATION_TYPE_ALIGNED != 0 {
359                            self.message.message[DIGIT_WORD_IDX0 + 1] =
360                                inner_key_buf.as_ptr().cast::<u32>().read();
361                            self.message.message[DIGIT_WORD_IDX0 + 2] =
362                                inner_key_buf.as_ptr().add(4).cast::<u32>().read();
363                        }
364
365                        // stamp the lane ID back onto the message
366                        {
367                            let message_bytes = decompose_blocks_mut(&mut self.message.message);
368                            *message_bytes.get_unchecked_mut(
369                                *SWAP_DWORD_BYTE_ORDER.get_unchecked(self.message.digit_index),
370                            ) = (nonce_prefix / 10) as u8 + b'0';
371                            *message_bytes.get_unchecked_mut(
372                                *SWAP_DWORD_BYTE_ORDER.get_unchecked(self.message.digit_index + 1),
373                            ) = (nonce_prefix % 10) as u8 + b'0';
374                        }
375
376                        let mut decimal_inner_key = next_inner_key as u64 - 1;
377                        if MUTATION_TYPE & MUTATION_TYPE_OCTAL != 0 {
378                            decimal_inner_key = 0;
379                            let mut key_octal = next_inner_key - 1;
380                            for m in (0..7u32).map(|i| 10u64.pow(i)) {
381                                let output = (key_octal % 8) + 1;
382                                key_octal /= 8;
383                                decimal_inner_key += output as u64 * m;
384                            }
385                            let mut message_be = [0u8; 64];
386                            for i in 0..16 {
387                                message_be[i * 4..][..4]
388                                    .copy_from_slice(&self.message.message[i].to_be_bytes());
389                            }
390                        }
391
392                        // the nonce is the 7 digits in the message, plus the first two digits recomputed from the lane index
393                        return Some(nonce_prefix as u64 * 10u64.pow(7) + decimal_inner_key);
394                    }
395
396                    self.attempted_nonces += 16;
397
398                    if MUTATION_TYPE == MUTATION_TYPE_ALIGNED_OCTAL {
399                        crate::strings::to_octal_7::<true, 0x80, 1>(
400                            &mut inner_key_buf,
401                            next_inner_key,
402                        )
403                    } else if MUTATION_TYPE == MUTATION_TYPE_ALIGNED {
404                        crate::strings::simd_itoa8::<7, true, 0x80>(
405                            &mut inner_key_buf,
406                            next_inner_key,
407                        );
408                    } else if MUTATION_TYPE == MUTATION_TYPE_UNALIGNED_OCTAL {
409                        let message_bytes = decompose_blocks_mut(&mut self.message.message);
410                        let mut key_copy = next_inner_key;
411
412                        for i in (0..7).rev() {
413                            let output = key_copy % 8;
414                            key_copy /= 8;
415                            *message_bytes.get_unchecked_mut(
416                                *SWAP_DWORD_BYTE_ORDER
417                                    .get_unchecked(self.message.digit_index + i + 2),
418                            ) = output as u8 + b'1';
419                        }
420                    } else {
421                        let message_bytes = decompose_blocks_mut(&mut self.message.message);
422                        let mut key_copy = next_inner_key;
423
424                        for i in (0..7).rev() {
425                            let output = key_copy % 10;
426                            key_copy /= 10;
427                            *message_bytes.get_unchecked_mut(
428                                *SWAP_DWORD_BYTE_ORDER
429                                    .get_unchecked(self.message.digit_index + i + 2),
430                            ) = output as u8 + b'0';
431                        }
432                    }
433                }
434            }
435        }
436
437        crate::unlikely();
438        None
439    }
440}
441
442impl crate::solver::Solver for SingleBlockSolver {
443    fn set_limit(&mut self, limit: u64) {
444        self.limit = limit;
445    }
446
447    fn get_attempted_nonces(&self) -> u64 {
448        self.attempted_nonces
449    }
450
451    fn solve_nonce_only<const TYPE: u8>(&mut self, target: u64, mask: u64) -> Option<u64> {
452        if self.attempted_nonces >= self.limit {
453            return None;
454        }
455        let target = target & mask;
456
457        // the official default difficulty is 5e6, so we design for 1e8
458        // and there should almost always be a valid solution within our supported solution space
459        // pgeom(5 * 16e7, 1/5e7, lower=F) = 0.03%
460        // pgeom(16e7, 1/5e7, lower=F) = 20%, which is too much so we need the prefix to change as well
461
462        // pre-compute an OR to apply to the message to add the lane ID
463        let lane_id_0_word_idx = self.message.digit_index / 4;
464        let lane_id_1_word_idx = (self.message.digit_index + 1) / 4;
465
466        macro_rules! dispatch {
467            ($idx0:literal, $idx1_inc:literal) => {
468                unsafe {
469                    if self.message.digit_index % 4 == 2 {
470                        // if we have to much search space it doesn't matter
471                        // use the octal kernel
472                        if self.message.no_trailing_zeros
473                            || self.message.approx_working_set_count.get() >= 100
474                        {
475                            self.solve_impl::<$idx0, $idx1_inc, TYPE, MUTATION_TYPE_ALIGNED_OCTAL>(
476                                target, mask,
477                            )
478                        } else {
479                            self.solve_impl::<$idx0, $idx1_inc, TYPE, MUTATION_TYPE_ALIGNED>(
480                                target, mask,
481                            )
482                        }
483                    } else if self.message.no_trailing_zeros {
484                        self.solve_impl::<$idx0, $idx1_inc, TYPE, MUTATION_TYPE_UNALIGNED_OCTAL>(
485                            target, mask,
486                        )
487                    } else {
488                        self.solve_impl::<$idx0, $idx1_inc, TYPE, MUTATION_TYPE_UNALIGNED>(
489                            target, mask,
490                        )
491                    }
492                }
493            };
494            ($idx0:literal) => {
495                if lane_id_0_word_idx == lane_id_1_word_idx {
496                    dispatch!($idx0, false)
497                } else {
498                    dispatch!($idx0, true)
499                }
500            };
501        }
502
503        let nonce = match lane_id_0_word_idx {
504            0 => dispatch!(0),
505            1 => dispatch!(1),
506            2 => dispatch!(2),
507            3 => dispatch!(3),
508            4 => dispatch!(4),
509            5 => dispatch!(5),
510            6 => dispatch!(6),
511            7 => dispatch!(7),
512            8 => dispatch!(8),
513            9 => dispatch!(9),
514            10 => dispatch!(10),
515            11 => dispatch!(11),
516            12 => dispatch!(12),
517            13 => dispatch!(13),
518            _ => unsafe { core::hint::unreachable_unchecked() },
519        }?;
520
521        Some(nonce + self.message.nonce_addend)
522    }
523
524    fn solve<const TYPE: u8>(&mut self, target: u64, mask: u64) -> Option<(u64, [u32; 8])> {
525        let nonce = self.solve_nonce_only::<TYPE>(target, mask)?;
526
527        // recompute the hash from the beginning
528        // this prevents the compiler from having to compute the final B-H registers alive in tight loops
529        let mut final_sha_state = self.message.prefix_state;
530        crate::sha256::digest_block(&mut final_sha_state, &self.message.message);
531
532        Some((nonce, final_sha_state))
533    }
534}
535
536/// AVX-512 decimal nonce double block solver.
537///
538///
539/// Current implementation: 16 way SIMD with 1-round hotstart granularity.
540pub struct DoubleBlockSolver {
541    message: DoubleBlockMessage,
542    attempted_nonces: u64,
543
544    limit: u64,
545}
546
547impl From<super::safe::DoubleBlockSolver> for DoubleBlockSolver {
548    fn from(solver: super::safe::DoubleBlockSolver) -> Self {
549        Self {
550            message: solver.message,
551            attempted_nonces: solver.attempted_nonces,
552            limit: solver.limit,
553        }
554    }
555}
556
557impl From<DoubleBlockMessage> for DoubleBlockSolver {
558    fn from(message: DoubleBlockMessage) -> Self {
559        Self {
560            message,
561            attempted_nonces: 0,
562            limit: u64::MAX,
563        }
564    }
565}
566
567impl crate::solver::Solver for DoubleBlockSolver {
568    fn set_limit(&mut self, limit: u64) {
569        self.limit = limit;
570    }
571
572    fn get_attempted_nonces(&self) -> u64 {
573        self.attempted_nonces
574    }
575
576    #[inline]
577    fn solve<const TYPE: u8>(&mut self, target: u64, mask: u64) -> Option<(u64, [u32; 8])> {
578        unsafe { self.solve_impl::<TYPE>(target, mask) }
579    }
580}
581
582impl DoubleBlockSolver {
583    #[inline(never)]
584    #[target_feature(enable = "avx512f")]
585    fn solve_impl<const TYPE: u8>(&mut self, target: u64, mask: u64) -> Option<(u64, [u32; 8])> {
586        let target = target & mask;
587
588        if self.attempted_nonces >= self.limit {
589            return None;
590        }
591
592        for (ix, i) in (DoubleBlockMessage::DIGIT_IDX as usize..)
593            .take(9)
594            .enumerate()
595        {
596            let message = decompose_blocks_mut(&mut self.message.message);
597            message[SWAP_DWORD_BYTE_ORDER[i]] = b'0';
598            if ix >= 2 {
599                message[SWAP_DWORD_BYTE_ORDER[i]] = b'1';
600            }
601        }
602
603        let mut partial_state = self.message.prefix_state;
604        crate::sha256::sha2_arx::<0>(&mut partial_state, &self.message.message[..13]);
605
606        let mut terminal_message_schedule = Align16([0; 64]);
607        terminal_message_schedule[14] = ((self.message.message_length * 8) >> 32) as u32;
608        terminal_message_schedule[15] = (self.message.message_length * 8) as u32;
609        crate::sha256::do_message_schedule_k_w(&mut terminal_message_schedule);
610
611        let mut itoa_buf = Align16(*b"1111\x80111");
612        // the addend is definitely not zero for double block solver, so we can start at 0
613        // to recoup some lost search space from using octal digits
614        for prefix_set_index in 0..(LANE_ID_LSB_STR.len() / 16) {
615            unsafe {
616                let lane_id_0_or_value =
617                    _mm512_slli_epi32(load_lane_id_epi32(&LANE_ID_MSB_STR, prefix_set_index), 8);
618                let lane_id_1_or_value = load_lane_id_epi32(&LANE_ID_LSB_STR, prefix_set_index);
619
620                let lane_index_value_v = _mm512_or_epi32(
621                    _mm512_set1_epi32(self.message.message[13] as _),
622                    _mm512_or_epi32(lane_id_0_or_value, lane_id_1_or_value),
623                );
624
625                for next_inner_key in 1..=0o10_000_000 {
626                    let cum0 = itoa_buf.as_ptr().cast::<u32>().read();
627                    let cum1 = itoa_buf.as_ptr().add(4).cast::<u32>().read();
628
629                    let mut state =
630                        core::array::from_fn(|i| _mm512_set1_epi32(partial_state[i] as _));
631
632                    {
633                        let mut blocks = [
634                            _mm512_set1_epi32(self.message.message[0] as _),
635                            _mm512_set1_epi32(self.message.message[1] as _),
636                            _mm512_set1_epi32(self.message.message[2] as _),
637                            _mm512_set1_epi32(self.message.message[3] as _),
638                            _mm512_set1_epi32(self.message.message[4] as _),
639                            _mm512_set1_epi32(self.message.message[5] as _),
640                            _mm512_set1_epi32(self.message.message[6] as _),
641                            _mm512_set1_epi32(self.message.message[7] as _),
642                            _mm512_set1_epi32(self.message.message[8] as _),
643                            _mm512_set1_epi32(self.message.message[9] as _),
644                            _mm512_set1_epi32(self.message.message[10] as _),
645                            _mm512_set1_epi32(self.message.message[11] as _),
646                            _mm512_set1_epi32(self.message.message[12] as _),
647                            lane_index_value_v,
648                            _mm512_set1_epi32(cum0 as _),
649                            _mm512_set1_epi32(cum1 as _),
650                        ];
651
652                        crate::sha256::avx512::multiway_arx::<13>(&mut state, &mut blocks);
653
654                        // we have to do feedback now
655                        state
656                            .iter_mut()
657                            .zip(self.message.prefix_state.iter())
658                            .for_each(|(state, prefix_state)| {
659                                *state =
660                                    _mm512_add_epi32(*state, _mm512_set1_epi32(*prefix_state as _));
661                            });
662                    }
663
664                    // save only A register for comparison
665                    let save_a = state[0];
666
667                    #[cfg(feature = "compare-64bit")]
668                    let save_b = state[1];
669
670                    crate::sha256::avx512::bcst_multiway_arx::<14>(
671                        &mut state,
672                        &terminal_message_schedule,
673                    );
674
675                    #[cfg(not(feature = "compare-64bit"))]
676                    let cmp_fn = |x: __m512i, y: __m512i| {
677                        if TYPE == crate::solver::SOLVE_TYPE_GT {
678                            _mm512_cmpgt_epu32_mask(x, y)
679                        } else if TYPE == crate::solver::SOLVE_TYPE_LT {
680                            _mm512_cmplt_epu32_mask(x, y)
681                        } else {
682                            _mm512_cmpeq_epu32_mask(
683                                _mm512_and_si512(x, _mm512_set1_epi32((mask >> 32) as _)),
684                                y,
685                            )
686                        }
687                    };
688
689                    #[cfg(feature = "compare-64bit")]
690                    let cmp64_fn = |x: __m512i, y: __m512i| {
691                        if TYPE == crate::solver::SOLVE_TYPE_GT {
692                            _mm512_cmpgt_epu64_mask(x, y)
693                        } else if TYPE == crate::solver::SOLVE_TYPE_LT {
694                            _mm512_cmplt_epu64_mask(x, y)
695                        } else {
696                            _mm512_cmpeq_epu64_mask(
697                                _mm512_and_si512(x, _mm512_set1_epi64(mask as _)),
698                                y,
699                            )
700                        }
701                    };
702
703                    state[0] = _mm512_add_epi32(state[0], save_a);
704
705                    #[cfg(feature = "compare-64bit")]
706                    {
707                        state[1] = _mm512_add_epi32(state[1], save_b);
708                    }
709
710                    #[cfg(not(feature = "compare-64bit"))]
711                    let met_target = (cmp_fn)(state[0], _mm512_set1_epi32((target >> 32) as _));
712
713                    #[cfg(feature = "compare-64bit")]
714                    let result_ab_lo = _mm512_unpacklo_epi32(state[1], state[0]);
715                    #[cfg(feature = "compare-64bit")]
716                    let result_ab_hi = _mm512_unpackhi_epi32(state[1], state[0]);
717                    #[cfg(feature = "compare-64bit")]
718                    let (met_target_high, met_target_lo) = {
719                        let ab_met_target_lo =
720                            cmp64_fn(result_ab_lo, _mm512_set1_epi64(target as _)) as u16;
721                        let ab_met_target_high =
722                            cmp64_fn(result_ab_hi, _mm512_set1_epi64(target as _)) as u16;
723                        (ab_met_target_high, ab_met_target_lo)
724                    };
725
726                    #[cfg(feature = "compare-64bit")]
727                    let met_target_test = met_target_high != 0 || met_target_lo != 0;
728
729                    #[cfg(not(feature = "compare-64bit"))]
730                    let met_target_test = met_target != 0;
731
732                    if met_target_test {
733                        crate::unlikely();
734
735                        #[cfg(not(feature = "compare-64bit"))]
736                        let success_lane_idx = met_target.trailing_zeros() as usize;
737
738                        #[cfg(feature = "compare-64bit")]
739                        let success_lane_idx = INDEX_REMAP_PUNPCKLDQ
740                            [(met_target_high << 8 | met_target_lo).trailing_zeros() as usize];
741
742                        let nonce_prefix = 10 + 16 * prefix_set_index + success_lane_idx;
743
744                        self.message.message[14] = cum0;
745                        self.message.message[15] = cum1;
746                        {
747                            let message_bytes = decompose_blocks_mut(&mut self.message.message);
748                            *message_bytes.get_unchecked_mut(
749                                *SWAP_DWORD_BYTE_ORDER
750                                    .get_unchecked(DoubleBlockMessage::DIGIT_IDX as usize),
751                            ) = (nonce_prefix / 10) as u8 + b'0';
752                            *message_bytes.get_unchecked_mut(
753                                *SWAP_DWORD_BYTE_ORDER
754                                    .get_unchecked(DoubleBlockMessage::DIGIT_IDX as usize + 1),
755                            ) = (nonce_prefix % 10) as u8 + b'0';
756                        }
757
758                        // recompute the hash from the beginning
759                        // this prevents the compiler from having to compute the final B-H registers alive in tight loops
760                        let mut final_sha_state = self.message.prefix_state;
761                        crate::sha256::digest_block(&mut final_sha_state, &self.message.message);
762                        let mut terminal_message = [0; 16];
763                        terminal_message[14] = ((self.message.message_length * 8) >> 32) as u32;
764                        terminal_message[15] = (self.message.message_length * 8) as u32;
765                        crate::sha256::digest_block(&mut final_sha_state, &terminal_message);
766
767                        let mut decimal_inner_key = 0;
768                        let mut key_octal = next_inner_key - 1;
769                        for m in (0..7u32).map(|i| 10u64.pow(i)) {
770                            let output = (key_octal % 8) + 1;
771                            key_octal /= 8;
772                            decimal_inner_key += output as u64 * m;
773                        }
774
775                        let computed_nonce = nonce_prefix as u64 * 10u64.pow(7)
776                            + decimal_inner_key
777                            + self.message.nonce_addend;
778
779                        // the nonce is the 8 digits in the message, plus the first two digits recomputed from the lane index
780                        return Some((computed_nonce, *final_sha_state));
781                    }
782
783                    self.attempted_nonces += 16;
784
785                    if self.attempted_nonces >= self.limit {
786                        return None;
787                    }
788
789                    crate::strings::to_octal_7::<true, 0x80, 1>(&mut itoa_buf, next_inner_key);
790                }
791            }
792        }
793
794        crate::unlikely();
795
796        None
797    }
798}
799
800#[macro_use]
801#[path = "impl_decimal_solver.rs"]
802mod impl_decimal_solver;
803
804impl_decimal_solver!(
805    [SingleBlockSolver, DoubleBlockSolver] => DecimalSolver
806);
807
808/// AVX-512 binary nonce solver.
809///
810/// Output: nonce in little endian order
811///
812/// Current implementation: 16 way SIMD with 1-round hotstart granularity.
813pub struct BinarySolver {
814    message: BinaryMessage,
815    attempted_nonces: u64,
816    limit: u64,
817}
818
819impl From<BinaryMessage> for BinarySolver {
820    fn from(message: BinaryMessage) -> Self {
821        Self {
822            message,
823            attempted_nonces: 0,
824            limit: u64::MAX,
825        }
826    }
827}
828
829impl From<crate::solver::safe::BinarySolver> for BinarySolver {
830    fn from(solver: crate::solver::safe::BinarySolver) -> Self {
831        Self {
832            message: solver.message,
833            attempted_nonces: solver.attempted_nonces,
834            limit: solver.limit,
835        }
836    }
837}
838
839impl BinarySolver {
840    #[inline(never)]
841    #[target_feature(enable = "avx512f")]
842    fn solve_impl<
843        const TYPE: u8,
844        const FIRST_NONCE_WORD_IDX: usize,
845        const NEED_SECOND_BLOCK: bool,
846    >(
847        &mut self,
848        prefix_state: [u32; 8],
849        first_block: &Align64<[u32; 16]>,
850        second_block_schedule: &[u32; 64],
851        nonce_byte_offset: usize,
852        nonce_byte_count: core::num::NonZeroU8,
853        target: u64,
854        mask: u64,
855    ) -> Option<u64> {
856        let target = target & mask;
857
858        // at most 3 words may need to be patched (i.e. 96 bits)
859        // from which word to poke the nonce?
860        let poke_word_base = FIRST_NONCE_WORD_IDX.min(16 - 4);
861        let poke_word_byte_base = poke_word_base * 4;
862        let mut poke_word_tbl = Align16([!0u8; 16]);
863        let nonce_byte_count_decr = nonce_byte_count.get() as usize - 1;
864        for (i, ix) in ((nonce_byte_offset + 1)..)
865            .take(nonce_byte_count_decr)
866            .enumerate()
867        {
868            poke_word_tbl.0[unsafe {
869                *crate::SWAP_DWORD_BYTE_ORDER.get_unchecked(ix - poke_word_byte_base)
870            }] = i as u8;
871        }
872        let poke_word_tbl = unsafe { _mm_load_si128(poke_word_tbl.as_ptr().cast()) };
873        let lane_id_byte_remainder = 3 - nonce_byte_offset % 4;
874        let mut lane_id_base = Align64([0u32; 16]);
875        for i in 0..16 {
876            lane_id_base[i as usize] = i << (lane_id_byte_remainder * 8);
877        }
878        let lane_id_iterand = 16 << (lane_id_byte_remainder * 8);
879
880        let mut memo_state = prefix_state;
881        crate::sha256::ingest_message_prefix::<FIRST_NONCE_WORD_IDX>(
882            &mut memo_state,
883            first_block[..FIRST_NONCE_WORD_IDX].try_into().unwrap(),
884        );
885
886        for x in 0..(self
887            .limit
888            .min(256u64.saturating_pow(nonce_byte_count_decr as u32))
889            .max(1))
890        {
891            unsafe {
892                let mut block_tpl = *first_block;
893                let xm = _mm_cvtsi64x_si128(x as _);
894                let xmd = _mm_shuffle_epi8(xm, poke_word_tbl);
895                let loadd = _mm_loadu_si128(block_tpl.as_ptr().add(poke_word_base).cast());
896                _mm_storeu_si128(
897                    block_tpl.as_mut_ptr().add(poke_word_base).cast(),
898                    _mm_or_si128(loadd, xmd),
899                );
900                let mut lane_id_v = _mm512_load_si512(lane_id_base.as_ptr().cast());
901
902                for lane_id_set_idx in 0..(256 / 16) {
903                    macro_rules! get_msg {
904                        ($idx:expr) => {
905                            if $idx == FIRST_NONCE_WORD_IDX {
906                                _mm512_or_epi32(_mm512_set1_epi32(block_tpl[$idx] as _), lane_id_v)
907                            } else {
908                                _mm512_set1_epi32(block_tpl[$idx] as _)
909                            }
910                        };
911                    }
912
913                    let mut state = core::array::from_fn(|i| _mm512_set1_epi32(memo_state[i] as _));
914                    let mut msg = [
915                        get_msg!(0),
916                        get_msg!(1),
917                        get_msg!(2),
918                        get_msg!(3),
919                        get_msg!(4),
920                        get_msg!(5),
921                        get_msg!(6),
922                        get_msg!(7),
923                        get_msg!(8),
924                        get_msg!(9),
925                        get_msg!(10),
926                        get_msg!(11),
927                        get_msg!(12),
928                        get_msg!(13),
929                        get_msg!(14),
930                        get_msg!(15),
931                    ];
932
933                    crate::sha256::avx512::multiway_arx::<FIRST_NONCE_WORD_IDX>(
934                        &mut state, &mut msg,
935                    );
936
937                    if NEED_SECOND_BLOCK {
938                        for i in 0..8 {
939                            state[i] =
940                                _mm512_add_epi32(state[i], _mm512_set1_epi32(prefix_state[i] as _));
941                        }
942                        let save_a = state[0];
943                        #[cfg(feature = "compare-64bit")]
944                        let save_b = state[1];
945
946                        crate::sha256::avx512::bcst_multiway_arx::<0>(
947                            &mut state,
948                            second_block_schedule,
949                        );
950                        state[0] = _mm512_add_epi32(state[0], save_a);
951                        #[cfg(feature = "compare-64bit")]
952                        {
953                            state[1] = _mm512_add_epi32(state[1], save_b);
954                        }
955                    } else {
956                        state[0] =
957                            _mm512_add_epi32(state[0], _mm512_set1_epi32(prefix_state[0] as _));
958                        #[cfg(feature = "compare-64bit")]
959                        {
960                            state[1] =
961                                _mm512_add_epi32(state[1], _mm512_set1_epi32(prefix_state[1] as _));
962                        }
963                    }
964
965                    #[cfg(not(feature = "compare-64bit"))]
966                    let cmp_fn = |x: __m512i, y: __m512i| {
967                        if TYPE == crate::solver::SOLVE_TYPE_GT {
968                            _mm512_cmpgt_epu32_mask(x, y)
969                        } else if TYPE == crate::solver::SOLVE_TYPE_LT {
970                            _mm512_cmplt_epu32_mask(x, y)
971                        } else {
972                            _mm512_cmpeq_epu32_mask(
973                                _mm512_and_si512(x, _mm512_set1_epi32((mask >> 32) as _)),
974                                y,
975                            )
976                        }
977                    };
978
979                    #[cfg(feature = "compare-64bit")]
980                    let cmp64_fn = |x: __m512i, y: __m512i| {
981                        if TYPE == crate::solver::SOLVE_TYPE_GT {
982                            _mm512_cmpgt_epu64_mask(x, y)
983                        } else if TYPE == crate::solver::SOLVE_TYPE_LT {
984                            _mm512_cmplt_epu64_mask(x, y)
985                        } else {
986                            _mm512_cmpeq_epu64_mask(
987                                _mm512_and_si512(x, _mm512_set1_epi64(mask as _)),
988                                y,
989                            )
990                        }
991                    };
992
993                    #[cfg(not(feature = "compare-64bit"))]
994                    let met_target = (cmp_fn)(state[0], _mm512_set1_epi32((target >> 32) as _));
995
996                    #[cfg(feature = "compare-64bit")]
997                    let result_ab_lo = _mm512_unpacklo_epi32(state[1], state[0]);
998                    #[cfg(feature = "compare-64bit")]
999                    let result_ab_hi = _mm512_unpackhi_epi32(state[1], state[0]);
1000                    #[cfg(feature = "compare-64bit")]
1001                    let (met_target_high, met_target_lo) = {
1002                        let ab_met_target_lo =
1003                            cmp64_fn(result_ab_lo, _mm512_set1_epi64(target as _)) as u16;
1004                        let ab_met_target_high =
1005                            cmp64_fn(result_ab_hi, _mm512_set1_epi64(target as _)) as u16;
1006                        (ab_met_target_high, ab_met_target_lo)
1007                    };
1008
1009                    #[cfg(feature = "compare-64bit")]
1010                    let met_target_test = met_target_high != 0 || met_target_lo != 0;
1011
1012                    #[cfg(not(feature = "compare-64bit"))]
1013                    let met_target_test = met_target != 0;
1014
1015                    if met_target_test {
1016                        crate::unlikely();
1017
1018                        #[cfg(not(feature = "compare-64bit"))]
1019                        let success_lane_idx = met_target.trailing_zeros() as usize;
1020                        #[cfg(feature = "compare-64bit")]
1021                        let success_lane_idx = INDEX_REMAP_PUNPCKLDQ
1022                            [(met_target_high << 8 | met_target_lo).trailing_zeros() as usize];
1023
1024                        let nonce_addend = 16 * lane_id_set_idx + success_lane_idx;
1025
1026                        let nonce = x << 8 | nonce_addend as u64;
1027
1028                        return Some(nonce);
1029                    }
1030
1031                    lane_id_v = _mm512_add_epi32(lane_id_v, _mm512_set1_epi32(lane_id_iterand));
1032                    self.attempted_nonces += 16;
1033                }
1034
1035                if self.attempted_nonces >= self.limit {
1036                    return None;
1037                }
1038            }
1039        }
1040
1041        None
1042    }
1043}
1044
1045impl crate::solver::Solver for BinarySolver {
1046    fn set_limit(&mut self, limit: u64) {
1047        self.limit = limit;
1048    }
1049
1050    fn get_attempted_nonces(&self) -> u64 {
1051        self.attempted_nonces
1052    }
1053
1054    fn solve<const TYPE: u8>(&mut self, target: u64, mask: u64) -> Option<(u64, [u32; 8])> {
1055        if (self.message.nonce_byte_count.get() == 1) // edge case not worth optimizing, bail out
1056            || (self.message.salt_residual_len + self.message.nonce_byte_count.get() as usize > 64)
1057        // TODO: optimize edge case where nonce itself cross block boundary
1058        {
1059            crate::unlikely();
1060            let mut solver = crate::solver::safe::BinarySolver::from(self.message.clone());
1061            return solver.solve::<TYPE>(target, mask);
1062        }
1063
1064        let salt = &self.message.salt_residual[..self.message.salt_residual_len];
1065        let mut blocks = [GenericArray::default(); 2];
1066        blocks[0][..salt.len()].copy_from_slice(salt);
1067        let mut ptr = salt.len();
1068        let mut cur_block = 0;
1069
1070        for _ in 0..self.message.nonce_byte_count.get() {
1071            blocks[cur_block][ptr] = 0;
1072            ptr += 1;
1073            if ptr == 64 {
1074                cur_block = 1;
1075                ptr = 0;
1076            }
1077        }
1078        blocks[cur_block][ptr] = 0x80;
1079        ptr += 1;
1080        if ptr + 8 > 64 {
1081            cur_block = 1;
1082        }
1083        blocks[cur_block][(64 - 8)..]
1084            .copy_from_slice(&(self.message.message_length * 8).to_be_bytes());
1085
1086        let used_blocks = &mut blocks[..=cur_block];
1087
1088        let mut block_template_be = Align64([0; 16]);
1089        for i in 0..16 {
1090            block_template_be[i] =
1091                u32::from_be_bytes(used_blocks[0][i * 4..][..4].try_into().unwrap());
1092        }
1093
1094        let mut second_block_schedule = [0; 64];
1095        if cur_block == 1 {
1096            for i in 0..16 {
1097                second_block_schedule[i] =
1098                    u32::from_be_bytes(used_blocks[1][i * 4..][..4].try_into().unwrap());
1099            }
1100            crate::sha256::do_message_schedule_k_w(&mut second_block_schedule);
1101        }
1102
1103        macro_rules! dispatch {
1104            ($skipped_rounds:expr) => {
1105                unsafe {
1106                    if cur_block == 1 {
1107                        if let Some(nonce) = self.solve_impl::<TYPE, { $skipped_rounds }, true>(
1108                            *self.message.prefix_state,
1109                            &block_template_be,
1110                            &second_block_schedule,
1111                            self.message.salt_residual_len,
1112                            self.message.nonce_byte_count,
1113                            target,
1114                            mask,
1115                        ) {
1116                            let mut final_sha_state = self.message.prefix_state;
1117                            for i in 0..self.message.nonce_byte_count.get() as usize {
1118                                used_blocks[0][self.message.salt_residual_len + i] =
1119                                    nonce.to_le_bytes()[i];
1120                            }
1121                            sha2::compress256(&mut final_sha_state, &used_blocks);
1122                            return Some((nonce, final_sha_state.0));
1123                        }
1124                    } else {
1125                        if let Some(nonce) = self.solve_impl::<TYPE, { $skipped_rounds }, false>(
1126                            *self.message.prefix_state,
1127                            &block_template_be,
1128                            &[0; 64],
1129                            self.message.salt_residual_len,
1130                            self.message.nonce_byte_count,
1131                            target,
1132                            mask,
1133                        ) {
1134                            let mut final_sha_state = self.message.prefix_state;
1135                            for i in 0..self.message.nonce_byte_count.get() as usize {
1136                                used_blocks[0][self.message.salt_residual_len + i] =
1137                                    nonce.to_le_bytes()[i];
1138                            }
1139                            sha2::compress256(&mut final_sha_state, &used_blocks);
1140                            return Some((nonce, final_sha_state.0));
1141                        }
1142                    }
1143                }
1144            };
1145        }
1146
1147        match self.message.salt_residual_len / 4 {
1148            0 => dispatch!(0),
1149            1 => dispatch!(1),
1150            2 => dispatch!(2),
1151            3 => dispatch!(3),
1152            4 => dispatch!(4),
1153            5 => dispatch!(5),
1154            6 => dispatch!(6),
1155            7 => dispatch!(7),
1156            8 => dispatch!(8),
1157            9 => dispatch!(9),
1158            10 => dispatch!(10),
1159            11 => dispatch!(11),
1160            12 => dispatch!(12),
1161            13 => dispatch!(13),
1162            14 => dispatch!(14),
1163            15 => dispatch!(15),
1164            _ => unreachable!(),
1165        }
1166
1167        crate::unlikely();
1168
1169        None
1170    }
1171}
1172
1173/// AVX-512 GoAway solver.
1174///
1175///
1176/// Current implementation: 16 way SIMD with 1-round hotstart granularity.
1177pub struct GoAwaySolver {
1178    message: GoAwayMessage,
1179    attempted_nonces: u64,
1180    limit: u64,
1181}
1182
1183impl From<super::safe::GoAwaySolver> for GoAwaySolver {
1184    fn from(solver: super::safe::GoAwaySolver) -> Self {
1185        Self {
1186            message: solver.message,
1187            attempted_nonces: solver.attempted_nonces,
1188            limit: solver.limit,
1189        }
1190    }
1191}
1192
1193impl From<GoAwayMessage> for GoAwaySolver {
1194    fn from(message: GoAwayMessage) -> Self {
1195        Self {
1196            message,
1197            attempted_nonces: 0,
1198            limit: u64::MAX,
1199        }
1200    }
1201}
1202
1203impl GoAwaySolver {
1204    const MSG_LEN: u32 = 10 * 4 * 8;
1205
1206    #[inline(never)]
1207    #[target_feature(enable = "avx512f")]
1208    unsafe fn solve_nonce_only_impl<const TYPE: u8>(
1209        &mut self,
1210        target: u64,
1211        mask: u64,
1212    ) -> Option<u64> {
1213        let lane_id_v = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1214
1215        let target = target & mask;
1216
1217        let mut prefix_state = crate::sha256::IV;
1218        crate::sha256::ingest_message_prefix(&mut prefix_state, self.message.challenge);
1219
1220        let remaining_limit = self.limit.min(u32::MAX as u64) as u32;
1221
1222        {
1223            let mut partial_state = prefix_state;
1224            crate::sha256::sha2_arx::<8>(&mut partial_state, &[self.message.high_word]);
1225
1226            for low_word in (0..=remaining_limit).step_by(16) {
1227                let mut state = core::array::from_fn(|i| _mm512_set1_epi32(partial_state[i] as _));
1228
1229                let mut msg = [
1230                    _mm512_set1_epi32(self.message.challenge[0] as _),
1231                    _mm512_set1_epi32(self.message.challenge[1] as _),
1232                    _mm512_set1_epi32(self.message.challenge[2] as _),
1233                    _mm512_set1_epi32(self.message.challenge[3] as _),
1234                    _mm512_set1_epi32(self.message.challenge[4] as _),
1235                    _mm512_set1_epi32(self.message.challenge[5] as _),
1236                    _mm512_set1_epi32(self.message.challenge[6] as _),
1237                    _mm512_set1_epi32(self.message.challenge[7] as _),
1238                    _mm512_set1_epi32(self.message.high_word as _),
1239                    _mm512_or_epi32(_mm512_set1_epi32(low_word as _), lane_id_v),
1240                    _mm512_set1_epi32(u32::from_be_bytes([0x80, 0, 0, 0]) as _),
1241                    _mm512_setzero_epi32(),
1242                    _mm512_setzero_epi32(),
1243                    _mm512_setzero_epi32(),
1244                    _mm512_setzero_epi32(),
1245                    _mm512_set1_epi32(Self::MSG_LEN as _),
1246                ];
1247                crate::sha256::avx512::multiway_arx::<9>(&mut state, &mut msg);
1248
1249                state[0] = _mm512_add_epi32(state[0], _mm512_set1_epi32(crate::sha256::IV[0] as _));
1250
1251                #[cfg(feature = "compare-64bit")]
1252                {
1253                    state[1] =
1254                        _mm512_add_epi32(state[1], _mm512_set1_epi32(crate::sha256::IV[1] as _));
1255                }
1256
1257                #[cfg(not(feature = "compare-64bit"))]
1258                let cmp_fn = |x: __m512i, y: __m512i| {
1259                    if TYPE == crate::solver::SOLVE_TYPE_GT {
1260                        _mm512_cmpgt_epu32_mask(x, y)
1261                    } else if TYPE == crate::solver::SOLVE_TYPE_LT {
1262                        _mm512_cmplt_epu32_mask(x, y)
1263                    } else {
1264                        _mm512_cmpeq_epu32_mask(
1265                            _mm512_and_si512(x, _mm512_set1_epi32((mask >> 32) as _)),
1266                            y,
1267                        )
1268                    }
1269                };
1270
1271                #[cfg(feature = "compare-64bit")]
1272                let cmp64_fn = |x: __m512i, y: __m512i| {
1273                    if TYPE == crate::solver::SOLVE_TYPE_GT {
1274                        _mm512_cmpgt_epu64_mask(x, y)
1275                    } else if TYPE == crate::solver::SOLVE_TYPE_LT {
1276                        _mm512_cmplt_epu64_mask(x, y)
1277                    } else {
1278                        _mm512_cmpeq_epu64_mask(
1279                            _mm512_and_si512(x, _mm512_set1_epi64(mask as _)),
1280                            y,
1281                        )
1282                    }
1283                };
1284
1285                #[cfg(not(feature = "compare-64bit"))]
1286                let met_target = cmp_fn(state[0], _mm512_set1_epi32((target >> 32) as _));
1287
1288                #[cfg(feature = "compare-64bit")]
1289                let result_ab_lo = _mm512_unpacklo_epi32(state[1], state[0]);
1290                #[cfg(feature = "compare-64bit")]
1291                let result_ab_hi = _mm512_unpackhi_epi32(state[1], state[0]);
1292                #[cfg(feature = "compare-64bit")]
1293                let (met_target_high, met_target_lo) = {
1294                    let ab_met_target_lo =
1295                        cmp64_fn(result_ab_lo, _mm512_set1_epi64(target as _)) as u16;
1296                    let ab_met_target_high =
1297                        cmp64_fn(result_ab_hi, _mm512_set1_epi64(target as _)) as u16;
1298                    (ab_met_target_high, ab_met_target_lo)
1299                };
1300
1301                #[cfg(feature = "compare-64bit")]
1302                let met_target_test = met_target_high != 0 || met_target_lo != 0;
1303                #[cfg(not(feature = "compare-64bit"))]
1304                let met_target_test = met_target != 0;
1305
1306                self.attempted_nonces += 16;
1307
1308                if met_target_test {
1309                    crate::unlikely();
1310
1311                    #[cfg(not(feature = "compare-64bit"))]
1312                    let success_lane_idx = met_target.trailing_zeros();
1313
1314                    #[cfg(feature = "compare-64bit")]
1315                    let success_lane_idx = INDEX_REMAP_PUNPCKLDQ
1316                        [(met_target_high << 8 | met_target_lo).trailing_zeros() as usize];
1317
1318                    let final_low_word = low_word | (success_lane_idx as u32);
1319
1320                    return Some((self.message.high_word as u64) << 32 | final_low_word as u64);
1321                }
1322
1323                if self.attempted_nonces >= self.limit {
1324                    return None;
1325                }
1326            }
1327        }
1328        None
1329    }
1330}
1331
1332impl crate::solver::Solver for GoAwaySolver {
1333    fn set_limit(&mut self, limit: u64) {
1334        self.limit = limit;
1335    }
1336
1337    fn get_attempted_nonces(&self) -> u64 {
1338        self.attempted_nonces
1339    }
1340
1341    #[inline(always)]
1342    fn solve_nonce_only<const TYPE: u8>(&mut self, target: u64, mask: u64) -> Option<u64> {
1343        unsafe { self.solve_nonce_only_impl::<TYPE>(target, mask) }
1344    }
1345
1346    fn solve<const TYPE: u8>(&mut self, target: u64, mask: u64) -> Option<(u64, [u32; 8])> {
1347        let mut output_msg = [0; 16];
1348        let nonce = self.solve_nonce_only::<TYPE>(target, mask)?;
1349        output_msg[..8].copy_from_slice(&self.message.challenge);
1350        output_msg[8] = (nonce >> 32) as u32;
1351        output_msg[9] = nonce as u32;
1352        output_msg[10] = u32::from_be_bytes([0x80, 0, 0, 0]);
1353        output_msg[15] = Self::MSG_LEN as _;
1354
1355        let mut final_sha_state = crate::sha256::IV;
1356        crate::sha256::digest_block(&mut final_sha_state, &output_msg);
1357
1358        Some((nonce, final_sha_state))
1359    }
1360}
1361
1362/// AVX-512 Cerberus solver.
1363///
1364/// Current implementation: 9-digit out-of-order kernel with dual-wavefront 16 way SIMD and quarter-round hotstart granularity.
1365pub struct CerberusSolver {
1366    message: CerberusMessage,
1367    attempted_nonces: u64,
1368    limit: u64,
1369}
1370
1371impl From<CerberusMessage> for CerberusSolver {
1372    fn from(message: CerberusMessage) -> Self {
1373        Self {
1374            message,
1375            attempted_nonces: 0,
1376            limit: !0,
1377        }
1378    }
1379}
1380
1381impl CerberusSolver {
1382    #[inline(never)]
1383    #[target_feature(enable = "avx512f")]
1384    fn solve_decimal_impl<
1385        const CENTER_WORD_IDX: usize,
1386        const LANE_ID_WORD_IDX: usize,
1387        const CONSTANT_WORD_COUNT: usize,
1388    >(
1389        &mut self,
1390        msg_tpl: Align64<[u32; 16]>,
1391        target: u64,
1392        mask: u64,
1393    ) -> Option<(u64, u64)> {
1394        debug_assert_eq!(target, 0);
1395
1396        let CerberusMessage::Decimal(message) = &self.message else {
1397            return None;
1398        };
1399
1400        // inform LLVM that padding is guaranteed to be zero
1401        let mut msg = Align64([0u32; 16]);
1402        msg.0[..=CENTER_WORD_IDX + 1].copy_from_slice(&msg_tpl.0[..=CENTER_WORD_IDX + 1]);
1403        let prepared_state = crate::blake3::ingest_message_prefix(
1404            *message.prefix_state,
1405            &msg[..CONSTANT_WORD_COUNT],
1406            0,
1407            message.salt_residual_len as u32 + 9,
1408            message.flags,
1409        );
1410
1411        for lane_id_idx in 0..(LANE_ID_STR_COMBINED_LE_HI.len() / 16) {
1412            if self.attempted_nonces >= self.limit {
1413                return None;
1414            }
1415            unsafe {
1416                let mut lane_id_value = _mm512_load_si512(
1417                    LANE_ID_STR_COMBINED_LE_HI
1418                        .as_ptr()
1419                        .add(lane_id_idx * 16)
1420                        .cast(),
1421                );
1422                if CENTER_WORD_IDX < LANE_ID_WORD_IDX {
1423                    lane_id_value = _mm512_srli_epi32(lane_id_value, 8);
1424                }
1425
1426                let state_base =
1427                    core::array::from_fn(|i| _mm512_set1_epi32(prepared_state[i] as _));
1428                let patch =
1429                    _mm512_or_epi32(_mm512_set1_epi32(msg[LANE_ID_WORD_IDX] as _), lane_id_value);
1430                let maskv = _mm512_set1_epi32((mask >> 32) as _);
1431
1432                for (i, word) in crate::strings::DIGIT_LUT_10000_LE_EVEN.iter().enumerate() {
1433                    msg[CENTER_WORD_IDX] = *word;
1434
1435                    let mut state = state_base;
1436
1437                    crate::blake3::avx512::compress_mb16::<CONSTANT_WORD_COUNT, LANE_ID_WORD_IDX>(
1438                        &mut state, &msg, patch,
1439                    );
1440
1441                    let s0 = state[0];
1442
1443                    msg[CENTER_WORD_IDX] |= u32::from_be_bytes([1, 0, 0, 0]);
1444
1445                    state = state_base;
1446                    crate::blake3::avx512::compress_mb16::<CONSTANT_WORD_COUNT, LANE_ID_WORD_IDX>(
1447                        &mut state, &msg, patch,
1448                    );
1449                    let s1 = state[0];
1450
1451                    let hit0 = _mm512_testn_epi32_mask(s0, maskv);
1452                    let hit1 = _mm512_testn_epi32_mask(s1, maskv);
1453
1454                    self.attempted_nonces += 32;
1455
1456                    if hit0 != 0 || hit1 != 0 {
1457                        crate::unlikely();
1458
1459                        let success_lane_idx0 = hit0.trailing_zeros();
1460                        let success_lane_idx1 = hit1.trailing_zeros();
1461
1462                        if success_lane_idx0 < success_lane_idx1 {
1463                            return Some((
1464                                i as u64 * 2,
1465                                lane_id_idx as u64 * 16 + success_lane_idx0 as u64,
1466                            ));
1467                        } else {
1468                            return Some((
1469                                i as u64 * 2 + 1,
1470                                lane_id_idx as u64 * 16 + success_lane_idx1 as u64,
1471                            ));
1472                        }
1473                    }
1474                }
1475            }
1476        }
1477        None
1478    }
1479
1480    #[inline(never)]
1481    #[target_feature(enable = "avx512f")]
1482    unsafe fn solve_binary_impl(&mut self, _target: u64, mask: u64) -> Option<u64> {
1483        let CerberusMessage::Binary(message) = &self.message else {
1484            return None;
1485        };
1486        let mut msg = [0; 16];
1487        msg[0] = message.first_word;
1488        let prepared_state = crate::blake3::ingest_message_prefix(
1489            *message.midstate,
1490            &msg[..1],
1491            0,
1492            8,
1493            crate::blake3::FLAG_CHUNK_END | crate::blake3::FLAG_ROOT,
1494        );
1495        let state_base = core::array::from_fn(|i| _mm512_set1_epi32(prepared_state[i] as _));
1496        let mut nonce = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1497        let increment_nonce = _mm512_set1_epi32(16);
1498        let maskv = _mm512_set1_epi32((mask >> 32) as _);
1499        for rep in 0..=(u32::MAX / 16) {
1500            let mut state = state_base;
1501            crate::blake3::avx512::compress_mb16::<1, 1>(&mut state, &msg, nonce);
1502            self.attempted_nonces += 16;
1503
1504            let hit = _mm512_testn_epi32_mask(state[0], maskv);
1505            if hit != 0 {
1506                crate::unlikely();
1507
1508                let success_lane_idx = hit.trailing_zeros();
1509
1510                return Some(
1511                    (rep * 16 + success_lane_idx) as u64 | (message.first_word as u64) << 32,
1512                );
1513            }
1514            nonce = _mm512_add_epi32(nonce, increment_nonce);
1515            if self.attempted_nonces >= self.limit {
1516                return None;
1517            }
1518        }
1519        None
1520    }
1521}
1522
1523impl crate::solver::Solver for CerberusSolver {
1524    fn set_limit(&mut self, limit: u64) {
1525        self.limit = limit;
1526    }
1527
1528    fn get_attempted_nonces(&self) -> u64 {
1529        self.attempted_nonces
1530    }
1531
1532    fn solve_nonce_only<const TYPE: u8>(&mut self, target: u64, mask: u64) -> Option<u64> {
1533        match &self.message {
1534            CerberusMessage::Binary(_) => unsafe { self.solve_binary_impl(target, mask) },
1535            CerberusMessage::Decimal(message) => {
1536                // two digits as lane ID, N=\x00, ? is prefix
1537                // position % 4 =0: |1234|5678|NNN9
1538                // position % 4 =1: |123?|4567|NN89
1539                // position % 4 =2: |12??|3456|N789
1540                // position % 4 =3: |1???|2345|6789
1541
1542                let center_word_idx = message.salt_residual_len / 4 + 1;
1543                let position_mod = message.salt_residual_len % 4;
1544                let nonce_addend = message.nonce_addend;
1545                let salt_residual = message.salt_residual;
1546                let salt_residual_len = message.salt_residual_len;
1547
1548                for resid0 in 0..10u64 {
1549                    for resid1 in 0..10u64 {
1550                        if self.attempted_nonces >= self.limit {
1551                            return None;
1552                        }
1553                        let mut msg = salt_residual;
1554
1555                        match position_mod {
1556                            0 => {
1557                                msg[salt_residual_len] = resid0 as u8 + b'0';
1558                                msg[salt_residual_len + 8] = resid1 as u8 + b'0';
1559                            }
1560                            1 => {
1561                                msg[salt_residual_len + 7] = resid0 as u8 + b'0';
1562                                msg[salt_residual_len + 8] = resid1 as u8 + b'0';
1563                            }
1564                            2 => {
1565                                msg[salt_residual_len] = resid0 as u8 + b'0';
1566                                msg[salt_residual_len + 1] = resid1 as u8 + b'0';
1567                            }
1568                            3 => {
1569                                msg[salt_residual_len] = resid0 as u8 + b'0';
1570                                msg[salt_residual_len + 8] = resid1 as u8 + b'0';
1571                            }
1572                            _ => unreachable!(),
1573                        }
1574
1575                        let msg = Align64(core::array::from_fn(|i| {
1576                            u32::from_le_bytes([
1577                                msg[i * 4],
1578                                msg[i * 4 + 1],
1579                                msg[i * 4 + 2],
1580                                msg[i * 4 + 3],
1581                            ])
1582                        }));
1583
1584                        macro_rules! dispatch {
1585                            ($center_word_idx:literal) => {
1586                                unsafe {
1587                                    if position_mod < 2 {
1588                                        self.solve_decimal_impl::<$center_word_idx, { $center_word_idx - 1 }, {$center_word_idx - 1}>(
1589                                            msg, target, mask,
1590                                        )
1591                                    } else {
1592                                        self.solve_decimal_impl::<$center_word_idx, { $center_word_idx + 1 }, $center_word_idx>(
1593                                            msg, target, mask,
1594                                        )
1595                                    }
1596                                }
1597                            };
1598                        }
1599
1600                        if let Some((middle_word, success_lane_idx)) = match center_word_idx {
1601                            1 => dispatch!(1),
1602                            2 => dispatch!(2),
1603                            3 => dispatch!(3),
1604                            4 => dispatch!(4),
1605                            5 => dispatch!(5),
1606                            6 => dispatch!(6),
1607                            7 => dispatch!(7),
1608                            8 => dispatch!(8),
1609                            9 => dispatch!(9),
1610                            10 => dispatch!(10),
1611                            11 => dispatch!(11),
1612                            12 => dispatch!(12),
1613                            13 => dispatch!(13),
1614                            14 => dispatch!(14),
1615                            15 => dispatch!(15),
1616                            _ => unreachable!(),
1617                        } {
1618                            let output_nonce = nonce_addend
1619                                + match position_mod {
1620                                    0 => {
1621                                        10 * middle_word
1622                                            + 100_000 * success_lane_idx
1623                                            + 100_000_000 * resid0
1624                                            + resid1
1625                                    }
1626                                    1 => {
1627                                        100 * middle_word
1628                                            + 1_000_000 * success_lane_idx
1629                                            + 10 * resid0
1630                                            + resid1
1631                                    }
1632                                    2 => {
1633                                        1000 * middle_word
1634                                            + success_lane_idx
1635                                            + 100_000_000 * resid0
1636                                            + 10_000_000 * resid1
1637                                    }
1638                                    3 => {
1639                                        10000 * middle_word
1640                                            + 10 * success_lane_idx
1641                                            + 100_000_000 * resid0
1642                                            + resid1
1643                                    }
1644                                    _ => unreachable!(),
1645                                };
1646
1647                            return Some(output_nonce);
1648                        }
1649                    }
1650                }
1651                None
1652            }
1653        }
1654    }
1655
1656    fn solve<const TYPE: u8>(&mut self, target: u64, mask: u64) -> Option<(u64, [u32; 8])> {
1657        if let Some(nonce) = self.solve_nonce_only::<TYPE>(target, mask) {
1658            match &self.message {
1659                CerberusMessage::Decimal(message) => {
1660                    let mut msg = message.salt_residual;
1661
1662                    let mut nonce_copy = nonce;
1663                    for i in (0..9).rev() {
1664                        msg[message.salt_residual_len + i] = (nonce_copy % 10) as u8 + b'0';
1665                        nonce_copy /= 10;
1666                    }
1667
1668                    let mut msg = core::array::from_fn(|i| {
1669                        u32::from_le_bytes([
1670                            msg[i * 4],
1671                            msg[i * 4 + 1],
1672                            msg[i * 4 + 2],
1673                            msg[i * 4 + 3],
1674                        ])
1675                    });
1676
1677                    let hash = crate::blake3::compress8(
1678                        &message.prefix_state,
1679                        &mut msg,
1680                        0,
1681                        message.salt_residual_len as u32 + 9,
1682                        message.flags,
1683                    );
1684
1685                    Some((nonce, hash))
1686                }
1687                CerberusMessage::Binary(message) => {
1688                    let mut msg = [0; 16];
1689                    msg[0] = message.first_word;
1690                    msg[1] = nonce as u32;
1691                    let hash = crate::blake3::compress8(
1692                        &message.midstate,
1693                        &msg,
1694                        0,
1695                        8,
1696                        crate::blake3::FLAG_CHUNK_END | crate::blake3::FLAG_ROOT,
1697                    );
1698                    Some((msg[1] as u64 | (msg[0] as u64) << 32, hash))
1699                }
1700            }
1701        } else {
1702            None
1703        }
1704    }
1705}
1706
1707#[cfg(target_feature = "avx512f")]
1708#[cfg(test)]
1709mod tests {
1710    use crate::message::{CerberusBinaryMessage, CerberusDecimalMessage};
1711
1712    use super::*;
1713
1714    #[test]
1715    fn test_solve_cerberus_decimal() {
1716        for i in 0..=1 {
1717            crate::solver::tests::test_cerberus_decimal_validator::<CerberusSolver, _>(|prefix| {
1718                Some(CerberusMessage::Decimal(CerberusDecimalMessage::new(prefix, i)?).into())
1719            });
1720        }
1721    }
1722
1723    #[test]
1724    fn test_solve_cerberus_binary() {
1725        for i in 0..=1 {
1726            crate::solver::tests::test_cerberus_binary_validator::<CerberusSolver, _>(|prefix| {
1727                Some(CerberusMessage::Binary(CerberusBinaryMessage::new(prefix, i)).into())
1728            });
1729        }
1730    }
1731
1732    #[test]
1733    fn test_solve_decimal() {
1734        crate::solver::tests::test_decimal_validator::<DecimalSolver, _>(|prefix, search_space| {
1735            if let Some(solver) = SingleBlockMessage::new(prefix, search_space).map(Into::into) {
1736                Some(DecimalSolver::SingleBlock(solver))
1737            } else {
1738                DoubleBlockMessage::new(prefix, search_space).map(Into::into)
1739            }
1740        });
1741    }
1742
1743    #[test]
1744    fn test_solve_decimal_f64() {
1745        crate::solver::tests::test_decimal_validator_f64_safe::<DecimalSolver, _>(
1746            |prefix, search_space| {
1747                if let Some((solver, p)) =
1748                    SingleBlockMessage::new_f64(prefix, search_space).map(|(x, p)| (x.into(), p))
1749                {
1750                    Some((DecimalSolver::SingleBlock(solver), p))
1751                } else {
1752                    DoubleBlockMessage::new(prefix, search_space)
1753                        .map(|x| (DecimalSolver::DoubleBlock(x.into()), None))
1754                }
1755            },
1756        );
1757    }
1758
1759    #[test]
1760    fn test_solve_binary() {
1761        crate::solver::tests::test_binary_validator::<BinarySolver, _>(
1762            |prefix, nonce_byte_count| {
1763                BinarySolver::from(BinaryMessage::new(prefix, nonce_byte_count))
1764            },
1765        )
1766    }
1767
1768    #[test]
1769    fn test_solve_goaway() {
1770        crate::solver::tests::test_goaway_validator::<GoAwaySolver, _>(|prefix| {
1771            GoAwaySolver::from(GoAwayMessage::new(
1772                core::array::from_fn(|i| {
1773                    u32::from_be_bytes([
1774                        prefix[i * 4],
1775                        prefix[i * 4 + 1],
1776                        prefix[i * 4 + 2],
1777                        prefix[i * 4 + 3],
1778                    ])
1779                }),
1780                0,
1781            ))
1782        });
1783    }
1784}