use crate::{
net::{request_id::RequestId, Message, MAX_MESSAGE_SIZE},
Error,
};
use commonware_macros::select_loop;
use commonware_runtime::{Handle, IoBufs, Sink, Spawner, Stream};
use commonware_stream::utils::codec::{recv_frame, send_frame};
use commonware_utils::channel::{mpsc, oneshot};
use std::collections::HashMap;
use tracing::debug;
const REQUEST_BUFFER_SIZE: usize = 64;
const RECV_BUFFER_SIZE: usize = 64;
pub(super) struct Request<M: Message> {
pub(super) request: M,
pub(super) response_tx: oneshot::Sender<Result<M, Error>>,
}
async fn recv_loop<St: Stream>(mut stream: St, tx: mpsc::Sender<IoBufs>) {
loop {
match recv_frame(&mut stream, MAX_MESSAGE_SIZE).await {
Ok(data) => {
if tx.send(data).await.is_err() {
return;
}
}
Err(_) => return,
}
}
}
async fn run_loop<E, Si, St, M>(
context: E,
mut sink: Si,
stream: St,
mut request_rx: mpsc::Receiver<Request<M>>,
mut pending_requests: HashMap<RequestId, oneshot::Sender<Result<M, Error>>>,
) where
E: Spawner + Clone,
Si: Sink,
St: Stream,
M: Message,
{
let (response_tx, mut response_rx) = mpsc::channel(RECV_BUFFER_SIZE);
let recv_handle = context
.clone()
.spawn(move |_| recv_loop(stream, response_tx));
select_loop! {
context,
on_stopped => {
debug!("context shutdown, terminating I/O task");
recv_handle.abort();
},
Some(Request {
request,
response_tx,
}) = request_rx.recv() else {
recv_handle.abort();
return;
} => {
let request_id = request.request_id();
pending_requests.insert(request_id, response_tx);
let data = request.encode();
if let Err(e) = send_frame(&mut sink, data, MAX_MESSAGE_SIZE).await {
if let Some(sender) = pending_requests.remove(&request_id) {
let _ = sender.send(Err(Error::Network(e)));
}
recv_handle.abort();
return;
}
},
Some(response_data) = response_rx.recv() else {
for (_, sender) in pending_requests.drain() {
let _ = sender.send(Err(Error::RequestChannelClosed));
}
return;
} => {
match M::decode(response_data.coalesce()) {
Ok(message) => {
let request_id = message.request_id();
if let Some(sender) = pending_requests.remove(&request_id) {
let _ = sender.send(Ok(message));
}
}
Err(_) => { }
}
},
}
}
pub(super) fn run<E, Si, St, M>(
context: E,
sink: Si,
stream: St,
) -> Result<(mpsc::Sender<Request<M>>, Handle<()>), commonware_runtime::Error>
where
E: Spawner,
Si: Sink,
St: Stream,
M: Message,
{
let (request_tx, request_rx) = mpsc::channel(REQUEST_BUFFER_SIZE);
let handle =
context.spawn(move |context| run_loop(context, sink, stream, request_rx, HashMap::new()));
Ok((request_tx, handle))
}