use crate::TlsError;
use pipebuf::{tripwire, PBufRdWr};
use rustls::{ServerConfig, ServerConnection};
use std::io::ErrorKind;
use std::sync::Arc;
pub struct TlsServer {
sc: Option<ServerConnection>,
}
impl TlsServer {
pub fn new(config: Option<Arc<ServerConfig>>) -> Result<Self, rustls::Error> {
let sc = if let Some(conf) = config {
Some(ServerConnection::new(conf)?)
} else {
None
};
Ok(Self { sc })
}
pub fn connection(&self) -> Option<&ServerConnection> {
self.sc.as_ref()
}
pub fn process(&mut self, mut ext: PBufRdWr, mut int: PBufRdWr) -> Result<bool, TlsError> {
let before = tripwire!(ext.rd, ext.wr, int.rd, int.wr);
if let Some(ref mut sc) = self.sc {
loop {
if sc.wants_write() && !ext.wr.is_eof() {
sc.write_tls(&mut ext.wr).map_err(|e| {
TlsError(format!(
"Unexpected error from ServerConnection::write_tls: {e}"
))
})?;
if int.rd.is_done() && !sc.wants_write() {
ext.wr.close();
}
continue;
}
if !sc.is_handshaking() {
if !int.rd.is_empty() {
int.rd.output_to(&mut sc.writer(), false).map_err(|e| {
TlsError(format!(
"Unexpected error from ServerConnection::writer.write: {e}"
))
})?;
continue;
}
if int.rd.consume_eof() {
if int.rd.is_aborted() {
ext.wr.abort();
} else {
sc.send_close_notify();
}
continue;
}
}
if sc.wants_read() && !ext.rd.is_empty() {
sc.read_tls(&mut ext.rd).map_err(|e| {
TlsError(format!(
"Unexpected failure from ServerConnection::read_tls: {e}"
))
})?;
let state = sc
.process_new_packets()
.map_err(|e| TlsError(format!("TLS stream error: {e}")))?;
if !int.wr.is_eof() {
let read_len = state.plaintext_bytes_to_read();
if read_len > 0 {
if let Err(e) = int.wr.input_from(&mut sc.reader(), read_len) {
match e.kind() {
ErrorKind::WouldBlock => (),
ErrorKind::UnexpectedEof => int.wr.abort(),
_ => return Err(TlsError(format!("TLS read error: {e}"))),
}
}
}
}
continue;
}
if ext.rd.has_pending_eof()
&& (ext.rd.is_aborted() || ext.rd.is_empty() || int.rd.is_done())
{
ext.rd.consume_eof();
if !int.wr.is_eof() {
if ext.rd.is_aborted() {
int.wr.abort();
} else {
int.wr.close();
}
}
continue;
}
break;
}
} else {
int.rd.forward(ext.wr.reborrow());
ext.rd.forward(int.wr.reborrow());
}
let after = tripwire!(ext.rd, ext.wr, int.rd, int.wr);
Ok(after != before)
}
}