huddle_protocol/crypto/
passphrase.rs1use argon2::{Algorithm, Argon2, Params, Version};
14use base64::Engine;
15use chacha20poly1305::aead::{Aead, AeadCore, KeyInit, OsRng};
16use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce};
17use rand::RngCore;
18use zeroize::Zeroizing;
19
20use crate::error::{ProtocolError, Result};
21
22pub const SALT_LEN: usize = 16;
23pub const KEY_LEN: usize = 32;
24pub const NONCE_LEN: usize = 12;
25
26pub fn random_salt() -> [u8; SALT_LEN] {
28 let mut salt = [0u8; SALT_LEN];
29 OsRng.fill_bytes(&mut salt);
30 salt
31}
32
33pub fn derive_key(passphrase: &str, salt: &[u8]) -> Result<[u8; KEY_LEN]> {
38 let zeroizing = derive_key_zeroizing(passphrase, salt)?;
39 Ok(*zeroizing)
43}
44
45pub fn derive_key_zeroizing(passphrase: &str, salt: &[u8]) -> Result<Zeroizing<[u8; KEY_LEN]>> {
49 let params = Params::new(65_536, 3, 4, Some(KEY_LEN))
50 .map_err(|e| ProtocolError::Session(format!("argon2 params: {e}")))?;
51 let argon = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
52 let mut out = Zeroizing::new([0u8; KEY_LEN]);
53 argon
54 .hash_password_into(passphrase.as_bytes(), salt, out.as_mut_slice())
55 .map_err(|e| ProtocolError::Session(format!("argon2 derive: {e}")))?;
56 Ok(out)
57}
58
59pub fn wrap(plaintext: &[u8], passphrase_key: &[u8; KEY_LEN]) -> Result<String> {
62 let cipher = ChaCha20Poly1305::new(Key::from_slice(passphrase_key));
63 let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng);
64 let ciphertext = cipher
65 .encrypt(&nonce, plaintext)
66 .map_err(|e| ProtocolError::Session(format!("wrap failed: {e}")))?;
67 let mut combined = Vec::with_capacity(NONCE_LEN + ciphertext.len());
68 combined.extend_from_slice(&nonce);
69 combined.extend_from_slice(&ciphertext);
70 Ok(base64::engine::general_purpose::STANDARD.encode(&combined))
71}
72
73pub fn unwrap(encoded: &str, passphrase_key: &[u8; KEY_LEN]) -> Result<Vec<u8>> {
75 let bytes = base64::engine::general_purpose::STANDARD
76 .decode(encoded)
77 .map_err(|e| ProtocolError::Session(format!("bad base64: {e}")))?;
78 if bytes.len() < NONCE_LEN + 16 {
79 return Err(ProtocolError::Session("wrapped key too short".into()));
80 }
81 let (nonce_bytes, ciphertext) = bytes.split_at(NONCE_LEN);
82 let cipher = ChaCha20Poly1305::new(Key::from_slice(passphrase_key));
83 let nonce = Nonce::from_slice(nonce_bytes);
84 cipher
85 .decrypt(nonce, ciphertext)
86 .map_err(|e| ProtocolError::Session(format!("unwrap failed (wrong passphrase?): {e}")))
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn derive_is_deterministic() {
95 let salt = [42u8; SALT_LEN];
96 let k1 = derive_key("hunter2", &salt).unwrap();
97 let k2 = derive_key("hunter2", &salt).unwrap();
98 assert_eq!(k1, k2);
99 }
100
101 #[test]
102 fn different_passphrases_different_keys() {
103 let salt = [42u8; SALT_LEN];
104 let k1 = derive_key("hunter2", &salt).unwrap();
105 let k2 = derive_key("hunter3", &salt).unwrap();
106 assert_ne!(k1, k2);
107 }
108
109 #[test]
110 fn different_salts_different_keys() {
111 let k1 = derive_key("same", &[1u8; SALT_LEN]).unwrap();
112 let k2 = derive_key("same", &[2u8; SALT_LEN]).unwrap();
113 assert_ne!(k1, k2);
114 }
115
116 #[test]
117 fn wrap_unwrap_round_trip() {
118 let salt = random_salt();
119 let key = derive_key("hunter2", &salt).unwrap();
120 let secret = b"this is a megolm session key";
121 let wrapped = wrap(secret, &key).unwrap();
122 let recovered = unwrap(&wrapped, &key).unwrap();
123 assert_eq!(recovered, secret);
124 }
125
126 #[test]
127 fn wrong_passphrase_fails_unwrap() {
128 let salt = random_salt();
129 let right_key = derive_key("hunter2", &salt).unwrap();
130 let wrong_key = derive_key("hunter3", &salt).unwrap();
131 let wrapped = wrap(b"secret", &right_key).unwrap();
132 assert!(unwrap(&wrapped, &wrong_key).is_err());
133 }
134
135 #[test]
136 fn wrapped_output_is_nondeterministic() {
137 let salt = random_salt();
138 let key = derive_key("hunter2", &salt).unwrap();
139 let w1 = wrap(b"hello", &key).unwrap();
140 let w2 = wrap(b"hello", &key).unwrap();
141 assert_ne!(w1, w2, "nonce should differ each time");
142 }
143}