nvim_rs/rpc/
model.rs

1//! Decoding and encoding msgpack rpc messages from/to neovim.
2use std::{
3  self,
4  convert::TryInto,
5  io::{self, Cursor, ErrorKind, Read, Write},
6  sync::Arc,
7};
8
9use futures::{
10  io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
11  lock::Mutex,
12};
13use rmpv::{decode::read_value, encode::write_value, Value};
14
15use crate::error::{DecodeError, EncodeError};
16
17/// A msgpack-rpc message, see
18/// <https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md>
19#[derive(Debug, PartialEq, Clone)]
20pub enum RpcMessage {
21  RpcRequest {
22    msgid: u64,
23    method: String,
24    params: Vec<Value>,
25  }, // 0
26  RpcResponse {
27    msgid: u64,
28    error: Value,
29    result: Value,
30  }, // 1
31  RpcNotification {
32    method: String,
33    params: Vec<Value>,
34  }, // 2
35}
36
37macro_rules! rpc_args {
38    ($($e:expr), *) => {{
39        let vec = vec![
40          $(Value::from($e),)*
41        ];
42        Value::from(vec)
43    }}
44}
45
46/// Continously reads from reader, pushing onto `rest`. Then tries to decode the
47/// contents of `rest`. If it succeeds, returns the message, and leaves any
48/// non-decoded bytes in `rest`. If we did not read enough for a full message,
49/// read more. Return on all other errors.
50//
51// TODO: This might be inefficient. Can't we read into `rest` directly?
52pub async fn decode<R: AsyncRead + Send + Unpin + 'static>(
53  reader: &mut R,
54  rest: &mut Vec<u8>,
55) -> std::result::Result<RpcMessage, Box<DecodeError>> {
56  let mut buf = Box::new([0_u8; 80 * 1024]);
57  let mut bytes_read;
58
59  loop {
60    let mut c = Cursor::new(&rest);
61
62    match decode_buffer(&mut c).map_err(|b| *b) {
63      Ok(msg) => {
64        let pos = c.position();
65        *rest = rest.split_off(pos as usize); // TODO: more efficiency
66        return Ok(msg);
67      }
68      Err(DecodeError::BufferError(e))
69        if e.kind() == ErrorKind::UnexpectedEof =>
70      {
71        debug!("Not enough data, reading more!");
72        bytes_read = reader.read(&mut *buf).await;
73      }
74      Err(err) => return Err(err.into()),
75    }
76
77    match bytes_read {
78      Ok(n) if n == 0 => {
79        return Err(io::Error::new(ErrorKind::UnexpectedEof, "EOF").into());
80      }
81      Ok(n) => {
82        rest.extend_from_slice(&buf[..n]);
83      }
84      Err(err) => return Err(err.into()),
85    }
86  }
87}
88
89/// Syncronously decode the content of a reader into an rpc message. Tries to
90/// give detailed errors if something went wrong.
91fn decode_buffer<R: Read>(
92  reader: &mut R,
93) -> std::result::Result<RpcMessage, Box<DecodeError>> {
94  use crate::error::InvalidMessage::*;
95
96  let arr: Vec<Value> = read_value(reader)?.try_into().map_err(NotAnArray)?;
97
98  let mut arr = arr.into_iter();
99
100  let msgtyp: u64 = arr
101    .next()
102    .ok_or(WrongArrayLength(3..=4, 0))?
103    .try_into()
104    .map_err(InvalidType)?;
105
106  match msgtyp {
107    0 => {
108      let msgid: u64 = arr
109        .next()
110        .ok_or(WrongArrayLength(4..=4, 1))?
111        .try_into()
112        .map_err(InvalidMsgid)?;
113      let method = match arr.next() {
114        Some(Value::String(s)) if s.is_str() => {
115          s.into_str().expect("Can remove using #230 of rmpv")
116        }
117        Some(val) => return Err(InvalidRequestName(msgid, val).into()),
118        None => return Err(WrongArrayLength(4..=4, 2).into()),
119      };
120      let params: Vec<Value> = arr
121        .next()
122        .ok_or(WrongArrayLength(4..=4, 3))?
123        .try_into()
124        .map_err(|val| InvalidParams(val, method.clone()))?;
125
126      Ok(RpcMessage::RpcRequest {
127        msgid,
128        method,
129        params,
130      })
131    }
132    1 => {
133      let msgid: u64 = arr
134        .next()
135        .ok_or(WrongArrayLength(4..=4, 1))?
136        .try_into()
137        .map_err(InvalidMsgid)?;
138      let error = arr.next().ok_or(WrongArrayLength(4..=4, 2))?;
139      let result = arr.next().ok_or(WrongArrayLength(4..=4, 3))?;
140      Ok(RpcMessage::RpcResponse {
141        msgid,
142        error,
143        result,
144      })
145    }
146    2 => {
147      let method = match arr.next() {
148        Some(Value::String(s)) if s.is_str() => {
149          s.into_str().expect("Can remove using #230 of rmpv")
150        }
151        Some(val) => return Err(InvalidNotificationName(val).into()),
152        None => return Err(WrongArrayLength(3..=3, 1).into()),
153      };
154      let params: Vec<Value> = arr
155        .next()
156        .ok_or(WrongArrayLength(3..=3, 2))?
157        .try_into()
158        .map_err(|val| InvalidParams(val, method.clone()))?;
159      Ok(RpcMessage::RpcNotification { method, params })
160    }
161    t => Err(UnknownMessageType(t).into()),
162  }
163}
164
165/// Encode the given message into the `writer`.
166pub fn encode_sync<W: Write>(
167  writer: &mut W,
168  msg: RpcMessage,
169) -> std::result::Result<(), Box<EncodeError>> {
170  match msg {
171    RpcMessage::RpcRequest {
172      msgid,
173      method,
174      params,
175    } => {
176      let val = rpc_args!(0, msgid, method, params);
177      write_value(writer, &val)?;
178    }
179    RpcMessage::RpcResponse {
180      msgid,
181      error,
182      result,
183    } => {
184      let val = rpc_args!(1, msgid, error, result);
185      write_value(writer, &val)?;
186    }
187    RpcMessage::RpcNotification { method, params } => {
188      let val = rpc_args!(2, method, params);
189      write_value(writer, &val)?;
190    }
191  };
192
193  Ok(())
194}
195
196/// Encode the given message into the `BufWriter`. Flushes the writer when
197/// finished.
198pub async fn encode<W: AsyncWrite + Send + Unpin + 'static>(
199  writer: Arc<Mutex<W>>,
200  msg: RpcMessage,
201) -> std::result::Result<(), Box<EncodeError>> {
202  let mut v: Vec<u8> = vec![];
203  encode_sync(&mut v, msg)?;
204
205  let mut writer = writer.lock().await;
206  writer.write_all(&v).await?;
207  writer.flush().await?;
208
209  Ok(())
210}
211
212pub trait IntoVal<T> {
213  fn into_val(self) -> T;
214}
215
216impl<'a> IntoVal<Value> for &'a str {
217  fn into_val(self) -> Value {
218    Value::from(self)
219  }
220}
221
222impl IntoVal<Value> for Vec<String> {
223  fn into_val(self) -> Value {
224    let vec: Vec<Value> = self.into_iter().map(Value::from).collect();
225    Value::from(vec)
226  }
227}
228
229impl IntoVal<Value> for Vec<Value> {
230  fn into_val(self) -> Value {
231    Value::from(self)
232  }
233}
234
235impl IntoVal<Value> for (i64, i64) {
236  fn into_val(self) -> Value {
237    Value::from(vec![Value::from(self.0), Value::from(self.1)])
238  }
239}
240
241impl IntoVal<Value> for bool {
242  fn into_val(self) -> Value {
243    Value::from(self)
244  }
245}
246
247impl IntoVal<Value> for i64 {
248  fn into_val(self) -> Value {
249    Value::from(self)
250  }
251}
252
253impl IntoVal<Value> for f64 {
254  fn into_val(self) -> Value {
255    Value::from(self)
256  }
257}
258
259impl IntoVal<Value> for String {
260  fn into_val(self) -> Value {
261    Value::from(self)
262  }
263}
264
265impl IntoVal<Value> for Value {
266  fn into_val(self) -> Value {
267    self
268  }
269}
270
271impl IntoVal<Value> for Vec<(Value, Value)> {
272  fn into_val(self) -> Value {
273    Value::from(self)
274  }
275}
276
277#[cfg(all(test, feature = "use_tokio"))]
278mod test {
279  use super::*;
280  use futures::{io::BufWriter, lock::Mutex};
281  use std::{io::Cursor, sync::Arc};
282
283  use tokio;
284
285  #[tokio::test]
286  async fn request_test() {
287    let msg = RpcMessage::RpcRequest {
288      msgid: 1,
289      method: "test_method".to_owned(),
290      params: vec![],
291    };
292
293    let buff: Vec<u8> = vec![];
294    let tmp = Arc::new(Mutex::new(BufWriter::new(buff)));
295    let tmp2 = tmp.clone();
296    let msg2 = msg.clone();
297
298    encode(tmp2, msg2).await.unwrap();
299
300    let msg_dest = {
301      let v = &mut *tmp.lock().await;
302      let x = v.get_mut();
303      decode_buffer(&mut x.as_slice()).unwrap()
304    };
305
306    assert_eq!(msg, msg_dest);
307  }
308
309  #[tokio::test]
310  async fn request_test_twice() {
311    let msg_1 = RpcMessage::RpcRequest {
312      msgid: 1,
313      method: "test_method".to_owned(),
314      params: vec![],
315    };
316
317    let msg_2 = RpcMessage::RpcRequest {
318      msgid: 2,
319      method: "test_method_2".to_owned(),
320      params: vec![],
321    };
322
323    let buff: Vec<u8> = vec![];
324    let tmp = Arc::new(Mutex::new(BufWriter::new(buff)));
325    let msg_1_c = msg_1.clone();
326    let msg_2_c = msg_2.clone();
327
328    let tmp_c = tmp.clone();
329    encode(tmp_c, msg_1_c).await.unwrap();
330    let tmp_c = tmp.clone();
331    encode(tmp_c, msg_2_c).await.unwrap();
332    let len = (*tmp).lock().await.get_ref().len();
333    assert_eq!(34, len); // Note: msg2 is 2 longer than msg
334
335    let v = &mut *tmp.lock().await;
336    let x = v.get_mut();
337    let mut cursor = Cursor::new(x.as_slice());
338    let msg_dest_1 = decode_buffer(&mut cursor).unwrap();
339
340    assert_eq!(msg_1, msg_dest_1);
341    assert_eq!(16, cursor.position());
342
343    let msg_dest_2 = decode_buffer(&mut cursor).unwrap();
344    assert_eq!(msg_2, msg_dest_2);
345  }
346}