Skip to main content

commonware_sync/net/
io.rs

1use crate::{
2    net::{request_id::RequestId, Message, MAX_MESSAGE_SIZE},
3    Error,
4};
5use commonware_macros::select_loop;
6use commonware_runtime::{Handle, Sink, Spawner, Stream};
7use commonware_stream::utils::codec::{recv_frame, send_frame};
8use commonware_utils::channel::{mpsc, oneshot};
9use std::collections::HashMap;
10use tracing::debug;
11
12const REQUEST_BUFFER_SIZE: usize = 64;
13
14/// A request and callback for a response.
15pub(super) struct Request<M: Message> {
16    pub(super) request: M,
17    pub(super) response_tx: oneshot::Sender<Result<M, Error>>,
18}
19
20/// Run the I/O loop which:
21/// - Receives requests from the request channel and sends them to the sink.
22/// - Receives responses from the stream and forwards them to their callback channel.
23async fn run_loop<E, Si, St, M>(
24    context: E,
25    mut sink: Si,
26    mut stream: St,
27    mut request_rx: mpsc::Receiver<Request<M>>,
28    mut pending_requests: HashMap<RequestId, oneshot::Sender<Result<M, Error>>>,
29) where
30    E: Spawner,
31    Si: Sink,
32    St: Stream,
33    M: Message,
34{
35    select_loop! {
36        context,
37        on_stopped => {
38            debug!("context shutdown, terminating I/O task");
39        },
40        Some(Request {
41            request,
42            response_tx,
43        }) = request_rx.recv() else return => {
44            let request_id = request.request_id();
45            pending_requests.insert(request_id, response_tx);
46            let data = request.encode();
47            if let Err(e) = send_frame(&mut sink, data, MAX_MESSAGE_SIZE).await {
48                if let Some(sender) = pending_requests.remove(&request_id) {
49                    let _ = sender.send(Err(Error::Network(e)));
50                }
51                return;
52            }
53        },
54        Ok(response_data) = recv_frame(&mut stream, MAX_MESSAGE_SIZE) else {
55            for (_, sender) in pending_requests.drain() {
56                let _ = sender.send(Err(Error::RequestChannelClosed));
57            }
58            return;
59        } => {
60            match M::decode(response_data.coalesce()) {
61                Ok(message) => {
62                    let request_id = message.request_id();
63                    if let Some(sender) = pending_requests.remove(&request_id) {
64                        let _ = sender.send(Ok(message));
65                    }
66                }
67                Err(_) => { /* ignore */ }
68            }
69        },
70    }
71}
72
73/// Starts the I/O task and returns a sender for requests and a handle to the task.
74/// The I/O task is responsible for sending and receiving messages over the network.
75/// The I/O task uses a oneshot channel to send responses back to the caller.
76pub(super) fn run<E, Si, St, M>(
77    context: E,
78    sink: Si,
79    stream: St,
80) -> Result<(mpsc::Sender<Request<M>>, Handle<()>), commonware_runtime::Error>
81where
82    E: Spawner,
83    Si: Sink,
84    St: Stream,
85    M: Message,
86{
87    let (request_tx, request_rx) = mpsc::channel(REQUEST_BUFFER_SIZE);
88    let handle =
89        context.spawn(move |context| run_loop(context, sink, stream, request_rx, HashMap::new()));
90    Ok((request_tx, handle))
91}