use crate::{
address::Address,
dns::{query_record, sha256_truncated},
message::{is_encrypted, plaintext_message, raw_header, split_content_headers}
};
use base64::{Engine, prelude::BASE64_STANDARD};
use futures_util::future;
use hickory_resolver::{
proto::rr::{
Name, RData, RecordType,
rdata::{
SMIMEA,
tlsa::{CertUsage, Matching, Selector}
}
},
recursor::RecursorError
};
use log::{error, info};
use mail_parser::MessageParser;
use openssl::{
cms::{CMSOptions, CmsContentInfo},
stack::Stack,
symm::Cipher,
x509::X509
};
use std::borrow::Cow;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum Error {
#[error("DNS Error: {0}")]
DnsError(#[from] RecursorError),
#[error("The SMIMEA DNS record(s) contained no usable S/MIME encryption keys")]
NoUsableKeys
}
pub type Result<T, E = Error> = std::result::Result<T, E>;
pub async fn find_certs_for_recipient<A>(recipient: &A) -> Result<Vec<X509>>
where
A: Address
{
let mut name: Name = recipient.domain().parse().unwrap();
name.set_fqdn(true);
name = name.prepend_label("_smimecert").unwrap();
name = name
.prepend_label(sha256_truncated(recipient.local().as_bytes()))
.unwrap();
let raw_certs = query_record(&name, RecordType::SMIMEA, |record| {
let RData::SMIMEA(SMIMEA(smimea)) = record.into_data() else {
return None;
};
Some(smimea)
})
.await?;
let mut certs = Vec::with_capacity(raw_certs.len());
for smimea in raw_certs {
if smimea.cert_usage != CertUsage::PkixEe
&& smimea.cert_usage != CertUsage::DaneEe
{
continue;
}
if smimea.selector != Selector::Full {
continue;
}
if smimea.matching != Matching::Raw {
continue;
}
let cert = match X509::from_der(&smimea.cert_data) {
Ok(cert) => cert,
Err(err) => {
error!("Unable to parse SMIMEA record for {name}: {err}");
continue;
}
};
certs.push(cert);
}
if certs.is_empty() {
return Err(Error::NoUsableKeys);
}
Ok(certs)
}
pub async fn find_certs_for_all_recipients<A>(to: &[A]) -> Result<Vec<X509>>
where
A: Address
{
let certs = future::try_join_all(to.iter().map(find_certs_for_recipient)).await?;
Ok(certs.into_iter().flatten().collect())
}
pub fn encrypt(msg_bytes: &[u8], recipient_certs: Vec<X509>) -> Cow<'_, [u8]> {
let Some(mut msg) = MessageParser::new().parse(msg_bytes) else {
error!("Unable to parse message, not encrypting");
return msg_bytes.into();
};
if is_encrypted(&msg) {
info!("Found encrypted message, not encrypting twice");
return msg_bytes.into();
}
let content_headers = split_content_headers(&mut msg);
let plaintext = plaintext_message(&msg, &content_headers);
let mut recipients = Stack::new().unwrap();
for cert in recipient_certs {
recipients.push(cert).unwrap();
}
let mut encrypted = Vec::<u8>::new();
for header in msg.root_part().headers() {
encrypted.extend_from_slice(raw_header(&msg, header));
}
encrypted.extend_from_slice(
b"Content-Disposition: attachment; filename=\"smime.p7m\"\r\n"
);
encrypted.extend_from_slice(
b"Content-Type: application/pkcs7-mime; name=\"smime.p7m\";\r\n"
);
encrypted.extend_from_slice(b" smime-type=\"enveloped-data\"\r\n");
encrypted.extend_from_slice(b"Content-Description: S/MIME Encrypted Message\r\n");
encrypted.extend_from_slice(b"Content-Transfer-Encoding: base64\r\n");
encrypted.extend_from_slice(b"\r\n");
let cipher = Cipher::aes_256_cbc();
let flags = CMSOptions::empty();
let cms = CmsContentInfo::encrypt(&recipients, &plaintext, cipher, flags).unwrap();
let encrypted_message = cms.to_der().unwrap();
let mut i = 0;
while i < encrypted_message.len() {
let slice = &encrypted_message[i .. encrypted_message.len().min(i + 48)];
encrypted.extend_from_slice(BASE64_STANDARD.encode(slice).as_bytes());
encrypted.extend_from_slice(b"\r\n");
i += 48;
}
encrypted.into()
}