use super::common::protocol::Proto;
use crate::{
Config, SUPPORTED_PROTOCOL_VERSIONS,
auth::JwtClaims,
print_endpoint,
server::{CloneableWbApi, common::init_server_socket},
stats::VERSION,
};
use hashbrown::HashMap;
use miette::{IntoDiagnostic, Result};
use std::{
io,
net::{IpAddr, SocketAddr},
ops::ControlFlow,
time::Duration,
};
use tokio::{
io::{AsyncBufReadExt, BufReader, Lines},
net::{
TcpListener, TcpStream,
tcp::{OwnedReadHalf, OwnedWriteHalf},
},
select,
sync::mpsc,
};
use tosub::SubsystemHandle;
use tracing::{debug, error, info, trace, warn};
use uuid::Uuid;
use worterbuch_common::{
Protocol, ServerInfo, ServerMessage, WbApi, Welcome, write_line_and_flush,
};
enum SocketEvent {
Disconnected(Option<Uuid>),
Connected(Option<Result<(TcpStream, SocketAddr), io::Error>>),
ShutdownRequested,
}
pub async fn start(
worterbuch: CloneableWbApi,
config: Config,
bind_addr: IpAddr,
port: u16,
subsys: SubsystemHandle,
) -> Result<()> {
let addr = format!("{bind_addr}:{port}");
info!("Serving TCP endpoint at {addr}");
let listener = init_server_socket(bind_addr, port, config.clone())?;
if config.print_endpoints {
print_endpoint(&listener, true)?;
}
let listener = TcpListener::from_std(listener).into_diagnostic()?;
let (conn_closed_tx, mut conn_closed_rx) = mpsc::channel(100);
let mut waiting_for_free_connections = false;
let mut clients = HashMap::new();
loop {
let evt = next_socket_event(
&subsys,
&mut conn_closed_rx,
&listener,
waiting_for_free_connections,
)
.await;
match evt {
SocketEvent::Disconnected(uuid) => {
if let Some(id) = uuid {
clients.remove(&id);
while let Ok(id) = conn_closed_rx.try_recv() {
clients.remove(&id);
}
debug!("{} TCP connection(s) open.", clients.len());
waiting_for_free_connections = false;
} else {
break;
}
}
SocketEvent::Connected(con) => {
if let Some(con) = con {
debug!("Trying to accept new client connection.");
match con {
Ok((socket, remote_addr)) => {
let id = Uuid::new_v4();
debug!("{} TCP connection(s) open.", clients.len());
let worterbuch = worterbuch.clone();
let conn_closed_tx = conn_closed_tx.clone();
let client = subsys.spawn(format!("client-{id}"), async move |s| {
select! {
s = serve(&s, id, remote_addr, worterbuch, socket) => if let Err(e) = s {
error!("Connection to client {id} ({remote_addr:?}) closed with error: {e}");
},
_ = s.shutdown_requested() => (),
}
conn_closed_tx.send(id).await.ok();
Ok::<(),miette::Error>(())
});
clients.insert(id, client);
}
Err(e) => {
error!("Error while trying to accept client connection: {e}");
warn!(
"{} TCP connections open, waiting for connections to close.",
clients.len()
);
waiting_for_free_connections = true;
}
}
debug!("Ready to accept new connections.");
}
}
SocketEvent::ShutdownRequested => break,
}
}
for (cid, subsys) in clients {
subsys.request_local_shutdown();
debug!("Waiting for connection to client {cid} to close …");
subsys.join().await;
}
debug!("All clients disconnected.");
drop(listener);
debug!("tcpserver subsystem completed.");
Ok(())
}
async fn next_socket_event(
subsys: &SubsystemHandle,
conn_closed_rx: &mut mpsc::Receiver<Uuid>,
listener: &TcpListener,
waiting_for_free_connections: bool,
) -> SocketEvent {
select! {
recv = conn_closed_rx.recv() => SocketEvent::Disconnected(recv),
con = listener.accept() => if !waiting_for_free_connections {
SocketEvent::Connected(Some(con))
} else {
SocketEvent::Connected(None)
},
_ = subsys.shutdown_requested() => SocketEvent::ShutdownRequested,
}
}
async fn serve(
subsys: &SubsystemHandle,
client_id: Uuid,
remote_addr: SocketAddr,
worterbuch: CloneableWbApi,
socket: TcpStream,
) -> Result<()> {
info!("New client connected: {client_id} ({remote_addr})");
if let Err(e) = worterbuch
.connected(client_id, Some(remote_addr), Protocol::TCP)
.await
{
error!("Error while adding new client: {e}");
} else {
debug!("Receiving messages from client {client_id} ({remote_addr}) …",);
if let Err(e) = serve_loop(subsys, client_id, remote_addr, worterbuch.clone(), socket).await
{
error!("Error in serve loop: {e}");
}
}
worterbuch
.disconnected(client_id, Some(remote_addr))
.await?;
Ok(())
}
struct ServeLoop {
client_id: Uuid,
remote_addr: SocketAddr,
authorized: Option<JwtClaims>,
tcp_rx: Lines<BufReader<OwnedReadHalf>>,
proto: Proto,
}
async fn serve_loop(
subsys: &SubsystemHandle,
client_id: Uuid,
remote_addr: SocketAddr,
worterbuch: CloneableWbApi,
socket: TcpStream,
) -> Result<()> {
let config = worterbuch.config().to_owned();
let authorization_required = config.auth_token_key.is_some();
let send_timeout = config.send_timeout;
let authorized = None;
let (tcp_rx, tcp_tx) = socket.into_split();
let (tcp_send_tx, tcp_send_rx) = mpsc::channel(config.channel_buffer_size);
subsys.spawn("forward_messages_to_socket", async move |s| {
forward_messages_to_socket(s, tcp_send_rx, tcp_tx, client_id, send_timeout).await
});
let tcp_rx = BufReader::new(tcp_rx);
let tcp_rx = tcp_rx.lines();
let supported_protocol_versions = SUPPORTED_PROTOCOL_VERSIONS.into();
tcp_send_tx
.send(ServerMessage::Welcome(Welcome {
client_id: client_id.to_string(),
info: ServerInfo::new(
VERSION.to_owned(),
supported_protocol_versions,
authorization_required,
),
}))
.await
.into_diagnostic()?;
let proto = Proto::new(
client_id,
tcp_send_tx,
authorization_required,
config,
worterbuch,
);
let serve_loop = ServeLoop {
authorized,
client_id,
proto,
remote_addr,
tcp_rx,
};
serve_loop.run().await
}
async fn forward_messages_to_socket(
subsys: SubsystemHandle,
mut tcp_send_rx: mpsc::Receiver<ServerMessage>,
mut tcp_tx: OwnedWriteHalf,
client_id: Uuid,
send_timeout: Option<Duration>,
) -> Result<()> {
loop {
select! {
recv = tcp_send_rx.recv() => if let Some(msg) = recv {
if let Err(e) = write_line_and_flush(msg, &mut tcp_tx, send_timeout, client_id).await {
error!("Error sending TCP message: {e}");
break;
}
} else {
warn!("Message forwarding to client {client_id} stopped: channel closed.");
break;
},
_ = subsys.shutdown_requested() => {
warn!("Message forwarding to client {client_id} stopped: subsystem stopped.");
break;
},
}
}
Ok(())
}
impl ServeLoop {
async fn run(mut self) -> Result<()> {
loop {
let next_line = self.tcp_rx.next_line().await;
if let ControlFlow::Break(it) = self.process_next_line(next_line).await? {
break Ok(it);
}
}
}
async fn process_next_line(
&mut self,
next_line: Result<Option<String>, io::Error>,
) -> Result<ControlFlow<()>> {
match next_line {
Ok(Some(json)) => self.process_line(json).await,
Ok(None) => self.done(),
Err(e) => self.tcp_error(e),
}
}
async fn process_line(&mut self, json: String) -> Result<ControlFlow<()>> {
trace!("Processing incoming message …");
let msg_processed = self
.proto
.process_incoming_message(&json, &mut self.authorized)
.await?;
if !msg_processed {
return Ok(ControlFlow::Break(()));
}
trace!("Processing incoming message done.");
Ok(ControlFlow::Continue(()))
}
fn tcp_error(&mut self, e: io::Error) -> std::result::Result<ControlFlow<()>, miette::Error> {
warn!(
"TCP stream of client {} ({}) closed with error:, {}",
self.client_id, self.remote_addr, e
);
Ok(ControlFlow::Break(()))
}
fn done(&self) -> std::result::Result<ControlFlow<()>, miette::Error> {
debug!(
"TCP stream of client {} ({}) closed normally.",
self.client_id, self.remote_addr
);
Ok(ControlFlow::Break(()))
}
}