1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
//! Implementation of [Firebase Scrypt](https://github.com/firebase/scrypt) in pure Rust.
//!
//! If you are only using the raw functions instead of the higher-level struct [`FirebaseScrypt`],
//! it's recommended to disable default features in your ``Cargo.toml``
//!
//! ```toml
//! [dependencies]
//! firebase-scrypt = { version = "0.1", default-features = false }
//! ```
//!
//! # Usage (with ``simple`` feature)
//! ```
//! use firebase_scrypt::FirebaseScrypt;
//!
//! const SALT_SEPARATOR: &str = "Bw==";
//! const SIGNER_KEY: &str = "jxspr8Ki0RYycVU8zykbdLGjFQ3McFUH0uiiTvC8pVMXAn210wjLNmdZJzxUECKbm0QsEmYUSDzZvpjeJ9WmXA==";
//! const ROUNDS: u32 = 8;
//! const MEM_COST: u32 = 14;
//!
//! let firebase_scrypt = FirebaseScrypt::new(SALT_SEPARATOR, SIGNER_KEY, ROUNDS, MEM_COST);
//!
//! let password = "user1password";
//! let salt = "42xEC+ixf3L2lw==";
//! let password_hash ="lSrfV15cpx95/sZS2W9c9Kp6i/LVgQNDNC/qzrCnh1SAyZvqmZqAjTdn3aoItz+VHjoZilo78198JAdRuid5lQ==";
//!
//! assert!(firebase_scrypt.verify_password(password, salt, password_hash).unwrap())
//! ```

use crate::errors::{DerivedKeyError, EncryptError, GenerateHashError};
use aes::{
    cipher::{KeyIvInit, StreamCipher},
    Aes256,
};
use constant_time_eq::constant_time_eq;
use ctr::Ctr128BE;
use scrypt::Params;

pub mod errors;
#[cfg(feature = "simple")]
mod simple;

#[cfg(feature = "simple")]
pub use simple::FirebaseScrypt;

const IV: [u8; 16] = *b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0";

fn clean(a: &str) -> String {
    a.replace("-", "+").replace("_", "/")
}

fn generate_derived_key<'a>(
    password: &'a str,
    salt: &'a str,
    salt_separator: &'a str,
    rounds: u32,
    mem_cost: u32,
) -> Result<[u8; 64], DerivedKeyError> {
    let log2_n = 2_f32.powf(mem_cost as f32).log2().floor() as u32;
    let p: u32 = 1;

    debug_assert!(log2_n < 64, "log2 of n must not be larger than 64");

    let mut salt = base64::decode(salt)?;
    salt.append(&mut base64::decode(salt_separator)?);
    let password = password.as_bytes();

    let params = Params::new(log2_n as u8, rounds, p)?;

    let mut result = [0u8; 64];
    scrypt::scrypt(password, salt.as_slice(), &params, &mut result)?;

    Ok(result)
}

fn encrypt(signer_key: &[u8], key: [u8; 32]) -> Result<Vec<u8>, EncryptError> {
    let mut cipher = Ctr128BE::<Aes256>::new(&key.into(), &IV.into());

    let mut buffer = vec![0u8; signer_key.len()];
    cipher.apply_keystream_b2b(signer_key, &mut buffer)?;

    Ok(buffer)
}

/// Verifies the password with a given known hash.
///
/// In case the salt separator, signer key, number of rounds and cost of memory don't change in
/// runtime, you may want to use the [`FirebaseScrypt`] struct to manage them.
///
/// # Example
/// ```
/// use firebase_scrypt::verify_password;
///
/// const SALT_SEPARATOR: &str = "Bw==";
/// const SIGNER_KEY: &str = "jxspr8Ki0RYycVU8zykbdLGjFQ3McFUH0uiiTvC8pVMXAn210wjLNmdZJzxUECKbm0QsEmYUSDzZvpjeJ9WmXA==";
/// const ROUNDS: u32 = 8;
/// const MEM_COST: u32 = 14;
///
/// let password = "user1password";
/// let salt = "42xEC+ixf3L2lw==";
/// let password_hash ="lSrfV15cpx95/sZS2W9c9Kp6i/LVgQNDNC/qzrCnh1SAyZvqmZqAjTdn3aoItz+VHjoZilo78198JAdRuid5lQ==";
///
/// let is_valid = verify_password(
///     password,
///     password_hash,
///     salt,
///     SALT_SEPARATOR,
///     SIGNER_KEY,
///     ROUNDS,
///     MEM_COST,
/// ).unwrap();
///
/// assert!(is_valid)
/// ```
pub fn verify_password(
    password: &str,
    known_hash: &str,
    salt: &str,
    salt_separator: &str,
    signer_key: &str,
    rounds: u32,
    mem_cost: u32,
) -> Result<bool, GenerateHashError> {
    let password_hash =
        generate_raw_hash(password, salt, salt_separator, signer_key, rounds, mem_cost)?;

    Ok(constant_time_eq(
        password_hash.as_slice(),
        base64::decode(clean(known_hash))?.as_slice(),
    ))
}

