srp6a 0.1.0

Implementation of SRP 6a (Secure Remote Password) according to RFC 5054 (https://datatracker.ietf.org/doc/html/rfc5054)
Documentation
use crate::{group::Group, utils};
use digest::Digest;
use num_bigint::BigUint;
use num_traits::Zero;
use rand::RngCore;
use rand::rngs::OsRng;
use thiserror::Error;

pub struct Server<D: Digest> {
    pub group: Group,
    pub k: BigUint,
    _phantom: std::marker::PhantomData<D>,
}

impl<D: Digest> Server<D> {
    pub fn new(group: Group) -> Self {
        let k = utils::compute_k::<D>(&group);

        Self {
            group,
            k,
            _phantom: std::marker::PhantomData,
        }
    }

    pub fn generate_keypair(
        &self,
        length_key: usize,
        verifier: &BigUint,
    ) -> Result<(BigUint, BigUint), ServerError> {
        if length_key == 0 {
            return Err(ServerError::InvalidKeyLength(
                "Key length must be greater than 0".to_string(),
            ));
        }

        if !length_key.is_multiple_of(8) {
            return Err(ServerError::InvalidKeyLength(
                "The key length must be a multiple of 8".to_string(),
            ));
        }

        let max_bits = self.group.n.bits() - 1;
        if length_key > max_bits as usize {
            return Err(ServerError::InvalidKeyLength(format!(
                "Key length {} exceeds maximum allowed {} bits for this group",
                length_key, max_bits
            )));
        }

        let mut b_bytes = vec![0u8; length_key / 8];
        let mut rng = OsRng;
        rng.fill_bytes(&mut b_bytes);

        self.generate_keypair_from_private_key(&b_bytes, verifier)
    }

    pub fn generate_keypair_from_private_key(
        &self,
        private_key_bytes: &[u8],
        verifier: &BigUint,
    ) -> Result<(BigUint, BigUint), ServerError> {
        let private_key = BigUint::from_bytes_be(private_key_bytes);

        let k_mod = &self.k % &self.group.n;
        let v_mod = verifier % &self.group.n;
        let kv = (k_mod * v_mod) % &self.group.n;

        let gb = self.group.g.modpow(&private_key, &self.group.n);
        let public_key = (kv + gb) % &self.group.n;

        if public_key.is_zero() {
            return Err(ServerError::InvalidPublicKey(
                "Public key B cannot be zero".to_string(),
            ));
        }

        Ok((private_key, public_key))
    }

    pub fn compute_premaster_key(
        &self,
        public_client_key: &BigUint,
        verifier: &BigUint,
        private_server_key: &BigUint,
        public_server_key: &BigUint,
    ) -> Result<BigUint, ServerError> {
        if public_client_key.is_zero() || public_client_key >= &self.group.n {
            return Err(ServerError::InvalidPublicKey(
                "Client public key A is invalid".to_string(),
            ));
        }

        let a_bytes = utils::zero_pad(public_client_key.to_bytes_be(), self.group.length_n / 8);
        let b_bytes = utils::zero_pad(public_server_key.to_bytes_be(), self.group.length_n / 8);

        let mut u_hash = D::new();
        u_hash.update(&a_bytes);
        u_hash.update(&b_bytes);
        let u = BigUint::from_bytes_be(&u_hash.finalize());

        if u.is_zero() {
            return Err(ServerError::InvalidParameter(
                "u parameter is zero".to_string(),
            ));
        }

        let v_u = verifier.modpow(&u, &self.group.n);
        let a_v_u = (public_client_key * v_u) % &self.group.n;
        let premaster_key = a_v_u.modpow(private_server_key, &self.group.n);

        if premaster_key.is_zero() {
            return Err(ServerError::KeyDerivationFailed(
                "Premaster key is zero".to_string(),
            ));
        }

        Ok(premaster_key)
    }

    pub fn compute_session_key(&self, premaster_key: &BigUint) -> BigUint {
        let mut hasher = D::new();
        hasher.update(premaster_key.to_bytes_be());
        BigUint::from_bytes_be(&hasher.finalize())
    }

    pub fn compute_server_proof(
        &self,
        public_client_key: &BigUint,
        client_proof: &[u8],
        session_key: &BigUint,
    ) -> Vec<u8> {
        utils::compute_server_proof::<D>(public_client_key, client_proof, session_key)
    }

    pub fn verify_client_proof(
        &self,
        username: &[u8],
        salt: &[u8],
        public_client_key: &BigUint,
        public_server_key: &BigUint,
        session_key: &BigUint,
        client_proof: &[u8],
    ) -> bool {
        let expected_proof = utils::compute_client_proof::<D>(
            &self.group,
            username,
            salt,
            public_client_key,
            public_server_key,
            session_key,
        );
        expected_proof == client_proof
    }
}

#[derive(Debug, Error)]
pub enum ServerError {
    #[error("Invalid key length: {0}")]
    InvalidKeyLength(String),

    #[error("Invalid public key: {0}")]
    InvalidPublicKey(String),

    #[error("Invalid verifier: {0}")]
    InvalidVerifier(String),

    #[error("Invalid parameter: {0}")]
    InvalidParameter(String),

    #[error("Key derivation failed: {0}")]
    KeyDerivationFailed(String),
}