commonware_sync/net/
io.rs1use crate::{
2 net::{request_id::RequestId, Message, MAX_MESSAGE_SIZE},
3 Error,
4};
5use commonware_macros::select_loop;
6use commonware_runtime::{Handle, Sink, Spawner, Stream};
7use commonware_stream::utils::codec::{recv_frame, send_frame};
8use futures::{
9 channel::{mpsc, oneshot},
10 StreamExt,
11};
12use std::collections::HashMap;
13use tracing::debug;
14
15const REQUEST_BUFFER_SIZE: usize = 64;
16
17pub(super) struct Request<M: Message> {
19 pub(super) request: M,
20 pub(super) response_tx: oneshot::Sender<Result<M, Error>>,
21}
22
23async fn run_loop<E, Si, St, M>(
27 context: E,
28 mut sink: Si,
29 mut stream: St,
30 mut request_rx: mpsc::Receiver<Request<M>>,
31 mut pending_requests: HashMap<RequestId, oneshot::Sender<Result<M, Error>>>,
32) where
33 E: Spawner,
34 Si: Sink,
35 St: Stream,
36 M: Message,
37{
38 select_loop! {
39 context,
40 on_stopped => {
41 debug!("context shutdown, terminating I/O task");
42 },
43 outgoing = request_rx.next() => {
44 match outgoing {
45 Some(Request { request, response_tx }) => {
46 let request_id = request.request_id();
47 pending_requests.insert(request_id, response_tx);
48 let data = request.encode().to_vec();
49 if let Err(e) = send_frame(&mut sink, &data, MAX_MESSAGE_SIZE).await {
50 if let Some(sender) = pending_requests.remove(&request_id) {
51 let _ = sender.send(Err(Error::Network(e)));
52 }
53 return;
54 }
55 },
56 None => return,
57 }
58 },
59 incoming = recv_frame(&mut stream, MAX_MESSAGE_SIZE) => {
60 match incoming {
61 Ok(response_data) => {
62 match M::decode(&response_data[..]) {
63 Ok(message) => {
64 let request_id = message.request_id();
65 if let Some(sender) = pending_requests.remove(&request_id) {
66 let _ = sender.send(Ok(message));
67 }
68 },
69 Err(_) => { }
70 }
71 },
72 Err(_e) => {
73 for (_, sender) in pending_requests.drain() {
74 let _ = sender.send(Err(Error::RequestChannelClosed));
75 }
76 return;
77 }
78 }
79 }
80 }
81}
82
83pub(super) fn run<E, Si, St, M>(
87 context: E,
88 sink: Si,
89 stream: St,
90) -> Result<(mpsc::Sender<Request<M>>, Handle<()>), commonware_runtime::Error>
91where
92 E: Spawner,
93 Si: Sink,
94 St: Stream,
95 M: Message,
96{
97 let (request_tx, request_rx) = mpsc::channel(REQUEST_BUFFER_SIZE);
98 let handle =
99 context.spawn(move |context| run_loop(context, sink, stream, request_rx, HashMap::new()));
100 Ok((request_tx, handle))
101}