commonware_sync/net/
io.rs1use crate::{
2 net::{request_id::RequestId, Message, MAX_MESSAGE_SIZE},
3 Error,
4};
5use commonware_macros::select;
6use commonware_runtime::{Handle, Sink, Spawner, Stream};
7use commonware_stream::utils::codec::{recv_frame, send_frame};
8use futures::{
9 channel::{mpsc, oneshot},
10 StreamExt,
11};
12use std::collections::HashMap;
13
14const REQUEST_BUFFER_SIZE: usize = 64;
15
16pub(super) struct Request<M: Message> {
18 pub(super) request: M,
19 pub(super) response_tx: oneshot::Sender<Result<M, Error>>,
20}
21
22async fn run_loop<Si, St, M>(
26 mut sink: Si,
27 mut stream: St,
28 mut request_rx: mpsc::Receiver<Request<M>>,
29 mut pending_requests: HashMap<RequestId, oneshot::Sender<Result<M, Error>>>,
30) where
31 Si: Sink,
32 St: Stream,
33 M: Message,
34{
35 loop {
36 select! {
37 outgoing = request_rx.next() => {
38 match outgoing {
39 Some(Request { request, response_tx }) => {
40 let request_id = request.request_id();
41 pending_requests.insert(request_id, response_tx);
42 let data = request.encode().to_vec();
43 if let Err(e) = send_frame(&mut sink, &data, MAX_MESSAGE_SIZE).await {
44 if let Some(sender) = pending_requests.remove(&request_id) {
45 let _ = sender.send(Err(Error::Network(e)));
46 }
47 return;
48 }
49 },
50 None => return,
51 }
52 },
53 incoming = recv_frame(&mut stream, MAX_MESSAGE_SIZE) => {
54 match incoming {
55 Ok(response_data) => {
56 match M::decode(&response_data[..]) {
57 Ok(message) => {
58 let request_id = message.request_id();
59 if let Some(sender) = pending_requests.remove(&request_id) {
60 let _ = sender.send(Ok(message));
61 }
62 },
63 Err(_) => { }
64 }
65 },
66 Err(_e) => {
67 for (_, sender) in pending_requests.drain() {
68 let _ = sender.send(Err(Error::RequestChannelClosed));
69 }
70 return;
71 }
72 }
73 }
74 }
75 }
76}
77
78pub(super) fn run<E, Si, St, M>(
82 context: E,
83 sink: Si,
84 stream: St,
85) -> Result<(mpsc::Sender<Request<M>>, Handle<()>), commonware_runtime::Error>
86where
87 E: Spawner,
88 Si: Sink,
89 St: Stream,
90 M: Message,
91{
92 let (request_tx, request_rx) = mpsc::channel(REQUEST_BUFFER_SIZE);
93 let handle = context.spawn(move |_| run_loop(sink, stream, request_rx, HashMap::new()));
94 Ok((request_tx, handle))
95}