use std::net::ToSocketAddrs;
use std::net::{Ipv4Addr, SocketAddr};
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{Error, Result};
use async_trait::async_trait;
use tokio::io::split;
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::rustls::pki_types::pem::PemObject;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use tokio_rustls::rustls::RootCertStore;
use tokio_rustls::{rustls, TlsAcceptor};
use crate::connector::Connector;
pub struct Config {
listen_addr: SocketAddr,
ca_cert_file: Option<PathBuf>,
server_cert_file: PathBuf,
server_key_file: PathBuf,
}
const DEFAULT_TLS_PORT: u16 = 9445;
impl Default for Config {
fn default() -> Self {
Config {
listen_addr: SocketAddr::new(
std::net::IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
DEFAULT_TLS_PORT,
),
ca_cert_file: None,
server_cert_file: PathBuf::from("cert.pem"),
server_key_file: PathBuf::from("key.pem"),
}
}
}
pub(crate) struct TlsConnector {
connector: tokio_rustls::TlsConnector,
acceptor: TlsAcceptor,
listener: TcpListener,
}
impl TryFrom<Config> for TlsConnector {
type Error = Error;
fn try_from(config: Config) -> Result<Self, Self::Error> {
let mut root_cert_store = RootCertStore::empty();
if let Some(cafile) = config.ca_cert_file {
for cert in CertificateDer::pem_file_iter(cafile)? {
root_cert_store.add(cert?)?;
}
}
let tls_client_config = rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth(); let connector = tokio_rustls::TlsConnector::from(Arc::new(tls_client_config));
let server_certs = CertificateDer::pem_file_iter(&config.server_cert_file)?
.collect::<Result<Vec<_>, _>>()?;
let server_key = PrivateKeyDer::from_pem_file(&config.server_key_file)?;
let tls_server_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(server_certs, server_key)?;
let acceptor = TlsAcceptor::from(Arc::new(tls_server_config));
let listener = std::net::TcpListener::bind(config.listen_addr)?;
Ok(TlsConnector {
acceptor,
connector,
listener: TcpListener::from_std(listener)?,
})
}
}
#[async_trait]
impl Connector for TlsConnector {
async fn new_outgoing_connection(
&mut self,
locator: &crate::locator::Peer,
) -> Result<crate::connection::Connection> {
if !locator.transport.eq("tls") {
return Err(Error::msg("not a tls address"));
}
let port: u16 = match locator.hints.get("port") {
Some(port_as_str) => port_as_str.parse()?,
None => DEFAULT_TLS_PORT,
};
let addr = (locator.designator.as_str(), port)
.to_socket_addrs()?
.next()
.ok_or_else(|| Error::msg("could not resolve designator host"))?;
let domain = ServerName::try_from(locator.designator.as_str())?.to_owned();
let tcp_stream = TcpStream::connect(addr).await?;
let tls_stream = self.connector.connect(domain, tcp_stream).await?;
let split_stream = split(tls_stream);
Ok(split_stream.into())
}
async fn accept_incoming_connection(
&mut self,
) -> Result<Option<crate::connection::Connection>> {
let (tcp_stream, _) = self.listener.accept().await?;
let acceptor = self.acceptor.clone();
let tls_stream = acceptor.accept(tcp_stream).await?;
let split_stream = split(tls_stream);
Ok(Some(split_stream.into()))
}
}