use core::fmt;
use crate::{
pubsub::{In, JsonSink, Listener, Out},
types::Request,
HandlerArgs,
};
use serde_json::value::RawValue;
use tokio::{
select,
sync::{mpsc, oneshot, watch},
task::JoinHandle,
};
use tokio_stream::StreamExt;
use tracing::{debug, debug_span, error, instrument, trace, Instrument};
pub const DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT: usize = 16;
pub type ConnectionId = u64;
#[derive(Debug)]
pub struct ServerShutdown {
pub(crate) _shutdown: watch::Sender<()>,
}
impl From<watch::Sender<()>> for ServerShutdown {
fn from(sender: watch::Sender<()>) -> Self {
Self { _shutdown: sender }
}
}
pub(crate) struct ListenerTask<T: Listener> {
pub(crate) listener: T,
pub(crate) manager: ConnectionManager,
}
impl<T> ListenerTask<T>
where
T: Listener,
{
pub(crate) async fn task_future(self) {
let ListenerTask {
listener,
mut manager,
} = self;
loop {
let (resp_sink, req_stream) = match listener.accept().await {
Ok((resp_sink, req_stream)) => (resp_sink, req_stream),
Err(err) => {
error!(%err, "Failed to accept connection");
continue;
}
};
manager.handle_new_connection::<T>(req_stream, resp_sink);
}
}
pub(crate) fn spawn(self) -> JoinHandle<()> {
let future = self.task_future();
tokio::spawn(future)
}
}
pub(crate) struct ConnectionManager {
pub(crate) shutdown: watch::Receiver<()>,
pub(crate) next_id: ConnectionId,
pub(crate) router: crate::Router<()>,
pub(crate) notification_buffer_per_task: usize,
}
impl ConnectionManager {
fn next_id(&mut self) -> ConnectionId {
let id = self.next_id;
self.next_id += 1;
id
}
fn router(&self) -> crate::Router<()> {
self.router.clone()
}
fn make_tasks<T: Listener>(
&self,
conn_id: ConnectionId,
requests: In<T>,
connection: Out<T>,
) -> (RouteTask<T>, WriteTask<T>) {
let (tx, rx) = mpsc::channel(self.notification_buffer_per_task);
let (gone_tx, gone_rx) = oneshot::channel();
let rt = RouteTask {
router: self.router(),
conn_id,
write_task: tx,
requests,
gone: gone_tx,
};
let wt = WriteTask {
shutdown: self.shutdown.clone(),
gone: gone_rx,
conn_id,
json: rx,
connection,
};
(rt, wt)
}
fn spawn_tasks<T: Listener>(&mut self, requests: In<T>, connection: Out<T>) {
let conn_id = self.next_id();
let (rt, wt) = self.make_tasks::<T>(conn_id, requests, connection);
rt.spawn();
wt.spawn();
}
fn handle_new_connection<T: Listener>(&mut self, requests: In<T>, connection: Out<T>) {
self.spawn_tasks::<T>(requests, connection);
}
}
struct RouteTask<T: crate::pubsub::Listener> {
pub(crate) router: crate::Router<()>,
pub(crate) conn_id: ConnectionId,
pub(crate) write_task: mpsc::Sender<Box<RawValue>>,
pub(crate) requests: In<T>,
pub(crate) gone: oneshot::Sender<()>,
}
impl<T: crate::pubsub::Listener> fmt::Debug for RouteTask<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RouteTask")
.field("conn_id", &self.conn_id)
.finish_non_exhaustive()
}
}
impl<T> RouteTask<T>
where
T: crate::pubsub::Listener,
{
#[instrument(name = "RouteTask", skip(self), fields(conn_id = self.conn_id))]
pub async fn task_future(self) {
let RouteTask {
router,
mut requests,
write_task,
gone,
..
} = self;
loop {
select! {
biased;
_ = write_task.closed() => {
debug!("IpcWriteTask has gone away");
break;
}
item = requests.next() => {
let Some(item) = item else {
trace!("IPC read stream has closed");
break;
};
let Ok(req) = Request::try_from(item) else {
tracing::warn!("inbound request is malformatted");
continue
};
let span = debug_span!("ipc request handling", id = req.id(), method = req.method());
let args = HandlerArgs {
ctx: write_task.clone().into(),
req,
};
let fut = router.handle_request(args);
let write_task = write_task.clone();
let Ok(permit) = write_task.reserve_owned().await else {
tracing::error!("write task dropped while waiting for permit");
break;
};
tokio::spawn(
async move {
let rv = fut.await.expect("infallible");
let _ = permit.send(
rv
);
}
.instrument(span)
);
}
}
}
drop(gone);
}
pub(crate) fn spawn(self) -> tokio::task::JoinHandle<()> {
let future = self.task_future();
tokio::spawn(future)
}
}
struct WriteTask<T: Listener> {
pub(crate) shutdown: watch::Receiver<()>,
pub(crate) gone: oneshot::Receiver<()>,
pub(crate) conn_id: ConnectionId,
pub(crate) json: mpsc::Receiver<Box<RawValue>>,
pub(crate) connection: Out<T>,
}
impl<T: Listener> WriteTask<T> {
#[instrument(skip(self), fields(conn_id = self.conn_id))]
pub(crate) async fn task_future(self) {
let WriteTask {
mut shutdown,
mut gone,
mut json,
mut connection,
..
} = self;
shutdown.mark_unchanged();
loop {
select! {
biased;
_ = &mut gone => {
debug!("Connection has gone away");
break;
}
_ = shutdown.changed() => {
debug!("shutdown signal received");
break;
}
json = json.recv() => {
let Some(json) = json else {
tracing::error!("Json stream has closed");
break;
};
if let Err(err) = connection.send_json(json).await {
debug!(%err, "Failed to send json");
break;
}
}
}
}
}
pub(crate) fn spawn(self) -> JoinHandle<()> {
tokio::spawn(self.task_future())
}
}