Skip to main content

genmeta_proxy/
lib.rs

1use std::{io::IsTerminal, net::SocketAddr, sync::Arc, time::Duration};
2
3use clap::Parser;
4use dhttp::{
5    ddns::resolvers::DnsScheme,
6    dquic::binds::BindPattern,
7    endpoint::Endpoint,
8    home::{self, DhttpHome, identity::IdentityProfile},
9    message::IntoUri,
10    name::DhttpName as Name,
11};
12use http_body_util::BodyExt;
13use snafu::{IntoError, Report, ResultExt, Snafu};
14use tokio::{net::TcpListener, sync::Semaphore, task::JoinSet};
15use tracing::Instrument;
16use tracing_subscriber::prelude::*;
17
18#[derive(Parser, Debug, Clone)]
19#[command(version, about)]
20pub struct Options {
21    /// Proxy listen address patterns
22    #[arg(long = "listen", value_name = "bind", default_values = ["127.0.0.1:16080", "[::1]:16080"])]
23    pub listens: Vec<BindPattern>,
24
25    /// Client identity for DHTTP/3 connections
26    #[arg(short, long, value_name = "client_identity")]
27    pub id: Option<Name<'static>>,
28
29    /// Skip identity loading and use anonymous mode
30    #[arg(long, conflicts_with = "id")]
31    pub anonymous: bool,
32
33    /// DNS resolution schemes
34    #[arg(long, value_name = "scheme", default_values = ["mdns", "h3"], value_delimiter = ',', hide = cfg!(not(debug_assertions)))]
35    pub dns: Vec<DnsScheme>,
36
37    /// Bind patterns for DHTTP/3 connections
38    #[arg(long = "interface", value_name = "bind", default_value = "*", hide = cfg!(not(debug_assertions)))]
39    pub binds: Vec<BindPattern>,
40
41    /// Show detailed request logging
42    #[arg(short, long)]
43    pub verbose: bool,
44
45    /// Run as daemon (background process)
46    #[arg(long)]
47    pub daemon: bool,
48
49    /// Log file path (write tracing output to this file instead of stderr)
50    #[arg(long, value_name = "path")]
51    pub log: Option<std::path::PathBuf>,
52}
53
54#[derive(Debug, Snafu)]
55#[snafu(visibility(pub))]
56pub enum Error {
57    #[snafu(display("failed to normalize dhttp uri"))]
58    NormalizeUri {
59        source: dhttp::message::IntoUriError,
60    },
61
62    #[snafu(display("failed to locate dhttp config"))]
63    LocateDhttpHome { source: home::LocateDhttpHomeError },
64
65    #[snafu(display("failed to load explicit identity `{name}`"))]
66    LoadExplicitIdentity {
67        name: Name<'static>,
68        source: dhttp::home::identity::ssl::ResolveIdentityProfileError,
69    },
70
71    #[snafu(display("failed to load identity certificate and key"))]
72    LoadIdentitySsl {
73        source: dhttp::home::identity::ssl::LoadIdentityError,
74    },
75
76    #[snafu(display("failed to build dhttp endpoint"))]
77    BuildEndpoint {
78        source: dhttp::endpoint::BuildEndpointError,
79    },
80
81    #[snafu(display("failed to bind proxy listener"))]
82    BindListener { source: std::io::Error },
83
84    #[snafu(display("failed to connect to tunnel target `{addr}`"))]
85    TunnelConnect {
86        addr: String,
87        source: std::io::Error,
88    },
89
90    #[snafu(display("failed to upgrade tunnel connection"))]
91    TunnelUpgrade { source: hyper::Error },
92
93    #[snafu(display("failed to connect to `{addr}`"))]
94    ForwardConnect {
95        addr: String,
96        source: std::io::Error,
97    },
98
99    #[snafu(display("failed to perform HTTP handshake with `{addr}`"))]
100    ForwardHandshake { addr: String, source: hyper::Error },
101
102    #[snafu(display("failed to send HTTP request"))]
103    ForwardSendRequest { source: hyper::Error },
104
105    #[snafu(display("missing host in request"))]
106    ForwardMissingHost {},
107
108    #[snafu(display("invalid host header"))]
109    ForwardInvalidHost { source: hyper::header::ToStrError },
110
111    #[snafu(display("failed to daemonize"))]
112    #[cfg(unix)]
113    Daemonize { source: daemonize::Error },
114
115    #[snafu(display("failed to create log file `{}`", path.display()))]
116    CreateLogFile {
117        path: std::path::PathBuf,
118        source: std::io::Error,
119    },
120
121    #[snafu(transparent)]
122    Whatever { source: Box<snafu::Whatever> },
123}
124
125impl snafu::FromString for Error {
126    type Source = <snafu::Whatever as snafu::FromString>::Source;
127
128    fn without_source(message: String) -> Self {
129        Error::Whatever {
130            source: Box::new(snafu::Whatever::without_source(message)),
131        }
132    }
133
134    fn with_source(source: Self::Source, message: String) -> Self {
135        Error::Whatever {
136            source: Box::new(snafu::Whatever::with_source(source, message)),
137        }
138    }
139}
140type BoxBody = http_body_util::combinators::UnsyncBoxBody<
141    bytes::Bytes,
142    Box<dyn std::error::Error + Send + Sync>,
143>;
144
145fn full_body(text: &'static str) -> BoxBody {
146    http_body_util::Full::new(bytes::Bytes::from(text))
147        .map_err(|never| match never {})
148        .boxed_unsync()
149}
150
151fn box_body<B>(body: B) -> BoxBody
152where
153    B: http_body_util::BodyExt<Data = bytes::Bytes> + Send + 'static,
154    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
155{
156    body.map_err(Into::into).boxed_unsync()
157}
158
159async fn handle_request(
160    req: hyper::Request<hyper::body::Incoming>,
161    client: &Endpoint,
162    router: &route::Router,
163) -> Result<hyper::Response<BoxBody>, hyper::Error> {
164    let route = router.classify(&req);
165    tracing::info!(method = %req.method(), uri = %req.uri(), route = ?route, "proxy request");
166    match route {
167        route::Route::GenmetaPlainHttp { .. } => {
168            let mut req = req;
169            let self_name = client.name();
170            match req.uri().clone().into_uri(self_name.as_ref()) {
171                Ok(uri) => *req.uri_mut() = uri,
172                Err(e) => {
173                    let error = NormalizeUriSnafu.into_error(e);
174                    tracing::error!(
175                        error = %Report::from_error(&error),
176                        "failed to normalize dhttp uri"
177                    );
178                    return Ok(hyper::Response::builder()
179                        .status(502)
180                        .body(full_body("Bad Gateway"))
181                        .expect("valid static response"));
182                }
183            }
184            match h3_forward::forward_h3(req, client).await {
185                Ok(resp) => Ok(resp.map(box_body)),
186                Err(e) => {
187                    tracing::error!(error = %Report::from_error(&e), "h3 forward failed");
188                    Ok(hyper::Response::builder()
189                        .status(502)
190                        .body(full_body("Bad Gateway"))
191                        .expect("valid static response"))
192                }
193            }
194        }
195        route::Route::GenmetaConnect { .. } => Ok(hyper::Response::builder()
196            .status(502)
197            .body(full_body("HTTPS proxy to .dhttp.net not supported"))
198            .expect("valid static response")),
199        route::Route::TunnelConnect { authority } => {
200            match tunnel::tunnel_connect(req, authority.as_str()).await {
201                Ok(resp) => Ok(resp.map(box_body)),
202                Err(e) => {
203                    tracing::error!(error = %Report::from_error(&e), "tunnel connect failed");
204                    Ok(hyper::Response::builder()
205                        .status(502)
206                        .body(full_body("Bad Gateway"))
207                        .expect("valid static response"))
208                }
209            }
210        }
211        route::Route::StandardForward { .. } => match forward::forward_http(req).await {
212            Ok(resp) => Ok(resp.map(box_body)),
213            Err(e) => {
214                tracing::error!(error = %Report::from_error(&e), "http forward failed");
215                Ok(hyper::Response::builder()
216                    .status(502)
217                    .body(full_body("Bad Gateway"))
218                    .expect("valid static response"))
219            }
220        },
221    }
222}
223
224/// Initialize tracing subscriber, optionally writing to a log file.
225fn init_tracing(options: &Options) -> Result<tracing_appender::non_blocking::WorkerGuard, Error> {
226    let (writer, guard) = if let Some(ref log_path) = options.log
227        && !options.daemon
228    {
229        let file = std::fs::OpenOptions::new()
230            .create(true)
231            .append(true)
232            .open(log_path)
233            .context(CreateLogFileSnafu {
234                path: log_path.clone(),
235            })?;
236        tracing_appender::non_blocking(file)
237    } else {
238        tracing_appender::non_blocking(std::io::stderr())
239    };
240    let use_ansi = (options.log.is_none() || options.daemon) && std::io::stderr().is_terminal();
241    tracing_subscriber::registry()
242        .with(
243            tracing_subscriber::fmt::layer()
244                .with_ansi(use_ansi)
245                .with_timer(tracing_subscriber::fmt::time::LocalTime::rfc_3339())
246                .with_writer(writer),
247        )
248        .with(
249            tracing_subscriber::EnvFilter::builder()
250                .with_default_directive(tracing_subscriber::filter::LevelFilter::INFO.into())
251                .from_env_lossy()
252                .add_directive(
253                    "netlink_packet_route=error"
254                        .parse()
255                        .expect("BUG: static tracing directive is valid"),
256                ),
257        )
258        .init();
259    Ok(guard)
260}
261
262/// Bind TCP listeners on the configured listen addresses.
263async fn bind_listeners(options: &Options) -> Result<Vec<TcpListener>, Error> {
264    let mut listeners = Vec::new();
265    for bind in &options.listens {
266        let ip = bind.host.as_ip_addr().ok_or_else(|| {
267            <Error as snafu::FromString>::without_source(format!(
268                "listen bind `{}` must be a concrete ip address",
269                bind.host
270            ))
271        })?;
272        let addr = SocketAddr::new(ip, bind.effective_port());
273        let listener = TcpListener::bind(addr).await.context(BindListenerSnafu)?;
274        tracing::info!(%addr, "proxy listening");
275        listeners.push(listener);
276    }
277    Ok(listeners)
278}
279
280async fn load_identity_profile(options: &Options) -> Result<Option<IdentityProfile>, Error> {
281    if options.anonymous {
282        return Ok(None);
283    }
284
285    let home = match DhttpHome::load_from_environment() {
286        Ok(home) => home,
287        Err(source) if options.id.is_none() => {
288            tracing::warn!(
289                error = %snafu::Report::from_error(&source),
290                "failed to locate dhttp config, using anonymous endpoint"
291            );
292            return Ok(None);
293        }
294        Err(source) => return Err(LocateDhttpHomeSnafu.into_error(source)),
295    };
296
297    if let Some(name) = &options.id {
298        tracing::debug!(%name, "trying to load command line identity");
299        return home
300            .resolve_identity_profile(name.clone())
301            .await
302            .context(LoadExplicitIdentitySnafu { name: name.clone() })
303            .map(Some);
304    }
305
306    match home.resolve_default_identity_profile().await {
307        Ok(identity) => {
308            tracing::debug!(name = %identity.name(), "using default identity");
309            Ok(Some(identity))
310        }
311        Err(source) => {
312            tracing::debug!(
313                error = %snafu::Report::from_error(&source),
314                "failed to load default identity, using anonymous endpoint"
315            );
316            Ok(None)
317        }
318    }
319}
320
321pub async fn run(options: Options) -> Result<(), Error> {
322    let _guard = init_tracing(&options)?;
323
324    let identity_profile = load_identity_profile(&options).await?;
325    let identity = match &identity_profile {
326        Some(profile) => Some(Arc::new(
327            profile
328                .load_identity()
329                .await
330                .context(LoadIdentitySslSnafu)?,
331        )),
332        None => None,
333    };
334
335    let mut builder = Endpoint::builder()
336        .bind(Arc::new(options.binds.clone()))
337        .maybe_identity(identity);
338    for scheme in options.dns.iter().copied() {
339        builder = builder.dns(scheme);
340    }
341    let client = Arc::new(builder.build().await.context(BuildEndpointSnafu)?);
342
343    let listeners = bind_listeners(&options).await?;
344    let router = Arc::new(route::Router::new());
345
346    let semaphore = Arc::new(Semaphore::new(1024));
347    let mut tasks = JoinSet::new();
348    for listener in listeners {
349        let client = client.clone();
350        let router = router.clone();
351        let semaphore = semaphore.clone();
352        tasks.spawn(accept_loop(listener, client, router, semaphore));
353    }
354
355    while let Some(result) = tasks.join_next().await {
356        match result {
357            Ok(()) => tracing::info!("listener task exited"),
358            Err(e) => {
359                tracing::error!(error = %snafu::Report::from_error(&e), "listener task panicked")
360            }
361        }
362    }
363
364    Ok(())
365}
366
367/// Configure TCP keepalive on a stream to detect dead peers.
368///
369/// After 60 seconds of idle, sends probes every 10 seconds; 3 consecutive
370/// failures trigger a RST (~90 seconds total).
371fn configure_tcp_keepalive(stream: &tokio::net::TcpStream) {
372    let sock = socket2::SockRef::from(stream);
373    let keepalive = socket2::TcpKeepalive::new()
374        .with_time(Duration::from_secs(60))
375        .with_interval(Duration::from_secs(10));
376    // `with_retries` is only available on platforms that support TCP_KEEPCNT.
377    #[cfg(any(
378        target_os = "android",
379        target_os = "dragonfly",
380        target_os = "freebsd",
381        target_os = "fuchsia",
382        target_os = "illumos",
383        target_os = "linux",
384        target_os = "netbsd",
385        target_vendor = "apple",
386    ))]
387    let keepalive = keepalive.with_retries(3);
388    if let Err(e) = sock.set_tcp_keepalive(&keepalive) {
389        tracing::warn!(error = %e, "failed to set TCP keepalive");
390    }
391}
392
393/// Accept loop for a single TCP listener. Runs until the listener is dropped.
394async fn accept_loop(
395    listener: TcpListener,
396    client: Arc<Endpoint>,
397    router: Arc<route::Router>,
398    semaphore: Arc<Semaphore>,
399) {
400    loop {
401        let (stream, addr) = match listener.accept().await {
402            Ok(accepted) => accepted,
403            Err(e) => {
404                tracing::warn!(error = %snafu::Report::from_error(&e), "accept failed, retrying");
405                tokio::time::sleep(Duration::from_millis(33)).await;
406                continue;
407            }
408        };
409        configure_tcp_keepalive(&stream);
410        let permit = match semaphore.clone().acquire_owned().await {
411            Ok(permit) => permit,
412            Err(_) => break, // semaphore closed
413        };
414        tracing::debug!(%addr, "accepted connection");
415        let client = client.clone();
416        let router = router.clone();
417        let span = tracing::info_span!("conn", %addr);
418        // Inherent termination: TCP keepalive detects dead peers (~90s),
419        // header_read_timeout closes idle keep-alive connections (120s).
420        tokio::spawn(
421            async move {
422                let _permit = permit;
423                let io = hyper_util::rt::TokioIo::new(stream);
424                if let Err(e) = hyper::server::conn::http1::Builder::new()
425                    .timer(hyper_util::rt::TokioTimer::new())
426                    .header_read_timeout(Some(Duration::from_secs(120)))
427                    .preserve_header_case(true)
428                    .title_case_headers(true)
429                    .serve_connection(
430                        io,
431                        hyper::service::service_fn(move |req| {
432                            let client = client.clone();
433                            let router = router.clone();
434                            async move { handle_request(req, &client, &router).await }
435                        }),
436                    )
437                    .with_upgrades()
438                    .await
439                {
440                    tracing::error!(error = %Report::from_error(&e), %addr, "connection error");
441                }
442            }
443            .instrument(span),
444        );
445    }
446}
447
448pub mod forward;
449pub mod h3_forward;
450pub mod route;
451pub mod tunnel;