use std::marker::PhantomData;
use digest::{Digest, Output};
use num_bigint::BigUint;
use subtle::ConstantTimeEq;
use crate::types::{SrpAuthError, SrpGroup};
use crate::utils::{compute_k, compute_m1, compute_m2, compute_u};
pub struct SrpClient<'a, D: Digest> {
params: &'a SrpGroup,
d: PhantomData<D>,
}
pub struct SrpClientVerifier<D: Digest> {
m1: Output<D>,
m2: Output<D>,
key: Vec<u8>,
}
impl<'a, D: Digest> SrpClient<'a, D> {
pub fn new(params: &'a SrpGroup) -> Self {
Self {
params,
d: Default::default(),
}
}
pub fn compute_a_pub(&self, a: &BigUint) -> BigUint {
self.params.g.modpow(a, &self.params.n)
}
pub fn compute_identity_hash(username: &[u8], password: &[u8]) -> Output<D> {
let mut d = D::new();
d.update(username);
d.update(b":");
d.update(password);
d.finalize()
}
pub fn compute_x(identity_hash: &[u8], salt: &[u8]) -> BigUint {
let mut x = D::new();
x.update(salt);
x.update(identity_hash);
BigUint::from_bytes_be(&x.finalize())
}
pub fn compute_premaster_secret(
&self,
b_pub: &BigUint,
k: &BigUint,
x: &BigUint,
a: &BigUint,
u: &BigUint,
) -> BigUint {
let base = (k * (self.params.g.modpow(x, &self.params.n))) % &self.params.n;
let base = ((&self.params.n + b_pub) - &base) % &self.params.n;
let exp = (u * x) + a;
base.modpow(&exp, &self.params.n)
}
pub fn compute_v(&self, x: &BigUint) -> BigUint {
self.params.g.modpow(x, &self.params.n)
}
pub fn compute_verifier(&self, username: &[u8], password: &[u8], salt: &[u8]) -> Vec<u8> {
let identity_hash = Self::compute_identity_hash(username, password);
let x = Self::compute_x(identity_hash.as_slice(), salt);
self.compute_v(&x).to_bytes_be()
}
pub fn compute_public_ephemeral(&self, a: &[u8]) -> Vec<u8> {
self.compute_a_pub(&BigUint::from_bytes_be(a)).to_bytes_be()
}
pub fn process_reply(
&self,
a: &[u8],
username: &[u8],
password: &[u8],
salt: &[u8],
b_pub: &[u8],
) -> Result<SrpClientVerifier<D>, SrpAuthError> {
let a = BigUint::from_bytes_be(a);
let a_pub = self.compute_a_pub(&a);
let b_pub = BigUint::from_bytes_be(b_pub);
if &b_pub % &self.params.n == BigUint::default() {
return Err(SrpAuthError::IllegalParameter("b_pub".to_owned()));
}
let u = compute_u::<D>(&a_pub.to_bytes_be(), &b_pub.to_bytes_be());
let k = compute_k::<D>(self.params);
let identity_hash = Self::compute_identity_hash(username, password);
let x = Self::compute_x(identity_hash.as_slice(), salt);
let key = self.compute_premaster_secret(&b_pub, &k, &x, &a, &u);
let m1 = compute_m1::<D>(
&a_pub.to_bytes_be(),
&b_pub.to_bytes_be(),
&key.to_bytes_be(),
);
let m2 = compute_m2::<D>(&a_pub.to_bytes_be(), &m1, &key.to_bytes_be());
Ok(SrpClientVerifier {
m1,
m2,
key: key.to_bytes_be(),
})
}
}
impl<D: Digest> SrpClientVerifier<D> {
pub fn key(&self) -> &[u8] {
&self.key
}
pub fn proof(&self) -> &[u8] {
self.m1.as_slice()
}
pub fn verify_server(&self, reply: &[u8]) -> Result<(), SrpAuthError> {
if self.m2.ct_eq(reply).unwrap_u8() != 1 {
Err(SrpAuthError::BadRecordMac("server".to_owned()))
} else {
Ok(())
}
}
}