wasi-crypto 0.1.9

Experimental implementation of the WASI cryptography APIs
Documentation
use std::sync::Arc;

use ::rsa::{BigUint, PublicKey as _, PublicKeyParts as _, RsaPrivateKey, RsaPublicKey};
use ::sha2::{Digest, Sha256, Sha384, Sha512};
use rsa::pkcs1::{DecodeRsaPrivateKey as _, DecodeRsaPublicKey as _, LineEnding};
use rsa::pkcs8::{
    DecodePrivateKey as _, DecodePublicKey as _, EncodePrivateKey as _, EncodePublicKey as _,
};
use serde::{Deserialize, Serialize};
use zeroize::Zeroize;

use super::*;
use crate::asymmetric_common::*;
use crate::error::*;
use crate::rand::SecureRandom;

const RAW_ENCODING_VERSION: u16 = 1;
const RAW_ENCODING_ALG_ID: u16 = 1;
const MIN_MODULUS_SIZE: usize = 2048;
const MAX_MODULUS_SIZE: usize = 4096;

#[derive(Debug, Clone)]
pub struct RsaSignatureSecretKey {
    pub alg: SignatureAlgorithm,
}

#[derive(Serialize, Deserialize, Zeroize)]
struct RsaSignatureKeyPairParts {
    version: u16,
    alg_id: u16,
    n: ::rsa::BigUint,
    e: ::rsa::BigUint,
    d: ::rsa::BigUint,
    primes: Vec<::rsa::BigUint>,
}

#[derive(Clone, Debug)]
pub struct RsaSignatureKeyPair {
    pub alg: SignatureAlgorithm,
    ctx: ::rsa::RsaPrivateKey,
}

fn modulus_bits(alg: SignatureAlgorithm) -> Result<usize, CryptoError> {
    let modulus_bits = match alg {
        SignatureAlgorithm::RSA_PKCS1_2048_SHA256
        | SignatureAlgorithm::RSA_PKCS1_2048_SHA384
        | SignatureAlgorithm::RSA_PKCS1_2048_SHA512
        | SignatureAlgorithm::RSA_PSS_2048_SHA256
        | SignatureAlgorithm::RSA_PSS_2048_SHA384
        | SignatureAlgorithm::RSA_PSS_2048_SHA512 => 2048,
        SignatureAlgorithm::RSA_PKCS1_3072_SHA384
        | SignatureAlgorithm::RSA_PKCS1_3072_SHA512
        | SignatureAlgorithm::RSA_PSS_3072_SHA384
        | SignatureAlgorithm::RSA_PSS_3072_SHA512 => 3072,
        SignatureAlgorithm::RSA_PKCS1_4096_SHA512 | SignatureAlgorithm::RSA_PSS_4096_SHA512 => 4096,
        _ => bail!(CryptoError::UnsupportedAlgorithm),
    };
    Ok(modulus_bits)
}

impl RsaSignatureKeyPair {
    fn from_pkcs8(alg: SignatureAlgorithm, der: &[u8]) -> Result<Self, CryptoError> {
        ensure!(der.len() < 4096, CryptoError::InvalidKey);
        let ctx = ::rsa::RsaPrivateKey::from_pkcs8_der(der)
            .or_else(|_| ::rsa::RsaPrivateKey::from_pkcs1_der(der))
            .map_err(|_| CryptoError::InvalidKey)?;
        Ok(RsaSignatureKeyPair { alg, ctx })
    }

    fn from_pem(alg: SignatureAlgorithm, pem: &[u8]) -> Result<Self, CryptoError> {
        ensure!(pem.len() < 4096, CryptoError::InvalidKey);
        let pem = std::str::from_utf8(pem)
            .map_err(|_| CryptoError::InvalidKey)?
            .trim();
        let ctx = ::rsa::RsaPrivateKey::from_pkcs8_pem(pem)
            .or_else(|_| ::rsa::RsaPrivateKey::from_pkcs1_pem(pem))
            .map_err(|_| CryptoError::InvalidKey)?;
        Ok(RsaSignatureKeyPair { alg, ctx })
    }

