use std::{
self,
convert::TryInto,
io::{self, Cursor, ErrorKind, Read},
sync::Arc,
};
use futures::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt,},
lock::Mutex,
};
use rmpv::{decode::read_value, encode::write_value, Value};
use crate::error::{DecodeError, EncodeError};
#[derive(Debug, PartialEq, Clone)]
pub enum RpcMessage {
RpcRequest {
msgid: u64,
method: String,
params: Vec<Value>,
}, RpcResponse {
msgid: u64,
error: Value,
result: Value,
}, RpcNotification {
method: String,
params: Vec<Value>,
}, }
macro_rules! rpc_args {
($($e:expr), *) => {{
let vec = vec![
$(Value::from($e),)*
];
Value::from(vec)
}}
}
pub async fn decode<R: AsyncRead + Send + Unpin + 'static>(
reader: &mut R,
rest: &mut Vec<u8>,
) -> std::result::Result<RpcMessage, Box<DecodeError>> {
let mut buf = Box::new([0_u8; 80 * 1024]);
let mut bytes_read;
loop {
let mut c = Cursor::new(&rest);
match decode_buffer(&mut c).map_err(|b| *b) {
Ok(msg) => {
let pos = c.position();
*rest = rest.split_off(pos as usize); return Ok(msg);
}
Err(DecodeError::BufferError(e))
if e.kind() == ErrorKind::UnexpectedEof =>
{
debug!("Not enough data, reading more!");
bytes_read = reader.read(&mut *buf).await;
}
Err(err) => return Err(err.into()),
}
match bytes_read {
Ok(n) if n == 0 => {
return Err(io::Error::new(ErrorKind::UnexpectedEof, "EOF").into());
}
Ok(n) => {
rest.extend_from_slice(&buf[..n]);
}
Err(err) => return Err(err.into()),
}
}
}
fn decode_buffer<R: Read>(
reader: &mut R,
) -> std::result::Result<RpcMessage, Box<DecodeError>> {
use crate::error::InvalidMessage::*;
let arr: Vec<Value> = read_value(reader)?.try_into().map_err(NotAnArray)?;
let mut arr = arr.into_iter();
let msgtyp: u64 = arr
.next()
.ok_or(WrongArrayLength(3..=4, 0))?
.try_into()
.map_err(InvalidType)?;
match msgtyp {
0 => {
let msgid: u64 = arr
.next()
.ok_or(WrongArrayLength(4..=4, 1))?
.try_into()
.map_err(InvalidMsgid)?;
let method = match arr.next() {
Some(Value::String(s)) if s.is_str() => {
s.into_str().expect("Can remove using #230 of rmpv")
}
Some(val) => return Err(InvalidRequestName(msgid, val).into()),
None => return Err(WrongArrayLength(4..=4, 2).into()),
};
let params: Vec<Value> = arr
.next()
.ok_or(WrongArrayLength(4..=4, 3))?
.try_into()
.map_err(|val| InvalidParams(val, method.clone()))?;
Ok(RpcMessage::RpcRequest {
msgid,
method,
params,
})
}
1 => {
let msgid: u64 = arr
.next()
.ok_or(WrongArrayLength(4..=4, 1))?
.try_into()
.map_err(InvalidMsgid)?;
let error = arr.next().ok_or(WrongArrayLength(4..=4, 2))?;
let result = arr.next().ok_or(WrongArrayLength(4..=4, 3))?;
Ok(RpcMessage::RpcResponse {
msgid,
error,
result,
})
}
2 => {
let method = match arr.next() {
Some(Value::String(s)) if s.is_str() => {
s.into_str().expect("Can remove using #230 of rmpv")
}
Some(val) => return Err(InvalidNotificationName(val).into()),
None => return Err(WrongArrayLength(3..=3, 1).into()),
};
let params: Vec<Value> = arr
.next()
.ok_or(WrongArrayLength(3..=3, 2))?
.try_into()
.map_err(|val| InvalidParams(val, method.clone()))?;
Ok(RpcMessage::RpcNotification { method, params })
}
t => Err(UnknownMessageType(t).into()),
}
}
pub async fn encode<W: AsyncWrite + Send + Unpin + 'static>(
writer: Arc<Mutex<W>>,
msg: RpcMessage,
) -> std::result::Result<(), Box<EncodeError>> {
let mut v: Vec<u8> = vec![];
match msg {
RpcMessage::RpcRequest {
msgid,
method,
params,
} => {
let val = rpc_args!(0, msgid, method, params);
write_value(&mut v, &val)?;
}
RpcMessage::RpcResponse {
msgid,
error,
result,
} => {
let val = rpc_args!(1, msgid, error, result);
write_value(&mut v, &val)?;
}
RpcMessage::RpcNotification { method, params } => {
let val = rpc_args!(2, method, params);
write_value(&mut v, &val)?;
}
};
let mut writer = writer.lock().await;
writer.write_all(&v).await?;
writer.flush().await?;
Ok(())
}
pub trait IntoVal<T> {
fn into_val(self) -> T;
}
impl<'a> IntoVal<Value> for &'a str {
fn into_val(self) -> Value {
Value::from(self)
}
}
impl IntoVal<Value> for Vec<String> {
fn into_val(self) -> Value {
let vec: Vec<Value> = self.into_iter().map(Value::from).collect();
Value::from(vec)
}
}
impl IntoVal<Value> for Vec<Value> {
fn into_val(self) -> Value {
Value::from(self)
}
}
impl IntoVal<Value> for (i64, i64) {
fn into_val(self) -> Value {
Value::from(vec![Value::from(self.0), Value::from(self.1)])
}
}
impl IntoVal<Value> for bool {
fn into_val(self) -> Value {
Value::from(self)
}
}
impl IntoVal<Value> for i64 {
fn into_val(self) -> Value {
Value::from(self)
}
}
impl IntoVal<Value> for f64 {
fn into_val(self) -> Value {
Value::from(self)
}
}
impl IntoVal<Value> for String {
fn into_val(self) -> Value {
Value::from(self)
}
}
impl IntoVal<Value> for Value {
fn into_val(self) -> Value {
self
}
}
impl IntoVal<Value> for Vec<(Value, Value)> {
fn into_val(self) -> Value {
Value::from(self)
}
}
#[cfg(all(test, feature = "use_tokio"))]
mod test {
use super::*;
use futures::{io::BufWriter, lock::Mutex};
use std::{io::Cursor, sync::Arc};
use tokio;
#[tokio::test]
async fn request_test() {
let msg = RpcMessage::RpcRequest {
msgid: 1,
method: "test_method".to_owned(),
params: vec![],
};
let buff: Vec<u8> = vec![];
let tmp = Arc::new(Mutex::new(BufWriter::new(buff)));
let tmp2 = tmp.clone();
let msg2 = msg.clone();
encode(tmp2, msg2).await.unwrap();
let msg_dest = {
let v = &mut *tmp.lock().await;
let x = v.get_mut();
decode_buffer(&mut x.as_slice()).unwrap()
};
assert_eq!(msg, msg_dest);
}
#[tokio::test]
async fn request_test_twice() {
let msg_1 = RpcMessage::RpcRequest {
msgid: 1,
method: "test_method".to_owned(),
params: vec![],
};
let msg_2 = RpcMessage::RpcRequest {
msgid: 2,
method: "test_method_2".to_owned(),
params: vec![],
};
let buff: Vec<u8> = vec![];
let tmp = Arc::new(Mutex::new(BufWriter::new(buff)));
let msg_1_c = msg_1.clone();
let msg_2_c = msg_2.clone();
let tmp_c = tmp.clone();
encode(tmp_c, msg_1_c).await.unwrap();
let tmp_c = tmp.clone();
encode(tmp_c, msg_2_c).await.unwrap();
let len = (*tmp).lock().await.get_ref().len();
assert_eq!(34, len); let v = &mut *tmp.lock().await;
let x = v.get_mut();
let mut cursor = Cursor::new(x.as_slice());
let msg_dest_1 = decode_buffer(&mut cursor).unwrap();
assert_eq!(msg_1, msg_dest_1);
assert_eq!(16, cursor.position());
let msg_dest_2 = decode_buffer(&mut cursor).unwrap();
assert_eq!(msg_2, msg_dest_2);
}
}