1use argon2::{Algorithm, Argon2, Params, Version};
10use rand::RngCore;
11use zeroize::{Zeroize, Zeroizing};
12
13use crate::error::JoyError;
14
15#[derive(Clone)]
17pub struct Salt([u8; 32]);
18
19impl Salt {
20 pub fn as_bytes(&self) -> &[u8; 32] {
21 &self.0
22 }
23
24 pub fn to_hex(&self) -> String {
25 hex::encode(self.0)
26 }
27
28 pub fn from_hex(s: &str) -> Result<Self, JoyError> {
29 let bytes =
30 hex::decode(s).map_err(|e| JoyError::AuthFailed(format!("invalid salt: {e}")))?;
31 let arr: [u8; 32] = bytes
32 .try_into()
33 .map_err(|_| JoyError::AuthFailed("salt must be 32 bytes".into()))?;
34 Ok(Self(arr))
35 }
36}
37
38pub struct DerivedKey(Zeroizing<[u8; 32]>);
40
41impl DerivedKey {
42 pub fn as_bytes(&self) -> &[u8; 32] {
43 &self.0
44 }
45}
46
47pub fn generate_salt() -> Salt {
49 let mut bytes = [0u8; 32];
50 rand::thread_rng().fill_bytes(&mut bytes);
51 Salt(bytes)
52}
53
54pub fn derive_key(passphrase: &str, salt: &Salt) -> Result<DerivedKey, JoyError> {
59 #[cfg(any(feature = "fast-kdf", debug_assertions))]
60 let params = Params::new(256, 1, 1, Some(32))
61 .map_err(|e| JoyError::AuthFailed(format!("argon2 params: {e}")))?;
62 #[cfg(not(any(feature = "fast-kdf", debug_assertions)))]
63 let params = Params::new(65536, 3, 4, Some(32))
64 .map_err(|e| JoyError::AuthFailed(format!("argon2 params: {e}")))?;
65 let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
66
67 let mut output = Zeroizing::new([0u8; 32]);
68 argon2
69 .hash_password_into(passphrase.as_bytes(), salt.as_bytes(), output.as_mut())
70 .map_err(|e| JoyError::AuthFailed(format!("key derivation failed: {e}")))?;
71
72 Ok(DerivedKey(output))
73}
74
75pub fn validate_passphrase(passphrase: &str) -> Result<(), JoyError> {
77 let word_count = passphrase.split_whitespace().count();
78 if word_count < 6 {
79 return Err(JoyError::PassphraseTooShort);
80 }
81 Ok(())
82}
83
84impl Drop for Salt {
85 fn drop(&mut self) {
86 self.0.zeroize();
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 const TEST_PASSPHRASE: &str = "correct horse battery staple extra words";
95
96 #[test]
97 fn salt_is_random() {
98 let s1 = generate_salt();
99 let s2 = generate_salt();
100 assert_ne!(s1.as_bytes(), s2.as_bytes());
101 }
102
103 #[test]
104 fn salt_hex_roundtrip() {
105 let salt = generate_salt();
106 let hex = salt.to_hex();
107 let parsed = Salt::from_hex(&hex).unwrap();
108 assert_eq!(salt.as_bytes(), parsed.as_bytes());
109 }
110
111 #[test]
112 fn derive_deterministic() {
113 let salt =
114 Salt::from_hex("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")
115 .unwrap();
116 let k1 = derive_key(TEST_PASSPHRASE, &salt).unwrap();
117 let k2 = derive_key(TEST_PASSPHRASE, &salt).unwrap();
118 assert_eq!(k1.as_bytes(), k2.as_bytes());
119 }
120
121 #[test]
122 fn derive_different_salt() {
123 let s1 = generate_salt();
124 let s2 = generate_salt();
125 let k1 = derive_key(TEST_PASSPHRASE, &s1).unwrap();
126 let k2 = derive_key(TEST_PASSPHRASE, &s2).unwrap();
127 assert_ne!(k1.as_bytes(), k2.as_bytes());
128 }
129
130 #[test]
131 fn derive_different_passphrase() {
132 let salt = generate_salt();
133 let k1 = derive_key("one two three four five six", &salt).unwrap();
134 let k2 = derive_key("seven eight nine ten eleven twelve", &salt).unwrap();
135 assert_ne!(k1.as_bytes(), k2.as_bytes());
136 }
137
138 #[test]
139 fn passphrase_too_short() {
140 assert!(validate_passphrase("one two three").is_err());
141 assert!(validate_passphrase("one two three four five").is_err());
142 }
143
144 #[test]
145 fn passphrase_valid() {
146 assert!(validate_passphrase("one two three four five six").is_ok());
147 assert!(validate_passphrase("a b c d e f g h").is_ok());
148 }
149}