use std::{io::IsTerminal, net::SocketAddr, sync::Arc, time::Duration};
use clap::Parser;
use dhttp::{
ddns::resolvers::DnsScheme,
dquic::binds::BindPattern,
endpoint::Endpoint,
home::{self, DhttpHome, identity::IdentityProfile},
message::IntoUri,
name::DhttpName as Name,
};
use http_body_util::BodyExt;
use snafu::{IntoError, Report, ResultExt, Snafu};
use tokio::{net::TcpListener, sync::Semaphore, task::JoinSet};
use tracing::Instrument;
use tracing_subscriber::prelude::*;
#[derive(Parser, Debug, Clone)]
#[command(version, about)]
pub struct Options {
#[arg(long = "listen", value_name = "bind", default_values = ["127.0.0.1:16080", "[::1]:16080"])]
pub listens: Vec<BindPattern>,
#[arg(short, long, value_name = "client_identity")]
pub id: Option<Name<'static>>,
#[arg(long, conflicts_with = "id")]
pub anonymous: bool,
#[arg(long, value_name = "scheme", default_values = ["mdns", "h3"], value_delimiter = ',', hide = cfg!(not(debug_assertions)))]
pub dns: Vec<DnsScheme>,
#[arg(long = "interface", value_name = "bind", default_value = "*", hide = cfg!(not(debug_assertions)))]
pub binds: Vec<BindPattern>,
#[arg(short, long)]
pub verbose: bool,
#[arg(long)]
pub daemon: bool,
#[arg(long, value_name = "path")]
pub log: Option<std::path::PathBuf>,
}
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum Error {
#[snafu(display("failed to normalize dhttp uri"))]
NormalizeUri {
source: dhttp::message::IntoUriError,
},
#[snafu(display("failed to locate dhttp config"))]
LocateDhttpHome { source: home::LocateDhttpHomeError },
#[snafu(display("failed to load explicit identity `{name}`"))]
LoadExplicitIdentity {
name: Name<'static>,
source: dhttp::home::identity::ssl::ResolveIdentityProfileError,
},
#[snafu(display("failed to load identity certificate and key"))]
LoadIdentitySsl {
source: dhttp::home::identity::ssl::LoadIdentityError,
},
#[snafu(display("failed to build dhttp endpoint"))]
BuildEndpoint {
source: dhttp::endpoint::BuildEndpointError,
},
#[snafu(display("failed to bind proxy listener"))]
BindListener { source: std::io::Error },
#[snafu(display("failed to connect to tunnel target `{addr}`"))]
TunnelConnect {
addr: String,
source: std::io::Error,
},
#[snafu(display("failed to upgrade tunnel connection"))]
TunnelUpgrade { source: hyper::Error },
#[snafu(display("failed to connect to `{addr}`"))]
ForwardConnect {
addr: String,
source: std::io::Error,
},
#[snafu(display("failed to perform HTTP handshake with `{addr}`"))]
ForwardHandshake { addr: String, source: hyper::Error },
#[snafu(display("failed to send HTTP request"))]
ForwardSendRequest { source: hyper::Error },
#[snafu(display("missing host in request"))]
ForwardMissingHost {},
#[snafu(display("invalid host header"))]
ForwardInvalidHost { source: hyper::header::ToStrError },
#[snafu(display("failed to daemonize"))]
#[cfg(unix)]
Daemonize { source: daemonize::Error },
#[snafu(display("failed to create log file `{}`", path.display()))]
CreateLogFile {
path: std::path::PathBuf,
source: std::io::Error,
},
#[snafu(transparent)]
Whatever { source: Box<snafu::Whatever> },
}
impl snafu::FromString for Error {
type Source = <snafu::Whatever as snafu::FromString>::Source;
fn without_source(message: String) -> Self {
Error::Whatever {
source: Box::new(snafu::Whatever::without_source(message)),
}
}
fn with_source(source: Self::Source, message: String) -> Self {
Error::Whatever {
source: Box::new(snafu::Whatever::with_source(source, message)),
}
}
}
type BoxBody = http_body_util::combinators::UnsyncBoxBody<
bytes::Bytes,
Box<dyn std::error::Error + Send + Sync>,
>;
fn full_body(text: &'static str) -> BoxBody {
http_body_util::Full::new(bytes::Bytes::from(text))
.map_err(|never| match never {})
.boxed_unsync()
}
fn box_body<B>(body: B) -> BoxBody
where
B: http_body_util::BodyExt<Data = bytes::Bytes> + Send + 'static,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
body.map_err(Into::into).boxed_unsync()
}
async fn handle_request(
req: hyper::Request<hyper::body::Incoming>,
client: &Endpoint,
router: &route::Router,
) -> Result<hyper::Response<BoxBody>, hyper::Error> {
let route = router.classify(&req);
tracing::info!(method = %req.method(), uri = %req.uri(), route = ?route, "proxy request");
match route {
route::Route::GenmetaPlainHttp { .. } => {
let mut req = req;
let self_name = client.name();
match req.uri().clone().into_uri(self_name.as_ref()) {
Ok(uri) => *req.uri_mut() = uri,
Err(e) => {
let error = NormalizeUriSnafu.into_error(e);
tracing::error!(
error = %Report::from_error(&error),
"failed to normalize dhttp uri"
);
return Ok(hyper::Response::builder()
.status(502)
.body(full_body("Bad Gateway"))
.expect("valid static response"));
}
}
match h3_forward::forward_h3(req, client).await {
Ok(resp) => Ok(resp.map(box_body)),
Err(e) => {
tracing::error!(error = %Report::from_error(&e), "h3 forward failed");
Ok(hyper::Response::builder()
.status(502)
.body(full_body("Bad Gateway"))
.expect("valid static response"))
}
}
}
route::Route::GenmetaConnect { .. } => Ok(hyper::Response::builder()
.status(502)
.body(full_body("HTTPS proxy to .dhttp.net not supported"))
.expect("valid static response")),
route::Route::TunnelConnect { authority } => {
match tunnel::tunnel_connect(req, authority.as_str()).await {
Ok(resp) => Ok(resp.map(box_body)),
Err(e) => {
tracing::error!(error = %Report::from_error(&e), "tunnel connect failed");
Ok(hyper::Response::builder()
.status(502)
.body(full_body("Bad Gateway"))
.expect("valid static response"))
}
}
}
route::Route::StandardForward { .. } => match forward::forward_http(req).await {
Ok(resp) => Ok(resp.map(box_body)),
Err(e) => {
tracing::error!(error = %Report::from_error(&e), "http forward failed");
Ok(hyper::Response::builder()
.status(502)
.body(full_body("Bad Gateway"))
.expect("valid static response"))
}
},
}
}
fn init_tracing(options: &Options) -> Result<tracing_appender::non_blocking::WorkerGuard, Error> {
let (writer, guard) = if let Some(ref log_path) = options.log
&& !options.daemon
{
let file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(log_path)
.context(CreateLogFileSnafu {
path: log_path.clone(),
})?;
tracing_appender::non_blocking(file)
} else {
tracing_appender::non_blocking(std::io::stderr())
};
let use_ansi = (options.log.is_none() || options.daemon) && std::io::stderr().is_terminal();
tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::layer()
.with_ansi(use_ansi)
.with_timer(tracing_subscriber::fmt::time::LocalTime::rfc_3339())
.with_writer(writer),
)
.with(
tracing_subscriber::EnvFilter::builder()
.with_default_directive(tracing_subscriber::filter::LevelFilter::INFO.into())
.from_env_lossy()
.add_directive(
"netlink_packet_route=error"
.parse()
.expect("BUG: static tracing directive is valid"),
),
)
.init();
Ok(guard)
}
async fn bind_listeners(options: &Options) -> Result<Vec<TcpListener>, Error> {
let mut listeners = Vec::new();
for bind in &options.listens {
let ip = bind.host.as_ip_addr().ok_or_else(|| {
<Error as snafu::FromString>::without_source(format!(
"listen bind `{}` must be a concrete ip address",
bind.host
))
})?;
let addr = SocketAddr::new(ip, bind.effective_port());
let listener = TcpListener::bind(addr).await.context(BindListenerSnafu)?;
tracing::info!(%addr, "proxy listening");
listeners.push(listener);
}
Ok(listeners)
}
async fn load_identity_profile(options: &Options) -> Result<Option<IdentityProfile>, Error> {
if options.anonymous {
return Ok(None);
}
let home = match DhttpHome::load_from_environment() {
Ok(home) => home,
Err(source) if options.id.is_none() => {
tracing::warn!(
error = %snafu::Report::from_error(&source),
"failed to locate dhttp config, using anonymous endpoint"
);
return Ok(None);
}
Err(source) => return Err(LocateDhttpHomeSnafu.into_error(source)),
};
if let Some(name) = &options.id {
tracing::debug!(%name, "trying to load command line identity");
return home
.resolve_identity_profile(name.clone())
.await
.context(LoadExplicitIdentitySnafu { name: name.clone() })
.map(Some);
}
match home.resolve_default_identity_profile().await {
Ok(identity) => {
tracing::debug!(name = %identity.name(), "using default identity");
Ok(Some(identity))
}
Err(source) => {
tracing::debug!(
error = %snafu::Report::from_error(&source),
"failed to load default identity, using anonymous endpoint"
);
Ok(None)
}
}
}
pub async fn run(options: Options) -> Result<(), Error> {
let _guard = init_tracing(&options)?;
let identity_profile = load_identity_profile(&options).await?;
let identity = match &identity_profile {
Some(profile) => Some(Arc::new(
profile
.load_identity()
.await
.context(LoadIdentitySslSnafu)?,
)),
None => None,
};
let mut builder = Endpoint::builder()
.bind(Arc::new(options.binds.clone()))
.maybe_identity(identity);
for scheme in options.dns.iter().copied() {
builder = builder.dns(scheme);
}
let client = Arc::new(builder.build().await.context(BuildEndpointSnafu)?);
let listeners = bind_listeners(&options).await?;
let router = Arc::new(route::Router::new());
let semaphore = Arc::new(Semaphore::new(1024));
let mut tasks = JoinSet::new();
for listener in listeners {
let client = client.clone();
let router = router.clone();
let semaphore = semaphore.clone();
tasks.spawn(accept_loop(listener, client, router, semaphore));
}
while let Some(result) = tasks.join_next().await {
match result {
Ok(()) => tracing::info!("listener task exited"),
Err(e) => {
tracing::error!(error = %snafu::Report::from_error(&e), "listener task panicked")
}
}
}
Ok(())
}
fn configure_tcp_keepalive(stream: &tokio::net::TcpStream) {
let sock = socket2::SockRef::from(stream);
let keepalive = socket2::TcpKeepalive::new()
.with_time(Duration::from_secs(60))
.with_interval(Duration::from_secs(10));
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "illumos",
target_os = "linux",
target_os = "netbsd",
target_vendor = "apple",
))]
let keepalive = keepalive.with_retries(3);
if let Err(e) = sock.set_tcp_keepalive(&keepalive) {
tracing::warn!(error = %e, "failed to set TCP keepalive");
}
}
async fn accept_loop(
listener: TcpListener,
client: Arc<Endpoint>,
router: Arc<route::Router>,
semaphore: Arc<Semaphore>,
) {
loop {
let (stream, addr) = match listener.accept().await {
Ok(accepted) => accepted,
Err(e) => {
tracing::warn!(error = %snafu::Report::from_error(&e), "accept failed, retrying");
tokio::time::sleep(Duration::from_millis(33)).await;
continue;
}
};
configure_tcp_keepalive(&stream);
let permit = match semaphore.clone().acquire_owned().await {
Ok(permit) => permit,
Err(_) => break, };
tracing::debug!(%addr, "accepted connection");
let client = client.clone();
let router = router.clone();
let span = tracing::info_span!("conn", %addr);
tokio::spawn(
async move {
let _permit = permit;
let io = hyper_util::rt::TokioIo::new(stream);
if let Err(e) = hyper::server::conn::http1::Builder::new()
.timer(hyper_util::rt::TokioTimer::new())
.header_read_timeout(Some(Duration::from_secs(120)))
.preserve_header_case(true)
.title_case_headers(true)
.serve_connection(
io,
hyper::service::service_fn(move |req| {
let client = client.clone();
let router = router.clone();
async move { handle_request(req, &client, &router).await }
}),
)
.with_upgrades()
.await
{
tracing::error!(error = %Report::from_error(&e), %addr, "connection error");
}
}
.instrument(span),
);
}
}
pub mod forward;
pub mod h3_forward;
pub mod route;
pub mod tunnel;