    fn from_local(alg: SignatureAlgorithm, local: &[u8]) -> Result<Self, CryptoError> {
        ensure!(local.len() < 2048, CryptoError::InvalidKey);
        let parts: RsaSignatureKeyPairParts =
            bincode::deserialize(local).map_err(|_| CryptoError::InvalidKey)?;
        ensure!(
            parts.version == RAW_ENCODING_VERSION && parts.alg_id == RAW_ENCODING_ALG_ID,
            CryptoError::InvalidKey
        );
        let ctx = ::rsa::RsaPrivateKey::from_components(parts.n, parts.e, parts.d, parts.primes);
        Ok(RsaSignatureKeyPair { alg, ctx })
    }

    fn to_pkcs8(&self) -> Result<Vec<u8>, CryptoError> {
        self.ctx
            .to_pkcs8_der()
            .map_err(|_| CryptoError::InternalError)
            .map(|x| x.as_ref().to_vec())
    }

    fn to_pem(&self) -> Result<Vec<u8>, CryptoError> {
        self.ctx
            .to_pkcs8_pem(LineEnding::LF)
            .map(|s| s.as_bytes().to_vec())
            .map_err(|_| CryptoError::InternalError)
    }

    fn to_local(&self) -> Result<Vec<u8>, CryptoError> {
        let parts = RsaSignatureKeyPairParts {
            version: RAW_ENCODING_VERSION,
            alg_id: RAW_ENCODING_ALG_ID,
            n: self.ctx.n().clone(),
            e: self.ctx.e().clone(),
            d: self.ctx.d().clone(),
            primes: self.ctx.primes().to_vec(),
        };
        let local = bincode::serialize(&parts).map_err(|_| CryptoError::InternalError)?;
        Ok(local)
    }

    pub fn generate(
        alg: SignatureAlgorithm,
        _options: Option<SignatureOptions>,
    ) -> Result<Self, CryptoError> {
        let modulus_bits = modulus_bits(alg)?;
        let mut rng = SecureRandom::new();
        let ctx = ::rsa::RsaPrivateKey::new(&mut rng, modulus_bits)
            .map_err(|_| CryptoError::UnsupportedAlgorithm)?;
        Ok(RsaSignatureKeyPair { alg, ctx })
    }

    pub fn import(
        alg: SignatureAlgorithm,
        encoded: &[u8],
        encoding: KeyPairEncoding,
    ) -> Result<Self, CryptoError> {
        match alg.family() {
            SignatureAlgorithmFamily::RSA => {}
            _ => bail!(CryptoError::UnsupportedAlgorithm),
        };
        let mut kp = match encoding {
            KeyPairEncoding::Pkcs8 => Self::from_pkcs8(alg, encoded)?,
            KeyPairEncoding::Pem => Self::from_pem(alg, encoded)?,
            KeyPairEncoding::Local => Self::from_local(alg, encoded)?,
            _ => bail!(CryptoError::UnsupportedEncoding),
        };
        let modulus_size = kp.ctx.size();
        let min_modulus_bits = modulus_bits(alg)?;
        ensure!(
            (min_modulus_bits / 8..=MAX_MODULUS_SIZE / 8).contains(&modulus_size),
            CryptoError::InvalidKey
        );
        kp.ctx.validate().map_err(|_| CryptoError::InvalidKey)?;
        kp.ctx.precompute().map_err(|_| CryptoError::InvalidKey)?;
        Ok(kp)
    }

    pub fn export(&self, encoding: KeyPairEncoding) -> Result<Vec<u8>, CryptoError> {
        match encoding {
            KeyPairEncoding::Pkcs8 => self.to_pkcs8(),
            KeyPairEncoding::Pem => self.to_pem(),
            KeyPairEncoding::Local => self.to_local(),
            _ => bail!(CryptoError::UnsupportedEncoding),
        }
    }

