use std::io::{self, Read, Write};
use std::ops::DerefMut;
use std::sync::Arc;
use tokio::net::TcpStream;
use rustls::{ClientConfig, ConnectionCommon};
use webpki::Error as BadCertError;
#[cfg(feature = "dangerous_tls")]
pub(super) mod dangerous;
mod client_auth;
pub use client_auth::ClientCertResolver;
use crate::options::OptionsT;
use crate::transport::{ReadyFor, Transport as TransportTrait, WantedFlags};
pub enum CryptoBytes<'a>
{
Der(&'a [u8]),
Pem(&'a [u8])
}
#[derive(Debug)]
pub enum CryptoError
{
BadCert(BadCertError),
IoError(io::Error),
NoValidItemsInPem,
InvalidPrivateKey
}
pub type OptionsWithTls<'a> = OptionsT<'a, ClientConfig>;
impl<'a> OptionsWithTls<'a>
{
pub fn enable_tls_client_auth(&mut self, cert_bytes: CryptoBytes, pk_bytes: CryptoBytes) -> Result<&mut Self, CryptoError>
{
self.connection_cfg.client_auth_cert_resolver = Arc::new(ClientCertResolver::from_bytes(cert_bytes, pk_bytes)?);
Ok(self)
}
}
struct TryRWProxy<'a>(&'a TcpStream);
impl<'a> Read for TryRWProxy<'a>
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize>
{
self.0.try_read(buf)
}
}
impl<'a> Write for TryRWProxy<'a>
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.try_write(buf)
}
fn flush(&mut self) -> io::Result<()> { Ok(()) }
}
pub(super) struct Transport<T>
{
tcp: TcpStream,
tls: T
}
impl<T> Transport<T>
{
pub(super) fn new(tcp: TcpStream, tls: T) -> Self
{
Self { tcp, tls }
}
}
impl<D, T: DerefMut<Target = ConnectionCommon<D>> + Send> TransportTrait for Transport<T>
{
fn wants(&self, _rd_hint: bool, _wr_hint: bool) -> WantedFlags
{
WantedFlags { read: self.tls.wants_read(), write: self.tls.wants_write() }
}
fn ready_for(&self) -> ReadyFor
{
ReadyFor::wrap(&self.tcp)
}
fn pre_read(&mut self) -> io::Result<bool>
{
if let Err(err) = self.tls.read_tls(&mut TryRWProxy(&self.tcp)) {
if err.kind() == io::ErrorKind::WouldBlock {
return Ok(false);
}
}
let tls_state = self.tls.process_new_packets().map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
Ok(tls_state.plaintext_bytes_to_read() > 0)
}
fn read(&mut self, dst: &mut [u8]) -> io::Result<usize>
{
self.tls.reader().read(dst)
}
fn pre_write(&mut self) -> io::Result<()>
{
match self.tls.write_tls(&mut TryRWProxy(&self.tcp)) {
Ok(written) => if written <= 0 {
Err(io::ErrorKind::UnexpectedEof.into())
} else {
Ok(())
},
Err(err) => if err.kind() == io::ErrorKind::WouldBlock { Ok(()) } else { Err(err) }
}
}
fn write(&mut self, src: &[u8], zero_if_would_block: bool) -> io::Result<usize>
{
let written = self.tls.writer().write(src)?;
if written <= 0 && !zero_if_would_block {
Err(io::ErrorKind::WouldBlock.into())
} else {
Ok(written)
}
}
fn flush(&mut self) -> io::Result<()>
{
self.tls.writer().flush()
}
fn send_close_notify(&mut self)
{
self.tls.send_close_notify()
}
}