use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
use crate::config_watcher::ServiceChange;
use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
use crate::multi_map::MultiMap;
use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
use crate::protocol::{
self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, UdpTraffic,
HASH_WIDTH_IN_BYTES,
};
use crate::transport::{SocketOpts, TcpTransport, Transport};
use anyhow::{anyhow, bail, Context, Result};
use backoff::backoff::Backoff;
use backoff::ExponentialBackoff;
use rand::RngCore;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::time;
use tracing::{debug, error, info, info_span, instrument, warn, Instrument, Span};
#[cfg(feature = "noise")]
use crate::transport::NoiseTransport;
#[cfg(feature = "tls")]
use crate::transport::TlsTransport;
type ServiceDigest = protocol::Digest; type Nonce = protocol::Digest;
const TCP_POOL_SIZE: usize = 8; const UDP_POOL_SIZE: usize = 2; const CHAN_SIZE: usize = 2048; const HANDSHAKE_TIMEOUT: u64 = 5;
pub async fn run_server(
config: &Config,
shutdown_rx: broadcast::Receiver<bool>,
service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> {
let config = match &config.server {
Some(config) => config,
None => {
return Err(anyhow!("Try to run as a server, but the configuration is missing. Please add the `[server]` block"))
}
};
match config.transport.transport_type {
TransportType::Tcp => {
let mut server = Server::<TcpTransport>::from(config).await?;
server.run(shutdown_rx, service_rx).await?;
}
TransportType::Tls => {
#[cfg(feature = "tls")]
{
let mut server = Server::<TlsTransport>::from(config).await?;
server.run(shutdown_rx, service_rx).await?;
}
#[cfg(not(feature = "tls"))]
crate::helper::feature_not_compile("tls")
}
TransportType::Noise => {
#[cfg(feature = "noise")]
{
let mut server = Server::<NoiseTransport>::from(config).await?;
server.run(shutdown_rx, service_rx).await?;
}
#[cfg(not(feature = "noise"))]
crate::helper::feature_not_compile("noise")
}
}
Ok(())
}
type ControlChannelMap<T> = MultiMap<ServiceDigest, Nonce, ControlChannelHandle<T>>;
struct Server<'a, T: Transport> {
config: &'a ServerConfig,
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
transport: Arc<T>,
}
fn generate_service_hashmap(
server_config: &ServerConfig,
) -> HashMap<ServiceDigest, ServerServiceConfig> {
let mut ret = HashMap::new();
for u in &server_config.services {
ret.insert(protocol::digest(u.0.as_bytes()), (*u.1).clone());
}
ret
}
impl<'a, T: 'static + Transport> Server<'a, T> {
pub async fn from(config: &'a ServerConfig) -> Result<Server<'a, T>> {
Ok(Server {
config,
services: Arc::new(RwLock::new(generate_service_hashmap(config))),
control_channels: Arc::new(RwLock::new(ControlChannelMap::new())),
transport: Arc::new(T::new(&config.transport)?),
})
}
pub async fn run(
&mut self,
mut shutdown_rx: broadcast::Receiver<bool>,
mut service_rx: mpsc::Receiver<ServiceChange>,
) -> Result<()> {
let l = self
.transport
.bind(&self.config.bind_addr)
.await
.with_context(|| "Failed to listen at `server.bind_addr`")?;
info!("Listening at {}", self.config.bind_addr);
let mut backoff = ExponentialBackoff {
max_interval: Duration::from_millis(100),
max_elapsed_time: None,
..Default::default()
};
loop {
tokio::select! {
ret = self.transport.accept(&l) => {
match ret {
Err(err) => {
if let Some(err) = err.downcast_ref::<io::Error>() {
if let Some(d) = backoff.next_backoff() {
error!("Failed to accept: {}. Retry in {:?}...", err, d);
time::sleep(d).await;
} else {
error!("Too many retries. Aborting...");
break;
}
}
}
Ok((conn, addr)) => {
backoff.reset();
match time::timeout(Duration::from_secs(HANDSHAKE_TIMEOUT), self.transport.handshake(conn)).await {
Ok(conn) => {
match conn.with_context(|| "Failed to do transport handshake") {
Ok(conn) => {
let services = self.services.clone();
let control_channels = self.control_channels.clone();
tokio::spawn(async move {
if let Err(err) = handle_connection(conn, services, control_channels).await {
error!("{:?}", err);
}
}.instrument(info_span!("handle_connection", %addr)));
}, Err(e) => {
error!("{:?}", e);
}
}
},
Err(e) => {
error!("Transport handshake timeout: {}", e);
}
}
}
}
},
_ = shutdown_rx.recv() => {
info!("Shuting down gracefully...");
break;
},
e = service_rx.recv() => {
if let Some(e) = e {
self.handle_hot_reload(e).await;
}
}
}
}
info!("Shutdown");
Ok(())
}
async fn handle_hot_reload(&mut self, e: ServiceChange) {
match e {
ServiceChange::ServerAdd(s) => {
let hash = protocol::digest(s.name.as_bytes());
let mut wg = self.services.write().await;
let _ = wg.insert(hash, s);
let mut wg = self.control_channels.write().await;
let _ = wg.remove1(&hash);
}
ServiceChange::ServerDelete(s) => {
let hash = protocol::digest(s.as_bytes());
let _ = self.services.write().await.remove(&hash);
let mut wg = self.control_channels.write().await;
let _ = wg.remove1(&hash);
}
_ => (),
}
}
}
async fn handle_connection<T: 'static + Transport>(
mut conn: T::Stream,
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
) -> Result<()> {
let hello = read_hello(&mut conn).await?;
match hello {
ControlChannelHello(_, service_digest) => {
do_control_channel_handshake(conn, services, control_channels, service_digest).await?;
}
DataChannelHello(_, nonce) => {
do_data_channel_handshake(conn, control_channels, nonce).await?;
}
}
Ok(())
}
async fn do_control_channel_handshake<T: 'static + Transport>(
mut conn: T::Stream,
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
service_digest: ServiceDigest,
) -> Result<()> {
info!("Try to handshake a control channel");
T::hint(&conn, SocketOpts::for_control_channel());
let mut nonce = vec![0u8; HASH_WIDTH_IN_BYTES];
rand::thread_rng().fill_bytes(&mut nonce);
let hello_send = Hello::ControlChannelHello(
protocol::CURRENT_PROTO_VERSION,
nonce.clone().try_into().unwrap(),
);
conn.write_all(&bincode::serialize(&hello_send).unwrap())
.await?;
conn.flush().await?;
let service_config = match services.read().await.get(&service_digest) {
Some(v) => v,
None => {
conn.write_all(&bincode::serialize(&Ack::ServiceNotExist).unwrap())
.await?;
bail!("No such a service {}", hex::encode(&service_digest));
}
}
.to_owned();
let service_name = &service_config.name;
let mut concat = Vec::from(service_config.token.as_ref().unwrap().as_bytes());
concat.append(&mut nonce);
let protocol::Auth(d) = read_auth(&mut conn).await?;
let session_key = protocol::digest(&concat);
if session_key != d {
conn.write_all(&bincode::serialize(&Ack::AuthFailed).unwrap())
.await?;
debug!(
"Expect {}, but got {}",
hex::encode(session_key),
hex::encode(d)
);
bail!("Service {} failed the authentication", service_name);
} else {
let mut h = control_channels.write().await;
if h.remove1(&service_digest).is_some() {
warn!(
"Dropping previous control channel for service {}",
service_name
);
}
conn.write_all(&bincode::serialize(&Ack::Ok).unwrap())
.await?;
conn.flush().await?;
info!(service = %service_config.name, "Control channel established");
let handle = ControlChannelHandle::new(conn, service_config);
let _ = h.insert(service_digest, session_key, handle);
}
Ok(())
}
async fn do_data_channel_handshake<T: 'static + Transport>(
conn: T::Stream,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
nonce: Nonce,
) -> Result<()> {
debug!("Try to handshake a data channel");
let control_channels_guard = control_channels.read().await;
match control_channels_guard.get2(&nonce) {
Some(handle) => {
T::hint(&conn, SocketOpts::from_server_cfg(&handle.service));
handle
.data_ch_tx
.send(conn)
.await
.with_context(|| "Data channel for a stale control channel")?;
}
None => {
warn!("Data channel has incorrect nonce");
}
}
Ok(())
}
pub struct ControlChannelHandle<T: Transport> {
_shutdown_tx: broadcast::Sender<bool>,
data_ch_tx: mpsc::Sender<T::Stream>,
service: ServerServiceConfig,
}
impl<T> ControlChannelHandle<T>
where
T: 'static + Transport,
{
#[instrument(skip_all, fields(service = %service.name))]
fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
let (data_ch_req_tx, data_ch_req_rx) = mpsc::unbounded_channel();
let pool_size = match service.service_type {
ServiceType::Tcp => TCP_POOL_SIZE,
ServiceType::Udp => UDP_POOL_SIZE,
};
for _i in 0..pool_size {
if let Err(e) = data_ch_req_tx.send(true) {
error!("Failed to request data channel {}", e);
};
}
let shutdown_rx_clone = shutdown_tx.subscribe();
let bind_addr = service.bind_addr.clone();
match service.service_type {
ServiceType::Tcp => tokio::spawn(
async move {
if let Err(e) = run_tcp_connection_pool::<T>(
bind_addr,
data_ch_rx,
data_ch_req_tx,
shutdown_rx_clone,
)
.await
.with_context(|| "Failed to run TCP connection pool")
{
error!("{:?}", e);
}
}
.instrument(Span::current()),
),
ServiceType::Udp => tokio::spawn(
async move {
if let Err(e) = run_udp_connection_pool::<T>(
bind_addr,
data_ch_rx,
data_ch_req_tx,
shutdown_rx_clone,
)
.await
.with_context(|| "Failed to run TCP connection pool")
{
error!("{:?}", e);
}
}
.instrument(Span::current()),
),
};
let ch = ControlChannel::<T> {
conn,
shutdown_rx,
service: service.clone(),
data_ch_req_rx,
};
tokio::spawn(
async move {
if let Err(err) = ch.run().await {
error!("{:?}", err);
}
}
.instrument(Span::current()),
);
ControlChannelHandle {
_shutdown_tx: shutdown_tx,
data_ch_tx,
service,
}
}
}
struct ControlChannel<T: Transport> {
conn: T::Stream, service: ServerServiceConfig, shutdown_rx: broadcast::Receiver<bool>, data_ch_req_rx: mpsc::UnboundedReceiver<bool>, }
impl<T: Transport> ControlChannel<T> {
#[instrument(skip(self), fields(service = %self.service.name))]
async fn run(mut self) -> Result<()> {
let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
loop {
tokio::select! {
val = self.data_ch_req_rx.recv() => {
match val {
Some(_) => {
if let Err(e) = self.conn.write_all(&cmd).await.with_context(||"Failed to write control cmds") {
error!("{:?}", e);
break;
}
if let Err(e) = self.conn.flush().await.with_context(|| "Failed to flush control cmds") {
error!("{:?}", e);
break;
}
}
None => {
break;
}
}
},
_ = self.shutdown_rx.recv() => {
break;
}
}
}
info!("Control channel shutdown");
Ok(())
}
}
fn tcp_listen_and_send(
addr: String,
data_ch_req_tx: mpsc::UnboundedSender<bool>,
mut shutdown_rx: broadcast::Receiver<bool>,
) -> mpsc::Receiver<TcpStream> {
let (tx, rx) = mpsc::channel(CHAN_SIZE);
tokio::spawn(async move {
let l = backoff::future::retry_notify(listen_backoff(), || async {
Ok(TcpListener::bind(&addr).await?)
}, |e, duration| {
error!("{:?}. Retry in {:?}", e, duration);
})
.await
.with_context(|| "Failed to listen for the service");
let l: TcpListener = match l {
Ok(v) => v,
Err(e) => {
error!("{:?}", e);
return;
}
};
info!("Listening at {}", &addr);
let mut backoff = ExponentialBackoff {
max_interval: Duration::from_secs(1),
max_elapsed_time: None,
..Default::default()
};
loop {
tokio::select! {
val = l.accept() => {
match val {
Err(e) => {
error!("{}. Sleep for a while", e);
if let Some(d) = backoff.next_backoff() {
time::sleep(d).await;
} else {
error!("Too many retries. Aborting...");
break;
}
}
Ok((incoming, addr)) => {
if data_ch_req_tx.send(true).with_context(|| "Failed to send data chan create request").is_err() {
break;
}
backoff.reset();
debug!("New visitor from {}", addr);
let _ = tx.send(incoming).await;
}
}
},
_ = shutdown_rx.recv() => {
break;
}
}
}
info!("TCPListener shutdown");
}.instrument(Span::current()));
rx
}
#[instrument(skip_all)]
async fn run_tcp_connection_pool<T: Transport>(
bind_addr: String,
mut data_ch_rx: mpsc::Receiver<T::Stream>,
data_ch_req_tx: mpsc::UnboundedSender<bool>,
shutdown_rx: broadcast::Receiver<bool>,
) -> Result<()> {
let mut visitor_rx = tcp_listen_and_send(bind_addr, data_ch_req_tx, shutdown_rx);
while let Some(mut visitor) = visitor_rx.recv().await {
if let Some(mut ch) = data_ch_rx.recv().await {
tokio::spawn(async move {
let cmd = bincode::serialize(&DataChannelCmd::StartForwardTcp).unwrap();
if ch.write_all(&cmd).await.is_ok() {
let _ = copy_bidirectional(&mut ch, &mut visitor).await;
}
});
} else {
break;
}
}
info!("Shutdown");
Ok(())
}
#[instrument(skip_all)]
async fn run_udp_connection_pool<T: Transport>(
bind_addr: String,
mut data_ch_rx: mpsc::Receiver<T::Stream>,
_data_ch_req_tx: mpsc::UnboundedSender<bool>,
mut shutdown_rx: broadcast::Receiver<bool>,
) -> Result<()> {
let l: UdpSocket = backoff::future::retry_notify(
listen_backoff(),
|| async {
Ok(UdpSocket::bind(&bind_addr)
.await
.with_context(|| "Failed to listen for the service")?)
},
|e, duration| {
warn!("{:?}. Retry in {:?}", e, duration);
},
)
.await
.with_context(|| "Failed to listen for the service")?;
info!("Listening at {}", &bind_addr);
let cmd = bincode::serialize(&DataChannelCmd::StartForwardUdp).unwrap();
let mut conn = data_ch_rx
.recv()
.await
.ok_or(anyhow!("No available data channels"))?;
conn.write_all(&cmd).await?;
let mut buf = [0u8; UDP_BUFFER_SIZE];
loop {
tokio::select! {
val = l.recv_from(&mut buf) => {
let (n, from) = val?;
UdpTraffic::write_slice(&mut conn, from, &buf[..n]).await?;
},
hdr_len = conn.read_u8() => {
let t = UdpTraffic::read(&mut conn, hdr_len?).await?;
l.send_to(&t.data, t.from).await?;
}
_ = shutdown_rx.recv() => {
break;
}
}
}
debug!("UDP pool dropped");
Ok(())
}