commonware_sync/net/
io.rs

1use crate::{
2    net::{request_id::RequestId, Message, MAX_MESSAGE_SIZE},
3    Error,
4};
5use commonware_macros::select;
6use commonware_runtime::{Handle, Sink, Spawner, Stream};
7use commonware_stream::utils::codec::{recv_frame, send_frame};
8use futures::{
9    channel::{mpsc, oneshot},
10    StreamExt,
11};
12use std::collections::HashMap;
13
14const REQUEST_BUFFER_SIZE: usize = 64;
15
16/// A request and callback for a response.
17pub(super) struct Request<M: Message> {
18    pub(super) request: M,
19    pub(super) response_tx: oneshot::Sender<Result<M, Error>>,
20}
21
22/// Run the I/O loop which:
23/// - Receives requests from the request channel and sends them to the sink.
24/// - Receives responses from the stream and forwards them to their callback channel.
25async fn run_loop<Si, St, M>(
26    mut sink: Si,
27    mut stream: St,
28    mut request_rx: mpsc::Receiver<Request<M>>,
29    mut pending_requests: HashMap<RequestId, oneshot::Sender<Result<M, Error>>>,
30) where
31    Si: Sink,
32    St: Stream,
33    M: Message,
34{
35    loop {
36        select! {
37            outgoing = request_rx.next() => {
38                match outgoing {
39                    Some(Request { request, response_tx }) => {
40                        let request_id = request.request_id();
41                        pending_requests.insert(request_id, response_tx);
42                        let data = request.encode().to_vec();
43                        if let Err(e) = send_frame(&mut sink, &data, MAX_MESSAGE_SIZE).await {
44                            if let Some(sender) = pending_requests.remove(&request_id) {
45                                let _ = sender.send(Err(Error::Network(e)));
46                            }
47                            return;
48                        }
49                    },
50                    None => return,
51                }
52            },
53            incoming = recv_frame(&mut stream, MAX_MESSAGE_SIZE) => {
54                match incoming {
55                    Ok(response_data) => {
56                        match M::decode(&response_data[..]) {
57                            Ok(message) => {
58                                let request_id = message.request_id();
59                                if let Some(sender) = pending_requests.remove(&request_id) {
60                                    let _ = sender.send(Ok(message));
61                                }
62                            },
63                            Err(_) => { /* ignore */ }
64                        }
65                    },
66                    Err(_e) => {
67                        for (_, sender) in pending_requests.drain() {
68                            let _ = sender.send(Err(Error::RequestChannelClosed));
69                        }
70                        return;
71                    }
72                }
73            }
74        }
75    }
76}
77
78/// Starts the I/O task and returns a sender for requests and a handle to the task.
79/// The I/O task is responsible for sending and receiving messages over the network.
80/// The I/O task uses a oneshot channel to send responses back to the caller.
81pub(super) fn run<E, Si, St, M>(
82    context: E,
83    sink: Si,
84    stream: St,
85) -> Result<(mpsc::Sender<Request<M>>, Handle<()>), commonware_runtime::Error>
86where
87    E: Spawner,
88    Si: Sink,
89    St: Stream,
90    M: Message,
91{
92    let (request_tx, request_rx) = mpsc::channel(REQUEST_BUFFER_SIZE);
93    let handle = context.spawn(move |_| run_loop(sink, stream, request_rx, HashMap::new()));
94    Ok((request_tx, handle))
95}