use rmpv::Value;
use serde::de::{self, SeqAccess, Visitor};
use serde::ser::SerializeSeq;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Debug, Clone, PartialEq)]
pub enum RpcMessage {
Request {
msgid: u32,
method: String,
params: Value,
},
Response {
msgid: u32,
error: Value,
result: Value,
},
Notify { method: String, params: Value },
}
impl RpcMessage {
pub fn request(msgid: u32, method: impl Into<String>, params: Value) -> Self {
Self::Request {
msgid,
method: method.into(),
params,
}
}
pub fn response_ok(msgid: u32, result: Value) -> Self {
Self::Response {
msgid,
error: Value::Nil,
result,
}
}
pub fn response_err(msgid: u32, error: Value) -> Self {
Self::Response {
msgid,
error,
result: Value::Nil,
}
}
pub fn notify(method: impl Into<String>, params: Value) -> Self {
Self::Notify {
method: method.into(),
params,
}
}
}
impl Serialize for RpcMessage {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
match self {
Self::Request {
msgid,
method,
params,
} => {
let mut seq = serializer.serialize_seq(Some(4))?;
seq.serialize_element(&0u8)?;
seq.serialize_element(msgid)?;
seq.serialize_element(method)?;
seq.serialize_element(params)?;
seq.end()
},
Self::Response {
msgid,
error,
result,
} => {
let mut seq = serializer.serialize_seq(Some(4))?;
seq.serialize_element(&1u8)?;
seq.serialize_element(msgid)?;
seq.serialize_element(error)?;
seq.serialize_element(result)?;
seq.end()
},
Self::Notify { method, params } => {
let mut seq = serializer.serialize_seq(Some(3))?;
seq.serialize_element(&2u8)?;
seq.serialize_element(method)?;
seq.serialize_element(params)?;
seq.end()
},
}
}
}
impl<'de> Deserialize<'de> for RpcMessage {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct MsgVisitor;
impl<'de> Visitor<'de> for MsgVisitor {
type Value = RpcMessage;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str("a MessagePack-RPC array of 3 or 4 elements")
}
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
let kind: u8 = seq
.next_element()?
.ok_or_else(|| de::Error::custom("missing message type tag"))?;
match kind {
0 => {
let msgid: u32 = seq
.next_element()?
.ok_or_else(|| de::Error::custom("request: missing msgid"))?;
let method: String = seq
.next_element()?
.ok_or_else(|| de::Error::custom("request: missing method"))?;
let params: Value = seq
.next_element()?
.ok_or_else(|| de::Error::custom("request: missing params"))?;
Ok(RpcMessage::Request {
msgid,
method,
params,
})
},
1 => {
let msgid: u32 = seq
.next_element()?
.ok_or_else(|| de::Error::custom("response: missing msgid"))?;
let error: Value = seq
.next_element()?
.ok_or_else(|| de::Error::custom("response: missing error"))?;
let result: Value = seq
.next_element()?
.ok_or_else(|| de::Error::custom("response: missing result"))?;
Ok(RpcMessage::Response {
msgid,
error,
result,
})
},
2 => {
let method: String = seq
.next_element()?
.ok_or_else(|| de::Error::custom("notify: missing method"))?;
let params: Value = seq
.next_element()?
.ok_or_else(|| de::Error::custom("notify: missing params"))?;
Ok(RpcMessage::Notify { method, params })
},
other => Err(de::Error::custom(format!(
"unknown message type tag: {other}"
))),
}
}
}
deserializer.deserialize_seq(MsgVisitor)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rmpv::Value;
fn round_trip(msg: &RpcMessage) -> RpcMessage {
let bytes = rmp_serde::to_vec(msg).expect("encode");
rmp_serde::from_slice(&bytes).expect("decode")
}
#[test]
fn encodes_request_as_4_element_array_with_tag_0() {
let msg = RpcMessage::request(7, "echo", Value::String("hi".into()));
let bytes = rmp_serde::to_vec(&msg).unwrap();
let value: Value = rmp_serde::from_slice(&bytes).unwrap();
let array = value.as_array().expect("array");
assert_eq!(array.len(), 4);
assert_eq!(array[0].as_u64(), Some(0));
assert_eq!(array[1].as_u64(), Some(7));
assert_eq!(array[2].as_str(), Some("echo"));
assert_eq!(array[3].as_str(), Some("hi"));
}
#[test]
fn encodes_response_ok_as_4_element_array_with_tag_1_and_nil_error() {
let msg = RpcMessage::response_ok(42, Value::Integer(99.into()));
let bytes = rmp_serde::to_vec(&msg).unwrap();
let value: Value = rmp_serde::from_slice(&bytes).unwrap();
let array = value.as_array().expect("array");
assert_eq!(array.len(), 4);
assert_eq!(array[0].as_u64(), Some(1));
assert_eq!(array[1].as_u64(), Some(42));
assert!(array[2].is_nil());
assert_eq!(array[3].as_u64(), Some(99));
}
#[test]
fn encodes_response_err_as_4_element_array_with_nil_result() {
let err = Value::String("boom".into());
let msg = RpcMessage::response_err(42, err);
let bytes = rmp_serde::to_vec(&msg).unwrap();
let value: Value = rmp_serde::from_slice(&bytes).unwrap();
let array = value.as_array().expect("array");
assert_eq!(array.len(), 4);
assert_eq!(array[0].as_u64(), Some(1));
assert_eq!(array[2].as_str(), Some("boom"));
assert!(array[3].is_nil());
}
#[test]
fn encodes_notify_as_3_element_array_with_tag_2() {
let msg = RpcMessage::notify("control.ready", Value::Map(vec![]));
let bytes = rmp_serde::to_vec(&msg).unwrap();
let value: Value = rmp_serde::from_slice(&bytes).unwrap();
let array = value.as_array().expect("array");
assert_eq!(array.len(), 3);
assert_eq!(array[0].as_u64(), Some(2));
assert_eq!(array[1].as_str(), Some("control.ready"));
assert!(array[2].is_map());
}
#[test]
fn round_trip_request_preserves_all_fields() {
let msg = RpcMessage::request(
123,
"http.handle",
Value::Array(vec![
Value::String("GET".into()),
Value::String("/health".into()),
]),
);
assert_eq!(msg, round_trip(&msg));
}
#[test]
fn round_trip_response_ok_preserves_result() {
let msg = RpcMessage::response_ok(7, Value::String("pong".into()));
assert_eq!(msg, round_trip(&msg));
}
#[test]
fn round_trip_response_err_preserves_error() {
let msg = RpcMessage::response_err(7, Value::Integer(500.into()));
assert_eq!(msg, round_trip(&msg));
}
#[test]
fn round_trip_notify_preserves_method_and_params() {
let msg = RpcMessage::notify("control.heartbeat", Value::Nil);
assert_eq!(msg, round_trip(&msg));
}
#[test]
fn rejects_unknown_type_tag() {
let bad = Value::Array(vec![
Value::Integer(5.into()),
Value::Integer(0.into()),
Value::Nil,
Value::Nil,
]);
let bytes = rmp_serde::to_vec(&bad).unwrap();
let result: std::result::Result<RpcMessage, _> = rmp_serde::from_slice(&bytes);
assert!(result.is_err(), "should reject tag != 0/1/2");
}
#[test]
fn rejects_request_with_missing_fields() {
let bad = Value::Array(vec![Value::Integer(0.into()), Value::Integer(1.into())]);
let bytes = rmp_serde::to_vec(&bad).unwrap();
let result: std::result::Result<RpcMessage, _> = rmp_serde::from_slice(&bytes);
assert!(result.is_err(), "should reject incomplete request");
}
#[test]
fn rejects_non_array_input() {
let bad = Value::String("not a message".into());
let bytes = rmp_serde::to_vec(&bad).unwrap();
let result: std::result::Result<RpcMessage, _> = rmp_serde::from_slice(&bytes);
assert!(result.is_err(), "should reject non-array input");
}
}