use std::sync::{Arc, Mutex};
use futures::io::BufReader;
use futures::{
AsyncRead, AsyncWrite, Future, FutureExt as _, Stream, StreamExt as _, select_biased,
};
use itertools::iproduct;
use oneshot_fused_workaround as oneshot;
use safelog::sensitive as sv;
use std::collections::HashMap;
use std::io::Error as IoError;
use strum::IntoEnumIterator;
use tor_cell::relaycell::msg as relaymsg;
use tor_error::{ErrorKind, HasKind, debug_report};
use tor_hsservice::{HsNickname, RendRequest, StreamRequest};
use tor_log_ratelim::log_ratelim;
use tor_proto::client::stream::{DataStream, IncomingStreamRequest};
use tor_rtcompat::{Runtime, SpawnExt as _};
use crate::config::{
Encapsulation, ProxyAction, ProxyActionDiscriminants, ProxyConfig, TargetAddr,
};
#[derive(Debug)]
pub struct OnionServiceReverseProxy {
state: Mutex<State>,
}
#[derive(Debug)]
struct State {
config: ProxyConfig,
shutdown_tx: Option<oneshot::Sender<void::Void>>,
shutdown_rx: futures::future::Shared<oneshot::Receiver<void::Void>>,
}
#[derive(Clone, Debug, thiserror::Error)]
#[non_exhaustive]
pub enum HandleRequestsError {
#[error("Unable to spawn a task")]
Spawn(#[source] Arc<futures::task::SpawnError>),
}
impl HasKind for HandleRequestsError {
fn kind(&self) -> ErrorKind {
match self {
HandleRequestsError::Spawn(e) => e.kind(),
}
}
}
impl OnionServiceReverseProxy {
pub fn new(config: ProxyConfig) -> Arc<Self> {
let (shutdown_tx, shutdown_rx) = oneshot::channel();
Arc::new(Self {
state: Mutex::new(State {
config,
shutdown_tx: Some(shutdown_tx),
shutdown_rx: shutdown_rx.shared(),
}),
})
}
pub fn reconfigure(
&self,
config: ProxyConfig,
how: tor_config::Reconfigure,
) -> Result<(), tor_config::ReconfigureError> {
if how == tor_config::Reconfigure::CheckAllOrNothing {
return Ok(());
}
let mut state = self.state.lock().expect("poisoned lock");
state.config = config;
Ok(())
}
pub fn shutdown(&self) {
let mut state = self.state.lock().expect("poisoned lock");
let _ = state.shutdown_tx.take();
}
pub async fn handle_requests<R, S>(
&self,
runtime: R,
nickname: HsNickname,
requests: S,
) -> Result<(), HandleRequestsError>
where
R: Runtime,
S: Stream<Item = RendRequest> + Unpin,
{
let mut stream_requests = tor_hsservice::handle_rend_requests(requests).fuse();
let mut shutdown_rx = self
.state
.lock()
.expect("poisoned lock")
.shutdown_rx
.clone()
.fuse();
let nickname = Arc::new(nickname);
#[cfg(feature = "metrics")]
#[derive(Clone, Copy, Eq, PartialEq, Hash)]
enum CounterSelector {
Ret(Result<(), ()>),
Total,
}
#[cfg(feature = "metrics")]
let metrics_counters = {
use CounterSelector as CS;
let counters = iproduct!(
ProxyActionDiscriminants::iter(),
[
(CS::Total, "arti_hss_proxy_connections_total"),
(CS::Ret(Ok(())), "arti_hss_proxy_connections_ok_total"),
(CS::Ret(Err(())), "arti_hss_proxy_connections_failed_total"),
],
)
.map(|(action, (outcome, name))| {
let k = (action, outcome);
let nickname = nickname.to_string();
let action: &str = action.into();
let v = metrics::counter!(name, "nickname" => nickname, "action" => action);
(k, v)
})
.collect::<HashMap<(ProxyActionDiscriminants, CounterSelector), _>>();
Arc::new(counters)
};
loop {
let stream_request = select_biased! {
_ = shutdown_rx => return Ok(()),
stream_request = stream_requests.next() => match stream_request {
None => return Ok(()),
Some(s) => s,
}
};
runtime.spawn({
let action = self.choose_action(stream_request.request());
let runtime = runtime.clone();
let nickname = nickname.clone();
let req = stream_request.request().clone();
#[cfg(feature = "metrics")]
let metrics_counters = metrics_counters.clone();
async move {
let outcome =
run_action(runtime, nickname.as_ref(), action.clone(), stream_request).await;
#[cfg(feature = "metrics")]
{
use CounterSelector as CS;
let action = ProxyActionDiscriminants::from(&action);
let outcome = outcome.as_ref().map(|_|()).map_err(|_|());
for outcome in [CS::Total, CS::Ret(outcome)] {
if let Some(counter) = metrics_counters.get(&(action, outcome)) {
counter.increment(1);
} else {
}
}
}
log_ratelim!(
"Performing action on {}", nickname;
outcome;
Err(_) => WARN, "Unable to take action {:?} for request {:?}", sv(action), sv(req)
);
}
})
.map_err(|e| HandleRequestsError::Spawn(Arc::new(e)))?;
}
}
fn choose_action(&self, stream_request: &IncomingStreamRequest) -> ProxyAction {
let port: u16 = match stream_request {
IncomingStreamRequest::Begin(begin) => {
begin.port()
}
other => {
tracing::warn!(
"Rejecting onion service request for invalid command {:?}. Internal error.",
other
);
return ProxyAction::DestroyCircuit;
}
};
self.state
.lock()
.expect("poisoned lock")
.config
.resolve_port_for_begin(port)
.cloned()
.unwrap_or(ProxyAction::DestroyCircuit)
}
}
async fn run_action<R: Runtime>(
runtime: R,
nickname: &HsNickname,
action: ProxyAction,
request: StreamRequest,
) -> Result<(), RequestFailed> {
match action {
ProxyAction::DestroyCircuit => {
request
.shutdown_circuit()
.map_err(RequestFailed::CantDestroy)?;
}
ProxyAction::Forward(encap, target) => match (encap, target) {
(Encapsulation::Simple, ref addr @ TargetAddr::Inet(a)) => {
let rt_clone = runtime.clone();
forward_connection(rt_clone, request, runtime.connect(&a), nickname, addr).await?;
}
},
ProxyAction::RejectStream => {
let end = relaymsg::End::new_with_reason(relaymsg::EndReason::DONE);
request
.reject(end)
.await
.map_err(RequestFailed::CantReject)?;
}
ProxyAction::IgnoreStream => drop(request),
};
Ok(())
}
#[derive(thiserror::Error, Debug, Clone)]
enum RequestFailed {
#[error("Unable to destroy onion service circuit")]
CantDestroy(#[source] tor_error::Bug),
#[error("Unable to reject onion service request")]
CantReject(#[source] tor_hsservice::ClientError),
#[error("Unable to accept onion service connection")]
AcceptRemote(#[source] tor_hsservice::ClientError),
#[error("Unable to spawn task")]
Spawn(#[source] Arc<futures::task::SpawnError>),
}
impl HasKind for RequestFailed {
fn kind(&self) -> ErrorKind {
match self {
RequestFailed::CantDestroy(e) => e.kind(),
RequestFailed::CantReject(e) => e.kind(),
RequestFailed::AcceptRemote(e) => e.kind(),
RequestFailed::Spawn(e) => e.kind(),
}
}
}
const STREAM_BUF_LEN: usize = 4096;
async fn forward_connection<R, FUT, TS>(
runtime: R,
request: StreamRequest,
target_stream_future: FUT,
nickname: &HsNickname,
addr: &TargetAddr,
) -> Result<(), RequestFailed>
where
R: Runtime,
FUT: Future<Output = Result<TS, IoError>>,
TS: AsyncRead + AsyncWrite + Send + 'static,
{
let local_stream = target_stream_future.await.map_err(Arc::new);
log_ratelim!(
"Connecting to {} for onion service {}", sv(addr), nickname;
local_stream
);
let local_stream = match local_stream {
Ok(s) => s,
Err(_) => {
let end = relaymsg::End::new_with_reason(relaymsg::EndReason::DONE);
if let Err(e_rejecting) = request.reject(end).await {
debug_report!(
&e_rejecting,
"Unable to reject onion service request from client"
);
return Err(RequestFailed::CantReject(e_rejecting));
}
return Ok(());
}
};
let onion_service_stream: DataStream = {
let connected = relaymsg::Connected::new_empty();
request
.accept(connected)
.await
.map_err(RequestFailed::AcceptRemote)?
};
let onion_service_stream = BufReader::with_capacity(STREAM_BUF_LEN, onion_service_stream);
let local_stream = BufReader::with_capacity(STREAM_BUF_LEN, local_stream);
runtime
.spawn(
futures_copy::copy_buf_bidirectional(
onion_service_stream,
local_stream,
futures_copy::eof::Close,
futures_copy::eof::Close,
)
.map(|_| ()),
)
.map_err(|e| RequestFailed::Spawn(Arc::new(e)))?;
Ok(())
}