use std::{io::Error, net::SocketAddr};
#[cfg(feature = "ssl")]
use std::sync::Arc;
use anyhow::Result;
use bytes::{Bytes, BytesMut};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
net::{TcpListener as TokioTcpListener, TcpStream},
sync::mpsc::{Sender, UnboundedReceiver, channel, unbounded_channel},
};
#[cfg(feature = "ssl")]
use tokio_rustls::{
TlsAcceptor,
rustls::{
ServerConfig,
pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
},
server::TlsStream,
};
use crate::{
codec::Decoder,
server::transport::{MAX_MESSAGE_SIZE, Server, ServerOptions, Socket},
};
enum MaybeSslStream {
#[cfg(feature = "ssl")]
Ssl(Box<TlsStream<TcpStream>>),
Base(TcpStream),
}
impl MaybeSslStream {
fn split(self) -> (Reader, Writer) {
use tokio::io::split;
match self {
Self::Base(it) => {
let (rx, tx) = split(it);
(Reader::Base(rx), Writer::Base(tx))
}
#[cfg(feature = "ssl")]
Self::Ssl(it) => {
let (rx, tx) = split(it);
(Reader::Ssl(rx), Writer::Ssl(tx))
}
}
}
}
enum Reader {
#[cfg(feature = "ssl")]
Ssl(ReadHalf<Box<TlsStream<TcpStream>>>),
Base(ReadHalf<TcpStream>),
}
impl Reader {
async fn read_buf(&mut self, buffer: &mut BytesMut) -> Result<usize, Error> {
match self {
Self::Base(it) => it.read_buf(buffer).await,
#[cfg(feature = "ssl")]
Self::Ssl(it) => it.read_buf(buffer).await,
}
}
}
enum Writer {
#[cfg(feature = "ssl")]
Ssl(WriteHalf<Box<TlsStream<TcpStream>>>),
Base(WriteHalf<TcpStream>),
}
impl Writer {
async fn write_all(&mut self, buffer: &[u8]) -> Result<(), Error> {
match self {
Self::Base(it) => it.write_all(buffer).await,
#[cfg(feature = "ssl")]
Self::Ssl(it) => it.write_all(buffer).await,
}
}
}
pub struct TcpSocket {
writer: Writer,
receiver: UnboundedReceiver<Bytes>,
close_signal_sender: Sender<()>,
}
impl TcpSocket {
fn new(stream: MaybeSslStream, addr: SocketAddr) -> Self {
let (close_signal_sender, mut close_signal_receiver) = channel::<()>(1);
let (tx, receiver) = unbounded_channel::<Bytes>();
let (mut reader, writer) = stream.split();
tokio::spawn(async move {
let mut buffer = BytesMut::new();
'a: loop {
tokio::select! {
Ok(size) = reader.read_buf(&mut buffer) => {
if size == 0 {
break;
}
if buffer.len() < 4 {
continue;
}
if buffer.len() > MAX_MESSAGE_SIZE * 3 {
break;
}
loop {
if buffer.len() <= 4 {
break;
}
let size = match Decoder::message_size(&buffer, true) {
Err(_) => break,
Ok(size) => {
if size > MAX_MESSAGE_SIZE {
log::warn!(
"tcp message size too large: \
size={size}, \
max={MAX_MESSAGE_SIZE}, \
addr={addr:?}"
);
break 'a;
}
if size > buffer.len() {
break;
}
size
}
};
if tx.send(buffer.split_to(size).freeze()).is_err() {
break 'a;
}
}
}
_ = close_signal_receiver.recv() => {
break;
}
else => {
break;
}
}
}
});
Self {
close_signal_sender,
writer,
receiver,
}
}
}
impl Socket for TcpSocket {
async fn read(&mut self) -> Option<Bytes> {
self.receiver.recv().await
}
async fn write(&mut self, buffer: &[u8]) -> Result<()> {
Ok(self.writer.write_all(buffer).await?)
}
async fn close(&mut self) {
self.receiver.close();
let _ = self.close_signal_sender.send(()).await;
}
}
pub struct TcpServer {
socket_receiver: UnboundedReceiver<(TcpSocket, SocketAddr)>,
local_addr: SocketAddr,
}
impl Server for TcpServer {
type Socket = TcpSocket;
async fn bind(options: &ServerOptions) -> Result<Self> {
#[cfg(feature = "ssl")]
let acceptor = if let Some(ssl) = &options.ssl {
Some(TlsAcceptor::from(Arc::new(
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(
CertificateDer::pem_file_iter(ssl.certificate_chain.clone())?
.collect::<Result<Vec<_>, _>>()?,
PrivateKeyDer::from_pem_file(ssl.private_key.clone())?,
)?,
)))
} else {
None
};
let listener = TokioTcpListener::bind(options.listen).await?;
let local_addr = listener.local_addr()?;
let (tx, socket_receiver) = unbounded_channel::<(TcpSocket, SocketAddr)>();
tokio::spawn(async move {
while let Ok((socket, addr)) = listener.accept().await {
if let Err(e) = socket.set_nodelay(true) {
log::warn!("tls socket set nodelay failed!: addr={addr}, err={e}");
}
#[cfg(feature = "ssl")]
if let Some(acceptor) = acceptor.clone() {
let tx = tx.clone();
tokio::spawn(async move {
if let Ok(socket) = acceptor.accept(socket).await {
let _ = tx.send((
TcpSocket::new(MaybeSslStream::Ssl(socket.into()), addr),
addr,
));
};
});
continue;
}
if tx
.send((TcpSocket::new(MaybeSslStream::Base(socket), addr), addr))
.is_err()
{
break;
}
}
});
Ok(Self {
socket_receiver,
local_addr,
})
}
async fn accept(&mut self) -> Option<(Self::Socket, SocketAddr)> {
self.socket_receiver.recv().await
}
fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.local_addr)
}
}