airsim_client/
msgpack.rs

1#![allow(dead_code)]
2use async_std::channel::{unbounded, Receiver, Sender};
3use async_std::io::prelude::*;
4use async_std::net::{TcpStream, ToSocketAddrs};
5use async_std::sync::{Arc, Mutex};
6use async_std::task;
7use futures::future::FutureExt;
8use futures::select;
9use msgpack_rpc::message::{Message, Notification, Request, Response};
10use std::collections::HashMap;
11use std::io::Cursor;
12
13use crate::error::NetworkResult;
14use crate::NetworkError;
15
16/// msgpack client used to interface with the airsim msgpack server
17#[derive(Clone, Debug)]
18pub struct MsgPackClient {
19    request_sender: Sender<Request>,
20    notification_sender: Sender<Notification>,
21    pub notification_receiver: Receiver<Notification>,
22    pub request_receiver: Receiver<Request>,
23    response_channels: Arc<Mutex<HashMap<u32, Sender<Response>>>>,
24}
25
26enum Rpc {
27    Send(Message),
28    Receive(usize),
29}
30
31impl MsgPackClient {
32    /// Establish a TCP socket connection to the `MessagePack-RPC` server
33    /// running in a background thread
34    pub async fn connect(addrs: impl ToSocketAddrs) -> NetworkResult<Self> {
35        let mut stream = TcpStream::connect(addrs).await?;
36        let response_channels = Arc::new(Mutex::new(HashMap::new()));
37
38        let (request_sender, request_receiver) = unbounded::<Request>();
39        let (inner_request_sender, inner_request_receiver) = unbounded::<Request>();
40        let (notification_sender, notification_receiver) = unbounded::<Notification>();
41        let (inner_notification_sender, inner_notification_receiver) = unbounded::<Notification>();
42        let res_channels = Arc::clone(&response_channels);
43
44        task::spawn(async move {
45            let mut current_message: Vec<u8> = vec![];
46
47            // 1,024 bytes = 1 kB
48            // 1kB x 1000 = 1mB
49            let buf_size: usize = 1024 * 50; // 0.1mB
50
51            // for some reason, msgpack expects a fixed size
52            // for the bytes buffer
53            let mut buf = vec![0_u8; buf_size];
54
55            loop {
56                let to_process = select! {
57                    maybe_request = request_receiver.recv().fuse() => {
58                        if let Ok(request) = maybe_request {
59                            Some(Rpc::Send(Message::Request(request)))
60                        } else {
61                            None
62                        }
63                    },
64                    maybe_notification = notification_receiver.recv().fuse() => {
65                        if let Ok(notification) = maybe_notification {
66                            Some(Rpc::Send(Message::Notification(notification)))
67                        } else {
68                            None
69                        }
70                    },
71                    maybe_bytes_read = stream.read(&mut buf).fuse() => {
72                        if let Ok(bytes_read) = maybe_bytes_read {
73                            Some(Rpc::Receive(bytes_read))
74                        } else {
75                            None
76                        }
77                    }
78                };
79                match to_process {
80                    Some(Rpc::Send(m)) => {
81                        let message = m.pack().expect("Couldn't serialize message");
82                        stream.write_all(&message).await.expect("Couldn't send message");
83                    }
84                    Some(Rpc::Receive(n)) => {
85                        current_message.extend(&buf[..n]);
86                        let mut frame = Cursor::new(current_message.clone());
87
88                        let recv_res = match Message::decode(&mut frame) {
89                            Ok(Message::Notification(n)) => inner_notification_sender
90                                .send(n)
91                                .await
92                                .map_err(|e| NetworkError::Send { message: e.to_string() }),
93                            Ok(Message::Request(r)) => inner_request_sender
94                                .send(r)
95                                .await
96                                .map_err(|e| NetworkError::Send { message: e.to_string() }),
97                            Ok(Message::Response(r)) => {
98                                let mut senders = res_channels.lock().await;
99                                let sender: Sender<Response> =
100                                    senders.remove(&r.id).expect("Got response but no request awaiting it");
101
102                                // send response to the `request` function
103                                sender
104                                    .send(r)
105                                    .await
106                                    .map_err(|e| NetworkError::Send { message: e.to_string() })
107                            }
108                            Err(e) => {
109                                // DecodeError
110                                panic!("{e}");
111                            }
112                        };
113
114                        // if error, return it
115                        if let Err(e) = recv_res {
116                            return e;
117                        }
118
119                        #[allow(clippy::cast_possible_truncation)]
120                        {
121                            let (_, remaining) = current_message.split_at(frame.position() as usize);
122                            current_message = remaining.to_vec();
123                        }
124                    }
125                    None => {}
126                }
127            }
128        });
129        Ok(Self {
130            request_sender,
131            notification_sender,
132            notification_receiver: inner_notification_receiver,
133            request_receiver: inner_request_receiver,
134            response_channels,
135        })
136    }
137
138    pub async fn request(&self, request: Request) -> Result<Response, NetworkError> {
139        let (response_sender, response_receiver) = unbounded();
140
141        // add the response sender (forwards the response from the server) by request id
142        let _ = self.response_channels.lock().await.insert(request.id, response_sender);
143
144        // forward request to the thread that then forwards it to the MessagePack-RPC server
145        // the response is added to the response channel
146        let send_res = self.request_sender.send(request).await;
147        if send_res.is_err() {
148            let e = format!("Failed to send request: {:?}", send_res);
149            return Err(NetworkError::Send { message: e });
150        }
151
152        // return result from request which is forwarded from the background thread above
153        response_receiver.recv().await.map_err(NetworkError::Recv)
154    }
155
156    pub async fn _notify(&self, notification: Notification) -> Result<(), NetworkError> {
157        let res = self.notification_sender.send(notification.to_owned()).await;
158        if res.is_err() {
159            let e = format!("Failed to send notification: {:?}", notification);
160            return Err(NetworkError::Send { message: e });
161        }
162        Ok(())
163    }
164}