use std::marker::PhantomData;
use digest::{Digest, Output};
use num_bigint::BigUint;
use subtle::ConstantTimeEq;
use crate::types::{SrpAuthError, SrpGroup};
use crate::utils::{compute_k, compute_m1, compute_m2, compute_u};
pub struct SrpServer<'a, D: Digest> {
params: &'a SrpGroup,
d: PhantomData<D>,
}
pub struct SrpServerVerifier<D: Digest> {
m1: Output<D>,
m2: Output<D>,
key: Vec<u8>,
}
impl<'a, D: Digest> SrpServer<'a, D> {
pub fn new(params: &'a SrpGroup) -> Self {
Self {
params,
d: Default::default(),
}
}
pub fn compute_b_pub(&self, b: &BigUint, k: &BigUint, v: &BigUint) -> BigUint {
let inter = (k * v) % &self.params.n;
(inter + self.params.g.modpow(b, &self.params.n)) % &self.params.n
}
pub fn compute_premaster_secret(
&self,
a_pub: &BigUint,
v: &BigUint,
u: &BigUint,
b: &BigUint,
) -> BigUint {
let base = (a_pub * v.modpow(u, &self.params.n)) % &self.params.n;
base.modpow(b, &self.params.n)
}
pub fn compute_public_ephemeral(&self, b: &[u8], v: &[u8]) -> Vec<u8> {
self.compute_b_pub(
&BigUint::from_bytes_be(b),
&compute_k::<D>(self.params),
&BigUint::from_bytes_be(v),
)
.to_bytes_be()
}
pub fn process_reply(
&self,
b: &[u8],
v: &[u8],
a_pub: &[u8],
) -> Result<SrpServerVerifier<D>, SrpAuthError> {
let b = BigUint::from_bytes_be(b);
let v = BigUint::from_bytes_be(v);
let a_pub = BigUint::from_bytes_be(a_pub);
let k = compute_k::<D>(self.params);
let b_pub = self.compute_b_pub(&b, &k, &v);
if &a_pub % &self.params.n == BigUint::default() {
return Err(SrpAuthError::IllegalParameter("a_pub".to_owned()));
}
let u = compute_u::<D>(&a_pub.to_bytes_be(), &b_pub.to_bytes_be());
let key = self.compute_premaster_secret(&a_pub, &v, &u, &b);
let m1 = compute_m1::<D>(
&a_pub.to_bytes_be(),
&b_pub.to_bytes_be(),
&key.to_bytes_be(),
);
let m2 = compute_m2::<D>(&a_pub.to_bytes_be(), &m1, &key.to_bytes_be());
Ok(SrpServerVerifier {
m1,
m2,
key: key.to_bytes_be(),
})
}
}
impl<D: Digest> SrpServerVerifier<D> {
pub fn key(&self) -> &[u8] {
&self.key
}
pub fn proof(&self) -> &[u8] {
self.m2.as_slice()
}
pub fn verify_client(&self, reply: &[u8]) -> Result<(), SrpAuthError> {
if self.m1.ct_eq(reply).unwrap_u8() != 1 {
Err(SrpAuthError::BadRecordMac("client".to_owned()))
} else {
Ok(())
}
}
}