bee_crypto/ternary/sponge/curlp/batched/
bct_curlp.rs

1// Copyright 2020-2021 IOTA Stiftung
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::ternary::{
5    sponge::{
6        curlp::batched::{
7            bct::{BcTrit, BcTritArr, BcTrits},
8            HIGH_BITS,
9        },
10        CurlPRounds,
11    },
12    HASH_LENGTH,
13};
14
15pub(crate) struct BctCurlP {
16    rounds: CurlPRounds,
17    state: BcTritArr<{ 3 * HASH_LENGTH }>,
18    scratch_pad: BcTritArr<{ 3 * HASH_LENGTH }>,
19}
20
21impl BctCurlP {
22    #[allow(clippy::assertions_on_constants)]
23    pub(crate) fn new(rounds: CurlPRounds) -> Self {
24        // Ensure that changing the hash length will not cause undefined behaviour.
25        assert!(3 * HASH_LENGTH > 728);
26        Self {
27            rounds,
28            state: BcTritArr::filled(HIGH_BITS),
29            scratch_pad: BcTritArr::filled(HIGH_BITS),
30        }
31    }
32
33    pub(crate) fn reset(&mut self) {
34        self.state.fill(HIGH_BITS);
35    }
36
37    pub(crate) fn transform(&mut self) {
38        let mut scratch_pad_index = 0;
39
40        // All the unchecked accesses here are guaranteed to be safe by the assertion inside `new`.
41        for _round in 0..self.rounds as usize {
42            self.scratch_pad.copy_from_slice(&self.state);
43
44            let BcTrit(mut alpha, mut beta) = unsafe { *self.scratch_pad.get_unchecked(scratch_pad_index) };
45
46            scratch_pad_index += 364;
47
48            let mut temp = unsafe { *self.scratch_pad.get_unchecked(scratch_pad_index) };
49
50            let delta = beta ^ temp.lo();
51
52            *unsafe { self.state.get_unchecked_mut(0) } = BcTrit(!(delta & alpha), delta | (alpha ^ temp.hi()));
53
54            let mut state_index = 1;
55
56            while state_index < self.state.len() {
57                scratch_pad_index += 364;
58
59                alpha = temp.lo();
60                beta = temp.hi();
61                temp = unsafe { *self.scratch_pad.get_unchecked(scratch_pad_index) };
62
63                let delta = beta ^ temp.lo();
64
65                *unsafe { self.state.get_unchecked_mut(state_index) } =
66                    BcTrit(!(delta & alpha), delta | (alpha ^ temp.hi()));
67
68                state_index += 1;
69
70                scratch_pad_index -= 365;
71
72                alpha = temp.lo();
73                beta = temp.hi();
74                temp = unsafe { *self.scratch_pad.get_unchecked(scratch_pad_index) };
75
76                let delta = beta ^ temp.lo();
77
78                *unsafe { self.state.get_unchecked_mut(state_index) } =
79                    BcTrit(!(delta & alpha), delta | (alpha ^ temp.hi()));
80
81                state_index += 1;
82            }
83        }
84    }
85
86    pub(crate) fn absorb(&mut self, bc_trits: &BcTrits) {
87        let mut length = bc_trits.len();
88        let mut offset = 0;
89
90        loop {
91            let length_to_copy = if length < HASH_LENGTH { length } else { HASH_LENGTH };
92            // This is safe as `length_to_copy <= HASH_LENGTH`.
93            unsafe { self.state.get_unchecked_mut(0..length_to_copy) }
94                .copy_from_slice(unsafe { bc_trits.get_unchecked(offset..offset + length_to_copy) });
95
96            self.transform();
97
98            if length <= length_to_copy {
99                break;
100            } else {
101                offset += length_to_copy;
102                length -= length_to_copy;
103            }
104        }
105    }
106
107    // This method shouldn't assume that `result` has any particular content, just that it has an
108    // adequate size.
109    pub(crate) fn squeeze_into(&mut self, result: &mut BcTrits) {
110        let trit_count = result.len();
111
112        let hash_count = trit_count / HASH_LENGTH;
113
114        for i in 0..hash_count {
115            unsafe { result.get_unchecked_mut(i * HASH_LENGTH..(i + 1) * HASH_LENGTH) }
116                .copy_from_slice(unsafe { self.state.get_unchecked(0..HASH_LENGTH) });
117
118            self.transform();
119        }
120
121        let last = trit_count - hash_count * HASH_LENGTH;
122
123        unsafe { result.get_unchecked_mut(trit_count - last..trit_count) }
124            .copy_from_slice(unsafe { self.state.get_unchecked(0..last) });
125
126        if trit_count % HASH_LENGTH != 0 {
127            self.transform();
128        }
129    }
130}