#![allow(clippy::similar_names)]
use std::net::ToSocketAddrs;
use std::os::unix::io::AsRawFd;
use std::pin::Pin;
use std::sync::Arc;
use anyhow::anyhow;
use anyhow::{Context, Result, bail};
use clap::Parser;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::UnixDatagram;
use tracing::{debug, error, info, trace, warn};
use tarweb::sock;
pub mod protos {
include!(concat!(env!("OUT_DIR"), "/tarweb.rs"));
}
const BUF_CAPACITY: usize = 2048;
#[derive(clap::Parser)]
struct Opt {
#[arg(long, short, default_value = "info")]
verbose: String,
#[arg(long, short, default_value = "[::]:443")]
listen: std::net::SocketAddr,
#[arg(long, default_value = "/")]
restrict_dirs: Vec<std::path::PathBuf>,
#[arg(long, short)]
config: String,
}
#[allow(clippy::unnecessary_wraps)]
fn load_tls(cfg: Option<&protos::backend::Tls>) -> Result<Option<Arc<rustls::ServerConfig>>> {
let Some(cfg) = cfg else {
return Ok(None);
};
let certs = tarweb::load_certs(&cfg.cert_file)?;
let key = tarweb::load_private_key(&cfg.key_file)?;
Ok(Some(Arc::new({
let mut cfg =
rustls::ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13])
.with_no_client_auth()
.with_single_cert(certs, key)?;
cfg.enable_secret_extraction = true;
cfg
})))
}
fn load_backend(
be: &protos::backend::BackendType,
frontend_tls: Option<&protos::backend::Tls>,
sorry: Option<&protos::Backend>,
) -> Result<Backend> {
if sorry.is_some_and(|s| s.sorry.is_some()) {
return Err(anyhow!("sorry servers can't have sorry servers"));
}
let sorry = sorry
.map(|s| {
load_backend(
s.backend_type.as_ref().unwrap(),
s.frontend_tls.as_ref(),
None,
)
})
.transpose()?
.map(Box::new);
Ok(match be {
protos::backend::BackendType::Null(_) => {
if sorry.is_some() {
return Err(anyhow!("null backend with sorry server not allowed"));
}
Backend::Null
}
protos::backend::BackendType::Proxy(p) => Backend::Proxy {
addr: p.addr.clone(),
proxy_header: p.proxy_header,
frontend_tls: load_tls(frontend_tls)?,
sorry,
},
protos::backend::BackendType::Pass(p) => Backend::Pass {
path: p.path.clone().into(),
frontend_tls: load_tls(frontend_tls)?,
sorry,
},
})
}
fn load_config(filename: &str) -> Result<Config> {
let pool = prost_reflect::DescriptorPool::decode(PROTO_DESCRIPTOR)?;
let md = pool
.get_message_by_name("tarweb.SNIConfig")
.ok_or(anyhow!("Unable to reflect SNIConfig"))?;
let cwd = std::env::current_dir()
.map(|c| c.display().to_string())
.unwrap_or("<unknown>".to_string());
let txt = std::fs::read_to_string(filename)
.context(anyhow!("opening {filename:?} from cwd {cwd:?}"))?;
let dyn_msg = prost_reflect::DynamicMessage::parse_text_format(md, &txt)?;
let protocfg: protos::SniConfig = dyn_msg.transcode_to()?;
let mut config = Config {
max_lifetime: if protocfg.max_lifetime_ms > 0 {
Some(tokio::time::Duration::from_millis(protocfg.max_lifetime_ms))
} else {
None
},
handshake_timeout: if protocfg.handshake_timeout_ms > 0 {
Some(tokio::time::Duration::from_millis(
protocfg.handshake_timeout_ms,
))
} else {
None
},
rules: vec![],
default_backend: {
let (be, frontend_tls, sorry) = protocfg
.default_backend
.as_ref()
.map(|d| (&d.backend_type, d.frontend_tls.as_ref(), d.sorry.as_deref()))
.ok_or(anyhow!("Config missing default backend"))?;
load_backend(
be.as_ref()
.ok_or(anyhow!("default backend missing an actual backend"))?,
frontend_tls,
sorry,
)?
},
default_backend_timeout: protocfg.default_backend.and_then(|b| {
let t = b.max_lifetime_ms;
if t > 0 {
Some(tokio::time::Duration::from_millis(b.max_lifetime_ms))
} else {
None
}
}),
};
for rule in protocfg.rules {
config.rules.push(Rule {
re: regex::Regex::new(&rule.regex)?,
timeout: rule.backend.as_ref().and_then(|b| {
let t = b.max_lifetime_ms;
if t > 0 {
Some(tokio::time::Duration::from_millis(b.max_lifetime_ms))
} else {
None
}
}),
backend: {
let (be, frontend_tls, sorry) = rule
.backend
.as_ref()
.map(|d| (&d.backend_type, d.frontend_tls.as_ref(), d.sorry.as_deref()))
.ok_or(anyhow!("rule missing backend"))?;
load_backend(
be.as_ref()
.ok_or(anyhow!("backend missing actual backend"))?,
frontend_tls,
sorry,
)?
},
});
}
Ok(config)
}
async fn read_tls_clienthello(
stream: &mut tokio::net::TcpStream,
) -> Result<(Vec<u8>, Result<Vec<u8>>)> {
const REC_HDR_LEN: usize = 5;
let mut hello = Vec::with_capacity(BUF_CAPACITY);
let mut bytes = Vec::with_capacity(BUF_CAPACITY);
let mut needed: Option<usize> = None;
while needed.is_none_or(|n| hello.len() < n) {
let mut rec_hdr = [0u8; REC_HDR_LEN];
stream
.read_exact(&mut rec_hdr)
.await
.context("read TLS record header")?;
bytes.extend(rec_hdr);
let content_type = rec_hdr[0];
let _legacy_ver = u16::from_be_bytes([rec_hdr[1], rec_hdr[2]]);
let rec_len = u16::from_be_bytes([rec_hdr[3], rec_hdr[4]]) as usize;
if content_type != 22 {
return Ok((
bytes,
Err(anyhow!(
"unexpected TLS content_type {content_type}, want 22 (handshake)"
)),
));
}
if rec_len == 0 {
return Ok((bytes, Err(anyhow!("zero-length TLS record"))));
}
let mut rec_payload = vec![0u8; rec_len];
stream
.read_exact(&mut rec_payload)
.await
.context("read TLS record payload")?;
hello.extend(&rec_payload);
bytes.extend(&rec_payload);
if needed.is_none() {
if hello.len() < 4 {
continue;
}
let msg_type = hello[0];
if msg_type != 1 {
return Ok((
bytes,
Err(anyhow!(
"first handshake msg is type {msg_type}, expected 1 (ClientHello)"
)),
));
}
let body_len =
((hello[1] as usize) << 16) | ((hello[2] as usize) << 8) | (hello[3] as usize);
needed = Some(4 + body_len);
}
}
let n = needed.unwrap();
if hello.len() > n {
hello.truncate(n);
}
Ok((bytes, Ok(hello)))
}
async fn pass_fd_over_uds(
stream: tokio::net::TcpStream,
sock: UnixDatagram,
bytes: Vec<u8>,
) -> Result<()> {
use nix::sys::socket::{ControlMessage, MsgFlags, sendmsg};
let fd = stream.as_raw_fd();
let iov = [std::io::IoSlice::new(&bytes)];
let cmsg = [ControlMessage::ScmRights(&[fd])];
sock.writable().await?;
let sent = sendmsg::<()>(
sock.as_raw_fd(),
&iov,
&cmsg,
MsgFlags::MSG_NOSIGNAL | MsgFlags::MSG_DONTWAIT,
None,
)
.context("sendmsg SCM_RIGHTS")?;
if sent != bytes.len() {
return Err(anyhow!(
"sendmsg: expected to send {} bytes, sent {sent}",
bytes.len()
));
}
Ok(())
}
fn extract_sni(clienthello: &[u8]) -> Result<Option<String>> {
if clienthello.len() < 4 {
bail!("ClientHello too short for handshake header");
}
if clienthello[0] != 1 {
bail!("not a ClientHello (handshake type {})", clienthello[0]);
}
let body_len = ((clienthello[1] as usize) << 16)
| ((clienthello[2] as usize) << 8)
| (clienthello[3] as usize);
if clienthello.len() < 4 + body_len {
bail!("truncated ClientHello body");
}
let body = &clienthello[4..4 + body_len];
let mut i = 0usize;
if body.len() < 35 {
bail!("ClientHello body too short");
}
i += 2 + 32;
let sid_len = body[i] as usize;
i += 1;
if body.len() < i + sid_len {
bail!("truncated session_id");
}
i += sid_len;
if body.len() < i + 2 {
bail!("missing cipher_suites length");
}
let cs_len = u16::from_be_bytes([body[i], body[i + 1]]) as usize;
i += 2;
if body.len() < i + cs_len || !cs_len.is_multiple_of(2) {
bail!("invalid cipher_suites vector");
}
i += cs_len;
if body.len() < i + 1 {
bail!("missing compression_methods length");
}
let cmethod_len = body[i] as usize;
i += 1;
if body.len() < i + cmethod_len {
bail!("invalid compression_methods vector");
}
i += cmethod_len;
if i == body.len() {
return Ok(None); }
if body.len() < i + 2 {
bail!("missing extensions length");
}
let ext_total = u16::from_be_bytes([body[i], body[i + 1]]) as usize;
i += 2;
if body.len() < i + ext_total {
bail!("truncated extensions block");
}
let mut j = i;
while j + 4 <= i + ext_total {
let etype = u16::from_be_bytes([body[j], body[j + 1]]);
let elen = u16::from_be_bytes([body[j + 2], body[j + 3]]) as usize;
j += 4;
if j + elen > i + ext_total {
bail!("truncated extension body");
}
if etype == 0x0000 {
let ext = &body[j..j + elen];
if ext.len() < 2 {
bail!("server_name: missing list length");
}
let list_len = u16::from_be_bytes([ext[0], ext[1]]) as usize;
if ext.len() < 2 + list_len {
bail!("server_name: truncated list");
}
let mut k = 2usize;
while k + 3 <= 2 + list_len {
let name_type = ext[k];
let host_len = u16::from_be_bytes([ext[k + 1], ext[k + 2]]) as usize;
k += 3;
if k + host_len > 2 + list_len {
bail!("server_name: truncated host entry");
}
if name_type == 0 {
let host_bytes = &ext[k..k + host_len];
let host = String::from_utf8_lossy(host_bytes).to_string();
return Ok(Some(host));
}
k += host_len;
}
return Ok(None);
}
j += elen;
}
Ok(None)
}
#[derive(Debug)]
enum Backend {
Null,
Pass {
path: std::path::PathBuf,
frontend_tls: Option<Arc<rustls::ServerConfig>>,
sorry: Option<Box<Backend>>,
},
Proxy {
addr: String,
proxy_header: bool,
frontend_tls: Option<Arc<rustls::ServerConfig>>,
sorry: Option<Box<Backend>>,
},
}
#[derive(Debug)]
struct Rule {
re: regex::Regex,
backend: Backend,
timeout: Option<tokio::time::Duration>,
}
#[derive(Debug)]
struct Config {
max_lifetime: Option<tokio::time::Duration>,
handshake_timeout: Option<tokio::time::Duration>,
rules: Vec<Rule>,
default_backend: Backend,
default_backend_timeout: Option<tokio::time::Duration>,
}
struct RoutedConnection {
backend: ConnectedBackend,
timeout: Option<tokio::time::Duration>,
}
enum ConnectedBackend {
Done,
Proxy {
stream: tokio::net::TcpStream,
bytes: Vec<u8>,
conn: tokio::net::TcpStream,
proxy_header: bool,
frontend_tls: Option<Arc<rustls::ServerConfig>>,
},
}
async fn tls_handshake(
mut stream: tokio::net::TcpStream,
mut bytes: Vec<u8>,
cfg: Arc<rustls::ServerConfig>,
) -> Result<(tokio::net::TcpStream, Vec<u8>)> {
use std::io::Read;
use tokio::io::AsyncWriteExt;
debug!("Handshaking…");
let mut tls = rustls::ServerConnection::new(cfg)
.context("creating TLS server config: This is sorry-able, but is not implemented")?;
loop {
{
let mut cur = std::io::Cursor::new(&bytes);
let n = tls.read_tls(&mut cur)?;
bytes.drain(0..n);
}
let io = tls.process_new_packets()?;
let bytes_to_write = io.tls_bytes_to_write();
if bytes_to_write > 0 {
let mut buf = vec![0u8; bytes_to_write];
let mut cur = std::io::Cursor::new(&mut buf);
let n = tls.write_tls(&mut cur)?;
stream.write_all(&buf[..n]).await?;
}
let still_handshaking = tls.is_handshaking();
if !still_handshaking {
let plain_n = io.plaintext_bytes_to_read();
let mut buf = vec![0u8; plain_n];
let n = tls.reader().read(&mut buf[..plain_n])?;
assert_eq!(plain_n, n);
let ulp_name = b"tls\0";
let rc = unsafe {
libc::setsockopt(
stream.as_raw_fd(),
libc::SOL_TCP,
libc::TCP_ULP,
ulp_name.as_ptr().cast(),
ulp_name.len().try_into()?,
)
};
if rc < 0 {
return Err(anyhow!(
"setsockopt()=>{rc}: {}",
std::io::Error::from_raw_os_error(rc.abs())
));
}
let suite = tls.negotiated_cipher_suite().ok_or(anyhow!("bleh"))?;
let keys = tls.dangerous_extract_secrets()?;
let tls_rx = ktls::CryptoInfo::from_rustls(suite, keys.rx)?;
let tls_tx = ktls::CryptoInfo::from_rustls(suite, keys.tx)?;
for (name, s) in [(libc::TLS_RX, tls_rx), (libc::TLS_TX, tls_tx)] {
let rc = unsafe {
libc::setsockopt(
stream.as_raw_fd(),
libc::SOL_TLS,
name,
s.as_ptr(),
s.size().try_into()?,
)
};
if rc < 0 {
return Err(anyhow!(
"setsockopt()=>{rc}: {}",
std::io::Error::from_raw_os_error(rc.abs())
));
}
}
return Ok((stream, buf));
}
let mut buf = [0u8; 4096];
let n = stream.read(&mut buf).await?;
if n == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"EOF during handshake",
)
.into());
}
bytes.extend(&buf[..n]);
if bytes.len() > 8192 {
return Err(anyhow!("max TLS outstanding size exceeded"));
}
}
}
async fn connect_for_proxy(id: usize, addr: &str) -> Result<tokio::net::TcpStream> {
let addrs = addr.to_socket_addrs()?;
let mut conn = None;
for addr in addrs {
match tokio::net::TcpStream::connect(addr).await {
Ok(ok) => {
trace!("id={id} Connected to backend {addr}");
conn = Some(ok);
break;
}
Err(e) => {
debug!("id={id} Failed to connect to backend {addr:?}: {e}");
}
}
}
conn.ok_or(anyhow!(
"failed to connect to any backend with address {addr}"
))
}
async fn handle_connected_backend(id: usize, backend: ConnectedBackend) -> Result<()> {
match backend {
ConnectedBackend::Done => Ok(()),
ConnectedBackend::Proxy {
stream,
bytes,
conn,
proxy_header,
frontend_tls,
} => handle_connected_proxy(id, stream, bytes, conn, proxy_header, frontend_tls).await,
}
}
async fn handle_connected_proxy(
id: usize,
stream: tokio::net::TcpStream,
bytes: Vec<u8>,
mut conn: tokio::net::TcpStream,
proxy_header: bool,
tls: Option<Arc<rustls::ServerConfig>>,
) -> Result<()> {
let (mut stream, bytes) = if let Some(tls) = tls {
tls_handshake(stream, bytes, tls).await?
} else {
(stream, bytes)
};
let (mut up_r, mut up_w) = conn.split();
let (mut down_r, mut down_w) = stream.split();
let upstream = async {
if proxy_header {
let me = down_r.local_addr()?;
let peer = down_r.peer_addr()?;
let src_port = peer.port();
let src_addr = peer.ip().to_string();
let proto = if peer.is_ipv4() {
"TCP4"
} else if peer.is_ipv6() {
"TCP6"
} else {
"UNKNOWN"
};
let dst_addr = me.ip().to_string();
let dst_port = me.port();
up_w.write_all(
format!("PROXY {proto} {src_addr} {dst_addr} {src_port} {dst_port}\r\n").as_bytes(),
)
.await?;
}
up_w.write_all(&bytes).await?;
tokio::io::copy(&mut down_r, &mut up_w).await?;
up_w.shutdown().await?;
trace!("id={id} Upstream write completed");
Ok::<_, anyhow::Error>(())
};
let downstream = async {
tokio::io::copy(&mut up_r, &mut down_w).await?;
down_w.shutdown().await?;
trace!("id={id} Downstream write completed");
Ok::<_, anyhow::Error>(())
};
tokio::try_join!(upstream, downstream)?;
Ok(())
}
fn connect_or_handoff_backend<'a>(
id: usize,
stream: tokio::net::TcpStream,
bytes: Vec<u8>,
backend: &'a Backend,
) -> Pin<Box<dyn std::future::Future<Output = Result<ConnectedBackend>> + Send + 'a>> {
Box::pin(async move {
match backend {
Backend::Null => {
trace!("id={id} Null backend. Closing");
Ok(ConnectedBackend::Done)
}
Backend::Pass {
path,
frontend_tls,
sorry,
} => {
let sock = tokio::net::UnixDatagram::unbound().context("create UnixDatagram")?;
if let Err(e) = sock
.connect(path)
.with_context(|| format!("connect to {:?}", path.display()))
{
info!("Primary backend connect failure: {e}");
if let Some(s) = sorry {
return connect_or_handoff_backend(id, stream, bytes, s).await;
}
return Err(e);
}
if false {
let ucred = nix::sys::socket::getsockopt(
&sock,
nix::sys::socket::sockopt::PeerCredentials,
)?;
debug!(
"id={id} peer pid={} uid={} gid={}",
ucred.pid(),
ucred.uid(),
ucred.gid()
);
}
let (stream, bytes) = if let Some(tls) = frontend_tls {
tls_handshake(stream, bytes, tls.clone()).await?
} else {
(stream, bytes)
};
pass_fd_over_uds(stream, sock, bytes).await?;
Ok(ConnectedBackend::Done)
}
Backend::Proxy {
addr,
proxy_header,
frontend_tls,
sorry,
} => {
let conn = match connect_for_proxy(id, addr).await {
Ok(c) => c,
Err(e) => {
info!("Primary backend connect failure: {e}");
return match sorry {
None => Err(e),
Some(s) => connect_or_handoff_backend(id, stream, bytes, s).await,
};
}
};
Ok(ConnectedBackend::Proxy {
stream,
bytes,
conn,
proxy_header: *proxy_header,
frontend_tls: frontend_tls.clone(),
})
}
}
})
}
async fn connect_or_handoff_backend_with_timeout(
id: usize,
stream: tokio::net::TcpStream,
bytes: Vec<u8>,
backend: &Backend,
timeout: Option<tokio::time::Duration>,
) -> Result<ConnectedBackend> {
let fut = connect_or_handoff_backend(id, stream, bytes, backend);
if let Some(timeout) = timeout {
match tokio::time::timeout(timeout, fut).await {
Ok(r) => r,
Err(e) => Err(anyhow!("backend connect/handoff timeout: {e}")),
}
} else {
fut.await
}
}
fn is_full_match(re: ®ex::Regex, text: &str) -> bool {
match re.find(text) {
Some(m) => m.start() == 0 && m.end() == text.len(),
None => false,
}
}
async fn route_and_connect(
id: usize,
mut stream: tokio::net::TcpStream,
config: &Config,
) -> Result<RoutedConnection> {
let (bytes, clienthello) = read_tls_clienthello(&mut stream).await?;
match clienthello {
Ok(clienthello) => {
debug!("id={id} ClientHello len={} bytes", clienthello.len());
match extract_sni(&clienthello)? {
Some(sni) => {
debug!("id={id} SNI: {sni:?}");
for rule in &config.rules {
if is_full_match(&rule.re, &sni) {
trace!("id={id} SNI {sni} matched rule {rule:?}");
return Ok(RoutedConnection {
backend: connect_or_handoff_backend_with_timeout(
id,
stream,
bytes,
&rule.backend,
rule.timeout,
)
.await?,
timeout: rule.timeout,
});
}
}
}
None => {
warn!("id={id} Failed to extract SNI");
}
}
}
Err(e) => {
warn!("id={id} Using default backend because no clienthello: {e}");
}
}
Ok(RoutedConnection {
backend: connect_or_handoff_backend_with_timeout(
id,
stream,
bytes,
&config.default_backend,
config.default_backend_timeout,
)
.await?,
timeout: config.default_backend_timeout,
})
}
async fn handle_conn(id: usize, stream: tokio::net::TcpStream, config: &Config) -> Result<()> {
let fut = route_and_connect(id, stream, config);
let routed = if let Some(timeout) = config.handshake_timeout {
match tokio::time::timeout(timeout, fut).await {
Ok(r) => r?,
Err(e) => return Err(anyhow!("handshake timeout: {e}")),
}
} else {
fut.await?
};
let fut = handle_connected_backend(id, routed.backend);
if let Some(timeout) = routed.timeout {
tokio::time::timeout(timeout, fut).await?
} else {
fut.await
}
}
async fn mainloop(
mut config: Arc<Config>,
config_filename: &str,
listener: tokio::net::TcpListener,
) -> Result<()> {
let mut id = 0;
let mut hups = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
.expect("Registering SIGHUP");
loop {
let (stream, peer) = tokio::select! {
r = listener.accept() => r,
_ = hups.recv() => {
let cwd = std::env::current_dir().map(|c|c.display().to_string()).unwrap_or("<unknown>".to_string());
info!("Got SIGHUP. Loading new config {config_filename:?} in cwd {cwd:?}");
match load_config(config_filename) {
Ok(c) => config = Arc::new(c),
Err(e) => error!(
"Failed to load config {config_filename:?}, staying with old config: {e}"
),
}
continue;
}
}?;
debug!("id={id} fd={} Accepted {}", stream.as_raw_fd(), peer);
let config = config.clone();
tokio::spawn(async move {
let fut = handle_conn(id, stream, &config);
let res = if let Some(timeout) = config.max_lifetime {
match tokio::time::timeout(timeout, fut).await {
Ok(o) => o,
Err(e) => Err(anyhow!("connection timeout for peer {peer}: {e}")),
}
} else {
fut.await
};
if let Err(e) = res {
warn!("id={id} Handling connection to {peer}: {e:#}");
}
debug!("id={id} Done");
});
id += 1;
}
}
const PROTO_DESCRIPTOR: &[u8] = include_bytes!("../../descriptor.bin");
#[tokio::main]
async fn main() -> Result<()> {
let opt = Opt::parse();
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.unwrap();
tracing_subscriber::fmt()
.with_env_filter(&opt.verbose)
.with_writer(std::io::stderr)
.init();
info!("SNI Router");
let listener = tokio::net::TcpListener::bind(&opt.listen)
.await
.context(format!("listening to {}", opt.listen))?;
debug!("Listening on {}", listener.local_addr()?);
tarweb::privs::sni_drop(
&opt.restrict_dirs
.iter()
.map(std::path::PathBuf::as_path)
.collect::<Vec<_>>(),
)?;
sock::set_nodelay(listener.as_raw_fd())?;
let config = load_config(&opt.config).context(format!("Loading config {:?}", opt.config))?;
mainloop(Arc::new(config), &opt.config, listener).await
}
#[cfg(test)]
mod tests {
#![allow(clippy::too_many_lines)]
use super::*;
use std::net::SocketAddr;
use std::sync::atomic::Ordering;
const MAX_TEST_CONNECTION_TIME: tokio::time::Duration = tokio::time::Duration::from_secs(5);
#[test]
fn config_loads_handshake_timeout() -> Result<()> {
let tmp_dir = tempfile::TempDir::new()?;
let config_file = tmp_dir.path().join("config.cfg");
std::fs::write(
&config_file,
r#"
default_backend: <
null: <>
>
handshake_timeout_ms: 1234
"#,
)?;
let config = load_config(config_file.to_str().unwrap())?;
assert_eq!(
config.handshake_timeout,
Some(tokio::time::Duration::from_millis(1234))
);
Ok(())
}
#[tokio::test]
async fn default_client() -> Result<()> {
if false {
tracing_subscriber::fmt()
.with_env_filter("trace")
.with_writer(std::io::stderr)
.init();
}
for curl_opt in ["--tlsv1", "--tlsv1.1", "--tls1.2", "--tls1.3"] {
for sni in ["foo", "bar", "bar2", "socket"] {
info!("TESTING: sni={sni} opt={curl_opt}");
let tmp_dir = tempfile::TempDir::new()?;
let hit_something = std::sync::atomic::AtomicBool::new(false);
let listener =
tokio::net::TcpListener::bind("[::1]:0".parse::<SocketAddr>()?).await?;
let listener_port = listener.local_addr()?.port();
let backend_bar =
tokio::net::TcpListener::bind("[::1]:0".parse::<SocketAddr>()?).await?;
let backend_bar_port = backend_bar.local_addr()?.port();
let backend_baz =
tokio::net::TcpListener::bind("[::1]:0".parse::<SocketAddr>()?).await?;
let backend_baz_port = backend_baz.local_addr()?.port();
let sockfile = tmp_dir.path().join("tarweb-testing.sock");
let backend_sock = tokio::net::UnixDatagram::bind(&sockfile)?;
#[allow(clippy::regex_creation_in_loops)]
let config = Config {
max_lifetime: Some(MAX_TEST_CONNECTION_TIME),
handshake_timeout: None,
rules: vec![
Rule {
re: regex::Regex::new("foo")?,
backend: Backend::Null,
timeout: None,
},
Rule {
re: regex::Regex::new("socket")?,
backend: Backend::Pass {
path: sockfile.clone(),
frontend_tls: None,
sorry: None,
},
timeout: None,
},
Rule {
re: regex::Regex::new("bar")?,
backend: Backend::Proxy {
addr: format!("[::1]:{backend_bar_port}"),
proxy_header: false,
frontend_tls: None,
sorry: None,
},
timeout: None,
},
],
default_backend: Backend::Proxy {
addr: format!("[::1]:{backend_baz_port}"),
proxy_header: false,
frontend_tls: None,
sorry: None,
},
default_backend_timeout: None,
};
let _main =
tokio::task::spawn(
async move { mainloop(Arc::new(config), "", listener).await },
);
let (done_tx1, mut done_rx_bar) = tokio::sync::mpsc::channel::<()>(1);
let (done_tx2, mut done_rx_baz) = tokio::sync::mpsc::channel::<()>(1);
let (done_tx3, mut done_rx_sock) = tokio::sync::mpsc::channel::<()>(1);
let client = async {
let _status = tokio::process::Command::new("curl")
.arg("-S")
.arg("--no-progress-meter")
.arg("--connect-to")
.arg(format!("foo:443:[::1]:{listener_port}"))
.arg("--connect-to")
.arg(format!("bar:443:[::1]:{listener_port}"))
.arg("--connect-to")
.arg(format!("socket:443:[::1]:{listener_port}"))
.arg("--connect-to")
.arg(format!("bar2:443:[::1]:{listener_port}"))
.arg(format!("https://{sni}/"))
.spawn()?
.wait()
.await?;
drop(done_tx1);
drop(done_tx2);
drop(done_tx3);
Ok::<(), anyhow::Error>(())
};
let backend_bar = async {
if sni == "bar" {
info!("COVERED: bar");
hit_something.store(true, Ordering::Relaxed);
tokio::select! {
_ = backend_bar.accept() => Ok(()),
_ = done_rx_bar.recv() => Err(anyhow!("nobody connected to backend")),
}
} else {
Ok(())
}
};
let backend_baz = async {
if sni == "bar2" {
info!("COVERED: default");
hit_something.store(true, Ordering::Relaxed);
tokio::select! {
_ = backend_baz.accept() => Ok(()),
_ = done_rx_baz.recv() => Err(anyhow!("nobody connected to backend")),
}
} else {
Ok(())
}
};
let backend_sock = async {
if sni == "socket" {
info!("COVERED: socket");
hit_something.store(true, Ordering::Relaxed);
let mut buf = [0u8; 2048];
tokio::select! {
_ = backend_sock.recv(&mut buf) => Ok(()),
_ = done_rx_sock.recv() => Err(anyhow!("nobody connected to backend")),
}
} else {
Ok(())
}
};
if sni == "foo" {
hit_something.store(true, Ordering::Relaxed);
}
tokio::time::timeout(MAX_TEST_CONNECTION_TIME, async {
tokio::try_join!(client, backend_bar, backend_baz, backend_sock,)
})
.await??;
assert!(
hit_something.load(Ordering::Relaxed),
"SNI {sni:?} and opts {curl_opt:?} did not do anything"
);
}
}
Ok(())
}
#[tokio::test]
async fn handshake_timeout_closes_idle_preroute_client() -> Result<()> {
let listener = tokio::net::TcpListener::bind("[::1]:0".parse::<SocketAddr>()?).await?;
let listener_port = listener.local_addr()?.port();
let config = Config {
max_lifetime: Some(MAX_TEST_CONNECTION_TIME),
handshake_timeout: Some(tokio::time::Duration::from_millis(50)),
rules: vec![],
default_backend: Backend::Null,
default_backend_timeout: None,
};
let _main =
tokio::task::spawn(async move { mainloop(Arc::new(config), "", listener).await });
let mut stream = tokio::net::TcpStream::connect(format!("[::1]:{listener_port}")).await?;
let mut buf = [0u8; 1];
let read = tokio::time::timeout(MAX_TEST_CONNECTION_TIME, stream.read(&mut buf)).await?;
match read {
Ok(0) => Ok(()),
Ok(n) => Err(anyhow!("idle preroute client read unexpected {n} bytes")),
Err(e) if e.kind() == std::io::ErrorKind::ConnectionReset => Ok(()),
Err(e) => Err(e.into()),
}
}
#[tokio::test]
async fn handshake_timeout_stops_after_proxy_backend_connects() -> Result<()> {
let listener = tokio::net::TcpListener::bind("[::1]:0".parse::<SocketAddr>()?).await?;
let listener_port = listener.local_addr()?.port();
let backend = tokio::net::TcpListener::bind("[::1]:0".parse::<SocketAddr>()?).await?;
let backend_port = backend.local_addr()?.port();
let config = Config {
max_lifetime: Some(MAX_TEST_CONNECTION_TIME),
handshake_timeout: Some(tokio::time::Duration::from_millis(50)),
rules: vec![],
default_backend: Backend::Proxy {
addr: format!("[::1]:{backend_port}"),
proxy_header: false,
frontend_tls: None,
sorry: None,
},
default_backend_timeout: None,
};
let _main =
tokio::task::spawn(async move { mainloop(Arc::new(config), "", listener).await });
let backend = tokio::spawn(async move {
let (mut stream, _) = backend.accept().await?;
let mut got = [0u8; 5];
stream.read_exact(&mut got).await?;
if got != *b"abcde" {
return Err(anyhow!("backend got unexpected bytes: {got:?}"));
}
tokio::time::sleep(tokio::time::Duration::from_millis(150)).await;
stream.write_all(b"ok").await?;
stream.shutdown().await?;
Ok::<(), anyhow::Error>(())
});
let mut stream = tokio::net::TcpStream::connect(format!("[::1]:{listener_port}")).await?;
stream.write_all(b"abcde").await?;
let mut got = Vec::new();
tokio::time::timeout(MAX_TEST_CONNECTION_TIME, stream.read_to_end(&mut got)).await??;
backend.await??;
assert_eq!(got, b"ok");
Ok(())
}
}