1use std::{borrow::Cow, fmt};
2
3use derive_ex::derive_ex;
4use ordered_float::OrderedFloat;
5use parse_display::Display;
6use serde::{Deserialize, Deserializer, Serialize};
7use serde_json::{Value, value::RawValue};
8
9use crate::{Error, OutgoingRequestId, SessionError, SessionResult, utils::write_string_no_escape};
10
11use super::Result;
12
13#[derive(Debug, Serialize, Deserialize, Clone, Display)]
14#[derive_ex(Eq, PartialEq, Hash)]
15#[serde(transparent)]
16pub struct RequestId(#[eq(key = $.key())] RawRequestId);
17
18impl From<u64> for RequestId {
19 fn from(id: u64) -> Self {
20 RequestId(RawRequestId::U64(id))
21 }
22}
23
24#[derive(Debug, Serialize, Deserialize, Clone, Display)]
25#[display("{0}")]
26#[serde(untagged)]
27enum RawRequestId {
28 U64(u64),
29 I64(i64),
30 F64(f64),
31 #[display("\"{0}\"")]
32 String(String),
33}
34
35#[derive(Debug, Clone, Eq, PartialEq, Hash)]
36enum RawRequestIdKey<'a> {
37 U64(u64),
38 I64(i64),
39 F64(OrderedFloat<f64>),
40 String(&'a str),
41}
42
43impl RawRequestId {
44 fn key(&self) -> RawRequestIdKey {
45 match self {
46 RawRequestId::U64(n) => RawRequestIdKey::U64(*n),
47 RawRequestId::I64(n) if *n > 0 => RawRequestIdKey::U64(*n as u64),
48 RawRequestId::I64(n) => RawRequestIdKey::I64(*n),
49 RawRequestId::F64(f)
50 if f.fract() == 0.0 && u64::MIN as f64 <= *f && *f <= u64::MAX as f64 =>
51 {
52 RawRequestIdKey::U64(*f as u64)
53 }
54 RawRequestId::F64(f) => RawRequestIdKey::F64(OrderedFloat(*f)),
55 RawRequestId::String(s) => RawRequestIdKey::String(s),
56 }
57 }
58}
59
60const MAX_SAFE_INTEGER: u64 = 9007199254740991;
61impl From<OutgoingRequestId> for RequestId {
62 fn from(id: OutgoingRequestId) -> Self {
63 if id.0 < MAX_SAFE_INTEGER as u128 {
64 RequestId(RawRequestId::U64(id.0 as u64))
65 } else {
66 RequestId(RawRequestId::String(id.0.to_string()))
67 }
68 }
69}
70impl TryFrom<RequestId> for OutgoingRequestId {
71 type Error = SessionError;
72 fn try_from(id: RequestId) -> SessionResult<OutgoingRequestId> {
73 TryFrom::<&RequestId>::try_from(&id)
74 }
75}
76impl TryFrom<&RequestId> for OutgoingRequestId {
77 type Error = SessionError;
78 fn try_from(id: &RequestId) -> SessionResult<OutgoingRequestId> {
79 match id.0 {
80 RawRequestId::U64(n) => return Ok(OutgoingRequestId(n as u128)),
81 RawRequestId::I64(n) => {
82 if let Ok(value) = n.try_into() {
83 return Ok(OutgoingRequestId(value));
84 }
85 }
86 RawRequestId::F64(f) => {
87 if f.fract() == 0.0 && 0.0 <= f && f <= MAX_SAFE_INTEGER as f64 {
88 return Ok(OutgoingRequestId(f as u128));
89 }
90 }
91 RawRequestId::String(ref s) => {
92 if let Ok(n) = s.parse() {
93 return Ok(OutgoingRequestId(n));
94 }
95 }
96 }
97 Err(SessionError::request_id_not_found())
98 }
99}
100
101#[derive(Debug, Serialize, Deserialize, Eq, PartialEq)]
102#[derive_ex(Default, bound())]
103#[serde(bound(deserialize = "&'a P: Deserialize<'de>, &'a R: Deserialize<'de>"))]
104pub(crate) struct RawMessage<'a, P: ?Sized = RawValue, R: ?Sized = RawValue> {
105 #[serde(borrow)]
106 #[default(Cow::Borrowed("2.0"))]
107 pub jsonrpc: Cow<'a, str>,
108 #[serde(
109 skip_serializing_if = "Option::is_none",
110 deserialize_with = "deserialize_some",
111 default
112 )]
113 pub id: Option<RequestId>,
114 #[serde(
115 skip_serializing_if = "Option::is_none",
116 deserialize_with = "deserialize_some",
117 default,
118 borrow
119 )]
120 pub method: Option<Cow<'a, str>>,
121 #[serde(
122 skip_serializing_if = "Option::is_none",
123 deserialize_with = "deserialize_some",
124 default,
125 borrow
126 )]
127 pub params: Option<&'a P>,
128 #[serde(
129 skip_serializing_if = "Option::is_none",
130 deserialize_with = "deserialize_some",
131 default,
132 borrow
133 )]
134 pub result: Option<&'a R>,
135 #[serde(skip_serializing_if = "Option::is_none")]
136 pub error: Option<ErrorObject>,
137}
138
139fn deserialize_some<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
140where
141 T: Deserialize<'de>,
142 D: Deserializer<'de>,
143{
144 Ok(Some(T::deserialize(deserializer)?))
145}
146
147impl<'a, P: ?Sized, R: ?Sized> RawMessage<'a, P, R>
148where
149 &'a P: Deserialize<'a>,
150 &'a R: Deserialize<'a>,
151{
152 pub fn from_line<'b: 'a>(s: &'b str) -> Result<Vec<Self>, serde_json::Error> {
153 let s = s.trim();
154 if s.starts_with('{') {
155 let m = serde_json::from_str::<Self>(s)?;
156 Ok(vec![m])
157 } else if s.starts_with('[') {
158 serde_json::from_str::<Vec<Self>>(s)
159 } else {
160 Err(<serde_json::Error as serde::de::Error>::custom(
161 "not a json object or array",
162 ))
163 }
164 }
165 pub(crate) fn verify_version(&self) -> Result<()> {
166 if self.jsonrpc == "2.0" {
167 Ok(())
168 } else {
169 Err(Error::unsupported_version())
170 }
171 }
172}
173impl<'a> RawMessage<'a> {
174 pub(crate) fn into_variants(self) -> Result<RawMessageVariants<'a>> {
175 self.verify_version()?;
176 match self {
177 RawMessage {
178 id: Some(id),
179 method: Some(method),
180 params,
181 result: None,
182 error: None,
183 ..
184 } => Ok(RawMessageVariants::Request { id, method, params }),
185 RawMessage {
186 id: Some(id),
187 method: None,
188 params: None,
189 result: Some(result),
190 error: None,
191 ..
192 } => Ok(RawMessageVariants::Success { id, result }),
193 RawMessage {
194 id,
195 method: None,
196 params: None,
197 result: None,
198 error: Some(error),
199 ..
200 } => Ok(RawMessageVariants::Error { id, error }),
201 RawMessage {
202 id: None,
203 method: Some(method),
204 params,
205 result: None,
206 error: None,
207 ..
208 } => Ok(RawMessageVariants::Notification { method, params }),
209 _ => Err(Error::invalid_message()),
210 }
211 }
212}
213
214pub(crate) enum RawMessageVariants<'a> {
215 Request {
216 id: RequestId,
217 method: Cow<'a, str>,
218 params: Option<&'a RawValue>,
219 },
220 Success {
221 id: RequestId,
222 result: &'a RawValue,
223 },
224 Error {
225 id: Option<RequestId>,
226 error: ErrorObject,
227 },
228 Notification {
229 method: Cow<'a, str>,
230 params: Option<&'a RawValue>,
231 },
232}
233
234#[derive(Display, Debug)]
235#[display("{0}")]
236pub(crate) struct MessageData(pub String);
237
238impl MessageData {
239 pub fn from_raw_message<P, R>(msg: &RawMessage<P, R>) -> Result<Self, serde_json::Error>
240 where
241 P: Serialize,
242 R: Serialize,
243 {
244 serde_json::to_string(msg).map(Self)
245 }
246 pub fn from_request<P>(id: RequestId, method: &str, params: Option<&P>) -> SessionResult<Self>
247 where
248 P: Serialize,
249 {
250 Self::from_raw_message::<P, ()>(&RawMessage {
251 id: Some(id),
252 method: Some(Cow::Borrowed(method)),
253 params,
254 ..Default::default()
255 })
256 .map_err(SessionError::serialize_failed)
257 }
258 pub fn from_notification<P>(method: &str, params: Option<&P>) -> SessionResult<Self>
259 where
260 P: Serialize,
261 {
262 Self::from_raw_message::<P, ()>(&RawMessage {
263 method: Some(Cow::Borrowed(method)),
264 params,
265 ..Default::default()
266 })
267 .map_err(SessionError::serialize_failed)
268 }
269
270 pub fn from_success<R>(id: RequestId, result: &R) -> Result<Self>
271 where
272 R: Serialize,
273 {
274 Self::from_raw_message::<(), R>(&RawMessage {
275 id: Some(id),
276 result: Some(result),
277 ..Default::default()
278 })
279 .map_err(|e| Error::from(e).with_message("Serialize failed", true))
280 }
281 pub fn from_error(id: Option<RequestId>, e: Error, expose_internals: bool) -> Self {
282 Self::from_error_object(id, e.to_error_object(expose_internals))
283 }
284 pub fn from_error_object(id: Option<RequestId>, e: ErrorObject) -> Self {
285 Self::from_raw_message::<(), ()>(&RawMessage {
286 id,
287 error: Some(e),
288 ..Default::default()
289 })
290 .unwrap()
291 }
292 pub fn from_result(id: RequestId, r: Result<impl Serialize>, expose_internals: bool) -> Self {
293 let e = match r {
294 Ok(data) => match Self::from_success(id.clone(), &data) {
295 Ok(data) => return data,
296 Err(e) => e,
297 },
298 Err(e) => e,
299 };
300 Self::from_error(Some(id), e, expose_internals)
301 }
302 pub fn from_result_message_data(
303 id: RequestId,
304 md: Result<Self>,
305 expose_internals: bool,
306 ) -> Self {
307 match md {
308 Ok(data) => data,
309 Err(e) => Self::from_error(Some(id), e, expose_internals),
310 }
311 }
312}
313
314#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Debug)]
318pub struct ErrorObject {
319 pub code: ErrorCode,
320 pub message: String,
321 #[serde(skip_serializing_if = "Option::is_none")]
322 pub data: Option<Value>,
323}
324impl fmt::Display for ErrorObject {
325 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326 write!(f, "[{}] {}", self.code, self.message)?;
327 if let Some(data) = &self.data {
328 write!(f, " (")?;
329 write_string_no_escape(data, f)?;
330 write!(f, ")")?;
331 }
332 Ok(())
333 }
334}
335#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Hash, Ord, PartialOrd, Display)]
339#[serde(transparent)]
340pub struct ErrorCode(pub i64);
341
342impl ErrorCode {
343 pub const PARSE_ERROR: Self = Self(-32700);
344 pub const INVALID_REQUEST: Self = Self(-32600);
345 pub const METHOD_NOT_FOUND: Self = Self(-32601);
346 pub const INVALID_PARAMS: Self = Self(-32602);
347 pub const INTERNAL_ERROR: Self = Self(-32603);
348 pub const SERVER_ERROR_START: Self = Self(-32000);
349 pub const SERVER_ERROR_END: Self = Self(-32099);
350
351 pub fn message(self) -> &'static str {
352 match self {
353 Self::PARSE_ERROR => "Parse error",
354 Self::INVALID_REQUEST => "Invalid Request",
355 Self::METHOD_NOT_FOUND => "Method not found",
356 Self::INVALID_PARAMS => "Invalid params",
357 Self::INTERNAL_ERROR => "Internal error",
358 _ if Self::SERVER_ERROR_START <= self && self <= Self::SERVER_ERROR_END => {
359 "Server error"
360 }
361 _ => "Unknown error",
362 }
363 }
364}
365impl fmt::Debug for ErrorCode {
366 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
367 write!(f, "{}({})", self.0, self.message())
368 }
369}
370
371#[cfg(test)]
372mod tests;