1use crate::{hash, hkdf_extract_expand};
35use curve25519_dalek::{
36 constants::RISTRETTO_BASEPOINT_POINT,
37 ristretto::{CompressedRistretto, RistrettoPoint},
38 scalar::Scalar,
39};
40use rand::Rng as _;
41use serde::{Deserialize, Serialize};
42use thiserror::Error;
43use zeroize::{Zeroize, ZeroizeOnDrop};
44
45#[derive(Error, Debug)]
47pub enum Spake2Error {
48 #[error("Invalid message format")]
49 InvalidMessage,
50 #[error("Protocol not in correct state")]
51 InvalidState,
52 #[error("Shared secret derivation failed")]
53 DerivationFailed,
54 #[error("Point decompression failed")]
55 DecompressionFailed,
56}
57
58pub type Spake2Result<T> = Result<T, Spake2Error>;
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum Spake2Side {
64 Alice,
65 Bob,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct Spake2Message {
71 point: [u8; 32],
72}
73
74impl Spake2Message {
75 fn new(point: &RistrettoPoint) -> Self {
77 Self {
78 point: point.compress().to_bytes(),
79 }
80 }
81
82 fn to_point(&self) -> Spake2Result<RistrettoPoint> {
84 CompressedRistretto::from_slice(&self.point)
85 .map_err(|_| Spake2Error::InvalidMessage)?
86 .decompress()
87 .ok_or(Spake2Error::DecompressionFailed)
88 }
89}
90
91#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
93pub struct Spake2SharedSecret {
94 secret: Vec<u8>,
95}
96
97impl Spake2SharedSecret {
98 pub fn as_bytes(&self) -> &[u8] {
100 &self.secret
101 }
102
103 pub fn derive_key(&self, info: &[u8], len: usize) -> Spake2Result<Vec<u8>> {
105 let mut output = vec![0u8; len];
106 let expanded = hkdf_extract_expand(&self.secret, b"", info);
107 output[..len.min(32)].copy_from_slice(&expanded[..len.min(32)]);
108 if len > 32 {
109 for i in (32..len).step_by(32) {
111 let mut info_extended = info.to_vec();
112 info_extended.extend_from_slice(&[i as u8]);
113 let expanded = hkdf_extract_expand(&self.secret, b"", &info_extended);
114 let end = (i + 32).min(len);
115 output[i..end].copy_from_slice(&expanded[..(end - i)]);
116 }
117 }
118 Ok(output)
119 }
120}
121
122impl PartialEq for Spake2SharedSecret {
123 fn eq(&self, other: &Self) -> bool {
124 use subtle::ConstantTimeEq;
125 self.secret.ct_eq(&other.secret).into()
126 }
127}
128
129impl Eq for Spake2SharedSecret {}
130
131pub struct Spake2 {
133 side: Spake2Side,
134 password_scalar: Scalar,
135 secret_scalar: Scalar,
136 public_point: RistrettoPoint,
137}
138
139impl Spake2 {
140 fn constant_m() -> RistrettoPoint {
143 let hash1 = hash(b"chie-spake2-M");
144 let hash2 = hash(b"chie-spake2-M-2");
145 let mut bytes = [0u8; 64];
146 bytes[..32].copy_from_slice(&hash1);
147 bytes[32..].copy_from_slice(&hash2);
148 RistrettoPoint::from_uniform_bytes(&bytes)
149 }
150
151 fn constant_n() -> RistrettoPoint {
152 let hash1 = hash(b"chie-spake2-N");
153 let hash2 = hash(b"chie-spake2-N-2");
154 let mut bytes = [0u8; 64];
155 bytes[..32].copy_from_slice(&hash1);
156 bytes[32..].copy_from_slice(&hash2);
157 RistrettoPoint::from_uniform_bytes(&bytes)
158 }
159
160 pub fn start(side: Spake2Side, password: &[u8]) -> (Self, Spake2Message) {
164 let password_hash = hash(password);
166 let password_scalar = Scalar::from_bytes_mod_order(password_hash);
167
168 let mut rng = rand::thread_rng();
170 let secret_bytes: [u8; 32] = {
171 let mut arr = [0u8; 32];
172 rng.fill(&mut arr);
173 arr
174 };
175 let secret_scalar = Scalar::from_bytes_mod_order(secret_bytes);
176
177 let base_point = secret_scalar * RISTRETTO_BASEPOINT_POINT;
179 let password_point = match side {
180 Spake2Side::Alice => password_scalar * Self::constant_m(),
181 Spake2Side::Bob => password_scalar * Self::constant_n(),
182 };
183 let public_point = base_point + password_point;
184
185 let message = Spake2Message::new(&public_point);
186
187 let state = Self {
188 side,
189 password_scalar,
190 secret_scalar,
191 public_point,
192 };
193
194 (state, message)
195 }
196
197 pub fn finish(self, other_message: &Spake2Message) -> Spake2Result<Spake2SharedSecret> {
201 let received_point = other_message.to_point()?;
203
204 let password_component = match self.side {
206 Spake2Side::Alice => self.password_scalar * Self::constant_n(),
208 Spake2Side::Bob => self.password_scalar * Self::constant_m(),
210 };
211
212 let shared_point = received_point - password_component;
213
214 let key_point = self.secret_scalar * shared_point;
216
217 let transcript = self.compute_transcript(&received_point);
219 let key_material = key_point.compress().to_bytes();
220
221 let secret = hkdf_extract_expand(&key_material, &transcript, b"SPAKE2 Key").to_vec();
223
224 Ok(Spake2SharedSecret { secret })
225 }
226
227 fn compute_transcript(&self, other_point: &RistrettoPoint) -> Vec<u8> {
229 let mut transcript = Vec::new();
230
231 let (alice_point, bob_point) = match self.side {
233 Spake2Side::Alice => (self.public_point, *other_point),
234 Spake2Side::Bob => (*other_point, self.public_point),
235 };
236
237 transcript.extend_from_slice(&alice_point.compress().to_bytes());
238 transcript.extend_from_slice(&bob_point.compress().to_bytes());
239
240 transcript
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247
248 #[test]
249 fn test_spake2_basic() {
250 let password = b"shared-secret-password";
251
252 let (alice, alice_msg) = Spake2::start(Spake2Side::Alice, password);
253 let (bob, bob_msg) = Spake2::start(Spake2Side::Bob, password);
254
255 let alice_secret = alice.finish(&bob_msg).unwrap();
256 let bob_secret = bob.finish(&alice_msg).unwrap();
257
258 assert_eq!(alice_secret, bob_secret);
259 }
260
261 #[test]
262 fn test_spake2_different_passwords_fail() {
263 let alice_password = b"password1";
264 let bob_password = b"password2";
265
266 let (alice, alice_msg) = Spake2::start(Spake2Side::Alice, alice_password);
267 let (bob, bob_msg) = Spake2::start(Spake2Side::Bob, bob_password);
268
269 let alice_secret = alice.finish(&bob_msg).unwrap();
270 let bob_secret = bob.finish(&alice_msg).unwrap();
271
272 assert_ne!(alice_secret, bob_secret);
273 }
274
275 #[test]
276 fn test_spake2_deterministic_with_same_password() {
277 let password = b"test-password";
278
279 let (alice1, _alice_msg1) = Spake2::start(Spake2Side::Alice, password);
281 let (_bob1, bob_msg1) = Spake2::start(Spake2Side::Bob, password);
282
283 let (alice2, _alice_msg2) = Spake2::start(Spake2Side::Alice, password);
284 let (_bob2, bob_msg2) = Spake2::start(Spake2Side::Bob, password);
285
286 let secret1 = alice1.finish(&bob_msg1).unwrap();
287 let secret2 = alice2.finish(&bob_msg2).unwrap();
288
289 assert_ne!(secret1, secret2);
291 }
292
293 #[test]
294 fn test_spake2_key_derivation() {
295 let password = b"shared-secret";
296
297 let (alice, alice_msg) = Spake2::start(Spake2Side::Alice, password);
298 let (bob, bob_msg) = Spake2::start(Spake2Side::Bob, password);
299
300 let alice_secret = alice.finish(&bob_msg).unwrap();
301 let bob_secret = bob.finish(&alice_msg).unwrap();
302
303 let alice_key = alice_secret.derive_key(b"app-key", 32).unwrap();
305 let bob_key = bob_secret.derive_key(b"app-key", 32).unwrap();
306
307 assert_eq!(alice_key, bob_key);
308 }
309
310 #[test]
311 fn test_spake2_message_serialization() {
312 let password = b"test";
313 let (_alice, alice_msg) = Spake2::start(Spake2Side::Alice, password);
314
315 let serialized = crate::codec::encode(&alice_msg).unwrap();
317 let deserialized: Spake2Message = crate::codec::decode(&serialized).unwrap();
318
319 assert!(deserialized.to_point().is_ok());
321 }
322
323 #[test]
324 fn test_spake2_wrong_side_fails() {
325 let password = b"password";
326
327 let (alice1, alice_msg1) = Spake2::start(Spake2Side::Alice, password);
328 let (alice2, alice_msg2) = Spake2::start(Spake2Side::Alice, password);
329
330 let secret1 = alice1.finish(&alice_msg2).unwrap();
332 let secret2 = alice2.finish(&alice_msg1).unwrap();
333
334 assert_ne!(secret1, secret2);
336 }
337
338 #[test]
339 fn test_spake2_multiple_sessions() {
340 let password = b"shared-password";
341
342 let (alice1, alice_msg1) = Spake2::start(Spake2Side::Alice, password);
344 let (bob1, bob_msg1) = Spake2::start(Spake2Side::Bob, password);
345 let secret1_a = alice1.finish(&bob_msg1).unwrap();
346 let secret1_b = bob1.finish(&alice_msg1).unwrap();
347 assert_eq!(secret1_a, secret1_b);
348
349 let (alice2, alice_msg2) = Spake2::start(Spake2Side::Alice, password);
351 let (bob2, bob_msg2) = Spake2::start(Spake2Side::Bob, password);
352 let secret2_a = alice2.finish(&bob_msg2).unwrap();
353 let secret2_b = bob2.finish(&alice_msg2).unwrap();
354 assert_eq!(secret2_a, secret2_b);
355
356 assert_ne!(secret1_a, secret2_a);
358 }
359
360 #[test]
361 fn test_spake2_empty_password() {
362 let password = b"";
363
364 let (alice, alice_msg) = Spake2::start(Spake2Side::Alice, password);
365 let (bob, bob_msg) = Spake2::start(Spake2Side::Bob, password);
366
367 let alice_secret = alice.finish(&bob_msg).unwrap();
368 let bob_secret = bob.finish(&alice_msg).unwrap();
369
370 assert_eq!(alice_secret, bob_secret);
371 }
372
373 #[test]
374 fn test_spake2_long_password() {
375 let password =
376 b"this-is-a-very-long-password-with-many-characters-to-test-long-input-handling";
377
378 let (alice, alice_msg) = Spake2::start(Spake2Side::Alice, password);
379 let (bob, bob_msg) = Spake2::start(Spake2Side::Bob, password);
380
381 let alice_secret = alice.finish(&bob_msg).unwrap();
382 let bob_secret = bob.finish(&alice_msg).unwrap();
383
384 assert_eq!(alice_secret, bob_secret);
385 }
386
387 #[test]
388 fn test_spake2_binary_password() {
389 let password: Vec<u8> = (0..=255).collect();
390
391 let (alice, alice_msg) = Spake2::start(Spake2Side::Alice, &password);
392 let (bob, bob_msg) = Spake2::start(Spake2Side::Bob, &password);
393
394 let alice_secret = alice.finish(&bob_msg).unwrap();
395 let bob_secret = bob.finish(&alice_msg).unwrap();
396
397 assert_eq!(alice_secret, bob_secret);
398 }
399
400 #[test]
401 fn test_spake2_constants_different() {
402 let m = Spake2::constant_m();
403 let n = Spake2::constant_n();
404
405 assert_ne!(m.compress().to_bytes(), n.compress().to_bytes());
407 }
408}