commonware_sync/net/
io.rs1use crate::{
2 net::{request_id::RequestId, Message, MAX_MESSAGE_SIZE},
3 Error,
4};
5use commonware_macros::select_loop;
6use commonware_runtime::{Handle, IoBufs, Sink, Spawner, Stream};
7use commonware_stream::utils::codec::{recv_frame, send_frame};
8use commonware_utils::channel::{mpsc, oneshot};
9use std::collections::HashMap;
10use tracing::{debug, warn};
11
12const REQUEST_BUFFER_SIZE: usize = 64;
13const RECV_BUFFER_SIZE: usize = 64;
14
15pub(super) struct Request<M: Message> {
17 pub(super) request: M,
18 pub(super) response_tx: oneshot::Sender<Result<M, Error>>,
19}
20
21async fn recv_loop<St: Stream>(mut stream: St, tx: mpsc::Sender<IoBufs>) {
25 loop {
26 match recv_frame(&mut stream, MAX_MESSAGE_SIZE).await {
27 Ok(data) => {
28 if tx.send(data).await.is_err() {
29 return;
30 }
31 }
32 Err(_) => return,
33 }
34 }
35}
36
37async fn run_loop<E, Si, St, M>(
45 context: E,
46 mut sink: Si,
47 stream: St,
48 mut request_rx: mpsc::Receiver<Request<M>>,
49 mut pending_requests: HashMap<RequestId, oneshot::Sender<Result<M, Error>>>,
50) where
51 E: Spawner,
52 Si: Sink,
53 St: Stream,
54 M: Message,
55{
56 let (response_tx, mut response_rx) = mpsc::channel(RECV_BUFFER_SIZE);
57
58 let recv_handle = context
60 .child("recv")
61 .spawn(move |_| recv_loop(stream, response_tx));
62
63 select_loop! {
64 context,
65 on_stopped => {
66 debug!("context shutdown, terminating I/O task");
67 recv_handle.abort();
68 },
69 Some(Request {
70 request,
71 response_tx,
72 }) = request_rx.recv() else {
73 recv_handle.abort();
74 return;
75 } => {
76 let request_id = request.request_id();
77 pending_requests.insert(request_id, response_tx);
78 let data = request.encode();
79 if let Err(e) = send_frame(&mut sink, data, MAX_MESSAGE_SIZE).await {
80 if let Some(sender) = pending_requests.remove(&request_id) {
81 let _ = sender.send(Err(Error::Network(e)));
82 }
83 recv_handle.abort();
84 return;
85 }
86 },
87 Some(response_data) = response_rx.recv() else {
88 for (_, sender) in pending_requests.drain() {
89 let _ = sender.send(Err(Error::RequestChannelClosed));
90 }
91 return;
92 } => match M::decode(response_data.coalesce()) {
93 Ok(message) => {
94 let request_id = message.request_id();
95 if let Some(sender) = pending_requests.remove(&request_id) {
96 let _ = sender.send(Ok(message));
97 }
98 }
99 Err(_) => {
100 recv_handle.abort();
101 warn!(
102 pending_count = pending_requests.len(),
103 "failed to decode response; terminating I/O task"
104 );
105 for (_, sender) in pending_requests.drain() {
106 let _ = sender.send(Err(Error::InvalidResponse));
107 }
108 return;
109 }
110 },
111 }
112}
113
114pub(super) fn run<E, Si, St, M>(
118 context: E,
119 sink: Si,
120 stream: St,
121) -> Result<(mpsc::Sender<Request<M>>, Handle<()>), commonware_runtime::Error>
122where
123 E: Spawner,
124 Si: Sink,
125 St: Stream,
126 M: Message,
127{
128 let (request_tx, request_rx) = mpsc::channel(REQUEST_BUFFER_SIZE);
129 let handle =
130 context.spawn(move |context| run_loop(context, sink, stream, request_rx, HashMap::new()));
131 Ok((request_tx, handle))
132}