use std::{
collections::HashMap,
future::poll_fn,
mem,
sync::{Arc, Mutex},
};
use tokio::sync::{broadcast, mpsc, oneshot};
use crate::{
message::*,
relay::{
simple::MessageQueueHandle, MessageRelayService, MessageSendError,
Relay, Sender, SplitSender,
},
};
enum Output {
Msg(Bytes),
Ask(Bytes, MessageQueueHandle<Vec<BytesMut>>),
Flush(oneshot::Sender<()>),
}
pub struct MsgRelayMux {
tx: mpsc::Sender<Output>,
br: broadcast::Sender<Bytes>,
}
pub struct InternalConnection {
tx: mpsc::Sender<Output>,
input: MessageQueueHandle<Vec<BytesMut>>,
br: broadcast::Receiver<Bytes>,
}
impl MsgRelayMux {
pub fn new<R>(
mut relay: R,
output_buffer: usize,
input_ask_buffer: usize,
) -> Self
where
R: Relay + SplitSender + Send + 'static,
{
let asks = Arc::new(Mutex::new(HashMap::new()));
let tx = {
let (tx, mut rx) = mpsc::channel::<Output>(output_buffer);
let sender = relay.split_sender();
tokio::spawn({
let asks = asks.clone();
async move {
while let Some(output) = rx.recv().await {
match output {
Output::Msg(msg) => {
if sender.feed(msg).await.is_err() {
break;
}
}
Output::Flush(notify) => {
if sender.flush().await.is_err() {
break;
}
let _ = notify.send(());
}
Output::Ask(msg, q) => {
if let Ok(id) = MsgId::try_from(msg.as_ref())
{
asks.lock().unwrap().insert(id, q);
}
if sender.feed(msg).await.is_err() {
break;
}
}
}
}
rx.close();
let asks = mem::take(&mut *asks.lock().unwrap());
for q in asks.into_values() {
q.close();
}
}
});
tx
};
let (br, _) = broadcast::channel(input_ask_buffer);
let input_asks = br.clone();
tokio::spawn(async move {
while let Some(msg) = relay.next().await {
if msg.len() == MESSAGE_HEADER_SIZE {
let _ = input_asks.send(msg.into());
} else {
if let Some(q) = <&MsgId>::try_from(msg.as_ref())
.ok()
.and_then(|id| asks.lock().unwrap().remove(id))
{
q.push(msg);
}
}
}
});
Self { tx, br }
}
}
impl MessageRelayService for MsgRelayMux {
type MessageRelay = InternalConnection;
async fn connect(&self) -> Option<Self::MessageRelay> {
Some(InternalConnection {
tx: self.tx.clone(),
input: MessageQueueHandle::default(),
br: self.br.subscribe(),
})
}
}
impl Relay for InternalConnection {
async fn next(&mut self) -> Option<BytesMut> {
loop {
tokio::select!(
msg = poll_fn(|cx| self.input.poll_recv(cx)) => return msg,
ask = self.br.recv() => {
if let Ok(bytes) = ask {
return Some(bytes.into());
}
}
)
}
}
async fn feed(&self, msg: Bytes) -> Result<(), MessageSendError> {
let output = if msg.len() == MESSAGE_HEADER_SIZE {
Output::Ask(msg, self.input.clone())
} else {
Output::Msg(msg)
};
self.tx.send(output).await.map_err(|_| MessageSendError)
}
async fn flush(&self) -> Result<(), MessageSendError> {
let (tx, rx) = oneshot::channel();
self.tx
.send(Output::Flush(tx))
.await
.map_err(|_| MessageSendError)?;
rx.await.map_err(|_| MessageSendError)
}
}
#[cfg(test)]
mod test {
use std::time::Duration;
use crate::{
message::{allocate_message, InstanceId, MessageTag, MsgId},
relay::SimpleMessageRelay,
};
use super::*;
fn mk_msg(id: &MsgId) -> Bytes {
allocate_message(id, Duration::from_secs(10), 0, &[0, 255])
}
#[tokio::test(flavor = "multi_thread")]
async fn mux() {
let sk = &[1];
let instance = InstanceId::from([1u8; 32]);
let s = SimpleMessageRelay::new();
let _c1 = s.connect();
let mut c2 = s.connect();
let mux = MsgRelayMux::new(s.connection(true), 1, 1);
let mut m1 = mux.connect().await.unwrap();
let mut m2 = mux.connect().await.unwrap();
let msg_0_id = MsgId::new(&instance, sk, None, MessageTag::tag(0));
let msg_0 = mk_msg(&msg_0_id);
m1.ask(&msg_0_id, Duration::from_secs(10)).await.unwrap();
c2.send(msg_0.clone()).await.unwrap();
let msg_0_in = m1.next().await.unwrap();
assert_eq!(msg_0, msg_0_in);
let msg_1_id = MsgId::new(&instance, sk, None, MessageTag::tag(1));
let msg_1 = mk_msg(&msg_1_id);
m2.send(msg_1.clone()).await.unwrap();
c2.ask(&msg_1_id, Duration::from_secs(10)).await.unwrap();
let msg_1_in = c2.next().await.unwrap();
assert_eq!(msg_1, msg_1_in);
let msg_2_id = MsgId::new(&instance, sk, None, MessageTag::tag(2));
c2.ask(&msg_2_id, Duration::from_secs(10)).await.unwrap();
let ask_m1 = m1.next().await.unwrap();
let ask_m2 = m2.next().await.unwrap();
assert_eq!(ask_m1, ask_m2);
assert_eq!(MsgId::try_from(ask_m1.as_ref()).unwrap(), msg_2_id);
}
}