1use chacha20poly1305::{
4 ChaCha20Poly1305, Nonce,
5 aead::{Aead, KeyInit},
6};
7use hkdf::Hkdf;
8use sha2::Sha256;
9use zeroize::{Zeroize, Zeroizing};
10
11use crate::error::ProtocolError;
12
13pub const SAS_EMOJI: [&str; 256] = [
15 "ðķ", "ðą", "ð", "ðđ", "ð°", "ðĶ", "ðŧ", "ðž", "ðĻ", "ðŊ", "ðĶ", "ðŪ", "ð·", "ðļ", "ðĩ", "ð",
16 "ð§", "ðĶ", "ðĶ", "ðĶ
", "ðĶ", "ðš", "ð", "ðī", "ðĶ", "ð", "ð", "ðĶ", "ð", "ð", "ð", "ðŠē",
17 "ðĒ", "ð", "ðĶ", "ðĶ", "ð", "ðĶ", "ðĶ", "ðĶ", "ð ", "ðĄ", "ðŽ", "ðĶ", "ðģ", "ð", "ð", "ð",
18 "ð
", "ðĶ", "ðĶ", "ðͧ", "ð", "ðĶ", "ðĶ", "ðŠ", "ðĶ", "ðĶ", "ðĶŽ", "ð", "ð", "ð", "ð", "ð",
19 "ð", "ð", "ð", "ðĶ", "ð", "ðĐ", "ðĶŪ", "ð", "ð", "ðĶ", "ðĶĪ", "ðĶ", "ðĶ", "ðĶĒ", "ðĶĐ", "ðïļ",
20 "ð", "ðĶ", "ðĶĻ", "ðĶĄ", "ðĶŦ", "ðĶĶ", "ðĶĨ", "ð", "ð", "ðŋïļ", "ðĶ", "ðĩ", "ð", "ðē", "ðģ", "ðī",
21 "ðŠĩ", "ðą", "ðŋ", "âïļ", "ð", "ð", "ðŠī", "ð", "ð", "ð", "ð", "ðū", "ðš", "ðŧ", "ðđ", "ðĨ",
22 "ð·", "ðž", "ð", "ð", "ð°", "ð", "ð", "ð", "ð", "ð", "ð", "ð", "ð", "ð", "ð", "ð",
23 "ð", "ð", "â", "ð", "ðŦ", "âĻ", "âïļ", "ðĪïļ", "â
", "ðĨïļ", "ðĶïļ", "ð§ïļ", "âïļ", "ðĐïļ", "ðĻïļ", "âïļ",
24 "âïļ", "â", "ðŽïļ", "ðĻ", "ðŠïļ", "ðŦïļ", "ð", "ð§", "ðĶ", "ðĨ", "ðŊ", "ð", "ð", "âū", "ðĨ", "ðū",
25 "ð", "ð", "ðĨ", "ðą", "ð", "ðļ", "ð", "ðĨ", "ðŋ", "â·ïļ", "ð", "ðŠ", "ðïļ", "ðĪļ", "âđïļ", "ðĪš",
26 "ð", "ð§", "ð", "ð", "ðĢ", "ð§", "ðī", "ð", "ðĨ", "ðĨ", "ðĨ", "ð
", "ðïļ", "ðŠ", "ðĻ", "ð",
27 "ðđ", "ðĨ", "ð·", "ðš", "ðļ", "ðŠ", "ðŧ", "ðŽ", "ðŪ", "ðđïļ", "ðē", "ð§Đ", "ðŪ", "ðŠ", "ð§ŋ", "ð°",
28 "ð", "âïļ", "ðļ", "ð", "ðķ", "âĩ", "ðĪ", "ðĨïļ", "ð", "ð", "ð", "ð
", "ð", "ð", "ð", "ð",
29 "ð ", "ðĄ", "ðĒ", "ðĢ", "ðĪ", "ðĨ", "ðĶ", "ðĻ", "ðĐ", "ðŠ", "ðŦ", "ðŽ", "ð", "ðŊ", "ð°", "ð",
30 "ðž", "ð―", "âŠ", "ð", "ð", "ð", "âĐïļ", "ð", "âē", "âš", "ð", "ðŧ", "ð", "ðū", "ðïļ", "ð ",
31];
32
33const NONCE_LEN: usize = 12;
34const TAG_LEN: usize = 16;
35
36pub fn derive_sas(
49 shared_secret: &[u8; 32],
50 initiator_pub: &[u8; 32],
51 responder_pub: &[u8; 32],
52 short_code: &str,
53) -> [u8; 8] {
54 let salt = build_salt(initiator_pub, responder_pub);
55 let info = build_info(b"auths-pairing-sas-v1", short_code);
56
57 let hk = Hkdf::<Sha256>::new(Some(&salt), shared_secret);
58 let mut out = [0u8; 8];
59 let _ = hk.expand(&info, &mut out);
61 out
62}
63
64pub fn format_sas_emoji(sas_bytes: &[u8; 8]) -> String {
66 sas_bytes[..4]
67 .iter()
68 .map(|&b| SAS_EMOJI[b as usize])
69 .collect::<Vec<_>>()
70 .join(" ")
71}
72
73pub fn format_sas_numeric(sas_bytes: &[u8; 8]) -> String {
75 let val =
76 u32::from_be_bytes([sas_bytes[4], sas_bytes[5], sas_bytes[6], sas_bytes[7]]) % 1_000_000;
77 format!("{:03}-{:03}", val / 1000, val % 1000)
78}
79
80pub struct TransportKey(Zeroizing<[u8; 32]>);
85
86impl TransportKey {
87 pub fn new(key: [u8; 32]) -> Self {
88 Self(Zeroizing::new(key))
89 }
90
91 pub fn encrypt(mut self, plaintext: &[u8]) -> Result<Vec<u8>, ProtocolError> {
95 let cipher = ChaCha20Poly1305::new_from_slice(&*self.0)
96 .map_err(|_| ProtocolError::EncryptionFailed("invalid key".into()))?;
97 let nonce_bytes: [u8; NONCE_LEN] = rand::random();
98 let nonce = Nonce::from(nonce_bytes);
99 let ciphertext = cipher
100 .encrypt(&nonce, plaintext)
101 .map_err(|_| ProtocolError::EncryptionFailed("encryption failed".into()))?;
102
103 self.0.zeroize();
104
105 let mut out = Vec::with_capacity(NONCE_LEN + ciphertext.len());
106 out.extend_from_slice(&nonce_bytes);
107 out.extend_from_slice(&ciphertext);
108 Ok(out)
109 }
110
111 pub fn as_bytes(&self) -> &[u8; 32] {
113 &self.0
114 }
115}
116
117pub fn decrypt_from_transport(
123 ciphertext: &[u8],
124 transport_key: &[u8; 32],
125) -> Result<Vec<u8>, ProtocolError> {
126 if ciphertext.len() < NONCE_LEN + TAG_LEN {
127 return Err(ProtocolError::DecryptionFailed(
128 "ciphertext too short".into(),
129 ));
130 }
131 let (nonce_bytes, ct) = ciphertext.split_at(NONCE_LEN);
132 let nonce = Nonce::from_slice(nonce_bytes);
133 let cipher = ChaCha20Poly1305::new_from_slice(transport_key)
134 .map_err(|_| ProtocolError::DecryptionFailed("invalid key".into()))?;
135 cipher
136 .decrypt(nonce, ct)
137 .map_err(|_| ProtocolError::DecryptionFailed("decryption failed".into()))
138}
139
140pub fn derive_transport_key(
151 shared_secret: &[u8; 32],
152 initiator_pub: &[u8; 32],
153 responder_pub: &[u8; 32],
154 short_code: &str,
155) -> TransportKey {
156 let salt = build_salt(initiator_pub, responder_pub);
157 let info = build_info(b"auths-pairing-transport-v1", short_code);
158
159 let hk = Hkdf::<Sha256>::new(Some(&salt), shared_secret);
160 let mut key = [0u8; 32];
161 let _ = hk.expand(&info, &mut key);
163 TransportKey::new(key)
164}
165
166fn build_salt(initiator_pub: &[u8; 32], responder_pub: &[u8; 32]) -> [u8; 64] {
167 let mut salt = [0u8; 64];
168 salt[..32].copy_from_slice(initiator_pub);
169 salt[32..].copy_from_slice(responder_pub);
170 salt
171}
172
173fn build_info(domain: &[u8], short_code: &str) -> Vec<u8> {
174 let mut info = Vec::with_capacity(domain.len() + short_code.len());
175 info.extend_from_slice(domain);
176 info.extend_from_slice(short_code.as_bytes());
177 info
178}
179
180#[cfg(test)]
181#[allow(clippy::disallowed_methods)]
182mod tests {
183 use super::*;
184 use std::collections::HashSet;
185
186 const TEST_SECRET: [u8; 32] = [0x42; 32];
187 const TEST_INIT_PUB: [u8; 32] = [0x01; 32];
188 const TEST_RESP_PUB: [u8; 32] = [0x02; 32];
189 const TEST_SHORT_CODE: &str = "ABC123";
190
191 #[test]
192 fn sas_determinism() {
193 let a = derive_sas(
194 &TEST_SECRET,
195 &TEST_INIT_PUB,
196 &TEST_RESP_PUB,
197 TEST_SHORT_CODE,
198 );
199 let b = derive_sas(
200 &TEST_SECRET,
201 &TEST_INIT_PUB,
202 &TEST_RESP_PUB,
203 TEST_SHORT_CODE,
204 );
205 assert_eq!(a, b);
206 }
207
208 #[test]
209 fn sas_divergence_different_secret() {
210 let a = derive_sas(
211 &TEST_SECRET,
212 &TEST_INIT_PUB,
213 &TEST_RESP_PUB,
214 TEST_SHORT_CODE,
215 );
216 let b = derive_sas(&[0xFF; 32], &TEST_INIT_PUB, &TEST_RESP_PUB, TEST_SHORT_CODE);
217 assert_ne!(a, b);
218 }
219
220 #[test]
221 fn sas_divergence_different_pubkeys() {
222 let a = derive_sas(
223 &TEST_SECRET,
224 &TEST_INIT_PUB,
225 &TEST_RESP_PUB,
226 TEST_SHORT_CODE,
227 );
228 let b = derive_sas(&TEST_SECRET, &[0x03; 32], &TEST_RESP_PUB, TEST_SHORT_CODE);
229 assert_ne!(a, b);
230 }
231
232 #[test]
233 fn domain_separation() {
234 let sas = derive_sas(
235 &TEST_SECRET,
236 &TEST_INIT_PUB,
237 &TEST_RESP_PUB,
238 TEST_SHORT_CODE,
239 );
240 let tk = derive_transport_key(
241 &TEST_SECRET,
242 &TEST_INIT_PUB,
243 &TEST_RESP_PUB,
244 TEST_SHORT_CODE,
245 );
246 assert_ne!(&sas[..], &tk.as_bytes()[..8]);
247 }
248
249 #[test]
250 fn emoji_format() {
251 let sas = derive_sas(
252 &TEST_SECRET,
253 &TEST_INIT_PUB,
254 &TEST_RESP_PUB,
255 TEST_SHORT_CODE,
256 );
257 let emoji = format_sas_emoji(&sas);
258 let parts: Vec<&str> = emoji.split(" ").collect();
259 assert_eq!(parts.len(), 4);
260 for part in &parts {
261 assert!(SAS_EMOJI.contains(part), "emoji {part} not in wordlist");
262 }
263 }
264
265 #[test]
266 fn numeric_format() {
267 let sas = derive_sas(
268 &TEST_SECRET,
269 &TEST_INIT_PUB,
270 &TEST_RESP_PUB,
271 TEST_SHORT_CODE,
272 );
273 let numeric = format_sas_numeric(&sas);
274 let re = regex_lite::Regex::new(r"^\d{3}-\d{3}$").unwrap();
275 assert!(re.is_match(&numeric), "numeric format wrong: {numeric}");
276 }
277
278 #[test]
279 fn emoji_wordlist_integrity() {
280 assert_eq!(SAS_EMOJI.len(), 256);
281 let set: HashSet<&str> = SAS_EMOJI.iter().copied().collect();
282 assert_eq!(set.len(), 256, "duplicate emoji in wordlist");
283 }
284
285 #[test]
286 fn transport_encryption_roundtrip() {
287 let tk = derive_transport_key(
288 &TEST_SECRET,
289 &TEST_INIT_PUB,
290 &TEST_RESP_PUB,
291 TEST_SHORT_CODE,
292 );
293 let key_bytes = *tk.as_bytes();
294 let plaintext = b"test attestation payload";
295 let ciphertext = tk.encrypt(plaintext).unwrap();
296 assert_eq!(ciphertext.len(), NONCE_LEN + plaintext.len() + TAG_LEN);
297 let decrypted = decrypt_from_transport(&ciphertext, &key_bytes).unwrap();
298 assert_eq!(decrypted, plaintext);
299 }
300
301 #[test]
302 fn transport_encryption_wrong_key() {
303 let tk = derive_transport_key(
304 &TEST_SECRET,
305 &TEST_INIT_PUB,
306 &TEST_RESP_PUB,
307 TEST_SHORT_CODE,
308 );
309 let ciphertext = tk.encrypt(b"secret").unwrap();
310 let result = decrypt_from_transport(&ciphertext, &[0xFF; 32]);
311 assert!(matches!(result, Err(ProtocolError::DecryptionFailed(_))));
312 }
313
314 #[test]
315 fn sas_test_vector() {
316 let sas = derive_sas(
319 &TEST_SECRET,
320 &TEST_INIT_PUB,
321 &TEST_RESP_PUB,
322 TEST_SHORT_CODE,
323 );
324 assert_eq!(sas, [189, 58, 161, 90, 151, 221, 243, 229]);
325 assert_eq!(format_sas_emoji(&sas), "ðŠ ðĶŽ ð ðĶ");
326 assert_eq!(format_sas_numeric(&sas), "905-509");
327 }
328
329 #[test]
330 fn mitm_simulation_produces_different_sas() {
331 let real_shared = [0x42u8; 32];
335 let attacker_shared_a = [0xAA; 32];
336 let attacker_shared_b = [0xBB; 32];
337
338 let sas_real = derive_sas(
339 &real_shared,
340 &TEST_INIT_PUB,
341 &TEST_RESP_PUB,
342 TEST_SHORT_CODE,
343 );
344 let sas_mitm_a = derive_sas(
345 &attacker_shared_a,
346 &TEST_INIT_PUB,
347 &[0x03; 32],
348 TEST_SHORT_CODE,
349 );
350 let sas_mitm_b = derive_sas(
351 &attacker_shared_b,
352 &[0x04; 32],
353 &TEST_RESP_PUB,
354 TEST_SHORT_CODE,
355 );
356
357 assert_ne!(sas_real, sas_mitm_a);
358 assert_ne!(sas_real, sas_mitm_b);
359 assert_ne!(sas_mitm_a, sas_mitm_b);
360 }
361
362 #[test]
363 fn transport_key_decrypt_short_ciphertext() {
364 let result = decrypt_from_transport(&[0u8; 10], &[0u8; 32]);
365 assert!(matches!(result, Err(ProtocolError::DecryptionFailed(_))));
366 }
367}