use crate::TlsError;
use pipebuf::{tripwire, PBufRdWr, PBufState};
use rustls::client::UnbufferedClientConnection;
use rustls::pki_types::ServerName;
use rustls::server::UnbufferedServerConnection;
use rustls::unbuffered::ConnectionState;
use rustls::{ClientConfig, ServerConfig};
use std::sync::Arc;
const FIXUP_CLOSE: bool = true;
macro_rules! read_early_data {
(true, $red:ident, $discard:ident, $int:ident) => {{
while let Some(rec) = $red.next_record() {
let rec =
rec.map_err(|e| TlsError(format!("Failed fetching TLS incoming data: {e}")))?;
$discard += rec.discard;
$int.wr.append(rec.payload);
}
}};
(false, $red:ident, $discard:ident, $int:ident) => {{
return Err(TlsError("Not expecting early data on client".into()));
}};
}
macro_rules! process {
($ext:ident, $int:ident, $conn:ident, $is_server:tt) => {{
if $int.rd.is_aborted() || $ext.rd.is_aborted() {
$int.rd.consume($int.rd.data().len());
$int.rd.consume_eof();
$ext.rd.consume($ext.rd.data().len());
$ext.rd.consume_eof();
if !$ext.wr.is_eof() {
$ext.wr.abort();
}
if !$int.wr.is_eof() {
$int.wr.abort();
}
} else {
let mut discard = 0;
loop {
$ext.rd.consume(discard);
discard = 0;
if $ext.rd.data().len() == 0 && $ext.rd.consume_eof() {
if !$int.wr.is_eof() {
$int.wr.close();
}
if FIXUP_CLOSE && $int.rd.consume_eof() {
$int.rd.consume($int.rd.data().len());
if $int.rd.is_aborted() {
$ext.wr.abort();
} else {
$ext.wr.close();
}
}
break;
}
let status = $conn.process_tls_records($ext.rd.data_mut());
discard += status.discard;
let state = status.state.map_err(|e| {
TlsError(format!(
"Failed whilst processing incoming TLS records: {e}"
))
})?;
match state {
ConnectionState::ReadTraffic(mut rt) => {
while let Some(rec) = rt.next_record() {
let rec = rec.map_err(|e| {
TlsError(format!("Failed fetching TLS incoming data: {e}"))
})?;
discard += rec.discard;
$int.wr.append(rec.payload);
}
}
ConnectionState::ReadEarlyData(mut _red) => {
read_early_data!($is_server, _red, discard, $int);
}
ConnectionState::Closed => {
if !$int.wr.is_eof() {
$int.wr.close();
}
if FIXUP_CLOSE && $int.rd.consume_eof() {
$int.rd.consume($int.rd.data().len());
if $int.rd.is_aborted() {
$ext.wr.abort();
} else {
$ext.wr.close();
}
}
break;
}
ConnectionState::EncodeTlsData(mut etd) => {
let len = etd.encode($ext.wr.space(18 * 1024)).map_err(|e| {
TlsError(format!("Failed to write TLS handshake record: {e}"))
})?;
if !$ext.wr.is_eof() {
$ext.wr.commit(len);
}
}
ConnectionState::TransmitTlsData(ttd) => {
$ext.wr.push();
ttd.done();
}
ConnectionState::BlockedHandshake => break,
ConnectionState::WriteTraffic(mut wt) => {
let wr_open = !$ext.wr.is_eof();
let data = $int.rd.data();
let len = data.len();
let closing = $int.rd.state() == PBufState::Closing;
if len == 0 && !closing {
break;
}
if len > 0 && wr_open {
let space = $ext.wr.space(len + (len >> 3).max(100));
let written = wt.encrypt(data, space).map_err(|e| {
TlsError(format!("Error encrypting outgoing data: {e}"))
})?;
$ext.wr.commit(written);
$int.rd.consume(len);
}
if closing {
$int.rd.consume_eof();
let space = $ext.wr.space(1024);
let written = wt.queue_close_notify(space).map_err(|e| {
TlsError(format!("Error encrypting outgoing close_notify: {e}"))
})?;
if wr_open {
$ext.wr.commit(written);
$ext.wr.close();
}
}
}
_ => return Err(TlsError(format!("Unexpected TLS state: {state:?}"))),
}
}
$ext.rd.consume(discard);
}
}};
}
pub struct TlsServer {
sc: Option<UnbufferedServerConnection>,
}
impl TlsServer {
pub fn new(config: Option<Arc<ServerConfig>>) -> Result<Self, rustls::Error> {
let sc = if let Some(conf) = config {
Some(UnbufferedServerConnection::new(conf)?)
} else {
None
};
Ok(Self { sc })
}
pub fn connection(&self) -> Option<&UnbufferedServerConnection> {
self.sc.as_ref()
}
pub fn process(&mut self, mut ext: PBufRdWr, mut int: PBufRdWr) -> Result<bool, TlsError> {
let before = tripwire!(ext.rd, ext.wr, int.rd, int.wr);
if let Some(ref mut sc) = self.sc {
process!(ext, int, sc, true);
} else {
int.rd.forward(ext.wr.reborrow());
ext.rd.forward(int.wr.reborrow());
}
let after = tripwire!(ext.rd, ext.wr, int.rd, int.wr);
Ok(after != before)
}
}
pub struct TlsClient {
cc: Option<UnbufferedClientConnection>,
}
impl TlsClient {
pub fn new(
config: Option<(Arc<ClientConfig>, ServerName<'static>)>,
) -> Result<Self, rustls::Error> {
let cc = if let Some((conf, name)) = config {
Some(UnbufferedClientConnection::new(conf, name)?)
} else {
None
};
Ok(Self { cc })
}
pub fn connection(&self) -> Option<&UnbufferedClientConnection> {
self.cc.as_ref()
}
pub fn process(&mut self, mut ext: PBufRdWr, mut int: PBufRdWr) -> Result<bool, TlsError> {
let before = tripwire!(ext.rd, ext.wr, int.rd, int.wr);
if let Some(ref mut cc) = self.cc {
process!(ext, int, cc, false);
} else {
int.rd.forward(ext.wr.reborrow());
ext.rd.forward(int.wr.reborrow());
}
let after = tripwire!(ext.rd, ext.wr, int.rd, int.wr);
Ok(after != before)
}
}