use crate::TlsError;
use pipebuf::{tripwire, PBufRdWr};
use rustls::{pki_types::ServerName, ClientConfig, ClientConnection};
use std::io::ErrorKind;
use std::sync::Arc;
pub struct TlsClient {
cc: Option<ClientConnection>,
}
impl TlsClient {
pub fn new(
config: Option<(Arc<ClientConfig>, ServerName<'static>)>,
) -> Result<Self, rustls::Error> {
let cc = if let Some((conf, name)) = config {
Some(ClientConnection::new(conf, name)?)
} else {
None
};
Ok(Self { cc })
}
pub fn connection(&self) -> Option<&ClientConnection> {
self.cc.as_ref()
}
pub fn process(&mut self, mut ext: PBufRdWr, mut int: PBufRdWr) -> Result<bool, TlsError> {
let before = tripwire!(ext.rd, ext.wr, int.rd, int.wr);
if let Some(ref mut cc) = self.cc {
loop {
if cc.wants_write() && !ext.wr.is_eof() {
cc.write_tls(&mut ext.wr).map_err(|e| {
TlsError(format!(
"Unexpected error from ClientConnection::write_tls: {e}"
))
})?;
if int.rd.is_done() && !cc.wants_write() {
ext.wr.close();
}
continue;
}
if !cc.is_handshaking() {
if !int.rd.is_empty() {
int.rd.output_to(&mut cc.writer(), false).map_err(|e| {
TlsError(format!(
"Unexpected error from ClientConnection::writer.write: {e}"
))
})?;
continue;
}
if int.rd.consume_eof() {
if int.rd.is_aborted() {
ext.wr.abort();
} else {
cc.send_close_notify();
}
continue;
}
}
if cc.wants_read() && !ext.rd.is_empty() {
cc.read_tls(&mut ext.rd).map_err(|e| {
TlsError(format!(
"Unexpected failure from ClientConnection::read_tls: {e}"
))
})?;
let state = cc
.process_new_packets()
.map_err(|e| TlsError(format!("TLS stream error: {e}")))?;
if !int.wr.is_eof() {
let read_len = state.plaintext_bytes_to_read();
if read_len > 0 {
if let Err(e) = int.wr.input_from(&mut cc.reader(), read_len) {
match e.kind() {
ErrorKind::WouldBlock => (),
ErrorKind::UnexpectedEof => int.wr.abort(),
_ => return Err(TlsError(format!("TLS read error: {e}"))),
}
}
}
}
continue;
}
if ext.rd.has_pending_eof()
&& (ext.rd.is_aborted() || ext.rd.is_empty() || int.rd.is_done())
{
ext.rd.consume_eof();
if !int.wr.is_eof() {
if ext.rd.is_aborted() {
int.wr.abort();
} else {
int.wr.close();
}
}
continue;
}
break;
}
} else {
int.rd.forward(ext.wr.reborrow());
ext.rd.forward(int.wr.reborrow());
}
let after = tripwire!(ext.rd, ext.wr, int.rd, int.wr);
Ok(after != before)
}
}