use crate::big_number::{BigNumber, Zero};
use crate::hash::*;
use crate::{error::Srp6Error, Result};
const STRONG_SESSION_KEY_LENGTH: usize = HASH_LENGTH * 2;
#[doc(alias = "N")]
pub type PrimeModulus = BigNumber;
#[doc(alias = "g")]
pub type Generator = BigNumber;
#[doc(alias = "s")]
pub type Salt = BigNumber;
#[doc(alias("A", "B"))]
pub type PublicKey = BigNumber;
#[doc(alias("a", "b", "x"))]
pub type PrivateKey = BigNumber;
pub type KeyPair = (PublicKey, PrivateKey);
#[doc(alias = "v")]
pub type PasswordVerifier = BigNumber;
#[doc(alias = "k")]
pub type MultiplierParameter = BigNumber;
#[doc(alias = "S")]
pub type SessionKey = BigNumber;
#[doc(alias = "K")]
pub type StrongSessionKey = BigNumber;
#[doc(alias("M", "M1"))]
pub type Proof = BigNumber;
#[doc(alias = "M2")]
pub type StrongProof = BigNumber;
#[doc(alias = "I")]
pub type Username = String;
pub type UsernameRef<'a> = &'a str;
#[doc(alias("P", "p"))]
pub type ClearTextPassword = String;
pub type ClearTextPasswordRef<'a> = &'a str;
#[derive(Debug, Clone)]
pub struct UserCredentials<'a> {
pub username: UsernameRef<'a>,
pub password: ClearTextPasswordRef<'a>,
}
#[derive(Debug, Clone)]
pub struct UserSecrets {
pub username: Username,
pub salt: Salt,
pub verifier: PasswordVerifier,
}
#[allow(non_snake_case)]
pub fn calculate_session_key_S_for_host<const N_BYTE_LEN: usize>(
N: &PrimeModulus,
A: &PublicKey,
B: &PublicKey,
b: &PrivateKey,
v: &PasswordVerifier,
) -> Result<SessionKey> {
if (A % N).is_zero() {
return Err(Srp6Error::InvalidPublicKey(A.clone()));
}
let u = &calculate_u::<N_BYTE_LEN>(A, B);
let base = &(A * &v.modpow(u, N));
let S: BigNumber = base.modpow(b, N);
Ok(S)
}
#[allow(non_snake_case)]
#[allow(clippy::many_single_char_names)]
pub fn calculate_session_key_S_for_client<const N_BYTE_LEN: usize>(
N: &PrimeModulus,
k: &MultiplierParameter,
g: &Generator,
B: &PublicKey,
A: &PublicKey,
a: &PrivateKey,
x: &PrivateKey,
) -> Result<SessionKey> {
if (B % N).is_zero() {
return Err(Srp6Error::InvalidPublicKey(B.clone()));
}
let u = &calculate_u::<N_BYTE_LEN>(A, B);
let exp: BigNumber = a + &(u * x);
let g_mod_x = &g.modpow(x, N);
let base = B - &(k * g_mod_x);
let S = base.modpow(&exp, N);
Ok(S)
}
#[allow(non_snake_case)]
pub fn calculate_session_key_hash_interleave_K<const N_BYTE_LEN: usize>(
S: &SessionKey,
) -> StrongSessionKey {
let S = S.to_array_pad_zero::<N_BYTE_LEN>();
let n = S.len() / 2;
let mut half = vec![0; n];
for (i, Si) in S.iter().step_by(2).enumerate() {
half[i] = *Si;
}
let even_half_of_S_hash = HashFunc::new().chain(&half[..n]).finalize();
for (i, Si) in S.iter().skip(1).step_by(2).enumerate() {
half[i] = *Si;
}
let odd_half_of_S_hash = HashFunc::new().chain(&half[..n]).finalize();
let mut vK = [0_u8; STRONG_SESSION_KEY_LENGTH];
for (i, h_Si) in even_half_of_S_hash
.iter()
.zip(odd_half_of_S_hash.iter())
.enumerate()
{
vK[i * 2] = *h_Si.0;
vK[i * 2 + 1] = *h_Si.1;
}
StrongSessionKey::from_bytes_be(&vK)
}
#[allow(non_snake_case)]
pub fn calculate_proof_M<const N_BYTE_LEN: usize, const SALT_LENGTH: usize>(
N: &PrimeModulus,
g: &Generator,
I: UsernameRef,
s: &Salt,
A: &PublicKey,
B: &PublicKey,
K: &StrongSessionKey,
) -> Proof {
let xor_hash: Hash = calculate_hash_N_xor_g::<N_BYTE_LEN>(N, g);
let username_hash = HashFunc::new().chain(I.as_bytes()).finalize();
let M = Proof::from_bytes_be(
HashFunc::new()
.chain(xor_hash)
.chain(username_hash)
.chain(s.to_array_pad_zero::<SALT_LENGTH>())
.chain(A.to_array_pad_zero::<N_BYTE_LEN>())
.chain(B.to_array_pad_zero::<N_BYTE_LEN>())
.chain(K.to_array_pad_zero::<STRONG_SESSION_KEY_LENGTH>())
.finalize()
.as_slice(),
);
M
}
#[allow(non_snake_case)]
pub fn calculate_strong_proof_M2<const KEY_LENGTH: usize>(
A: &PublicKey,
M: &Proof,
K: &StrongSessionKey,
) -> StrongProof {
let M2: StrongProof = HashFunc::new()
.chain(A.to_array_pad_zero::<KEY_LENGTH>())
.chain(M.to_array_pad_zero::<HASH_LENGTH>())
.chain(K.to_array_pad_zero::<STRONG_SESSION_KEY_LENGTH>())
.into();
M2
}
#[allow(non_snake_case)]
fn calculate_hash_N_xor_g<const KEY_LENGTH: usize>(N: &PrimeModulus, g: &Generator) -> Hash {
let mut h = HashFunc::new()
.chain(N.to_array_pad_zero::<KEY_LENGTH>())
.finalize();
let h_g = HashFunc::new().chain(g.to_vec().as_slice()).finalize();
for (i, v) in h.iter_mut().enumerate() {
*v ^= h_g[i];
}
let H_n_g: Hash = h.into();
H_n_g
}
#[allow(non_snake_case)]
pub fn calculate_password_verifier_v(
N: &PrimeModulus,
g: &Generator,
x: &PrivateKey,
) -> PasswordVerifier {
g.modpow(x, N)
}
#[allow(non_snake_case)]
pub fn calculate_u<const N_BYTE_LEN: usize>(A: &PublicKey, B: &PublicKey) -> BigNumber {
hash_w_pad::<N_BYTE_LEN>(A, B)
}
#[allow(non_snake_case)]
pub fn calculate_pubkey_A(N: &PrimeModulus, g: &Generator, a: &PrivateKey) -> PublicKey {
g.modpow(a, N)
}
#[allow(non_snake_case)]
pub fn calculate_pubkey_B(
N: &PrimeModulus,
k: &MultiplierParameter,
g: &Generator,
v: &PasswordVerifier,
b: &PrivateKey,
) -> PublicKey {
let g_mod_N = g.modpow(b, N);
&((k * v) + g_mod_N) % N
}
#[allow(non_snake_case)]
#[allow(dead_code)]
pub fn calculate_private_key_x(I: UsernameRef, p: ClearTextPasswordRef, s: &Salt) -> PrivateKey {
let ph = calculate_p_hash(I, p);
PrivateKey::from_bytes_be(
HashFunc::new()
.chain(s.to_vec())
.chain(ph)
.finalize()
.as_slice(),
)
}
#[allow(non_snake_case)]
#[cfg(not(feature = "wow"))]
pub fn calculate_p_hash(I: UsernameRef, p: ClearTextPasswordRef) -> Hash {
HashFunc::new()
.chain(I.as_bytes())
.chain(":".as_bytes())
.chain(p.as_bytes())
.finalize()
.into()
}
#[allow(non_snake_case)]
#[cfg(feature = "wow")]
pub fn calculate_p_hash(I: UsernameRef, p: ClearTextPasswordRef) -> Hash {
HashFunc::new()
.chain(I.to_uppercase().as_bytes())
.chain(":".as_bytes())
.chain(p.to_uppercase().as_bytes())
.finalize()
.into()
}
#[allow(non_snake_case)]
#[cfg(not(feature = "wow"))]
pub fn calculate_k<const KEY_LENGTH: usize>(
N: &PrimeModulus,
g: &Generator,
) -> MultiplierParameter {
HashFunc::new()
.chain(N.to_vec().as_slice())
.chain(g.to_array_pad_zero::<KEY_LENGTH>())
.into()
}
#[cfg(feature = "wow")]
pub fn calculate_k<const KEY_LENGTH: usize>(
_: &PrimeModulus,
_: &Generator,
) -> MultiplierParameter {
MultiplierParameter::from(3)
}
pub fn generate_private_key<const KEY_LENGTH: usize>() -> PrivateKey {
PrivateKey::new_rand(KEY_LENGTH)
}
pub fn generate_salt<const SALT_LENGTH: usize>() -> Salt {
Salt::new_rand(SALT_LENGTH)
}