use crate::error::*;
use crate::sansio::Server;
use async_trait::async_trait;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
#[cfg(unix)]
use tokio::net::{UnixListener, UnixStream};
#[derive(Debug)]
pub enum AsyncListener {
TCP(TcpListener),
#[cfg(unix)]
UNIX(UnixListener),
}
impl AsyncListener {
pub async fn new<S: AsRef<str>>(address: S) -> Result<Self> {
let address = address.as_ref();
if let Some(addr) = address.strip_prefix("tcp:") {
let listener = TcpListener::bind(addr)
.await
.map_err(|e| context!(ErrorKind::Io(e.kind())))?;
Ok(AsyncListener::TCP(listener))
} else if let Some(addr) = address.strip_prefix("unix:") {
#[cfg(unix)]
{
Self::create_unix_listener(addr).await
}
#[cfg(not(unix))]
{
let _ = addr;
Err(context!(ErrorKind::InvalidAddress))
}
} else {
Err(context!(ErrorKind::InvalidAddress))
}
}
#[cfg(unix)]
async fn create_unix_listener(addr: &str) -> Result<Self> {
use std::fs;
if let Some(abstract_addr) = addr.strip_prefix('@') {
#[cfg(any(target_os = "linux", target_os = "android"))]
{
let addr = abstract_addr.split(';').next().unwrap_or(abstract_addr);
let socket_path = format!("\0{}", addr);
let listener = UnixListener::bind(socket_path)
.map_err(|e| context!(ErrorKind::Io(e.kind())))?;
Ok(AsyncListener::UNIX(listener))
}
#[cfg(not(any(target_os = "linux", target_os = "android")))]
{
let _ = abstract_addr;
Err(context!(ErrorKind::InvalidAddress))
}
} else {
let addr = addr.split(';').next().unwrap_or(addr);
let _ = fs::remove_file(addr);
let listener =
UnixListener::bind(addr).map_err(|e| context!(ErrorKind::Io(e.kind())))?;
Ok(AsyncListener::UNIX(listener))
}
}
pub async fn accept(&self) -> Result<AsyncStream> {
match self {
AsyncListener::TCP(listener) => {
let (stream, _) = listener
.accept()
.await
.map_err(|e| context!(ErrorKind::Io(e.kind())))?;
Ok(AsyncStream::TCP(stream))
}
#[cfg(unix)]
AsyncListener::UNIX(listener) => {
let (stream, _) = listener
.accept()
.await
.map_err(|e| context!(ErrorKind::Io(e.kind())))?;
Ok(AsyncStream::UNIX(stream))
}
}
}
}
pub enum AsyncStream {
TCP(TcpStream),
#[cfg(unix)]
UNIX(UnixStream),
}
impl AsyncStream {
async fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
AsyncStream::TCP(stream) => stream.read(buf).await,
#[cfg(unix)]
AsyncStream::UNIX(stream) => stream.read(buf).await,
}
}
async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
match self {
AsyncStream::TCP(stream) => stream.write_all(buf).await,
#[cfg(unix)]
AsyncStream::UNIX(stream) => stream.write_all(buf).await,
}
}
async fn flush(&mut self) -> std::io::Result<()> {
match self {
AsyncStream::TCP(stream) => stream.flush().await,
#[cfg(unix)]
AsyncStream::UNIX(stream) => stream.flush().await,
}
}
}
#[async_trait]
pub trait AsyncConnectionHandler: Send + Sync {
async fn handle(
&self,
server: &mut Server,
upgraded_iface: Option<String>,
) -> Result<Option<String>>;
}
pub struct ListenAsyncConfig {
pub idle_timeout: Duration,
pub stop_listening: Option<Arc<AtomicBool>>,
}
impl Default for ListenAsyncConfig {
fn default() -> Self {
ListenAsyncConfig {
idle_timeout: Duration::ZERO,
stop_listening: None,
}
}
}
pub async fn listen_async<S: AsRef<str>, H: AsyncConnectionHandler + 'static>(
handler: Arc<H>,
address: S,
config: &ListenAsyncConfig,
) -> Result<()> {
let listener = AsyncListener::new(address).await?;
let mut active_connections = 0usize;
loop {
let stream = if config.idle_timeout.as_secs() > 0 || config.stop_listening.is_some() {
let timeout_duration = if config.stop_listening.is_some() {
Duration::from_millis(100)
} else {
config.idle_timeout
};
match tokio::time::timeout(timeout_duration, listener.accept()).await {
Ok(Ok(stream)) => stream,
Ok(Err(e)) => return Err(e),
Err(_) => {
if let Some(stop) = &config.stop_listening {
if stop.load(Ordering::SeqCst) {
return Ok(());
}
}
if config.idle_timeout.as_secs() > 0 && active_connections == 0 {
return Err(context!(ErrorKind::Timeout));
}
continue;
}
}
} else {
listener.accept().await?
};
let handler = Arc::clone(&handler);
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, handler).await {
match e.kind() {
ErrorKind::ConnectionClosed => {}
_ => eprintln!("Connection error: {:?}", e),
}
}
});
active_connections += 1;
}
}
async fn handle_connection<H: AsyncConnectionHandler>(
mut stream: AsyncStream,
handler: Arc<H>,
) -> Result<()> {
let mut server = Server::new();
let mut buf = vec![0u8; 8192];
let mut upgraded_iface: Option<String> = None;
loop {
let n = stream
.read(&mut buf)
.await
.map_err(|_| context!(ErrorKind::ConnectionClosed))?;
if n == 0 {
return Ok(());
}
server.handle_input(&buf[..n])?;
upgraded_iface = handler.handle(&mut server, upgraded_iface.clone()).await?;
while let Some(transmit) = server.poll_transmit() {
stream
.write_all(&transmit.payload)
.await
.map_err(|_| context!(ErrorKind::ConnectionClosed))?;
stream
.flush()
.await
.map_err(|_| context!(ErrorKind::ConnectionClosed))?;
}
}
}