dup_crypto/
scrypt.rs

1//  Copyright (C) 2020  Éloïs SANCHEZ.
2//
3// This program is free software: you can redistribute it and/or modify
4// it under the terms of the GNU Affero General Public License as
5// published by the Free Software Foundation, either version 3 of the
6// License, or (at your option) any later version.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11// GNU Affero General Public License for more details.
12//
13// You should have received a copy of the GNU Affero General Public License
14// along with this program.  If not, see <https://www.gnu.org/licenses/>.
15
16//! Manage cryptographic operations for DUniter Protocols and the Duniter eco-system most broadly.
17//!
18//! Scrypt
19
20pub mod params;
21
22use std::{iter::repeat, mem::MaybeUninit, num::NonZeroU32, ptr};
23
24const ITERATIONS: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(1) };
25
26#[cfg(target_arch = "wasm32")]
27#[cfg(not(tarpaulin_include))]
28fn pbkdf2(salt: &[u8], password: &[u8], output: &mut [u8]) {
29    let mut hmac = cryptoxide::hmac::Hmac::new(cryptoxide::sha2::Sha256::new(), password);
30    cryptoxide::pbkdf2::pbkdf2(&mut hmac, salt, ITERATIONS.get(), output);
31}
32#[cfg(not(target_arch = "wasm32"))]
33fn pbkdf2(salt: &[u8], password: &[u8], output: &mut [u8]) {
34    ring::pbkdf2::derive(
35        ring::pbkdf2::PBKDF2_HMAC_SHA256,
36        ITERATIONS,
37        salt,
38        password,
39        output,
40    );
41}
42
43/**
44 * The scrypt key derivation function.
45 *
46 * # Arguments
47 *
48 * * `password` - The password to process as a byte vector
49 * * `salt` - The salt value to use as a byte vector
50 * * `params` - The `ScryptParams` to use
51 * * `output` - The resulting derived key is returned in this byte vector.
52 *
53 */
54pub fn scrypt(password: &[u8], salt: &[u8], params: &params::ScryptParams, output: &mut [u8]) {
55    // This check required by Scrypt:
56    // check output.len() > 0 && output.len() <= (2^32 - 1) * 32
57    assert!(!output.is_empty());
58    assert!(output.len() / 32 <= 0xffff_ffff);
59
60    // The checks in the ScryptParams constructor guarantee that the following is safe:
61    let n = 1 << params.log_n;
62    let r128 = (params.r as usize) * 128;
63    let pr128 = (params.p as usize) * r128;
64    let nr128 = n * r128;
65
66    let mut b: Vec<u8> = repeat(0).take(pr128).collect();
67
68    pbkdf2(salt, password, b.as_mut_slice());
69
70    let mut v: Vec<u8> = repeat(0).take(nr128).collect();
71    let mut t: Vec<u8> = repeat(0).take(r128).collect();
72
73    for chunk in b.as_mut_slice().chunks_mut(r128) {
74        scrypt_ro_mix(chunk, v.as_mut_slice(), t.as_mut_slice(), n);
75    }
76
77    pbkdf2(&b, password, output);
78}
79
80// Execute the ROMix operation in-place.
81// b - the data to operate on
82// v - a temporary variable to store the vector V
83// t - a temporary variable to store the result of the xor
84// n - the scrypt parameter N
85#[allow(clippy::many_single_char_names)]
86fn scrypt_ro_mix(b: &mut [u8], v: &mut [u8], t: &mut [u8], n: usize) {
87    fn integerify(x: &[u8], n: usize) -> usize {
88        // n is a power of 2, so n - 1 gives us a bitmask that we can use to perform a calculation
89        // mod n using a simple bitwise and.
90        let mask = n - 1;
91        // This cast is safe since we're going to get the value mod n (which is a power of 2), so we
92        // don't have to care about truncating any of the high bits off
93        (read_u32_le(&x[x.len() - 64..x.len() - 60]) as usize) & mask
94    }
95
96    let len = b.len();
97
98    for chunk in v.chunks_mut(len) {
99        copy_memory(b, chunk);
100        scrypt_block_mix(chunk, b);
101    }
102
103    for _ in 0..n {
104        let j = integerify(b, n);
105        xor(b, &v[j * len..(j + 1) * len], t);
106        scrypt_block_mix(t, b);
107    }
108}
109
110/// Copy bytes from src to dest
111#[inline]
112fn copy_memory(src: &[u8], dst: &mut [u8]) {
113    assert!(dst.len() >= src.len());
114    unsafe {
115        let srcp = src.as_ptr();
116        let dstp = dst.as_mut_ptr();
117        ptr::copy_nonoverlapping(srcp, dstp, src.len());
118    }
119}
120
121#[allow(clippy::uninit_assumed_init)]
122fn read_u32_le(input: &[u8]) -> u32 {
123    assert_eq!(input.len(), 4);
124    unsafe {
125        let mut tmp: [u8; 4] = MaybeUninit::uninit().assume_init();
126        ptr::copy_nonoverlapping(input.get_unchecked(0), tmp.as_mut_ptr(), 4);
127        u32::from_le_bytes(tmp)
128    }
129}
130
131#[allow(clippy::uninit_assumed_init)]
132fn read_u32v_le(dst: &mut [u32], input: &[u8]) {
133    assert_eq!(dst.len() * 4, input.len());
134    unsafe {
135        let mut x: *mut u32 = dst.get_unchecked_mut(0);
136        let mut y: *const u8 = input.get_unchecked(0);
137        for _ in 0..dst.len() {
138            let mut tmp: [u8; 4] = MaybeUninit::uninit().assume_init();
139            ptr::copy_nonoverlapping(y, tmp.as_mut_ptr(), 4);
140            *x = u32::from_le_bytes(tmp);
141            x = x.offset(1);
142            y = y.offset(4);
143        }
144    }
145}
146
147fn write_u32_le(dst: &mut [u8], mut input: u32) {
148    assert_eq!(dst.len(), 4);
149    input = input.to_le();
150    unsafe {
151        let mut tmp = std::mem::transmute::<u32, [u8; 4]>(input);
152        ptr::copy_nonoverlapping(tmp.as_mut_ptr(), dst.get_unchecked_mut(0), 4);
153    }
154}
155
156fn xor(x: &[u8], y: &[u8], output: &mut [u8]) {
157    for ((out, &x_i), &y_i) in output.iter_mut().zip(x.iter()).zip(y.iter()) {
158        *out = x_i ^ y_i;
159    }
160}
161
162// Execute the BlockMix operation
163// input - the input vector. The length must be a multiple of 128.
164// output - the output vector. Must be the same length as input.
165fn scrypt_block_mix(input: &[u8], output: &mut [u8]) {
166    let mut x = [0u8; 64];
167    copy_memory(&input[input.len() - 64..], &mut x);
168
169    let mut t = [0u8; 64];
170
171    for (i, chunk) in input.chunks(64).enumerate() {
172        xor(&x, chunk, &mut t);
173        salsa20_8(&t, &mut x);
174        let pos = if i % 2 == 0 {
175            (i / 2) * 64
176        } else {
177            (i / 2) * 64 + input.len() / 2
178        };
179        copy_memory(&x, &mut output[pos..pos + 64]);
180    }
181}
182
183// The salsa20/8 core function.
184fn salsa20_8(input: &[u8], output: &mut [u8]) {
185    let mut x = [0u32; 16];
186    read_u32v_le(&mut x, input);
187
188    let rounds = 8;
189
190    macro_rules! run_round (
191        ($($set_idx:expr, $idx_a:expr, $idx_b:expr, $rot:expr);*) => { {
192            $( x[$set_idx] ^= x[$idx_a].wrapping_add(x[$idx_b]).rotate_left($rot); )*
193        } }
194    );
195
196    for _ in 0..rounds / 2 {
197        run_round!(
198            0x4, 0x0, 0xc, 7;
199            0x8, 0x4, 0x0, 9;
200            0xc, 0x8, 0x4, 13;
201            0x0, 0xc, 0x8, 18;
202            0x9, 0x5, 0x1, 7;
203            0xd, 0x9, 0x5, 9;
204            0x1, 0xd, 0x9, 13;
205            0x5, 0x1, 0xd, 18;
206            0xe, 0xa, 0x6, 7;
207            0x2, 0xe, 0xa, 9;
208            0x6, 0x2, 0xe, 13;
209            0xa, 0x6, 0x2, 18;
210            0x3, 0xf, 0xb, 7;
211            0x7, 0x3, 0xf, 9;
212            0xb, 0x7, 0x3, 13;
213            0xf, 0xb, 0x7, 18;
214            0x1, 0x0, 0x3, 7;
215            0x2, 0x1, 0x0, 9;
216            0x3, 0x2, 0x1, 13;
217            0x0, 0x3, 0x2, 18;
218            0x6, 0x5, 0x4, 7;
219            0x7, 0x6, 0x5, 9;
220            0x4, 0x7, 0x6, 13;
221            0x5, 0x4, 0x7, 18;
222            0xb, 0xa, 0x9, 7;
223            0x8, 0xb, 0xa, 9;
224            0x9, 0x8, 0xb, 13;
225            0xa, 0x9, 0x8, 18;
226            0xc, 0xf, 0xe, 7;
227            0xd, 0xc, 0xf, 9;
228            0xe, 0xd, 0xc, 13;
229            0xf, 0xe, 0xd, 18
230        )
231    }
232
233    for i in 0..16 {
234        write_u32_le(
235            &mut output[i * 4..(i + 1) * 4],
236            x[i].wrapping_add(read_u32_le(&input[i * 4..(i + 1) * 4])),
237        );
238    }
239}
240
241#[cfg(test)]
242mod tests {
243
244    use super::*;
245
246    #[test]
247    fn test_scrypt_ring() {
248        let salt = "JhxtHB7UcsDbA9wMSyMKXUzBZUQvqVyB32KwzS9SWoLkjrUhHV".as_bytes();
249        let password = "JhxtHB7UcsDbA9wMSyMKXUzBZUQvqVyB32KwzS9SWoLkjrUhHV_".as_bytes();
250
251        let mut seed = [0u8; 32];
252        let now = std::time::Instant::now();
253        scrypt(
254            password,
255            salt,
256            &params::ScryptParams::default(),
257            seed.as_mut(),
258        );
259        println!("{} ms", now.elapsed().as_millis());
260
261        assert_eq!(
262            seed,
263            [
264                144u8, 5, 70, 118, 105, 47, 173, 220, 39, 152, 105, 7, 211, 34, 125, 146, 183, 172,
265                41, 54, 154, 134, 125, 97, 1, 125, 241, 95, 96, 6, 79, 150
266            ]
267        );
268    }
269
270    #[test]
271    fn test_scrypt_short_seed() {
272        let salt = "JhxtHB7UcsDbA9wMSyMKXUzBZUQvqVyB32KwzS9SWoLkjrUhHV".as_bytes();
273        let password = "JhxtHB7UcsDbA9wMSyMKXUzBZUQvqVyB32KwzS9SWoLkjrUhHV_".as_bytes();
274
275        let mut seed = [0u8; 26];
276        let now = std::time::Instant::now();
277        scrypt(
278            password,
279            salt,
280            &params::ScryptParams::default(),
281            seed.as_mut(),
282        );
283        println!("{} ms", now.elapsed().as_millis());
284
285        assert_eq!(
286            seed,
287            [
288                144u8, 5, 70, 118, 105, 47, 173, 220, 39, 152, 105, 7, 211, 34, 125, 146, 183, 172,
289                41, 54, 154, 134, 125, 97, 1, 125,
290            ]
291        );
292    }
293
294    #[test]
295    fn test_scrypt_long_seed() {
296        let salt = "JhxtHB7UcsDbA9wMSyMKXUzBZUQvqVyB32KwzS9SWoLkjrUhHV".as_bytes();
297        let password = "JhxtHB7UcsDbA9wMSyMKXUzBZUQvqVyB32KwzS9SWoLkjrUhHV_".as_bytes();
298
299        let mut seed = [0u8; 68];
300        let now = std::time::Instant::now();
301        scrypt(
302            password,
303            salt,
304            &params::ScryptParams::default(),
305            seed.as_mut(),
306        );
307        println!("{} ms", now.elapsed().as_millis());
308
309        assert_eq!(
310            seed[..32],
311            [
312                144u8, 5, 70, 118, 105, 47, 173, 220, 39, 152, 105, 7, 211, 34, 125, 146, 183, 172,
313                41, 54, 154, 134, 125, 97, 1, 125, 241, 95, 96, 6, 79, 150
314            ]
315        );
316
317        println!("{:?}", seed);
318    }
319}