    pub fn public_key(&self) -> Result<RsaSignaturePublicKey, CryptoError> {
        let ctx = self.ctx.to_public_key();
        Ok(RsaSignaturePublicKey { alg: self.alg, ctx })
    }
}

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct RsaSignature {
    pub raw: Vec<u8>,
}

impl RsaSignature {
    pub fn new(raw: Vec<u8>) -> Self {
        RsaSignature { raw }
    }

    pub fn from_raw(alg: SignatureAlgorithm, raw: &[u8]) -> Result<Self, CryptoError> {
        let expected_len = modulus_bits(alg)? / 8;
        ensure!(raw.len() == expected_len, CryptoError::InvalidSignature);
        Ok(Self::new(raw.to_vec()))
    }
}

impl SignatureLike for RsaSignature {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn as_ref(&self) -> &[u8] {
        &self.raw
    }
}

fn padding_scheme(alg: SignatureAlgorithm) -> ::rsa::PaddingScheme {
    match alg {
        SignatureAlgorithm::RSA_PKCS1_2048_SHA256 => {
            ::rsa::PaddingScheme::new_pkcs1v15_sign(Some(::rsa::Hash::SHA2_256))
        }
        SignatureAlgorithm::RSA_PKCS1_2048_SHA384 | SignatureAlgorithm::RSA_PKCS1_3072_SHA384 => {
            ::rsa::PaddingScheme::new_pkcs1v15_sign(Some(::rsa::Hash::SHA2_384))
        }
        SignatureAlgorithm::RSA_PKCS1_2048_SHA512
        | SignatureAlgorithm::RSA_PKCS1_3072_SHA512
        | SignatureAlgorithm::RSA_PKCS1_4096_SHA512 => {
            ::rsa::PaddingScheme::new_pkcs1v15_sign(Some(::rsa::Hash::SHA2_512))
        }

        SignatureAlgorithm::RSA_PSS_2048_SHA256 => {
            ::rsa::PaddingScheme::new_pss::<Sha256, _>(SecureRandom::new())
        }
        SignatureAlgorithm::RSA_PSS_2048_SHA384 | SignatureAlgorithm::RSA_PSS_3072_SHA384 => {
            ::rsa::PaddingScheme::new_pss::<Sha384, _>(SecureRandom::new())
        }
        SignatureAlgorithm::RSA_PSS_2048_SHA512
        | SignatureAlgorithm::RSA_PSS_3072_SHA512
        | SignatureAlgorithm::RSA_PSS_4096_SHA512 => {
            ::rsa::PaddingScheme::new_pss::<Sha512, _>(SecureRandom::new())
        }
        _ => unreachable!(),
    }
}

#[allow(clippy::large_enum_variant)]
#[derive(Debug)]
enum HashVariant {
    Sha256(Sha256),
    Sha384(Sha384),
    Sha512(Sha512),
}

impl HashVariant {
    fn for_alg(alg: SignatureAlgorithm) -> Result<Self, CryptoError> {
        let h = match alg {
            SignatureAlgorithm::RSA_PKCS1_2048_SHA256 | SignatureAlgorithm::RSA_PSS_2048_SHA256 => {
                HashVariant::Sha256(Sha256::new())
            }
            SignatureAlgorithm::RSA_PKCS1_2048_SHA384
            | SignatureAlgorithm::RSA_PKCS1_3072_SHA384
            | SignatureAlgorithm::RSA_PSS_2048_SHA384
            | SignatureAlgorithm::RSA_PSS_3072_SHA384 => HashVariant::Sha384(Sha384::new()),
            SignatureAlgorithm::RSA_PKCS1_2048_SHA512
            | SignatureAlgorithm::RSA_PKCS1_3072_SHA512
            | SignatureAlgorithm::RSA_PKCS1_4096_SHA512
            | SignatureAlgorithm::RSA_PSS_2048_SHA512
            | SignatureAlgorithm::RSA_PSS_3072_SHA512
            | SignatureAlgorithm::RSA_PSS_4096_SHA512 => HashVariant::Sha512(Sha512::new()),
            _ => bail!(CryptoError::UnsupportedAlgorithm),
        };
        Ok(h)
    }
}

