1use std::{convert::Infallible, fmt::Debug, io, net::SocketAddr, sync::Arc};
2
3use bytes::Bytes;
4use http::{Method, StatusCode, Version, header};
5use http_body_util::{BodyExt, Empty, StreamBody, combinators::BoxBody};
6use hyper::{
7 Request, Response,
8 body::{Frame, Incoming},
9 service::service_fn,
10};
11use hyper_util::{
12 rt::{TokioExecutor, TokioIo},
13 server::conn::auto,
14};
15use iroh::{
16 Endpoint, EndpointId,
17 endpoint::{ConnectionError, RecvStream, SendStream},
18};
19use iroh_blobs::util::connection_pool::{self, ConnectionPool, ConnectionRef};
20use n0_error::{AnyError, Result, StdResultExt, anyerr, stack_error};
21use n0_future::TryStreamExt;
22use tokio::{
23 io::{AsyncRead, AsyncWrite, AsyncWriteExt},
24 net::{TcpListener, TcpStream},
25};
26use tokio_util::{io::ReaderStream, sync::CancellationToken};
27use tracing::{Instrument, debug, error_span, warn};
28
29pub use self::opts::{
30 Deny, ErrorResponder, HttpProxyOpts, PoolOpts, ProxyMode, RequestHandler, RequestHandlerChain,
31 StaticForwardProxy, StaticReverseProxy,
32};
33use crate::{
34 ALPN, Authority, HEADER_SECTION_MAX_LENGTH, inc_by_delta,
35 parse::{HttpRequest, HttpResponse},
36 util::{
37 Prebufferable, Prebuffered, StreamEvent, TrackedRead, TrackedStream, TrackedWrite,
38 forward_bidi, nores,
39 },
40};
41
42pub(crate) mod metrics;
43pub use self::metrics::DownstreamMetrics;
44pub(crate) mod opts;
45
46#[derive(Clone, Debug)]
63pub struct DownstreamProxy {
64 pool: ConnectionPool,
65 metrics: Arc<DownstreamMetrics>,
66}
67
68impl DownstreamProxy {
69 pub fn new(endpoint: Endpoint, pool_opts: PoolOpts) -> Self {
71 let metrics = Arc::new(DownstreamMetrics::default());
72 let opts: connection_pool::Options = pool_opts.into();
73
74 let pool_opts = opts.with_on_connected({
76 let metrics = metrics.clone();
77 move |_endpoint, unguarded_conn| {
80 let metrics = metrics.clone();
81 async move {
82 metrics.iroh_connections_opened.inc();
83 let metrics = metrics.clone();
84 tokio::spawn(async move {
85 let reason = unguarded_conn.closed().await;
86 match reason {
87 ConnectionError::LocallyClosed => {
88 metrics.iroh_connections_closed_idle.inc();
89 }
90 _ => {
92 metrics.iroh_connections_closed_error.inc();
93 }
94 }
95 });
96 Ok(())
97 }
98 }
99 });
100
101 let pool = ConnectionPool::new(endpoint, ALPN, pool_opts);
102 Self { pool, metrics }
103 }
104
105 pub fn metrics(&self) -> &Arc<DownstreamMetrics> {
107 &self.metrics
108 }
109
110 pub async fn create_tunnel(
114 &self,
115 destination: &EndpointAuthority,
116 ) -> Result<TunnelClientStreams, ProxyError> {
117 let (conn, mut send, recv) = self
118 .connect(destination.endpoint_id)
119 .await
120 .map_err(ProxyError::gateway_timeout)?;
121 send.write_all(destination.authority.to_connect_request().as_bytes())
122 .await?;
123 let mut recv = Prebuffered::new(recv, HEADER_SECTION_MAX_LENGTH);
124 let response = HttpResponse::read(&mut recv)
125 .await
126 .map_err(ProxyError::bad_gateway)?;
127 debug!(status=%response.status, "response from upstream");
128 if response.status != StatusCode::OK {
129 Err(ProxyError::new(
130 Some(response.status),
131 anyerr!("Upstream gateway returned error response"),
132 ))
133 } else {
134 Ok(TunnelClientStreams { send, recv, conn })
135 }
136 }
137
138 pub async fn forward_tcp_listener(&self, listener: TcpListener, mode: ProxyMode) -> Result<()> {
142 let cancel_token = CancellationToken::new();
143 let _cancel_guard = cancel_token.clone().drop_guard();
144 let mut id = 0;
145 loop {
146 let (stream, addr) = listener.accept().await?;
147 let span = error_span!("tcp-accept", id);
148 let addr = SrcAddr::Tcp(addr);
149 self.spawn_forward_stream(addr, stream, mode.clone(), span, cancel_token.child_token());
150 id += 1;
151 }
152 }
153
154 #[cfg(unix)]
158 pub async fn forward_uds_listener(
159 &self,
160 listener: tokio::net::UnixListener,
161 mode: ProxyMode,
162 ) -> Result<()> {
163 let cancel_token = CancellationToken::new();
164 let _cancel_guard = cancel_token.clone().drop_guard();
165 let mut id = 0;
166 loop {
167 let (stream, addr) = listener.accept().await?;
168 let addr = SrcAddr::Unix(addr.into());
169 let span = error_span!("uds-accept", id);
170 self.spawn_forward_stream(addr, stream, mode.clone(), span, cancel_token.child_token());
171 id += 1;
172 }
173 }
174
175 fn spawn_forward_stream(
176 &self,
177 client_addr: SrcAddr,
178 stream: impl SplittableStream,
179 mode: ProxyMode,
180 span: tracing::Span,
181 cancel_token: CancellationToken,
182 ) {
183 let this = self.clone();
184 tokio::spawn(
185 cancel_token
186 .child_token()
187 .run_until_cancelled_owned(async move {
188 debug!(%client_addr, "accepted connection");
189 if let Err(err) = this.forward_stream(client_addr, stream, &mode).await {
190 warn!("Failed to handle connection: {err:#}");
191 }
192 })
193 .instrument(span),
194 );
195 }
196
197 async fn forward_stream(
204 &self,
205 src_addr: SrcAddr,
206 mut stream: impl SplittableStream + 'static,
207 mode: &ProxyMode,
208 ) -> Result<()> {
209 match mode {
210 ProxyMode::Tcp(destination) => {
211 self.metrics.requests_accepted.inc();
212 self.metrics.requests_accepted_tcp.inc();
213 let (tcp_recv, tcp_send) = stream.split();
214 let mut conn = self.create_tunnel(destination).await?;
215 debug!(endpoint_id=%conn.conn.remote_id().fmt_short(), "tunnel established");
216 let metrics = self.metrics.clone();
217 let mut tcp_recv =
218 TrackedRead::new(tcp_recv, inc_by_delta!(metrics, bytes_to_upstream));
219 let mut tcp_send =
220 TrackedWrite::new(tcp_send, inc_by_delta!(metrics, bytes_from_upstream));
221 let res =
222 forward_bidi(&mut tcp_recv, &mut tcp_send, &mut conn.recv, &mut conn.send)
223 .await
224 .map_err(ProxyError::io);
225 match res {
226 Ok(_) => {
227 self.metrics.requests_completed.inc();
228 Ok(())
229 }
230 Err(err) => {
231 self.metrics.requests_failed.inc();
232 Err(err.into())
233 }
234 }
235 }
236 ProxyMode::Http(opts) => {
237 let io = TokioIo::new(stream);
238 let service = service_fn(|req| {
239 let this = self.clone();
240 let opts = opts.clone();
241 let src_addr = src_addr.clone();
242 async move {
243 let res = match this.handle_hyper_request(src_addr, req, &opts).await {
244 Ok(res) => res,
245 Err(err) => {
246 warn!("Error while forwarding HTTP/2 request: {err:#}");
247 let status =
248 err.response_status().unwrap_or(StatusCode::BAD_GATEWAY);
249 opts.error_response(status).await
250 }
251 };
252 Ok::<_, Infallible>(res)
253 }
254 });
255 let mut builder = auto::Builder::new(TokioExecutor::new());
256 builder
257 .http2()
258 .initial_stream_window_size(1 << 20)
259 .initial_connection_window_size(1 << 20)
260 .max_concurrent_streams(1024)
261 .enable_connect_protocol();
262 builder.serve_connection_with_upgrades(io, service).await?;
263 Ok(())
264 }
265 }
266 }
267
268 async fn connect(
269 &self,
270 destination: EndpointId,
271 ) -> Result<(ConnectionRef, SendStream, RecvStream), ProxyError> {
272 let conn = self
273 .pool
274 .get_or_connect(destination)
275 .await
276 .map_err(|err| ProxyError::gateway_timeout(anyerr!(err)))?;
277 let (send, recv) = conn
278 .open_bi()
279 .await
280 .map_err(|err| ProxyError::bad_gateway(anyerr!(err)))?;
281 Ok((conn, send, recv))
282 }
283
284 async fn handle_hyper_request(
285 &self,
286 src_addr: SrcAddr,
287 mut request: Request<Incoming>,
288 opts: &HttpProxyOpts,
289 ) -> Result<Response<HyperBody>, ProxyError> {
290 debug!(?request, "incoming");
291
292 let original_version = request.version();
293 let is_upgrade = request.headers().contains_key(header::UPGRADE);
294 let is_connect = request.method() == Method::CONNECT;
295 let is_h2_extended_connect = convert_h2_extended_connect_to_upgrade(&mut request);
296 let upgrade = if is_connect || is_upgrade {
297 Some(hyper::upgrade::on(&mut request))
298 } else {
299 None
300 };
301
302 let (parts, body) = request.into_parts();
303 let mut request = HttpRequest::from_parts(parts);
304
305 let metrics = self.metrics.clone();
306
307 let destination = match opts
308 .request_handler
309 .handle_request(src_addr, &mut request)
310 .await
311 {
312 Ok(destination) => destination,
313 Err(deny) => {
314 metrics.requests_denied.inc();
315 return Err(ProxyError::from(deny));
316 }
317 };
318
319 metrics.requests_accepted.inc();
321 if original_version == Version::HTTP_2 {
322 metrics.requests_accepted_h2.inc();
323 if is_connect {
324 if is_h2_extended_connect {
325 metrics.requests_accepted_h2_extended_connect.inc();
326 } else {
327 metrics.requests_accepted_h2_connect.inc();
328 }
329 }
330 } else {
331 metrics.requests_accepted_h1.inc();
332 if is_connect {
333 metrics.requests_accepted_h1_connect.inc();
334 }
335 if is_upgrade {
336 metrics.requests_accepted_h1_upgrade.inc();
337 }
338 }
339
340 request.version = Version::HTTP_11;
342 let request = request;
344
345 debug!(destination=%destination.fmt_short(), ?request, is_connect, is_h2_extended_connect, is_upgrade, "pipe request to upstream");
346
347 let (conn, send, recv) = self.connect(destination).await?;
349 debug!(endpoint_id=%conn.remote_id().fmt_short(), "connected to upstream");
350
351 let conn_guard = Arc::new(conn);
355
356 let mut upstream_send = TrackedWrite::new(send, inc_by_delta!(metrics, bytes_to_upstream))
359 .with_guard(conn_guard.clone());
360 let upstream_recv = TrackedRead::new(recv, inc_by_delta!(metrics, bytes_from_upstream))
361 .with_guard(conn_guard.clone());
362 let mut upstream_recv = Prebuffered::new(upstream_recv, HEADER_SECTION_MAX_LENGTH);
364
365 request.write(&mut upstream_send).await?;
367
368 let response = if let Some(upgrade_fut) = upgrade {
369 let mut response = match read_response(&mut upstream_recv).await {
372 Ok(response) => response,
373 Err(err) => {
374 metrics.requests_failed.inc();
375 return Err(err.into());
376 }
377 };
378 debug!(?response, "read connect response");
379
380 if is_h2_extended_connect && response.status == StatusCode::SWITCHING_PROTOCOLS {
381 response.status = StatusCode::OK;
382 response.headers.remove(header::UPGRADE);
383 response.headers.remove(header::CONNECTION);
384 }
385
386 let is_ok = is_connect && response.status == StatusCode::OK
387 || is_upgrade && response.status == StatusCode::SWITCHING_PROTOCOLS;
388
389 if is_ok {
390 spawn(forward_hyper_upgrade(
391 upgrade_fut,
392 upstream_recv,
393 upstream_send,
394 ));
395 response_to_hyper::<tokio::io::Empty>(response, None, metrics)?
396 } else if request.method == Method::CONNECT {
397 response_to_hyper::<tokio::io::Empty>(response, None, metrics)?
398 } else {
399 spawn(forward_hyper_body_and_finish(body, upstream_send));
400 response_to_hyper(response, Some(upstream_recv), metrics)?
401 }
402 } else {
403 spawn(forward_hyper_body_and_finish(body, upstream_send));
405 let response = match read_response(&mut upstream_recv).await {
406 Ok(response) => response,
407 Err(err) => {
408 metrics.requests_failed.inc();
409 return Err(err.into());
410 }
411 };
412 debug!(
413 status = %response.status,
414 "received response header from upstream"
415 );
416 response_to_hyper(response, Some(upstream_recv), metrics)?
417 };
418
419 Ok(response)
420 }
421}
422
423fn convert_h2_extended_connect_to_upgrade(request: &mut Request<Incoming>) -> bool {
424 if request.version() != Version::HTTP_2 {
425 return false;
426 }
427 let extended_connect_protocol = request
430 .extensions()
431 .get::<hyper::ext::Protocol>()
432 .map(|p| p.as_str().to_string());
433 if let Some(protocol) = extended_connect_protocol {
434 debug!(%protocol, "extended CONNECT request, converting to upgrade request");
435 *request.method_mut() = Method::GET;
436 request
437 .headers_mut()
438 .insert(header::UPGRADE, protocol.parse().unwrap());
439 request
440 .headers_mut()
441 .insert(header::CONNECTION, "upgrade".parse().unwrap());
442 true
443 } else {
444 false
445 }
446}
447
448trait SplittableStream: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {
449 fn split<'a>(
450 &'a mut self,
451 ) -> (
452 impl AsyncRead + Send + Unpin + 'a,
453 impl AsyncWrite + Send + Unpin + 'a,
454 );
455}
456
457impl SplittableStream for TcpStream {
458 fn split<'a>(
459 &'a mut self,
460 ) -> (
461 impl AsyncRead + Send + Unpin + 'a,
462 impl AsyncWrite + Send + Unpin + 'a,
463 ) {
464 TcpStream::split(self)
465 }
466}
467
468#[cfg(unix)]
469impl SplittableStream for tokio::net::UnixStream {
470 fn split<'a>(
471 &'a mut self,
472 ) -> (
473 impl AsyncRead + Send + Unpin + 'a,
474 impl AsyncWrite + Send + Unpin + 'a,
475 ) {
476 tokio::net::UnixStream::split(self)
477 }
478}
479
480#[derive(derive_more::From, Debug, Clone, derive_more::Display)]
482pub enum SrcAddr {
483 #[display("{_0}")]
485 Tcp(SocketAddr),
486 #[cfg(unix)]
488 #[display("Unix({_0:?})")]
489 Unix(std::os::unix::net::SocketAddr),
490}
491
492pub struct TunnelClientStreams {
498 pub send: SendStream,
500 pub recv: Prebuffered<RecvStream>,
502 pub conn: ConnectionRef,
504}
505
506#[derive(Debug, Clone)]
511pub struct EndpointAuthority {
512 pub endpoint_id: EndpointId,
514 pub authority: Authority,
516}
517
518impl EndpointAuthority {
519 pub fn new(endpoint_id: EndpointId, authority: Authority) -> Self {
521 Self {
522 endpoint_id,
523 authority,
524 }
525 }
526
527 pub fn fmt_short(&self) -> String {
529 format!("{}->{}", self.endpoint_id.fmt_short(), self.authority)
530 }
531}
532
533#[stack_error(add_meta, derive)]
535pub struct ProxyError {
536 response_status: Option<StatusCode>,
537 #[error(source)]
538 source: AnyError,
539}
540
541impl From<Deny> for ProxyError {
542 #[track_caller]
543 fn from(value: Deny) -> Self {
544 ProxyError::new(Some(value.code), value.reason)
545 }
546}
547
548impl From<io::Error> for ProxyError {
549 fn from(value: io::Error) -> Self {
550 Self::io(value)
551 }
552}
553
554impl From<iroh::endpoint::WriteError> for ProxyError {
555 fn from(value: iroh::endpoint::WriteError) -> Self {
556 Self::io(anyerr!(value))
557 }
558}
559
560impl ProxyError {
561 pub fn response_status(&self) -> Option<StatusCode> {
563 self.response_status
564 }
565
566 fn gateway_timeout(source: impl Into<AnyError>) -> Self {
567 Self::new(Some(StatusCode::GATEWAY_TIMEOUT), source.into())
568 }
569
570 fn bad_gateway(source: impl Into<AnyError>) -> Self {
571 Self::new(Some(StatusCode::BAD_GATEWAY), source.into())
572 }
573
574 fn io(source: impl Into<AnyError>) -> Self {
575 Self::new(None, source.into())
576 }
577}
578
579type HyperBody = BoxBody<Bytes, io::Error>;
580
581fn response_to_hyper<R>(
582 response: HttpResponse,
583 body: Option<R>,
584 metrics: Arc<DownstreamMetrics>,
585) -> Result<Response<HyperBody>, ProxyError>
586where
587 R: AsyncRead + Send + Sync + Unpin + 'static,
588{
589 let mut builder = Response::builder().status(response.status);
590 let headers = builder.headers_mut().unwrap();
591 *headers = response.headers;
592 let body = match body {
593 Some(body) => {
594 let stream = ReaderStream::new(body);
595 let stream = TrackedStream::new(stream, move |ev| match ev {
596 StreamEvent::Done(Ok(())) => nores(metrics.requests_completed.inc()),
597 StreamEvent::Done(Err(_)) => nores(metrics.requests_failed.inc()),
598 _ => {}
599 });
600 StreamBody::new(stream.map_ok(Frame::data)).boxed()
601 }
602 None => Empty::new().map_err(infallible_to_io).boxed(),
603 };
604 builder
605 .body(body)
606 .map_err(|err| ProxyError::bad_gateway(anyerr!(err)))
607}
608
609async fn forward_hyper_body_and_finish<F, G: Unpin>(
610 body: Incoming,
611 mut send: TrackedWrite<SendStream, F, G>,
612) -> Result<()>
613where
614 F: Fn(u64) + Unpin + Send + 'static,
615{
616 forward_hyper_body(body, &mut send).await?;
617 send.into_inner().finish().anyerr()?;
618 Ok(())
619}
620
621async fn forward_hyper_body(
624 mut body: Incoming,
625 send: &mut (impl AsyncWrite + Unpin),
626) -> Result<()> {
627 while let Some(frame) = body.frame().await {
628 let frame = frame.anyerr()?;
629 if let Ok(data) = frame.into_data() {
631 send.write_all(&data).await.anyerr()?;
632 }
633 }
634 Ok(())
635}
636
637async fn forward_hyper_upgrade(
638 upgrade_fut: hyper::upgrade::OnUpgrade,
639 mut upstream_recv: impl AsyncRead + Send + Unpin,
640 mut upstream_send: impl AsyncWrite + Send + Unpin,
641) -> Result<()> {
642 let upgraded = upgrade_fut.await.std_context("HTTP/1 upgrade failed")?;
643 let upgraded = TokioIo::new(upgraded);
644 let (mut client_read, mut client_write) = tokio::io::split(upgraded);
646 forward_bidi(
647 &mut client_read,
648 &mut client_write,
649 &mut upstream_recv,
650 &mut upstream_send,
651 )
652 .await?;
653 Ok(())
654}
655
656async fn read_response(recv: &mut impl Prebufferable) -> Result<HttpResponse, ProxyError> {
657 HttpResponse::read(recv)
658 .await
659 .map_err(ProxyError::bad_gateway)
660}
661
662fn infallible_to_io(err: Infallible) -> io::Error {
663 match err {}
664}
665
666fn spawn<F, T>(fut: F) -> tokio::task::JoinHandle<()>
667where
668 F: Future<Output = Result<T>> + Send + 'static,
669{
670 tokio::spawn(
671 async move {
672 if let Err(err) = fut.await {
673 warn!("{err:#}")
674 }
675 }
676 .instrument(tracing::Span::current()),
677 )
678}