arti_rpc_client_core/msgs/
response.rs1use std::sync::Arc;
4
5use serde::{Deserialize, Serialize};
6
7use super::{AnyRequestId, JsonAnyObj};
8use crate::{
9 conn::ErrorResponse,
10 util::{Utf8CString, define_from_for_arc},
11};
12
13#[derive(Clone, Debug, derive_more::AsRef)]
17pub(crate) struct UnparsedResponse {
18 msg: String,
20}
21
22impl UnparsedResponse {
23 pub(crate) fn new(msg: String) -> Self {
25 Self { msg }
26 }
27}
28
29#[derive(Clone, Debug)]
34pub(crate) struct ValidatedResponse {
35 pub(crate) msg: Utf8CString,
37 pub(crate) meta: ResponseMeta,
39}
40
41#[derive(Clone, Debug, thiserror::Error)]
43#[non_exhaustive]
44pub(crate) enum DecodeResponseError {
45 #[error("Arti sent a message that didn't conform to the RPC protocol")]
47 JsonProtocolViolation(#[source] Arc<serde_json::Error>),
48
49 #[error("Arti sent a message that didn't conform to the RPC protocol: {0}")]
51 ProtocolViolation(&'static str),
52
53 #[error("Arti reported a fatal error: {0}")]
56 Fatal(ErrorResponse),
57}
58define_from_for_arc!( serde_json::Error => DecodeResponseError [JsonProtocolViolation] );
59
60impl UnparsedResponse {
61 pub(crate) fn try_validate(self) -> Result<ValidatedResponse, DecodeResponseError> {
64 let json: serde_json::Value = serde_json::from_str(self.as_str())?;
71 let mut msg: String = serde_json::to_string(&json)?;
72 debug_assert!(!msg.contains('\n'));
73 msg.push('\n');
74 let msg: Utf8CString = msg.try_into().map_err(|_| {
75 DecodeResponseError::ProtocolViolation("Unexpected NUL in validated message")
77 })?;
78 let response: Response = serde_json::from_value(json)?;
79 let meta = match ResponseMeta::try_from_response(&response) {
80 Ok(m) => m?,
81 Err(_) => {
82 return Err(DecodeResponseError::Fatal(
83 ErrorResponse::from_validated_string(msg),
84 ));
85 }
86 };
87 Ok(ValidatedResponse { msg, meta })
88 }
89
90 pub(crate) fn as_str(&self) -> &str {
92 self.msg.as_str()
93 }
94}
95
96impl ValidatedResponse {
97 pub(crate) fn is_final(&self) -> bool {
99 use ResponseKind as K;
100 match self.meta.kind {
101 K::Error | K::Success => true,
102 K::Update => false,
103 }
104 }
105
106 pub(crate) fn id(&self) -> &AnyRequestId {
108 &self.meta.id
109 }
110}
111
112#[derive(Clone, Debug)]
114#[cfg_attr(test, derive(Eq, PartialEq))]
115pub(crate) struct ResponseMeta {
116 pub(crate) id: AnyRequestId,
118 pub(crate) kind: ResponseKind,
120}
121
122#[derive(Clone, Debug, Eq, PartialEq)]
126pub(crate) enum ResponseKind {
127 Error,
129 Success,
131 Update,
133}
134
135#[derive(Deserialize, Debug)]
138struct Response {
139 id: Option<AnyRequestId>,
143 #[serde(flatten)]
145 body: ResponseBody,
146}
147
148#[derive(Deserialize, Debug)]
150enum ResponseBody {
151 #[serde(rename = "error")]
155 Error(RpcError),
156 #[serde(rename = "result")]
158 Success(JsonAnyObj),
159 #[serde(rename = "update")]
161 Update(JsonAnyObj),
162}
163impl<'a> From<&'a ResponseBody> for ResponseKind {
164 fn from(value: &'a ResponseBody) -> Self {
165 use ResponseBody as RMB;
166 use ResponseKind as RK;
167 match value {
170 RMB::Error(_) => RK::Error,
171 RMB::Success(_) => RK::Success,
172 RMB::Update(_) => RK::Update,
173 }
174 }
175}
176
177#[derive(thiserror::Error, Debug, Clone)]
180#[error("Response was fatal (it had no ID)")]
181struct ResponseWasFatal;
182
183impl ResponseMeta {
184 fn try_from_response(
189 response: &Response,
190 ) -> Result<Result<Self, DecodeResponseError>, ResponseWasFatal> {
191 use DecodeResponseError as E;
192 use ResponseBody as Body;
193 match (&response.id, &response.body) {
194 (None, Body::Error(_ignore)) => {
195 Err(ResponseWasFatal)
198 }
199 (None, _) => Ok(Err(E::ProtocolViolation("Missing ID field"))),
200 (Some(id), body) => Ok(Ok(ResponseMeta {
201 id: id.clone(),
202 kind: (body).into(),
203 })),
204 }
205 }
206}
207
208pub(crate) fn try_decode_response_as_err(s: &str) -> Result<Option<RpcError>, DecodeResponseError> {
215 let Response { body, .. } = serde_json::from_str(s)?;
216 match body {
217 ResponseBody::Error(e) => Ok(Some(e)),
218 _ => Ok(None),
219 }
220}
221
222#[derive(Clone, Debug, Deserialize, Serialize)]
224#[cfg_attr(test, derive(PartialEq, Eq))]
225pub struct RpcError {
226 message: String,
228 code: RpcErrorCode,
230 kinds: Vec<String>,
232}
233
234impl RpcError {
235 pub fn message(&self) -> &str {
237 self.message.as_str()
238 }
239 pub fn code(&self) -> RpcErrorCode {
241 self.code
242 }
243 pub fn kinds_iter(&self) -> impl Iterator<Item = &'_ str> {
248 self.kinds.iter().map(|s| s.as_ref())
249 }
250}
251
252caret::caret_int! {
253 #[derive(serde::Deserialize, serde::Serialize)]
254 pub struct RpcErrorCode(i32) {
255 INVALID_REQUEST = -32600,
257 NO_SUCH_METHOD = -32601,
259 INVALID_PARAMS = -32602,
261 INTERNAL_ERROR = -32603,
263 OBJECT_ERROR = 1,
265 REQUEST_ERROR = 2,
267 METHOD_NOT_IMPL = 3,
269 }
270}
271
272#[cfg(test)]
273mod test {
274 #![allow(clippy::bool_assert_comparison)]
276 #![allow(clippy::clone_on_copy)]
277 #![allow(clippy::dbg_macro)]
278 #![allow(clippy::mixed_attributes_style)]
279 #![allow(clippy::print_stderr)]
280 #![allow(clippy::print_stdout)]
281 #![allow(clippy::single_char_pattern)]
282 #![allow(clippy::unwrap_used)]
283 #![allow(clippy::unchecked_time_subtraction)]
284 #![allow(clippy::useless_vec)]
285 #![allow(clippy::needless_pass_by_value)]
286 use super::*;
289
290 fn response_meta(s: &str) -> Result<ResponseMeta, DecodeResponseError> {
293 match ResponseMeta::try_from_response(&serde_json::from_str::<Response>(s)?) {
294 Ok(v) => v,
295 Err(_) => {
296 let utf8 = Utf8CString::try_from(s.to_string())
297 .map_err(|_| DecodeResponseError::ProtocolViolation("not utf8cstr?"))?;
298 Err(DecodeResponseError::Fatal(
299 ErrorResponse::from_validated_string(utf8),
300 ))
301 }
302 }
303 }
304
305 #[test]
306 fn response_meta_good() {
307 use ResponseKind as RK;
308 use ResponseMeta as RM;
309 for (s, expected) in [
310 (
311 r#"{"id":7, "result": {}}"#,
312 RM {
313 id: 7.into(),
314 kind: RK::Success,
315 },
316 ),
317 (
318 r#"{"id":"hi", "update": {"here":["goes", "nothing"]}}"#,
319 RM {
320 id: "hi".to_string().into(),
321 kind: RK::Update,
322 },
323 ),
324 (
325 r#"{"id": 6, "error": {"message":"iffy wobbler", "code":999, "kinds": ["BadVibes"]}}"#,
326 RM {
327 id: 6.into(),
328 kind: RK::Error,
329 },
330 ),
331 (
332 r#"{"id": 6, "error": {"message":"iffy wobbler", "code":999, "kinds": ["BadVibes"], "data": {"a":"b"}}}"#,
333 RM {
334 id: 6.into(),
335 kind: RK::Error,
336 },
337 ),
338 ] {
339 let got = response_meta(s).unwrap();
340 assert_eq!(got, expected);
341 }
342 }
343
344 #[test]
345 fn response_meta_bad() {
346 macro_rules! check_err {
347 { $s:expr, $p:pat } => {
348 let got_err = response_meta($s).unwrap_err();
349 assert!(matches!(got_err, $p));
350 }
351
352 }
353
354 use DecodeResponseError as E;
355
356 check_err!(
358 r#"{"error": {"message":"iffy wobbler", "code":999, "kinds": ["BadVibes"], "data": {"a":"b"}}}"#,
359 E::Fatal(_)
360 );
361 check_err!(r#"{"result": {}}"#, E::ProtocolViolation(_));
363 check_err!(r#"{"update": {}}"#, E::ProtocolViolation(_));
365 check_err!(r#"{"id": 7, "flupdate": {}}"#, E::JsonProtocolViolation(_));
367 check_err!(r#"{{{{{"#, E::JsonProtocolViolation(_));
369 check_err!(
371 r#"{"id": 77 "error": {"message":"iffy wobbler"}}"#,
372 E::JsonProtocolViolation(_)
373 );
374 }
375
376 #[test]
377 fn bad_json() {
378 for s in [
380 "{ ", "", "{ \0 }", "{ \"\0\" }", ] {
385 let r: Result<serde_json::Value, _> = serde_json::from_str(s);
386 assert!(dbg!(r.err()).is_some());
387 }
388 }
389
390 #[test]
391 fn re_encode() {
392 let response = r#"{
393 "id": 6,
394 "error": {
395 "message":"iffy wobbler",
396 "code":999,
397 "kinds": ["BadVibes"],
398 "data": {"a":"b"},
399 "explosion": 22
400 },
401 "xyzzy":"plugh"
402 }"#;
403 let json_orig: serde_json::Value = serde_json::from_str(response).unwrap();
404 let resp = UnparsedResponse::new(response.into());
405 let valid = resp.try_validate().unwrap();
406 let msg: &str = valid.msg.as_ref();
407 let json_reencoded: serde_json::Value = serde_json::from_str(msg).unwrap();
408 assert_eq!(json_orig, json_reencoded);
411 }
412}