Skip to main content

commonware_sync/net/
io.rs

1use crate::{
2    net::{request_id::RequestId, Message, MAX_MESSAGE_SIZE},
3    Error,
4};
5use commonware_macros::select_loop;
6use commonware_runtime::{Handle, IoBufs, Sink, Spawner, Stream};
7use commonware_stream::utils::codec::{recv_frame, send_frame};
8use commonware_utils::channel::{mpsc, oneshot};
9use std::collections::HashMap;
10use tracing::debug;
11
12const REQUEST_BUFFER_SIZE: usize = 64;
13const RECV_BUFFER_SIZE: usize = 64;
14
15/// A request and callback for a response.
16pub(super) struct Request<M: Message> {
17    pub(super) request: M,
18    pub(super) response_tx: oneshot::Sender<Result<M, Error>>,
19}
20
21/// Dedicated recv task: reads frames from the stream and forwards them on a
22/// channel. Runs in its own task so that `recv_frame` is never cancelled by
23/// `select!` (cancelling a partially-read frame corrupts the stream).
24async fn recv_loop<St: Stream>(mut stream: St, tx: mpsc::Sender<IoBufs>) {
25    loop {
26        match recv_frame(&mut stream, MAX_MESSAGE_SIZE).await {
27            Ok(data) => {
28                if tx.send(data).await.is_err() {
29                    return;
30                }
31            }
32            Err(_) => return,
33        }
34    }
35}
36
37/// Run the I/O loop which:
38/// - Receives requests from the request channel and sends them to the sink.
39/// - Receives responses (via channel from the recv task) and forwards them to
40///   their callback channel.
41///
42/// Both select branches (`request_rx.recv()` and `response_rx.recv()`) are
43/// cancellation-safe, unlike `recv_frame`.
44async fn run_loop<E, Si, St, M>(
45    context: E,
46    mut sink: Si,
47    stream: St,
48    mut request_rx: mpsc::Receiver<Request<M>>,
49    mut pending_requests: HashMap<RequestId, oneshot::Sender<Result<M, Error>>>,
50) where
51    E: Spawner + Clone,
52    Si: Sink,
53    St: Stream,
54    M: Message,
55{
56    let (response_tx, mut response_rx) = mpsc::channel(RECV_BUFFER_SIZE);
57
58    // Spawn dedicated recv task so recv_frame is never cancelled.
59    let recv_handle = context
60        .clone()
61        .spawn(move |_| recv_loop(stream, response_tx));
62
63    select_loop! {
64        context,
65        on_stopped => {
66            debug!("context shutdown, terminating I/O task");
67            recv_handle.abort();
68        },
69        Some(Request {
70            request,
71            response_tx,
72        }) = request_rx.recv() else {
73            recv_handle.abort();
74            return;
75        } => {
76            let request_id = request.request_id();
77            pending_requests.insert(request_id, response_tx);
78            let data = request.encode();
79            if let Err(e) = send_frame(&mut sink, data, MAX_MESSAGE_SIZE).await {
80                if let Some(sender) = pending_requests.remove(&request_id) {
81                    let _ = sender.send(Err(Error::Network(e)));
82                }
83                recv_handle.abort();
84                return;
85            }
86        },
87        Some(response_data) = response_rx.recv() else {
88            for (_, sender) in pending_requests.drain() {
89                let _ = sender.send(Err(Error::RequestChannelClosed));
90            }
91            return;
92        } => {
93            match M::decode(response_data.coalesce()) {
94                Ok(message) => {
95                    let request_id = message.request_id();
96                    if let Some(sender) = pending_requests.remove(&request_id) {
97                        let _ = sender.send(Ok(message));
98                    }
99                }
100                Err(_) => { /* ignore */ }
101            }
102        },
103    }
104}
105
106/// Starts the I/O task and returns a sender for requests and a handle to the task.
107/// The I/O task is responsible for sending and receiving messages over the network.
108/// The I/O task uses a oneshot channel to send responses back to the caller.
109pub(super) fn run<E, Si, St, M>(
110    context: E,
111    sink: Si,
112    stream: St,
113) -> Result<(mpsc::Sender<Request<M>>, Handle<()>), commonware_runtime::Error>
114where
115    E: Spawner,
116    Si: Sink,
117    St: Stream,
118    M: Message,
119{
120    let (request_tx, request_rx) = mpsc::channel(REQUEST_BUFFER_SIZE);
121    let handle =
122        context.spawn(move |context| run_loop(context, sink, stream, request_rx, HashMap::new()));
123    Ok((request_tx, handle))
124}