#[derive(Debug)]
pub struct RsaSignatureState {
    pub kp: RsaSignatureKeyPair,
    h: HashVariant,
}

impl RsaSignatureState {
    pub fn new(kp: RsaSignatureKeyPair) -> Self {
        let h = HashVariant::for_alg(kp.alg).unwrap();
        RsaSignatureState { kp, h }
    }
}

impl SignatureStateLike for RsaSignatureState {
    fn update(&mut self, input: &[u8]) -> Result<(), CryptoError> {
        match &mut self.h {
            HashVariant::Sha256(x) => x.update(input),
            HashVariant::Sha384(x) => x.update(input),
            HashVariant::Sha512(x) => x.update(input),
        };
        Ok(())
    }

    fn sign(&mut self) -> Result<Signature, CryptoError> {
        let mut rng = SecureRandom::new();
        let digest = match &self.h {
            HashVariant::Sha256(x) => x.clone().finalize().as_slice().to_vec(),
            HashVariant::Sha384(x) => x.clone().finalize().as_slice().to_vec(),
            HashVariant::Sha512(x) => x.clone().finalize().as_slice().to_vec(),
        };
        let encoded_signature = self
            .kp
            .ctx
            .sign_blinded(&mut rng, padding_scheme(self.kp.alg), &digest)
            .map_err(|_| CryptoError::InvalidKey)?;
        let signature = RsaSignature::new(encoded_signature);
        Ok(Signature::new(Box::new(signature)))
    }
}

#[derive(Debug)]
pub struct RsaSignatureVerificationState {
    pub pk: RsaSignaturePublicKey,
    h: HashVariant,
}

impl RsaSignatureVerificationState {
    pub fn new(pk: RsaSignaturePublicKey) -> Self {
        let h = HashVariant::for_alg(pk.alg).unwrap();
        RsaSignatureVerificationState { pk, h }
    }
}

impl SignatureVerificationStateLike for RsaSignatureVerificationState {
    fn update(&mut self, input: &[u8]) -> Result<(), CryptoError> {
        match &mut self.h {
            HashVariant::Sha256(x) => x.update(input),
            HashVariant::Sha384(x) => x.update(input),
            HashVariant::Sha512(x) => x.update(input),
        };
        Ok(())
    }

    fn verify(&self, signature: &Signature) -> Result<(), CryptoError> {
        let signature = signature.inner();
        let signature = signature
            .as_any()
            .downcast_ref::<RsaSignature>()
            .ok_or(CryptoError::InvalidSignature)?;
        let digest = match &self.h {
            HashVariant::Sha256(x) => x.clone().finalize().as_slice().to_vec(),
            HashVariant::Sha384(x) => x.clone().finalize().as_slice().to_vec(),
            HashVariant::Sha512(x) => x.clone().finalize().as_slice().to_vec(),
        };
        self.pk
            .ctx
            .verify(padding_scheme(self.pk.alg), &digest, signature.as_ref())
            .map_err(|_| CryptoError::InvalidSignature)?;
        Ok(())
    }
}

#[derive(Serialize, Deserialize, Zeroize)]
struct RsaSignaturePublicKeyParts {
    version: u16,
    alg_id: u16,
    n: ::rsa::BigUint,
    e: ::rsa::BigUint,
}

#[derive(Clone, Debug)]
pub struct RsaSignaturePublicKey {
    pub alg: SignatureAlgorithm,
    ctx: ::rsa::RsaPublicKey,
}

impl RsaSignaturePublicKey {
    fn from_pkcs8(alg: SignatureAlgorithm, der: &[u8]) -> Result<Self, CryptoError> {
        ensure!(der.len() < 4096, CryptoError::InvalidKey);
        let ctx = ::rsa::RsaPublicKey::from_public_key_der(der)
            .or_else(|_| ::rsa::RsaPublicKey::from_pkcs1_der(der))
            .map_err(|_| CryptoError::InvalidKey)?;
        Ok(RsaSignaturePublicKey { alg, ctx })
    }

