use std::{any, cell::RefCell, io, sync::Arc};
use ntex_io::{Filter, FilterBuf, FilterLayer, Io, Layer};
use ntex_util::{time, time::Millis};
use tls_rustls::{ServerConfig, ServerConnection};
use crate::{Servername, rustls::Stream};
#[derive(Debug)]
pub struct TlsServerFilter {
session: RefCell<ServerConnection>,
}
impl FilterLayer for TlsServerFilter {
fn query(&self, id: any::TypeId) -> Option<Box<dyn any::Any>> {
let session = &mut *self.session.borrow_mut();
if let Some(item) = Stream::new(session).query(id) {
Some(item)
} else if id == any::TypeId::of::<Servername>() {
if let Some(name) = session.server_name() {
Some(Box::new(Servername(name.to_string())))
} else {
None
}
} else {
None
}
}
fn process_read_buf(&self, buf: &mut FilterBuf<'_>) -> io::Result<()> {
Stream::new(&mut *self.session.borrow_mut()).process_read_buf(buf)
}
fn process_write_buf(&self, buf: &mut FilterBuf<'_>) -> io::Result<()> {
Stream::new(&mut *self.session.borrow_mut()).process_write_buf(buf)
}
}
impl TlsServerFilter {
pub async fn create<F: Filter>(
io: Io<F>,
cfg: Arc<ServerConfig>,
timeout: Millis,
) -> Result<Io<Layer<TlsServerFilter, F>>, io::Error> {
log::trace!("{}: Initiate server connection", io.tag());
time::timeout(timeout, async {
let mut session = ServerConnection::new(cfg).map_err(io::Error::other)?;
session.set_buffer_limit(Some(io.cfg().write_page_size().capacity()));
let io = io.add_filter(TlsServerFilter {
session: RefCell::new(session),
});
super::stream::handshake(&io.filter().session, &io).await?;
log::trace!("{}: TLS Handshake successed", io.tag());
Ok(io)
})
.await
.map_err(|()| io::Error::new(io::ErrorKind::TimedOut, "rustls handshake timeout"))
.and_then(|item| item)
}
}