Skip to main content

arti_rpc_client_core/
conn.rs

1//! Middle-level API for RPC connections
2//!
3//! This module focuses around the `RpcConn` type, which supports sending RPC requests
4//! and matching them with their responses.
5
6use std::{
7    io::{self},
8    sync::{Arc, Mutex},
9};
10
11use crate::msgs::{
12    AnyRequestId, ObjectId,
13    request::InvalidRequestError,
14    response::{ResponseKind, RpcError, ValidatedResponse},
15};
16
17mod auth;
18mod builder;
19mod connimpl;
20mod stream;
21
22use crate::util::Utf8CString;
23pub use builder::{BuilderError, ConnPtDescription, RpcConnBuilder};
24pub use connimpl::RpcConn;
25use serde::{Deserialize, de::DeserializeOwned};
26pub use stream::StreamError;
27use tor_rpc_connect::{HasClientErrorAction, auth::cookie::CookieAccessError};
28
29/// A handle to an open request.
30///
31/// These handles are created with [`RpcConn::execute_with_handle`].
32///
33/// Note that dropping a RequestHandle does not cancel the associated request:
34/// it will continue running, but you won't have a way to receive updates from it.
35/// To cancel a request, use [`RpcConn::cancel`].
36#[derive(educe::Educe)]
37#[educe(Debug)]
38pub struct RequestHandle {
39    /// The underlying `Receiver` that we'll use to get updates for this request
40    ///
41    /// It's wrapped in a `Mutex` to prevent concurrent calls to `Receiver::wait_on_message_for`.
42    //
43    // NOTE: As an alternative to using a Mutex here, we _could_ remove
44    // the restriction from `wait_on_message_for` that says that only one thread
45    // may be waiting on a given request ID at once.  But that would introduce
46    // complexity to the implementation,
47    // and it's not clear that the benefit would be worth it.
48    #[educe(Debug(ignore))]
49    conn: Mutex<Arc<connimpl::Receiver>>,
50    /// The ID of this request.
51    id: AnyRequestId,
52}
53
54// TODO RPC: Possibly abolish these types.
55//
56// I am keeping this for now because it makes it more clear that we can never reinterpret
57// a success as an update or similar.
58//
59// I am not at all pleased with these types; we should revise them.
60//
61// TODO RPC: Possibly, all of these should be reconstructed
62// from their serde_json::Values rather than forwarded verbatim.
63// (But why would we our json to be more canonical than arti's? See #1491.)
64//
65// DODGY TYPES BEGIN: TODO RPC
66
67/// A Success Response from Arti, indicating that a request was successful.
68///
69/// This is the complete message, including `id` and `result` fields.
70//
71// Invariant: it is valid JSON and contains no NUL bytes or newlines.
72// TODO RPC: check that the newline invariant is enforced in constructors.
73#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
74#[as_ref(forward)]
75pub struct SuccessResponse(Utf8CString);
76
77impl SuccessResponse {
78    /// Helper: Decode the `result` field of this response as an instance of D.
79    fn decode<D: DeserializeOwned>(&self) -> Result<D, serde_json::Error> {
80        /// Helper object for decoding the "result" field.
81        #[derive(Deserialize)]
82        struct Response<R> {
83            /// The decoded value.
84            result: R,
85        }
86        let response: Response<D> = serde_json::from_str(self.as_ref())?;
87        Ok(response.result)
88    }
89}
90
91/// An Update Response from Arti, with information about the progress of a request.
92///
93/// This is the complete message, including `id` and `update` fields.
94//
95// Invariant: it is valid JSON and contains no NUL bytes or newlines.
96// TODO RPC: check that the newline invariant is enforced in constructors.
97// TODO RPC consider changing this to CString.
98#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
99#[as_ref(forward)]
100pub struct UpdateResponse(Utf8CString);
101
102/// A Error Response from Arti, indicating that an error occurred.
103///
104/// (This is the complete message, including the `error` field.
105/// It also an `id` if it
106/// is in response to a request; but not if it is a fatal protocol error.)
107//
108// Invariant: Does not contain a NUL. (Safe to convert to CString.)
109//
110// Invariant: This field MUST encode a response whose body is an RPC error.
111//
112// Otherwise the `decode` method may panic.
113//
114// TODO RPC: check that the newline invariant is enforced in constructors.
115#[derive(Clone, Debug, derive_more::AsRef, derive_more::Into)]
116#[as_ref(forward)]
117// TODO: If we keep this, it should implement Error.
118pub struct ErrorResponse(Utf8CString);
119impl ErrorResponse {
120    /// Construct an ErrorResponse from the Error reply.
121    ///
122    /// This not a From impl because we want it to be crate-internal.
123    pub(crate) fn from_validated_string(s: Utf8CString) -> Self {
124        ErrorResponse(s)
125    }
126
127    /// Convert this response into an internal error in response to `cmd`.
128    ///
129    /// This is only appropriate when the error cannot be caused because of user behavior.
130    pub(crate) fn internal_error(&self, cmd: &str) -> ProtoError {
131        ProtoError::InternalRequestFailed(UnexpectedReply {
132            request: cmd.to_string(),
133            reply: self.to_string(),
134            problem: UnexpectedReplyProblem::ErrorNotExpected,
135        })
136    }
137
138    /// Try to interpret this response as an [`RpcError`].
139    pub fn decode(&self) -> RpcError {
140        crate::msgs::response::try_decode_response_as_err(self.0.as_ref())
141            .expect("Could not decode response that was already decoded as an error?")
142            .expect("Could not extract error from response that was already decoded as an error?")
143    }
144}
145
146impl std::fmt::Display for ErrorResponse {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148        let e = self.decode();
149        write!(f, "Peer said {:?}", e.message())
150    }
151}
152
153/// A final response -- that is, the last one that we expect to receive for a request.
154///
155type FinalResponse = Result<SuccessResponse, ErrorResponse>;
156
157/// Any of the three types of Arti responses.
158#[derive(Clone, Debug)]
159#[allow(clippy::exhaustive_structs)]
160pub enum AnyResponse {
161    /// The request has succeeded; no more response will be given.
162    Success(SuccessResponse),
163    /// The request has failed; no more response will be given.
164    Error(ErrorResponse),
165    /// An incremental update; more messages may arrive.
166    Update(UpdateResponse),
167}
168// TODO RPC: DODGY TYPES END.
169
170impl AnyResponse {
171    /// Convert `v` into `AnyResponse`.
172    fn from_validated(v: ValidatedResponse) -> Self {
173        // TODO RPC, Perhaps unify AnyResponse with ValidatedResponse, once we are sure what
174        // AnyResponse should look like.
175        match v.meta.kind {
176            ResponseKind::Error => AnyResponse::Error(ErrorResponse::from_validated_string(v.msg)),
177            ResponseKind::Success => AnyResponse::Success(SuccessResponse(v.msg)),
178            ResponseKind::Update => AnyResponse::Update(UpdateResponse(v.msg)),
179        }
180    }
181
182    /// Consume this `AnyResponse`, and return its internal string.
183    #[cfg(feature = "ffi")]
184    pub(crate) fn into_string(self) -> Utf8CString {
185        match self {
186            AnyResponse::Success(m) => m.into(),
187            AnyResponse::Error(m) => m.into(),
188            AnyResponse::Update(m) => m.into(),
189        }
190    }
191}
192
193impl RpcConn {
194    /// Return the ObjectId for the negotiated Session.
195    ///
196    /// Nearly all RPC methods require a Session, or some other object
197    /// accessed via the session.
198    ///
199    /// (This function will only return None if no authentication has been performed.
200    /// TODO RPC: It is not currently possible to make an unauthenticated connection.)
201    pub fn session(&self) -> Option<&ObjectId> {
202        self.session.as_ref()
203    }
204
205    /// Run a command, and wait for success or failure.
206    ///
207    /// Note that this function will return `Err(.)` only if sending the command or getting a
208    /// response failed.  If the command was sent successfully, and Arti reported an error in response,
209    /// this function returns `Ok(Err(.))`.
210    ///
211    /// Note that the command does not need to include an `id` field.  If you omit it,
212    /// one will be generated.
213    pub fn execute(&self, cmd: &str) -> Result<FinalResponse, ProtoError> {
214        let hnd = self.execute_with_handle(cmd)?;
215        hnd.wait()
216    }
217
218    /// Helper for executing internally-generated requests and decoding their results.
219    ///
220    /// Behaves like `execute`, except on success, where it tries to decode the `result` field
221    /// of the response as a `T`.
222    ///
223    /// Use this method in cases where it's reasonable for Arti to sometimes return an RPC error:
224    /// in other words, where it's not necessarily a programming error or version mismatch.
225    ///
226    /// Don't use this for user-generated requests: it will misreport unexpected replies
227    /// as internal errors.
228    pub(crate) fn execute_internal<T: DeserializeOwned>(
229        &self,
230        cmd: &str,
231    ) -> Result<Result<T, ErrorResponse>, ProtoError> {
232        match self.execute(cmd)? {
233            Ok(success) => match success.decode::<T>() {
234                Ok(result) => Ok(Ok(result)),
235                Err(json_error) => Err(ProtoError::InternalRequestFailed(UnexpectedReply {
236                    request: cmd.to_string(),
237                    reply: Utf8CString::from(success).to_string(),
238                    problem: UnexpectedReplyProblem::CannotDecode(Arc::new(json_error)),
239                })),
240            },
241            Err(error) => Ok(Err(error)),
242        }
243    }
244
245    /// Helper for executing internally-generated requests and decoding their results.
246    ///
247    /// Behaves like `execute_internal`, except that it treats any RPC error reply
248    /// as an internal error or version mismatch.
249    ///
250    /// Don't use this for user-generated requests, or for requests that can fail because of
251    /// incorrect user inputs: it will misreport failures in those requests as internal errors.
252    pub(crate) fn execute_internal_ok<T: DeserializeOwned>(
253        &self,
254        cmd: &str,
255    ) -> Result<T, ProtoError> {
256        match self.execute_internal(cmd)? {
257            Ok(v) => Ok(v),
258            Err(err_response) => Err(err_response.internal_error(cmd)),
259        }
260    }
261
262    /// Cancel a request by ID.
263    pub fn cancel(&self, request_id: &AnyRequestId) -> Result<(), ProtoError> {
264        /// Arguments to an `rpc::cancel` request.
265        #[derive(serde::Serialize, Debug)]
266        struct CancelParams<'a> {
267            /// The request to cancel.
268            request_id: &'a AnyRequestId,
269        }
270
271        let request = crate::msgs::request::Request::new(
272            ObjectId::connection_id(),
273            "rpc:cancel",
274            CancelParams { request_id },
275        );
276        match self.execute_internal::<EmptyReply>(&request.encode()?)? {
277            Ok(EmptyReply {}) => Ok(()),
278            Err(_) => Err(ProtoError::RequestCompleted),
279        }
280    }
281
282    /// Like `execute`, but don't wait.  This lets the caller see the
283    /// request ID and  maybe cancel it.
284    pub fn execute_with_handle(&self, cmd: &str) -> Result<RequestHandle, ProtoError> {
285        self.send_request(cmd)
286    }
287    /// As execute(), but run update_cb for every update we receive.
288    pub fn execute_with_updates<F>(
289        &self,
290        cmd: &str,
291        mut update_cb: F,
292    ) -> Result<FinalResponse, ProtoError>
293    where
294        F: FnMut(UpdateResponse) + Send + Sync,
295    {
296        let hnd = self.execute_with_handle(cmd)?;
297        loop {
298            match hnd.wait_with_updates()? {
299                AnyResponse::Success(s) => return Ok(Ok(s)),
300                AnyResponse::Error(e) => return Ok(Err(e)),
301                AnyResponse::Update(u) => update_cb(u),
302            }
303        }
304    }
305
306    /// Helper: Tell Arti to release `obj`.
307    ///
308    /// Do not use this method for a user-provided object ID:
309    /// It gives an internal error if the object does not exist.
310    pub(crate) fn release_obj(&self, obj: ObjectId) -> Result<(), ProtoError> {
311        let release_request = crate::msgs::request::Request::new(obj, "rpc:release", NoParams {});
312        let _empty_response: EmptyReply = self.execute_internal_ok(&release_request.encode()?)?;
313        Ok(())
314    }
315
316    // TODO RPC: shutdown() on the socket on Drop.
317}
318
319impl RequestHandle {
320    /// Return the ID of this request, to help cancelling it.
321    pub fn id(&self) -> &AnyRequestId {
322        &self.id
323    }
324    /// Wait for success or failure, and return what happened.
325    ///
326    /// (Ignores any update messages that are received.)
327    ///
328    /// Note that this function will return `Err(.)` only if sending the command or getting a
329    /// response failed.  If the command was sent successfully, and Arti reported an error in response,
330    /// this function returns `Ok(Err(.))`.
331    pub fn wait(self) -> Result<FinalResponse, ProtoError> {
332        loop {
333            match self.wait_with_updates()? {
334                AnyResponse::Success(s) => return Ok(Ok(s)),
335                AnyResponse::Error(e) => return Ok(Err(e)),
336                AnyResponse::Update(_) => {}
337            }
338        }
339    }
340    /// Wait for the next success, failure, or update from this handle.
341    ///
342    /// Note that this function will return `Err(.)` only if sending the command or getting a
343    /// response failed.  If the command was sent successfully, and Arti reported an error in response,
344    /// this function returns `Ok(AnyResponse::Error(.))`.
345    ///
346    /// You may call this method on the same `RequestHandle` from multiple threads.
347    /// If you do so, those calls will receive responses (or errors) in an unspecified order.
348    ///
349    /// If this function returns Success or Error, then you shouldn't call it again.
350    /// All future calls to this function will fail with `CmdError::RequestCancelled`.
351    /// (TODO RPC: Maybe rename that error.)
352    pub fn wait_with_updates(&self) -> Result<AnyResponse, ProtoError> {
353        let conn = self.conn.lock().expect("Poisoned lock");
354        let validated = conn.wait_on_message_for(&self.id)?;
355
356        Ok(AnyResponse::from_validated(validated))
357    }
358
359    // TODO RPC: Sketch out how we would want to do this in an async world,
360    // or with poll
361}
362
363/// An error (or other condition) that has caused an RPC connection to shut down.
364#[derive(Clone, Debug, thiserror::Error)]
365#[non_exhaustive]
366pub enum ShutdownError {
367    // TODO nb: Read/Write are no longer well separated in the API.
368    //
369    /// Io error occurred while reading.
370    #[error("Unable to read response")]
371    Read(#[source] Arc<io::Error>),
372    /// Io error occurred while writing.
373    #[error("Unable to write request")]
374    Write(#[source] Arc<io::Error>),
375    /// Something was wrong with Arti's responses; this is a protocol violation.
376    #[error("Arti sent a message that didn't conform to the RPC protocol: {0:?}")]
377    ProtocolViolated(String),
378    /// Arti has told us that we violated the protocol somehow.
379    #[error("Arti reported a fatal error: {0:?}")]
380    ProtocolViolationReport(ErrorResponse),
381    /// The underlying connection closed.
382    ///
383    /// This probably means that Arti has shut down.
384    #[error("Connection closed")]
385    ConnectionClosed,
386}
387
388impl From<crate::msgs::response::DecodeResponseError> for ShutdownError {
389    fn from(value: crate::msgs::response::DecodeResponseError) -> Self {
390        use crate::msgs::response::DecodeResponseError::*;
391        use ShutdownError as E;
392        match value {
393            JsonProtocolViolation(e) => E::ProtocolViolated(e.to_string()),
394            ProtocolViolation(s) => E::ProtocolViolated(s.to_string()),
395            Fatal(rpc_err) => E::ProtocolViolationReport(rpc_err),
396        }
397    }
398}
399
400/// An error that has occurred while launching an RPC command.
401#[derive(Clone, Debug, thiserror::Error)]
402#[non_exhaustive]
403pub enum ProtoError {
404    /// The RPC connection failed, or was closed by the other side.
405    #[error("RPC connection is shut down")]
406    Shutdown(#[from] ShutdownError),
407
408    /// There was a problem in the request we tried to send.
409    #[error("Invalid request")]
410    InvalidRequest(#[from] InvalidRequestError),
411
412    /// We tried to send a request with an ID that was already pending.
413    #[error("Request ID already in use.")]
414    RequestIdInUse,
415
416    /// We tried to wait for or inspect a request that had already succeeded or failed.
417    #[error("Request has already completed (or failed)")]
418    RequestCompleted,
419
420    /// We tried to wait for the same request more than once.
421    ///
422    /// (This should be impossible.)
423    #[error("Internal error: waiting on the same request more than once at a time.")]
424    DuplicateWait,
425
426    /// We got an internal error while trying to encode an RPC request.
427    ///
428    /// (This should be impossible.)
429    #[error("Internal error while encoding request")]
430    CouldNotEncode(#[source] Arc<serde_json::Error>),
431
432    /// We got a response to some internally generated request that wasn't what we expected.
433    #[error("{0}")]
434    InternalRequestFailed(#[source] UnexpectedReply),
435}
436
437/// A set of errors encountered while trying to connect to the Arti process
438#[derive(Clone, Debug, thiserror::Error)]
439pub struct ConnectFailure {
440    /// A list of all the declined connect points we encountered, and how they failed.
441    declined: Vec<(builder::ConnPtDescription, ConnectError)>,
442    /// A description of where we found the final error (if it's an abort.)
443    final_desc: Option<builder::ConnPtDescription>,
444    /// The final error explaining why we couldn't connect.
445    ///
446    /// This is either an abort, an AllAttemptsDeclined, or an error that prevented the
447    /// search process from even beginning.
448    #[source]
449    pub(crate) final_error: ConnectError,
450}
451
452impl std::fmt::Display for ConnectFailure {
453    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454        write!(f, "Unable to connect")?;
455        if !self.declined.is_empty() {
456            write!(
457                f,
458                " ({} attempts failed{})",
459                self.declined.len(),
460                if matches!(self.final_error, ConnectError::AllAttemptsDeclined) {
461                    ""
462                } else {
463                    " before fatal error"
464                }
465            )?;
466        }
467        Ok(())
468    }
469}
470
471impl ConnectFailure {
472    /// If this attempt failed because of a fatal error that made a connect point attempt abort,
473    /// return a description of the origin of that connect point.
474    pub fn fatal_error_origin(&self) -> Option<&builder::ConnPtDescription> {
475        self.final_desc.as_ref()
476    }
477
478    /// For each connect attempt that failed nonfatally, return a description of the
479    /// origin of that connect point, and the error that caused it to fail.
480    pub fn declined_attempt_outcomes(
481        &self,
482    ) -> impl Iterator<Item = (&builder::ConnPtDescription, &ConnectError)> {
483        // Note: this map looks like a no-op, but isn't.
484        self.declined.iter().map(|(a, b)| (a, b))
485    }
486
487    /// Return a helper type to format this error, and all of its internal errors recursively.
488    ///
489    /// Unlike [`tor_error::Report`], this method includes not only fatal errors, but also
490    /// information about connect attempts that failed nonfatally.
491    pub fn display_verbose(&self) -> ConnectFailureVerboseFmt<'_> {
492        ConnectFailureVerboseFmt(self)
493    }
494}
495
496/// Helper type to format a ConnectFailure along with all of its internal errors,
497/// including non-fatal errors.
498#[derive(Debug, Clone)]
499pub struct ConnectFailureVerboseFmt<'a>(&'a ConnectFailure);
500
501impl<'a> std::fmt::Display for ConnectFailureVerboseFmt<'a> {
502    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
503        use tor_error::ErrorReport as _;
504        writeln!(f, "{}:", self.0)?;
505        for (idx, (origin, error)) in self.0.declined_attempt_outcomes().enumerate() {
506            writeln!(f, "  {}. {}: {}", idx + 1, origin, error.report())?;
507        }
508        if let Some(origin) = self.0.fatal_error_origin() {
509            writeln!(
510                f,
511                "  {}. [FATAL] {}: {}",
512                self.0.declined.len() + 1,
513                origin,
514                self.0.final_error.report()
515            )?;
516        } else {
517            writeln!(f, "  - {}", self.0.final_error.report())?;
518        }
519        Ok(())
520    }
521}
522
523/// An error while trying to connect to the Arti process.
524#[derive(Clone, Debug, thiserror::Error)]
525#[non_exhaustive]
526pub enum ConnectError {
527    /// Unable to parse connect points from an environment variable.
528    #[error("Cannot parse connect points from environment variable")]
529    BadEnvironment,
530    /// We were unable to load and/or parse a given connect point.
531    #[error("Unable to load and parse connect point")]
532    CannotParse(#[from] tor_rpc_connect::load::LoadError),
533    /// The path used to specify a connect file couldn't be resolved.
534    #[error("Unable to resolve connect point path")]
535    CannotResolvePath(#[source] tor_config_path::CfgPathError),
536    /// A parsed connect point couldn't be resolved.
537    #[error("Unable to resolve connect point")]
538    CannotResolveConnectPoint(#[from] tor_rpc_connect::ResolveError),
539    /// IO error while connecting to Arti.
540    #[error("Unable to make a connection")]
541    CannotConnect(#[from] tor_rpc_connect::ConnectError),
542    /// The connect point told us to connect via a type of stream we don't know how to support.
543    #[error("Connect point stream type was unsupported")]
544    StreamTypeUnsupported,
545    /// Opened a connection, but didn't get a banner message.
546    ///
547    /// (This isn't a `BadMessage`, since it is likelier to represent something that isn't
548    /// pretending to be Arti at all than it is to be a malfunctioning Arti.)
549    #[error("Did not receive expected banner message upon connecting")]
550    InvalidBanner,
551    /// All attempted connect points were declined, and none were aborted.
552    #[error("All connect points were declined (or there were none)")]
553    AllAttemptsDeclined,
554    /// A connect file or directory was given as a relative path.
555    /// (Only absolute paths are supported).
556    #[error("Connect file was given as a relative path.")]
557    RelativeConnectFile,
558    /// One of our authentication messages received an error.
559    #[error("Received an error while trying to authenticate: {0}")]
560    AuthenticationFailed(ErrorResponse),
561    /// The connect point uses an RPC authentication type we don't support.
562    #[error("Authentication type is not supported")]
563    AuthenticationNotSupported,
564    /// We couldn't decode one of the responses we got.
565    #[error("Message not in expected format")]
566    BadMessage(#[source] Arc<serde_json::Error>),
567    /// A protocol error occurred during negotiations.
568    #[error("Error while negotiating with Arti")]
569    ProtoError(#[from] ProtoError),
570    /// The server thinks it is listening on an address where we don't expect to find it.
571    /// This can be misconfiguration or an attempted MITM attack.
572    #[error("We connected to the server at {ours}, but it thinks it's listening at {theirs}")]
573    ServerAddressMismatch {
574        /// The address we think the server has
575        ours: String,
576        /// The address that the server says it has.
577        theirs: String,
578    },
579    /// The server tried to prove knowledge of a cookie file, but its proof was incorrect.
580    #[error("Server's cookie MAC was not as expected.")]
581    CookieMismatch,
582    /// We were unable to access the configured cookie file.
583    #[error("Unable to load secret cookie value")]
584    LoadCookie(#[from] CookieAccessError),
585}
586
587impl HasClientErrorAction for ConnectError {
588    fn client_action(&self) -> tor_rpc_connect::ClientErrorAction {
589        use ConnectError as E;
590        use tor_rpc_connect::ClientErrorAction as A;
591        match self {
592            E::BadEnvironment => A::Abort,
593            E::CannotParse(e) => e.client_action(),
594            E::CannotResolvePath(_) => A::Abort,
595            E::CannotResolveConnectPoint(e) => e.client_action(),
596            E::CannotConnect(e) => e.client_action(),
597            E::StreamTypeUnsupported => A::Decline,
598            E::InvalidBanner => A::Decline,
599            E::RelativeConnectFile => A::Abort,
600            E::AuthenticationFailed(_) => A::Decline,
601            // TODO RPC: Is this correct?  This error can also occur when
602            // we are talking to something other than an RPC server.
603            E::BadMessage(_) => A::Abort,
604            E::ProtoError(e) => e.client_action(),
605            E::AllAttemptsDeclined => A::Abort,
606            E::AuthenticationNotSupported => A::Decline,
607            E::ServerAddressMismatch { .. } => A::Abort,
608            E::CookieMismatch => A::Abort,
609            E::LoadCookie(e) => e.client_action(),
610        }
611    }
612}
613
614impl HasClientErrorAction for ProtoError {
615    fn client_action(&self) -> tor_rpc_connect::ClientErrorAction {
616        use ProtoError as E;
617        use tor_rpc_connect::ClientErrorAction as A;
618        match self {
619            E::Shutdown(_) => A::Decline,
620            E::InternalRequestFailed(_) => A::Decline,
621            // These are always internal errors if they occur while negotiating a connection to RPC,
622            // which is the context we care about for `HasClientErrorAction`.
623            E::InvalidRequest(_)
624            | E::RequestIdInUse
625            | E::RequestCompleted
626            | E::DuplicateWait
627            | E::CouldNotEncode(_) => A::Abort,
628        }
629    }
630}
631
632/// In response to a request that we generated internally,
633/// Arti gave a reply that we did not understand.
634///
635/// This could be due to a bug in this library, a bug in Arti,
636/// or a compatibility issue between the two.
637#[derive(Clone, Debug, thiserror::Error)]
638#[error("In response to our request {request:?}, Arti gave the unexpected reply {reply:?}")]
639pub struct UnexpectedReply {
640    /// The request we sent.
641    request: String,
642    /// The response we got.
643    reply: String,
644    /// What was wrong with the response.
645    #[source]
646    problem: UnexpectedReplyProblem,
647}
648
649/// Underlying reason for an UnexpectedReply
650#[derive(Clone, Debug, thiserror::Error)]
651enum UnexpectedReplyProblem {
652    /// There was a json failure while trying to decode the response:
653    /// the result type was not what we expected.
654    #[error("Cannot decode as correct JSON type")]
655    CannotDecode(Arc<serde_json::Error>),
656    /// Arti replied with an RPC error in a context no error should have been possible.
657    #[error("Unexpected error")]
658    ErrorNotExpected,
659}
660
661/// Arguments to a request that takes no parameters.
662#[derive(serde::Serialize, Debug)]
663struct NoParams {}
664
665/// A reply with no data.
666#[derive(serde::Deserialize, Debug)]
667struct EmptyReply {}
668
669#[cfg(test)]
670mod test {
671    // @@ begin test lint list maintained by maint/add_warning @@
672    #![allow(clippy::bool_assert_comparison)]
673    #![allow(clippy::clone_on_copy)]
674    #![allow(clippy::dbg_macro)]
675    #![allow(clippy::mixed_attributes_style)]
676    #![allow(clippy::print_stderr)]
677    #![allow(clippy::print_stdout)]
678    #![allow(clippy::single_char_pattern)]
679    #![allow(clippy::unwrap_used)]
680    #![allow(clippy::unchecked_time_subtraction)]
681    #![allow(clippy::useless_vec)]
682    #![allow(clippy::needless_pass_by_value)]
683    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
684
685    use std::{sync::atomic::AtomicUsize, thread, time::Duration};
686
687    use io::{BufRead as _, BufReader, Write as _};
688    use rand::{Rng as _, SeedableRng as _, seq::SliceRandom as _};
689    use tor_basic_utils::{RngExt as _, test_rng::testing_rng};
690
691    use crate::{
692        msgs::request::{JsonMap, Request, ValidatedRequest},
693        nb_stream::PollingStream,
694    };
695
696    use super::*;
697
698    /// helper: Return a dummy RpcConn, along with a socketpair for it to talk to.
699    fn dummy_connected() -> (RpcConn, crate::testing::SocketpairStream) {
700        let (s1, s2) = crate::testing::construct_socketpair().unwrap();
701        let conn = RpcConn::new(PollingStream::new(s1).unwrap());
702
703        (conn, s2)
704    }
705
706    fn write_val(w: &mut impl io::Write, v: &serde_json::Value) {
707        let mut enc = serde_json::to_string(v).unwrap();
708        enc.push('\n');
709        w.write_all(enc.as_bytes()).unwrap();
710    }
711
712    #[test]
713    fn simple() {
714        let (conn, sock) = dummy_connected();
715
716        let user_thread = thread::spawn(move || {
717            let response1 = conn
718                .execute_internal_ok::<JsonMap>(
719                    r#"{"obj":"fred","method":"arti:x-frob","params":{}}"#,
720                )
721                .unwrap();
722            (response1, conn)
723        });
724
725        let fake_arti_thread = thread::spawn(move || {
726            let mut sock = BufReader::new(sock);
727            let mut s = String::new();
728            let _len = sock.read_line(&mut s).unwrap();
729            let request = ValidatedRequest::from_string_strict(s.as_ref()).unwrap();
730            let response = serde_json::json!({
731                "id": request.id().clone(),
732                "result": { "xyz" : 3 }
733            });
734            write_val(sock.get_mut(), &response);
735            sock // prevent close
736        });
737
738        let _sock = fake_arti_thread.join().unwrap();
739        let (map, _conn) = user_thread.join().unwrap();
740        assert_eq!(map.get("xyz"), Some(&serde_json::Value::Number(3.into())));
741    }
742
743    #[test]
744    fn complex() {
745        use std::sync::atomic::Ordering::SeqCst;
746        let n_threads = 16;
747        let n_commands_per_thread = 128;
748        let n_commands_total = n_threads * n_commands_per_thread;
749        let n_completed = Arc::new(AtomicUsize::new(0));
750
751        let (conn, sock) = dummy_connected();
752        let conn = Arc::new(conn);
753        let mut user_threads = Vec::new();
754        let mut rng = testing_rng();
755
756        // -------
757        // User threads: Make a bunch of requests.
758        for th_idx in 0..n_threads {
759            let conn = Arc::clone(&conn);
760            let n_completed = Arc::clone(&n_completed);
761            let mut rng = rand_chacha::ChaCha12Rng::from_seed(rng.random());
762            let th = thread::spawn(move || {
763                for cmd_idx in 0..n_commands_per_thread {
764                    // We are spawning a bunch of worker threads, each of which will run a number of
765                    // commands in sequence.  Each command will be a request that gets optional
766                    // updates, and an error or a success.
767                    // We will double-check that each request gets the response it asked for.
768                    let s = format!("{}:{}", th_idx, cmd_idx);
769                    let want_updates: bool = rng.random();
770                    let want_failure: bool = rng.random();
771                    let req = serde_json::json!({
772                        "obj":"fred",
773                        "method":"arti:x-echo",
774                        "meta": {
775                            "updates": want_updates,
776                        },
777                        "params": {
778                            "val": &s,
779                            "fail": want_failure,
780                        },
781                    });
782                    let req = serde_json::to_string(&req).unwrap();
783
784                    // Wait for a final response, processing updates if we asked for them.
785                    let mut n_updates = 0;
786                    let outcome = conn
787                        .execute_with_updates(&req, |_update| {
788                            n_updates += 1;
789                        })
790                        .unwrap();
791                    assert_eq!(n_updates > 0, want_updates);
792
793                    // See if we liked the final response.
794                    if want_failure {
795                        let e = outcome.unwrap_err().decode();
796                        assert_eq!(e.message(), "You asked me to fail");
797                        assert_eq!(i32::from(e.code()), 33);
798                        assert_eq!(
799                            e.kinds_iter().collect::<Vec<_>>(),
800                            vec!["Example".to_string()]
801                        );
802                    } else {
803                        let success = outcome.unwrap();
804                        let map = success.decode::<JsonMap>().unwrap();
805                        assert_eq!(map.get("echo"), Some(&serde_json::Value::String(s)));
806                    }
807                    n_completed.fetch_add(1, SeqCst);
808                    if rng.random::<f32>() < 0.02 {
809                        thread::sleep(Duration::from_millis(3));
810                    }
811                }
812            });
813            user_threads.push(th);
814        }
815
816        #[derive(serde::Deserialize, Debug)]
817        struct Echo {
818            val: String,
819            fail: bool,
820        }
821
822        // -----
823        // Worker thread: handles user requests.
824        let worker_rng = rand_chacha::ChaCha12Rng::from_seed(rng.random());
825        let worker_thread = thread::spawn(move || {
826            let mut rng = worker_rng;
827            let mut sock = BufReader::new(sock);
828            let mut pending: Vec<Request<Echo>> = Vec::new();
829            let mut n_received = 0;
830
831            // How many requests do we buffer before we shuffle them and answer them out-of-order?
832            let scramble_factor = 7;
833            // After receiving how many requests do we stop shuffling requests?
834            //
835            // (Our shuffling algorithm can deadlock us otherwise.)
836            let scramble_threshold =
837                n_commands_total - (n_commands_per_thread + 1) * scramble_factor;
838
839            'outer: loop {
840                let flush_pending_at = if n_received >= scramble_threshold {
841                    1
842                } else {
843                    scramble_factor
844                };
845
846                // Queue a handful of requests in "pending"
847                while pending.len() < flush_pending_at {
848                    let mut buf = String::new();
849                    if sock.read_line(&mut buf).unwrap() == 0 {
850                        break 'outer;
851                    }
852                    n_received += 1;
853                    let req: Request<Echo> = serde_json::from_str(&buf).unwrap();
854                    pending.push(req);
855                }
856
857                // Handle the requests in "pending" in random order.
858                let mut handling = std::mem::take(&mut pending);
859                handling.shuffle(&mut rng);
860
861                for req in handling {
862                    if req.meta.unwrap_or_default().updates {
863                        let n_updates = rng.gen_range_checked(1..4).unwrap();
864                        for _ in 0..n_updates {
865                            let up = serde_json::json!({
866                                "id": req.id.clone(),
867                                "update": {
868                                    "hello": req.params.val.clone(),
869                                }
870                            });
871                            write_val(sock.get_mut(), &up);
872                        }
873                    }
874
875                    let response = if req.params.fail {
876                        serde_json::json!({
877                            "id": req.id.clone(),
878                            "error": { "message": "You asked me to fail", "code": 33, "kinds": ["Example"], "data": req.params.val },
879                        })
880                    } else {
881                        serde_json::json!({
882                            "id": req.id.clone(),
883                            "result": {
884                                "echo": req.params.val
885                            }
886                        })
887                    };
888                    write_val(sock.get_mut(), &response);
889                }
890            }
891        });
892        drop(conn);
893        for t in user_threads {
894            t.join().unwrap();
895        }
896
897        worker_thread.join().unwrap();
898
899        assert_eq!(n_completed.load(SeqCst), n_commands_total);
900    }
901
902    #[test]
903    fn arti_socket_closed() {
904        // Here we send a bunch of requests and then close the socket without answering them.
905        //
906        // Every request should get a ProtoError::Shutdown.
907        let n_threads = 16;
908
909        let (conn, sock) = dummy_connected();
910        let conn = Arc::new(conn);
911        let mut user_threads = Vec::new();
912        for _ in 0..n_threads {
913            let conn = Arc::clone(&conn);
914            let th = thread::spawn(move || {
915                // We are spawning a bunch of worker threads, each of which will run a number of
916                // We will double-check that each request gets the response it asked for.
917                let req = serde_json::json!({
918                    "obj":"fred",
919                    "method":"arti:x-echo",
920                    "params":{}
921                });
922                let req = serde_json::to_string(&req).unwrap();
923                let outcome = conn.execute(&req);
924                if !matches!(
925                    &outcome,
926                    Err(ProtoError::Shutdown(ShutdownError::Write(_)))
927                        | Err(ProtoError::Shutdown(ShutdownError::Read(_))),
928                ) {
929                    dbg!(&outcome);
930                }
931
932                assert!(matches!(
933                    outcome,
934                    Err(ProtoError::Shutdown(ShutdownError::Write(_)))
935                        | Err(ProtoError::Shutdown(ShutdownError::Read(_)))
936                        | Err(ProtoError::Shutdown(ShutdownError::ConnectionClosed))
937                ));
938            });
939            user_threads.push(th);
940        }
941
942        drop(sock);
943
944        for t in user_threads {
945            t.join().unwrap();
946        }
947    }
948
949    /// Send a bunch of requests and then send back a single reply.
950    ///
951    /// That reply should cause every request to get closed.
952    fn proto_err_with_msg<F>(msg: &str, outcome_ok: F)
953    where
954        F: Fn(ProtoError) -> bool,
955    {
956        let n_threads = 16;
957
958        let (conn, mut sock) = dummy_connected();
959        let conn = Arc::new(conn);
960        let mut user_threads = Vec::new();
961        for _ in 0..n_threads {
962            let conn = Arc::clone(&conn);
963            let th = thread::spawn(move || {
964                // We are spawning a bunch of worker threads, each of which will run a number of
965                // We will double-check that each request gets the response it asked for.
966                let req = serde_json::json!({
967                    "obj":"fred",
968                    "method":"arti:x-echo",
969                    "params":{}
970                });
971                let req = serde_json::to_string(&req).unwrap();
972                conn.execute(&req)
973            });
974            user_threads.push(th);
975        }
976
977        sock.write_all(msg.as_bytes()).unwrap();
978
979        for t in user_threads {
980            let outcome = t.join().unwrap();
981            assert!(outcome_ok(outcome.unwrap_err()));
982        }
983    }
984
985    #[test]
986    fn syntax_error() {
987        proto_err_with_msg("this is not json\n", |outcome| {
988            matches!(
989                outcome,
990                ProtoError::Shutdown(ShutdownError::ProtocolViolated(_))
991            )
992        });
993    }
994
995    #[test]
996    fn fatal_error() {
997        let j = serde_json::json!({
998            "error":{ "message": "This test is doomed", "code": 413, "kinds": ["Example"], "data": {} },
999        });
1000        let mut s = serde_json::to_string(&j).unwrap();
1001        s.push('\n');
1002
1003        proto_err_with_msg(&s, |outcome| {
1004            matches!(
1005                outcome,
1006                ProtoError::Shutdown(ShutdownError::ProtocolViolationReport(_))
1007            )
1008        });
1009    }
1010}