Skip to main content

iroh_proxy_utils/
downstream.rs

1use std::{convert::Infallible, fmt::Debug, io, net::SocketAddr, sync::Arc};
2
3use bytes::Bytes;
4use http::{Method, StatusCode, Version, header};
5use http_body_util::{BodyExt, Empty, StreamBody, combinators::BoxBody};
6use hyper::{
7    Request, Response,
8    body::{Frame, Incoming},
9    service::service_fn,
10};
11use hyper_util::{
12    rt::{TokioExecutor, TokioIo},
13    server::conn::auto,
14};
15use iroh::{
16    Endpoint, EndpointId,
17    endpoint::{ConnectionError, RecvStream, SendStream},
18};
19use iroh_blobs::util::connection_pool::{self, ConnectionPool, ConnectionRef};
20use n0_error::{AnyError, Result, StdResultExt, anyerr, stack_error};
21use n0_future::TryStreamExt;
22use tokio::{
23    io::{AsyncRead, AsyncWrite, AsyncWriteExt},
24    net::{TcpListener, TcpStream},
25};
26use tokio_util::{io::ReaderStream, sync::CancellationToken};
27use tracing::{Instrument, debug, error_span, warn};
28
29pub use self::opts::{
30    Deny, ErrorResponder, HttpProxyOpts, PoolOpts, ProxyMode, RequestHandler, RequestHandlerChain,
31    StaticForwardProxy, StaticReverseProxy,
32};
33use crate::{
34    ALPN, Authority, HEADER_SECTION_MAX_LENGTH, inc_by_delta,
35    parse::{HttpRequest, HttpResponse},
36    util::{
37        Prebufferable, Prebuffered, StreamEvent, TrackedRead, TrackedStream, TrackedWrite,
38        forward_bidi, nores,
39    },
40};
41
42pub(crate) mod metrics;
43pub use self::metrics::DownstreamMetrics;
44pub(crate) mod opts;
45
46/// Proxy that accepts TCP connections and forwards them over iroh.
47///
48/// The downstream proxy is the client-facing component that receives incoming
49/// TCP connections (typically HTTP requests) and routes them to upstream proxies
50/// via iroh's peer-to-peer QUIC connections.
51///
52/// # Modes
53///
54/// - **TCP mode**: Blindly tunnels all traffic to a fixed upstream destination.
55/// - **HTTP mode**: Parses HTTP requests to enable dynamic routing and supports
56///   both forward proxy (absolute-form) and reverse proxy (origin-form) requests.
57///
58/// # Connection Pooling
59///
60/// Maintains a pool of iroh connections to upstream endpoints for efficiency.
61/// Multiple requests to the same endpoint share a single QUIC connection.
62#[derive(Clone, Debug)]
63pub struct DownstreamProxy {
64    pool: ConnectionPool,
65    metrics: Arc<DownstreamMetrics>,
66}
67
68impl DownstreamProxy {
69    /// Creates a downstream proxy with the given endpoint and pool options.
70    pub fn new(endpoint: Endpoint, pool_opts: PoolOpts) -> Self {
71        let metrics = Arc::new(DownstreamMetrics::default());
72        let opts: connection_pool::Options = pool_opts.into();
73
74        // Track connection open/close metrics.
75        let pool_opts = opts.with_on_connected({
76            let metrics = metrics.clone();
77            // This conn clone is *not* protected by a pool guard, so by awaiting its close method
78            // we don't keep the connection alive post-idle.
79            move |_endpoint, unguarded_conn| {
80                let metrics = metrics.clone();
81                async move {
82                    metrics.iroh_connections_opened.inc();
83                    let metrics = metrics.clone();
84                    tokio::spawn(async move {
85                        let reason = unguarded_conn.closed().await;
86                        match reason {
87                            ConnectionError::LocallyClosed => {
88                                metrics.iroh_connections_closed_idle.inc();
89                            }
90                            // It's always us that closes connections gracefully.
91                            _ => {
92                                metrics.iroh_connections_closed_error.inc();
93                            }
94                        }
95                    });
96                    Ok(())
97                }
98            }
99        });
100
101        let pool = ConnectionPool::new(endpoint, ALPN, pool_opts);
102        Self { pool, metrics }
103    }
104
105    /// Returns the downstream metrics tracker.
106    pub fn metrics(&self) -> &Arc<DownstreamMetrics> {
107        &self.metrics
108    }
109
110    /// Opens a CONNECT tunnel to the upstream proxy and returns the client streams.
111    ///
112    /// Note: any non-`200 OK` response from upstream is returned as a `ProxyError`.
113    pub async fn create_tunnel(
114        &self,
115        destination: &EndpointAuthority,
116    ) -> Result<TunnelClientStreams, ProxyError> {
117        let (conn, mut send, recv) = self
118            .connect(destination.endpoint_id)
119            .await
120            .map_err(ProxyError::gateway_timeout)?;
121        send.write_all(destination.authority.to_connect_request().as_bytes())
122            .await?;
123        let mut recv = Prebuffered::new(recv, HEADER_SECTION_MAX_LENGTH);
124        let response = HttpResponse::read(&mut recv)
125            .await
126            .map_err(ProxyError::bad_gateway)?;
127        debug!(status=%response.status, "response from upstream");
128        if response.status != StatusCode::OK {
129            Err(ProxyError::new(
130                Some(response.status),
131                anyerr!("Upstream gateway returned error response"),
132            ))
133        } else {
134            Ok(TunnelClientStreams { send, recv, conn })
135        }
136    }
137
138    /// Accepts TCP connections from the listener and forwards each in a new task.
139    ///
140    /// Runs indefinitely until the listener errors or the task is cancelled.
141    pub async fn forward_tcp_listener(&self, listener: TcpListener, mode: ProxyMode) -> Result<()> {
142        let cancel_token = CancellationToken::new();
143        let _cancel_guard = cancel_token.clone().drop_guard();
144        let mut id = 0;
145        loop {
146            let (stream, addr) = listener.accept().await?;
147            let span = error_span!("tcp-accept", id);
148            let addr = SrcAddr::Tcp(addr);
149            self.spawn_forward_stream(addr, stream, mode.clone(), span, cancel_token.child_token());
150            id += 1;
151        }
152    }
153
154    /// Accepts UDS connections from the Unix Domain Socket and forwards each in a new task.
155    ///
156    /// Runs indefinitely until the listener errors or the task is cancelled.
157    #[cfg(unix)]
158    pub async fn forward_uds_listener(
159        &self,
160        listener: tokio::net::UnixListener,
161        mode: ProxyMode,
162    ) -> Result<()> {
163        let cancel_token = CancellationToken::new();
164        let _cancel_guard = cancel_token.clone().drop_guard();
165        let mut id = 0;
166        loop {
167            let (stream, addr) = listener.accept().await?;
168            let addr = SrcAddr::Unix(addr.into());
169            let span = error_span!("uds-accept", id);
170            self.spawn_forward_stream(addr, stream, mode.clone(), span, cancel_token.child_token());
171            id += 1;
172        }
173    }
174
175    fn spawn_forward_stream(
176        &self,
177        client_addr: SrcAddr,
178        stream: impl SplittableStream,
179        mode: ProxyMode,
180        span: tracing::Span,
181        cancel_token: CancellationToken,
182    ) {
183        let this = self.clone();
184        tokio::spawn(
185            cancel_token
186                .child_token()
187                .run_until_cancelled_owned(async move {
188                    debug!(%client_addr, "accepted connection");
189                    if let Err(err) = this.forward_stream(client_addr, stream, &mode).await {
190                        warn!("Failed to handle connection: {err:#}");
191                    }
192                })
193                .instrument(span),
194        );
195    }
196
197    /// Forwards a single TCP stream.
198    ///
199    /// For [`ProxyMode::Http`], this parses the first HTTP request from the stream, and then forwards or rejects according
200    /// to the configured [`HttpProxyOpts`].
201    /// For [`ProxyMode::Tcp`], this creates a CONNECT tunnel to the configured upstream and authority, and forwards the TCP
202    /// stream without parsing anything.
203    async fn forward_stream(
204        &self,
205        src_addr: SrcAddr,
206        mut stream: impl SplittableStream + 'static,
207        mode: &ProxyMode,
208    ) -> Result<()> {
209        match mode {
210            ProxyMode::Tcp(destination) => {
211                self.metrics.requests_accepted.inc();
212                self.metrics.requests_accepted_tcp.inc();
213                let (tcp_recv, tcp_send) = stream.split();
214                let mut conn = self.create_tunnel(destination).await?;
215                debug!(endpoint_id=%conn.conn.remote_id().fmt_short(), "tunnel established");
216                let metrics = self.metrics.clone();
217                let mut tcp_recv =
218                    TrackedRead::new(tcp_recv, inc_by_delta!(metrics, bytes_to_upstream));
219                let mut tcp_send =
220                    TrackedWrite::new(tcp_send, inc_by_delta!(metrics, bytes_from_upstream));
221                let res =
222                    forward_bidi(&mut tcp_recv, &mut tcp_send, &mut conn.recv, &mut conn.send)
223                        .await
224                        .map_err(ProxyError::io);
225                match res {
226                    Ok(_) => {
227                        self.metrics.requests_completed.inc();
228                        Ok(())
229                    }
230                    Err(err) => {
231                        self.metrics.requests_failed.inc();
232                        Err(err.into())
233                    }
234                }
235            }
236            ProxyMode::Http(opts) => {
237                let io = TokioIo::new(stream);
238                let service = service_fn(|req| {
239                    let this = self.clone();
240                    let opts = opts.clone();
241                    let src_addr = src_addr.clone();
242                    async move {
243                        let res = match this.handle_hyper_request(src_addr, req, &opts).await {
244                            Ok(res) => res,
245                            Err(err) => {
246                                warn!("Error while forwarding HTTP/2 request: {err:#}");
247                                let status =
248                                    err.response_status().unwrap_or(StatusCode::BAD_GATEWAY);
249                                opts.error_response(status).await
250                            }
251                        };
252                        Ok::<_, Infallible>(res)
253                    }
254                });
255                let mut builder = auto::Builder::new(TokioExecutor::new());
256                builder
257                    .http2()
258                    .initial_stream_window_size(1 << 20)
259                    .initial_connection_window_size(1 << 20)
260                    .max_concurrent_streams(1024)
261                    .enable_connect_protocol();
262                builder.serve_connection_with_upgrades(io, service).await?;
263                Ok(())
264            }
265        }
266    }
267
268    async fn connect(
269        &self,
270        destination: EndpointId,
271    ) -> Result<(ConnectionRef, SendStream, RecvStream), ProxyError> {
272        let conn = self
273            .pool
274            .get_or_connect(destination)
275            .await
276            .map_err(|err| ProxyError::gateway_timeout(anyerr!(err)))?;
277        let (send, recv) = conn
278            .open_bi()
279            .await
280            .map_err(|err| ProxyError::bad_gateway(anyerr!(err)))?;
281        Ok((conn, send, recv))
282    }
283
284    async fn handle_hyper_request(
285        &self,
286        src_addr: SrcAddr,
287        mut request: Request<Incoming>,
288        opts: &HttpProxyOpts,
289    ) -> Result<Response<HyperBody>, ProxyError> {
290        debug!(?request, "incoming");
291
292        let original_version = request.version();
293        let is_upgrade = request.headers().contains_key(header::UPGRADE);
294        let is_connect = request.method() == Method::CONNECT;
295        let is_h2_extended_connect = convert_h2_extended_connect_to_upgrade(&mut request);
296        let upgrade = if is_connect || is_upgrade {
297            Some(hyper::upgrade::on(&mut request))
298        } else {
299            None
300        };
301
302        let (parts, body) = request.into_parts();
303        let mut request = HttpRequest::from_parts(parts);
304
305        let metrics = self.metrics.clone();
306
307        let destination = match opts
308            .request_handler
309            .handle_request(src_addr, &mut request)
310            .await
311        {
312            Ok(destination) => destination,
313            Err(deny) => {
314                metrics.requests_denied.inc();
315                return Err(ProxyError::from(deny));
316            }
317        };
318
319        // track metrics.
320        metrics.requests_accepted.inc();
321        if original_version == Version::HTTP_2 {
322            metrics.requests_accepted_h2.inc();
323            if is_connect {
324                if is_h2_extended_connect {
325                    metrics.requests_accepted_h2_extended_connect.inc();
326                } else {
327                    metrics.requests_accepted_h2_connect.inc();
328                }
329            }
330        } else {
331            metrics.requests_accepted_h1.inc();
332            if is_connect {
333                metrics.requests_accepted_h1_connect.inc();
334            }
335            if is_upgrade {
336                metrics.requests_accepted_h1_upgrade.inc();
337            }
338        }
339
340        // We always forward as HTTP/1.1.
341        request.version = Version::HTTP_11;
342        // Now we shouldn't mutate the request anymore.
343        let request = request;
344
345        debug!(destination=%destination.fmt_short(), ?request, is_connect, is_h2_extended_connect, is_upgrade, "pipe request to upstream");
346
347        // Connect to upstream.
348        let (conn, send, recv) = self.connect(destination).await?;
349        debug!(endpoint_id=%conn.remote_id().fmt_short(), "connected to upstream");
350
351        // We need to keep `conn` alive until the request is fully processed in both directions.
352        // `conn` is a `ConnectionRef` guard handed out from the connection pool. Once it is dropped,
353        // the connection is marked as idle, and will be closed after the idle timeout.
354        let conn_guard = Arc::new(conn);
355
356        // We want to track bytes written to/from upstream. And we store the `conn_guard` into the streams.
357        // Once both streams are dropped, the request is fully done, and we can drop the conn ref safely.
358        let mut upstream_send = TrackedWrite::new(send, inc_by_delta!(metrics, bytes_to_upstream))
359            .with_guard(conn_guard.clone());
360        let upstream_recv = TrackedRead::new(recv, inc_by_delta!(metrics, bytes_from_upstream))
361            .with_guard(conn_guard.clone());
362        // We need to prebuffer for reading the response before passing it on.
363        let mut upstream_recv = Prebuffered::new(upstream_recv, HEADER_SECTION_MAX_LENGTH);
364
365        // Send request headers.
366        request.write(&mut upstream_send).await?;
367
368        let response = if let Some(upgrade_fut) = upgrade {
369            // For upgrade requests: First read the response to see if the upgrade was accepted.
370            // Only then forward the request body as upgrade stream.
371            let mut response = match read_response(&mut upstream_recv).await {
372                Ok(response) => response,
373                Err(err) => {
374                    metrics.requests_failed.inc();
375                    return Err(err.into());
376                }
377            };
378            debug!(?response, "read connect response");
379
380            if is_h2_extended_connect && response.status == StatusCode::SWITCHING_PROTOCOLS {
381                response.status = StatusCode::OK;
382                response.headers.remove(header::UPGRADE);
383                response.headers.remove(header::CONNECTION);
384            }
385
386            let is_ok = is_connect && response.status == StatusCode::OK
387                || is_upgrade && response.status == StatusCode::SWITCHING_PROTOCOLS;
388
389            if is_ok {
390                spawn(forward_hyper_upgrade(
391                    upgrade_fut,
392                    upstream_recv,
393                    upstream_send,
394                ));
395                response_to_hyper::<tokio::io::Empty>(response, None, metrics)?
396            } else if request.method == Method::CONNECT {
397                response_to_hyper::<tokio::io::Empty>(response, None, metrics)?
398            } else {
399                spawn(forward_hyper_body_and_finish(body, upstream_send));
400                response_to_hyper(response, Some(upstream_recv), metrics)?
401            }
402        } else {
403            // For non-upgrade requests: Forward the body and read the response concurrently.
404            spawn(forward_hyper_body_and_finish(body, upstream_send));
405            let response = match read_response(&mut upstream_recv).await {
406                Ok(response) => response,
407                Err(err) => {
408                    metrics.requests_failed.inc();
409                    return Err(err.into());
410                }
411            };
412            debug!(
413                status = %response.status,
414                "received response header from upstream"
415            );
416            response_to_hyper(response, Some(upstream_recv), metrics)?
417        };
418
419        Ok(response)
420    }
421}
422
423fn convert_h2_extended_connect_to_upgrade(request: &mut Request<Incoming>) -> bool {
424    if request.version() != Version::HTTP_2 {
425        return false;
426    }
427    // Handle HTTP/2 extended CONNECT (RFC 8441) - convert to upgrade-style request.
428    // Extended CONNECT uses :protocol pseudo-header instead of Upgrade header.
429    let extended_connect_protocol = request
430        .extensions()
431        .get::<hyper::ext::Protocol>()
432        .map(|p| p.as_str().to_string());
433    if let Some(protocol) = extended_connect_protocol {
434        debug!(%protocol, "extended CONNECT request, converting to upgrade request");
435        *request.method_mut() = Method::GET;
436        request
437            .headers_mut()
438            .insert(header::UPGRADE, protocol.parse().unwrap());
439        request
440            .headers_mut()
441            .insert(header::CONNECTION, "upgrade".parse().unwrap());
442        true
443    } else {
444        false
445    }
446}
447
448trait SplittableStream: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {
449    fn split<'a>(
450        &'a mut self,
451    ) -> (
452        impl AsyncRead + Send + Unpin + 'a,
453        impl AsyncWrite + Send + Unpin + 'a,
454    );
455}
456
457impl SplittableStream for TcpStream {
458    fn split<'a>(
459        &'a mut self,
460    ) -> (
461        impl AsyncRead + Send + Unpin + 'a,
462        impl AsyncWrite + Send + Unpin + 'a,
463    ) {
464        TcpStream::split(self)
465    }
466}
467
468#[cfg(unix)]
469impl SplittableStream for tokio::net::UnixStream {
470    fn split<'a>(
471        &'a mut self,
472    ) -> (
473        impl AsyncRead + Send + Unpin + 'a,
474        impl AsyncWrite + Send + Unpin + 'a,
475    ) {
476        tokio::net::UnixStream::split(self)
477    }
478}
479
480/// Source address for downstream client streams.
481#[derive(derive_more::From, Debug, Clone, derive_more::Display)]
482pub enum SrcAddr {
483    /// TCP source address
484    #[display("{_0}")]
485    Tcp(SocketAddr),
486    /// UDS source address
487    #[cfg(unix)]
488    #[display("Unix({_0:?})")]
489    Unix(std::os::unix::net::SocketAddr),
490}
491
492/// Bidirectional QUIC streams for an established tunnel.
493///
494/// Returned by [`DownstreamProxy::create_tunnel`] after a successful CONNECT
495/// handshake with the upstream proxy. Use these streams for bidirectional
496/// data transfer through the tunnel.
497pub struct TunnelClientStreams {
498    /// Send stream toward the upstream proxy.
499    pub send: SendStream,
500    /// Receive stream from the upstream proxy (with read-ahead buffer).
501    pub recv: Prebuffered<RecvStream>,
502    /// Connection reference that keeps the QUIC connection alive.
503    pub conn: ConnectionRef,
504}
505
506/// Routing destination combining an iroh endpoint and target authority.
507///
508/// Specifies both the upstream proxy to connect to (via `endpoint_id`) and
509/// the origin server to reach through that proxy (via `authority`).
510#[derive(Debug, Clone)]
511pub struct EndpointAuthority {
512    /// Iroh endpoint identifier of the upstream proxy.
513    pub endpoint_id: EndpointId,
514    /// Target authority for the CONNECT request (host:port).
515    pub authority: Authority,
516}
517
518impl EndpointAuthority {
519    /// Creates a new endpoint-authority pair.
520    pub fn new(endpoint_id: EndpointId, authority: Authority) -> Self {
521        Self {
522            endpoint_id,
523            authority,
524        }
525    }
526
527    /// Returns a short string representation for logging.
528    pub fn fmt_short(&self) -> String {
529        format!("{}->{}", self.endpoint_id.fmt_short(), self.authority)
530    }
531}
532
533/// Error from downstream proxy operations.
534#[stack_error(add_meta, derive)]
535pub struct ProxyError {
536    response_status: Option<StatusCode>,
537    #[error(source)]
538    source: AnyError,
539}
540
541impl From<Deny> for ProxyError {
542    #[track_caller]
543    fn from(value: Deny) -> Self {
544        ProxyError::new(Some(value.code), value.reason)
545    }
546}
547
548impl From<io::Error> for ProxyError {
549    fn from(value: io::Error) -> Self {
550        Self::io(value)
551    }
552}
553
554impl From<iroh::endpoint::WriteError> for ProxyError {
555    fn from(value: iroh::endpoint::WriteError) -> Self {
556        Self::io(anyerr!(value))
557    }
558}
559
560impl ProxyError {
561    /// Returns the HTTP status code to surface to the client, if any.
562    pub fn response_status(&self) -> Option<StatusCode> {
563        self.response_status
564    }
565
566    fn gateway_timeout(source: impl Into<AnyError>) -> Self {
567        Self::new(Some(StatusCode::GATEWAY_TIMEOUT), source.into())
568    }
569
570    fn bad_gateway(source: impl Into<AnyError>) -> Self {
571        Self::new(Some(StatusCode::BAD_GATEWAY), source.into())
572    }
573
574    fn io(source: impl Into<AnyError>) -> Self {
575        Self::new(None, source.into())
576    }
577}
578
579type HyperBody = BoxBody<Bytes, io::Error>;
580
581fn response_to_hyper<R>(
582    response: HttpResponse,
583    body: Option<R>,
584    metrics: Arc<DownstreamMetrics>,
585) -> Result<Response<HyperBody>, ProxyError>
586where
587    R: AsyncRead + Send + Sync + Unpin + 'static,
588{
589    let mut builder = Response::builder().status(response.status);
590    let headers = builder.headers_mut().unwrap();
591    *headers = response.headers;
592    let body = match body {
593        Some(body) => {
594            let stream = ReaderStream::new(body);
595            let stream = TrackedStream::new(stream, move |ev| match ev {
596                StreamEvent::Done(Ok(())) => nores(metrics.requests_completed.inc()),
597                StreamEvent::Done(Err(_)) => nores(metrics.requests_failed.inc()),
598                _ => {}
599            });
600            StreamBody::new(stream.map_ok(Frame::data)).boxed()
601        }
602        None => Empty::new().map_err(infallible_to_io).boxed(),
603    };
604    builder
605        .body(body)
606        .map_err(|err| ProxyError::bad_gateway(anyerr!(err)))
607}
608
609async fn forward_hyper_body_and_finish<F, G: Unpin>(
610    body: Incoming,
611    mut send: TrackedWrite<SendStream, F, G>,
612) -> Result<()>
613where
614    F: Fn(u64) + Unpin + Send + 'static,
615{
616    forward_hyper_body(body, &mut send).await?;
617    send.into_inner().finish().anyerr()?;
618    Ok(())
619}
620
621/// Forwards hyper body to send stream without finishing.
622/// Used for upgrade requests where we may need to continue using the stream.
623async fn forward_hyper_body(
624    mut body: Incoming,
625    send: &mut (impl AsyncWrite + Unpin),
626) -> Result<()> {
627    while let Some(frame) = body.frame().await {
628        let frame = frame.anyerr()?;
629        // TODO: Add support for trailers.
630        if let Ok(data) = frame.into_data() {
631            send.write_all(&data).await.anyerr()?;
632        }
633    }
634    Ok(())
635}
636
637async fn forward_hyper_upgrade(
638    upgrade_fut: hyper::upgrade::OnUpgrade,
639    mut upstream_recv: impl AsyncRead + Send + Unpin,
640    mut upstream_send: impl AsyncWrite + Send + Unpin,
641) -> Result<()> {
642    let upgraded = upgrade_fut.await.std_context("HTTP/1 upgrade failed")?;
643    let upgraded = TokioIo::new(upgraded);
644    // Split the upgraded connection for bidirectional copy
645    let (mut client_read, mut client_write) = tokio::io::split(upgraded);
646    forward_bidi(
647        &mut client_read,
648        &mut client_write,
649        &mut upstream_recv,
650        &mut upstream_send,
651    )
652    .await?;
653    Ok(())
654}
655
656async fn read_response(recv: &mut impl Prebufferable) -> Result<HttpResponse, ProxyError> {
657    HttpResponse::read(recv)
658        .await
659        .map_err(ProxyError::bad_gateway)
660}
661
662fn infallible_to_io(err: Infallible) -> io::Error {
663    match err {}
664}
665
666fn spawn<F, T>(fut: F) -> tokio::task::JoinHandle<()>
667where
668    F: Future<Output = Result<T>> + Send + 'static,
669{
670    tokio::spawn(
671        async move {
672            if let Err(err) = fut.await {
673                warn!("{err:#}")
674            }
675        }
676        .instrument(tracing::Span::current()),
677    )
678}