1use argon2::{Algorithm, Argon2, Params, Version};
13use hkdf::Hkdf;
14use rand::RngCore;
15use sha2::Sha256;
16use zeroize::{Zeroize, Zeroizing};
17
18use crate::Error;
19
20#[derive(Clone, Debug)]
22pub struct Salt([u8; 32]);
23
24impl Salt {
25 pub fn as_bytes(&self) -> &[u8; 32] {
26 &self.0
27 }
28
29 pub fn to_hex(&self) -> String {
30 hex::encode(self.0)
31 }
32
33 pub fn from_hex(s: &str) -> Result<Self, Error> {
34 let bytes = hex::decode(s).map_err(|e| Error::InvalidHex(e.to_string()))?;
35 let arr: [u8; 32] = bytes
36 .try_into()
37 .map_err(|v: Vec<u8>| Error::InvalidLength {
38 expected: 32,
39 got: v.len(),
40 })?;
41 Ok(Self(arr))
42 }
43
44 pub fn from_bytes(bytes: [u8; 32]) -> Self {
45 Self(bytes)
46 }
47}
48
49impl Drop for Salt {
50 fn drop(&mut self) {
51 self.0.zeroize();
52 }
53}
54
55pub struct DerivedKey(Zeroizing<[u8; 32]>);
57
58impl DerivedKey {
59 pub fn as_bytes(&self) -> &[u8; 32] {
60 &self.0
61 }
62
63 pub fn from_bytes(bytes: [u8; 32]) -> Self {
66 Self(Zeroizing::new(bytes))
67 }
68}
69
70pub fn generate_salt() -> Salt {
72 let mut bytes = [0u8; 32];
73 rand::thread_rng().fill_bytes(&mut bytes);
74 Salt(bytes)
75}
76
77pub fn derive_argon2id(passphrase: &str, salt: &Salt) -> Result<DerivedKey, Error> {
82 #[cfg(any(feature = "fast-kdf", debug_assertions))]
83 let params = Params::new(256, 1, 1, Some(32)).map_err(|e| Error::Kdf(e.to_string()))?;
84 #[cfg(not(any(feature = "fast-kdf", debug_assertions)))]
85 let params = Params::new(65536, 3, 4, Some(32)).map_err(|e| Error::Kdf(e.to_string()))?;
86 let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
87
88 let mut output = Zeroizing::new([0u8; 32]);
89 argon2
90 .hash_password_into(passphrase.as_bytes(), salt.as_bytes(), output.as_mut())
91 .map_err(|e| Error::Kdf(e.to_string()))?;
92
93 Ok(DerivedKey(output))
94}
95
96pub fn derive_hkdf_sha256(ikm: &[u8], salt: &[u8], info: &[u8]) -> [u8; 32] {
102 let hk = Hkdf::<Sha256>::new(Some(salt), ikm);
103 let mut out = [0u8; 32];
104 hk.expand(info, &mut out)
105 .expect("HKDF-SHA256 expand to 32 bytes never fails");
106 out
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 const TEST_PASSPHRASE: &str = "correct horse battery staple extra words";
114
115 #[test]
116 fn salt_is_random() {
117 let s1 = generate_salt();
118 let s2 = generate_salt();
119 assert_ne!(s1.as_bytes(), s2.as_bytes());
120 }
121
122 #[test]
123 fn salt_hex_roundtrip() {
124 let salt = generate_salt();
125 let hex = salt.to_hex();
126 let parsed = Salt::from_hex(&hex).unwrap();
127 assert_eq!(salt.as_bytes(), parsed.as_bytes());
128 }
129
130 #[test]
131 fn salt_invalid_length_rejected() {
132 assert!(matches!(
133 Salt::from_hex("00").unwrap_err(),
134 Error::InvalidLength { expected: 32, .. }
135 ));
136 }
137
138 #[test]
139 fn argon2_deterministic() {
140 let salt =
141 Salt::from_hex("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")
142 .unwrap();
143 let k1 = derive_argon2id(TEST_PASSPHRASE, &salt).unwrap();
144 let k2 = derive_argon2id(TEST_PASSPHRASE, &salt).unwrap();
145 assert_eq!(k1.as_bytes(), k2.as_bytes());
146 }
147
148 #[test]
149 fn argon2_different_salt() {
150 let s1 = generate_salt();
151 let s2 = generate_salt();
152 let k1 = derive_argon2id(TEST_PASSPHRASE, &s1).unwrap();
153 let k2 = derive_argon2id(TEST_PASSPHRASE, &s2).unwrap();
154 assert_ne!(k1.as_bytes(), k2.as_bytes());
155 }
156
157 #[test]
158 fn argon2_different_passphrase() {
159 let salt = generate_salt();
160 let k1 = derive_argon2id("one two three four five six", &salt).unwrap();
161 let k2 = derive_argon2id("seven eight nine ten eleven twelve", &salt).unwrap();
162 assert_ne!(k1.as_bytes(), k2.as_bytes());
163 }
164
165 #[test]
166 fn hkdf_deterministic() {
167 let a = derive_hkdf_sha256(b"ikm", b"salt", b"info");
168 let b = derive_hkdf_sha256(b"ikm", b"salt", b"info");
169 assert_eq!(a, b);
170 }
171
172 #[test]
173 fn hkdf_domain_separated_by_info() {
174 let a = derive_hkdf_sha256(b"ikm", b"salt", b"context-a");
175 let b = derive_hkdf_sha256(b"ikm", b"salt", b"context-b");
176 assert_ne!(a, b);
177 }
178
179 #[test]
180 fn hkdf_responds_to_salt() {
181 let a = derive_hkdf_sha256(b"ikm", b"salt-a", b"info");
182 let b = derive_hkdf_sha256(b"ikm", b"salt-b", b"info");
183 assert_ne!(a, b);
184 }
185}