Skip to main content

auths_pairing_protocol/
sas.rs

1//! SAS (Short Authentication String) derivation and transport encryption.
2
3use chacha20poly1305::{
4    ChaCha20Poly1305, Nonce,
5    aead::{Aead, KeyInit},
6};
7use hkdf::Hkdf;
8use sha2::Sha256;
9use zeroize::{Zeroize, Zeroizing};
10
11use crate::error::ProtocolError;
12
13/// 256-emoji wordlist — visually distinct, renders on macOS/Windows/Linux terminals.
14pub const SAS_EMOJI: [&str; 256] = [
15    "ðŸķ", "ðŸą", "🐭", "ðŸđ", "🐰", "ðŸĶŠ", "ðŸŧ", "🐞", "ðŸĻ", "ðŸŊ", "ðŸĶ", "ðŸŪ", "🐷", "ðŸļ", "ðŸĩ", "🐔",
16    "🐧", "ðŸĶ", "ðŸĶ†", "ðŸĶ…", "ðŸĶ‰", "🐚", "🐗", "ðŸī", "ðŸĶ„", "🐝", "🐛", "ðŸĶ‹", "🐌", "🐞", "🐜", "ðŸŠē",
17    "ðŸĒ", "🐍", "ðŸĶŽ", "ðŸĶ‚", "🐙", "ðŸĶ‘", "ðŸĶ", "ðŸĶž", "🐠", "ðŸĄ", "🐎", "ðŸĶˆ", "ðŸģ", "🐋", "🐊", "🐆",
18    "🐅", "ðŸĶ“", "ðŸĶ", "ðŸĶ§", "🐘", "ðŸĶ›", "ðŸĶ", "🐊", "ðŸĶ’", "ðŸĶ˜", "ðŸĶŽ", "🐃", "🐂", "🐄", "🐎", "🐖",
19    "🐏", "🐑", "🐐", "ðŸĶŒ", "🐕", "ðŸĐ", "ðŸĶŪ", "🐈", "🐓", "ðŸĶƒ", "ðŸĶĪ", "ðŸĶš", "ðŸĶœ", "ðŸĶĒ", "ðŸĶĐ", "🕊ïļ",
20    "🐇", "ðŸĶ", "ðŸĶĻ", "ðŸĶĄ", "ðŸĶŦ", "ðŸĶĶ", "ðŸĶĨ", "🐁", "🐀", "ðŸŋïļ", "ðŸĶ”", "ðŸŒĩ", "🎄", "ðŸŒē", "ðŸŒģ", "ðŸŒī",
21    "ðŸŠĩ", "ðŸŒą", "ðŸŒŋ", "☘ïļ", "🍀", "🎍", "ðŸŠī", "🎋", "🍃", "🍂", "🍁", "ðŸŒū", "🌚", "ðŸŒŧ", "ðŸŒđ", "ðŸĨ€",
22    "🌷", "🌞", "💐", "🍄", "🌰", "🎃", "🌎", "🌍", "🌏", "🌕", "🌖", "🌗", "🌘", "🌑", "🌒", "🌓",
23    "🌔", "🌙", "⭐", "🌟", "ðŸ’Ŧ", "âœĻ", "☀ïļ", "ðŸŒĪïļ", "⛅", "ðŸŒĨïļ", "ðŸŒĶïļ", "🌧ïļ", "⛈ïļ", "ðŸŒĐïļ", "ðŸŒĻïļ", "❄ïļ",
24    "☃ïļ", "⛄", "🌎ïļ", "ðŸ’Ļ", "🌊ïļ", "ðŸŒŦïļ", "🌊", "💧", "ðŸ’Ķ", "ðŸ”Ĩ", "ðŸŽŊ", "🏀", "🏈", "âšū", "ðŸĨŽ", "ðŸŽū",
25    "🏐", "🏉", "ðŸĨ", "ðŸŽą", "🏓", "ðŸļ", "🏒", "ðŸĨŠ", "ðŸŽŋ", "⛷ïļ", "🏂", "🊂", "🏋ïļ", "ðŸĪļ", "â›đïļ", "ðŸĪš",
26    "🏇", "🧘", "🏄", "🏊", "ðŸšĢ", "🧗", "ðŸšī", "🏆", "ðŸĨ‡", "ðŸĨˆ", "ðŸĨ‰", "🏅", "🎖ïļ", "🎊", "ðŸŽĻ", "🎭",
27    "ðŸŽđ", "ðŸĨ", "🎷", "🎚", "ðŸŽļ", "🊕", "ðŸŽŧ", "🎎", "ðŸŽŪ", "ðŸ•đïļ", "ðŸŽē", "ðŸ§Đ", "ðŸ”Ū", "🊄", "ðŸ§ŋ", "🎰",
28    "🚀", "✈ïļ", "ðŸ›ļ", "🚁", "ðŸ›ķ", "â›ĩ", "ðŸšĪ", "ðŸ›Ĩïļ", "🚂", "🚃", "🚄", "🚅", "🚆", "🚇", "🚈", "🚊",
29    "🏠", "ðŸĄ", "ðŸĒ", "ðŸĢ", "ðŸĪ", "ðŸĨ", "ðŸĶ", "ðŸĻ", "ðŸĐ", "🏊", "ðŸŦ", "🏎", "🏭", "ðŸŊ", "🏰", "💒",
30    "🗞", "ðŸ—―", "⛩", "🕌", "🛕", "🕍", "â›Đïļ", "🕋", "â›ē", "⛹", "🌁", "ðŸ—ŧ", "🌋", "ðŸ—ū", "🏕ïļ", "🎠",
31];
32
33const NONCE_LEN: usize = 12;
34const TAG_LEN: usize = 16;
35
36/// Derive 8 SAS bytes from the ECDH shared secret, both ephemeral public keys, and short code.
37///
38/// Args:
39/// * `shared_secret`: The 32-byte X25519 shared secret.
40/// * `initiator_pub`: The initiator's X25519 ephemeral public key.
41/// * `responder_pub`: The responder's X25519 ephemeral public key.
42/// * `short_code`: The session's short code (binds SAS to session).
43///
44/// Usage:
45/// ```ignore
46/// let sas = derive_sas(&shared_secret, &init_pub, &resp_pub, "ABC123");
47/// ```
48pub fn derive_sas(
49    shared_secret: &[u8; 32],
50    initiator_pub: &[u8; 32],
51    responder_pub: &[u8; 32],
52    short_code: &str,
53) -> [u8; 8] {
54    let salt = build_salt(initiator_pub, responder_pub);
55    let info = build_info(b"auths-pairing-sas-v1", short_code);
56
57    let hk = Hkdf::<Sha256>::new(Some(&salt), shared_secret);
58    let mut out = [0u8; 8];
59    // 8 bytes is always within HKDF-SHA256 output limit (max 8160 bytes)
60    let _ = hk.expand(&info, &mut out);
61    out
62}
63
64/// Format SAS bytes 0-3 as 4 emoji separated by double spaces.
65pub fn format_sas_emoji(sas_bytes: &[u8; 8]) -> String {
66    sas_bytes[..4]
67        .iter()
68        .map(|&b| SAS_EMOJI[b as usize])
69        .collect::<Vec<_>>()
70        .join("  ")
71}
72
73/// Format SAS bytes 4-7 as a 6-digit numeric code `XXX-XXX`.
74pub fn format_sas_numeric(sas_bytes: &[u8; 8]) -> String {
75    let val =
76        u32::from_be_bytes([sas_bytes[4], sas_bytes[5], sas_bytes[6], sas_bytes[7]]) % 1_000_000;
77    format!("{:03}-{:03}", val / 1000, val % 1000)
78}
79
80/// Single-use transport encryption key derived from the ECDH shared secret.
81///
82/// Wraps a 32-byte key in `Zeroizing` and enforces single use via move semantics.
83/// `encrypt()` takes `self` by value — a second call is a compile error.
84pub struct TransportKey(Zeroizing<[u8; 32]>);
85
86impl TransportKey {
87    pub fn new(key: [u8; 32]) -> Self {
88        Self(Zeroizing::new(key))
89    }
90
91    /// Encrypt plaintext with ChaCha20-Poly1305. Consumes the key (single use).
92    ///
93    /// Output format: `[nonce:12][ciphertext+tag]`
94    pub fn encrypt(mut self, plaintext: &[u8]) -> Result<Vec<u8>, ProtocolError> {
95        let cipher = ChaCha20Poly1305::new_from_slice(&*self.0)
96            .map_err(|_| ProtocolError::EncryptionFailed("invalid key".into()))?;
97        let nonce_bytes: [u8; NONCE_LEN] = rand::random();
98        let nonce = Nonce::from(nonce_bytes);
99        let ciphertext = cipher
100            .encrypt(&nonce, plaintext)
101            .map_err(|_| ProtocolError::EncryptionFailed("encryption failed".into()))?;
102
103        self.0.zeroize();
104
105        let mut out = Vec::with_capacity(NONCE_LEN + ciphertext.len());
106        out.extend_from_slice(&nonce_bytes);
107        out.extend_from_slice(&ciphertext);
108        Ok(out)
109    }
110
111    /// Access the raw key bytes (for the responder side that needs them for decryption).
112    pub fn as_bytes(&self) -> &[u8; 32] {
113        &self.0
114    }
115}
116
117/// Decrypt ciphertext produced by `TransportKey::encrypt()`.
118///
119/// Args:
120/// * `ciphertext`: The `[nonce:12][ciphertext+tag]` blob.
121/// * `transport_key`: The 32-byte transport key.
122pub fn decrypt_from_transport(
123    ciphertext: &[u8],
124    transport_key: &[u8; 32],
125) -> Result<Vec<u8>, ProtocolError> {
126    if ciphertext.len() < NONCE_LEN + TAG_LEN {
127        return Err(ProtocolError::DecryptionFailed(
128            "ciphertext too short".into(),
129        ));
130    }
131    let (nonce_bytes, ct) = ciphertext.split_at(NONCE_LEN);
132    let nonce = Nonce::from_slice(nonce_bytes);
133    let cipher = ChaCha20Poly1305::new_from_slice(transport_key)
134        .map_err(|_| ProtocolError::DecryptionFailed("invalid key".into()))?;
135    cipher
136        .decrypt(nonce, ct)
137        .map_err(|_| ProtocolError::DecryptionFailed("decryption failed".into()))
138}
139
140/// Derive a single-use transport key from the ECDH shared secret.
141///
142/// Uses the same HKDF salt (both ephemeral public keys) but a different info string
143/// for domain separation from the SAS derivation.
144///
145/// Args:
146/// * `shared_secret`: The 32-byte X25519 shared secret.
147/// * `initiator_pub`: The initiator's X25519 ephemeral public key.
148/// * `responder_pub`: The responder's X25519 ephemeral public key.
149/// * `short_code`: The session's short code.
150pub fn derive_transport_key(
151    shared_secret: &[u8; 32],
152    initiator_pub: &[u8; 32],
153    responder_pub: &[u8; 32],
154    short_code: &str,
155) -> TransportKey {
156    let salt = build_salt(initiator_pub, responder_pub);
157    let info = build_info(b"auths-pairing-transport-v1", short_code);
158
159    let hk = Hkdf::<Sha256>::new(Some(&salt), shared_secret);
160    let mut key = [0u8; 32];
161    // 32 bytes is always within HKDF-SHA256 output limit (max 8160 bytes)
162    let _ = hk.expand(&info, &mut key);
163    TransportKey::new(key)
164}
165
166fn build_salt(initiator_pub: &[u8; 32], responder_pub: &[u8; 32]) -> [u8; 64] {
167    let mut salt = [0u8; 64];
168    salt[..32].copy_from_slice(initiator_pub);
169    salt[32..].copy_from_slice(responder_pub);
170    salt
171}
172
173fn build_info(domain: &[u8], short_code: &str) -> Vec<u8> {
174    let mut info = Vec::with_capacity(domain.len() + short_code.len());
175    info.extend_from_slice(domain);
176    info.extend_from_slice(short_code.as_bytes());
177    info
178}
179
180#[cfg(test)]
181#[allow(clippy::disallowed_methods)]
182mod tests {
183    use super::*;
184    use std::collections::HashSet;
185
186    const TEST_SECRET: [u8; 32] = [0x42; 32];
187    const TEST_INIT_PUB: [u8; 32] = [0x01; 32];
188    const TEST_RESP_PUB: [u8; 32] = [0x02; 32];
189    const TEST_SHORT_CODE: &str = "ABC123";
190
191    #[test]
192    fn sas_determinism() {
193        let a = derive_sas(
194            &TEST_SECRET,
195            &TEST_INIT_PUB,
196            &TEST_RESP_PUB,
197            TEST_SHORT_CODE,
198        );
199        let b = derive_sas(
200            &TEST_SECRET,
201            &TEST_INIT_PUB,
202            &TEST_RESP_PUB,
203            TEST_SHORT_CODE,
204        );
205        assert_eq!(a, b);
206    }
207
208    #[test]
209    fn sas_divergence_different_secret() {
210        let a = derive_sas(
211            &TEST_SECRET,
212            &TEST_INIT_PUB,
213            &TEST_RESP_PUB,
214            TEST_SHORT_CODE,
215        );
216        let b = derive_sas(&[0xFF; 32], &TEST_INIT_PUB, &TEST_RESP_PUB, TEST_SHORT_CODE);
217        assert_ne!(a, b);
218    }
219
220    #[test]
221    fn sas_divergence_different_pubkeys() {
222        let a = derive_sas(
223            &TEST_SECRET,
224            &TEST_INIT_PUB,
225            &TEST_RESP_PUB,
226            TEST_SHORT_CODE,
227        );
228        let b = derive_sas(&TEST_SECRET, &[0x03; 32], &TEST_RESP_PUB, TEST_SHORT_CODE);
229        assert_ne!(a, b);
230    }
231
232    #[test]
233    fn domain_separation() {
234        let sas = derive_sas(
235            &TEST_SECRET,
236            &TEST_INIT_PUB,
237            &TEST_RESP_PUB,
238            TEST_SHORT_CODE,
239        );
240        let tk = derive_transport_key(
241            &TEST_SECRET,
242            &TEST_INIT_PUB,
243            &TEST_RESP_PUB,
244            TEST_SHORT_CODE,
245        );
246        assert_ne!(&sas[..], &tk.as_bytes()[..8]);
247    }
248
249    #[test]
250    fn emoji_format() {
251        let sas = derive_sas(
252            &TEST_SECRET,
253            &TEST_INIT_PUB,
254            &TEST_RESP_PUB,
255            TEST_SHORT_CODE,
256        );
257        let emoji = format_sas_emoji(&sas);
258        let parts: Vec<&str> = emoji.split("  ").collect();
259        assert_eq!(parts.len(), 4);
260        for part in &parts {
261            assert!(SAS_EMOJI.contains(part), "emoji {part} not in wordlist");
262        }
263    }
264
265    #[test]
266    fn numeric_format() {
267        let sas = derive_sas(
268            &TEST_SECRET,
269            &TEST_INIT_PUB,
270            &TEST_RESP_PUB,
271            TEST_SHORT_CODE,
272        );
273        let numeric = format_sas_numeric(&sas);
274        let re = regex_lite::Regex::new(r"^\d{3}-\d{3}$").unwrap();
275        assert!(re.is_match(&numeric), "numeric format wrong: {numeric}");
276    }
277
278    #[test]
279    fn emoji_wordlist_integrity() {
280        assert_eq!(SAS_EMOJI.len(), 256);
281        let set: HashSet<&str> = SAS_EMOJI.iter().copied().collect();
282        assert_eq!(set.len(), 256, "duplicate emoji in wordlist");
283    }
284
285    #[test]
286    fn transport_encryption_roundtrip() {
287        let tk = derive_transport_key(
288            &TEST_SECRET,
289            &TEST_INIT_PUB,
290            &TEST_RESP_PUB,
291            TEST_SHORT_CODE,
292        );
293        let key_bytes = *tk.as_bytes();
294        let plaintext = b"test attestation payload";
295        let ciphertext = tk.encrypt(plaintext).unwrap();
296        assert_eq!(ciphertext.len(), NONCE_LEN + plaintext.len() + TAG_LEN);
297        let decrypted = decrypt_from_transport(&ciphertext, &key_bytes).unwrap();
298        assert_eq!(decrypted, plaintext);
299    }
300
301    #[test]
302    fn transport_encryption_wrong_key() {
303        let tk = derive_transport_key(
304            &TEST_SECRET,
305            &TEST_INIT_PUB,
306            &TEST_RESP_PUB,
307            TEST_SHORT_CODE,
308        );
309        let ciphertext = tk.encrypt(b"secret").unwrap();
310        let result = decrypt_from_transport(&ciphertext, &[0xFF; 32]);
311        assert!(matches!(result, Err(ProtocolError::DecryptionFailed(_))));
312    }
313
314    #[test]
315    fn sas_test_vector() {
316        // Hardcoded test vector to prevent implementation drift.
317        // If this fails, the HKDF inputs or info string changed.
318        let sas = derive_sas(
319            &TEST_SECRET,
320            &TEST_INIT_PUB,
321            &TEST_RESP_PUB,
322            TEST_SHORT_CODE,
323        );
324        assert_eq!(sas, [189, 58, 161, 90, 151, 221, 243, 229]);
325        assert_eq!(format_sas_emoji(&sas), "🎊  ðŸĶŽ  🏉  ðŸĶ”");
326        assert_eq!(format_sas_numeric(&sas), "905-509");
327    }
328
329    #[test]
330    fn mitm_simulation_produces_different_sas() {
331        // MITM has two separate shared secrets (one with each party).
332        // Even with the same short_code, the SAS values diverge because
333        // the shared secrets are different.
334        let real_shared = [0x42u8; 32];
335        let attacker_shared_a = [0xAA; 32];
336        let attacker_shared_b = [0xBB; 32];
337
338        let sas_real = derive_sas(
339            &real_shared,
340            &TEST_INIT_PUB,
341            &TEST_RESP_PUB,
342            TEST_SHORT_CODE,
343        );
344        let sas_mitm_a = derive_sas(
345            &attacker_shared_a,
346            &TEST_INIT_PUB,
347            &[0x03; 32],
348            TEST_SHORT_CODE,
349        );
350        let sas_mitm_b = derive_sas(
351            &attacker_shared_b,
352            &[0x04; 32],
353            &TEST_RESP_PUB,
354            TEST_SHORT_CODE,
355        );
356
357        assert_ne!(sas_real, sas_mitm_a);
358        assert_ne!(sas_real, sas_mitm_b);
359        assert_ne!(sas_mitm_a, sas_mitm_b);
360    }
361
362    #[test]
363    fn transport_key_decrypt_short_ciphertext() {
364        let result = decrypt_from_transport(&[0u8; 10], &[0u8; 32]);
365        assert!(matches!(result, Err(ProtocolError::DecryptionFailed(_))));
366    }
367}