use crate::{hash, hkdf_extract_expand};
use curve25519_dalek::{
constants::RISTRETTO_BASEPOINT_POINT,
ristretto::{CompressedRistretto, RistrettoPoint},
scalar::Scalar,
};
use rand::RngExt as _;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Error, Debug)]
pub enum SrpError {
#[error("Invalid verifier")]
InvalidVerifier,
#[error("Invalid public key")]
InvalidPublicKey,
#[error("Computation failed")]
ComputationFailed,
#[error("Point decompression failed")]
DecompressionFailed,
}
pub type SrpResult<T> = Result<T, SrpError>;
#[derive(Debug, Clone, Serialize, Deserialize, Zeroize, ZeroizeOnDrop)]
pub struct SrpVerifier {
#[zeroize(skip)]
salt: [u8; 32],
verifier: [u8; 32],
}
impl SrpVerifier {
pub fn generate(username: &[u8], password: &[u8]) -> Self {
let mut rng = rand::rng();
let salt: [u8; 32] = {
let mut arr = [0u8; 32];
rng.fill(&mut arr);
arr
};
let mut identity = Vec::new();
identity.extend_from_slice(username);
identity.push(b':');
identity.extend_from_slice(password);
let identity_hash = hash(&identity);
let mut x_input = Vec::new();
x_input.extend_from_slice(&salt);
x_input.extend_from_slice(&identity_hash);
let x_hash = hash(&x_input);
let x = Scalar::from_bytes_mod_order(x_hash);
let v_point = x * RISTRETTO_BASEPOINT_POINT;
let verifier = v_point.compress().to_bytes();
Self { salt, verifier }
}
pub fn salt(&self) -> &[u8; 32] {
&self.salt
}
fn verifier_point(&self) -> SrpResult<RistrettoPoint> {
CompressedRistretto::from_slice(&self.verifier)
.map_err(|_| SrpError::InvalidVerifier)?
.decompress()
.ok_or(SrpError::DecompressionFailed)
}
pub fn to_bytes(&self) -> Vec<u8> {
crate::codec::encode(self).unwrap()
}
pub fn from_bytes(bytes: &[u8]) -> SrpResult<Self> {
crate::codec::decode(bytes).map_err(|_| SrpError::InvalidVerifier)
}
}
#[derive(Clone, Debug, Zeroize, ZeroizeOnDrop)]
pub struct SrpSessionKey {
key: Vec<u8>,
}
impl SrpSessionKey {
pub fn as_bytes(&self) -> &[u8] {
&self.key
}
pub fn derive_key(&self, info: &[u8], len: usize) -> SrpResult<Vec<u8>> {
let mut output = vec![0u8; len];
let expanded = hkdf_extract_expand(&self.key, b"", info);
output[..len.min(32)].copy_from_slice(&expanded[..len.min(32)]);
if len > 32 {
for i in (32..len).step_by(32) {
let mut info_extended = info.to_vec();
info_extended.extend_from_slice(&[i as u8]);
let expanded = hkdf_extract_expand(&self.key, b"", &info_extended);
let end = (i + 32).min(len);
output[i..end].copy_from_slice(&expanded[..(end - i)]);
}
}
Ok(output)
}
}
impl PartialEq for SrpSessionKey {
fn eq(&self, other: &Self) -> bool {
use subtle::ConstantTimeEq;
self.key.ct_eq(&other.key).into()
}
}
impl Eq for SrpSessionKey {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SrpPublicKey {
point: [u8; 32],
}
impl SrpPublicKey {
fn new(point: &RistrettoPoint) -> Self {
Self {
point: point.compress().to_bytes(),
}
}
fn to_point(&self) -> SrpResult<RistrettoPoint> {
CompressedRistretto::from_slice(&self.point)
.map_err(|_| SrpError::InvalidPublicKey)?
.decompress()
.ok_or(SrpError::DecompressionFailed)
}
}
pub struct SrpClient {
#[allow(dead_code)]
username: Vec<u8>,
#[allow(dead_code)]
salt: [u8; 32],
x: Scalar,
a: Scalar,
big_a: RistrettoPoint,
}
impl SrpClient {
pub fn new(username: &[u8], password: &[u8], salt: &[u8; 32]) -> (Self, SrpPublicKey) {
let mut identity = Vec::new();
identity.extend_from_slice(username);
identity.push(b':');
identity.extend_from_slice(password);
let identity_hash = hash(&identity);
let mut x_input = Vec::new();
x_input.extend_from_slice(salt);
x_input.extend_from_slice(&identity_hash);
let x_hash = hash(&x_input);
let x = Scalar::from_bytes_mod_order(x_hash);
let mut rng = rand::rng();
let a_bytes: [u8; 32] = {
let mut arr = [0u8; 32];
rng.fill(&mut arr);
arr
};
let a = Scalar::from_bytes_mod_order(a_bytes);
let big_a = a * RISTRETTO_BASEPOINT_POINT;
let public_key = SrpPublicKey::new(&big_a);
let client = Self {
username: username.to_vec(),
salt: *salt,
x,
a,
big_a,
};
(client, public_key)
}
pub fn compute_key(self, server_public: &SrpPublicKey) -> SrpResult<SrpSessionKey> {
let big_b = server_public.to_point()?;
let mut u_input = Vec::new();
u_input.extend_from_slice(&self.big_a.compress().to_bytes());
u_input.extend_from_slice(&big_b.compress().to_bytes());
let u_hash = hash(&u_input);
let u = Scalar::from_bytes_mod_order(u_hash);
let k_hash = hash(&RISTRETTO_BASEPOINT_POINT.compress().to_bytes());
let k = Scalar::from_bytes_mod_order(k_hash);
let g_x = self.x * RISTRETTO_BASEPOINT_POINT;
let base = big_b - (k * g_x);
let exponent = self.a + (u * self.x);
let s_point = exponent * base;
let s_bytes = s_point.compress().to_bytes();
let key = hkdf_extract_expand(&s_bytes, b"", b"SRP Session Key").to_vec();
Ok(SrpSessionKey { key })
}
}
pub struct SrpServer {
#[allow(dead_code)]
username: Vec<u8>,
v: RistrettoPoint,
b: Scalar,
big_b: RistrettoPoint,
}
impl SrpServer {
pub fn new(username: &[u8], verifier: &SrpVerifier) -> (Self, SrpPublicKey) {
let v = verifier.verifier_point().expect("Invalid verifier");
let mut rng = rand::rng();
let b_bytes: [u8; 32] = {
let mut arr = [0u8; 32];
rng.fill(&mut arr);
arr
};
let b = Scalar::from_bytes_mod_order(b_bytes);
let k_hash = hash(&RISTRETTO_BASEPOINT_POINT.compress().to_bytes());
let k = Scalar::from_bytes_mod_order(k_hash);
let g_b = b * RISTRETTO_BASEPOINT_POINT;
let big_b = (k * v) + g_b;
let public_key = SrpPublicKey::new(&big_b);
let server = Self {
username: username.to_vec(),
v,
b,
big_b,
};
(server, public_key)
}
pub fn compute_key(self, client_public: &SrpPublicKey) -> SrpResult<SrpSessionKey> {
let big_a = client_public.to_point()?;
let mut u_input = Vec::new();
u_input.extend_from_slice(&big_a.compress().to_bytes());
u_input.extend_from_slice(&self.big_b.compress().to_bytes());
let u_hash = hash(&u_input);
let u = Scalar::from_bytes_mod_order(u_hash);
let v_u = u * self.v;
let base = big_a + v_u;
let s_point = self.b * base;
let s_bytes = s_point.compress().to_bytes();
let key = hkdf_extract_expand(&s_bytes, b"", b"SRP Session Key").to_vec();
Ok(SrpSessionKey { key })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_srp_basic() {
let username = b"alice";
let password = b"secure-password";
let verifier = SrpVerifier::generate(username, password);
let (client, client_public) = SrpClient::new(username, password, verifier.salt());
let (server, server_public) = SrpServer::new(username, &verifier);
let client_key = client.compute_key(&server_public).unwrap();
let server_key = server.compute_key(&client_public).unwrap();
assert_eq!(client_key, server_key);
}
#[test]
fn test_srp_wrong_password() {
let username = b"alice";
let password = b"correct-password";
let wrong_password = b"wrong-password";
let verifier = SrpVerifier::generate(username, password);
let (client, client_public) = SrpClient::new(username, wrong_password, verifier.salt());
let (server, server_public) = SrpServer::new(username, &verifier);
let client_key = client.compute_key(&server_public).unwrap();
let server_key = server.compute_key(&client_public).unwrap();
assert_ne!(client_key, server_key);
}
#[test]
fn test_srp_multiple_sessions() {
let username = b"bob";
let password = b"secret";
let verifier = SrpVerifier::generate(username, password);
let (client1, client_public1) = SrpClient::new(username, password, verifier.salt());
let (server1, server_public1) = SrpServer::new(username, &verifier);
let key1_c = client1.compute_key(&server_public1).unwrap();
let key1_s = server1.compute_key(&client_public1).unwrap();
assert_eq!(key1_c, key1_s);
let (client2, client_public2) = SrpClient::new(username, password, verifier.salt());
let (server2, server_public2) = SrpServer::new(username, &verifier);
let key2_c = client2.compute_key(&server_public2).unwrap();
let key2_s = server2.compute_key(&client_public2).unwrap();
assert_eq!(key2_c, key2_s);
assert_ne!(key1_c, key2_c);
}
#[test]
fn test_srp_verifier_serialization() {
let username = b"test";
let password = b"password";
let verifier = SrpVerifier::generate(username, password);
let bytes = verifier.to_bytes();
let deserialized = SrpVerifier::from_bytes(&bytes).unwrap();
assert_eq!(verifier.salt, deserialized.salt);
assert_eq!(verifier.verifier, deserialized.verifier);
}
#[test]
fn test_srp_key_derivation() {
let username = b"user";
let password = b"pass";
let verifier = SrpVerifier::generate(username, password);
let (client, client_public) = SrpClient::new(username, password, verifier.salt());
let (server, server_public) = SrpServer::new(username, &verifier);
let client_key = client.compute_key(&server_public).unwrap();
let server_key = server.compute_key(&client_public).unwrap();
let client_enc_key = client_key.derive_key(b"encryption", 32).unwrap();
let server_enc_key = server_key.derive_key(b"encryption", 32).unwrap();
assert_eq!(client_enc_key, server_enc_key);
let client_mac_key = client_key.derive_key(b"mac", 32).unwrap();
assert_ne!(client_enc_key, client_mac_key);
}
#[test]
fn test_srp_different_usernames() {
let password = b"same-password";
let verifier1 = SrpVerifier::generate(b"alice", password);
let verifier2 = SrpVerifier::generate(b"bob", password);
assert_ne!(verifier1.verifier, verifier2.verifier);
}
#[test]
fn test_srp_empty_username() {
let username = b"";
let password = b"password";
let verifier = SrpVerifier::generate(username, password);
let (client, client_public) = SrpClient::new(username, password, verifier.salt());
let (server, server_public) = SrpServer::new(username, &verifier);
let client_key = client.compute_key(&server_public).unwrap();
let server_key = server.compute_key(&client_public).unwrap();
assert_eq!(client_key, server_key);
}
#[test]
fn test_srp_long_credentials() {
let username = b"very-long-username-with-many-characters-for-testing";
let password = b"very-long-password-with-many-characters-for-testing-purposes";
let verifier = SrpVerifier::generate(username, password);
let (client, client_public) = SrpClient::new(username, password, verifier.salt());
let (server, server_public) = SrpServer::new(username, &verifier);
let client_key = client.compute_key(&server_public).unwrap();
let server_key = server.compute_key(&client_public).unwrap();
assert_eq!(client_key, server_key);
}
#[test]
fn test_srp_binary_data() {
let username: Vec<u8> = (0..32).collect();
let password: Vec<u8> = (32..64).collect();
let verifier = SrpVerifier::generate(&username, &password);
let (client, client_public) = SrpClient::new(&username, &password, verifier.salt());
let (server, server_public) = SrpServer::new(&username, &verifier);
let client_key = client.compute_key(&server_public).unwrap();
let server_key = server.compute_key(&client_public).unwrap();
assert_eq!(client_key, server_key);
}
#[test]
fn test_srp_public_key_serialization() {
let username = b"test";
let password = b"test";
let verifier = SrpVerifier::generate(username, password);
let (_client, client_public) = SrpClient::new(username, password, verifier.salt());
let serialized = crate::codec::encode(&client_public).unwrap();
let deserialized: SrpPublicKey = crate::codec::decode(&serialized).unwrap();
assert!(deserialized.to_point().is_ok());
}
#[test]
fn test_srp_session_key_constant_time_eq() {
let username = b"alice";
let password = b"password123";
let verifier = SrpVerifier::generate(username, password);
let (client1, client_public1) = SrpClient::new(username, password, verifier.salt());
let (server1, server_public1) = SrpServer::new(username, &verifier);
let key1 = client1.compute_key(&server_public1).unwrap();
let key2 = server1.compute_key(&client_public1).unwrap();
assert_eq!(key1, key2);
assert!(key1 == key2);
}
}