use crate::transmission;
use super::{Error, RequestId, Response, Result, ShutdownEmitter};
use std::{collections::HashMap, sync::Arc};
use tokio::sync::{mpsc, oneshot};
use tokio_tungstenite::tungstenite;
use tokio_util::sync::CancellationToken;
pub type Responder = oneshot::Sender<Result<Response>>;
type CommandSender = mpsc::UnboundedSender<ClientCommand>;
type CommandReceiver = mpsc::UnboundedReceiver<ClientCommand>;
type ResponseSender = mpsc::UnboundedSender<DeliveryCommand>;
type ResponseReceiver = mpsc::UnboundedReceiver<DeliveryCommand>;
pub fn init(
dispatching_cancellator: CancellationToken,
transmission_interrupter: transmission::Interrupter,
shutdown: ShutdownEmitter,
) -> (ClientRouter, ResponseRouter) {
let (client_sender, client_receiver) = mpsc::unbounded_channel();
let (response_sender, response_receiver) = mpsc::unbounded_channel();
tokio::spawn(routing_task(
client_receiver,
response_receiver,
dispatching_cancellator,
transmission_interrupter,
shutdown,
));
(
ClientRouter {
sender: client_sender,
},
ResponseRouter {
sender: response_sender,
},
)
}
#[derive(Clone)]
pub struct ClientRouter {
sender: CommandSender,
}
impl ClientRouter {
pub fn book(&self, id: RequestId, responder: Responder) -> Result {
self.sender
.send(ClientCommand::Book { id, responder })
.map_err(|_| Arc::new(tungstenite::Error::AlreadyClosed))
}
pub fn shutdown(self) {
let _ = self.sender.send(ClientCommand::Shutdown);
}
}
#[derive(Clone)]
pub struct ResponseRouter {
sender: ResponseSender,
}
impl ResponseRouter {
pub fn deliver(&self, id: RequestId, response: Response) {
self.sender
.send(DeliveryCommand::Deliver { id, response })
.expect("Routing task exists while there are responses to deliver")
}
pub fn shutdown(self, err: Error) {
self.sender
.send(DeliveryCommand::Shutdown(err))
.expect("Delivery error must be received by the router in any circumstances");
}
}
async fn routing_task(
mut client_commands: CommandReceiver,
mut responses: ResponseReceiver,
dispatching_cancellator: CancellationToken,
transmission_interrupter: transmission::Interrupter,
shutdown: ShutdownEmitter,
) {
let mut router = InnerRouter::new();
let internal_error = normal_operation(&mut router, &mut client_commands, &mut responses).await;
let _ = transmission_interrupter.send(internal_error.clone());
if let Some(err) = internal_error {
error_handler(router, client_commands, responses, err).await;
return;
}
if let Err(err) = graceful_shutdown(&mut router, &mut client_commands, &mut responses).await {
error_handler(router, client_commands, responses, err).await;
} else {
dispatching_cancellator.cancel();
}
log::debug!("Router task finished");
let _ = shutdown.send(true);
}
async fn normal_operation(
router: &mut InnerRouter,
client_commands: &mut CommandReceiver,
responses: &mut ResponseReceiver,
) -> Option<Error> {
loop {
tokio::select! {
biased;
cmd = client_commands.recv() => {
match cmd {
Some(ClientCommand::Book {id, responder }) => { router.book(id, responder); }
Some(ClientCommand::Shutdown) => {
client_commands.close();
break None;
}
None => {
assert!(client_commands.is_closed());
break None;
}
}
}
response = responses.recv() => {
match response {
Some(DeliveryCommand::Deliver { id, response }) => {
assert!(router.deliver(id, Ok(response)), "Request ID is booked before sending a request");
}
Some(DeliveryCommand::Shutdown(err)) => {
client_commands.close();
break Some(err);
}
None => unreachable!("Dispatcher task always sends Shutdown before dropping the channel"),
}
}
}
}
}
async fn graceful_shutdown(
router: &mut InnerRouter,
client_commands: &mut CommandReceiver,
responses: &mut ResponseReceiver,
) -> Result {
while let Some(cmd) = client_commands.recv().await {
match cmd {
ClientCommand::Book { responder, .. } => {
let _ = responder.send(Err(Arc::new(tungstenite::Error::AlreadyClosed)));
}
ClientCommand::Shutdown => {
log::warn!(
"Ignoring another disconnect() call because the client is already shutting down"
);
}
}
}
while !router.table.is_empty() {
match responses.recv().await {
Some(DeliveryCommand::Deliver { id, response }) => {
if !router.deliver(id, Ok(response)) {
log::warn!(
"Dropping response for unknown corrId {id} — request lost race with disconnect"
);
}
}
Some(DeliveryCommand::Shutdown(err)) => {
return Err(err);
}
None => {
unreachable!("Dispatcher task always sends Shutdown before dropping the channel")
}
}
}
Ok(())
}
async fn error_handler(
router: InnerRouter,
mut client_commands: CommandReceiver,
mut receiver: ResponseReceiver,
error: Error,
) {
receiver.close();
log::error!("Terminating the router task due to an error: {error}");
while let Some(cmd) = client_commands.recv().await {
match cmd {
ClientCommand::Book { responder, .. } => {
let _ = responder.send(Err(Arc::clone(&error)));
}
ClientCommand::Shutdown => {
log::warn!(
"Ignoring disconnect() call because the client is already shutting down due to an error"
);
}
}
}
for (_, responder) in router.table.into_iter() {
let _ = responder.send(Err(Arc::clone(&error)));
}
}
#[derive(Debug)]
enum DeliveryCommand {
Deliver { id: RequestId, response: Response },
Shutdown(Error),
}
enum ClientCommand {
Book { id: RequestId, responder: Responder },
Shutdown,
}
#[derive(Default)]
struct InnerRouter {
table: HashMap<RequestId, Responder>,
}
impl InnerRouter {
fn new() -> Self {
Self::default()
}
fn book(&mut self, id: RequestId, responder: Responder) {
let prev = self.table.insert(id, responder);
assert!(prev.is_none(), "Request ID cannot not be duplicated");
}
fn deliver(&mut self, id: RequestId, result: Result<Response>) -> bool {
if let Some(responder) = self.table.remove(&id) {
let _ = responder.send(result);
true
} else {
false
}
}
}