use log::debug;
use serde::Serialize;
use crate::big_number::{BigNumber, Zero};
use crate::hash::{hash, Digest, Hash, HashFunc, Update, HASH_LENGTH};
use crate::{Result, Srp6Error};
const STRONG_SESSION_KEY_LENGTH: usize = HASH_LENGTH * 2;
#[doc(alias = "N")]
pub type PrimeModulus = BigNumber;
#[doc(alias = "g")]
pub type Generator = BigNumber;
#[doc(alias = "s")]
pub type Salt = BigNumber;
#[doc(alias("A", "B"))]
pub type PublicKey = BigNumber;
#[doc(alias("a", "b"))]
pub type PrivateKey = BigNumber;
pub type KeyPair = (PublicKey, PrivateKey);
#[doc(alias = "v")]
pub type PasswordVerifier = BigNumber;
#[doc(alias = "k")]
pub type MultiplierParameter = BigNumber;
#[doc(alias = "S")]
pub type SessionKey = BigNumber;
#[doc(alias = "K")]
pub type StrongSessionKey = BigNumber;
#[doc(alias("M", "M1"))]
pub type Proof = BigNumber;
#[doc(alias = "M2")]
pub type StrongProof = BigNumber;
#[doc(alias = "I")]
pub type Username = String;
pub type UsernameRef<'a> = &'a str;
#[doc(alias = "p")]
pub type ClearTextPassword = str;
#[derive(Debug, Clone, Serialize)]
pub struct UserCredentials<'a> {
pub username: UsernameRef<'a>,
pub password: &'a ClearTextPassword,
}
#[derive(Debug, Clone, Serialize)]
pub struct UserDetails {
pub username: Username,
pub salt: Salt,
pub verifier: PasswordVerifier,
}
#[allow(non_snake_case)]
pub(crate) fn calculate_session_key_S_for_host<const KEY_LENGTH: usize>(
N: &PrimeModulus,
A: &PublicKey,
B: &PublicKey,
b: &PrivateKey,
v: &PasswordVerifier,
) -> Result<SessionKey> {
if (A % N).is_zero() {
return Err(Srp6Error::InvalidPublicKey(A.clone()));
}
let u = &calculate_u::<KEY_LENGTH>(A, B);
let base = &(A * &v.modpow(u, N));
let S: BigNumber = base.modpow(b, N);
debug!("S = {:?}", &S);
Ok(S)
}
#[allow(non_snake_case)]
#[allow(clippy::many_single_char_names)]
pub(crate) fn calculate_session_key_S_for_client<const KEY_LENGTH: usize>(
N: &PrimeModulus,
k: &MultiplierParameter,
g: &Generator,
B: &PublicKey,
A: &PublicKey,
a: &PrivateKey,
x: &PrivateKey,
) -> Result<SessionKey> {
if (B % N).is_zero() {
return Err(Srp6Error::InvalidPublicKey(B.clone()));
}
let u = &calculate_u::<KEY_LENGTH>(A, B);
let exp: BigNumber = a + &(u * x);
let g_mod_x = &g.modpow(x, N);
let base = B - &(k * g_mod_x);
let S = base.modpow(&exp, N);
debug!("S = {:?}", &S);
Ok(S)
}
#[allow(non_snake_case)]
pub(crate) fn calculate_session_key_hash_interleave_K<const KEY_LENGTH: usize>(
S: &SessionKey,
) -> StrongSessionKey {
let S = S.to_array_pad_zero::<KEY_LENGTH>();
let mut half = [0_u8; KEY_LENGTH];
for (i, Si) in S.iter().step_by(2).enumerate() {
half[i] = *Si;
}
let even_half_of_S_hash = HashFunc::new().chain(&half[..KEY_LENGTH / 2]).finalize();
for (i, Si) in S.iter().skip(1).step_by(2).enumerate() {
half[i] = *Si;
}
let odd_half_of_S_hash = HashFunc::new().chain(&half[..KEY_LENGTH / 2]).finalize();
let mut vK = [0_u8; STRONG_SESSION_KEY_LENGTH];
for (i, h_Si) in even_half_of_S_hash
.iter()
.zip(odd_half_of_S_hash.iter())
.enumerate()
{
vK[i * 2] = *h_Si.0;
vK[i * 2 + 1] = *h_Si.1;
}
let K = BigNumber::from_bytes_le(&vK);
debug!("K = {:?}", &K);
K
}
#[allow(non_snake_case)]
pub(crate) fn calculate_proof_M<const KEY_LENGTH: usize, const SALT_LENGTH: usize>(
N: &PrimeModulus,
g: &Generator,
I: UsernameRef,
s: &Salt,
A: &PublicKey,
B: &PublicKey,
K: &StrongSessionKey,
) -> Proof {
let xor_hash: Hash = calculate_hash_N_xor_g::<KEY_LENGTH>(N, g);
let username_hash = HashFunc::new().chain(I.as_bytes()).finalize();
debug!("H(I) = {:?}", &username_hash);
let M: Proof = HashFunc::new()
.chain(xor_hash)
.chain(username_hash)
.chain(s.to_array_pad_zero::<SALT_LENGTH>())
.chain(A.to_array_pad_zero::<KEY_LENGTH>())
.chain(B.to_array_pad_zero::<KEY_LENGTH>())
.chain(K.to_array_pad_zero::<STRONG_SESSION_KEY_LENGTH>())
.into();
debug!("M = {:?}", &M);
M
}
#[allow(non_snake_case)]
pub(crate) fn calculate_strong_proof_M2<const KEY_LENGTH: usize>(
A: &PublicKey,
M: &Proof,
K: &StrongSessionKey,
) -> StrongProof {
let M2: StrongProof = HashFunc::new()
.chain(A.to_array_pad_zero::<KEY_LENGTH>())
.chain(M.to_array_pad_zero::<HASH_LENGTH>())
.chain(K.to_array_pad_zero::<STRONG_SESSION_KEY_LENGTH>())
.into();
debug!("M2 = {:?}", &M2);
M2
}
#[allow(non_snake_case)]
fn calculate_hash_N_xor_g<const KEY_LENGTH: usize>(N: &PrimeModulus, g: &Generator) -> Hash {
let mut h = HashFunc::new()
.chain(N.to_array_pad_zero::<KEY_LENGTH>())
.finalize();
let h_g = HashFunc::new().chain(g.to_vec().as_slice()).finalize();
for (i, v) in h.iter_mut().enumerate() {
*v ^= h_g[i];
}
let H_n_g: Hash = h.into();
debug!("H(N) xor H(g) = {:X?}", &H_n_g);
H_n_g
}
#[allow(non_snake_case)]
pub(crate) fn calculate_password_verifier_v(
N: &PrimeModulus,
g: &Generator,
x: &PrivateKey,
) -> PasswordVerifier {
g.modpow(x, N)
}
#[allow(non_snake_case)]
pub(crate) fn calculate_u<const KEY_LENGTH: usize>(A: &PublicKey, B: &PublicKey) -> BigNumber {
let u = hash::<KEY_LENGTH>(A, B);
debug!("u = {:?}", &u);
u
}
#[allow(non_snake_case)]
pub(crate) fn calculate_pubkey_A(N: &PrimeModulus, g: &Generator, a: &PrivateKey) -> PublicKey {
let A = g.modpow(a, N);
debug!("A = {:?}", &A);
A
}
#[allow(non_snake_case)]
pub(crate) fn calculate_pubkey_B(
N: &PrimeModulus,
k: &MultiplierParameter,
g: &Generator,
v: &PasswordVerifier,
b: &PrivateKey,
) -> PublicKey {
let g_mod_N = g.modpow(b, N);
let B = &((k * v) + g_mod_N) % N;
debug!("B = {:?}", &B);
B
}
#[allow(non_snake_case)]
#[allow(dead_code)]
pub(crate) fn calculate_private_key_x(
I: UsernameRef,
p: &ClearTextPassword,
s: &Salt,
) -> PrivateKey {
let ph = calculate_p_hash(I, p);
let x: PrivateKey = HashFunc::new()
.chain(s.to_vec().as_slice())
.chain(ph)
.into();
debug!("x = {:?}", &x);
x
}
#[allow(non_snake_case)]
#[cfg(not(feature = "legacy"))]
pub(crate) fn calculate_p_hash(I: UsernameRef, p: &ClearTextPassword) -> Hash {
HashFunc::new()
.chain(I.as_bytes())
.chain(":".as_bytes())
.chain(p.as_bytes())
.finalize()
.into()
}
#[allow(non_snake_case)]
#[cfg(feature = "legacy")]
pub(crate) fn calculate_p_hash(I: UsernameRef, p: &ClearTextPassword) -> Hash {
HashFunc::new()
.chain(I.to_uppercase().as_bytes())
.chain(":".as_bytes())
.chain(p.to_uppercase().as_bytes())
.finalize()
.into()
}
#[allow(non_snake_case)]
#[cfg(not(feature = "legacy"))]
pub(crate) fn calculate_k<const KEY_LENGTH: usize>(
N: &PrimeModulus,
g: &Generator,
) -> MultiplierParameter {
HashFunc::new()
.chain(N.to_vec().as_slice())
.chain(g.to_array_pad_zero::<KEY_LENGTH>())
.into()
}
#[cfg(feature = "legacy")]
pub(crate) fn calculate_k<const KEY_LENGTH: usize>(
_: &PrimeModulus,
_: &Generator,
) -> MultiplierParameter {
MultiplierParameter::from(3)
}
pub(crate) fn generate_private_key<const KEY_LENGTH: usize>() -> PrivateKey {
PrivateKey::new_rand(KEY_LENGTH)
}
pub(crate) fn generate_salt<const SALT_LENGTH: usize>() -> Salt {
Salt::new_rand(SALT_LENGTH)
}
#[cfg(test)]
#[cfg(feature = "legacy")]
mod tests {
use std::convert::TryInto;
use crate::api::host::tests::Mock;
use crate::defaults::Srp6_256;
use super::*;
const KEY_LENGTH: usize = 32;
const SALT_LENGTH: usize = 32;
#[test]
fn should_generate_a_salt_of_n() {
let s = generate_salt::<32>();
assert_eq!(s.to_vec().len(), 32);
}
#[test]
fn should_calculate_users_private_key_x_and_password_verifier() {
let params = Srp6_256::default();
let x = &calculate_private_key_x(Mock::I(), Mock::p(), &Mock::s());
let v = calculate_password_verifier_v(¶ms.N, ¶ms.g, x);
assert_eq!(&v, &Mock::v());
}
#[test]
#[allow(non_snake_case)]
fn should_calculate_servers_pubkey_B() {
let params = Srp6_256::default();
let B = calculate_pubkey_B(¶ms.N, ¶ms.k, ¶ms.g, &Mock::v(), &Mock::b());
assert_eq!(B, Mock::B())
}
#[test]
fn should_calculate_session_key_on_host_side() {
let params = Srp6_256::default();
let session_key_s = calculate_session_key_S_for_host::<KEY_LENGTH>(
¶ms.N,
&Mock::A(),
&Mock::B(),
&Mock::b(),
&Mock::v(),
)
.unwrap();
assert!(!session_key_s.is_zero());
assert_eq!(&session_key_s, &Mock::S())
}
#[test]
#[should_panic]
#[allow(non_snake_case)]
fn should_fail_for_public_key_mod_N_is_zero() {
let params = Srp6_256::default();
calculate_session_key_S_for_host::<KEY_LENGTH>(
¶ms.N,
¶ms.N,
&Mock::B(),
&Mock::b(),
&Mock::v(),
)
.unwrap();
}
#[test]
fn should_calculate_hash_of_a_session_key() {
let hash_of_session_key = calculate_session_key_hash_interleave_K::<KEY_LENGTH>(&Mock::S());
assert_eq!(&hash_of_session_key, &Mock::K())
}
#[test]
fn should_calculate_the_xor_hash_right() {
use hex_literal::hex;
let params = Srp6_256::default();
let h = calculate_hash_N_xor_g::<KEY_LENGTH>(¶ms.N, ¶ms.g);
const EXPECTED_HASH_LE: Hash =
hex!("DD 7B B0 3A 38 AC 73 11 03 98 7C 5A 50 6F CA 96 6C 7B C2 A7");
let bn: BigNumber = "A7C27B6C96CA6F505A7C98031173AC383AB07BDD"
.try_into()
.unwrap();
assert_eq!(bn.to_vec().as_slice(), &EXPECTED_HASH_LE);
assert_eq!(h, EXPECTED_HASH_LE);
}
#[test]
fn should_calculate_proof() {
let params = Srp6_256::default();
let proof_m = calculate_proof_M::<KEY_LENGTH, SALT_LENGTH>(
¶ms.N,
¶ms.g,
Mock::I(),
&Mock::s(),
&Mock::A(),
&Mock::B(),
&Mock::K(),
);
assert_eq!(&proof_m, &Mock::M())
}
#[test]
fn should_panic_client_key_calc_for_mod_zero_public_server_key() {
let params = Srp6_256::default();
let res = calculate_session_key_S_for_client::<32>(
¶ms.N, ¶ms.N, ¶ms.N, ¶ms.N, ¶ms.N, ¶ms.N, ¶ms.N,
);
assert!(res.is_err());
assert_eq!(res.err().unwrap(), Srp6Error::InvalidPublicKey(params.N));
}
}