use std::{
future::Future,
io,
sync::{Arc, Mutex},
};
use native_tls::Identity;
use openssl::{
ec::{EcGroup, EcKey},
hash::MessageDigest,
nid::Nid,
pkey::{HasPrivate, HasPublic, PKey, PKeyRef, Private},
x509::{X509NameBuilder, X509ReqBuilder},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_native_tls::TlsStream;
use crate::{
acme::{self, Account},
Error,
};
pub enum TlsMode {
None,
Simple(Identity),
LetsEncrypt,
}
impl TlsMode {
pub async fn create_acme_account(&self) -> Result<Option<Account>, Error> {
if matches!(self, Self::LetsEncrypt) {
let account = acme::Client::new()
.await?
.open_directory(acme::LETS_ENCRYPT_DIRECTORY)
.await?
.new_account(None)
.await?;
Ok(Some(account))
} else {
Ok(None)
}
}
pub fn create_tls_acceptor(&self) -> Result<Option<TlsAcceptor>, Error> {
match self {
Self::None => Ok(None),
Self::Simple(identity) => {
let acceptor = TlsAcceptor::new(identity.clone())?;
Ok(Some(acceptor))
}
Self::LetsEncrypt => Ok(Some(TlsAcceptor::dummy())),
}
}
}
#[derive(Clone)]
pub struct TlsAcceptor {
inner: Arc<Mutex<Option<Arc<tokio_native_tls::TlsAcceptor>>>>,
}
impl TlsAcceptor {
pub fn dummy() -> Self {
Self {
inner: Arc::new(Mutex::new(None)),
}
}
pub fn new(identity: Identity) -> Result<Self, Error> {
let acceptor = native_tls::TlsAcceptor::new(identity)?;
let res = Self {
inner: Arc::new(Mutex::new(Some(Arc::new(acceptor.into())))),
};
Ok(res)
}
pub fn set_identity(&self, identity: Identity) -> Result<(), Error> {
let acceptor = native_tls::TlsAcceptor::new(identity)?;
let mut inner = self.inner.lock().unwrap();
*inner = Some(Arc::new(acceptor.into()));
Ok(())
}
pub fn accept<S>(&self, stream: S) -> impl Future<Output = io::Result<TlsStream<S>>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let acceptor = self.inner.lock().unwrap().clone();
async move {
acceptor
.ok_or_else(|| io::Error::from(io::ErrorKind::ConnectionRefused))?
.accept(stream)
.await
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))
}
}
}
pub fn generate_tls_key() -> Result<PKey<Private>, Error> {
let ec_group = EcGroup::from_curve_name(Nid::SECP384R1)?;
let ec_key = EcKey::generate(&ec_group)?;
let key = PKey::from_ec_key(ec_key)?;
Ok(key)
}
pub fn create_csr<T>(key: &PKeyRef<T>, hostname: &str) -> Result<Vec<u8>, Error>
where
T: HasPrivate + HasPublic,
{
let mut subject_name_builder = X509NameBuilder::new()?;
subject_name_builder.append_entry_by_nid(Nid::COMMONNAME, hostname)?;
let subject_name = subject_name_builder.build();
let mut csr_builder = X509ReqBuilder::new()?;
csr_builder.set_version(1)?;
csr_builder.set_subject_name(&subject_name)?;
csr_builder.set_pubkey(key)?;
csr_builder.sign(key, MessageDigest::sha256())?;
let res = csr_builder.build().to_der()?;
Ok(res)
}