use crate::Identity;
use async_native_tls::{Error, TlsAcceptor, TlsStream};
use pem::Pem;
use pkcs8::{
AlgorithmIdentifierRef, ObjectIdentifier, PrivateKeyInfo,
der::{Decode, Encode, asn1::AnyRef},
};
use std::{
io::{self, IoSlice, IoSliceMut},
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
};
use trillium_server_common::{Acceptor, AsyncRead, AsyncWrite, Transport};
#[derive(Clone, Debug)]
pub struct NativeTlsAcceptor(TlsAcceptor);
impl NativeTlsAcceptor {
pub fn new(t: impl Into<Self>) -> Self {
t.into()
}
pub fn from_cert_and_key(cert: &[u8], key: &[u8]) -> Self {
let cert_chain_der = extract_cert_chain_der(cert);
let key_pkcs8_der = normalize_key_to_pkcs8_der(key);
let cert_chain_pem = encode_cert_chain_pem(&cert_chain_der);
let key_pkcs8_pem = encode_pkcs8_pem(&key_pkcs8_der);
let pkcs8_err = match Identity::from_pkcs8(&cert_chain_pem, &key_pkcs8_pem) {
Ok(identity) => return identity.into(),
Err(e) => e,
};
let p12_der = build_pkcs12_der(&cert_chain_der, &key_pkcs8_der);
match Identity::from_pkcs12(&p12_der, INTERNAL_P12_PASSWORD) {
Ok(identity) => identity.into(),
Err(p12_err) => panic!(
"could not build Identity from provided cert and key.\n from_pkcs8 error: \
{pkcs8_err}\n from_pkcs12 fallback error: {p12_err}"
),
}
}
pub fn from_pkcs12(der: &[u8], password: &str) -> Self {
Identity::from_pkcs12(der, password)
.expect("could not build Identity from provided pkcs12 key and password")
.into()
}
pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Self {
Identity::from_pkcs8(pem, key)
.expect("could not build Identity from provided pem and key")
.into()
}
}
const PEM_TAG_PKCS8: &str = "PRIVATE KEY";
const PEM_TAG_PKCS1: &str = "RSA PRIVATE KEY";
const PEM_TAG_SEC1: &str = "EC PRIVATE KEY";
const PEM_TAG_CERT: &str = "CERTIFICATE";
const RSA_ENCRYPTION_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113549.1.1.1");
const EC_PUBLIC_KEY_OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.10045.2.1");
const INTERNAL_P12_PASSWORD: &str = "trillium";
fn parse_pem_blocks(input: &[u8]) -> Vec<Pem> {
pem::parse_many(input).expect("could not parse PEM input")
}
fn normalize_key_to_pkcs8_der(input: &[u8]) -> Vec<u8> {
let blocks = parse_pem_blocks(input);
let key = blocks
.iter()
.find(|b| matches!(b.tag(), PEM_TAG_PKCS8 | PEM_TAG_PKCS1 | PEM_TAG_SEC1))
.expect(
"no private key block found in key input (expected PRIVATE KEY, RSA PRIVATE KEY, or \
EC PRIVATE KEY)",
);
match key.tag() {
PEM_TAG_PKCS8 => key.contents().to_vec(),
PEM_TAG_PKCS1 => wrap_pkcs1_in_pkcs8(key.contents()),
PEM_TAG_SEC1 => wrap_sec1_in_pkcs8(key.contents()),
_ => unreachable!(),
}
}
fn wrap_pkcs1_in_pkcs8(pkcs1_der: &[u8]) -> Vec<u8> {
let algorithm = AlgorithmIdentifierRef {
oid: RSA_ENCRYPTION_OID,
parameters: Some(AnyRef::NULL),
};
PrivateKeyInfo::new(algorithm, pkcs1_der)
.to_der()
.expect("could not encode PKCS#1 key as PKCS#8")
}
fn wrap_sec1_in_pkcs8(sec1_der: &[u8]) -> Vec<u8> {
let parsed =
sec1::EcPrivateKey::from_der(sec1_der).expect("could not parse SEC1 EC private key");
let curve_oid = parsed
.parameters
.and_then(|p| p.named_curve())
.expect("EC private key is missing namedCurve parameters");
let curve_param: AnyRef<'_> = (&curve_oid).into();
let algorithm = AlgorithmIdentifierRef {
oid: EC_PUBLIC_KEY_OID,
parameters: Some(curve_param),
};
PrivateKeyInfo::new(algorithm, sec1_der)
.to_der()
.expect("could not encode SEC1 key as PKCS#8")
}
fn encode_cert_chain_pem(cert_chain_der: &[Vec<u8>]) -> Vec<u8> {
let blocks: Vec<Pem> = cert_chain_der
.iter()
.map(|d| Pem::new(PEM_TAG_CERT, d.clone()))
.collect();
pem::encode_many(&blocks).into_bytes()
}
fn encode_pkcs8_pem(key_pkcs8_der: &[u8]) -> Vec<u8> {
pem::encode(&Pem::new(PEM_TAG_PKCS8, key_pkcs8_der.to_vec())).into_bytes()
}
fn extract_cert_chain_der(input: &[u8]) -> Vec<Vec<u8>> {
let certs: Vec<Vec<u8>> = parse_pem_blocks(input)
.into_iter()
.filter(|b| b.tag() == PEM_TAG_CERT)
.map(|b| b.into_contents())
.collect();
assert!(
!certs.is_empty(),
"no CERTIFICATE blocks found in cert input"
);
certs
}
fn build_pkcs12_der(cert_chain_der: &[Vec<u8>], key_pkcs8_der: &[u8]) -> Vec<u8> {
let leaf = cert_chain_der.first().expect("cert chain was empty");
let intermediates: Vec<&[u8]> = cert_chain_der.iter().skip(1).map(Vec::as_slice).collect();
let pfx = p12::PFX::new_with_cas(
leaf,
key_pkcs8_der,
&intermediates,
INTERNAL_P12_PASSWORD,
"",
)
.expect("could not build PKCS#12 archive from cert and key");
pfx.to_der()
}
impl From<Identity> for NativeTlsAcceptor {
fn from(i: Identity) -> Self {
native_tls::TlsAcceptor::new(i).unwrap().into()
}
}
impl From<native_tls::TlsAcceptor> for NativeTlsAcceptor {
fn from(i: native_tls::TlsAcceptor) -> Self {
Self(i.into())
}
}
impl From<TlsAcceptor> for NativeTlsAcceptor {
fn from(i: TlsAcceptor) -> Self {
Self(i)
}
}
impl From<(&[u8], &str)> for NativeTlsAcceptor {
fn from(i: (&[u8], &str)) -> Self {
Self::from_pkcs12(i.0, i.1)
}
}
impl<Input> Acceptor<Input> for NativeTlsAcceptor
where
Input: Transport,
{
type Error = Error;
type Output = NativeTlsServerTransport<Input>;
async fn accept(&self, input: Input) -> Result<Self::Output, Self::Error> {
self.0.accept(input).await.map(NativeTlsServerTransport)
}
}
#[derive(Debug)]
pub struct NativeTlsServerTransport<T>(TlsStream<T>);
impl<T: AsyncWrite + AsyncRead + Unpin> AsRef<T> for NativeTlsServerTransport<T> {
fn as_ref(&self) -> &T {
self.0.get_ref()
}
}
impl<T: AsyncWrite + AsyncRead + Unpin> AsMut<T> for NativeTlsServerTransport<T> {
fn as_mut(&mut self) -> &mut T {
self.0.get_mut()
}
}
impl<T> AsRef<TlsStream<T>> for NativeTlsServerTransport<T> {
fn as_ref(&self) -> &TlsStream<T> {
&self.0
}
}
impl<T> AsMut<TlsStream<T>> for NativeTlsServerTransport<T> {
fn as_mut(&mut self) -> &mut TlsStream<T> {
&mut self.0
}
}
impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for NativeTlsServerTransport<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_read_vectored(cx, bufs)
}
}
impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for NativeTlsServerTransport<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_close(cx)
}
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
}
}
impl<T: Transport> Transport for NativeTlsServerTransport<T> {
fn peer_addr(&self) -> io::Result<Option<SocketAddr>> {
self.0.get_ref().peer_addr()
}
}
#[cfg(test)]
mod tests {
use super::{
EC_PUBLIC_KEY_OID, RSA_ENCRYPTION_OID, extract_cert_chain_der, normalize_key_to_pkcs8_der,
};
use pkcs8::PrivateKeyInfo;
const RSA_CERT: &[u8] = include_bytes!("../tests/fixtures/rsa.crt");
const RSA_PKCS1: &[u8] = include_bytes!("../tests/fixtures/rsa-pkcs1.key");
const EC_CERT: &[u8] = include_bytes!("../tests/fixtures/ec.crt");
const EC_SEC1: &[u8] = include_bytes!("../tests/fixtures/ec-sec1.key");
const EC_PKCS8: &[u8] = include_bytes!("../tests/fixtures/ec-pkcs8.key");
fn parse_pkcs8_der(der: &[u8]) -> PrivateKeyInfo<'_> {
PrivateKeyInfo::try_from(der).expect("output not parseable as PKCS#8")
}
#[test]
fn pkcs1_wraps_to_pkcs8_with_rsa_oid() {
let der = normalize_key_to_pkcs8_der(RSA_PKCS1);
assert_eq!(parse_pkcs8_der(&der).algorithm.oid, RSA_ENCRYPTION_OID);
}
#[test]
fn sec1_wraps_to_pkcs8_with_ec_oid_and_curve_param() {
let der = normalize_key_to_pkcs8_der(EC_SEC1);
let pki = parse_pkcs8_der(&der);
assert_eq!(pki.algorithm.oid, EC_PUBLIC_KEY_OID);
assert!(
pki.algorithm.parameters.is_some(),
"EC PKCS#8 must carry namedCurve OID in algorithm parameters"
);
}
#[test]
fn pkcs8_pass_through_preserves_algorithm() {
let der = normalize_key_to_pkcs8_der(EC_PKCS8);
assert_eq!(parse_pkcs8_der(&der).algorithm.oid, EC_PUBLIC_KEY_OID);
}
#[test]
fn cert_extracted_from_concatenated_bundle() {
let mut bundle = Vec::new();
bundle.extend_from_slice(EC_CERT);
bundle.extend_from_slice(EC_SEC1);
let extracted = extract_cert_chain_der(&bundle);
let original: Vec<Vec<u8>> = pem::parse_many(EC_CERT)
.unwrap()
.into_iter()
.map(pem::Pem::into_contents)
.collect();
assert_eq!(extracted, original);
}
#[test]
fn key_extracted_from_concatenated_bundle() {
let mut bundle = Vec::new();
bundle.extend_from_slice(RSA_CERT);
bundle.extend_from_slice(RSA_PKCS1);
let der = normalize_key_to_pkcs8_der(&bundle);
assert_eq!(parse_pkcs8_der(&der).algorithm.oid, RSA_ENCRYPTION_OID);
}
}