mrpc/
message.rs

1//! Defines the MessagePack-RPC message types and their serialization/deserialization.
2//!
3//! Includes structures for requests, responses, and notifications, as well as
4//! utilities for encoding and decoding these messages.
5use rmpv::Value;
6use std::io::{Read, Write};
7
8use crate::error::*;
9
10const REQUEST_MESSAGE: u64 = 0;
11const RESPONSE_MESSAGE: u64 = 1;
12const NOTIFICATION_MESSAGE: u64 = 2;
13
14/// Represents the different types of RPC messages: requests, responses, and notifications.
15#[derive(PartialEq, Clone, Debug)]
16pub enum Message {
17    Request(Request),
18    Response(Response),
19    Notification(Notification),
20}
21
22/// An RPC request message containing an ID, method name, and parameters.
23#[derive(PartialEq, Clone, Debug)]
24pub struct Request {
25    pub id: u32,
26    pub method: String,
27    pub params: Vec<Value>,
28}
29
30/// An RPC response message containing an ID and either a result or an error.
31#[derive(PartialEq, Clone, Debug)]
32pub struct Response {
33    pub id: u32,
34    pub result: std::result::Result<Value, Value>,
35}
36
37/// An RPC notification message containing a method name and parameters.
38#[derive(PartialEq, Clone, Debug)]
39pub struct Notification {
40    pub method: String,
41    pub params: Vec<Value>,
42}
43
44impl Message {
45    /// Converts the message to a MessagePack-RPC compatible Value.
46    pub fn to_value(&self) -> Value {
47        match self {
48            Message::Request(req) => Value::Array(vec![
49                Value::Integer(REQUEST_MESSAGE.into()),
50                Value::Integer(req.id.into()),
51                Value::String(req.method.clone().into()),
52                Value::Array(req.params.clone()),
53            ]),
54            Message::Response(resp) => Value::Array(vec![
55                Value::Integer(RESPONSE_MESSAGE.into()),
56                Value::Integer(resp.id.into()),
57                match &resp.result {
58                    Ok(_value) => Value::Nil,
59                    Err(err) => err.clone(),
60                },
61                match &resp.result {
62                    Ok(value) => value.clone(),
63                    Err(_) => Value::Nil,
64                },
65            ]),
66            Message::Notification(notif) => Value::Array(vec![
67                Value::Integer(NOTIFICATION_MESSAGE.into()),
68                Value::String(notif.method.clone().into()),
69                Value::Array(notif.params.clone()),
70            ]),
71        }
72    }
73
74    /// Creates a Message from a MessagePack-RPC compatible Value.
75    pub fn from_value(value: Value) -> Result<Self> {
76        match value {
77            Value::Array(array) => {
78                if array.is_empty() {
79                    return Err(RpcError::Protocol("Empty message array".into()));
80                }
81                match array[0] {
82                    Value::Integer(msg_type) => match msg_type.as_u64() {
83                        Some(REQUEST_MESSAGE) => {
84                            if array.len() != 4 {
85                                return Err(RpcError::Protocol(
86                                    "Invalid request message length".into(),
87                                ));
88                            }
89                            let id = array[1]
90                                .as_u64()
91                                .ok_or(RpcError::Protocol("Invalid request id".into()))?
92                                as u32;
93                            let method = array[2]
94                                .as_str()
95                                .ok_or(RpcError::Protocol("Invalid request method".into()))?
96                                .to_string();
97                            let params = match &array[3] {
98                                Value::Array(params) => params.clone(),
99                                _ => {
100                                    return Err(RpcError::Protocol("Invalid request params".into()))
101                                }
102                            };
103                            Ok(Message::Request(Request { id, method, params }))
104                        }
105                        Some(RESPONSE_MESSAGE) => {
106                            if array.len() != 4 {
107                                return Err(RpcError::Protocol(
108                                    "Invalid response message length".into(),
109                                ));
110                            }
111                            let id = array[1]
112                                .as_u64()
113                                .ok_or(RpcError::Protocol("Invalid response id".into()))?
114                                as u32;
115                            let result = if array[2] == Value::Nil {
116                                Ok(array[3].clone())
117                            } else {
118                                Err(array[2].clone())
119                            };
120                            Ok(Message::Response(Response { id, result }))
121                        }
122                        Some(NOTIFICATION_MESSAGE) => {
123                            if array.len() != 3 {
124                                return Err(RpcError::Protocol(
125                                    "Invalid notification message length".into(),
126                                ));
127                            }
128                            let method = array[1]
129                                .as_str()
130                                .ok_or(RpcError::Protocol("Invalid notification method".into()))?
131                                .to_string();
132                            let params = match &array[2] {
133                                Value::Array(params) => params.clone(),
134                                _ => {
135                                    return Err(RpcError::Protocol(
136                                        "Invalid notification params".into(),
137                                    ))
138                                }
139                            };
140                            Ok(Message::Notification(Notification { method, params }))
141                        }
142                        _ => Err(RpcError::Protocol("Invalid message type".into())),
143                    },
144                    _ => Err(RpcError::Protocol("Invalid message type".into())),
145                }
146            }
147            _ => Err(RpcError::Protocol("Invalid message format".into())),
148        }
149    }
150
151    /// Encodes the message to MessagePack format and writes it to the given writer.
152    pub fn encode<W: Write>(&self, writer: &mut W) -> Result<()> {
153        let value = self.to_value();
154        rmpv::encode::write_value(writer, &value)?;
155        Ok(())
156    }
157
158    /// Reads and decodes a message from MessagePack format using the given reader.
159    pub fn decode<R: Read>(reader: &mut R) -> Result<Self> {
160        match rmpv::decode::read_value(reader) {
161            Ok(value) => Self::from_value(value),
162            Err(rmpv::decode::Error::InvalidMarkerRead(e))
163            | Err(rmpv::decode::Error::InvalidDataRead(e)) => Err(RpcError::from(e)),
164            Err(rmpv::decode::Error::DepthLimitExceeded) => {
165                Err(RpcError::Protocol("Depth limit exceeded".into()))
166            }
167        }
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use std::io::Cursor;
175
176    // Test cases at the top level of the test module
177    lazy_static::lazy_static! {
178        static ref TEST_CASES: Vec<Message> = vec![
179            Message::Request(Request {
180                id: 1,
181                method: "test_method".to_string(),
182                params: vec![Value::String("param1".into()), Value::Integer(42.into())],
183            }),
184            Message::Response(Response {
185                id: 2,
186                result: Ok(Value::String("success".into())),
187            }),
188            Message::Response(Response {
189                id: 3,
190                result: Err(Value::String("error".into())),
191            }),
192            Message::Notification(Notification {
193                method: "test_notification".to_string(),
194                params: vec![Value::Boolean(true), Value::F64(2.14)],
195            }),
196            Message::Request(Request {
197                id: 4,
198                method: "complex_method".to_string(),
199                params: vec![
200                    Value::Array(vec![Value::String("nested".into()), Value::Integer(1.into())]),
201                    Value::Map(vec![
202                        (Value::String("key".into()), Value::Boolean(true)),
203                        (Value::String("value".into()), Value::F64(1.718)),
204                    ]),
205                ],
206            }),
207        ];
208    }
209
210    #[test]
211    fn test_message_idempotence_and_invalid_inputs() {
212        // Helper function to test idempotence
213        fn assert_idempotence(message: &Message) {
214            let value = message.to_value();
215            let roundtrip_message = Message::from_value(value).unwrap();
216            assert_eq!(message, &roundtrip_message);
217        }
218
219        // Test idempotence for all cases
220        for message in TEST_CASES.iter() {
221            assert_idempotence(message);
222        }
223
224        // Test invalid inputs
225        let invalid_values = vec![
226            Value::Nil,
227            Value::Boolean(true),
228            Value::Integer(42.into()),
229            Value::String("not an array".into()),
230            Value::Array(vec![]),
231            Value::Array(vec![Value::Integer(999.into())]), // Invalid message type
232            Value::Array(vec![Value::Integer(REQUEST_MESSAGE.into())]), // Incomplete request
233        ];
234
235        for invalid_value in invalid_values {
236            assert!(Message::from_value(invalid_value).is_err());
237        }
238    }
239
240    #[test]
241    fn test_message_round_trip_with_buffer() {
242        for original_message in TEST_CASES.iter() {
243            // Serialize the message to a buffer
244            let mut write_buffer = Vec::new();
245            original_message.encode(&mut write_buffer).unwrap();
246
247            // Deserialize the message from the buffer
248            let mut read_buffer = Cursor::new(write_buffer);
249            let deserialized_message = Message::decode(&mut read_buffer).unwrap();
250
251            // Assert that the deserialized message matches the original
252            assert_eq!(original_message, &deserialized_message);
253
254            // Ensure the entire buffer was consumed
255            assert_eq!(read_buffer.position() as usize, read_buffer.get_ref().len());
256        }
257    }
258}