use crate::{
tls::{TlsAcceptorSettings, UnimplementedTls},
traits::{CertifiedConn, StreamOps, TlsConnector, TlsProvider},
};
use async_trait::async_trait;
use futures::{AsyncRead, AsyncWrite};
use native_tls_crate as native_tls;
use std::{
borrow::Cow,
io::{Error as IoError, Result as IoResult},
};
use tracing::instrument;
#[cfg_attr(
docsrs,
doc(cfg(all(
feature = "native-tls",
any(feature = "tokio", feature = "async-std", feature = "smol")
)))
)]
#[derive(Default, Clone)]
#[non_exhaustive]
pub struct NativeTlsProvider {}
impl<S> CertifiedConn for async_native_tls::TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn peer_certificate(&self) -> IoResult<Option<Cow<'_, [u8]>>> {
let cert = self.peer_certificate();
match cert {
Ok(Some(c)) => {
let der = c.to_der().map_err(IoError::other)?;
Ok(Some(Cow::from(der)))
}
Ok(None) => Ok(None),
Err(e) => Err(IoError::other(e)),
}
}
fn export_keying_material(
&self,
_len: usize,
_label: &[u8],
_context: Option<&[u8]>,
) -> IoResult<Vec<u8>> {
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
tor_error::bad_api_usage!("native-tls does not support exporting keying material"),
))
}
fn own_certificate(&self) -> IoResult<Option<Cow<'_, [u8]>>> {
Ok(None)
}
}
impl<S: AsyncRead + AsyncWrite + StreamOps + Unpin> StreamOps for async_native_tls::TlsStream<S> {
fn set_tcp_notsent_lowat(&self, notsent_lowat: u32) -> IoResult<()> {
self.get_ref().set_tcp_notsent_lowat(notsent_lowat)
}
fn new_handle(&self) -> Box<dyn StreamOps + Send + Unpin> {
self.get_ref().new_handle()
}
}
pub struct NativeTlsConnector<S> {
connector: async_native_tls::TlsConnector,
_phantom: std::marker::PhantomData<fn(S) -> S>,
}
#[async_trait]
impl<S> TlsConnector<S> for NativeTlsConnector<S>
where
S: AsyncRead + AsyncWrite + StreamOps + Unpin + Send + 'static,
{
type Conn = async_native_tls::TlsStream<S>;
#[instrument(skip_all, level = "trace")]
async fn negotiate_unvalidated(&self, stream: S, sni_hostname: &str) -> IoResult<Self::Conn> {
let conn = self
.connector
.connect(sni_hostname, stream)
.await
.map_err(IoError::other)?;
Ok(conn)
}
}
impl<S> TlsProvider<S> for NativeTlsProvider
where
S: AsyncRead + AsyncWrite + StreamOps + Unpin + Send + 'static,
{
type Connector = NativeTlsConnector<S>;
type TlsStream = async_native_tls::TlsStream<S>;
type Acceptor = UnimplementedTls;
type TlsServerStream = UnimplementedTls;
fn tls_connector(&self) -> Self::Connector {
let mut builder = native_tls::TlsConnector::builder();
builder
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true);
builder.disable_built_in_roots(true);
let connector = builder.into();
NativeTlsConnector {
connector,
_phantom: std::marker::PhantomData,
}
}
fn tls_acceptor(&self, _settings: TlsAcceptorSettings) -> IoResult<Self::Acceptor> {
Err(crate::tls::TlsServerUnsupported {}.into())
}
fn supports_keying_material_export(&self) -> bool {
false
}
}