Skip to main content

iroh_proxy_utils/
upstream.rs

1use std::{
2    sync::{
3        Arc,
4        atomic::{AtomicU64, Ordering},
5    },
6    time::Duration,
7};
8
9use http::{StatusCode, Version};
10use iroh::{
11    EndpointId,
12    endpoint::{Connection, ConnectionError, RecvStream, SendStream},
13    protocol::{AcceptError, ProtocolHandler},
14};
15use n0_error::{Result, StackResultExt, StdResultExt};
16use n0_future::stream::StreamExt;
17use tokio::{
18    io::{AsyncWrite, AsyncWriteExt},
19    net::TcpStream,
20};
21use tokio_util::{future::FutureExt, sync::CancellationToken, task::TaskTracker};
22use tracing::{Instrument, debug, error_span, instrument, warn};
23
24use crate::{
25    Authority, HEADER_SECTION_MAX_LENGTH, HttpResponse,
26    parse::{
27        HttpProxyRequestKind, HttpRequest, absolute_target_to_origin_form,
28        filter_hop_by_hop_headers,
29    },
30    util::{
31        Prebuffered, StreamEvent, TrackedRead, TrackedStream, TrackedWrite, forward_bidi, nores,
32        recv_to_stream,
33    },
34};
35
36mod auth;
37mod metrics;
38pub use auth::*;
39pub use metrics::*;
40
41const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(1);
42
43/// Supported HTTP upgrade protocols. Only these will be forwarded with upgrade support.
44const SUPPORTED_UPGRADE_PROTOCOLS: &[&str] = &["websocket"];
45
46/// Proxy that receives iroh streams and forwards them to origin servers.
47///
48/// The upstream proxy is the server-side component that accepts connections from
49/// downstream proxies over iroh and forwards requests to actual TCP origin servers.
50///
51/// # Protocol Support
52///
53/// - **CONNECT tunnels**: Establishes TCP connections to the requested authority
54///   and bidirectionally forwards data.
55/// - **Absolute-form requests**: Forwards HTTP requests to origin servers using
56///   reqwest, with hop-by-hop header filtering per RFC 9110.
57///
58/// # Authorization
59///
60/// All requests pass through an [`AuthHandler`] before processing. Unauthorized
61/// requests receive a 403 Forbidden response.
62///
63/// # Usage
64///
65/// Implements [`ProtocolHandler`] for use with iroh's [`Router`](iroh::protocol::Router):
66///
67/// ```ignore
68/// let proxy = UpstreamProxy::new(AcceptAll)?;
69/// let router = Router::builder(endpoint)
70///     .accept(ALPN, proxy)
71///     .spawn();
72/// ```
73#[derive(derive_more::Debug)]
74pub struct UpstreamProxy {
75    #[debug("Arc<dyn AuthHandler>")]
76    auth: Arc<DynAuthHandler<'static>>,
77    conn_id: Arc<AtomicU64>,
78    shutdown: CancellationToken,
79    tasks: TaskTracker,
80    http_client: reqwest::Client,
81    metrics: Arc<UpstreamMetrics>,
82}
83
84impl ProtocolHandler for UpstreamProxy {
85    #[instrument("accept", level="error", skip_all, fields(id=self.conn_id.fetch_add(1, Ordering::SeqCst)))]
86    async fn accept(
87        &self,
88        connection: Connection,
89    ) -> std::result::Result<(), iroh::protocol::AcceptError> {
90        debug!(remote_id=%connection.remote_id().fmt_short(), "accepted connection");
91        self.metrics.connections_accepted.inc();
92        let res = self
93            .handle_connection(connection)
94            .await
95            .map_err(AcceptError::from_err);
96        self.metrics.connections_completed.inc();
97        res
98    }
99
100    async fn shutdown(&self) {
101        self.shutdown.cancel();
102        self.tasks.close();
103        debug!("shutting down ({} pending tasks)", self.tasks.len());
104        match self.tasks.wait().timeout(GRACEFUL_SHUTDOWN_TIMEOUT).await {
105            Ok(_) => debug!("all streams closed cleanly"),
106            Err(_) => debug!(
107                remaining = self.tasks.len(),
108                "not all streams closed in time, abort"
109            ),
110        }
111    }
112}
113
114impl UpstreamProxy {
115    /// Creates a new upstream proxy with the provided authorization handler.
116    pub fn new(auth: impl AuthHandler + 'static) -> Result<Self> {
117        Ok(Self {
118            auth: DynAuthHandler::new_arc(auth),
119            conn_id: Default::default(),
120            shutdown: CancellationToken::new(),
121            tasks: TaskTracker::new(),
122            http_client: reqwest::Client::new(),
123            metrics: Default::default(),
124        })
125    }
126
127    /// Returns the metrics tracker for this upstream proxy.
128    pub fn metrics(&self) -> Arc<UpstreamMetrics> {
129        self.metrics.clone()
130    }
131
132    /// Returns a future that resolves when this upstream proxy begins shutting down.
133    pub fn on_shutdown(&self) -> impl Future<Output = ()> + Send + 'static + use<> {
134        self.shutdown.clone().cancelled_owned()
135    }
136
137    async fn handle_connection(&self, connection: Connection) -> Result<()> {
138        let remote_id = connection.remote_id();
139        let mut stream_id = 0;
140        loop {
141            let (send, recv) = match connection
142                .accept_bi()
143                .with_cancellation_token(&self.shutdown)
144                .await
145            {
146                None => return Ok(()),
147                Some(Ok(streams)) => streams,
148                Some(Err(ConnectionError::ApplicationClosed(_))) => {
149                    debug!("connection closed by downstream remote");
150                    return Ok(());
151                }
152                Some(Err(err)) => {
153                    return Err(err).std_context("failed to accept streams");
154                }
155            };
156            let auth = self.auth.clone();
157            let shutdown = self.shutdown.clone();
158            let http_client = self.http_client.clone();
159            let metrics = self.metrics.clone();
160            self.tasks.spawn(
161                // We don't actually shutdown the stream task. If it didn't end by the time we stop waiting at shutdown,
162                // the connection will be closed, which causes the task to finish.
163                async move {
164                    if let Err(err) = Self::handle_remote_streams(
165                        auth,
166                        remote_id,
167                        send,
168                        recv,
169                        http_client,
170                        metrics,
171                    )
172                    .await
173                    {
174                        if shutdown.is_cancelled() {
175                            debug!("aborted at shutdown: {err:#}");
176                        } else {
177                            warn!("failed to handle streams: {err:#}");
178                        }
179                    }
180                }
181                .instrument(error_span!("stream", id=%stream_id)),
182            );
183            stream_id += 1;
184        }
185    }
186
187    async fn handle_remote_streams(
188        auth: Arc<DynAuthHandler<'static>>,
189        remote_id: EndpointId,
190        mut downstream_send: SendStream,
191        downstream_recv: RecvStream,
192        http_client: reqwest::Client,
193        metrics: Arc<UpstreamMetrics>,
194    ) -> Result<()> {
195        let mut downstream_recv = Prebuffered::new(downstream_recv, HEADER_SECTION_MAX_LENGTH);
196        let (request_len, req) = HttpRequest::peek(&mut downstream_recv).await?;
197        downstream_recv.discard(request_len);
198
199        debug!(?req, "handle request");
200        let req = req
201            .try_into_proxy_request()
202            .context("Received origin-form request but expected proxy request")?;
203
204        let id = req.kind.authority()?;
205        let req_metrics = metrics.get_or_insert(id);
206        req_metrics.bytes_to_origin.inc_by(request_len as u64);
207
208        match auth.authorize(remote_id, &req).await {
209            Ok(()) => {
210                metrics.requests_accepted.inc();
211                req_metrics.requests_accepted.inc();
212                debug!("request is authorized, continue");
213            }
214            Err(reason) => {
215                metrics.requests_denied.inc();
216                req_metrics.requests_denied.inc();
217                debug!(?reason, "request is not authorized, abort");
218                HttpResponse::new(StatusCode::FORBIDDEN)
219                    .no_body()
220                    .write(&mut downstream_send, true)
221                    .await
222                    .ok();
223                downstream_send.finish().anyerr()?;
224                return Ok(());
225            }
226        };
227
228        match req.kind {
229            HttpProxyRequestKind::Tunnel { target: authority } => {
230                debug!(%authority, "tunnel request: connecting to origin");
231                match TcpStream::connect(authority.to_addr()).await {
232                    Err(err) => {
233                        warn!("Failed to connect to origin server: {err:#}");
234                        metrics.requests_failed.inc();
235                        req_metrics.requests_failed.inc();
236                        error_response_and_finish(downstream_send).await?;
237                        Ok(())
238                    }
239                    Ok(tcp_stream) => {
240                        debug!(%authority, "connected to origin");
241                        HttpResponse::with_reason(StatusCode::OK, "Connection Established")
242                            .write(&mut downstream_send, true)
243                            .await
244                            .context("Failed to write CONNECT response to downstream")?;
245                        let (mut origin_recv, mut origin_send) = tcp_stream.into_split();
246
247                        let mut downstream_recv = TrackedRead::new(&mut downstream_recv, |d| {
248                            req_metrics.bytes_to_origin.inc_by(d);
249                        });
250                        let mut downstream_send = TrackedWrite::new(&mut downstream_send, |d| {
251                            req_metrics.bytes_from_origin.inc_by(d);
252                        });
253
254                        match forward_bidi(
255                            &mut downstream_recv,
256                            &mut downstream_send,
257                            &mut origin_recv,
258                            &mut origin_send,
259                        )
260                        .await
261                        {
262                            Ok((to_origin, from_origin)) => {
263                                metrics.requests_completed.inc();
264                                req_metrics.requests_completed.inc();
265                                debug!(to_origin, from_origin, "finish");
266                                Ok(())
267                            }
268                            Err(err) => {
269                                metrics.requests_failed.inc();
270                                req_metrics.requests_failed.inc();
271                                Err(err)
272                            }
273                        }
274                    }
275                }
276            }
277            HttpProxyRequestKind::Absolute { method, target } => {
278                // Check if this is an upgrade request we should handle specially
279                let upgrade_protocol = req
280                    .headers
281                    .get(http::header::UPGRADE)
282                    .and_then(|v| v.to_str().ok())
283                    .filter(|proto| {
284                        SUPPORTED_UPGRADE_PROTOCOLS
285                            .iter()
286                            .any(|p| p.eq_ignore_ascii_case(proto))
287                    });
288
289                if let Some(protocol) = upgrade_protocol {
290                    debug!(%target, %protocol, "upgrade request: connecting to origin");
291                    let mut headers = req.headers;
292                    filter_hop_by_hop_headers(&mut headers);
293                    // Request came in absolute-form over the tunnel; convert to origin-form for the origin.
294                    let authority = Authority::from_absolute_uri(&target)?;
295                    let origin_form_uri = absolute_target_to_origin_form(&target)?;
296                    let request = HttpRequest {
297                        version: Version::HTTP_11,
298                        headers,
299                        uri: origin_form_uri,
300                        method,
301                    };
302                    match Self::handle_upgrade_request(
303                        authority,
304                        request,
305                        downstream_recv,
306                        downstream_send,
307                        req_metrics.clone(),
308                    )
309                    .await
310                    {
311                        Ok(()) => {
312                            metrics.requests_completed.inc();
313                            req_metrics.requests_completed.inc();
314                            Ok(())
315                        }
316                        Err(err) => {
317                            metrics.requests_failed.inc();
318                            req_metrics.requests_failed.inc();
319                            Err(err)
320                        }
321                    }
322                } else {
323                    debug!(%target, "origin request: connecting to origin");
324                    let body = {
325                        let req_metrics = req_metrics.clone();
326                        let body = recv_to_stream(downstream_recv);
327                        let body = TrackedStream::new(body, move |ev| match ev {
328                            StreamEvent::Data(n) => nores(req_metrics.bytes_to_origin.inc_by(n)),
329                            _ => {}
330                        });
331                        reqwest::Body::wrap_stream(body)
332                    };
333
334                    // Filter hop-by-hop headers before forwarding to upstream per RFC 9110.
335                    let mut headers = req.headers;
336                    filter_hop_by_hop_headers(&mut headers);
337
338                    // Forward the request to the upstream server.
339                    let mut response = match http_client
340                        .request(method, target.to_string())
341                        .headers(headers)
342                        .body(body)
343                        .send()
344                        .await
345                    {
346                        Ok(response) => response,
347                        Err(err) => {
348                            error_response_and_finish(downstream_send).await?;
349                            metrics.requests_failed.inc();
350                            req_metrics.requests_failed.inc();
351                            return Err(err).anyerr();
352                        }
353                    };
354                    filter_hop_by_hop_headers(response.headers_mut());
355                    debug!(?response, "received response from origin");
356                    let res = forward_reqwest_response(
357                        response,
358                        &mut downstream_send,
359                        req_metrics.clone(),
360                    )
361                    .await;
362                    match res {
363                        Ok(total) => {
364                            debug!(response_body_len=%total, "finish");
365                            metrics.requests_completed.inc();
366                            req_metrics.requests_completed.inc();
367                            Ok(())
368                        }
369                        Err(err) => {
370                            metrics.requests_failed.inc();
371                            req_metrics.requests_failed.inc();
372                            Err(err)
373                        }
374                    }
375                }
376            }
377        }
378    }
379
380    /// Handle HTTP upgrade requests (e.g., WebSocket) by connecting directly to origin.
381    ///
382    /// This bypasses reqwest since it doesn't support HTTP upgrades. We send the
383    /// request manually over TCP, and if we get 101 Switching Protocols, we pipe
384    /// the connection bidirectionally. The request URI should be in origin-form
385    /// (path + query only); `authority` is used for the TCP connection.
386    async fn handle_upgrade_request(
387        authority: Authority,
388        request: HttpRequest,
389        mut downstream_recv: Prebuffered<RecvStream>,
390        mut downstream_send: SendStream,
391        req_metrics: Arc<TargetMetrics>,
392    ) -> Result<()> {
393        // Connect to origin
394        let origin = match TcpStream::connect(authority.to_addr()).await {
395            Ok(stream) => stream,
396            Err(err) => {
397                warn!("Failed to connect to origin for upgrade: {err:#}");
398                error_response_and_finish(downstream_send).await?;
399                return Err(err).anyerr();
400            }
401        };
402        let (origin_recv, mut origin_send) = origin.into_split();
403
404        let mut downstream_recv = TrackedRead::new(&mut downstream_recv, |d| {
405            req_metrics.bytes_to_origin.inc_by(d);
406        });
407        let mut downstream_send = TrackedWrite::new(&mut downstream_send, |d| {
408            req_metrics.bytes_from_origin.inc_by(d);
409        });
410
411        // Send the HTTP request to origin
412        request.write(&mut origin_send).await?;
413
414        // Read and forward the response from origin (expect 101 Switching Protocols)
415        let mut origin_recv = Prebuffered::new(origin_recv, HEADER_SECTION_MAX_LENGTH);
416        let response = HttpResponse::read(&mut origin_recv).await?;
417        debug!(?response, "upgrade response from origin");
418        response.write(&mut downstream_send, true).await?;
419
420        if response.status != StatusCode::SWITCHING_PROTOCOLS {
421            downstream_send.into_inner().finish().anyerr()?;
422            return Ok(());
423        }
424
425        // Pipe bidirectionally after successful upgrade
426        let (to_origin, from_origin) = forward_bidi(
427            &mut downstream_recv,
428            &mut downstream_send,
429            &mut origin_recv,
430            &mut origin_send,
431        )
432        .await?;
433        debug!(to_origin, from_origin, "upgrade connection finished");
434        Ok(())
435    }
436}
437
438async fn forward_reqwest_response(
439    response: reqwest::Response,
440    send: &mut SendStream,
441    req_metrics: Arc<TargetMetrics>,
442) -> Result<usize> {
443    let mut send = TrackedWrite::new(send, |d| {
444        req_metrics.bytes_from_origin.inc_by(d);
445    });
446    write_response(&response, &mut send).await?;
447    let send = send.into_inner();
448    let mut total = 0;
449    let mut body = response.bytes_stream();
450    while let Some(bytes) = body.next().await {
451        let bytes = bytes.anyerr()?;
452        total += bytes.len();
453        req_metrics.bytes_from_origin.inc_by(bytes.len() as u64);
454        send.write_chunk(bytes).await.anyerr()?;
455    }
456    send.finish().anyerr()?;
457    Ok(total)
458}
459
460async fn error_response_and_finish(mut send: SendStream) -> Result<(), n0_error::AnyError> {
461    HttpResponse::with_reason(StatusCode::BAD_GATEWAY, "Origin Is Unreachable")
462        .no_body()
463        .write(&mut send, true)
464        .await
465        .inspect_err(|err| warn!("Failed to write error response to downstream: {err:#}"))
466        .ok();
467    send.finish().anyerr()?;
468    Ok(())
469}
470
471async fn write_response(
472    res: &reqwest::Response,
473    send: &mut (impl AsyncWrite + Unpin),
474) -> Result<()> {
475    let status_line = format!(
476        "{:?} {} {}\r\n",
477        res.version(),
478        res.status().as_u16(),
479        // TODO: get reason phrase as returned from upstream.
480        res.status().canonical_reason().unwrap_or_default()
481    );
482    send.write_all(status_line.as_bytes()).await.anyerr()?;
483
484    for (name, value) in res.headers().iter() {
485        send.write_all(name.as_str().as_bytes()).await.anyerr()?;
486        send.write_all(b": ").await.anyerr()?;
487        send.write_all(value.as_bytes()).await.anyerr()?;
488        send.write_all(b"\r\n").await.anyerr()?;
489    }
490    send.write_all(b"\r\n").await.anyerr()?;
491    Ok(())
492}