Skip to main content

kerl/
kerl.rs

1use crate::keccak::Keccak;
2use crate::Sponge;
3use crate::constants::*;
4
5#[derive(Clone, Copy)]
6pub struct Kerl(Keccak);
7
8impl Default for Kerl {
9    fn default() -> Kerl {
10        Kerl(Keccak::new_keccak384())
11    }
12}
13
14impl Sponge for Kerl
15where
16    Self: Send + 'static,
17{
18    type Item = Trit;
19
20    fn absorb(&mut self, trits: &[Self::Item]) {
21        assert_eq!(trits.len() % TRIT_LENGTH, 0);
22        let mut bytes: [u8; BYTE_LENGTH] = [0; BYTE_LENGTH];
23
24        for chunk in trits.chunks(TRIT_LENGTH) {
25            trits_to_bytes(chunk, &mut bytes);
26            self.0.update(&bytes);
27        }
28    }
29
30    fn squeeze(&mut self, out: &mut [Self::Item]) {
31        assert_eq!(out.len() % TRIT_LENGTH, 0);
32        let mut bytes: [u8; BYTE_LENGTH] = [0; BYTE_LENGTH];
33
34        for chunk in out.chunks_mut(TRIT_LENGTH) {
35            self.0.pad();
36            self.0.fill_block();
37            self.0.squeeze(&mut bytes);
38            self.reset();
39            bytes_to_trits(&mut bytes.to_vec(), chunk);
40            for b in bytes.iter_mut() {
41                *b = *b ^ 0xFF;
42            }
43            self.0.update(&bytes);
44        }
45    }
46
47    fn reset(&mut self) {
48        self.0 = Keccak::new_keccak384();
49    }
50}
51
52fn trits_to_bytes(trits: &[Trit], bytes: &mut [u8]) {
53    assert_eq!(trits.len(), TRIT_LENGTH);
54    assert_eq!(bytes.len(), BYTE_LENGTH);
55
56    // We _know_ that the sizes match.
57    // So this is safe enough to do and saves us a few allocations.
58    let base: &mut [u32] =
59        unsafe { core::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut u32, 12) };
60
61    base.clone_from_slice(&[0; 12]);
62
63    let mut size = 1;
64    let mut all_minus_1 = true;
65
66    for t in trits[0..TRIT_LENGTH - 1].iter() {
67        if *t != -1 {
68            all_minus_1 = false;
69            break;
70        }
71    }
72
73    if all_minus_1 {
74        base.clone_from_slice(&HALF_3);
75        bigint_not(base);
76        bigint_add_small(base, 1_u32);
77    } else {
78        for t in trits[0..TRIT_LENGTH - 1].iter().rev() {
79            // multiply by radix
80            {
81                let sz = size;
82                let mut carry: u32 = 0;
83
84                for j in 0..sz {
85                    let v = (base[j] as u64) * (RADIX as u64) + (carry as u64);
86                    let (newcarry, newbase) = ((v >> 32) as u32, v as u32);
87                    carry = newcarry;
88                    base[j] = newbase;
89                }
90
91                if carry > 0 {
92                    base[sz] = carry;
93                    size += 1;
94                }
95            }
96
97            let trit = (t + 1) as u32;
98            // addition
99            {
100                let sz = bigint_add_small(base, trit);
101                if sz > size {
102                    size = sz;
103                }
104            }
105        }
106
107        if !is_null(base) {
108            if bigint_cmp(&HALF_3, base) <= 0 {
109                // base >= HALF_3
110                // just do base - HALF_3
111                bigint_sub(base, &HALF_3);
112            } else {
113                // we don't have a wrapping sub.
114                // so let's use some bit magic to achieve it
115                let mut tmp = HALF_3.clone();
116                bigint_sub(&mut tmp, base);
117                bigint_not(&mut tmp);
118                bigint_add_small(&mut tmp, 1_u32);
119                base.clone_from_slice(&tmp);
120            }
121        }
122    }
123
124    bytes.reverse();
125}
126
127    /// This will consume the input bytes slice and write to trits.
128fn bytes_to_trits(bytes: &mut [u8], trits: &mut [Trit]) {
129    assert_eq!(bytes.len(), BYTE_LENGTH);
130    assert_eq!(trits.len(), TRIT_LENGTH);
131
132    trits[TRIT_LENGTH - 1] = 0;
133
134    bytes.reverse();
135    // We _know_ that the sizes match.
136    // So this is safe enough to do and saves us a few allocations.
137    let base: &mut [u32] =
138        unsafe { core::slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut u32, 12) };
139
140    if is_null(base) {
141        trits.clone_from_slice(&[0; TRIT_LENGTH]);
142        return;
143    }
144
145    let mut flip_trits = false;
146
147    if base[INT_LENGTH - 1] >> 31 == 0 {
148        // positive number
149        // we need to add HALF_3 to move it into positvie unsigned space
150        bigint_add(base, &HALF_3);
151    } else {
152        // negative number
153        bigint_not(base);
154        if bigint_cmp(base, &HALF_3) > 0 {
155            bigint_sub(base, &HALF_3);
156            flip_trits = true;
157        } else {
158            bigint_add_small(base, 1 as u32);
159            let mut tmp = HALF_3.clone();
160            bigint_sub(&mut tmp, base);
161            base.clone_from_slice(&tmp);
162        }
163    }
164
165    let mut rem;
166    for i in 0..TRIT_LENGTH - 1 {
167        rem = 0;
168        for j in (0..INT_LENGTH).rev() {
169            let lhs = ((rem as u64) << 32) | (base[j] as u64);
170            let rhs = RADIX as u64;
171            let q = (lhs / rhs) as u32;
172            let r = (lhs % rhs) as u32;
173
174            base[j] = q;
175            rem = r;
176        }
177        trits[i] = rem as i8 - 1;
178    }
179
180    if flip_trits {
181        for v in trits.iter_mut() {
182            *v = -*v;
183        }
184    }
185}
186
187fn bigint_not(base: &mut [u32]) {
188    for i in base.iter_mut() {
189        *i = !*i;
190    }
191}
192
193fn bigint_add_small(base: &mut [u32], other: u32) -> usize {
194    let (mut carry, v) = full_add(base[0], other, false);
195    base[0] = v;
196
197    let mut i = 1;
198    while carry {
199        let (c, v) = full_add(base[i], 0, carry);
200        base[i] = v;
201        carry = c;
202        i += 1;
203    }
204
205    i
206}
207
208fn bigint_add(base: &mut [u32], rh: &[u32]) {
209    let mut carry = false;
210
211    for (a, b) in base.iter_mut().zip(rh.iter()) {
212        let (c, v) = full_add(*a, *b, carry);
213        *a = v;
214        carry = c;
215    }
216}
217
218fn bigint_cmp(lh: &[u32], rh: &[u32]) -> i8 {
219    for (a, b) in lh.iter().rev().zip(rh.iter().rev()) {
220        if a < b {
221            return -1;
222        } else if a > b {
223            return 1;
224        }
225    }
226    return 0;
227}
228
229fn bigint_sub(base: &mut [u32], rh: &[u32]) {
230    let mut noborrow = true;
231    for (a, b) in base.iter_mut().zip(rh) {
232        let (c, v) = full_add(*a, !*b, noborrow);
233        *a = v;
234        noborrow = c;
235    }
236    assert!(noborrow);
237}
238
239fn is_null(base: &[u32]) -> bool {
240    for b in base.iter() {
241        if *b != 0 {
242            return false;
243        }
244    }
245    return true;
246}
247
248fn full_add(lh: u32, rh: u32, carry: bool) -> (bool, u32) {
249    let a = u64::from(lh);
250    let b = u64::from(rh);
251
252    let mut v = a + b;
253    let mut l = v >> 32;
254    let mut r = v & 0xFFFF_FFFF;
255
256    let carry1 = l != 0;
257
258    if carry {
259        v = r + 1;
260    }
261    l = (v >> 32) & 0xFFFF_FFFF;
262    r = v & 0xFFFF_FFFF;
263    let carry2 = l != 0;
264    (carry1 || carry2, r as u32)
265}