1use std::{
3 self,
4 convert::TryInto,
5 io::{self, Cursor, ErrorKind, Read, Write},
6 sync::Arc,
7};
8
9use futures::{
10 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
11 lock::Mutex,
12};
13use rmpv::{decode::read_value, encode::write_value, Value};
14
15use crate::error::{DecodeError, EncodeError};
16
17#[derive(Debug, PartialEq, Clone)]
20pub enum RpcMessage {
21 RpcRequest {
22 msgid: u64,
23 method: String,
24 params: Vec<Value>,
25 }, RpcResponse {
27 msgid: u64,
28 error: Value,
29 result: Value,
30 }, RpcNotification {
32 method: String,
33 params: Vec<Value>,
34 }, }
36
37macro_rules! rpc_args {
38 ($($e:expr), *) => {{
39 let vec = vec![
40 $(Value::from($e),)*
41 ];
42 Value::from(vec)
43 }}
44}
45
46pub async fn decode<R: AsyncRead + Send + Unpin + 'static>(
53 reader: &mut R,
54 rest: &mut Vec<u8>,
55) -> std::result::Result<RpcMessage, Box<DecodeError>> {
56 let mut buf = Box::new([0_u8; 80 * 1024]);
57 let mut bytes_read;
58
59 loop {
60 let mut c = Cursor::new(&rest);
61
62 match decode_buffer(&mut c).map_err(|b| *b) {
63 Ok(msg) => {
64 let pos = c.position();
65 *rest = rest.split_off(pos as usize); return Ok(msg);
67 }
68 Err(DecodeError::BufferError(e))
69 if e.kind() == ErrorKind::UnexpectedEof =>
70 {
71 debug!("Not enough data, reading more!");
72 bytes_read = reader.read(&mut *buf).await;
73 }
74 Err(err) => return Err(err.into()),
75 }
76
77 match bytes_read {
78 Ok(n) if n == 0 => {
79 return Err(io::Error::new(ErrorKind::UnexpectedEof, "EOF").into());
80 }
81 Ok(n) => {
82 rest.extend_from_slice(&buf[..n]);
83 }
84 Err(err) => return Err(err.into()),
85 }
86 }
87}
88
89fn decode_buffer<R: Read>(
92 reader: &mut R,
93) -> std::result::Result<RpcMessage, Box<DecodeError>> {
94 use crate::error::InvalidMessage::*;
95
96 let arr: Vec<Value> = read_value(reader)?.try_into().map_err(NotAnArray)?;
97
98 let mut arr = arr.into_iter();
99
100 let msgtyp: u64 = arr
101 .next()
102 .ok_or(WrongArrayLength(3..=4, 0))?
103 .try_into()
104 .map_err(InvalidType)?;
105
106 match msgtyp {
107 0 => {
108 let msgid: u64 = arr
109 .next()
110 .ok_or(WrongArrayLength(4..=4, 1))?
111 .try_into()
112 .map_err(InvalidMsgid)?;
113 let method = match arr.next() {
114 Some(Value::String(s)) if s.is_str() => {
115 s.into_str().expect("Can remove using #230 of rmpv")
116 }
117 Some(val) => return Err(InvalidRequestName(msgid, val).into()),
118 None => return Err(WrongArrayLength(4..=4, 2).into()),
119 };
120 let params: Vec<Value> = arr
121 .next()
122 .ok_or(WrongArrayLength(4..=4, 3))?
123 .try_into()
124 .map_err(|val| InvalidParams(val, method.clone()))?;
125
126 Ok(RpcMessage::RpcRequest {
127 msgid,
128 method,
129 params,
130 })
131 }
132 1 => {
133 let msgid: u64 = arr
134 .next()
135 .ok_or(WrongArrayLength(4..=4, 1))?
136 .try_into()
137 .map_err(InvalidMsgid)?;
138 let error = arr.next().ok_or(WrongArrayLength(4..=4, 2))?;
139 let result = arr.next().ok_or(WrongArrayLength(4..=4, 3))?;
140 Ok(RpcMessage::RpcResponse {
141 msgid,
142 error,
143 result,
144 })
145 }
146 2 => {
147 let method = match arr.next() {
148 Some(Value::String(s)) if s.is_str() => {
149 s.into_str().expect("Can remove using #230 of rmpv")
150 }
151 Some(val) => return Err(InvalidNotificationName(val).into()),
152 None => return Err(WrongArrayLength(3..=3, 1).into()),
153 };
154 let params: Vec<Value> = arr
155 .next()
156 .ok_or(WrongArrayLength(3..=3, 2))?
157 .try_into()
158 .map_err(|val| InvalidParams(val, method.clone()))?;
159 Ok(RpcMessage::RpcNotification { method, params })
160 }
161 t => Err(UnknownMessageType(t).into()),
162 }
163}
164
165pub fn encode_sync<W: Write>(
167 writer: &mut W,
168 msg: RpcMessage,
169) -> std::result::Result<(), Box<EncodeError>> {
170 match msg {
171 RpcMessage::RpcRequest {
172 msgid,
173 method,
174 params,
175 } => {
176 let val = rpc_args!(0, msgid, method, params);
177 write_value(writer, &val)?;
178 }
179 RpcMessage::RpcResponse {
180 msgid,
181 error,
182 result,
183 } => {
184 let val = rpc_args!(1, msgid, error, result);
185 write_value(writer, &val)?;
186 }
187 RpcMessage::RpcNotification { method, params } => {
188 let val = rpc_args!(2, method, params);
189 write_value(writer, &val)?;
190 }
191 };
192
193 Ok(())
194}
195
196pub async fn encode<W: AsyncWrite + Send + Unpin + 'static>(
199 writer: Arc<Mutex<W>>,
200 msg: RpcMessage,
201) -> std::result::Result<(), Box<EncodeError>> {
202 let mut v: Vec<u8> = vec![];
203 encode_sync(&mut v, msg)?;
204
205 let mut writer = writer.lock().await;
206 writer.write_all(&v).await?;
207 writer.flush().await?;
208
209 Ok(())
210}
211
212pub trait IntoVal<T> {
213 fn into_val(self) -> T;
214}
215
216impl<'a> IntoVal<Value> for &'a str {
217 fn into_val(self) -> Value {
218 Value::from(self)
219 }
220}
221
222impl IntoVal<Value> for Vec<String> {
223 fn into_val(self) -> Value {
224 let vec: Vec<Value> = self.into_iter().map(Value::from).collect();
225 Value::from(vec)
226 }
227}
228
229impl IntoVal<Value> for Vec<Value> {
230 fn into_val(self) -> Value {
231 Value::from(self)
232 }
233}
234
235impl IntoVal<Value> for (i64, i64) {
236 fn into_val(self) -> Value {
237 Value::from(vec![Value::from(self.0), Value::from(self.1)])
238 }
239}
240
241impl IntoVal<Value> for bool {
242 fn into_val(self) -> Value {
243 Value::from(self)
244 }
245}
246
247impl IntoVal<Value> for i64 {
248 fn into_val(self) -> Value {
249 Value::from(self)
250 }
251}
252
253impl IntoVal<Value> for f64 {
254 fn into_val(self) -> Value {
255 Value::from(self)
256 }
257}
258
259impl IntoVal<Value> for String {
260 fn into_val(self) -> Value {
261 Value::from(self)
262 }
263}
264
265impl IntoVal<Value> for Value {
266 fn into_val(self) -> Value {
267 self
268 }
269}
270
271impl IntoVal<Value> for Vec<(Value, Value)> {
272 fn into_val(self) -> Value {
273 Value::from(self)
274 }
275}
276
277#[cfg(all(test, feature = "use_tokio"))]
278mod test {
279 use super::*;
280 use futures::{io::BufWriter, lock::Mutex};
281 use std::{io::Cursor, sync::Arc};
282
283 use tokio;
284
285 #[tokio::test]
286 async fn request_test() {
287 let msg = RpcMessage::RpcRequest {
288 msgid: 1,
289 method: "test_method".to_owned(),
290 params: vec![],
291 };
292
293 let buff: Vec<u8> = vec![];
294 let tmp = Arc::new(Mutex::new(BufWriter::new(buff)));
295 let tmp2 = tmp.clone();
296 let msg2 = msg.clone();
297
298 encode(tmp2, msg2).await.unwrap();
299
300 let msg_dest = {
301 let v = &mut *tmp.lock().await;
302 let x = v.get_mut();
303 decode_buffer(&mut x.as_slice()).unwrap()
304 };
305
306 assert_eq!(msg, msg_dest);
307 }
308
309 #[tokio::test]
310 async fn request_test_twice() {
311 let msg_1 = RpcMessage::RpcRequest {
312 msgid: 1,
313 method: "test_method".to_owned(),
314 params: vec![],
315 };
316
317 let msg_2 = RpcMessage::RpcRequest {
318 msgid: 2,
319 method: "test_method_2".to_owned(),
320 params: vec![],
321 };
322
323 let buff: Vec<u8> = vec![];
324 let tmp = Arc::new(Mutex::new(BufWriter::new(buff)));
325 let msg_1_c = msg_1.clone();
326 let msg_2_c = msg_2.clone();
327
328 let tmp_c = tmp.clone();
329 encode(tmp_c, msg_1_c).await.unwrap();
330 let tmp_c = tmp.clone();
331 encode(tmp_c, msg_2_c).await.unwrap();
332 let len = (*tmp).lock().await.get_ref().len();
333 assert_eq!(34, len); let v = &mut *tmp.lock().await;
336 let x = v.get_mut();
337 let mut cursor = Cursor::new(x.as_slice());
338 let msg_dest_1 = decode_buffer(&mut cursor).unwrap();
339
340 assert_eq!(msg_1, msg_dest_1);
341 assert_eq!(16, cursor.position());
342
343 let msg_dest_2 = decode_buffer(&mut cursor).unwrap();
344 assert_eq!(msg_2, msg_dest_2);
345 }
346}