pub(crate) mod auth;
mod methods;
use std::{
collections::HashMap,
io::Error as IoError,
pin::Pin,
sync::{Arc, Mutex, RwLock, Weak},
};
use asynchronous_codec::JsonCodecError;
use derive_deftly::Deftly;
use futures::{
AsyncWriteExt as _, FutureExt, Sink, SinkExt as _, StreamExt,
channel::mpsc,
stream::{FusedStream, FuturesUnordered},
};
use rpc::dispatch::BoxedUpdateSink;
use serde_json::error::Category as JsonErrorCategory;
use tor_async_utils::{SinkExt as _, mpsc_channel_no_memquota};
use crate::{
RpcMgr,
cancel::{self, Cancel, CancelHandle},
err::RequestParseError,
globalid::{GlobalId, MacKey},
msgs::{BoxedResponse, FlexibleRequest, ReqMeta, Request, RequestId, ResponseBody},
objmap::{GenIdx, ObjMap},
};
use tor_rpcbase::templates::*;
use tor_rpcbase::{self as rpc, RpcError};
#[derive(Deftly)]
#[derive_deftly(Object)]
pub struct Connection {
inner: Mutex<Inner>,
dispatch_table: Arc<RwLock<rpc::DispatchTable>>,
connection_id: ConnectionId,
global_id_mac_key: MacKey,
mgr: Weak<RpcMgr>,
require_auth: tor_rpc_connect::auth::RpcAuth,
}
struct Inner {
inflight: HashMap<RequestId, Option<CancelHandle>>,
objects: ObjMap,
this_connection: Option<Weak<Connection>>,
}
const UPDATE_CHAN_SIZE: usize = 128;
pub(crate) type BoxedRequestStream = Pin<
Box<dyn FusedStream<Item = Result<FlexibleRequest, asynchronous_codec::JsonCodecError>> + Send>,
>;
pub(crate) type BoxedResponseSink =
Pin<Box<dyn Sink<BoxedResponse, Error = asynchronous_codec::JsonCodecError> + Send>>;
#[derive(
Copy,
Clone,
Debug,
Eq,
PartialEq,
Hash,
derive_more::From,
derive_more::Into,
derive_more::AsRef,
)]
pub(crate) struct ConnectionId([u8; 16]);
impl ConnectionId {
pub(crate) const LEN: usize = 16;
}
impl Connection {
const CONNECTION_OBJ_ID: &'static str = "connection";
pub(crate) fn new(
connection_id: ConnectionId,
dispatch_table: Arc<RwLock<rpc::DispatchTable>>,
global_id_mac_key: MacKey,
mgr: Weak<RpcMgr>,
require_auth: tor_rpc_connect::auth::RpcAuth,
) -> Arc<Self> {
Arc::new_cyclic(|this_connection| Self {
inner: Mutex::new(Inner {
inflight: HashMap::new(),
objects: ObjMap::new(),
this_connection: Some(Weak::clone(this_connection)),
}),
dispatch_table,
connection_id,
global_id_mac_key,
mgr,
require_auth,
})
}
fn id_into_local_idx(&self, id: &rpc::ObjectId) -> Result<GenIdx, rpc::LookupError> {
if let Some(global_id) = GlobalId::try_decode(&self.global_id_mac_key, id)? {
if global_id.connection == self.connection_id {
Ok(global_id.local_id)
} else {
Err(rpc::LookupError::NoObject(id.clone()))
}
} else {
Ok(GenIdx::try_decode(id)?)
}
}
pub(crate) fn lookup_object(
&self,
id: &rpc::ObjectId,
) -> Result<Arc<dyn rpc::Object>, rpc::LookupError> {
if id.as_ref() == Self::CONNECTION_OBJ_ID {
let this = self
.inner
.lock()
.expect("lock poisoned")
.this_connection
.as_ref()
.ok_or_else(|| rpc::LookupError::NoObject(id.clone()))?
.upgrade()
.ok_or_else(|| rpc::LookupError::NoObject(id.clone()))?;
Ok(this as Arc<_>)
} else {
let local_id = self.id_into_local_idx(id)?;
self.lookup_by_idx(local_id)
.ok_or(rpc::LookupError::NoObject(id.clone()))
}
}
pub(crate) fn lookup_by_idx(&self, idx: crate::objmap::GenIdx) -> Option<Arc<dyn rpc::Object>> {
let inner = self.inner.lock().expect("lock poisoned");
inner.objects.lookup(idx)
}
fn remove_request(&self, id: &RequestId) {
let mut inner = self.inner.lock().expect("lock poisoned");
inner.inflight.remove(id);
}
fn register_request(&self, id: RequestId, handle: Option<CancelHandle>) {
let mut inner = self.inner.lock().expect("lock poisoned");
inner.inflight.insert(id, handle);
}
fn cancel_request(&self, id: &RequestId) -> Result<(), CancelError> {
let mut inner = self.inner.lock().expect("lock poisoned");
match inner.inflight.remove(id) {
Some(Some(handle)) => {
drop(inner);
handle.cancel()?;
Ok(())
}
Some(None) => {
inner.inflight.insert(id.clone(), None);
Err(CancelError::CannotCancelRequest)
}
None => Err(CancelError::RequestNotFound),
}
}
pub async fn run<IN, OUT>(
self: Arc<Self>,
input: IN,
mut output: OUT,
) -> Result<(), ConnectionError>
where
IN: futures::AsyncRead + Send + Sync + Unpin + 'static,
OUT: futures::AsyncWrite + Send + Sync + Unpin + 'static,
{
const BANNER: &[u8] = b"{\"arti_rpc\":{}}\n";
output
.write_all(BANNER)
.await
.map_err(|e| ConnectionError::WriteFailed(Arc::new(e)))?;
let write = Box::pin(asynchronous_codec::FramedWrite::new(
output,
crate::codecs::JsonLinesEncoder::<BoxedResponse>::default(),
));
let read = Box::pin(
asynchronous_codec::FramedRead::new(
input,
asynchronous_codec::JsonCodec::<(), FlexibleRequest>::new(),
)
.fuse(),
);
self.run_loop(read, write).await
}
pub(crate) async fn run_loop(
self: Arc<Self>,
mut request_stream: BoxedRequestStream,
mut response_sink: BoxedResponseSink,
) -> Result<(), ConnectionError> {
let (tx_response, mut rx_response) =
mpsc_channel_no_memquota::<BoxedResponse>(UPDATE_CHAN_SIZE);
let mut finished_requests = FuturesUnordered::new();
finished_requests.push(futures::future::pending().boxed());
struct Continue;
let outcome = async {
loop {
let _: Continue = futures::select! {
r = finished_requests.next() => {
let () = r.expect("Somehow, future::pending() terminated.");
Continue
}
r = rx_response.next() => {
let update = r.expect("Somehow, tx_update got closed.");
response_sink.send(update).await.map_err(ConnectionError::writing)?;
Continue
}
req = request_stream.next() => {
match req {
None => {
return Ok(());
}
Some(Err(e)) => {
return Err(ConnectionError::from_read_error(e));
}
Some(Ok(FlexibleRequest::Invalid(bad_req))) => {
let response = BoxedResponse::from_error(
bad_req.id().cloned(), bad_req.error()
);
response_sink
.send(response)
.await
.map_err( ConnectionError::writing)?;
if bad_req.id().is_none() {
return Err(bad_req.error().into());
}
Continue
}
Some(Ok(FlexibleRequest::Valid(req))) => {
let tx = tx_response.clone();
let fut = self.run_method_and_deliver_response(tx, req);
finished_requests.push(fut.boxed());
Continue
}
}
}
};
}
}
.await;
match outcome {
Err(e) if e.is_connection_close() => Ok(()),
other => other,
}
}
async fn run_method_and_deliver_response(
self: &Arc<Self>,
mut tx_response: mpsc::Sender<BoxedResponse>,
request: Request,
) {
let Request {
id,
obj,
meta,
method,
} = request;
let update_sender: BoxedUpdateSink = if meta.updates {
let id_clone = id.clone();
let sink =
tx_response
.clone()
.with_fn(move |obj: Box<dyn erased_serde::Serialize + Send>| {
Result::<BoxedResponse, _>::Ok(BoxedResponse {
id: Some(id_clone.clone()),
body: ResponseBody::Update(obj),
})
});
Box::pin(sink)
} else {
let sink = futures::sink::drain().sink_err_into();
Box::pin(sink)
};
let is_cancellable = method.is_cancellable();
let fut = self.run_method_lowlevel(update_sender, obj, method, meta);
let outcome = if is_cancellable {
let (handle, fut) = Cancel::new(fut);
self.register_request(id.clone(), Some(handle));
fut.await
} else {
self.register_request(id.clone(), None);
Ok(fut.await)
};
let body = match outcome {
Ok(Ok(value)) => ResponseBody::Success(value),
Ok(Err(err)) => {
if err.is_internal() {
tracing::warn!(
"Reporting an internal error on an RPC connection: {:?}",
err
);
}
ResponseBody::Error(Box::new(err))
}
Err(_cancelled) => ResponseBody::Error(Box::new(rpc::RpcError::from(RequestCancelled))),
};
let _ignore_err = tx_response
.send(BoxedResponse {
id: Some(id.clone()),
body,
})
.await;
self.remove_request(&id);
}
async fn run_method_lowlevel(
self: &Arc<Self>,
tx_updates: rpc::dispatch::BoxedUpdateSink,
obj_id: rpc::ObjectId,
method: Box<dyn rpc::DeserMethod>,
meta: ReqMeta,
) -> Result<Box<dyn erased_serde::Serialize + Send + 'static>, rpc::RpcError> {
let obj = self.lookup_object(&obj_id)?;
if !meta.require.is_empty() {
return Err(MissingFeaturesError(meta.require).into());
}
let context: Arc<dyn rpc::Context> = self.clone() as Arc<_>;
let invoke_future =
rpc::invoke_rpc_method(context, &obj_id, obj, method.upcast_box(), tx_updates)?;
invoke_future.await
}
pub(crate) fn mgr(&self) -> Result<Arc<RpcMgr>, MgrDisappearedError> {
self.mgr
.upgrade()
.ok_or(MgrDisappearedError::RpcMgrDisappeared)
}
}
#[derive(Clone, Debug, thiserror::Error)]
#[error("Required features not available")]
struct MissingFeaturesError(
Vec<String>,
);
impl From<MissingFeaturesError> for RpcError {
fn from(err: MissingFeaturesError) -> Self {
let mut e = RpcError::new(
err.to_string(),
tor_rpcbase::RpcErrorKind::FeatureNotPresent,
);
e.set_datum("rpc:unsupported_features".to_string(), err.0)
.expect("invalid keyword");
e
}
}
#[derive(Clone, Debug, thiserror::Error)]
#[non_exhaustive]
pub enum ConnectionError {
#[error("Could not write to connection")]
WriteFailed(#[source] Arc<IoError>),
#[error("Problem reading from connection")]
ReadFailed(#[source] Arc<IoError>),
#[error("Unable to decode request from connection")]
DecodeFailed(#[source] Arc<serde_json::Error>),
#[error("Unable to encode response onto connection")]
EncodeFailed(#[source] Arc<serde_json::Error>),
#[error("Unrecoverable problem from parsed request")]
RequestParseFailed(#[from] RequestParseError),
}
impl ConnectionError {
fn writing(error: JsonCodecError) -> Self {
match error {
JsonCodecError::Io(e) => Self::WriteFailed(Arc::new(e)),
JsonCodecError::Json(e) => Self::EncodeFailed(Arc::new(e)),
}
}
fn is_connection_close(&self) -> bool {
use JsonErrorCategory as JK;
use std::io::ErrorKind as IK;
#[allow(clippy::match_like_matches_macro)]
match self {
Self::ReadFailed(e) | Self::WriteFailed(e) => match e.kind() {
IK::UnexpectedEof | IK::ConnectionAborted | IK::BrokenPipe => true,
_ => false,
},
Self::DecodeFailed(e) => match e.classify() {
JK::Eof => true,
_ => false,
},
_ => false,
}
}
fn from_read_error(error: JsonCodecError) -> Self {
match error {
JsonCodecError::Io(e) => Self::ReadFailed(Arc::new(e)),
JsonCodecError::Json(e) => Self::DecodeFailed(Arc::new(e)),
}
}
}
#[derive(Clone, Debug, thiserror::Error, serde::Serialize)]
pub(crate) enum MgrDisappearedError {
#[error("RPC manager disappeared; Arti is shutting down?")]
RpcMgrDisappeared,
}
impl tor_error::HasKind for MgrDisappearedError {
fn kind(&self) -> tor_error::ErrorKind {
tor_error::ErrorKind::ArtiShuttingDown
}
}
impl rpc::Context for Connection {
fn lookup_object(&self, id: &rpc::ObjectId) -> Result<Arc<dyn rpc::Object>, rpc::LookupError> {
Connection::lookup_object(self, id)
}
fn register_owned(&self, object: Arc<dyn rpc::Object>) -> rpc::ObjectId {
let use_global_id = object.expose_outside_of_session();
let local_id = self
.inner
.lock()
.expect("Lock poisoned")
.objects
.insert_strong(object);
if use_global_id {
GlobalId::new(self.connection_id, local_id).encode(&self.global_id_mac_key)
} else {
local_id.encode()
}
}
fn release_owned(&self, id: &rpc::ObjectId) -> Result<(), rpc::LookupError> {
let removed_some = if id.as_ref() == Self::CONNECTION_OBJ_ID {
self.inner
.lock()
.expect("Lock poisoned")
.this_connection
.take()
.is_some()
} else {
let idx = self.id_into_local_idx(id)?;
if !idx.is_strong() {
return Err(rpc::LookupError::WrongType(id.clone()));
}
self.inner
.lock()
.expect("Lock poisoned")
.objects
.remove(idx)
.is_some()
};
if removed_some {
Ok(())
} else {
Err(rpc::LookupError::NoObject(id.clone()))
}
}
fn dispatch_table(&self) -> &Arc<std::sync::RwLock<rpc::DispatchTable>> {
&self.dispatch_table
}
}
#[derive(thiserror::Error, Clone, Debug, serde::Serialize)]
#[error("RPC request was cancelled")]
pub(crate) struct RequestCancelled;
impl From<RequestCancelled> for RpcError {
fn from(_: RequestCancelled) -> Self {
RpcError::new(
"Request cancelled".into(),
rpc::RpcErrorKind::RequestCancelled,
)
}
}
#[derive(thiserror::Error, Clone, Debug, serde::Serialize)]
pub(crate) enum CancelError {
#[error("RPC request not found")]
RequestNotFound,
#[error("Uncancellable request")]
CannotCancelRequest,
#[error("Request somehow cancelled twice!")]
AlreadyCancelled,
}
impl From<cancel::CannotCancel> for CancelError {
fn from(value: cancel::CannotCancel) -> Self {
use CancelError as CE;
use cancel::CannotCancel as CC;
match value {
CC::Cancelled => CE::AlreadyCancelled,
CC::Finished => CE::RequestNotFound,
}
}
}
impl From<CancelError> for RpcError {
fn from(err: CancelError) -> Self {
use CancelError as CE;
use rpc::RpcErrorKind as REK;
let code = match err {
CE::RequestNotFound => REK::RequestError,
CE::CannotCancelRequest => REK::RequestError,
CE::AlreadyCancelled => REK::InternalError,
};
RpcError::new(err.to_string(), code)
}
}