use crate::{group::Group, utils};
use digest::Digest;
use num_bigint::BigUint;
use num_traits::Zero;
use rand::RngCore;
use rand::rngs::OsRng;
use thiserror::Error;
pub struct Server<D: Digest> {
pub group: Group,
pub k: BigUint,
_phantom: std::marker::PhantomData<D>,
}
impl<D: Digest> Server<D> {
pub fn new(group: Group) -> Self {
let k = utils::compute_k::<D>(&group);
Self {
group,
k,
_phantom: std::marker::PhantomData,
}
}
pub fn generate_keypair(
&self,
length_key: usize,
verifier: &BigUint,
) -> Result<(BigUint, BigUint), ServerError> {
if length_key == 0 {
return Err(ServerError::InvalidKeyLength(
"Key length must be greater than 0".to_string(),
));
}
if !length_key.is_multiple_of(8) {
return Err(ServerError::InvalidKeyLength(
"The key length must be a multiple of 8".to_string(),
));
}
let max_bits = self.group.n.bits() - 1;
if length_key > max_bits as usize {
return Err(ServerError::InvalidKeyLength(format!(
"Key length {} exceeds maximum allowed {} bits for this group",
length_key, max_bits
)));
}
let mut b_bytes = vec![0u8; length_key / 8];
let mut rng = OsRng;
rng.fill_bytes(&mut b_bytes);
self.generate_keypair_from_private_key(&b_bytes, verifier)
}
pub fn generate_keypair_from_private_key(
&self,
private_key_bytes: &[u8],
verifier: &BigUint,
) -> Result<(BigUint, BigUint), ServerError> {
let private_key = BigUint::from_bytes_be(private_key_bytes);
let k_mod = &self.k % &self.group.n;
let v_mod = verifier % &self.group.n;
let kv = (k_mod * v_mod) % &self.group.n;
let gb = self.group.g.modpow(&private_key, &self.group.n);
let public_key = (kv + gb) % &self.group.n;
if public_key.is_zero() {
return Err(ServerError::InvalidPublicKey(
"Public key B cannot be zero".to_string(),
));
}
Ok((private_key, public_key))
}
pub fn compute_premaster_key(
&self,
public_client_key: &BigUint,
verifier: &BigUint,
private_server_key: &BigUint,
public_server_key: &BigUint,
) -> Result<BigUint, ServerError> {
if public_client_key.is_zero() || public_client_key >= &self.group.n {
return Err(ServerError::InvalidPublicKey(
"Client public key A is invalid".to_string(),
));
}
let a_bytes = utils::zero_pad(public_client_key.to_bytes_be(), self.group.length_n / 8);
let b_bytes = utils::zero_pad(public_server_key.to_bytes_be(), self.group.length_n / 8);
let mut u_hash = D::new();
u_hash.update(&a_bytes);
u_hash.update(&b_bytes);
let u = BigUint::from_bytes_be(&u_hash.finalize());
if u.is_zero() {
return Err(ServerError::InvalidParameter(
"u parameter is zero".to_string(),
));
}
let v_u = verifier.modpow(&u, &self.group.n);
let a_v_u = (public_client_key * v_u) % &self.group.n;
let premaster_key = a_v_u.modpow(private_server_key, &self.group.n);
if premaster_key.is_zero() {
return Err(ServerError::KeyDerivationFailed(
"Premaster key is zero".to_string(),
));
}
Ok(premaster_key)
}
pub fn compute_session_key(&self, premaster_key: &BigUint) -> BigUint {
let mut hasher = D::new();
hasher.update(premaster_key.to_bytes_be());
BigUint::from_bytes_be(&hasher.finalize())
}
pub fn compute_server_proof(
&self,
public_client_key: &BigUint,
client_proof: &[u8],
session_key: &BigUint,
) -> Vec<u8> {
utils::compute_server_proof::<D>(public_client_key, client_proof, session_key)
}
pub fn verify_client_proof(
&self,
username: &[u8],
salt: &[u8],
public_client_key: &BigUint,
public_server_key: &BigUint,
session_key: &BigUint,
client_proof: &[u8],
) -> bool {
let expected_proof = utils::compute_client_proof::<D>(
&self.group,
username,
salt,
public_client_key,
public_server_key,
session_key,
);
expected_proof == client_proof
}
}
#[derive(Debug, Error)]
pub enum ServerError {
#[error("Invalid key length: {0}")]
InvalidKeyLength(String),
#[error("Invalid public key: {0}")]
InvalidPublicKey(String),
#[error("Invalid verifier: {0}")]
InvalidVerifier(String),
#[error("Invalid parameter: {0}")]
InvalidParameter(String),
#[error("Key derivation failed: {0}")]
KeyDerivationFailed(String),
}