1use rmpv::Value;
6use std::io::{Read, Write};
7
8use crate::error::*;
9
10const REQUEST_MESSAGE: u64 = 0;
11const RESPONSE_MESSAGE: u64 = 1;
12const NOTIFICATION_MESSAGE: u64 = 2;
13
14#[derive(PartialEq, Clone, Debug)]
16pub enum Message {
17 Request(Request),
18 Response(Response),
19 Notification(Notification),
20}
21
22#[derive(PartialEq, Clone, Debug)]
24pub struct Request {
25 pub id: u32,
26 pub method: String,
27 pub params: Vec<Value>,
28}
29
30#[derive(PartialEq, Clone, Debug)]
32pub struct Response {
33 pub id: u32,
34 pub result: std::result::Result<Value, Value>,
35}
36
37#[derive(PartialEq, Clone, Debug)]
39pub struct Notification {
40 pub method: String,
41 pub params: Vec<Value>,
42}
43
44impl Message {
45 pub fn to_value(&self) -> Value {
47 match self {
48 Message::Request(req) => Value::Array(vec![
49 Value::Integer(REQUEST_MESSAGE.into()),
50 Value::Integer(req.id.into()),
51 Value::String(req.method.clone().into()),
52 Value::Array(req.params.clone()),
53 ]),
54 Message::Response(resp) => Value::Array(vec![
55 Value::Integer(RESPONSE_MESSAGE.into()),
56 Value::Integer(resp.id.into()),
57 match &resp.result {
58 Ok(_value) => Value::Nil,
59 Err(err) => err.clone(),
60 },
61 match &resp.result {
62 Ok(value) => value.clone(),
63 Err(_) => Value::Nil,
64 },
65 ]),
66 Message::Notification(notif) => Value::Array(vec![
67 Value::Integer(NOTIFICATION_MESSAGE.into()),
68 Value::String(notif.method.clone().into()),
69 Value::Array(notif.params.clone()),
70 ]),
71 }
72 }
73
74 pub fn from_value(value: Value) -> Result<Self> {
76 match value {
77 Value::Array(array) => {
78 if array.is_empty() {
79 return Err(RpcError::Protocol("Empty message array".into()));
80 }
81 match array[0] {
82 Value::Integer(msg_type) => match msg_type.as_u64() {
83 Some(REQUEST_MESSAGE) => {
84 if array.len() != 4 {
85 return Err(RpcError::Protocol(
86 "Invalid request message length".into(),
87 ));
88 }
89 let id = array[1]
90 .as_u64()
91 .ok_or(RpcError::Protocol("Invalid request id".into()))?
92 as u32;
93 let method = array[2]
94 .as_str()
95 .ok_or(RpcError::Protocol("Invalid request method".into()))?
96 .to_string();
97 let params = match &array[3] {
98 Value::Array(params) => params.clone(),
99 _ => {
100 return Err(RpcError::Protocol("Invalid request params".into()))
101 }
102 };
103 Ok(Message::Request(Request { id, method, params }))
104 }
105 Some(RESPONSE_MESSAGE) => {
106 if array.len() != 4 {
107 return Err(RpcError::Protocol(
108 "Invalid response message length".into(),
109 ));
110 }
111 let id = array[1]
112 .as_u64()
113 .ok_or(RpcError::Protocol("Invalid response id".into()))?
114 as u32;
115 let result = if array[2] == Value::Nil {
116 Ok(array[3].clone())
117 } else {
118 Err(array[2].clone())
119 };
120 Ok(Message::Response(Response { id, result }))
121 }
122 Some(NOTIFICATION_MESSAGE) => {
123 if array.len() != 3 {
124 return Err(RpcError::Protocol(
125 "Invalid notification message length".into(),
126 ));
127 }
128 let method = array[1]
129 .as_str()
130 .ok_or(RpcError::Protocol("Invalid notification method".into()))?
131 .to_string();
132 let params = match &array[2] {
133 Value::Array(params) => params.clone(),
134 _ => {
135 return Err(RpcError::Protocol(
136 "Invalid notification params".into(),
137 ))
138 }
139 };
140 Ok(Message::Notification(Notification { method, params }))
141 }
142 _ => Err(RpcError::Protocol("Invalid message type".into())),
143 },
144 _ => Err(RpcError::Protocol("Invalid message type".into())),
145 }
146 }
147 _ => Err(RpcError::Protocol("Invalid message format".into())),
148 }
149 }
150
151 pub fn encode<W: Write>(&self, writer: &mut W) -> Result<()> {
153 let value = self.to_value();
154 rmpv::encode::write_value(writer, &value)?;
155 Ok(())
156 }
157
158 pub fn decode<R: Read>(reader: &mut R) -> Result<Self> {
160 match rmpv::decode::read_value(reader) {
161 Ok(value) => Self::from_value(value),
162 Err(rmpv::decode::Error::InvalidMarkerRead(e))
163 | Err(rmpv::decode::Error::InvalidDataRead(e)) => Err(RpcError::from(e)),
164 Err(rmpv::decode::Error::DepthLimitExceeded) => {
165 Err(RpcError::Protocol("Depth limit exceeded".into()))
166 }
167 }
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use std::io::Cursor;
175
176 lazy_static::lazy_static! {
178 static ref TEST_CASES: Vec<Message> = vec![
179 Message::Request(Request {
180 id: 1,
181 method: "test_method".to_string(),
182 params: vec![Value::String("param1".into()), Value::Integer(42.into())],
183 }),
184 Message::Response(Response {
185 id: 2,
186 result: Ok(Value::String("success".into())),
187 }),
188 Message::Response(Response {
189 id: 3,
190 result: Err(Value::String("error".into())),
191 }),
192 Message::Notification(Notification {
193 method: "test_notification".to_string(),
194 params: vec![Value::Boolean(true), Value::F64(2.14)],
195 }),
196 Message::Request(Request {
197 id: 4,
198 method: "complex_method".to_string(),
199 params: vec![
200 Value::Array(vec![Value::String("nested".into()), Value::Integer(1.into())]),
201 Value::Map(vec![
202 (Value::String("key".into()), Value::Boolean(true)),
203 (Value::String("value".into()), Value::F64(1.718)),
204 ]),
205 ],
206 }),
207 ];
208 }
209
210 #[test]
211 fn test_message_idempotence_and_invalid_inputs() {
212 fn assert_idempotence(message: &Message) {
214 let value = message.to_value();
215 let roundtrip_message = Message::from_value(value).unwrap();
216 assert_eq!(message, &roundtrip_message);
217 }
218
219 for message in TEST_CASES.iter() {
221 assert_idempotence(message);
222 }
223
224 let invalid_values = vec![
226 Value::Nil,
227 Value::Boolean(true),
228 Value::Integer(42.into()),
229 Value::String("not an array".into()),
230 Value::Array(vec![]),
231 Value::Array(vec![Value::Integer(999.into())]), Value::Array(vec![Value::Integer(REQUEST_MESSAGE.into())]), ];
234
235 for invalid_value in invalid_values {
236 assert!(Message::from_value(invalid_value).is_err());
237 }
238 }
239
240 #[test]
241 fn test_message_round_trip_with_buffer() {
242 for original_message in TEST_CASES.iter() {
243 let mut write_buffer = Vec::new();
245 original_message.encode(&mut write_buffer).unwrap();
246
247 let mut read_buffer = Cursor::new(write_buffer);
249 let deserialized_message = Message::decode(&mut read_buffer).unwrap();
250
251 assert_eq!(original_message, &deserialized_message);
253
254 assert_eq!(read_buffer.position() as usize, read_buffer.get_ref().len());
256 }
257 }
258}