use crate::{
pubsub::{In, JsonSink, Listener, Out},
types::InboundData,
HandlerCtx, TaskSet,
};
use core::fmt;
use serde_json::value::RawValue;
use std::sync::{atomic::AtomicU64, Arc};
use tokio::{pin, runtime::Handle, select, sync::mpsc, task::JoinHandle};
use tokio_stream::StreamExt;
use tokio_util::sync::WaitForCancellationFutureOwned;
use tracing::{debug, debug_span, error, trace, Instrument};
pub const DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT: usize = 16;
pub type ConnectionId = u64;
pub(crate) struct ListenerTask<T: Listener> {
pub(crate) listener: T,
pub(crate) manager: ConnectionManager,
}
impl<T> ListenerTask<T>
where
T: Listener,
{
pub(crate) async fn task_future(self) {
let ListenerTask { listener, manager } = self;
loop {
let (resp_sink, req_stream) = match listener.accept().await {
Ok((resp_sink, req_stream)) => (resp_sink, req_stream),
Err(err) => {
error!(%err, "Failed to accept connection");
continue;
}
};
manager.handle_new_connection::<T>(req_stream, resp_sink);
}
}
pub(crate) fn spawn(self) -> JoinHandle<Option<()>> {
let tasks = self.manager.root_tasks.clone();
let future = self.task_future();
tasks.spawn_cancellable(future)
}
}
pub(crate) struct ConnectionManager {
pub(crate) root_tasks: TaskSet,
pub(crate) next_id: Arc<AtomicU64>,
pub(crate) router: crate::Router<()>,
pub(crate) notification_buffer_per_task: usize,
}
impl ConnectionManager {
pub(crate) fn new(router: crate::Router<()>) -> Self {
Self {
root_tasks: Default::default(),
next_id: AtomicU64::new(0).into(),
router,
notification_buffer_per_task: DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT,
}
}
pub(crate) fn with_root_tasks(mut self, root_tasks: TaskSet) -> Self {
self.root_tasks = root_tasks;
self
}
pub(crate) fn with_handle(mut self, handle: Handle) -> Self {
self.root_tasks = self.root_tasks.with_handle(handle);
self
}
pub(crate) const fn with_notification_buffer_per_client(
mut self,
notification_buffer_per_client: usize,
) -> Self {
self.notification_buffer_per_task = notification_buffer_per_client;
self
}
fn next_id(&self) -> ConnectionId {
self.next_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
fn router(&self) -> crate::Router<()> {
self.router.clone()
}
fn make_tasks<T: Listener>(
&self,
conn_id: ConnectionId,
requests: In<T>,
connection: Out<T>,
) -> (RouteTask<T>, WriteTask<T>) {
let (tx, rx) = mpsc::channel(self.notification_buffer_per_task);
let tasks = self.root_tasks.child();
let rt = RouteTask {
router: self.router(),
conn_id,
write_task: tx,
requests,
tasks: tasks.clone(),
};
let wt = WriteTask {
tasks,
conn_id,
json: rx,
connection,
};
(rt, wt)
}
fn spawn_tasks<T: Listener>(&self, requests: In<T>, connection: Out<T>) {
let conn_id = self.next_id();
let (rt, wt) = self.make_tasks::<T>(conn_id, requests, connection);
rt.spawn();
wt.spawn();
}
pub(crate) fn handle_new_connection<T: Listener>(&self, requests: In<T>, connection: Out<T>) {
self.spawn_tasks::<T>(requests, connection);
}
}
struct RouteTask<T: crate::pubsub::Listener> {
pub(crate) router: crate::Router<()>,
pub(crate) conn_id: ConnectionId,
pub(crate) write_task: mpsc::Sender<Box<RawValue>>,
pub(crate) requests: In<T>,
pub(crate) tasks: TaskSet,
}
impl<T: crate::pubsub::Listener> fmt::Debug for RouteTask<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RouteTask")
.field("conn_id", &self.conn_id)
.finish_non_exhaustive()
}
}
impl<T> RouteTask<T>
where
T: crate::pubsub::Listener,
{
pub(crate) async fn task_future(self, cancel: WaitForCancellationFutureOwned) {
let RouteTask {
router,
mut requests,
write_task,
tasks,
..
} = self;
let children = tasks.child();
pin!(cancel);
loop {
select! {
biased;
_ = &mut cancel => {
debug!("RouteTask cancelled");
break;
}
_ = write_task.closed() => {
debug!("WriteTask has gone away");
break;
}
item = requests.next() => {
let Some(item) = item else {
trace!("inbound read stream has closed");
break;
};
let reqs = InboundData::try_from(item).unwrap_or_default();
let Ok(permit) = write_task.clone().reserve_owned().await else {
tracing::error!("write task dropped while waiting for permit");
break;
};
let ctx =
HandlerCtx::new(
Some(write_task.clone()),
children.clone(),
);
let fut = router.handle_request_batch(ctx, reqs);
children.spawn_cancellable(
async move {
if let Some(rv) = fut.await {
let _ = permit.send(
rv
);
}
}
);
}
}
}
children.shutdown().await;
}
pub(crate) fn spawn(self) -> tokio::task::JoinHandle<()> {
let tasks = self.tasks.clone();
let future = move |cancel| self.task_future(cancel);
tasks.spawn_graceful(future)
}
}
struct WriteTask<T: Listener> {
pub(crate) tasks: TaskSet,
pub(crate) conn_id: ConnectionId,
pub(crate) json: mpsc::Receiver<Box<RawValue>>,
pub(crate) connection: Out<T>,
}
impl<T: Listener> WriteTask<T> {
pub(crate) async fn task_future(self) {
let WriteTask {
tasks,
mut json,
mut connection,
..
} = self;
loop {
select! {
biased;
_ = tasks.cancelled() => {
debug!("Shutdown signal received");
break;
}
json = json.recv() => {
let Some(json) = json else {
tracing::error!("Json stream has closed");
break;
};
let span = debug_span!("WriteTask", conn_id = self.conn_id);
if let Err(err) = connection.send_json(json).instrument(span).await {
debug!(%err, conn_id = self.conn_id, "Failed to send json");
break;
}
}
}
}
}
pub(crate) fn spawn(self) -> tokio::task::JoinHandle<Option<()>> {
let tasks = self.tasks.clone();
let future = self.task_future();
tasks.spawn_cancellable(future)
}
}