chie_crypto/
spake2.rs

1//! SPAKE2 - Simple Password-Authenticated Key Exchange.
2//!
3//! SPAKE2 is a password-authenticated key exchange (PAKE) protocol that allows two parties
4//! who share a password to derive a strong shared secret. It provides protection against
5//! offline dictionary attacks.
6//!
7//! # Features
8//! - Symmetric PAKE (both parties use same password)
9//! - Protection against offline dictionary attacks
10//! - Forward secrecy
11//! - Simple and efficient
12//!
13//! # Example
14//! ```
15//! use chie_crypto::spake2::{Spake2, Spake2Side};
16//!
17//! // Alice and Bob share a password
18//! let password = b"shared-secret-password";
19//!
20//! // Alice starts the protocol
21//! let (alice, alice_msg) = Spake2::start(Spake2Side::Alice, password);
22//!
23//! // Bob starts the protocol
24//! let (bob, bob_msg) = Spake2::start(Spake2Side::Bob, password);
25//!
26//! // They exchange messages and derive the shared secret
27//! let alice_secret = alice.finish(&bob_msg).unwrap();
28//! let bob_secret = bob.finish(&alice_msg).unwrap();
29//!
30//! // Shared secrets match
31//! assert_eq!(alice_secret, bob_secret);
32//! ```
33
34use crate::{hash, hkdf_extract_expand};
35use curve25519_dalek::{
36    constants::RISTRETTO_BASEPOINT_POINT,
37    ristretto::{CompressedRistretto, RistrettoPoint},
38    scalar::Scalar,
39};
40use rand::Rng as _;
41use serde::{Deserialize, Serialize};
42use thiserror::Error;
43use zeroize::{Zeroize, ZeroizeOnDrop};
44
45/// SPAKE2 error types.
46#[derive(Error, Debug)]
47pub enum Spake2Error {
48    #[error("Invalid message format")]
49    InvalidMessage,
50    #[error("Protocol not in correct state")]
51    InvalidState,
52    #[error("Shared secret derivation failed")]
53    DerivationFailed,
54    #[error("Point decompression failed")]
55    DecompressionFailed,
56}
57
58/// SPAKE2 result type.
59pub type Spake2Result<T> = Result<T, Spake2Error>;
60
61/// Side in the SPAKE2 protocol (Alice or Bob).
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum Spake2Side {
64    Alice,
65    Bob,
66}
67
68/// SPAKE2 protocol message.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Spake2Message {
71    point: [u8; 32],
72}
73
74impl Spake2Message {
75    /// Create message from a point.
76    fn new(point: &RistrettoPoint) -> Self {
77        Self {
78            point: point.compress().to_bytes(),
79        }
80    }
81
82    /// Decompress the point.
83    fn to_point(&self) -> Spake2Result<RistrettoPoint> {
84        CompressedRistretto::from_slice(&self.point)
85            .map_err(|_| Spake2Error::InvalidMessage)?
86            .decompress()
87            .ok_or(Spake2Error::DecompressionFailed)
88    }
89}
90
91/// Shared secret derived from SPAKE2.
92#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
93pub struct Spake2SharedSecret {
94    secret: Vec<u8>,
95}
96
97impl Spake2SharedSecret {
98    /// Get the shared secret as bytes.
99    pub fn as_bytes(&self) -> &[u8] {
100        &self.secret
101    }
102
103    /// Derive a key from the shared secret.
104    pub fn derive_key(&self, info: &[u8], len: usize) -> Spake2Result<Vec<u8>> {
105        let mut output = vec![0u8; len];
106        let expanded = hkdf_extract_expand(&self.secret, b"", info);
107        output[..len.min(32)].copy_from_slice(&expanded[..len.min(32)]);
108        if len > 32 {
109            // For longer keys, hash multiple times
110            for i in (32..len).step_by(32) {
111                let mut info_extended = info.to_vec();
112                info_extended.extend_from_slice(&[i as u8]);
113                let expanded = hkdf_extract_expand(&self.secret, b"", &info_extended);
114                let end = (i + 32).min(len);
115                output[i..end].copy_from_slice(&expanded[..(end - i)]);
116            }
117        }
118        Ok(output)
119    }
120}
121
122impl PartialEq for Spake2SharedSecret {
123    fn eq(&self, other: &Self) -> bool {
124        use subtle::ConstantTimeEq;
125        self.secret.ct_eq(&other.secret).into()
126    }
127}
128
129impl Eq for Spake2SharedSecret {}
130
131/// SPAKE2 protocol state machine.
132pub struct Spake2 {
133    side: Spake2Side,
134    password_scalar: Scalar,
135    secret_scalar: Scalar,
136    public_point: RistrettoPoint,
137}
138
139impl Spake2 {
140    // SPAKE2 constants M and N (nothing-up-my-sleeve values)
141    // These are derived from the string "chie-spake2-M" and "chie-spake2-N"
142    fn constant_m() -> RistrettoPoint {
143        let hash1 = hash(b"chie-spake2-M");
144        let hash2 = hash(b"chie-spake2-M-2");
145        let mut bytes = [0u8; 64];
146        bytes[..32].copy_from_slice(&hash1);
147        bytes[32..].copy_from_slice(&hash2);
148        RistrettoPoint::from_uniform_bytes(&bytes)
149    }
150
151    fn constant_n() -> RistrettoPoint {
152        let hash1 = hash(b"chie-spake2-N");
153        let hash2 = hash(b"chie-spake2-N-2");
154        let mut bytes = [0u8; 64];
155        bytes[..32].copy_from_slice(&hash1);
156        bytes[32..].copy_from_slice(&hash2);
157        RistrettoPoint::from_uniform_bytes(&bytes)
158    }
159
160    /// Start the SPAKE2 protocol.
161    ///
162    /// Returns the protocol state and the message to send to the other party.
163    pub fn start(side: Spake2Side, password: &[u8]) -> (Self, Spake2Message) {
164        // Hash password to scalar
165        let password_hash = hash(password);
166        let password_scalar = Scalar::from_bytes_mod_order(password_hash);
167
168        // Generate random secret scalar
169        let mut rng = rand::thread_rng();
170        let secret_bytes: [u8; 32] = {
171            let mut arr = [0u8; 32];
172            rng.fill(&mut arr);
173            arr
174        };
175        let secret_scalar = Scalar::from_bytes_mod_order(secret_bytes);
176
177        // Compute public point: X = x*G + w*M (Alice) or Y = y*G + w*N (Bob)
178        let base_point = secret_scalar * RISTRETTO_BASEPOINT_POINT;
179        let password_point = match side {
180            Spake2Side::Alice => password_scalar * Self::constant_m(),
181            Spake2Side::Bob => password_scalar * Self::constant_n(),
182        };
183        let public_point = base_point + password_point;
184
185        let message = Spake2Message::new(&public_point);
186
187        let state = Self {
188            side,
189            password_scalar,
190            secret_scalar,
191            public_point,
192        };
193
194        (state, message)
195    }
196
197    /// Finish the SPAKE2 protocol using the other party's message.
198    ///
199    /// Returns the shared secret.
200    pub fn finish(self, other_message: &Spake2Message) -> Spake2Result<Spake2SharedSecret> {
201        // Decompress received point
202        let received_point = other_message.to_point()?;
203
204        // Remove password component from received point
205        let password_component = match self.side {
206            // Alice computes: Z = Y - w*N
207            Spake2Side::Alice => self.password_scalar * Self::constant_n(),
208            // Bob computes: Z = X - w*M
209            Spake2Side::Bob => self.password_scalar * Self::constant_m(),
210        };
211
212        let shared_point = received_point - password_component;
213
214        // Compute shared secret: K = x*Z (Alice) or K = y*Z (Bob)
215        let key_point = self.secret_scalar * shared_point;
216
217        // Derive shared secret using transcript hash
218        let transcript = self.compute_transcript(&received_point);
219        let key_material = key_point.compress().to_bytes();
220
221        // Use HKDF to derive the shared secret
222        let secret = hkdf_extract_expand(&key_material, &transcript, b"SPAKE2 Key").to_vec();
223
224        Ok(Spake2SharedSecret { secret })
225    }
226
227    /// Compute protocol transcript for key derivation.
228    fn compute_transcript(&self, other_point: &RistrettoPoint) -> Vec<u8> {
229        let mut transcript = Vec::new();
230
231        // Include both public points in a canonical order (Alice's first, Bob's second)
232        let (alice_point, bob_point) = match self.side {
233            Spake2Side::Alice => (self.public_point, *other_point),
234            Spake2Side::Bob => (*other_point, self.public_point),
235        };
236
237        transcript.extend_from_slice(&alice_point.compress().to_bytes());
238        transcript.extend_from_slice(&bob_point.compress().to_bytes());
239
240        transcript
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247
248    #[test]
249    fn test_spake2_basic() {
250        let password = b"shared-secret-password";
251
252        let (alice, alice_msg) = Spake2::start(Spake2Side::Alice, password);
253        let (bob, bob_msg) = Spake2::start(Spake2Side::Bob, password);
254
255        let alice_secret = alice.finish(&bob_msg).unwrap();
256        let bob_secret = bob.finish(&alice_msg).unwrap();
257
258        assert_eq!(alice_secret, bob_secret);
259    }
260
261    #[test]
262    fn test_spake2_different_passwords_fail() {
263        let alice_password = b"password1";
264        let bob_password = b"password2";
265
266        let (alice, alice_msg) = Spake2::start(Spake2Side::Alice, alice_password);
267        let (bob, bob_msg) = Spake2::start(Spake2Side::Bob, bob_password);
268
269        let alice_secret = alice.finish(&bob_msg).unwrap();
270        let bob_secret = bob.finish(&alice_msg).unwrap();
271
272        assert_ne!(alice_secret, bob_secret);
273    }
274
275    #[test]
276    fn test_spake2_deterministic_with_same_password() {
277        let password = b"test-password";
278
279        // Run protocol twice
280        let (alice1, _alice_msg1) = Spake2::start(Spake2Side::Alice, password);
281        let (_bob1, bob_msg1) = Spake2::start(Spake2Side::Bob, password);
282
283        let (alice2, _alice_msg2) = Spake2::start(Spake2Side::Alice, password);
284        let (_bob2, bob_msg2) = Spake2::start(Spake2Side::Bob, password);
285
286        let secret1 = alice1.finish(&bob_msg1).unwrap();
287        let secret2 = alice2.finish(&bob_msg2).unwrap();
288
289        // Secrets should be different due to random nonces
290        assert_ne!(secret1, secret2);
291    }
292
293    #[test]
294    fn test_spake2_key_derivation() {
295        let password = b"shared-secret";
296
297        let (alice, alice_msg) = Spake2::start(Spake2Side::Alice, password);
298        let (bob, bob_msg) = Spake2::start(Spake2Side::Bob, password);
299
300        let alice_secret = alice.finish(&bob_msg).unwrap();
301        let bob_secret = bob.finish(&alice_msg).unwrap();
302
303        // Derive keys from shared secret
304        let alice_key = alice_secret.derive_key(b"app-key", 32).unwrap();
305        let bob_key = bob_secret.derive_key(b"app-key", 32).unwrap();
306
307        assert_eq!(alice_key, bob_key);
308    }
309
310    #[test]
311    fn test_spake2_message_serialization() {
312        let password = b"test";
313        let (_alice, alice_msg) = Spake2::start(Spake2Side::Alice, password);
314
315        // Serialize and deserialize
316        let serialized = crate::codec::encode(&alice_msg).unwrap();
317        let deserialized: Spake2Message = crate::codec::decode(&serialized).unwrap();
318
319        // Should be able to decompress
320        assert!(deserialized.to_point().is_ok());
321    }
322
323    #[test]
324    fn test_spake2_wrong_side_fails() {
325        let password = b"password";
326
327        let (alice1, alice_msg1) = Spake2::start(Spake2Side::Alice, password);
328        let (alice2, alice_msg2) = Spake2::start(Spake2Side::Alice, password);
329
330        // Both as Alice (wrong!)
331        let secret1 = alice1.finish(&alice_msg2).unwrap();
332        let secret2 = alice2.finish(&alice_msg1).unwrap();
333
334        // Secrets should not match
335        assert_ne!(secret1, secret2);
336    }
337
338    #[test]
339    fn test_spake2_multiple_sessions() {
340        let password = b"shared-password";
341
342        // Session 1
343        let (alice1, alice_msg1) = Spake2::start(Spake2Side::Alice, password);
344        let (bob1, bob_msg1) = Spake2::start(Spake2Side::Bob, password);
345        let secret1_a = alice1.finish(&bob_msg1).unwrap();
346        let secret1_b = bob1.finish(&alice_msg1).unwrap();
347        assert_eq!(secret1_a, secret1_b);
348
349        // Session 2 (should have different keys due to fresh randomness)
350        let (alice2, alice_msg2) = Spake2::start(Spake2Side::Alice, password);
351        let (bob2, bob_msg2) = Spake2::start(Spake2Side::Bob, password);
352        let secret2_a = alice2.finish(&bob_msg2).unwrap();
353        let secret2_b = bob2.finish(&alice_msg2).unwrap();
354        assert_eq!(secret2_a, secret2_b);
355
356        // Different sessions should have different keys
357        assert_ne!(secret1_a, secret2_a);
358    }
359
360    #[test]
361    fn test_spake2_empty_password() {
362        let password = b"";
363
364        let (alice, alice_msg) = Spake2::start(Spake2Side::Alice, password);
365        let (bob, bob_msg) = Spake2::start(Spake2Side::Bob, password);
366
367        let alice_secret = alice.finish(&bob_msg).unwrap();
368        let bob_secret = bob.finish(&alice_msg).unwrap();
369
370        assert_eq!(alice_secret, bob_secret);
371    }
372
373    #[test]
374    fn test_spake2_long_password() {
375        let password =
376            b"this-is-a-very-long-password-with-many-characters-to-test-long-input-handling";
377
378        let (alice, alice_msg) = Spake2::start(Spake2Side::Alice, password);
379        let (bob, bob_msg) = Spake2::start(Spake2Side::Bob, password);
380
381        let alice_secret = alice.finish(&bob_msg).unwrap();
382        let bob_secret = bob.finish(&alice_msg).unwrap();
383
384        assert_eq!(alice_secret, bob_secret);
385    }
386
387    #[test]
388    fn test_spake2_binary_password() {
389        let password: Vec<u8> = (0..=255).collect();
390
391        let (alice, alice_msg) = Spake2::start(Spake2Side::Alice, &password);
392        let (bob, bob_msg) = Spake2::start(Spake2Side::Bob, &password);
393
394        let alice_secret = alice.finish(&bob_msg).unwrap();
395        let bob_secret = bob.finish(&alice_msg).unwrap();
396
397        assert_eq!(alice_secret, bob_secret);
398    }
399
400    #[test]
401    fn test_spake2_constants_different() {
402        let m = Spake2::constant_m();
403        let n = Spake2::constant_n();
404
405        // M and N should be different
406        assert_ne!(m.compress().to_bytes(), n.compress().to_bytes());
407    }
408}