1use std::{fmt, num::NonZeroU32};
22
23use ring::pbkdf2;
24use zeroize::Zeroize;
25
26use crate::{
27 aes::{self, AesMasterKey},
28 rng::Crng,
29};
30
31static PBKDF2_ALGORITHM: pbkdf2::Algorithm = pbkdf2::PBKDF2_HMAC_SHA256;
33const PBKDF2_ITERATIONS: NonZeroU32 = NonZeroU32::new(600_000).unwrap();
37
38const AES_KEY_LEN: usize = ring::digest::SHA256_OUTPUT_LEN;
40
41pub const MIN_PASSWORD_LENGTH: usize = 12;
44pub const MAX_PASSWORD_LENGTH: usize = 512;
47lexe_std::const_assert!(MIN_PASSWORD_LENGTH < MAX_PASSWORD_LENGTH);
48
49#[derive(Clone, Debug)]
50pub enum Error {
51 PasswordTooShort,
52 PasswordTooLong,
53 AesDecrypt(aes::DecryptError),
54}
55
56impl std::error::Error for Error {}
57impl fmt::Display for Error {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 match self {
60 Self::PasswordTooShort => write!(
61 f,
62 "Password must have at least {MIN_PASSWORD_LENGTH} characters"
63 ),
64 Self::PasswordTooLong => write!(
65 f,
66 "Password cannot have more than {MAX_PASSWORD_LENGTH} characters"
67 ),
68 Self::AesDecrypt(err) => err.fmt(f),
69 }
70 }
71}
72impl From<aes::DecryptError> for Error {
73 fn from(err: aes::DecryptError) -> Self {
74 Self::AesDecrypt(err)
75 }
76}
77
78pub fn encrypt(
98 rng: &mut impl Crng,
99 password: &str,
100 salt: &[u8; 32],
101 data: &[u8],
102) -> Result<Vec<u8>, Error> {
103 validate_password_len(password)?;
104
105 let aes_key = derive_aes_key(password, salt);
107
108 let aad = &[salt.as_slice()];
110 let data_size_hint = Some(data.len());
111 let write_data_cb = |buf: &mut Vec<u8>| buf.extend_from_slice(data);
114 let ciphertext = aes_key.encrypt(rng, aad, data_size_hint, &write_data_cb);
115
116 Ok(ciphertext)
117}
118
119pub fn decrypt(
121 password: &str,
122 salt: &[u8; 32],
123 ciphertext: Vec<u8>,
124) -> Result<Vec<u8>, Error> {
125 validate_password_len(password)?;
127
128 let aes_key = derive_aes_key(password, salt);
130
131 let aad = &[salt.as_slice()];
133 let data = aes_key.decrypt(aad, ciphertext)?;
134
135 Ok(data)
136}
137
138pub fn validate_password_len(password: &str) -> Result<(), Error> {
142 let password_length = password.chars().count();
143 if password_length < MIN_PASSWORD_LENGTH {
144 return Err(Error::PasswordTooShort);
145 }
146 if password_length > MAX_PASSWORD_LENGTH {
147 return Err(Error::PasswordTooLong);
148 }
149 Ok(())
150}
151
152fn derive_aes_key(password: &str, salt: &[u8; 32]) -> AesMasterKey {
155 let mut aes_key_buf = [0u8; AES_KEY_LEN];
156 pbkdf2::derive(
157 PBKDF2_ALGORITHM,
158 PBKDF2_ITERATIONS,
159 salt,
160 password.as_bytes(),
161 &mut aes_key_buf,
162 );
163 let aes_key = AesMasterKey::new(&aes_key_buf);
164 aes_key_buf.zeroize();
166 aes_key
167}
168
169#[cfg(test)]
170mod test {
171 use lexe_hex::hex;
172 use proptest::{
173 arbitrary::any, proptest, strategy::Strategy, test_runner::Config,
174 };
175
176 use super::*;
177 use crate::rng::FastRng;
178
179 #[test]
180 fn encryption_roundtrip() {
181 let config = Config::with_cases(4);
183 let password_length_range = MIN_PASSWORD_LENGTH..MAX_PASSWORD_LENGTH;
184 let any_valid_password =
185 proptest::collection::vec(any::<char>(), password_length_range)
186 .prop_map(String::from_iter);
187 proptest!(config, |(
188 mut rng in any::<FastRng>(),
189 password in any_valid_password,
190 salt in any::<[u8; 32]>(),
191 data1 in any::<Vec<u8>>(),
192 )| {
193 let ciphertext =
194 encrypt(&mut rng, &password, &salt, &data1).unwrap();
195 let data2 = decrypt(&password, &salt, ciphertext).unwrap();
196 assert_eq!(data1, data2);
197 })
198 }
199
200 #[test]
202 fn decryption_compatibility() {
203 struct TestCase {
205 password: String,
206 salt: [u8; 32],
207 data1: &'static [u8],
208 maybe_ciphertext: Option<&'static str>,
209 }
210
211 let case0 = TestCase {
213 password: "medium-length!123123".to_owned(),
214 salt: [0u8; 32],
215 data1: b"",
216 maybe_ciphertext: Some(
217 "00a9ebf955ed070fe7acefe66e5a007b2c4165d3c2c23efc6a91d60a37e3a7b6181e4156d15d513cb9cee00739a226466e",
218 ),
219 };
220 let case1 = TestCase {
222 password: "passwordword".to_owned(),
223 salt: [69; 32],
224 data1: b"*jaw drops* awooga! hummina hummina bazooing!",
225 maybe_ciphertext: Some(
226 "00a9ebf955ed070fe7acefe66e5a007b2c4165d3c2c23efc6a91d60a37e3a7b6180c0d3cd90616335f13f5de7c9df0a1d89a7aec282b8083089c2360962e22db1a57685e82aea236c053b88495021767e0c17e05b3f72a86cfbbffc3724a",
227 ),
228 };
229 let password = (0u32..512)
231 .map(|i| char::from_u32(i).unwrap())
232 .collect::<String>();
233 let case2 = TestCase {
234 password,
235 salt: [69; 32],
236 data1: b"*jaw drops* awooga! hummina hummina bazooing!",
237 maybe_ciphertext: Some(
238 "00a9ebf955ed070fe7acefe66e5a007b2c4165d3c2c23efc6a91d60a37e3a7b618cf7a8ff3ea628ed33fb32428930340557454454258dedc67c9a3a5e350c2408ad82e6a8ac02779fd9df3f513364b6351301271cfd2c515fdca0cd15de0",
239 ),
240 };
241
242 for (i, case) in [case0, case1, case2].into_iter().enumerate() {
243 let TestCase {
244 password,
245 salt,
246 data1,
247 maybe_ciphertext,
248 } = case;
249
250 match maybe_ciphertext {
251 Some(cipherhext) => {
252 println!("Testing case {i}");
254 let ciphertext = hex::decode(cipherhext).unwrap();
255 let data2 = decrypt(&password, &salt, ciphertext).unwrap();
256 assert_eq!(data1, data2.as_slice());
257 }
258 None => {
259 let mut rng = FastRng::from_u64(20231016);
261 let ciphertext =
262 encrypt(&mut rng, &password, &salt, data1).unwrap();
263 let cipherhext = hex::display(&ciphertext);
264 println!("Case {i} ciphertext: {cipherhext}");
265 }
266 }
267 }
268 }
269}