/// Generates a hash in the form of a [`Vec<u8>`]
///
/// In case you want or are using the same hash representation as Firebase, use the [`FirebaseScrypt`]
/// struct to get the Base64 hashed directly.
///
/// # Example (generate Base64 hash)
/// ```
/// // Base64 crate for encoding the hash
/// use base64::encode;
/// use firebase_scrypt::generate_raw_hash;
///
/// const SALT_SEPARATOR: &str = "Bw==";
/// const SIGNER_KEY: &str = "jxspr8Ki0RYycVU8zykbdLGjFQ3McFUH0uiiTvC8pVMXAn210wjLNmdZJzxUECKbm0QsEmYUSDzZvpjeJ9WmXA==";
/// const ROUNDS: u32 = 8;
/// const MEM_COST: u32 = 14;
///
/// let password = "user1password";
/// let salt = "42xEC+ixf3L2lw==";
/// let password_hash ="lSrfV15cpx95/sZS2W9c9Kp6i/LVgQNDNC/qzrCnh1SAyZvqmZqAjTdn3aoItz+VHjoZilo78198JAdRuid5lQ==";
///
/// let hash = encode(generate_raw_hash(
///     password,
///     salt,
///     SALT_SEPARATOR,
///     SIGNER_KEY,
///     ROUNDS,
///     MEM_COST,
/// ).unwrap());
///
/// assert_eq!(hash, password_hash);
/// ```
pub fn generate_raw_hash(
    password: &str,
    salt: &str,
    salt_separator: &str,
    signer_key: &str,
    rounds: u32,
    mem_cost: u32,
) -> Result<Vec<u8>, GenerateHashError> {
    let derived_key =
        generate_derived_key(password, &clean(salt), salt_separator, rounds, mem_cost)?;
    let signer_key = base64::decode(signer_key)?;

    let result = encrypt(signer_key.as_slice(), derived_key[..32].try_into().unwrap())?;
    Ok(base64::decode(base64::encode(result))?)
}

#[cfg(test)]
mod tests {
    const SALT_SEPARATOR: &str = "Bw==";
    const SIGNER_KEY: &str =
        "jxspr8Ki0RYycVU8zykbdLGjFQ3McFUH0uiiTvC8pVMXAn210wjLNmdZJzxUECKbm0QsEmYUSDzZvpjeJ9WmXA==";
    const ROUNDS: u32 = 8;
    const MEM_COST: u32 = 14;

    const PASSWORD: &str = "user1password";
    const SALT: &str = "42xEC+ixf3L2lw==";
    const PASSWORD_HASH: &str =
        "lSrfV15cpx95/sZS2W9c9Kp6i/LVgQNDNC/qzrCnh1SAyZvqmZqAjTdn3aoItz+VHjoZilo78198JAdRuid5lQ==";

    use super::*;

    #[test]
    fn verify_password_works() {
        assert!(verify_password(
            PASSWORD,
            PASSWORD_HASH,
            SALT,
            SALT_SEPARATOR,
            SIGNER_KEY,
            ROUNDS,
            MEM_COST
        )
        .unwrap())
    }

    #[test]
    fn generate_hash_works() {
        assert_eq!(
            base64::encode(
                generate_raw_hash(PASSWORD, SALT, SALT_SEPARATOR, SIGNER_KEY, ROUNDS, MEM_COST,)
                    .unwrap()
            ),
            PASSWORD_HASH
        )
    }

    #[test]
    fn encrypt_works() {
        let param_1 = b"randomrandomrandomrandomrandomrandomrandom";
        let param_2 = b"12345678901234567890123456789012";

        assert_eq!(
            hex::encode(encrypt(param_1, *param_2).unwrap()),
            "09f509fa3d09cde568f80709416681e4ed5d9677ca8b4807a932869ba3fd057be3606c2940877850ed96"
        );
    }

    #[test]
    fn generate_derived_key_works() {
        assert_eq!(hex::encode(generate_derived_key(PASSWORD, SALT, SALT_SEPARATOR, ROUNDS, MEM_COST).unwrap()), "e87fa22d9b4e3be6bbd41214f2f98f8c78b694bd17e12c2b73501054a2099ce11fe896483c68a443c6cf9ff8a8dfe1dfe2adaa4be6c8ca1b7686687a26f48831");
    }
}