1use std::{io::IsTerminal, net::SocketAddr, sync::Arc, time::Duration};
2
3use clap::Parser;
4use dhttp::{
5 ddns::resolvers::DnsScheme,
6 dquic::binds::BindPattern,
7 endpoint::Endpoint,
8 home::{self, DhttpHome, identity::IdentityProfile},
9 message::IntoUri,
10 name::DhttpName as Name,
11};
12use http_body_util::BodyExt;
13use snafu::{IntoError, Report, ResultExt, Snafu};
14use tokio::{net::TcpListener, sync::Semaphore, task::JoinSet};
15use tracing::Instrument;
16use tracing_subscriber::prelude::*;
17
18#[derive(Parser, Debug, Clone)]
19#[command(version, about)]
20pub struct Options {
21 #[arg(long = "listen", value_name = "bind", default_values = ["127.0.0.1:16080", "[::1]:16080"])]
23 pub listens: Vec<BindPattern>,
24
25 #[arg(short, long, value_name = "client_identity")]
27 pub id: Option<Name<'static>>,
28
29 #[arg(long, conflicts_with = "id")]
31 pub anonymous: bool,
32
33 #[arg(long, value_name = "scheme", default_values = ["mdns", "h3"], value_delimiter = ',', hide = cfg!(not(debug_assertions)))]
35 pub dns: Vec<DnsScheme>,
36
37 #[arg(long = "interface", value_name = "bind", default_value = "*", hide = cfg!(not(debug_assertions)))]
39 pub binds: Vec<BindPattern>,
40
41 #[arg(short, long)]
43 pub verbose: bool,
44
45 #[arg(long)]
47 pub daemon: bool,
48
49 #[arg(long, value_name = "path")]
51 pub log: Option<std::path::PathBuf>,
52}
53
54#[derive(Debug, Snafu)]
55#[snafu(visibility(pub))]
56pub enum Error {
57 #[snafu(display("failed to normalize dhttp uri"))]
58 NormalizeUri {
59 source: dhttp::message::IntoUriError,
60 },
61
62 #[snafu(display("failed to locate dhttp config"))]
63 LocateDhttpHome { source: home::LocateDhttpHomeError },
64
65 #[snafu(display("failed to load explicit identity `{name}`"))]
66 LoadExplicitIdentity {
67 name: Name<'static>,
68 source: dhttp::home::identity::ssl::ResolveIdentityProfileError,
69 },
70
71 #[snafu(display("failed to load identity certificate and key"))]
72 LoadIdentitySsl {
73 source: dhttp::home::identity::ssl::LoadIdentityError,
74 },
75
76 #[snafu(display("failed to build dhttp endpoint"))]
77 BuildEndpoint {
78 source: dhttp::endpoint::BuildEndpointError,
79 },
80
81 #[snafu(display("failed to bind proxy listener"))]
82 BindListener { source: std::io::Error },
83
84 #[snafu(display("failed to connect to tunnel target `{addr}`"))]
85 TunnelConnect {
86 addr: String,
87 source: std::io::Error,
88 },
89
90 #[snafu(display("failed to upgrade tunnel connection"))]
91 TunnelUpgrade { source: hyper::Error },
92
93 #[snafu(display("failed to connect to `{addr}`"))]
94 ForwardConnect {
95 addr: String,
96 source: std::io::Error,
97 },
98
99 #[snafu(display("failed to perform HTTP handshake with `{addr}`"))]
100 ForwardHandshake { addr: String, source: hyper::Error },
101
102 #[snafu(display("failed to send HTTP request"))]
103 ForwardSendRequest { source: hyper::Error },
104
105 #[snafu(display("missing host in request"))]
106 ForwardMissingHost {},
107
108 #[snafu(display("invalid host header"))]
109 ForwardInvalidHost { source: hyper::header::ToStrError },
110
111 #[snafu(display("failed to daemonize"))]
112 #[cfg(unix)]
113 Daemonize { source: daemonize::Error },
114
115 #[snafu(display("failed to create log file `{}`", path.display()))]
116 CreateLogFile {
117 path: std::path::PathBuf,
118 source: std::io::Error,
119 },
120
121 #[snafu(transparent)]
122 Whatever { source: Box<snafu::Whatever> },
123}
124
125impl snafu::FromString for Error {
126 type Source = <snafu::Whatever as snafu::FromString>::Source;
127
128 fn without_source(message: String) -> Self {
129 Error::Whatever {
130 source: Box::new(snafu::Whatever::without_source(message)),
131 }
132 }
133
134 fn with_source(source: Self::Source, message: String) -> Self {
135 Error::Whatever {
136 source: Box::new(snafu::Whatever::with_source(source, message)),
137 }
138 }
139}
140type BoxBody = http_body_util::combinators::UnsyncBoxBody<
141 bytes::Bytes,
142 Box<dyn std::error::Error + Send + Sync>,
143>;
144
145fn full_body(text: &'static str) -> BoxBody {
146 http_body_util::Full::new(bytes::Bytes::from(text))
147 .map_err(|never| match never {})
148 .boxed_unsync()
149}
150
151fn box_body<B>(body: B) -> BoxBody
152where
153 B: http_body_util::BodyExt<Data = bytes::Bytes> + Send + 'static,
154 B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
155{
156 body.map_err(Into::into).boxed_unsync()
157}
158
159async fn handle_request(
160 req: hyper::Request<hyper::body::Incoming>,
161 client: &Endpoint,
162 router: &route::Router,
163) -> Result<hyper::Response<BoxBody>, hyper::Error> {
164 let route = router.classify(&req);
165 tracing::info!(method = %req.method(), uri = %req.uri(), route = ?route, "proxy request");
166 match route {
167 route::Route::GenmetaPlainHttp { .. } => {
168 let mut req = req;
169 let self_name = client.name();
170 match req.uri().clone().into_uri(self_name.as_ref()) {
171 Ok(uri) => *req.uri_mut() = uri,
172 Err(e) => {
173 let error = NormalizeUriSnafu.into_error(e);
174 tracing::error!(
175 error = %Report::from_error(&error),
176 "failed to normalize dhttp uri"
177 );
178 return Ok(hyper::Response::builder()
179 .status(502)
180 .body(full_body("Bad Gateway"))
181 .expect("valid static response"));
182 }
183 }
184 match h3_forward::forward_h3(req, client).await {
185 Ok(resp) => Ok(resp.map(box_body)),
186 Err(e) => {
187 tracing::error!(error = %Report::from_error(&e), "h3 forward failed");
188 Ok(hyper::Response::builder()
189 .status(502)
190 .body(full_body("Bad Gateway"))
191 .expect("valid static response"))
192 }
193 }
194 }
195 route::Route::GenmetaConnect { .. } => Ok(hyper::Response::builder()
196 .status(502)
197 .body(full_body("HTTPS proxy to .dhttp.net not supported"))
198 .expect("valid static response")),
199 route::Route::TunnelConnect { authority } => {
200 match tunnel::tunnel_connect(req, authority.as_str()).await {
201 Ok(resp) => Ok(resp.map(box_body)),
202 Err(e) => {
203 tracing::error!(error = %Report::from_error(&e), "tunnel connect failed");
204 Ok(hyper::Response::builder()
205 .status(502)
206 .body(full_body("Bad Gateway"))
207 .expect("valid static response"))
208 }
209 }
210 }
211 route::Route::StandardForward { .. } => match forward::forward_http(req).await {
212 Ok(resp) => Ok(resp.map(box_body)),
213 Err(e) => {
214 tracing::error!(error = %Report::from_error(&e), "http forward failed");
215 Ok(hyper::Response::builder()
216 .status(502)
217 .body(full_body("Bad Gateway"))
218 .expect("valid static response"))
219 }
220 },
221 }
222}
223
224fn init_tracing(options: &Options) -> Result<tracing_appender::non_blocking::WorkerGuard, Error> {
226 let (writer, guard) = if let Some(ref log_path) = options.log
227 && !options.daemon
228 {
229 let file = std::fs::OpenOptions::new()
230 .create(true)
231 .append(true)
232 .open(log_path)
233 .context(CreateLogFileSnafu {
234 path: log_path.clone(),
235 })?;
236 tracing_appender::non_blocking(file)
237 } else {
238 tracing_appender::non_blocking(std::io::stderr())
239 };
240 let use_ansi = (options.log.is_none() || options.daemon) && std::io::stderr().is_terminal();
241 tracing_subscriber::registry()
242 .with(
243 tracing_subscriber::fmt::layer()
244 .with_ansi(use_ansi)
245 .with_timer(tracing_subscriber::fmt::time::LocalTime::rfc_3339())
246 .with_writer(writer),
247 )
248 .with(
249 tracing_subscriber::EnvFilter::builder()
250 .with_default_directive(tracing_subscriber::filter::LevelFilter::INFO.into())
251 .from_env_lossy()
252 .add_directive(
253 "netlink_packet_route=error"
254 .parse()
255 .expect("BUG: static tracing directive is valid"),
256 ),
257 )
258 .init();
259 Ok(guard)
260}
261
262async fn bind_listeners(options: &Options) -> Result<Vec<TcpListener>, Error> {
264 let mut listeners = Vec::new();
265 for bind in &options.listens {
266 let ip = bind.host.as_ip_addr().ok_or_else(|| {
267 <Error as snafu::FromString>::without_source(format!(
268 "listen bind `{}` must be a concrete ip address",
269 bind.host
270 ))
271 })?;
272 let addr = SocketAddr::new(ip, bind.effective_port());
273 let listener = TcpListener::bind(addr).await.context(BindListenerSnafu)?;
274 tracing::info!(%addr, "proxy listening");
275 listeners.push(listener);
276 }
277 Ok(listeners)
278}
279
280async fn load_identity_profile(options: &Options) -> Result<Option<IdentityProfile>, Error> {
281 if options.anonymous {
282 return Ok(None);
283 }
284
285 let home = match DhttpHome::load_from_environment() {
286 Ok(home) => home,
287 Err(source) if options.id.is_none() => {
288 tracing::warn!(
289 error = %snafu::Report::from_error(&source),
290 "failed to locate dhttp config, using anonymous endpoint"
291 );
292 return Ok(None);
293 }
294 Err(source) => return Err(LocateDhttpHomeSnafu.into_error(source)),
295 };
296
297 if let Some(name) = &options.id {
298 tracing::debug!(%name, "trying to load command line identity");
299 return home
300 .resolve_identity_profile(name.clone())
301 .await
302 .context(LoadExplicitIdentitySnafu { name: name.clone() })
303 .map(Some);
304 }
305
306 match home.resolve_default_identity_profile().await {
307 Ok(identity) => {
308 tracing::debug!(name = %identity.name(), "using default identity");
309 Ok(Some(identity))
310 }
311 Err(source) => {
312 tracing::debug!(
313 error = %snafu::Report::from_error(&source),
314 "failed to load default identity, using anonymous endpoint"
315 );
316 Ok(None)
317 }
318 }
319}
320
321pub async fn run(options: Options) -> Result<(), Error> {
322 let _guard = init_tracing(&options)?;
323
324 let identity_profile = load_identity_profile(&options).await?;
325 let identity = match &identity_profile {
326 Some(profile) => Some(Arc::new(
327 profile
328 .load_identity()
329 .await
330 .context(LoadIdentitySslSnafu)?,
331 )),
332 None => None,
333 };
334
335 let mut builder = Endpoint::builder()
336 .bind(Arc::new(options.binds.clone()))
337 .maybe_identity(identity);
338 for scheme in options.dns.iter().copied() {
339 builder = builder.dns(scheme);
340 }
341 let client = Arc::new(builder.build().await.context(BuildEndpointSnafu)?);
342
343 let listeners = bind_listeners(&options).await?;
344 let router = Arc::new(route::Router::new());
345
346 let semaphore = Arc::new(Semaphore::new(1024));
347 let mut tasks = JoinSet::new();
348 for listener in listeners {
349 let client = client.clone();
350 let router = router.clone();
351 let semaphore = semaphore.clone();
352 tasks.spawn(accept_loop(listener, client, router, semaphore));
353 }
354
355 while let Some(result) = tasks.join_next().await {
356 match result {
357 Ok(()) => tracing::info!("listener task exited"),
358 Err(e) => {
359 tracing::error!(error = %snafu::Report::from_error(&e), "listener task panicked")
360 }
361 }
362 }
363
364 Ok(())
365}
366
367fn configure_tcp_keepalive(stream: &tokio::net::TcpStream) {
372 let sock = socket2::SockRef::from(stream);
373 let keepalive = socket2::TcpKeepalive::new()
374 .with_time(Duration::from_secs(60))
375 .with_interval(Duration::from_secs(10));
376 #[cfg(any(
378 target_os = "android",
379 target_os = "dragonfly",
380 target_os = "freebsd",
381 target_os = "fuchsia",
382 target_os = "illumos",
383 target_os = "linux",
384 target_os = "netbsd",
385 target_vendor = "apple",
386 ))]
387 let keepalive = keepalive.with_retries(3);
388 if let Err(e) = sock.set_tcp_keepalive(&keepalive) {
389 tracing::warn!(error = %e, "failed to set TCP keepalive");
390 }
391}
392
393async fn accept_loop(
395 listener: TcpListener,
396 client: Arc<Endpoint>,
397 router: Arc<route::Router>,
398 semaphore: Arc<Semaphore>,
399) {
400 loop {
401 let (stream, addr) = match listener.accept().await {
402 Ok(accepted) => accepted,
403 Err(e) => {
404 tracing::warn!(error = %snafu::Report::from_error(&e), "accept failed, retrying");
405 tokio::time::sleep(Duration::from_millis(33)).await;
406 continue;
407 }
408 };
409 configure_tcp_keepalive(&stream);
410 let permit = match semaphore.clone().acquire_owned().await {
411 Ok(permit) => permit,
412 Err(_) => break, };
414 tracing::debug!(%addr, "accepted connection");
415 let client = client.clone();
416 let router = router.clone();
417 let span = tracing::info_span!("conn", %addr);
418 tokio::spawn(
421 async move {
422 let _permit = permit;
423 let io = hyper_util::rt::TokioIo::new(stream);
424 if let Err(e) = hyper::server::conn::http1::Builder::new()
425 .timer(hyper_util::rt::TokioTimer::new())
426 .header_read_timeout(Some(Duration::from_secs(120)))
427 .preserve_header_case(true)
428 .title_case_headers(true)
429 .serve_connection(
430 io,
431 hyper::service::service_fn(move |req| {
432 let client = client.clone();
433 let router = router.clone();
434 async move { handle_request(req, &client, &router).await }
435 }),
436 )
437 .with_upgrades()
438 .await
439 {
440 tracing::error!(error = %Report::from_error(&e), %addr, "connection error");
441 }
442 }
443 .instrument(span),
444 );
445 }
446}
447
448pub mod forward;
449pub mod h3_forward;
450pub mod route;
451pub mod tunnel;