1use rmpv::Value;
11use serde::de::{self, SeqAccess, Visitor};
12use serde::ser::SerializeSeq;
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14
15#[derive(Debug, Clone, PartialEq)]
17pub enum RpcMessage {
18 Request {
20 msgid: u32,
21 method: String,
22 params: Value,
23 },
24 Response {
27 msgid: u32,
28 error: Value,
29 result: Value,
30 },
31 Notify { method: String, params: Value },
33}
34
35impl RpcMessage {
36 pub fn request(msgid: u32, method: impl Into<String>, params: Value) -> Self {
38 Self::Request {
39 msgid,
40 method: method.into(),
41 params,
42 }
43 }
44
45 pub fn response_ok(msgid: u32, result: Value) -> Self {
47 Self::Response {
48 msgid,
49 error: Value::Nil,
50 result,
51 }
52 }
53
54 pub fn response_err(msgid: u32, error: Value) -> Self {
56 Self::Response {
57 msgid,
58 error,
59 result: Value::Nil,
60 }
61 }
62
63 pub fn notify(method: impl Into<String>, params: Value) -> Self {
65 Self::Notify {
66 method: method.into(),
67 params,
68 }
69 }
70}
71
72impl Serialize for RpcMessage {
73 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
74 match self {
75 Self::Request {
76 msgid,
77 method,
78 params,
79 } => {
80 let mut seq = serializer.serialize_seq(Some(4))?;
81 seq.serialize_element(&0u8)?;
82 seq.serialize_element(msgid)?;
83 seq.serialize_element(method)?;
84 seq.serialize_element(params)?;
85 seq.end()
86 },
87 Self::Response {
88 msgid,
89 error,
90 result,
91 } => {
92 let mut seq = serializer.serialize_seq(Some(4))?;
93 seq.serialize_element(&1u8)?;
94 seq.serialize_element(msgid)?;
95 seq.serialize_element(error)?;
96 seq.serialize_element(result)?;
97 seq.end()
98 },
99 Self::Notify { method, params } => {
100 let mut seq = serializer.serialize_seq(Some(3))?;
101 seq.serialize_element(&2u8)?;
102 seq.serialize_element(method)?;
103 seq.serialize_element(params)?;
104 seq.end()
105 },
106 }
107 }
108}
109
110impl<'de> Deserialize<'de> for RpcMessage {
111 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
112 struct MsgVisitor;
113
114 impl<'de> Visitor<'de> for MsgVisitor {
115 type Value = RpcMessage;
116
117 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
118 f.write_str("a MessagePack-RPC array of 3 or 4 elements")
119 }
120
121 fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
122 let kind: u8 = seq
123 .next_element()?
124 .ok_or_else(|| de::Error::custom("missing message type tag"))?;
125
126 match kind {
127 0 => {
128 let msgid: u32 = seq
129 .next_element()?
130 .ok_or_else(|| de::Error::custom("request: missing msgid"))?;
131 let method: String = seq
132 .next_element()?
133 .ok_or_else(|| de::Error::custom("request: missing method"))?;
134 let params: Value = seq
135 .next_element()?
136 .ok_or_else(|| de::Error::custom("request: missing params"))?;
137 Ok(RpcMessage::Request {
138 msgid,
139 method,
140 params,
141 })
142 },
143 1 => {
144 let msgid: u32 = seq
145 .next_element()?
146 .ok_or_else(|| de::Error::custom("response: missing msgid"))?;
147 let error: Value = seq
148 .next_element()?
149 .ok_or_else(|| de::Error::custom("response: missing error"))?;
150 let result: Value = seq
151 .next_element()?
152 .ok_or_else(|| de::Error::custom("response: missing result"))?;
153 Ok(RpcMessage::Response {
154 msgid,
155 error,
156 result,
157 })
158 },
159 2 => {
160 let method: String = seq
161 .next_element()?
162 .ok_or_else(|| de::Error::custom("notify: missing method"))?;
163 let params: Value = seq
164 .next_element()?
165 .ok_or_else(|| de::Error::custom("notify: missing params"))?;
166 Ok(RpcMessage::Notify { method, params })
167 },
168 other => Err(de::Error::custom(format!(
169 "unknown message type tag: {other}"
170 ))),
171 }
172 }
173 }
174
175 deserializer.deserialize_seq(MsgVisitor)
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use rmpv::Value;
183
184 fn round_trip(msg: &RpcMessage) -> RpcMessage {
185 let bytes = rmp_serde::to_vec(msg).expect("encode");
186 rmp_serde::from_slice(&bytes).expect("decode")
187 }
188
189 #[test]
190 fn encodes_request_as_4_element_array_with_tag_0() {
191 let msg = RpcMessage::request(7, "echo", Value::String("hi".into()));
192 let bytes = rmp_serde::to_vec(&msg).unwrap();
193 let value: Value = rmp_serde::from_slice(&bytes).unwrap();
194
195 let array = value.as_array().expect("array");
196 assert_eq!(array.len(), 4);
197 assert_eq!(array[0].as_u64(), Some(0));
198 assert_eq!(array[1].as_u64(), Some(7));
199 assert_eq!(array[2].as_str(), Some("echo"));
200 assert_eq!(array[3].as_str(), Some("hi"));
201 }
202
203 #[test]
204 fn encodes_response_ok_as_4_element_array_with_tag_1_and_nil_error() {
205 let msg = RpcMessage::response_ok(42, Value::Integer(99.into()));
206 let bytes = rmp_serde::to_vec(&msg).unwrap();
207 let value: Value = rmp_serde::from_slice(&bytes).unwrap();
208
209 let array = value.as_array().expect("array");
210 assert_eq!(array.len(), 4);
211 assert_eq!(array[0].as_u64(), Some(1));
212 assert_eq!(array[1].as_u64(), Some(42));
213 assert!(array[2].is_nil());
214 assert_eq!(array[3].as_u64(), Some(99));
215 }
216
217 #[test]
218 fn encodes_response_err_as_4_element_array_with_nil_result() {
219 let err = Value::String("boom".into());
220 let msg = RpcMessage::response_err(42, err);
221 let bytes = rmp_serde::to_vec(&msg).unwrap();
222 let value: Value = rmp_serde::from_slice(&bytes).unwrap();
223
224 let array = value.as_array().expect("array");
225 assert_eq!(array.len(), 4);
226 assert_eq!(array[0].as_u64(), Some(1));
227 assert_eq!(array[2].as_str(), Some("boom"));
228 assert!(array[3].is_nil());
229 }
230
231 #[test]
232 fn encodes_notify_as_3_element_array_with_tag_2() {
233 let msg = RpcMessage::notify("control.ready", Value::Map(vec![]));
234 let bytes = rmp_serde::to_vec(&msg).unwrap();
235 let value: Value = rmp_serde::from_slice(&bytes).unwrap();
236
237 let array = value.as_array().expect("array");
238 assert_eq!(array.len(), 3);
239 assert_eq!(array[0].as_u64(), Some(2));
240 assert_eq!(array[1].as_str(), Some("control.ready"));
241 assert!(array[2].is_map());
242 }
243
244 #[test]
245 fn round_trip_request_preserves_all_fields() {
246 let msg = RpcMessage::request(
247 123,
248 "http.handle",
249 Value::Array(vec![
250 Value::String("GET".into()),
251 Value::String("/health".into()),
252 ]),
253 );
254 assert_eq!(msg, round_trip(&msg));
255 }
256
257 #[test]
258 fn round_trip_response_ok_preserves_result() {
259 let msg = RpcMessage::response_ok(7, Value::String("pong".into()));
260 assert_eq!(msg, round_trip(&msg));
261 }
262
263 #[test]
264 fn round_trip_response_err_preserves_error() {
265 let msg = RpcMessage::response_err(7, Value::Integer(500.into()));
266 assert_eq!(msg, round_trip(&msg));
267 }
268
269 #[test]
270 fn round_trip_notify_preserves_method_and_params() {
271 let msg = RpcMessage::notify("control.heartbeat", Value::Nil);
272 assert_eq!(msg, round_trip(&msg));
273 }
274
275 #[test]
276 fn rejects_unknown_type_tag() {
277 let bad = Value::Array(vec![
278 Value::Integer(5.into()),
279 Value::Integer(0.into()),
280 Value::Nil,
281 Value::Nil,
282 ]);
283 let bytes = rmp_serde::to_vec(&bad).unwrap();
284 let result: std::result::Result<RpcMessage, _> = rmp_serde::from_slice(&bytes);
285 assert!(result.is_err(), "should reject tag != 0/1/2");
286 }
287
288 #[test]
289 fn rejects_request_with_missing_fields() {
290 let bad = Value::Array(vec![Value::Integer(0.into()), Value::Integer(1.into())]);
291 let bytes = rmp_serde::to_vec(&bad).unwrap();
292 let result: std::result::Result<RpcMessage, _> = rmp_serde::from_slice(&bytes);
293 assert!(result.is_err(), "should reject incomplete request");
294 }
295
296 #[test]
297 fn rejects_non_array_input() {
298 let bad = Value::String("not a message".into());
299 let bytes = rmp_serde::to_vec(&bad).unwrap();
300 let result: std::result::Result<RpcMessage, _> = rmp_serde::from_slice(&bytes);
301 assert!(result.is_err(), "should reject non-array input");
302 }
303}