    fn from_pem(alg: SignatureAlgorithm, pem: &[u8]) -> Result<Self, CryptoError> {
        ensure!(pem.len() < 4096, CryptoError::InvalidKey);
        let pem = std::str::from_utf8(pem)
            .map_err(|_| CryptoError::InvalidKey)?
            .trim();
        let parsed_pem = ::rsa::RsaPublicKey::from_public_key_pem(pem)
            .or_else(|_| ::rsa::RsaPublicKey::from_pkcs1_pem(pem))
            .map_err(|_| CryptoError::InvalidKey)?;
        let ctx = ::rsa::RsaPublicKey::try_from(parsed_pem).map_err(|_| CryptoError::InvalidKey)?;
        Ok(RsaSignaturePublicKey { alg, ctx })
    }

    fn from_local(alg: SignatureAlgorithm, local: &[u8]) -> Result<Self, CryptoError> {
        ensure!(local.len() < 1024, CryptoError::InvalidKey);
        let parts: RsaSignaturePublicKeyParts =
            bincode::deserialize(local).map_err(|_| CryptoError::InvalidKey)?;
        ensure!(
            parts.version == RAW_ENCODING_VERSION && parts.alg_id == RAW_ENCODING_ALG_ID,
            CryptoError::InvalidKey
        );
        let ctx =
            ::rsa::RsaPublicKey::new(parts.n, parts.e).map_err(|_| CryptoError::InvalidKey)?;
        Ok(RsaSignaturePublicKey { alg, ctx })
    }

    fn to_pkcs8(&self) -> Result<Vec<u8>, CryptoError> {
        self.ctx
            .to_public_key_der()
            .map_err(|_| CryptoError::InternalError)
            .map(|x| x.as_ref().to_vec())
    }

    fn to_pem(&self) -> Result<Vec<u8>, CryptoError> {
        self.ctx
            .to_public_key_pem(LineEnding::LF)
            .map(|s| s.as_bytes().to_vec())
            .map_err(|_| CryptoError::InternalError)
    }

    fn to_local(&self) -> Result<Vec<u8>, CryptoError> {
        let parts = RsaSignaturePublicKeyParts {
            version: RAW_ENCODING_VERSION,
            alg_id: RAW_ENCODING_ALG_ID,
            n: self.ctx.n().clone(),
            e: self.ctx.e().clone(),
        };
        let local = bincode::serialize(&parts).map_err(|_| CryptoError::InternalError)?;
        Ok(local)
    }

    pub fn import(
        alg: SignatureAlgorithm,
        encoded: &[u8],
        encoding: PublicKeyEncoding,
    ) -> Result<Self, CryptoError> {
        let pk = match encoding {
            PublicKeyEncoding::Pkcs8 => Self::from_pkcs8(alg, encoded)?,
            PublicKeyEncoding::Pem => Self::from_pem(alg, encoded)?,
            PublicKeyEncoding::Local => Self::from_local(alg, encoded)?,
            _ => bail!(CryptoError::UnsupportedEncoding),
        };
        let modulus_size = pk.ctx.size();
        let min_modulus_bits = modulus_bits(alg)?;
        ensure!(
            modulus_size >= min_modulus_bits / 8 && modulus_size <= MAX_MODULUS_SIZE / 8,
            CryptoError::InvalidKey
        );
        Ok(pk)
    }

    pub fn export(&self, encoding: PublicKeyEncoding) -> Result<Vec<u8>, CryptoError> {
        match encoding {
            PublicKeyEncoding::Pkcs8 => self.to_pkcs8(),
            PublicKeyEncoding::Pem => self.to_pem(),
            PublicKeyEncoding::Local => self.to_local(),
            _ => bail!(CryptoError::UnsupportedEncoding),
        }
    }
}