1use crate::{hash, hkdf_extract_expand};
41use curve25519_dalek::{
42 constants::RISTRETTO_BASEPOINT_POINT,
43 ristretto::{CompressedRistretto, RistrettoPoint},
44 scalar::Scalar,
45};
46use rand::Rng as _;
47use serde::{Deserialize, Serialize};
48use thiserror::Error;
49use zeroize::{Zeroize, ZeroizeOnDrop};
50
51#[derive(Error, Debug)]
53pub enum SrpError {
54 #[error("Invalid verifier")]
55 InvalidVerifier,
56 #[error("Invalid public key")]
57 InvalidPublicKey,
58 #[error("Computation failed")]
59 ComputationFailed,
60 #[error("Point decompression failed")]
61 DecompressionFailed,
62}
63
64pub type SrpResult<T> = Result<T, SrpError>;
66
67#[derive(Debug, Clone, Serialize, Deserialize, Zeroize, ZeroizeOnDrop)]
69pub struct SrpVerifier {
70 #[zeroize(skip)]
71 salt: [u8; 32],
72 verifier: [u8; 32],
73}
74
75impl SrpVerifier {
76 pub fn generate(username: &[u8], password: &[u8]) -> Self {
80 let mut rng = rand::thread_rng();
82 let salt: [u8; 32] = {
83 let mut arr = [0u8; 32];
84 rng.fill(&mut arr);
85 arr
86 };
87
88 let mut identity = Vec::new();
90 identity.extend_from_slice(username);
91 identity.push(b':');
92 identity.extend_from_slice(password);
93 let identity_hash = hash(&identity);
94
95 let mut x_input = Vec::new();
96 x_input.extend_from_slice(&salt);
97 x_input.extend_from_slice(&identity_hash);
98 let x_hash = hash(&x_input);
99 let x = Scalar::from_bytes_mod_order(x_hash);
100
101 let v_point = x * RISTRETTO_BASEPOINT_POINT;
103 let verifier = v_point.compress().to_bytes();
104
105 Self { salt, verifier }
106 }
107
108 pub fn salt(&self) -> &[u8; 32] {
110 &self.salt
111 }
112
113 fn verifier_point(&self) -> SrpResult<RistrettoPoint> {
115 CompressedRistretto::from_slice(&self.verifier)
116 .map_err(|_| SrpError::InvalidVerifier)?
117 .decompress()
118 .ok_or(SrpError::DecompressionFailed)
119 }
120
121 pub fn to_bytes(&self) -> Vec<u8> {
123 crate::codec::encode(self).unwrap()
124 }
125
126 pub fn from_bytes(bytes: &[u8]) -> SrpResult<Self> {
128 crate::codec::decode(bytes).map_err(|_| SrpError::InvalidVerifier)
129 }
130}
131
132#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
134pub struct SrpSessionKey {
135 key: Vec<u8>,
136}
137
138impl SrpSessionKey {
139 pub fn as_bytes(&self) -> &[u8] {
141 &self.key
142 }
143
144 pub fn derive_key(&self, info: &[u8], len: usize) -> SrpResult<Vec<u8>> {
146 let mut output = vec![0u8; len];
147 let expanded = hkdf_extract_expand(&self.key, b"", info);
148 output[..len.min(32)].copy_from_slice(&expanded[..len.min(32)]);
149 if len > 32 {
150 for i in (32..len).step_by(32) {
152 let mut info_extended = info.to_vec();
153 info_extended.extend_from_slice(&[i as u8]);
154 let expanded = hkdf_extract_expand(&self.key, b"", &info_extended);
155 let end = (i + 32).min(len);
156 output[i..end].copy_from_slice(&expanded[..(end - i)]);
157 }
158 }
159 Ok(output)
160 }
161}
162
163impl PartialEq for SrpSessionKey {
164 fn eq(&self, other: &Self) -> bool {
165 use subtle::ConstantTimeEq;
166 self.key.ct_eq(&other.key).into()
167 }
168}
169
170impl Eq for SrpSessionKey {}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct SrpPublicKey {
175 point: [u8; 32],
176}
177
178impl SrpPublicKey {
179 fn new(point: &RistrettoPoint) -> Self {
180 Self {
181 point: point.compress().to_bytes(),
182 }
183 }
184
185 fn to_point(&self) -> SrpResult<RistrettoPoint> {
186 CompressedRistretto::from_slice(&self.point)
187 .map_err(|_| SrpError::InvalidPublicKey)?
188 .decompress()
189 .ok_or(SrpError::DecompressionFailed)
190 }
191}
192
193pub struct SrpClient {
195 #[allow(dead_code)]
196 username: Vec<u8>,
197 #[allow(dead_code)]
198 salt: [u8; 32],
199 x: Scalar,
200 a: Scalar,
201 big_a: RistrettoPoint,
202}
203
204impl SrpClient {
205 pub fn new(username: &[u8], password: &[u8], salt: &[u8; 32]) -> (Self, SrpPublicKey) {
209 let mut identity = Vec::new();
211 identity.extend_from_slice(username);
212 identity.push(b':');
213 identity.extend_from_slice(password);
214 let identity_hash = hash(&identity);
215
216 let mut x_input = Vec::new();
217 x_input.extend_from_slice(salt);
218 x_input.extend_from_slice(&identity_hash);
219 let x_hash = hash(&x_input);
220 let x = Scalar::from_bytes_mod_order(x_hash);
221
222 let mut rng = rand::thread_rng();
224 let a_bytes: [u8; 32] = {
225 let mut arr = [0u8; 32];
226 rng.fill(&mut arr);
227 arr
228 };
229 let a = Scalar::from_bytes_mod_order(a_bytes);
230
231 let big_a = a * RISTRETTO_BASEPOINT_POINT;
233
234 let public_key = SrpPublicKey::new(&big_a);
235
236 let client = Self {
237 username: username.to_vec(),
238 salt: *salt,
239 x,
240 a,
241 big_a,
242 };
243
244 (client, public_key)
245 }
246
247 pub fn compute_key(self, server_public: &SrpPublicKey) -> SrpResult<SrpSessionKey> {
249 let big_b = server_public.to_point()?;
250
251 let mut u_input = Vec::new();
253 u_input.extend_from_slice(&self.big_a.compress().to_bytes());
254 u_input.extend_from_slice(&big_b.compress().to_bytes());
255 let u_hash = hash(&u_input);
256 let u = Scalar::from_bytes_mod_order(u_hash);
257
258 let k_hash = hash(&RISTRETTO_BASEPOINT_POINT.compress().to_bytes());
260 let k = Scalar::from_bytes_mod_order(k_hash);
261
262 let g_x = self.x * RISTRETTO_BASEPOINT_POINT;
264
265 let base = big_b - (k * g_x);
267 let exponent = self.a + (u * self.x);
268 let s_point = exponent * base;
269
270 let s_bytes = s_point.compress().to_bytes();
272 let key = hkdf_extract_expand(&s_bytes, b"", b"SRP Session Key").to_vec();
273
274 Ok(SrpSessionKey { key })
275 }
276}
277
278pub struct SrpServer {
280 #[allow(dead_code)]
281 username: Vec<u8>,
282 v: RistrettoPoint,
283 b: Scalar,
284 big_b: RistrettoPoint,
285}
286
287impl SrpServer {
288 pub fn new(username: &[u8], verifier: &SrpVerifier) -> (Self, SrpPublicKey) {
292 let v = verifier.verifier_point().expect("Invalid verifier");
293
294 let mut rng = rand::thread_rng();
296 let b_bytes: [u8; 32] = {
297 let mut arr = [0u8; 32];
298 rng.fill(&mut arr);
299 arr
300 };
301 let b = Scalar::from_bytes_mod_order(b_bytes);
302
303 let k_hash = hash(&RISTRETTO_BASEPOINT_POINT.compress().to_bytes());
305 let k = Scalar::from_bytes_mod_order(k_hash);
306
307 let g_b = b * RISTRETTO_BASEPOINT_POINT;
309 let big_b = (k * v) + g_b;
310
311 let public_key = SrpPublicKey::new(&big_b);
312
313 let server = Self {
314 username: username.to_vec(),
315 v,
316 b,
317 big_b,
318 };
319
320 (server, public_key)
321 }
322
323 pub fn compute_key(self, client_public: &SrpPublicKey) -> SrpResult<SrpSessionKey> {
325 let big_a = client_public.to_point()?;
326
327 let mut u_input = Vec::new();
329 u_input.extend_from_slice(&big_a.compress().to_bytes());
330 u_input.extend_from_slice(&self.big_b.compress().to_bytes());
331 let u_hash = hash(&u_input);
332 let u = Scalar::from_bytes_mod_order(u_hash);
333
334 let v_u = u * self.v;
336 let base = big_a + v_u;
337 let s_point = self.b * base;
338
339 let s_bytes = s_point.compress().to_bytes();
341 let key = hkdf_extract_expand(&s_bytes, b"", b"SRP Session Key").to_vec();
342
343 Ok(SrpSessionKey { key })
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn test_srp_basic() {
353 let username = b"alice";
354 let password = b"secure-password";
355
356 let verifier = SrpVerifier::generate(username, password);
358
359 let (client, client_public) = SrpClient::new(username, password, verifier.salt());
361 let (server, server_public) = SrpServer::new(username, &verifier);
362
363 let client_key = client.compute_key(&server_public).unwrap();
364 let server_key = server.compute_key(&client_public).unwrap();
365
366 assert_eq!(client_key, server_key);
367 }
368
369 #[test]
370 fn test_srp_wrong_password() {
371 let username = b"alice";
372 let password = b"correct-password";
373 let wrong_password = b"wrong-password";
374
375 let verifier = SrpVerifier::generate(username, password);
376
377 let (client, client_public) = SrpClient::new(username, wrong_password, verifier.salt());
378 let (server, server_public) = SrpServer::new(username, &verifier);
379
380 let client_key = client.compute_key(&server_public).unwrap();
381 let server_key = server.compute_key(&client_public).unwrap();
382
383 assert_ne!(client_key, server_key);
385 }
386
387 #[test]
388 fn test_srp_multiple_sessions() {
389 let username = b"bob";
390 let password = b"secret";
391
392 let verifier = SrpVerifier::generate(username, password);
393
394 let (client1, client_public1) = SrpClient::new(username, password, verifier.salt());
396 let (server1, server_public1) = SrpServer::new(username, &verifier);
397 let key1_c = client1.compute_key(&server_public1).unwrap();
398 let key1_s = server1.compute_key(&client_public1).unwrap();
399 assert_eq!(key1_c, key1_s);
400
401 let (client2, client_public2) = SrpClient::new(username, password, verifier.salt());
403 let (server2, server_public2) = SrpServer::new(username, &verifier);
404 let key2_c = client2.compute_key(&server_public2).unwrap();
405 let key2_s = server2.compute_key(&client_public2).unwrap();
406 assert_eq!(key2_c, key2_s);
407
408 assert_ne!(key1_c, key2_c);
410 }
411
412 #[test]
413 fn test_srp_verifier_serialization() {
414 let username = b"test";
415 let password = b"password";
416
417 let verifier = SrpVerifier::generate(username, password);
418
419 let bytes = verifier.to_bytes();
420 let deserialized = SrpVerifier::from_bytes(&bytes).unwrap();
421
422 assert_eq!(verifier.salt, deserialized.salt);
423 assert_eq!(verifier.verifier, deserialized.verifier);
424 }
425
426 #[test]
427 fn test_srp_key_derivation() {
428 let username = b"user";
429 let password = b"pass";
430
431 let verifier = SrpVerifier::generate(username, password);
432
433 let (client, client_public) = SrpClient::new(username, password, verifier.salt());
434 let (server, server_public) = SrpServer::new(username, &verifier);
435
436 let client_key = client.compute_key(&server_public).unwrap();
437 let server_key = server.compute_key(&client_public).unwrap();
438
439 let client_enc_key = client_key.derive_key(b"encryption", 32).unwrap();
441 let server_enc_key = server_key.derive_key(b"encryption", 32).unwrap();
442
443 assert_eq!(client_enc_key, server_enc_key);
444
445 let client_mac_key = client_key.derive_key(b"mac", 32).unwrap();
447 assert_ne!(client_enc_key, client_mac_key);
448 }
449
450 #[test]
451 fn test_srp_different_usernames() {
452 let password = b"same-password";
453
454 let verifier1 = SrpVerifier::generate(b"alice", password);
455 let verifier2 = SrpVerifier::generate(b"bob", password);
456
457 assert_ne!(verifier1.verifier, verifier2.verifier);
459 }
460
461 #[test]
462 fn test_srp_empty_username() {
463 let username = b"";
464 let password = b"password";
465
466 let verifier = SrpVerifier::generate(username, password);
467
468 let (client, client_public) = SrpClient::new(username, password, verifier.salt());
469 let (server, server_public) = SrpServer::new(username, &verifier);
470
471 let client_key = client.compute_key(&server_public).unwrap();
472 let server_key = server.compute_key(&client_public).unwrap();
473
474 assert_eq!(client_key, server_key);
475 }
476
477 #[test]
478 fn test_srp_long_credentials() {
479 let username = b"very-long-username-with-many-characters-for-testing";
480 let password = b"very-long-password-with-many-characters-for-testing-purposes";
481
482 let verifier = SrpVerifier::generate(username, password);
483
484 let (client, client_public) = SrpClient::new(username, password, verifier.salt());
485 let (server, server_public) = SrpServer::new(username, &verifier);
486
487 let client_key = client.compute_key(&server_public).unwrap();
488 let server_key = server.compute_key(&client_public).unwrap();
489
490 assert_eq!(client_key, server_key);
491 }
492
493 #[test]
494 fn test_srp_binary_data() {
495 let username: Vec<u8> = (0..32).collect();
496 let password: Vec<u8> = (32..64).collect();
497
498 let verifier = SrpVerifier::generate(&username, &password);
499
500 let (client, client_public) = SrpClient::new(&username, &password, verifier.salt());
501 let (server, server_public) = SrpServer::new(&username, &verifier);
502
503 let client_key = client.compute_key(&server_public).unwrap();
504 let server_key = server.compute_key(&client_public).unwrap();
505
506 assert_eq!(client_key, server_key);
507 }
508
509 #[test]
510 fn test_srp_public_key_serialization() {
511 let username = b"test";
512 let password = b"test";
513 let verifier = SrpVerifier::generate(username, password);
514
515 let (_client, client_public) = SrpClient::new(username, password, verifier.salt());
516
517 let serialized = crate::codec::encode(&client_public).unwrap();
519 let deserialized: SrpPublicKey = crate::codec::decode(&serialized).unwrap();
520
521 assert!(deserialized.to_point().is_ok());
522 }
523
524 #[test]
525 fn test_srp_session_key_constant_time_eq() {
526 let username = b"alice";
527 let password = b"password123";
528
529 let verifier = SrpVerifier::generate(username, password);
530
531 let (client1, client_public1) = SrpClient::new(username, password, verifier.salt());
532 let (server1, server_public1) = SrpServer::new(username, &verifier);
533
534 let key1 = client1.compute_key(&server_public1).unwrap();
535 let key2 = server1.compute_key(&client_public1).unwrap();
536
537 assert_eq!(key1, key2);
539 assert!(key1 == key2);
540 }
541}