srp6a 0.1.0

Implementation of SRP 6a (Secure Remote Password) according to RFC 5054 (https://datatracker.ietf.org/doc/html/rfc5054)
Documentation
use crate::{group::Group, utils};
use digest::Digest;
use num_bigint::BigUint;
use num_traits::Zero;
use rand::RngCore;
use rand::rngs::OsRng;
use thiserror::Error;

pub struct Client<D: Digest> {
    pub group: Group,
    pub k: BigUint,
    _phantom: std::marker::PhantomData<D>,
}

impl<D: Digest> Client<D> {
    pub fn new(group: Group) -> Self {
        let k = utils::compute_k::<D>(&group);

        Self {
            group,
            k,
            _phantom: std::marker::PhantomData,
        }
    }

    pub fn compute_verifier(&self, username: &[u8], password: &[u8], salt: &[u8]) -> BigUint {
        // x = H(salt | H(username | ":" | password))
        let mut credentials_hash = D::new();
        credentials_hash.update(username);
        credentials_hash.update(b":");
        credentials_hash.update(password);
        let hash_ip = credentials_hash.finalize();

        let mut x_hash = D::new();
        x_hash.update(salt);
        x_hash.update(hash_ip);
        let x_bytes = x_hash.finalize();
        let x = BigUint::from_bytes_be(&x_bytes);

        // v = g^x mod N
        self.group.g.modpow(&x, &self.group.n)
    }

    pub fn generate_keypair(&self, length_key: usize) -> Result<(BigUint, BigUint), ClientError> {
        if length_key == 0 {
            return Err(ClientError::InvalidKeyLength(
                "Key length must be greater than 0".to_string(),
            ));
        }

        if !length_key.is_multiple_of(8) {
            return Err(ClientError::InvalidKeyLength(
                "The key length must be a multiple of 8".to_string(),
            ));
        }

        let max_bits = self.group.n.bits() - 1;
        if length_key > max_bits as usize {
            return Err(ClientError::InvalidKeyLength(format!(
                "Key length {} exceeds maximum allowed {} bits for this group",
                length_key, max_bits
            )));
        }

        let mut a_bytes = vec![0u8; length_key / 8];
        let mut rng = OsRng;
        rng.fill_bytes(&mut a_bytes);

        self.generate_keypair_from_private_key(&a_bytes)
    }

    pub fn generate_keypair_from_private_key(
        &self,
        private_key_bytes: &[u8],
    ) -> Result<(BigUint, BigUint), ClientError> {
        let private_key = BigUint::from_bytes_be(private_key_bytes);

        if private_key.is_zero() || private_key >= self.group.n {
            return Err(ClientError::InvalidPrivateKey(
                "Private key out of range".to_string(),
            ));
        }

        let public_key = self.group.g.modpow(&private_key, &self.group.n);

        if public_key.is_zero() {
            return Err(ClientError::InvalidPublicKey(
                "Public key A cannot be zero".to_string(),
            ));
        }

        Ok((private_key, public_key))
    }

    pub fn compute_premaster_key(
        &self,
        public_server_key: &BigUint,
        private_client_key: &BigUint,
        public_client_key: &BigUint,
        salt: &[u8],
        username: &[u8],
        password: &[u8],
    ) -> Result<BigUint, ClientError> {
        let x = {
            let mut credentials_hash = D::new();
            credentials_hash.update(username);
            credentials_hash.update(b":");
            credentials_hash.update(password);
            let hash_ip = credentials_hash.finalize();

            let mut x_hash = D::new();
            x_hash.update(salt);
            x_hash.update(hash_ip);
            BigUint::from_bytes_be(&x_hash.finalize())
        };

        let a_bytes = utils::zero_pad(public_client_key.to_bytes_be(), self.group.length_n / 8);
        let b_bytes = utils::zero_pad(public_server_key.to_bytes_be(), self.group.length_n / 8);

        let mut u_hash = D::new();
        u_hash.update(&a_bytes);
        u_hash.update(&b_bytes);
        let u = BigUint::from_bytes_be(&u_hash.finalize());

        if u.is_zero() {
            return Err(ClientError::InvalidParameter(
                "u parameter is zero".to_string(),
            ));
        }

        // S = (B - k*g^x)^(a + u*x) mod N
        let g_x = self.group.g.modpow(&x, &self.group.n);
        let k_g_x = (&self.k * g_x) % &self.group.n;

        let b_minus_k_g_x = if public_server_key >= &k_g_x {
            (public_server_key - &k_g_x) % &self.group.n
        } else {
            (&self.group.n - (&k_g_x - public_server_key)) % &self.group.n
        };

        let exponent = private_client_key + &u * x;
        let premaster_key = b_minus_k_g_x.modpow(&exponent, &self.group.n);

        if premaster_key.is_zero() {
            return Err(ClientError::KeyDerivationFailed(
                "Premaster key is zero".to_string(),
            ));
        }

        Ok(premaster_key)
    }

    pub fn compute_session_key(&self, premaster_key: &BigUint) -> BigUint {
        let mut hasher = D::new();
        hasher.update(premaster_key.to_bytes_be());
        BigUint::from_bytes_be(&hasher.finalize())
    }

    pub fn compute_client_proof(
        &self,
        username: &[u8],
        salt: &[u8],
        public_client_key: &BigUint,
        public_server_key: &BigUint,
        session_key: &BigUint,
    ) -> Vec<u8> {
        utils::compute_client_proof::<D>(
            &self.group,
            username,
            salt,
            public_client_key,
            public_server_key,
            session_key,
        )
    }

    pub fn verify_server_proof(
        &self,
        public_client_key: &BigUint,
        client_proof: &[u8],
        server_proof: &[u8],
        session_key: &BigUint,
    ) -> bool {
        utils::compute_server_proof::<D>(public_client_key, client_proof, session_key)
            == server_proof
    }
}

#[derive(Debug, Error)]
pub enum ClientError {
    #[error("Invalid key length: {0}")]
    InvalidKeyLength(String),

    #[error("Invalid private key: {0}")]
    InvalidPrivateKey(String),

    #[error("Invalid public key: {0}")]
    InvalidPublicKey(String),

    #[error("Invalid parameter: {0}")]
    InvalidParameter(String),

    #[error("Key derivation failed: {0}")]
    KeyDerivationFailed(String),
}