use crate::protocol::*;
use crate::topic::*;
use crate::{directory::Directory, transport::Transport};
use crate::err::*;
use std::collections::HashMap;
use std::time::Duration;
use futures::SinkExt;
use futures::StreamExt;
use tokio::task::JoinHandle;
pub(crate) struct TaskPub {
pub(crate) msg: PubMsg,
pub(crate) callback_ack: tokio::sync::oneshot::Sender<AckMsg>,
}
pub(crate) struct TaskSub {
pub(crate) msg: SubMsg,
pub(crate) callback_ack: tokio::sync::oneshot::Sender<AckMsg>,
pub(crate) callback_pub: tokio::sync::mpsc::UnboundedSender<PubMsg>,
}
pub(crate) struct TaskUnsub {
pub(crate) msg: UnsubMsg,
}
pub(crate) struct TaskUnsrv {
pub(crate) msg: UnsrvMsg,
}
pub(crate) struct TaskReq {
pub(crate) msg: ReqMsg,
pub(crate) callback_ack: tokio::sync::oneshot::Sender<AckMsg>,
pub(crate) callback_rsp: tokio::sync::oneshot::Sender<RspMsg>,
}
pub(crate) struct TaskRsp {
pub(crate) msg: RspMsg,
pub(crate) callback_ack: tokio::sync::oneshot::Sender<AckMsg>,
}
pub(crate) struct TaskSrv {
pub(crate) msg: SrvMsg,
pub(crate) callback_ack: tokio::sync::oneshot::Sender<AckMsg>,
pub(crate) callback_req: tokio::sync::mpsc::UnboundedSender<(MsgId, ReqMsg)>,
}
pub(crate) struct TaskStopBus {
pub(crate) callback_ack: tokio::sync::oneshot::Sender<AckMsg>,
}
pub(crate) enum Task {
Pub(TaskPub),
Sub(TaskSub),
Req(TaskReq),
Rsp(TaskRsp),
Srv(TaskSrv),
Unsub(TaskUnsub),
Unsrv(TaskUnsrv),
Stop,
StopBus(TaskStopBus),
KeepAlive
}
pub(crate) struct ClientCore<TTransport>
where
TTransport: Transport<ProtocolClient, ProtocolServer>,
{
next_msg_id: MsgId,
transport: TTransport,
callbacks_ack: HashMap<MsgId, tokio::sync::oneshot::Sender<AckMsg>>,
callbacks_rsp: HashMap<MsgId, tokio::sync::oneshot::Sender<RspMsg>>,
callbacks_pub: HashMap<MsgId, tokio::sync::mpsc::UnboundedSender<PubMsg>>,
callbacks_req: HashMap<MsgId, tokio::sync::mpsc::UnboundedSender<(MsgId, ReqMsg)>>,
subscriptions: Directory,
task_receiver: tokio::sync::mpsc::UnboundedReceiver<Task>,
}
impl<TTransport> ClientCore<TTransport>
where
TTransport: Transport<ProtocolClient, ProtocolServer>,
{
pub fn start(
transport: TTransport,
task_receiver: tokio::sync::mpsc::UnboundedReceiver<Task>,
) -> BusResult<JoinHandle<BusResult<()>>> {
let mut core = ClientCore {
transport,
subscriptions: Directory::new(),
task_receiver,
next_msg_id: 0,
callbacks_ack: HashMap::new(),
callbacks_rsp: HashMap::new(),
callbacks_pub: HashMap::new(),
callbacks_req: HashMap::new(),
};
let handle = tokio::spawn(async move {
#[cfg(debug_assertions)]
println!("[C] Bus client started");
let result = core.main_loop().await ;
#[cfg(debug_assertions)]
println!("[C] Bus client stopped: {result:?}");
result
});
Ok(handle)
}
async fn main_loop(&mut self) -> BusResult<()> {
let mut keep_alive_interval = tokio::time::interval(Duration::from_secs(KEEP_ALIVE_INTERVAL_S));
loop {
tokio::select! {
msg = self.transport.next() => {
self.after_receive(msg.ok_or(BusError::ChannelClosed)??).await?;
},
task_option = self.task_receiver.recv() => {
match task_option {
Some(task) => {
if let Task::Stop = task {
return Ok(())
}
self.send(task).await?;
},
None => {
return Ok(())
}
}
},
_ = keep_alive_interval.tick() => {
self.send(Task::KeepAlive).await?;
}
}
}
}
async fn send(&mut self, task: Task) -> BusResult<()> {
let msg = self.before_send(task)?;
#[cfg(debug_assertions)]
{
let log_msg = crate::debug::client_msg_to_string(&msg);
println!("[C] --> [B] {}", &log_msg);
}
self.transport.send(msg).await?;
Ok(())
}
fn before_send(&mut self, task: Task) -> BusResult<Msg<ProtocolClient>> {
self.next_msg_id += 1;
let id = self.next_msg_id;
let msg = match task {
Task::Pub(params) => {
self.callbacks_ack.insert(id, params.callback_ack);
Msg {
id,
content: ProtocolClient::Pub(params.msg),
}
}
Task::Sub(params) => {
self.callbacks_ack.insert(id, params.callback_ack);
self.callbacks_pub.insert(id, params.callback_pub);
self.subscriptions
.subscribe(id, &parse_topic(¶ms.msg.topic)?);
Msg {
id,
content: ProtocolClient::Sub(params.msg),
}
}
Task::Req(params) => {
self.callbacks_ack.insert(id, params.callback_ack);
self.callbacks_rsp.insert(id, params.callback_rsp);
Msg {
id,
content: ProtocolClient::Req(params.msg),
}
}
Task::Rsp(params) => {
self.callbacks_ack.insert(id, params.callback_ack);
Msg {
id,
content: ProtocolClient::Rsp(params.msg),
}
}
Task::Unsub(params) => Msg {
id,
content: ProtocolClient::Unsub(params.msg),
},
Task::Unsrv(params) => {
let topic = parse_topic(¶ms.msg.topic)?;
let callback_id = self.subscriptions.get_owner(&topic).unwrap();
self.subscriptions.unclaim(callback_id, &topic)?;
Msg {
id,
content: ProtocolClient::Unsrv(params.msg),
}
}
Task::Srv(params) => {
self.callbacks_ack.insert(id, params.callback_ack);
self.callbacks_req.insert(id, params.callback_req);
self.subscriptions
.claim(id, &parse_topic(¶ms.msg.topic)?)?;
Msg {
id,
content: ProtocolClient::Srv(params.msg),
}
}
Task::Stop => panic!(),
Task::StopBus(params) => {
self.callbacks_ack.insert(id, params.callback_ack);
Msg {
id,
content: ProtocolClient::Stop,
}
},
Task::KeepAlive => {
Msg {
id,
content: ProtocolClient::KeepAlive,
}
}
};
Ok(msg)
}
async fn after_receive(&mut self, msg: Msg<ProtocolServer>) -> BusResult<()> {
#[cfg(debug_assertions)]
{
let log_msg = crate::debug::server_msg_to_string(&msg);
println!("[C] <-- [B] {}", &log_msg);
}
match msg.content {
ProtocolServer::Ack(payload) => {
let _ = self
.callbacks_ack
.remove(&payload.msg_id)
.map(|c| c.send(payload));
}
ProtocolServer::Rsp(payload) => {
let callback = self.callbacks_rsp.remove(&payload.req_id).unwrap();
let _ = callback.send(payload); }
ProtocolServer::Pub(payload) => {
let callback_ids = self
.subscriptions
.get_subscribers(&parse_topic(&payload.topic)?);
for callback_id in callback_ids {
let callback_result = self
.callbacks_pub
.get_mut(&callback_id)
.unwrap()
.send(payload.clone());
if callback_result.is_err() {
self.callbacks_pub.remove(&callback_id);
let topics = self.subscriptions.drop_client(callback_id);
for topic in topics {
let unsubscribe_task = Task::Unsub(TaskUnsub {
msg: UnsubMsg { topic },
});
self.send(unsubscribe_task).await?;
}
}
}
}
ProtocolServer::Req(payload) => {
let callback_id_option =
self.subscriptions.get_owner(&parse_topic(&payload.topic)?);
match callback_id_option {
Some(callback_id) => {
let callback = self.callbacks_req.get_mut(&callback_id).unwrap();
callback.send((msg.id, payload))?;
}
None => {
}
}
}
}
Ok(())
}
}