use crate::tls::*;
use crate::*;
use std::collections::HashMap;
use std::net::IpAddr;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
fn default_max_inbound_connections() -> u32 {
20480
}
fn default_max_control_streams() -> u32 {
320
}
fn default_max_control_streams_per_ip() -> u32 {
4
}
fn default_max_relays_per_control() -> u32 {
64
}
fn default_connection_timeout_ms() -> u32 {
1000 * 20
}
#[non_exhaustive]
#[derive(serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tx3RelayConfig {
#[serde(flatten)]
pub tx3_config: Tx3Config,
#[serde(default = "default_max_inbound_connections")]
pub max_inbound_connections: u32,
#[serde(default = "default_max_control_streams")]
pub max_control_streams: u32,
#[serde(default = "default_max_control_streams_per_ip")]
pub max_control_streams_per_ip: u32,
#[serde(default = "default_max_relays_per_control")]
pub max_relays_per_control: u32,
#[serde(default = "default_connection_timeout_ms")]
pub connection_timeout_ms: u32,
}
impl Default for Tx3RelayConfig {
fn default() -> Self {
Self {
tx3_config: Tx3Config::default(),
max_inbound_connections: default_max_inbound_connections(),
max_control_streams: default_max_control_streams(),
max_control_streams_per_ip: default_max_control_streams_per_ip(),
max_relays_per_control: default_max_relays_per_control(),
connection_timeout_ms: default_connection_timeout_ms(),
}
}
}
impl Tx3RelayConfig {
pub fn new() -> Self {
Tx3RelayConfig::default()
}
pub fn with_bind<B>(mut self, bind: B) -> Self
where
B: Into<Tx3Url>,
{
self.tx3_config.bind.push(bind.into());
self
}
}
impl std::ops::Deref for Tx3RelayConfig {
type Target = Tx3Config;
fn deref(&self) -> &Self::Target {
&self.tx3_config
}
}
impl std::ops::DerefMut for Tx3RelayConfig {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.tx3_config
}
}
pub struct Tx3Relay {
config: Arc<Tx3RelayConfig>,
addrs: Vec<Tx3Url>,
shutdown: Arc<tokio::sync::Notify>,
}
impl Drop for Tx3Relay {
fn drop(&mut self) {
self.shutdown.notify_waiters();
}
}
impl Tx3Relay {
pub async fn new(mut config: Tx3RelayConfig) -> Result<Self> {
let shutdown = Arc::new(tokio::sync::Notify::new());
if config.tls.is_none() {
config.tls = Some(TlsConfigBuilder::default().build()?);
}
let state = RelayStateSync::new();
let mut all_bind = Vec::new();
let to_bind = config.bind.drain(..).collect::<Vec<_>>();
let config = Arc::new(config);
let inbound_limit = Arc::new(tokio::sync::Semaphore::new(
config.max_inbound_connections as usize,
));
let control_limit = Arc::new(tokio::sync::Semaphore::new(
config.max_control_streams as usize,
));
for bind in to_bind {
match bind.scheme() {
Tx3Scheme::Tx3rst => {
for addr in bind.socket_addrs().await? {
all_bind.push(bind_tx3_rst(
config.clone(),
addr,
state.clone(),
inbound_limit.clone(),
control_limit.clone(),
shutdown.clone(),
));
}
}
oth => {
return Err(other_err(format!(
"Unsupported Scheme: {}",
oth.as_str()
)))
}
}
}
let addrs = futures::future::try_join_all(all_bind)
.await?
.into_iter()
.flatten()
.collect();
tracing::info!(?addrs, "relay running");
Ok(Self {
config,
addrs,
shutdown,
})
}
pub fn local_tls_cert_digest(&self) -> &TlsCertDigest {
self.config.priv_tls().cert_digest()
}
pub fn local_addrs(&self) -> &[Tx3Url] {
&self.addrs
}
}
async fn bind_tx3_rst(
config: Arc<Tx3RelayConfig>,
addr: SocketAddr,
state: RelayStateSync,
inbound_limit: Arc<tokio::sync::Semaphore>,
control_limit: Arc<tokio::sync::Semaphore>,
shutdown: Arc<tokio::sync::Notify>,
) -> Result<Vec<Tx3Url>> {
let listener = tokio::net::TcpListener::bind(addr).await?;
let addr = listener.local_addr()?;
let mut out = Vec::new();
for a in upgrade_addr(addr)? {
out.push(Tx3Url::new(
url::Url::parse(&format!(
"tx3-rst://{}/{}",
a,
config.priv_tls().cert_digest().to_b64(),
))
.map_err(other_err)?,
));
}
tokio::task::spawn(async move {
loop {
let res = tokio::select! {
biased;
_ = shutdown.notified() => break,
r = listener.accept() => r,
};
match res {
Err(err) => {
tracing::warn!(?err, "accept error");
}
Ok((socket, addr)) => {
let ip = addr.ip();
let con_permit = match inbound_limit
.clone()
.try_acquire_owned()
{
Err(_) => {
tracing::warn!("Dropping incoming connection, max_inbound_connections reached");
continue;
}
Ok(con_permit) => con_permit,
};
let socket = match crate::tcp::tx3_tcp_configure(socket) {
Err(err) => {
tracing::warn!(?err, "tcp_configure error");
continue;
}
Ok(socket) => socket,
};
tokio::task::spawn(process_socket(
config.clone(),
socket,
ip,
state.clone(),
con_permit,
control_limit.clone(),
));
}
}
}
});
Ok(out)
}
enum ControlCmd {
NotifyPending(TlsCertDigest),
}
#[derive(Clone)]
struct ControlInfo {
ctrl_send: tokio::sync::mpsc::Sender<ControlCmd>,
relay_limit: Arc<tokio::sync::Semaphore>,
}
struct PendingInfo {
socket: tokio::net::TcpStream,
relay_permit: tokio::sync::OwnedSemaphorePermit,
}
struct RelayState {
control_channels: HashMap<TlsCertDigest, ControlInfo>,
control_addrs: HashMap<IpAddr, u32>,
pending_tokens: HashMap<TlsCertDigest, PendingInfo>,
}
impl RelayState {
fn new() -> Self {
Self {
control_channels: HashMap::new(),
control_addrs: HashMap::new(),
pending_tokens: HashMap::new(),
}
}
}
#[derive(Clone)]
struct RelayStateSync(Arc<parking_lot::Mutex<RelayState>>);
impl RelayStateSync {
fn new() -> Self {
Self(Arc::new(parking_lot::Mutex::new(RelayState::new())))
}
fn access<R, F>(&self, f: F) -> R
where
F: FnOnce(&mut RelayState) -> R,
{
let mut inner = self.0.lock();
f(&mut *inner)
}
}
async fn process_socket(
config: Arc<Tx3RelayConfig>,
socket: tokio::net::TcpStream,
ip: IpAddr,
state: RelayStateSync,
con_permit: tokio::sync::OwnedSemaphorePermit,
control_limit: Arc<tokio::sync::Semaphore>,
) {
if let Err(err) =
process_socket_err(config, socket, ip, state, con_permit, control_limit)
.await
{
tracing::debug!(?err, "process_socket error");
}
}
async fn process_socket_err(
config: Arc<Tx3RelayConfig>,
mut socket: tokio::net::TcpStream,
ip: IpAddr,
state: RelayStateSync,
con_permit: tokio::sync::OwnedSemaphorePermit,
control_limit: Arc<tokio::sync::Semaphore>,
) -> Result<()> {
let timeout = tokio::time::Instant::now()
+ std::time::Duration::from_millis(config.connection_timeout_ms as u64);
let this_cert = config.priv_tls().cert_digest().clone();
let mut token = [0; 32];
tokio::time::timeout_at(timeout, socket.read_exact(&mut token[..]))
.await??;
let token = TlsCertDigest(Arc::new(token));
if token == this_cert {
let control_permit = match control_limit.try_acquire_owned() {
Err(_) => {
tracing::warn!(
"Dropping incoming connection, max_control_streams reached"
);
return Ok(());
}
Ok(control_permit) => control_permit,
};
return process_relay_control(
config,
socket,
ip,
state,
timeout,
con_permit,
control_permit,
)
.await;
}
let mut splice_token = [0; 32];
use ring::rand::SecureRandom;
ring::rand::SystemRandom::new()
.fill(&mut splice_token[..])
.map_err(|_| other_err("SystemRandomFailure"))?;
let splice_token = TlsCertDigest(Arc::new(splice_token));
enum TokenRes {
HaveControl(tokio::sync::mpsc::Sender<ControlCmd>),
HaveConToken(
tokio::net::TcpStream,
tokio::net::TcpStream,
tokio::sync::OwnedSemaphorePermit,
),
Drop,
}
let res = {
let splice_token = splice_token.clone();
state.access(move |state| {
if let Some(info) = state.control_channels.get(&token) {
let relay_permit = match info.relay_limit.clone().try_acquire_owned() {
Err(_) => {
tracing::debug!("Dropping incoming connection, max_relays_per_control reached");
return TokenRes::Drop;
}
Ok(permit) => permit,
};
tracing::debug!(cert = ?token, splice = ?splice_token, "incoming pending relay");
state.pending_tokens.insert(splice_token, PendingInfo {
socket,
relay_permit,
});
TokenRes::HaveControl(info.ctrl_send.clone())
} else if let Some(info) = state.pending_tokens.remove(&token) {
tracing::debug!(splice = ?token, "incoming relay fulfill");
TokenRes::HaveConToken(socket, info.socket, info.relay_permit)
} else {
TokenRes::Drop
}
})
};
match res {
TokenRes::Drop => (), TokenRes::HaveControl(ctrl_send) => {
{
let state = state.clone();
let splice_token = splice_token.clone();
tokio::task::spawn(async move {
tokio::time::sleep_until(timeout).await;
state.access(move |state| {
state.pending_tokens.remove(&splice_token);
});
});
}
let _ = ctrl_send
.send(ControlCmd::NotifyPending(splice_token))
.await;
}
TokenRes::HaveConToken(mut socket, mut socket2, relay_permit) => {
let _relay_permit = relay_permit;
tokio::io::copy_bidirectional(&mut socket, &mut socket2).await?;
}
}
drop(con_permit);
Ok(())
}
async fn process_relay_control(
config: Arc<Tx3RelayConfig>,
socket: tokio::net::TcpStream,
ip: IpAddr,
state: RelayStateSync,
timeout: tokio::time::Instant,
con_permit: tokio::sync::OwnedSemaphorePermit,
control_permit: tokio::sync::OwnedSemaphorePermit,
) -> Result<()> {
let socket = tokio::time::timeout_at(
timeout,
Tx3Connection::priv_accept(config.priv_tls().clone(), socket),
)
.await??;
let remote_cert = socket.remote_tls_cert_digest().clone();
tracing::debug!(cert = ?remote_cert, "control stream established");
let (ctrl_send, ctrl_recv) = tokio::sync::mpsc::channel(1);
{
let remote_cert2 = remote_cert.clone();
let relay_limit = Arc::new(tokio::sync::Semaphore::new(
config.max_relays_per_control as usize,
));
state.access(move |state| {
{
let ip_count = state.control_addrs.entry(ip).or_default();
if *ip_count < config.max_control_streams_per_ip {
*ip_count += 1;
} else {
tracing::warn!("Dropping incoming connection, max_control_streams_per_ip reached");
return Err(other_err("ExceededMaxCtrlPerIp"));
}
}
if let Some(_old) = state.control_channels.insert(
remote_cert2.clone(),
ControlInfo {
ctrl_send,
relay_limit,
},
) {
tracing::debug!(
cert = ?remote_cert2,
"replaced existing control stream",
);
}
Ok(())
})?;
}
let res =
process_relay_control_inner(socket, state.clone(), ctrl_recv).await;
state.access(|state| {
if match state.control_addrs.get_mut(&ip) {
Some(ip_count) => {
if *ip_count > 0 {
*ip_count -= 1;
}
*ip_count == 0
}
None => false,
} {
state.control_addrs.remove(&ip);
}
state.control_channels.remove(&remote_cert);
});
tracing::debug!(cert = ?remote_cert, "control stream ended");
drop(con_permit);
drop(control_permit);
res
}
async fn process_relay_control_inner(
mut socket: Tx3Connection,
_state: RelayStateSync,
mut ctrl_recv: tokio::sync::mpsc::Receiver<ControlCmd>,
) -> Result<()> {
socket.write_all(&[0]).await?;
socket.flush().await?;
let (mut read, mut write) = tokio::io::split(socket);
tokio::select! {
r = async move {
let mut buf = [0];
if read.read_exact(&mut buf).await.is_err() {
Ok(())
} else {
Err(other_err("ControlReadData"))
}
} => r,
_ = async move {
while let Some(cmd) = ctrl_recv.recv().await {
match cmd {
ControlCmd::NotifyPending(splice_token) => {
write.write_all(&splice_token[..]).await?;
}
}
}
Result::Ok(())
} => {
Ok(())
}
}
}