Skip to main content

folk_protocol/
message.rs

1//! `MessagePack`-RPC message types.
2//!
3//! Three message variants per the `MessagePack`-RPC spec, each serialized as
4//! a positional `MessagePack` array:
5//!
6//! - `Request`: `[0, msgid, method, params]`
7//! - `Response`: `[1, msgid, error, result]`
8//! - `Notify`: `[2, method, params]`
9
10use rmpv::Value;
11use serde::de::{self, SeqAccess, Visitor};
12use serde::ser::SerializeSeq;
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14
15/// A single RPC message on the wire.
16#[derive(Debug, Clone, PartialEq)]
17pub enum RpcMessage {
18    /// Caller asks for a method invocation; expects a `Response` with matching `msgid`.
19    Request {
20        msgid: u32,
21        method: String,
22        params: Value,
23    },
24    /// Reply to a `Request` with the same `msgid`. Exactly one of `error`/`result`
25    /// is meaningful; the other should be `nil` (`Value::Nil`).
26    Response {
27        msgid: u32,
28        error: Value,
29        result: Value,
30    },
31    /// One-way message; no `Response` is expected.
32    Notify { method: String, params: Value },
33}
34
35impl RpcMessage {
36    /// Construct a `Request`.
37    pub fn request(msgid: u32, method: impl Into<String>, params: Value) -> Self {
38        Self::Request {
39            msgid,
40            method: method.into(),
41            params,
42        }
43    }
44
45    /// Construct a `Response` carrying a successful result. `error` is set to `Nil`.
46    pub fn response_ok(msgid: u32, result: Value) -> Self {
47        Self::Response {
48            msgid,
49            error: Value::Nil,
50            result,
51        }
52    }
53
54    /// Construct a `Response` carrying an error. `result` is set to `Nil`.
55    pub fn response_err(msgid: u32, error: Value) -> Self {
56        Self::Response {
57            msgid,
58            error,
59            result: Value::Nil,
60        }
61    }
62
63    /// Construct a `Notify`.
64    pub fn notify(method: impl Into<String>, params: Value) -> Self {
65        Self::Notify {
66            method: method.into(),
67            params,
68        }
69    }
70}
71
72impl Serialize for RpcMessage {
73    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
74        match self {
75            Self::Request {
76                msgid,
77                method,
78                params,
79            } => {
80                let mut seq = serializer.serialize_seq(Some(4))?;
81                seq.serialize_element(&0u8)?;
82                seq.serialize_element(msgid)?;
83                seq.serialize_element(method)?;
84                seq.serialize_element(params)?;
85                seq.end()
86            },
87            Self::Response {
88                msgid,
89                error,
90                result,
91            } => {
92                let mut seq = serializer.serialize_seq(Some(4))?;
93                seq.serialize_element(&1u8)?;
94                seq.serialize_element(msgid)?;
95                seq.serialize_element(error)?;
96                seq.serialize_element(result)?;
97                seq.end()
98            },
99            Self::Notify { method, params } => {
100                let mut seq = serializer.serialize_seq(Some(3))?;
101                seq.serialize_element(&2u8)?;
102                seq.serialize_element(method)?;
103                seq.serialize_element(params)?;
104                seq.end()
105            },
106        }
107    }
108}
109
110impl<'de> Deserialize<'de> for RpcMessage {
111    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
112        struct MsgVisitor;
113
114        impl<'de> Visitor<'de> for MsgVisitor {
115            type Value = RpcMessage;
116
117            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
118                f.write_str("a MessagePack-RPC array of 3 or 4 elements")
119            }
120
121            fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
122                let kind: u8 = seq
123                    .next_element()?
124                    .ok_or_else(|| de::Error::custom("missing message type tag"))?;
125
126                match kind {
127                    0 => {
128                        let msgid: u32 = seq
129                            .next_element()?
130                            .ok_or_else(|| de::Error::custom("request: missing msgid"))?;
131                        let method: String = seq
132                            .next_element()?
133                            .ok_or_else(|| de::Error::custom("request: missing method"))?;
134                        let params: Value = seq
135                            .next_element()?
136                            .ok_or_else(|| de::Error::custom("request: missing params"))?;
137                        Ok(RpcMessage::Request {
138                            msgid,
139                            method,
140                            params,
141                        })
142                    },
143                    1 => {
144                        let msgid: u32 = seq
145                            .next_element()?
146                            .ok_or_else(|| de::Error::custom("response: missing msgid"))?;
147                        let error: Value = seq
148                            .next_element()?
149                            .ok_or_else(|| de::Error::custom("response: missing error"))?;
150                        let result: Value = seq
151                            .next_element()?
152                            .ok_or_else(|| de::Error::custom("response: missing result"))?;
153                        Ok(RpcMessage::Response {
154                            msgid,
155                            error,
156                            result,
157                        })
158                    },
159                    2 => {
160                        let method: String = seq
161                            .next_element()?
162                            .ok_or_else(|| de::Error::custom("notify: missing method"))?;
163                        let params: Value = seq
164                            .next_element()?
165                            .ok_or_else(|| de::Error::custom("notify: missing params"))?;
166                        Ok(RpcMessage::Notify { method, params })
167                    },
168                    other => Err(de::Error::custom(format!(
169                        "unknown message type tag: {other}"
170                    ))),
171                }
172            }
173        }
174
175        deserializer.deserialize_seq(MsgVisitor)
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use rmpv::Value;
183
184    fn round_trip(msg: &RpcMessage) -> RpcMessage {
185        let bytes = rmp_serde::to_vec(msg).expect("encode");
186        rmp_serde::from_slice(&bytes).expect("decode")
187    }
188
189    #[test]
190    fn encodes_request_as_4_element_array_with_tag_0() {
191        let msg = RpcMessage::request(7, "echo", Value::String("hi".into()));
192        let bytes = rmp_serde::to_vec(&msg).unwrap();
193        let value: Value = rmp_serde::from_slice(&bytes).unwrap();
194
195        let array = value.as_array().expect("array");
196        assert_eq!(array.len(), 4);
197        assert_eq!(array[0].as_u64(), Some(0));
198        assert_eq!(array[1].as_u64(), Some(7));
199        assert_eq!(array[2].as_str(), Some("echo"));
200        assert_eq!(array[3].as_str(), Some("hi"));
201    }
202
203    #[test]
204    fn encodes_response_ok_as_4_element_array_with_tag_1_and_nil_error() {
205        let msg = RpcMessage::response_ok(42, Value::Integer(99.into()));
206        let bytes = rmp_serde::to_vec(&msg).unwrap();
207        let value: Value = rmp_serde::from_slice(&bytes).unwrap();
208
209        let array = value.as_array().expect("array");
210        assert_eq!(array.len(), 4);
211        assert_eq!(array[0].as_u64(), Some(1));
212        assert_eq!(array[1].as_u64(), Some(42));
213        assert!(array[2].is_nil());
214        assert_eq!(array[3].as_u64(), Some(99));
215    }
216
217    #[test]
218    fn encodes_response_err_as_4_element_array_with_nil_result() {
219        let err = Value::String("boom".into());
220        let msg = RpcMessage::response_err(42, err);
221        let bytes = rmp_serde::to_vec(&msg).unwrap();
222        let value: Value = rmp_serde::from_slice(&bytes).unwrap();
223
224        let array = value.as_array().expect("array");
225        assert_eq!(array.len(), 4);
226        assert_eq!(array[0].as_u64(), Some(1));
227        assert_eq!(array[2].as_str(), Some("boom"));
228        assert!(array[3].is_nil());
229    }
230
231    #[test]
232    fn encodes_notify_as_3_element_array_with_tag_2() {
233        let msg = RpcMessage::notify("control.ready", Value::Map(vec![]));
234        let bytes = rmp_serde::to_vec(&msg).unwrap();
235        let value: Value = rmp_serde::from_slice(&bytes).unwrap();
236
237        let array = value.as_array().expect("array");
238        assert_eq!(array.len(), 3);
239        assert_eq!(array[0].as_u64(), Some(2));
240        assert_eq!(array[1].as_str(), Some("control.ready"));
241        assert!(array[2].is_map());
242    }
243
244    #[test]
245    fn round_trip_request_preserves_all_fields() {
246        let msg = RpcMessage::request(
247            123,
248            "http.handle",
249            Value::Array(vec![
250                Value::String("GET".into()),
251                Value::String("/health".into()),
252            ]),
253        );
254        assert_eq!(msg, round_trip(&msg));
255    }
256
257    #[test]
258    fn round_trip_response_ok_preserves_result() {
259        let msg = RpcMessage::response_ok(7, Value::String("pong".into()));
260        assert_eq!(msg, round_trip(&msg));
261    }
262
263    #[test]
264    fn round_trip_response_err_preserves_error() {
265        let msg = RpcMessage::response_err(7, Value::Integer(500.into()));
266        assert_eq!(msg, round_trip(&msg));
267    }
268
269    #[test]
270    fn round_trip_notify_preserves_method_and_params() {
271        let msg = RpcMessage::notify("control.heartbeat", Value::Nil);
272        assert_eq!(msg, round_trip(&msg));
273    }
274
275    #[test]
276    fn rejects_unknown_type_tag() {
277        let bad = Value::Array(vec![
278            Value::Integer(5.into()),
279            Value::Integer(0.into()),
280            Value::Nil,
281            Value::Nil,
282        ]);
283        let bytes = rmp_serde::to_vec(&bad).unwrap();
284        let result: std::result::Result<RpcMessage, _> = rmp_serde::from_slice(&bytes);
285        assert!(result.is_err(), "should reject tag != 0/1/2");
286    }
287
288    #[test]
289    fn rejects_request_with_missing_fields() {
290        let bad = Value::Array(vec![Value::Integer(0.into()), Value::Integer(1.into())]);
291        let bytes = rmp_serde::to_vec(&bad).unwrap();
292        let result: std::result::Result<RpcMessage, _> = rmp_serde::from_slice(&bytes);
293        assert!(result.is_err(), "should reject incomplete request");
294    }
295
296    #[test]
297    fn rejects_non_array_input() {
298        let bad = Value::String("not a message".into());
299        let bytes = rmp_serde::to_vec(&bad).unwrap();
300        let result: std::result::Result<RpcMessage, _> = rmp_serde::from_slice(&bytes);
301        assert!(result.is_err(), "should reject non-array input");
302    }
303}