use aes_gcm::aead::{Aead, KeyInit as GcmKeyInit};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use cms::{
authenveloped_data::AuthEnvelopedData,
cert::IssuerAndSerialNumber,
content_info::{CmsVersion, ContentInfo},
enveloped_data::{
EncryptedContentInfo, EncryptedKey, KeyAgreeRecipientIdentifier, KeyAgreeRecipientInfo,
KeyTransRecipientInfo, OriginatorIdentifierOrKey, OriginatorPublicKey,
RecipientEncryptedKey, RecipientIdentifier, RecipientInfo, RecipientInfos,
},
};
use const_oid::db::{rfc5753, rfc5911, rfc5912};
use crypto_common::Generate as _;
use der::{
asn1::{BitString, ObjectIdentifier, OctetString, SetOfVec},
Any, AnyRef, Encode, Sequence,
};
use elliptic_curve::ecdh::EphemeralSecret;
use elliptic_curve::sec1::{FromSec1Point, ModulusSize, ToSec1Point};
use elliptic_curve::{AffinePoint, CurveArithmetic, FieldBytesSize};
use getrandom::{rand_core::UnwrapErr, SysRng};
use rsa::{pkcs8::DecodePublicKey, RsaPublicKey};
use spki::AlgorithmIdentifierOwned;
use x509_cert::Certificate;
use zeroize::Zeroizing;
use crate::error::SmimeError;
const ID_DATA: ObjectIdentifier = rfc5911::ID_DATA;
const ID_CT_AUTH_ENVELOPED_DATA: ObjectIdentifier = rfc5911::ID_CT_AUTH_ENVELOPED_DATA;
const ID_AES_128_GCM: ObjectIdentifier = rfc5911::ID_AES_128_GCM;
const ID_AES_256_GCM: ObjectIdentifier = rfc5911::ID_AES_256_GCM;
const ID_AES_128_WRAP: ObjectIdentifier = rfc5911::ID_AES_128_WRAP;
const ID_AES_256_WRAP: ObjectIdentifier = rfc5911::ID_AES_256_WRAP;
const DH_SHA256_KDF: ObjectIdentifier = rfc5753::DH_SINGLE_PASS_STD_DH_SHA_256_KDF_SCHEME;
const DH_SHA384_KDF: ObjectIdentifier = rfc5753::DH_SINGLE_PASS_STD_DH_SHA_384_KDF_SCHEME;
#[derive(Clone, Debug, Eq, PartialEq, Sequence)]
struct EccCmsSharedInfo {
key_info: AlgorithmIdentifierOwned,
#[asn1(
context_specific = "0",
tag_mode = "EXPLICIT",
constructed = "true",
optional = "true"
)]
entity_u_info: Option<OctetString>,
#[asn1(context_specific = "2", tag_mode = "EXPLICIT", constructed = "true")]
supp_pub_info: OctetString,
}
pub fn encrypt(inner_mime: &[u8], recipients: &[Certificate]) -> Result<Vec<u8>, SmimeError> {
if recipients.is_empty() {
return Err(SmimeError::NoRecipients);
}
let use_aes256 = recipients.iter().any(|cert| {
let spki = cert.tbs_certificate().subject_public_key_info();
if spki.algorithm.oid != rfc5912::ID_EC_PUBLIC_KEY {
return false;
}
spki.algorithm
.parameters
.as_ref()
.and_then(|p: &Any| p.decode_as::<ObjectIdentifier>().ok())
.map(|curve| curve == rfc5912::SECP_384_R_1)
.unwrap_or(false)
});
let (content_enc_alg, encrypted_content, mac, cek_bytes) = if use_aes256 {
encrypt_aes_gcm(32, ID_AES_256_GCM, inner_mime)?
} else {
encrypt_aes_gcm(16, ID_AES_128_GCM, inner_mime)?
};
let mut recipient_infos: Vec<RecipientInfo> = Vec::with_capacity(recipients.len());
for cert in recipients {
recipient_infos.push(build_recipient_info(cert, &cek_bytes)?);
}
let enc_content = OctetString::new(encrypted_content)?;
let set: SetOfVec<RecipientInfo> = SetOfVec::try_from(recipient_infos)?;
let recip_infos = RecipientInfos::from(set);
let auth_env_data = AuthEnvelopedData {
version: CmsVersion::V0,
originator_info: None,
recip_infos,
auth_encrypted_content_info: EncryptedContentInfo {
content_type: ID_DATA,
content_enc_alg,
encrypted_content: Some(enc_content),
},
auth_attrs: None,
mac: OctetString::new(mac)?,
unauth_attrs: None,
};
let auth_env_der = auth_env_data.to_der()?;
let content = AnyRef::try_from(auth_env_der.as_slice())?;
let ci = ContentInfo {
content_type: ID_CT_AUTH_ENVELOPED_DATA,
content: Any::from(content),
};
let ci_der = ci.to_der()?;
Ok(build_mime(&ci_der))
}
#[derive(Clone, Debug, Eq, PartialEq, Sequence)]
struct GcmParameters {
aes_nonce: OctetString,
#[asn1(default = "default_icv_len")]
aes_icv_len: u8,
}
fn default_icv_len() -> u8 {
12
}
#[allow(clippy::type_complexity)]
fn encrypt_aes_gcm(
key_len: usize,
cek_oid: ObjectIdentifier,
plaintext: &[u8],
) -> Result<
(
AlgorithmIdentifierOwned,
Vec<u8>,
Vec<u8>,
Zeroizing<Vec<u8>>,
),
SmimeError,
> {
let mut cek_buf = vec![0u8; key_len];
let mut nonce_buf = [0u8; 12];
getrandom::fill(&mut cek_buf).map_err(|e| SmimeError::RngFailure(format!("{e}")))?;
getrandom::fill(&mut nonce_buf).map_err(|e| SmimeError::RngFailure(format!("{e}")))?;
let nonce = aes_gcm::Nonce::<aes_gcm::aead::consts::U12>::try_from(nonce_buf.as_slice())
.map_err(|_| SmimeError::Other("nonce length mismatch".into()))?;
let ct_with_tag = if key_len == 32 {
let key_arr: &[u8; 32] = cek_buf
.as_slice()
.try_into()
.map_err(|_| SmimeError::Other("AES-256-GCM key length mismatch".into()))?;
let cipher = aes_gcm::Aes256Gcm::new(key_arr.into());
cipher
.encrypt(&nonce, plaintext)
.map_err(|_| SmimeError::Other("AES-256-GCM encrypt failed".into()))?
} else {
let key_arr: &[u8; 16] = cek_buf
.as_slice()
.try_into()
.map_err(|_| SmimeError::Other("AES-128-GCM key length mismatch".into()))?;
let cipher = aes_gcm::Aes128Gcm::new(key_arr.into());
cipher
.encrypt(&nonce, plaintext)
.map_err(|_| SmimeError::Other("AES-128-GCM encrypt failed".into()))?
};
let tag_start = ct_with_tag.len() - 16;
let ct = ct_with_tag[..tag_start].to_vec();
let tag = ct_with_tag[tag_start..].to_vec();
let cek_bytes = Zeroizing::new(cek_buf);
let gcm_params = GcmParameters {
aes_nonce: OctetString::new(nonce_buf.as_slice())?,
aes_icv_len: 16,
};
let params_any = Any::encode_from(&gcm_params)?;
let alg = AlgorithmIdentifierOwned {
oid: cek_oid,
parameters: Some(params_any),
};
Ok((alg, ct, tag, cek_bytes))
}
fn build_recipient_info(cert: &Certificate, cek: &[u8]) -> Result<RecipientInfo, SmimeError> {
let spki = cert.tbs_certificate().subject_public_key_info();
let alg_oid = spki.algorithm.oid;
if alg_oid == rfc5912::RSA_ENCRYPTION {
build_rsa_recipient(cert, cek)
} else if alg_oid == rfc5912::ID_RSAES_OAEP {
Err(SmimeError::UnsupportedAlgorithm(
"RSAES-OAEP recipient certs (id-RSAES-OAEP SPKI OID) are not supported; \
use RSA with PKCS#1v1.5"
.into(),
))
} else if alg_oid == rfc5912::ID_EC_PUBLIC_KEY {
let curve_oid = spki
.algorithm
.parameters
.as_ref()
.and_then(|p: &Any| p.decode_as::<ObjectIdentifier>().ok())
.ok_or_else(|| {
SmimeError::UnsupportedAlgorithm("EC public key missing curve OID parameter".into())
})?;
if curve_oid == rfc5912::SECP_256_R_1 {
build_ec_recipient::<p256::NistP256, sha2::Sha256>(
cert,
cek,
"P-256",
rfc5912::SECP_256_R_1,
DH_SHA256_KDF,
ID_AES_128_WRAP,
128,
)
} else if curve_oid == rfc5912::SECP_384_R_1 {
build_ec_recipient::<p384::NistP384, sha2::Sha384>(
cert,
cek,
"P-384",
rfc5912::SECP_384_R_1,
DH_SHA384_KDF,
ID_AES_256_WRAP,
256,
)
} else {
Err(SmimeError::UnsupportedAlgorithm(format!(
"EC curve {} not supported",
curve_oid
)))
}
} else {
Err(SmimeError::UnsupportedAlgorithm(format!(
"recipient key algorithm {} not supported",
alg_oid
)))
}
}
fn build_rsa_recipient(cert: &Certificate, cek: &[u8]) -> Result<RecipientInfo, SmimeError> {
use rsa::Pkcs1v15Encrypt;
let mut preflight = [0u8; 256];
getrandom::fill(&mut preflight).map_err(|e| SmimeError::RngFailure(format!("{e}")))?;
let _ = preflight;
let spki_der = cert.tbs_certificate().subject_public_key_info().to_der()?;
let rsa_pub = RsaPublicKey::from_public_key_der(&spki_der).map_err(|e| {
SmimeError::MalformedInput(format!("RSA public key in recipient cert: {e}"))
})?;
let mut rng = UnwrapErr(SysRng);
let encrypted_key = rsa_pub
.encrypt(&mut rng, Pkcs1v15Encrypt, cek)
.map_err(|e| SmimeError::Other(format!("RSA PKCS#1v15 encrypt: {e}")))?;
let ias = IssuerAndSerialNumber {
issuer: cert.tbs_certificate().issuer().clone(),
serial_number: cert.tbs_certificate().serial_number().clone(),
};
Ok(RecipientInfo::Ktri(KeyTransRecipientInfo {
version: CmsVersion::V0,
rid: RecipientIdentifier::IssuerAndSerialNumber(ias),
key_enc_alg: AlgorithmIdentifierOwned {
oid: rfc5912::RSA_ENCRYPTION,
parameters: Some(Any::null()),
},
enc_key: EncryptedKey::new(encrypted_key)?,
}))
}
fn build_ec_recipient<C, D>(
cert: &Certificate,
cek: &[u8],
curve_name: &str,
curve_oid: ObjectIdentifier,
kdf_oid: ObjectIdentifier,
wrap_oid: ObjectIdentifier,
wrap_key_bits: u32,
) -> Result<RecipientInfo, SmimeError>
where
C: CurveArithmetic,
AffinePoint<C>: FromSec1Point<C> + ToSec1Point<C>,
FieldBytesSize<C>: ModulusSize,
D: sha2::digest::Digest + sha2::digest::FixedOutputReset,
{
let raw_bits = cert
.tbs_certificate()
.subject_public_key_info()
.subject_public_key
.raw_bytes();
let recipient_pub = elliptic_curve::PublicKey::<C>::from_sec1_bytes(raw_bits).map_err(|e| {
SmimeError::MalformedInput(format!("{curve_name} public key in recipient cert: {e}"))
})?;
let ephemeral: EphemeralSecret<C> = EphemeralSecret::try_generate_from_rng(&mut SysRng)
.map_err(|e| SmimeError::RngFailure(format!("{e}")))?;
let ephemeral_pub = ephemeral.public_key();
let shared_secret = ephemeral.diffie_hellman(&recipient_pub);
let wrapped_cek = ecdh_wrap_cek::<D>(
shared_secret.raw_secret_bytes().as_ref(),
wrap_oid,
wrap_key_bits,
cek,
)?;
build_kari_recipient(
cert,
ephemeral_pub.to_sec1_point(false).as_bytes(),
curve_oid,
kdf_oid,
wrap_oid,
wrapped_cek,
)
}
fn build_kari_recipient(
cert: &Certificate,
ephemeral_pub_bytes: &[u8],
curve_oid: ObjectIdentifier,
kdf_oid: ObjectIdentifier,
wrap_oid: ObjectIdentifier,
wrapped_cek: Vec<u8>,
) -> Result<RecipientInfo, SmimeError> {
let originator_pub = OriginatorPublicKey {
algorithm: AlgorithmIdentifierOwned {
oid: rfc5912::ID_EC_PUBLIC_KEY,
parameters: Some(Any::from(&curve_oid)),
},
public_key: BitString::from_bytes(ephemeral_pub_bytes)?,
};
let ias = IssuerAndSerialNumber {
issuer: cert.tbs_certificate().issuer().clone(),
serial_number: cert.tbs_certificate().serial_number().clone(),
};
Ok(RecipientInfo::Kari(KeyAgreeRecipientInfo {
version: CmsVersion::V3,
originator: OriginatorIdentifierOrKey::OriginatorKey(originator_pub),
ukm: None,
key_enc_alg: AlgorithmIdentifierOwned {
oid: kdf_oid,
parameters: Some(wrap_alg_any(wrap_oid)?),
},
recipient_enc_keys: vec![RecipientEncryptedKey {
rid: KeyAgreeRecipientIdentifier::IssuerAndSerialNumber(ias),
enc_key: EncryptedKey::new(wrapped_cek)?,
}],
}))
}
fn ecdh_wrap_cek<D>(
shared_secret_bytes: &[u8],
wrap_oid: ObjectIdentifier,
wrap_key_bits: u32,
cek: &[u8],
) -> Result<Vec<u8>, SmimeError>
where
D: sha2::digest::Digest + sha2::digest::FixedOutputReset,
{
let key_wrap_alg = AlgorithmIdentifierOwned {
oid: wrap_oid,
parameters: None,
};
let supp_bytes = wrap_key_bits.to_be_bytes();
let shared_info = EccCmsSharedInfo {
key_info: key_wrap_alg,
entity_u_info: None,
supp_pub_info: OctetString::new(supp_bytes.as_slice())?,
};
let shared_info_der = shared_info.to_der()?;
let kek_len = (wrap_key_bits / 8) as usize;
let mut kek = Zeroizing::new(vec![0u8; kek_len]);
ansi_x963_kdf::derive_key_into::<D>(shared_secret_bytes, &shared_info_der, &mut kek)
.map_err(|_| SmimeError::Other("ANSI X9.63 KDF failed".into()))?;
let wrapped_len = cek.len() + 8;
let mut wrapped = vec![0u8; wrapped_len];
macro_rules! do_wrap {
($n:literal, $kw:ty) => {{
use aes_kw::cipher::KeyInit;
let kek_arr: &[u8; $n] = kek
.as_slice()
.try_into()
.map_err(|_| SmimeError::Other("KEK length mismatch".into()))?;
let wrapper = <$kw>::new(kek_arr.into());
wrapper
.wrap_key(cek, &mut wrapped)
.map_err(|e| SmimeError::Other(e.to_string()))?;
}};
}
match kek_len {
16 => do_wrap!(16, aes_kw::KwAes128),
32 => do_wrap!(32, aes_kw::KwAes256),
_ => {
return Err(SmimeError::Other(format!(
"unsupported KEK length: {kek_len} bytes"
)));
}
}
Ok(wrapped)
}
fn wrap_alg_any(wrap_oid: ObjectIdentifier) -> Result<Any, SmimeError> {
let alg = AlgorithmIdentifierOwned {
oid: wrap_oid,
parameters: None,
};
Any::encode_from(&alg).map_err(SmimeError::Der)
}
fn build_mime(der: &[u8]) -> Vec<u8> {
let b64 = BASE64.encode(der);
let mut folded = String::with_capacity(b64.len() + b64.len() / 76 * 2 + 4);
for chunk in b64.as_bytes().chunks(76) {
folded.push_str(
core::str::from_utf8(chunk)
.unwrap_or_else(|_| unreachable!("base64 output is always valid UTF-8")),
);
folded.push_str("\r\n");
}
let mime = format!(
"MIME-Version: 1.0\r\n\
Content-Type: application/pkcs7-mime; smime-type=authEnveloped-data; name=smime.p7m\r\n\
Content-Transfer-Encoding: base64\r\n\
Content-Disposition: attachment; filename=smime.p7m\r\n\
\r\n\
{folded}"
);
mime.into_bytes()
}