arti_rpc_client_core/msgs/
request.rs1use std::sync::Arc;
12
13use serde::{Deserialize, Serialize};
14
15pub(crate) type JsonMap = serde_json::Map<String, serde_json::Value>;
17
18use crate::conn::ProtoError;
19
20use super::{AnyRequestId, JsonAnyObj, ObjectId};
21
22#[derive(Serialize, Debug)]
27#[cfg_attr(test, derive(Eq, PartialEq, Deserialize))]
29#[allow(clippy::missing_docs_in_private_items)] pub(crate) struct Request<T> {
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub(crate) id: Option<AnyRequestId>,
33 pub(crate) obj: ObjectId,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub(crate) meta: Option<RequestMeta>,
36 pub(crate) method: String,
37 pub(crate) params: T,
38}
39
40#[derive(Clone, Debug, thiserror::Error)]
42#[non_exhaustive]
43pub enum InvalidRequestError {
44 #[error("Request was not valid Json")]
46 InvalidJson(#[source] Arc<serde_json::Error>),
47 #[error("Request's fields were invalid or missing")]
49 InvalidFormat(#[source] Arc<serde_json::Error>),
50 #[error("Unable to re-encode or format request")]
52 ReencodeFailed(#[source] Arc<serde_json::Error>),
53}
54
55impl<T: Serialize> Request<T> {
56 pub(crate) fn new(obj: ObjectId, method: impl Into<String>, params: T) -> Self {
58 Self {
59 id: None,
60 obj,
61 meta: Default::default(),
62 method: method.into(),
63 params,
64 }
65 }
66 pub(crate) fn encode(&self) -> Result<String, ProtoError> {
70 serde_json::to_string(self).map_err(|e| ProtoError::CouldNotEncode(Arc::new(e)))
71 }
72}
73
74#[derive(Deserialize, Debug)]
78#[allow(dead_code)] struct ParsedRequestFields {
81 id: AnyRequestId,
85 obj: ObjectId,
89 #[serde(skip_serializing_if = "Option::is_none")]
91 meta: Option<RequestMeta>,
92 method: String,
94 params: JsonAnyObj,
96}
97
98#[derive(derive_more::AsRef, Debug, Clone)]
100pub(crate) struct ValidatedRequest {
101 #[as_ref]
103 msg: String,
104 id: AnyRequestId,
106}
107
108impl ValidatedRequest {
109 pub(crate) fn id(&self) -> &AnyRequestId {
111 &self.id
112 }
113
114 fn from_json_value(val: serde_json::Value) -> Result<Self, InvalidRequestError> {
116 let mut msg = serde_json::to_string(&val)
117 .map_err(|e| InvalidRequestError::ReencodeFailed(Arc::new(e)))?;
118 debug_assert!(!msg.contains('\n'));
119 msg.push('\n');
120
121 let req: ParsedRequestFields = serde_json::from_value(val)
122 .map_err(|e| InvalidRequestError::InvalidFormat(Arc::new(e)))?;
123 let id = req.id;
124
125 Ok(ValidatedRequest { id, msg })
126 }
127
128 #[allow(dead_code)]
131 pub(crate) fn from_string_strict(s: &str) -> Result<Self, InvalidRequestError> {
132 let value: serde_json::Value =
133 serde_json::from_str(s).map_err(|e| InvalidRequestError::InvalidJson(Arc::new(e)))?;
134 Self::from_json_value(value)
135 }
136
137 pub(crate) fn from_string_loose<F>(
141 s: &str,
142 id_generator: F,
143 ) -> Result<Self, InvalidRequestError>
144 where
145 F: FnOnce() -> AnyRequestId,
146 {
147 let mut value: serde_json::Value =
148 serde_json::from_str(s).map_err(|e| InvalidRequestError::InvalidJson(Arc::new(e)))?;
149
150 if let Some(obj) = value.as_object_mut() {
151 obj.entry("id")
152 .or_insert_with(|| id_generator().into_json_value());
153 }
154
155 Self::from_json_value(value)
156 }
157}
158
159#[derive(Deserialize, Serialize, Debug, Default)]
161#[cfg_attr(test, derive(Eq, PartialEq))]
162pub(crate) struct RequestMeta {
163 #[serde(default)]
168 pub(crate) updates: bool,
169 #[serde(flatten)]
172 pub(crate) unrecognized_fields: JsonMap,
173}
174
175#[derive(Debug, Default)]
181pub(crate) struct IdGenerator {
182 next_id: u64,
184}
185
186impl IdGenerator {
187 pub(crate) fn next_id(&mut self) -> AnyRequestId {
189 let id = self.next_id;
190 self.next_id += 1;
191 format!("!auto!--{id}").into()
192 }
193}
194
195#[cfg(test)]
196mod test {
197 #![allow(clippy::bool_assert_comparison)]
199 #![allow(clippy::clone_on_copy)]
200 #![allow(clippy::dbg_macro)]
201 #![allow(clippy::mixed_attributes_style)]
202 #![allow(clippy::print_stderr)]
203 #![allow(clippy::print_stdout)]
204 #![allow(clippy::single_char_pattern)]
205 #![allow(clippy::unwrap_used)]
206 #![allow(clippy::unchecked_time_subtraction)]
207 #![allow(clippy::useless_vec)]
208 #![allow(clippy::needless_pass_by_value)]
209 impl ParsedRequestFields {
212 fn updates_requested(&self) -> bool {
214 self.meta.as_ref().map(|m| m.updates).unwrap_or(false)
215 }
216 }
217
218 use crate::util::assert_same_json;
219
220 use super::*;
221 const REQ1: &str = r#"{"id":7, "obj": "hi", "meta": {"updates": true}, "method":"twiddle", "params":{"stuff": "nonsense"} }"#;
222 const REQ2: &str = r#"{"id":"fred", "obj": "hi", "method":"twiddle", "params":{} }"#;
223 const REQ3: &str =
224 r#"{"id":"fred", "obj": "hi", "method":"twiddle", "params":{},"unrecognized":"waffles"}"#;
225
226 #[test]
227 fn parse_requests() {
228 let req1: ParsedRequestFields = serde_json::from_str(REQ1).unwrap();
229 assert_eq!(req1.id, 7.into());
230 assert_eq!(req1.obj.as_ref(), "hi");
231 assert_eq!(req1.updates_requested(), true);
232 assert_eq!(req1.method, "twiddle");
233
234 let req2: ParsedRequestFields = serde_json::from_str(REQ2).unwrap();
235 assert_eq!(req2.id, "fred".to_string().into());
236 assert_eq!(req2.obj.as_ref(), "hi");
237 assert_eq!(req2.updates_requested(), false);
238 assert_eq!(req2.method, "twiddle");
239
240 let _req3: ParsedRequestFields = serde_json::from_str(REQ2).unwrap();
241 }
242
243 #[test]
244 fn reencode_requests() {
245 for r in [REQ1, REQ2, REQ3] {
246 let val1 = ValidatedRequest::from_string_strict(r).unwrap();
247 let val2 = ValidatedRequest::from_string_loose(r, || panic!()).unwrap();
248
249 assert_same_json!(val1.as_ref(), val2.as_ref());
250 assert_same_json!(val1.as_ref(), r);
251 }
252 }
253
254 #[test]
255 fn bad_requests() {
256 for text in [
257 "123",
259 r#"{"id":12,}"#,
261 r#"{"obj":"hi", "method":"twiddle", "params":{"stuff":"nonsense"}}"#,
263 r#"{"obj":"hi", "id": 7, "method":"twiddle"}"#,
265 r#"{"obj":"hi", "id": 7, "method":"twiddle", "params": []}"#,
267 r#"{"obj":7, "id": 7, "method":"twiddle", "params":{"stuff":"nonsense"}}"#,
269 r#"{"obj":"hi", "id": [], "method":"twiddle", "params":{"stuff":"nonsense"}}"#,
271 r#"{"obj":"hi", "id": 7, "method":6", "params":{"stuff":"nonsense"}}"#,
273 ] {
274 let r: Result<ParsedRequestFields, _> = serde_json::from_str(dbg!(text));
275 assert!(r.is_err());
276 }
277 }
278
279 #[test]
280 fn fix_requests() {
281 let no_id = r#"{"obj":"hi", "method":"twiddle", "params":{"stuff":"nonsense"}}"#;
282 let validated = ValidatedRequest::from_string_loose(no_id, || 7.into()).unwrap();
283 let expected_with_id =
284 r#"{"id": 7, "obj":"hi", "method":"twiddle", "params":{"stuff":"nonsense"}}"#;
285 assert_same_json!(validated.as_ref(), expected_with_id);
286 }
287
288 #[test]
289 fn preserve_fields() {
290 let orig = r#"
291 {"obj":"hi",
292 "meta": { "updates": true, "waffles": "yesplz" },
293 "method":"twiddle",
294 "params":{"stuff":"nonsense"},
295 "explosions": -70
296 }"#;
297 let validated = ValidatedRequest::from_string_loose(orig, || 77.into()).unwrap();
298 let expected_with_id = r#"
299 {"id":77,
300 "obj":"hi",
301 "meta": { "updates": true, "waffles": "yesplz" },
302 "method":"twiddle",
303 "params":{"stuff":"nonsense"},
304 "explosions": -70
305 }"#;
306 assert_same_json!(validated.as_ref(), expected_with_id);
307 }
308
309 #[test]
310 fn ok_request_encode() {
311 let expected_encoded_request =
312 r#"{"obj":"connection","method":"arti:get_rpc_proxy_info","params":"123"}"#;
313 let obj_id = ObjectId::connection_id();
314 let encoded_request = Request::new(obj_id, "arti:get_rpc_proxy_info", "123")
315 .encode()
316 .unwrap();
317 assert_eq!(expected_encoded_request, encoded_request);
318 }
319
320 #[test]
322 fn err_request_encode() {
323 struct FailingSerialization;
324
325 impl serde::Serialize for FailingSerialization {
326 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
327 where
328 S: serde::Serializer,
329 {
330 Err(serde::ser::Error::custom(
331 "Intentional serialization failure",
332 ))
333 }
334 }
335
336 let obj_id = ObjectId::connection_id();
337 let failing_request = Request::new(obj_id, "arti:get_rpc_proxy_info", FailingSerialization);
338
339 let err = failing_request.encode().unwrap_err();
340 assert!(matches!(err, ProtoError::CouldNotEncode(_)));
341 }
342}