1use anyhow::{Result, anyhow, bail};
32use chacha20poly1305::{
33 ChaCha20Poly1305, Key, Nonce,
34 aead::{Aead, KeyInit},
35};
36use hkdf::Hkdf;
37use rand::{Rng, RngCore, rngs::OsRng};
38use sha2::{Digest, Sha256};
39use spake2::{Ed25519Group, Identity, Password, Spake2};
40use std::sync::Mutex;
41
42const CODE_DIGIT_LEN: usize = 2;
44const CODE_TOKEN_LEN: usize = 6;
46const BASE32_ALPHABET: &[u8; 32] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
48
49pub fn generate_code_phrase() -> String {
55 let mut rng = OsRng;
56 let digits = rng.gen_range(0..100);
57 let mut token = String::with_capacity(CODE_TOKEN_LEN);
58 for _ in 0..CODE_TOKEN_LEN {
59 let idx = rng.gen_range(0..BASE32_ALPHABET.len());
60 token.push(BASE32_ALPHABET[idx] as char);
61 }
62 format!("{:02}-{}", digits, token)
63}
64
65pub fn parse_code_phrase(s: &str) -> Result<&str> {
67 let s = s.trim();
68 let (digits, rest) = s
69 .split_once('-')
70 .ok_or_else(|| anyhow!("code phrase missing '-' separator: {s:?}"))?;
71 if digits.len() != CODE_DIGIT_LEN || !digits.chars().all(|c| c.is_ascii_digit()) {
72 bail!("code phrase digits must be {CODE_DIGIT_LEN} ASCII digits, got {digits:?}");
73 }
74 if rest.len() != CODE_TOKEN_LEN {
75 bail!(
76 "code phrase token must be {CODE_TOKEN_LEN} chars, got {} ({rest:?})",
77 rest.len()
78 );
79 }
80 if !rest.bytes().all(|b| BASE32_ALPHABET.contains(&b)) {
81 bail!("code phrase token has non-base32 char: {rest:?}");
82 }
83 Ok(s)
84}
85
86pub struct PakeSide {
90 state: Mutex<Option<Spake2<Ed25519Group>>>,
93 pub msg_out: Vec<u8>,
94}
95
96impl PakeSide {
97 pub fn new(code_phrase: &str, pair_id: &[u8]) -> Self {
101 let parsed = parse_code_phrase(code_phrase).expect("invalid code phrase");
102 let (state, msg_out) = Spake2::<Ed25519Group>::start_symmetric(
103 &Password::new(parsed.as_bytes()),
104 &Identity::new(pair_id),
105 );
106 Self {
107 state: Mutex::new(Some(state)),
108 msg_out,
109 }
110 }
111
112 pub fn from_seed(code_phrase: &str, pair_id: &[u8], seed: [u8; 32]) -> Self {
120 use rand_chacha::ChaCha20Rng;
121 use rand_chacha::rand_core::SeedableRng;
122 let parsed = parse_code_phrase(code_phrase).expect("invalid code phrase");
123 let rng = ChaCha20Rng::from_seed(seed);
124 let (state, msg_out) = Spake2::<Ed25519Group>::start_symmetric_with_rng(
125 &Password::new(parsed.as_bytes()),
126 &Identity::new(pair_id),
127 rng,
128 );
129 Self {
130 state: Mutex::new(Some(state)),
131 msg_out,
132 }
133 }
134
135 pub fn finish(&self, peer_msg: &[u8]) -> Result<[u8; 32]> {
138 let state = self
139 .state
140 .lock()
141 .expect("PakeSide mutex poisoned")
142 .take()
143 .ok_or_else(|| anyhow!("PakeSide.finish called twice"))?;
144 let key = state
145 .finish(peer_msg)
146 .map_err(|e| anyhow!("SPAKE2 finish failed: {e:?}"))?;
147 let mut out = [0u8; 32];
148 let n = key.len().min(32);
149 out[..n].copy_from_slice(&key[..n]);
150 Ok(out)
151 }
152}
153
154pub fn compute_sas_pake(spake_key: &[u8], pub_a: &[u8], pub_b: &[u8]) -> String {
162 let (lo, hi) = if pub_a <= pub_b {
163 (pub_a, pub_b)
164 } else {
165 (pub_b, pub_a)
166 };
167 let mut h = Sha256::new();
168 h.update(b"wire/v1 sas");
169 h.update(spake_key);
170 h.update(lo);
171 h.update(hi);
172 let digest = h.finalize();
173 let n = u32::from_be_bytes([digest[28], digest[29], digest[30], digest[31]]);
174 format!("{:06}", n % 1_000_000)
175}
176
177pub fn derive_aead_key(spake_key: &[u8], pair_id: &[u8]) -> [u8; 32] {
179 let hk = Hkdf::<Sha256>::new(Some(pair_id), spake_key);
180 let mut out = [0u8; 32];
181 hk.expand(b"wire/v1 bootstrap-aead", &mut out)
182 .expect("HKDF expand 32 bytes is infallible");
183 out
184}
185
186pub fn seal_bootstrap(aead_key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>> {
190 let cipher = ChaCha20Poly1305::new(Key::from_slice(aead_key));
191 let mut nonce_bytes = [0u8; 12];
192 OsRng.fill_bytes(&mut nonce_bytes);
193 let nonce = Nonce::from_slice(&nonce_bytes);
194 let ct = cipher
195 .encrypt(nonce, plaintext)
196 .map_err(|e| anyhow!("seal failed: {e:?}"))?;
197 let mut out = Vec::with_capacity(12 + ct.len());
198 out.extend_from_slice(&nonce_bytes);
199 out.extend_from_slice(&ct);
200 Ok(out)
201}
202
203pub fn open_bootstrap(aead_key: &[u8; 32], blob: &[u8]) -> Result<Vec<u8>> {
205 if blob.len() < 12 + 16 {
206 bail!("bootstrap blob too short: {} bytes", blob.len());
207 }
208 let cipher = ChaCha20Poly1305::new(Key::from_slice(aead_key));
209 let nonce = Nonce::from_slice(&blob[..12]);
210 cipher
211 .decrypt(nonce, &blob[12..])
212 .map_err(|e| anyhow!("open failed (auth tag mismatch?): {e:?}"))
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn seeded_pake_side_is_deterministic() {
221 let seed = [42u8; 32];
226 let a = PakeSide::from_seed("12-ABCDEF", b"pair-id-x", seed);
227 let b = PakeSide::from_seed("12-ABCDEF", b"pair-id-x", seed);
228 assert_eq!(a.msg_out, b.msg_out, "msg_out diverges across same seed");
229
230 let bob = PakeSide::new("12-ABCDEF", b"pair-id-x");
233 let key_a = a.finish(&bob.msg_out).expect("a.finish");
234 let key_b = b.finish(&bob.msg_out).expect("b.finish");
235 assert_eq!(key_a, key_b, "shared key diverges across same seed");
236 }
237
238 #[test]
239 fn seeded_pake_side_changes_with_seed() {
240 let a = PakeSide::from_seed("12-ABCDEF", b"pair-id-x", [1u8; 32]);
241 let b = PakeSide::from_seed("12-ABCDEF", b"pair-id-x", [2u8; 32]);
242 assert_ne!(
243 a.msg_out, b.msg_out,
244 "msg_out collides across distinct seeds"
245 );
246 }
247
248 #[test]
249 fn code_phrase_has_expected_shape() {
250 let code = generate_code_phrase();
251 let parsed = parse_code_phrase(&code).unwrap();
252 assert_eq!(parsed, code);
253 assert_eq!(code.len(), CODE_DIGIT_LEN + 1 + CODE_TOKEN_LEN);
254 assert!(code.chars().nth(CODE_DIGIT_LEN) == Some('-'));
255 }
256
257 #[test]
258 fn many_code_phrases_are_distinct() {
259 let mut seen = std::collections::HashSet::new();
261 for _ in 0..1000 {
262 let c = generate_code_phrase();
263 assert!(seen.insert(c));
264 }
265 }
266
267 #[test]
268 fn parse_rejects_malformed_codes() {
269 assert!(parse_code_phrase("foo").is_err());
270 assert!(parse_code_phrase("12345-ABCDEF").is_err()); assert!(parse_code_phrase("12-ABC").is_err()); assert!(parse_code_phrase("12-ABCDEF1").is_err()); assert!(parse_code_phrase("12-abcdef").is_err()); }
275
276 #[test]
277 fn pake_two_sides_derive_same_secret() {
278 let code = generate_code_phrase();
279 let pair_id = b"pair-id-shared";
280 let alice = PakeSide::new(&code, pair_id);
281 let bob = PakeSide::new(&code, pair_id);
282 let alice_secret = alice.finish(&bob.msg_out).unwrap();
283 let bob_secret = bob.finish(&alice.msg_out).unwrap();
284 assert_eq!(alice_secret, bob_secret, "SPAKE2 secrets diverged");
285 }
286
287 #[test]
288 fn pake_wrong_code_diverges() {
289 let pair_id = b"pair-id-same";
293 let alice = PakeSide::new("11-ABCDEF", pair_id);
294 let bob = PakeSide::new("99-ZZZZZZ", pair_id);
295 let alice_result = alice.finish(&bob.msg_out);
296 let bob_result = bob.finish(&alice.msg_out);
297 let mismatch = match (alice_result, bob_result) {
298 (Ok(a), Ok(b)) => a != b,
299 _ => true, };
301 assert!(
302 mismatch,
303 "wrong code phrase should not produce matching secrets"
304 );
305 }
306
307 #[test]
308 fn pake_different_pair_id_diverges() {
309 let code = "42-WIRE45"; let alice = PakeSide::new(code, b"pair-A");
314 let bob = PakeSide::new(code, b"pair-B");
315 let a = alice.finish(&bob.msg_out);
316 let b = bob.finish(&alice.msg_out);
317 let mismatch = match (a, b) {
318 (Ok(x), Ok(y)) => x != y,
319 _ => true,
320 };
321 assert!(mismatch, "different pair_id must NOT yield same secret");
322 }
323
324 #[test]
325 fn pake_finish_called_twice_errors() {
326 let code = generate_code_phrase();
327 let alice = PakeSide::new(&code, b"x");
328 let bob = PakeSide::new(&code, b"x");
329 alice.finish(&bob.msg_out).unwrap();
330 let err = alice.finish(&bob.msg_out).unwrap_err();
331 assert!(err.to_string().contains("twice"), "got: {err}");
332 }
333
334 #[test]
335 fn sas_is_6_digits_and_symmetric() {
336 let key = [42u8; 32];
337 let pub_a = [1u8; 32];
338 let pub_b = [2u8; 32];
339 let sas_ab = compute_sas_pake(&key, &pub_a, &pub_b);
340 let sas_ba = compute_sas_pake(&key, &pub_b, &pub_a);
341 assert_eq!(sas_ab.len(), 6);
342 assert!(sas_ab.chars().all(|c| c.is_ascii_digit()));
343 assert_eq!(sas_ab, sas_ba, "SAS must be symmetric in (pub_a, pub_b)");
344 }
345
346 #[test]
347 fn sas_changes_with_spake_key() {
348 let pub_a = [1u8; 32];
349 let pub_b = [2u8; 32];
350 let sas1 = compute_sas_pake(&[1u8; 32], &pub_a, &pub_b);
351 let sas2 = compute_sas_pake(&[2u8; 32], &pub_a, &pub_b);
352 assert_ne!(sas1, sas2);
353 }
354
355 #[test]
356 fn sas_changes_with_pubkeys() {
357 let key = [42u8; 32];
358 let pub_a = [1u8; 32];
359 let pub_b = [2u8; 32];
360 let pub_c = [3u8; 32];
361 assert_ne!(
362 compute_sas_pake(&key, &pub_a, &pub_b),
363 compute_sas_pake(&key, &pub_a, &pub_c)
364 );
365 }
366
367 #[test]
368 fn aead_seal_open_round_trip() {
369 let key = derive_aead_key(&[42u8; 32], b"pair-id");
370 let plaintext = b"some bootstrap payload bytes";
371 let sealed = seal_bootstrap(&key, plaintext).unwrap();
372 let opened = open_bootstrap(&key, &sealed).unwrap();
373 assert_eq!(opened, plaintext);
374 }
375
376 #[test]
377 fn aead_open_with_wrong_key_fails() {
378 let key1 = derive_aead_key(&[1u8; 32], b"x");
379 let key2 = derive_aead_key(&[2u8; 32], b"x");
380 let sealed = seal_bootstrap(&key1, b"secret").unwrap();
381 let result = open_bootstrap(&key2, &sealed);
382 assert!(result.is_err(), "wrong key must fail AEAD auth");
383 }
384
385 #[test]
386 fn aead_open_with_truncated_blob_fails() {
387 let key = derive_aead_key(&[42u8; 32], b"x");
388 let result = open_bootstrap(&key, b"too short");
389 assert!(result.is_err());
390 }
391
392 #[test]
393 fn full_pake_to_sealed_payload_round_trip() {
394 let code = generate_code_phrase();
398 let pair_id = b"e2e-pair";
399 let paul = PakeSide::new(&code, pair_id);
400 let willard = PakeSide::new(&code, pair_id);
401
402 let paul_msg = paul.msg_out.clone();
403 let willard_msg = willard.msg_out.clone();
404 let paul_secret = paul.finish(&willard_msg).unwrap();
405 let willard_secret = willard.finish(&paul_msg).unwrap();
406 assert_eq!(paul_secret, willard_secret);
407
408 let paul_aead_key = derive_aead_key(&paul_secret, pair_id);
409 let willard_aead_key = derive_aead_key(&willard_secret, pair_id);
410 assert_eq!(paul_aead_key, willard_aead_key);
411
412 let paul_card_bytes = b"{\"did\":\"did:wire:paul\", ...}";
414 let sealed = seal_bootstrap(&paul_aead_key, paul_card_bytes).unwrap();
415 let opened = open_bootstrap(&willard_aead_key, &sealed).unwrap();
416 assert_eq!(opened, paul_card_bytes);
417
418 let pub_a = [9u8; 32];
420 let pub_b = [10u8; 32];
421 let sas_paul = compute_sas_pake(&paul_secret, &pub_a, &pub_b);
422 let sas_willard = compute_sas_pake(&willard_secret, &pub_b, &pub_a);
423 assert_eq!(sas_paul, sas_willard);
424 }
425}