use crate::{group::Group, utils};
use digest::Digest;
use num_bigint::BigUint;
use num_traits::Zero;
use rand::RngCore;
use rand::rngs::OsRng;
use thiserror::Error;
pub struct Client<D: Digest> {
pub group: Group,
pub k: BigUint,
_phantom: std::marker::PhantomData<D>,
}
impl<D: Digest> Client<D> {
pub fn new(group: Group) -> Self {
let k = utils::compute_k::<D>(&group);
Self {
group,
k,
_phantom: std::marker::PhantomData,
}
}
pub fn compute_verifier(&self, username: &[u8], password: &[u8], salt: &[u8]) -> BigUint {
let mut credentials_hash = D::new();
credentials_hash.update(username);
credentials_hash.update(b":");
credentials_hash.update(password);
let hash_ip = credentials_hash.finalize();
let mut x_hash = D::new();
x_hash.update(salt);
x_hash.update(hash_ip);
let x_bytes = x_hash.finalize();
let x = BigUint::from_bytes_be(&x_bytes);
self.group.g.modpow(&x, &self.group.n)
}
pub fn generate_keypair(&self, length_key: usize) -> Result<(BigUint, BigUint), ClientError> {
if length_key == 0 {
return Err(ClientError::InvalidKeyLength(
"Key length must be greater than 0".to_string(),
));
}
if !length_key.is_multiple_of(8) {
return Err(ClientError::InvalidKeyLength(
"The key length must be a multiple of 8".to_string(),
));
}
let max_bits = self.group.n.bits() - 1;
if length_key > max_bits as usize {
return Err(ClientError::InvalidKeyLength(format!(
"Key length {} exceeds maximum allowed {} bits for this group",
length_key, max_bits
)));
}
let mut a_bytes = vec![0u8; length_key / 8];
let mut rng = OsRng;
rng.fill_bytes(&mut a_bytes);
self.generate_keypair_from_private_key(&a_bytes)
}
pub fn generate_keypair_from_private_key(
&self,
private_key_bytes: &[u8],
) -> Result<(BigUint, BigUint), ClientError> {
let private_key = BigUint::from_bytes_be(private_key_bytes);
if private_key.is_zero() || private_key >= self.group.n {
return Err(ClientError::InvalidPrivateKey(
"Private key out of range".to_string(),
));
}
let public_key = self.group.g.modpow(&private_key, &self.group.n);
if public_key.is_zero() {
return Err(ClientError::InvalidPublicKey(
"Public key A cannot be zero".to_string(),
));
}
Ok((private_key, public_key))
}
pub fn compute_premaster_key(
&self,
public_server_key: &BigUint,
private_client_key: &BigUint,
public_client_key: &BigUint,
salt: &[u8],
username: &[u8],
password: &[u8],
) -> Result<BigUint, ClientError> {
let x = {
let mut credentials_hash = D::new();
credentials_hash.update(username);
credentials_hash.update(b":");
credentials_hash.update(password);
let hash_ip = credentials_hash.finalize();
let mut x_hash = D::new();
x_hash.update(salt);
x_hash.update(hash_ip);
BigUint::from_bytes_be(&x_hash.finalize())
};
let a_bytes = utils::zero_pad(public_client_key.to_bytes_be(), self.group.length_n / 8);
let b_bytes = utils::zero_pad(public_server_key.to_bytes_be(), self.group.length_n / 8);
let mut u_hash = D::new();
u_hash.update(&a_bytes);
u_hash.update(&b_bytes);
let u = BigUint::from_bytes_be(&u_hash.finalize());
if u.is_zero() {
return Err(ClientError::InvalidParameter(
"u parameter is zero".to_string(),
));
}
let g_x = self.group.g.modpow(&x, &self.group.n);
let k_g_x = (&self.k * g_x) % &self.group.n;
let b_minus_k_g_x = if public_server_key >= &k_g_x {
(public_server_key - &k_g_x) % &self.group.n
} else {
(&self.group.n - (&k_g_x - public_server_key)) % &self.group.n
};
let exponent = private_client_key + &u * x;
let premaster_key = b_minus_k_g_x.modpow(&exponent, &self.group.n);
if premaster_key.is_zero() {
return Err(ClientError::KeyDerivationFailed(
"Premaster key is zero".to_string(),
));
}
Ok(premaster_key)
}
pub fn compute_session_key(&self, premaster_key: &BigUint) -> BigUint {
let mut hasher = D::new();
hasher.update(premaster_key.to_bytes_be());
BigUint::from_bytes_be(&hasher.finalize())
}
pub fn compute_client_proof(
&self,
username: &[u8],
salt: &[u8],
public_client_key: &BigUint,
public_server_key: &BigUint,
session_key: &BigUint,
) -> Vec<u8> {
utils::compute_client_proof::<D>(
&self.group,
username,
salt,
public_client_key,
public_server_key,
session_key,
)
}
pub fn verify_server_proof(
&self,
public_client_key: &BigUint,
client_proof: &[u8],
server_proof: &[u8],
session_key: &BigUint,
) -> bool {
utils::compute_server_proof::<D>(public_client_key, client_proof, session_key)
== server_proof
}
}
#[derive(Debug, Error)]
pub enum ClientError {
#[error("Invalid key length: {0}")]
InvalidKeyLength(String),
#[error("Invalid private key: {0}")]
InvalidPrivateKey(String),
#[error("Invalid public key: {0}")]
InvalidPublicKey(String),
#[error("Invalid parameter: {0}")]
InvalidParameter(String),
#[error("Key derivation failed: {0}")]
KeyDerivationFailed(String),
}