1use std::{fmt, marker::PhantomData};
2
3use serde::{de, ser, Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::id::Id;
7
8pub type Params = Vec<Value>;
10
11#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
13#[serde(deny_unknown_fields)]
14pub struct MethodCall {
15 pub method: String,
20 pub params: Params,
23 pub id: Id,
26}
27
28impl fmt::Display for MethodCall {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 let json = serde_json::to_string(self).expect("`MethodCall` is serializable");
31 write!(f, "{}", json)
32 }
33}
34
35impl MethodCall {
36 pub fn new<M: Into<String>>(method: M, params: Params, id: Id) -> Self {
38 Self {
39 method: method.into(),
40 params,
41 id,
42 }
43 }
44}
45
46#[derive(Clone, Debug, Eq, PartialEq)]
56pub struct Notification {
57 pub method: String,
62 pub params: Params,
65}
66
67impl fmt::Display for Notification {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 let json = serde_json::to_string(self).expect("`Notification` is serializable");
70 write!(f, "{}", json)
71 }
72}
73
74impl ser::Serialize for Notification {
75 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
76 where
77 S: ser::Serializer,
78 {
79 let mut state = ser::Serializer::serialize_struct(serializer, "Notification", 3)?;
80 ser::SerializeStruct::serialize_field(&mut state, "method", &self.method)?;
81 ser::SerializeStruct::serialize_field(&mut state, "params", &self.params)?;
82 ser::SerializeStruct::serialize_field(&mut state, "id", &Option::<Id>::None)?;
83 ser::SerializeStruct::end(state)
84 }
85}
86
87impl<'de> de::Deserialize<'de> for Notification {
88 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
89 where
90 D: de::Deserializer<'de>,
91 {
92 use self::request_field::{Field, FIELDS};
93
94 struct Visitor<'de> {
95 marker: PhantomData<Notification>,
96 lifetime: PhantomData<&'de ()>,
97 }
98 impl<'de> de::Visitor<'de> for Visitor<'de> {
99 type Value = Notification;
100
101 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
102 formatter.write_str("struct Notification")
103 }
104
105 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
106 where
107 A: de::MapAccess<'de>,
108 {
109 let mut method = Option::<String>::None;
110 let mut params = Option::<Params>::None;
111 let mut id = Option::<Option<Id>>::None;
112
113 while let Some(key) = de::MapAccess::next_key::<Field>(&mut map)? {
114 match key {
115 Field::Method => {
116 if method.is_some() {
117 return Err(de::Error::duplicate_field("method"));
118 }
119 method = Some(de::MapAccess::next_value::<String>(&mut map)?)
120 }
121 Field::Params => {
122 if params.is_some() {
123 return Err(de::Error::duplicate_field("params"));
124 }
125 params = Some(de::MapAccess::next_value::<Params>(&mut map)?)
126 }
127 Field::Id => {
128 if id.is_some() {
129 return Err(de::Error::duplicate_field("id"));
130 }
131 id = Some(de::MapAccess::next_value::<Option<Id>>(&mut map)?)
132 }
133 }
134 }
135
136 let method = method.ok_or_else(|| de::Error::missing_field("method"))?;
137 let params = params.ok_or_else(|| de::Error::missing_field("params"))?;
138 let id = id.ok_or_else(|| de::Error::missing_field("id"))?;
139 if id.is_some() {
140 return Err(de::Error::custom("JSON-RPC 1.0 notification id MUST be Null"));
141 }
142 Ok(Notification { method, params })
143 }
144 }
145
146 de::Deserializer::deserialize_struct(
147 deserializer,
148 "Notification",
149 FIELDS,
150 Visitor {
151 marker: PhantomData::<Notification>,
152 lifetime: PhantomData,
153 },
154 )
155 }
156}
157
158impl Notification {
159 pub fn new<M: Into<String>>(method: M, params: Params) -> Self {
161 Self {
162 method: method.into(),
163 params,
164 }
165 }
166}
167
168#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
170#[serde(deny_unknown_fields)]
171#[serde(untagged)]
172pub enum Call {
173 MethodCall(MethodCall),
175 Notification(Notification),
177}
178
179impl fmt::Display for Call {
180 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181 let json = serde_json::to_string(self).expect("`Call` is serializable");
182 write!(f, "{}", json)
183 }
184}
185
186impl Call {
187 pub fn method(&self) -> &str {
189 match self {
190 Self::MethodCall(call) => &call.method,
191 Self::Notification(notification) => ¬ification.method,
192 }
193 }
194
195 pub fn params(&self) -> &Params {
197 match self {
198 Self::MethodCall(call) => &call.params,
199 Self::Notification(notification) => ¬ification.params,
200 }
201 }
202
203 pub fn id(&self) -> Option<Id> {
205 match self {
206 Self::MethodCall(call) => Some(call.id.clone()),
207 Self::Notification(_notification) => None,
208 }
209 }
210}
211
212impl From<MethodCall> for Call {
213 fn from(call: MethodCall) -> Self {
214 Self::MethodCall(call)
215 }
216}
217
218impl From<Notification> for Call {
219 fn from(notify: Notification) -> Self {
220 Self::Notification(notify)
221 }
222}
223
224#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
226#[serde(deny_unknown_fields)]
227#[serde(untagged)]
228pub enum Request {
229 Single(Call),
231 Batch(Vec<Call>),
233}
234
235impl fmt::Display for Request {
236 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237 let json = serde_json::to_string(self).expect("`Request` is serializable");
238 write!(f, "{}", json)
239 }
240}
241
242#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
244#[serde(deny_unknown_fields)]
245#[serde(untagged)]
246pub enum MethodCallRequest {
247 Single(MethodCall),
249 Batch(Vec<MethodCall>),
251}
252
253impl From<MethodCall> for MethodCallRequest {
254 fn from(call: MethodCall) -> Self {
255 Self::Single(call)
256 }
257}
258
259impl From<Vec<MethodCall>> for MethodCallRequest {
260 fn from(calls: Vec<MethodCall>) -> Self {
261 Self::Batch(calls)
262 }
263}
264
265impl fmt::Display for MethodCallRequest {
266 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
267 let json = serde_json::to_string(self).expect("`MethodCallRequest` is serializable");
268 write!(f, "{}", json)
269 }
270}
271
272mod request_field {
273 use super::*;
274
275 pub const FIELDS: &[&str] = &["method", "params", "id"];
276 pub enum Field {
277 Method,
278 Params,
279 Id,
280 }
281
282 impl<'de> de::Deserialize<'de> for Field {
283 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
284 where
285 D: de::Deserializer<'de>,
286 {
287 de::Deserializer::deserialize_identifier(deserializer, FieldVisitor)
288 }
289 }
290
291 struct FieldVisitor;
292 impl<'de> de::Visitor<'de> for FieldVisitor {
293 type Value = Field;
294
295 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
296 formatter.write_str("field identifier")
297 }
298
299 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
300 where
301 E: de::Error,
302 {
303 match v {
304 "method" => Ok(Field::Method),
305 "params" => Ok(Field::Params),
306 "id" => Ok(Field::Id),
307 _ => Err(de::Error::unknown_field(v, &FIELDS)),
308 }
309 }
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 fn method_call_cases() -> Vec<(MethodCall, &'static str)> {
318 vec![
319 (
320 MethodCall {
322 method: "foo".to_string(),
323 params: vec![Value::from(1), Value::Bool(true)],
324 id: Id::Num(1),
325 },
326 r#"{"method":"foo","params":[1,true],"id":1}"#,
327 ),
328 (
329 MethodCall {
331 method: "foo".to_string(),
332 params: vec![],
333 id: Id::Num(1),
334 },
335 r#"{"method":"foo","params":[],"id":1}"#,
336 ),
337 ]
338 }
339
340 fn notification_cases() -> Vec<(Notification, &'static str)> {
341 vec![
342 (
343 Notification {
345 method: "foo".to_string(),
346 params: vec![Value::from(1), Value::Bool(true)],
347 },
348 r#"{"method":"foo","params":[1,true],"id":null}"#,
349 ),
350 (
351 Notification {
353 method: "foo".to_string(),
354 params: vec![],
355 },
356 r#"{"method":"foo","params":[],"id":null}"#,
357 ),
358 ]
359 }
360
361 #[test]
362 fn method_call_serialization() {
363 for (method_call, expect) in method_call_cases() {
364 let ser = serde_json::to_string(&method_call).unwrap();
365 assert_eq!(ser, expect);
366 let de = serde_json::from_str::<MethodCall>(expect).unwrap();
367 assert_eq!(de, method_call);
368 }
369 }
370
371 #[test]
372 fn notification_serialization() {
373 for (notification, expect) in notification_cases() {
374 let ser = serde_json::to_string(¬ification).unwrap();
375 assert_eq!(ser, expect);
376 let de = serde_json::from_str::<Notification>(expect).unwrap();
377 assert_eq!(de, notification);
378 }
379 }
380
381 #[test]
382 fn call_serialization() {
383 for (method_call, expect) in method_call_cases() {
384 let call = Call::MethodCall(method_call);
385 assert_eq!(serde_json::to_string(&call).unwrap(), expect);
386 assert_eq!(serde_json::from_str::<Call>(expect).unwrap(), call);
387 }
388
389 for (notification, expect) in notification_cases() {
390 let call = Call::Notification(notification);
391 assert_eq!(serde_json::to_string(&call).unwrap(), expect);
392 assert_eq!(serde_json::from_str::<Call>(expect).unwrap(), call);
393 }
394 }
395
396 #[test]
397 fn request_serialization() {
398 for (method_call, expect) in method_call_cases() {
399 let call_request = Request::Single(Call::MethodCall(method_call));
400 assert_eq!(serde_json::to_string(&call_request).unwrap(), expect);
401 assert_eq!(serde_json::from_str::<Request>(expect).unwrap(), call_request);
402 }
403
404 for (notification, expect) in notification_cases() {
405 let notification_request = Request::Single(Call::Notification(notification));
406 assert_eq!(serde_json::to_string(¬ification_request).unwrap(), expect);
407 assert_eq!(serde_json::from_str::<Request>(expect).unwrap(), notification_request);
408 }
409
410 let batch_request = Request::Batch(vec![
411 Call::MethodCall(MethodCall {
412 method: "foo".into(),
413 params: vec![],
414 id: Id::Num(1),
415 }),
416 Call::MethodCall(MethodCall {
417 method: "bar".into(),
418 params: vec![],
419 id: Id::Num(2),
420 }),
421 ]);
422 let batch_expect = r#"[{"method":"foo","params":[],"id":1},{"method":"bar","params":[],"id":2}]"#;
423 assert_eq!(serde_json::to_string(&batch_request).unwrap(), batch_expect);
424 assert_eq!(serde_json::from_str::<Request>(&batch_expect).unwrap(), batch_request);
425 }
426
427 #[test]
428 fn invalid_request() {
429 let cases = vec![
430 r#"{"method":"foo","params":[1,true],"id":1,"unknown":[]}"#,
432 r#"{"method":"foo","params":[1,true],"id":1.2}"#,
433 r#"{"method":"foo","params":[1,true],"id":null,"unknown":[]}"#,
434 r#"{"method":"foo","params":[1,true],"unknown":[]}"#,
435 r#"{"method":"foo","params":[1,true]}"#,
436 r#"{"method":"foo","unknown":[]}"#,
437 r#"{"method":1,"unknown":[]}"#,
438 r#"{"unknown":[]}"#,
439 ];
440
441 for case in cases {
442 let request = serde_json::from_str::<Request>(case);
443 assert!(request.is_err());
444 }
445 }
446
447 #[test]
448 fn valid_request() {
449 let cases = vec![
450 r#"{"method":"foo","params":[1,true],"id":1}"#,
452 r#"{"method":"foo","params":[],"id":1}"#,
453 r#"{"method":"foo","params":[1,true],"id":null}"#,
454 r#"{"method":"foo","params":[],"id":null}"#,
455 ];
456
457 for case in cases {
458 let request = serde_json::from_str::<Request>(case);
459 assert!(request.is_ok());
460 }
461 }
462}