1#![allow(dead_code)]
2use async_std::channel::{unbounded, Receiver, Sender};
3use async_std::io::prelude::*;
4use async_std::net::{TcpStream, ToSocketAddrs};
5use async_std::sync::{Arc, Mutex};
6use async_std::task;
7use futures::future::FutureExt;
8use futures::select;
9use msgpack_rpc::message::{Message, Notification, Request, Response};
10use std::collections::HashMap;
11use std::io::Cursor;
12
13use crate::error::NetworkResult;
14use crate::NetworkError;
15
16#[derive(Clone, Debug)]
18pub struct MsgPackClient {
19 request_sender: Sender<Request>,
20 notification_sender: Sender<Notification>,
21 pub notification_receiver: Receiver<Notification>,
22 pub request_receiver: Receiver<Request>,
23 response_channels: Arc<Mutex<HashMap<u32, Sender<Response>>>>,
24}
25
26enum Rpc {
27 Send(Message),
28 Receive(usize),
29}
30
31impl MsgPackClient {
32 pub async fn connect(addrs: impl ToSocketAddrs) -> NetworkResult<Self> {
35 let mut stream = TcpStream::connect(addrs).await?;
36 let response_channels = Arc::new(Mutex::new(HashMap::new()));
37
38 let (request_sender, request_receiver) = unbounded::<Request>();
39 let (inner_request_sender, inner_request_receiver) = unbounded::<Request>();
40 let (notification_sender, notification_receiver) = unbounded::<Notification>();
41 let (inner_notification_sender, inner_notification_receiver) = unbounded::<Notification>();
42 let res_channels = Arc::clone(&response_channels);
43
44 task::spawn(async move {
45 let mut current_message: Vec<u8> = vec![];
46
47 let buf_size: usize = 1024 * 50; let mut buf = vec![0_u8; buf_size];
54
55 loop {
56 let to_process = select! {
57 maybe_request = request_receiver.recv().fuse() => {
58 if let Ok(request) = maybe_request {
59 Some(Rpc::Send(Message::Request(request)))
60 } else {
61 None
62 }
63 },
64 maybe_notification = notification_receiver.recv().fuse() => {
65 if let Ok(notification) = maybe_notification {
66 Some(Rpc::Send(Message::Notification(notification)))
67 } else {
68 None
69 }
70 },
71 maybe_bytes_read = stream.read(&mut buf).fuse() => {
72 if let Ok(bytes_read) = maybe_bytes_read {
73 Some(Rpc::Receive(bytes_read))
74 } else {
75 None
76 }
77 }
78 };
79 match to_process {
80 Some(Rpc::Send(m)) => {
81 let message = m.pack().expect("Couldn't serialize message");
82 stream.write_all(&message).await.expect("Couldn't send message");
83 }
84 Some(Rpc::Receive(n)) => {
85 current_message.extend(&buf[..n]);
86 let mut frame = Cursor::new(current_message.clone());
87
88 let recv_res = match Message::decode(&mut frame) {
89 Ok(Message::Notification(n)) => inner_notification_sender
90 .send(n)
91 .await
92 .map_err(|e| NetworkError::Send { message: e.to_string() }),
93 Ok(Message::Request(r)) => inner_request_sender
94 .send(r)
95 .await
96 .map_err(|e| NetworkError::Send { message: e.to_string() }),
97 Ok(Message::Response(r)) => {
98 let mut senders = res_channels.lock().await;
99 let sender: Sender<Response> =
100 senders.remove(&r.id).expect("Got response but no request awaiting it");
101
102 sender
104 .send(r)
105 .await
106 .map_err(|e| NetworkError::Send { message: e.to_string() })
107 }
108 Err(e) => {
109 panic!("{e}");
111 }
112 };
113
114 if let Err(e) = recv_res {
116 return e;
117 }
118
119 #[allow(clippy::cast_possible_truncation)]
120 {
121 let (_, remaining) = current_message.split_at(frame.position() as usize);
122 current_message = remaining.to_vec();
123 }
124 }
125 None => {}
126 }
127 }
128 });
129 Ok(Self {
130 request_sender,
131 notification_sender,
132 notification_receiver: inner_notification_receiver,
133 request_receiver: inner_request_receiver,
134 response_channels,
135 })
136 }
137
138 pub async fn request(&self, request: Request) -> Result<Response, NetworkError> {
139 let (response_sender, response_receiver) = unbounded();
140
141 let _ = self.response_channels.lock().await.insert(request.id, response_sender);
143
144 let send_res = self.request_sender.send(request).await;
147 if send_res.is_err() {
148 let e = format!("Failed to send request: {:?}", send_res);
149 return Err(NetworkError::Send { message: e });
150 }
151
152 response_receiver.recv().await.map_err(NetworkError::Recv)
154 }
155
156 pub async fn _notify(&self, notification: Notification) -> Result<(), NetworkError> {
157 let res = self.notification_sender.send(notification.to_owned()).await;
158 if res.is_err() {
159 let e = format!("Failed to send notification: {:?}", notification);
160 return Err(NetworkError::Send { message: e });
161 }
162 Ok(())
163 }
164}