use std::future::Future;
use std::sync::Arc;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{broadcast, mpsc, Semaphore};
use crate::connection::Connection;
use crate::shutdown::Shutdown;
const MAX_CONNECTIONS: usize = 1;
#[derive(Debug)]
struct Listener {
listener: TcpListener,
proxied_server: String,
notify_shutdown: broadcast::Sender<()>,
limit_connections: Arc<Semaphore>,
shutdown_complete_rx: mpsc::Receiver<()>,
shutdown_complete_tx: mpsc::Sender<()>,
}
#[derive(Debug)]
struct Handler {
connection: Connection,
shutdown: Shutdown,
_shutdown_complete: tokio::sync::mpsc::Sender<()>,
}
impl Handler {
async fn run(&mut self) -> crate::Result<()> {
while !self.shutdown.is_shutdown() {
log::trace!("starting forward/backward pipes");
tokio::select! {
_ = self.connection.forward_pipe.run() => {
log::trace!("pipe closed via forward pipe");
break;
},
_ = self.connection.backward_pipe.run() => {
log::trace!("pipe closed via backward pipe");
let err = std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"remote server prematurely closed connection"
);
return Err(err);
},
_ = self.shutdown.recv() => {
log::trace!("pipe closed via shutdown signal");
},
}
}
Ok(())
}
}
impl Listener {
pub async fn run(&mut self, config: &config::Config) -> crate::Result<()> {
log::info!("listener is running, awaiting connections");
loop {
log::trace!("awaiting permit to accept new connection");
let permit = self
.limit_connections
.clone()
.acquire_owned()
.await
.unwrap();
log::trace!(
"permit acquired (remaining: {})",
self.limit_connections.available_permits()
);
log::debug!("awaiting new connection or shutdown signal");
let (client_socket, client_addr) = self.accept().await?;
log::info!("new connection from: {}", client_addr);
let server_socket = TcpStream::connect(&self.proxied_server)
.await
.expect("error - failed connecting to proxied server");
let mut handler = Handler {
connection: Connection::new(client_socket, server_socket, config).await,
shutdown: Shutdown::new(self.notify_shutdown.subscribe()),
_shutdown_complete: self.shutdown_complete_tx.clone(),
};
tokio::spawn(async move {
log::trace!("spawned task to manage {}", client_addr);
if let Err(err) = handler.run().await {
log::error!("connection error: {}", err);
}
drop(permit);
log::info!("closing connection from: {}", client_addr);
});
}
}
async fn accept(&mut self) -> crate::Result<(TcpStream, std::net::SocketAddr)> {
let mut backoff = 1;
loop {
match self.listener.accept().await {
Ok((socket, peer_addr)) => return Ok((socket, peer_addr)),
Err(err) => {
if backoff > 64 {
log::trace!("accept failed too many times, backoff strategy exhausted");
return Err(err);
}
}
}
tokio::time::sleep(tokio::time::Duration::from_secs(backoff)).await;
backoff *= 2;
}
}
}
pub async fn run(
listener: TcpListener,
srv_addr: &str,
shutdown: impl Future,
config: &config::Config,
) {
let (notify_shutdown, _) = broadcast::channel(1);
let (shutdown_complete_tx, shutdown_complete_rx) = tokio::sync::mpsc::channel(1);
let mut server = Listener {
listener,
proxied_server: srv_addr.to_string(),
limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)),
notify_shutdown,
shutdown_complete_rx,
shutdown_complete_tx,
};
tokio::select! {
result = server.run(config) => {
if let Err(err) = result {
log::error!("failed to accept: {}", err);
}
},
_ = shutdown => {
log::info!("shutdown signal received; shutting down listener");
}
}
let Listener {
mut shutdown_complete_rx,
shutdown_complete_tx,
notify_shutdown,
..
} = server;
drop(notify_shutdown);
drop(shutdown_complete_tx);
let _ = shutdown_complete_rx.recv().await;
}