commonware_sync/net/
io.rs

1use crate::{
2    net::{request_id::RequestId, Message, MAX_MESSAGE_SIZE},
3    Error,
4};
5use commonware_macros::select_loop;
6use commonware_runtime::{Handle, Sink, Spawner, Stream};
7use commonware_stream::utils::codec::{recv_frame, send_frame};
8use futures::{
9    channel::{mpsc, oneshot},
10    StreamExt,
11};
12use std::collections::HashMap;
13use tracing::debug;
14
15const REQUEST_BUFFER_SIZE: usize = 64;
16
17/// A request and callback for a response.
18pub(super) struct Request<M: Message> {
19    pub(super) request: M,
20    pub(super) response_tx: oneshot::Sender<Result<M, Error>>,
21}
22
23/// Run the I/O loop which:
24/// - Receives requests from the request channel and sends them to the sink.
25/// - Receives responses from the stream and forwards them to their callback channel.
26async fn run_loop<E, Si, St, M>(
27    context: E,
28    mut sink: Si,
29    mut stream: St,
30    mut request_rx: mpsc::Receiver<Request<M>>,
31    mut pending_requests: HashMap<RequestId, oneshot::Sender<Result<M, Error>>>,
32) where
33    E: Spawner,
34    Si: Sink,
35    St: Stream,
36    M: Message,
37{
38    select_loop! {
39        context,
40        on_stopped => {
41            debug!("context shutdown, terminating I/O task");
42        },
43        outgoing = request_rx.next() => {
44            match outgoing {
45                Some(Request { request, response_tx }) => {
46                    let request_id = request.request_id();
47                    pending_requests.insert(request_id, response_tx);
48                    let data = request.encode().to_vec();
49                    if let Err(e) = send_frame(&mut sink, &data, MAX_MESSAGE_SIZE).await {
50                        if let Some(sender) = pending_requests.remove(&request_id) {
51                            let _ = sender.send(Err(Error::Network(e)));
52                        }
53                        return;
54                    }
55                },
56                None => return,
57            }
58        },
59        incoming = recv_frame(&mut stream, MAX_MESSAGE_SIZE) => {
60            match incoming {
61                Ok(response_data) => {
62                    match M::decode(&response_data[..]) {
63                        Ok(message) => {
64                            let request_id = message.request_id();
65                            if let Some(sender) = pending_requests.remove(&request_id) {
66                                let _ = sender.send(Ok(message));
67                            }
68                        },
69                        Err(_) => { /* ignore */ }
70                    }
71                },
72                Err(_e) => {
73                    for (_, sender) in pending_requests.drain() {
74                        let _ = sender.send(Err(Error::RequestChannelClosed));
75                    }
76                    return;
77                }
78            }
79        }
80    }
81}
82
83/// Starts the I/O task and returns a sender for requests and a handle to the task.
84/// The I/O task is responsible for sending and receiving messages over the network.
85/// The I/O task uses a oneshot channel to send responses back to the caller.
86pub(super) fn run<E, Si, St, M>(
87    context: E,
88    sink: Si,
89    stream: St,
90) -> Result<(mpsc::Sender<Request<M>>, Handle<()>), commonware_runtime::Error>
91where
92    E: Spawner,
93    Si: Sink,
94    St: Stream,
95    M: Message,
96{
97    let (request_tx, request_rx) = mpsc::channel(REQUEST_BUFFER_SIZE);
98    let handle =
99        context.spawn(move |context| run_loop(context, sink, stream, request_rx, HashMap::new()));
100    Ok((request_tx, handle))
101}