Skip to main content

arti_rpc_client_core/msgs/
request.rs

1//! Support for encoding and decoding RPC Requests.
2//!
3//! There are several types in this module:
4//!
5//! - [`Request`] is for requests that are generated from within this crate,
6//!   to implement authentication, negotiation, and other functionality.
7//! - `ParsedRequestFields` (internal) is for a request we've completely validated,
8//!   with all of its fields present.
9//! - [`ValidatedRequest`] is for a string that we have validated as a request.
10
11use std::sync::Arc;
12
13use serde::{Deserialize, Serialize};
14
15/// Alias for a Map as used by the serde_json.
16pub(crate) type JsonMap = serde_json::Map<String, serde_json::Value>;
17
18use crate::conn::ProtoError;
19
20use super::{AnyRequestId, JsonAnyObj, ObjectId};
21
22/// An outbound request that we have generated from within this crate.
23///
24/// It lacks a required `id` field (since we will generate one when sending it),
25/// and it allows any Serialize for its `params`.
26#[derive(Serialize, Debug)]
27// Testing only. Don't implement Deserialize here; this is not the type you should parse into!
28#[cfg_attr(test, derive(Eq, PartialEq, Deserialize))]
29#[allow(clippy::missing_docs_in_private_items)] // Fields are as for ParsedRequest.
30pub(crate) struct Request<T> {
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub(crate) id: Option<AnyRequestId>,
33    pub(crate) obj: ObjectId,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub(crate) meta: Option<RequestMeta>,
36    pub(crate) method: String,
37    pub(crate) params: T,
38}
39
40/// An error that has prevented us from validating an request.
41#[derive(Clone, Debug, thiserror::Error)]
42#[non_exhaustive]
43pub enum InvalidRequestError {
44    /// We failed to turn the request into any kind of json.
45    #[error("Request was not valid Json")]
46    InvalidJson(#[source] Arc<serde_json::Error>),
47    /// We got the request into json, but we couldn't find the fields we wanted.
48    #[error("Request's fields were invalid or missing")]
49    InvalidFormat(#[source] Arc<serde_json::Error>),
50    /// We validated the request, but couldn't re-encode it.
51    #[error("Unable to re-encode or format request")]
52    ReencodeFailed(#[source] Arc<serde_json::Error>),
53}
54
55impl<T: Serialize> Request<T> {
56    /// Construct a new outbound Request.
57    pub(crate) fn new(obj: ObjectId, method: impl Into<String>, params: T) -> Self {
58        Self {
59            id: None,
60            obj,
61            meta: Default::default(),
62            method: method.into(),
63            params,
64        }
65    }
66    /// Try to encode this request as a String.
67    ///
68    /// The string may not yet be a valid request; it might need to get an ID assigned.
69    pub(crate) fn encode(&self) -> Result<String, ProtoError> {
70        serde_json::to_string(self).map_err(|e| ProtoError::CouldNotEncode(Arc::new(e)))
71    }
72}
73
74/// A request in its decoded (or unencoded) format.
75///
76/// We use this type to validate outbound requests from the application.
77#[derive(Deserialize, Debug)]
78// Don't implement Serialize here; this is not for generating requests!
79#[allow(dead_code)] // The fields here are only used for validating serde objects.
80struct ParsedRequestFields {
81    /// The identifier for this request.
82    ///
83    /// Used to match a request with its responses.
84    id: AnyRequestId,
85    /// The ID for the object to which this request is addressed.
86    ///
87    /// (Every request goes to a single object.)
88    obj: ObjectId,
89    /// Additional information for Arti about how to handle the request.
90    #[serde(skip_serializing_if = "Option::is_none")]
91    meta: Option<RequestMeta>,
92    /// The name of the method to invoke.
93    method: String,
94    /// Parameters to pass to the method.
95    params: JsonAnyObj,
96}
97
98/// A known-valid request, encoded as a string (in a single line, with a terminating newline).
99#[derive(derive_more::AsRef, Debug, Clone)]
100pub(crate) struct ValidatedRequest {
101    /// The message itself, as encoded.
102    #[as_ref]
103    msg: String,
104    /// The ID for this request.
105    id: AnyRequestId,
106}
107
108impl ValidatedRequest {
109    /// Return the Id associated with this request.
110    pub(crate) fn id(&self) -> &AnyRequestId {
111        &self.id
112    }
113
114    /// Try to construct a validated request from a `serde_json::Value`.
115    fn from_json_value(val: serde_json::Value) -> Result<Self, InvalidRequestError> {
116        let mut msg = serde_json::to_string(&val)
117            .map_err(|e| InvalidRequestError::ReencodeFailed(Arc::new(e)))?;
118        debug_assert!(!msg.contains('\n'));
119        msg.push('\n');
120
121        let req: ParsedRequestFields = serde_json::from_value(val)
122            .map_err(|e| InvalidRequestError::InvalidFormat(Arc::new(e)))?;
123        let id = req.id;
124
125        Ok(ValidatedRequest { id, msg })
126    }
127
128    /// Try to construct a validated request using `s`.
129    // TODO nb: Expose or remove.
130    #[allow(dead_code)]
131    pub(crate) fn from_string_strict(s: &str) -> Result<Self, InvalidRequestError> {
132        let value: serde_json::Value =
133            serde_json::from_str(s).map_err(|e| InvalidRequestError::InvalidJson(Arc::new(e)))?;
134        Self::from_json_value(value)
135    }
136
137    /// Try to construct a ValidatedRequest from the string in `s`.
138    ///
139    /// If it has no `id`, add one using `id_generator`.
140    pub(crate) fn from_string_loose<F>(
141        s: &str,
142        id_generator: F,
143    ) -> Result<Self, InvalidRequestError>
144    where
145        F: FnOnce() -> AnyRequestId,
146    {
147        let mut value: serde_json::Value =
148            serde_json::from_str(s).map_err(|e| InvalidRequestError::InvalidJson(Arc::new(e)))?;
149
150        if let Some(obj) = value.as_object_mut() {
151            obj.entry("id")
152                .or_insert_with(|| id_generator().into_json_value());
153        }
154
155        Self::from_json_value(value)
156    }
157}
158
159/// Crate-internal: The "meta" field in a request.
160#[derive(Deserialize, Serialize, Debug, Default)]
161#[cfg_attr(test, derive(Eq, PartialEq))]
162pub(crate) struct RequestMeta {
163    /// If true, the application wants to receive incremental updates
164    /// about the request that it sent.
165    ///
166    /// (Default: false)
167    #[serde(default)]
168    pub(crate) updates: bool,
169    /// Any unrecognized fields that we received from the user.
170    /// (We re-encode these in case the user knows about fields that we don't.)
171    #[serde(flatten)]
172    pub(crate) unrecognized_fields: JsonMap,
173}
174
175/// A helper to return unique Request identifiers.
176///
177/// All identifiers are prefixed with `"!aut o!--"`:
178/// if you don't use that string in your own IDs,
179/// you won't have any collisions.
180#[derive(Debug, Default)]
181pub(crate) struct IdGenerator {
182    /// The number
183    next_id: u64,
184}
185
186impl IdGenerator {
187    /// Return a previously unyielded identifier.
188    pub(crate) fn next_id(&mut self) -> AnyRequestId {
189        let id = self.next_id;
190        self.next_id += 1;
191        format!("!auto!--{id}").into()
192    }
193}
194
195#[cfg(test)]
196mod test {
197    // @@ begin test lint list maintained by maint/add_warning @@
198    #![allow(clippy::bool_assert_comparison)]
199    #![allow(clippy::clone_on_copy)]
200    #![allow(clippy::dbg_macro)]
201    #![allow(clippy::mixed_attributes_style)]
202    #![allow(clippy::print_stderr)]
203    #![allow(clippy::print_stdout)]
204    #![allow(clippy::single_char_pattern)]
205    #![allow(clippy::unwrap_used)]
206    #![allow(clippy::unchecked_time_subtraction)]
207    #![allow(clippy::useless_vec)]
208    #![allow(clippy::needless_pass_by_value)]
209    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
210
211    impl ParsedRequestFields {
212        /// Return true if this request is asking for updates.
213        fn updates_requested(&self) -> bool {
214            self.meta.as_ref().map(|m| m.updates).unwrap_or(false)
215        }
216    }
217
218    use crate::util::assert_same_json;
219
220    use super::*;
221    const REQ1: &str = r#"{"id":7, "obj": "hi", "meta": {"updates": true}, "method":"twiddle", "params":{"stuff": "nonsense"} }"#;
222    const REQ2: &str = r#"{"id":"fred", "obj": "hi", "method":"twiddle", "params":{} }"#;
223    const REQ3: &str =
224        r#"{"id":"fred", "obj": "hi", "method":"twiddle", "params":{},"unrecognized":"waffles"}"#;
225
226    #[test]
227    fn parse_requests() {
228        let req1: ParsedRequestFields = serde_json::from_str(REQ1).unwrap();
229        assert_eq!(req1.id, 7.into());
230        assert_eq!(req1.obj.as_ref(), "hi");
231        assert_eq!(req1.updates_requested(), true);
232        assert_eq!(req1.method, "twiddle");
233
234        let req2: ParsedRequestFields = serde_json::from_str(REQ2).unwrap();
235        assert_eq!(req2.id, "fred".to_string().into());
236        assert_eq!(req2.obj.as_ref(), "hi");
237        assert_eq!(req2.updates_requested(), false);
238        assert_eq!(req2.method, "twiddle");
239
240        let _req3: ParsedRequestFields = serde_json::from_str(REQ2).unwrap();
241    }
242
243    #[test]
244    fn reencode_requests() {
245        for r in [REQ1, REQ2, REQ3] {
246            let val1 = ValidatedRequest::from_string_strict(r).unwrap();
247            let val2 = ValidatedRequest::from_string_loose(r, || panic!()).unwrap();
248
249            assert_same_json!(val1.as_ref(), val2.as_ref());
250            assert_same_json!(val1.as_ref(), r);
251        }
252    }
253
254    #[test]
255    fn bad_requests() {
256        for text in [
257            // not an object.
258            "123",
259            // missing most parts.
260            r#"{"id":12,}"#,
261            // no id.
262            r#"{"obj":"hi", "method":"twiddle", "params":{"stuff":"nonsense"}}"#,
263            // no params
264            r#"{"obj":"hi", "id": 7, "method":"twiddle"}"#,
265            // bad params type
266            r#"{"obj":"hi", "id": 7, "method":"twiddle", "params": []}"#,
267            // weird obj.
268            r#"{"obj":7, "id": 7, "method":"twiddle", "params":{"stuff":"nonsense"}}"#,
269            // weird id.
270            r#"{"obj":"hi", "id": [], "method":"twiddle", "params":{"stuff":"nonsense"}}"#,
271            // weird method
272            r#"{"obj":"hi", "id": 7, "method":6", "params":{"stuff":"nonsense"}}"#,
273        ] {
274            let r: Result<ParsedRequestFields, _> = serde_json::from_str(dbg!(text));
275            assert!(r.is_err());
276        }
277    }
278
279    #[test]
280    fn fix_requests() {
281        let no_id = r#"{"obj":"hi", "method":"twiddle", "params":{"stuff":"nonsense"}}"#;
282        let validated = ValidatedRequest::from_string_loose(no_id, || 7.into()).unwrap();
283        let expected_with_id =
284            r#"{"id": 7, "obj":"hi", "method":"twiddle", "params":{"stuff":"nonsense"}}"#;
285        assert_same_json!(validated.as_ref(), expected_with_id);
286    }
287
288    #[test]
289    fn preserve_fields() {
290        let orig = r#"
291            {"obj":"hi",
292             "meta": { "updates": true, "waffles": "yesplz" },
293             "method":"twiddle",
294             "params":{"stuff":"nonsense"},
295             "explosions": -70
296            }"#;
297        let validated = ValidatedRequest::from_string_loose(orig, || 77.into()).unwrap();
298        let expected_with_id = r#"
299            {"id":77,
300            "obj":"hi",
301            "meta": { "updates": true, "waffles": "yesplz" },
302            "method":"twiddle",
303            "params":{"stuff":"nonsense"},
304            "explosions": -70
305            }"#;
306        assert_same_json!(validated.as_ref(), expected_with_id);
307    }
308
309    #[test]
310    fn ok_request_encode() {
311        let expected_encoded_request =
312            r#"{"obj":"connection","method":"arti:get_rpc_proxy_info","params":"123"}"#;
313        let obj_id = ObjectId::connection_id();
314        let encoded_request = Request::new(obj_id, "arti:get_rpc_proxy_info", "123")
315            .encode()
316            .unwrap();
317        assert_eq!(expected_encoded_request, encoded_request);
318    }
319
320    // This should not be possible
321    #[test]
322    fn err_request_encode() {
323        struct FailingSerialization;
324
325        impl serde::Serialize for FailingSerialization {
326            fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
327            where
328                S: serde::Serializer,
329            {
330                Err(serde::ser::Error::custom(
331                    "Intentional serialization failure",
332                ))
333            }
334        }
335
336        let obj_id = ObjectId::connection_id();
337        let failing_request = Request::new(obj_id, "arti:get_rpc_proxy_info", FailingSerialization);
338
339        let err = failing_request.encode().unwrap_err();
340        assert!(matches!(err, ProtoError::CouldNotEncode(_)));
341    }
342}