1use std::{
2 sync::{
3 Arc,
4 atomic::{AtomicU64, Ordering},
5 },
6 time::Duration,
7};
8
9use http::{StatusCode, Version};
10use iroh::{
11 EndpointId,
12 endpoint::{Connection, ConnectionError, RecvStream, SendStream},
13 protocol::{AcceptError, ProtocolHandler},
14};
15use n0_error::{Result, StackResultExt, StdResultExt};
16use n0_future::stream::StreamExt;
17use tokio::{
18 io::{AsyncWrite, AsyncWriteExt},
19 net::TcpStream,
20};
21use tokio_util::{future::FutureExt, sync::CancellationToken, task::TaskTracker};
22use tracing::{Instrument, debug, error_span, instrument, warn};
23
24use crate::{
25 Authority, HEADER_SECTION_MAX_LENGTH, HttpResponse,
26 parse::{
27 HttpProxyRequestKind, HttpRequest, absolute_target_to_origin_form,
28 filter_hop_by_hop_headers,
29 },
30 util::{
31 Prebuffered, StreamEvent, TrackedRead, TrackedStream, TrackedWrite, forward_bidi, nores,
32 recv_to_stream,
33 },
34};
35
36mod auth;
37mod metrics;
38pub use auth::*;
39pub use metrics::*;
40
41const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(1);
42
43const SUPPORTED_UPGRADE_PROTOCOLS: &[&str] = &["websocket"];
45
46#[derive(derive_more::Debug)]
74pub struct UpstreamProxy {
75 #[debug("Arc<dyn AuthHandler>")]
76 auth: Arc<DynAuthHandler<'static>>,
77 conn_id: Arc<AtomicU64>,
78 shutdown: CancellationToken,
79 tasks: TaskTracker,
80 http_client: reqwest::Client,
81 metrics: Arc<UpstreamMetrics>,
82}
83
84impl ProtocolHandler for UpstreamProxy {
85 #[instrument("accept", level="error", skip_all, fields(id=self.conn_id.fetch_add(1, Ordering::SeqCst)))]
86 async fn accept(
87 &self,
88 connection: Connection,
89 ) -> std::result::Result<(), iroh::protocol::AcceptError> {
90 debug!(remote_id=%connection.remote_id().fmt_short(), "accepted connection");
91 self.metrics.connections_accepted.inc();
92 let res = self
93 .handle_connection(connection)
94 .await
95 .map_err(AcceptError::from_err);
96 self.metrics.connections_completed.inc();
97 res
98 }
99
100 async fn shutdown(&self) {
101 self.shutdown.cancel();
102 self.tasks.close();
103 debug!("shutting down ({} pending tasks)", self.tasks.len());
104 match self.tasks.wait().timeout(GRACEFUL_SHUTDOWN_TIMEOUT).await {
105 Ok(_) => debug!("all streams closed cleanly"),
106 Err(_) => debug!(
107 remaining = self.tasks.len(),
108 "not all streams closed in time, abort"
109 ),
110 }
111 }
112}
113
114impl UpstreamProxy {
115 pub fn new(auth: impl AuthHandler + 'static) -> Result<Self> {
117 Ok(Self {
118 auth: DynAuthHandler::new_arc(auth),
119 conn_id: Default::default(),
120 shutdown: CancellationToken::new(),
121 tasks: TaskTracker::new(),
122 http_client: reqwest::Client::new(),
123 metrics: Default::default(),
124 })
125 }
126
127 pub fn metrics(&self) -> Arc<UpstreamMetrics> {
129 self.metrics.clone()
130 }
131
132 pub fn on_shutdown(&self) -> impl Future<Output = ()> + Send + 'static + use<> {
134 self.shutdown.clone().cancelled_owned()
135 }
136
137 async fn handle_connection(&self, connection: Connection) -> Result<()> {
138 let remote_id = connection.remote_id();
139 let mut stream_id = 0;
140 loop {
141 let (send, recv) = match connection
142 .accept_bi()
143 .with_cancellation_token(&self.shutdown)
144 .await
145 {
146 None => return Ok(()),
147 Some(Ok(streams)) => streams,
148 Some(Err(ConnectionError::ApplicationClosed(_))) => {
149 debug!("connection closed by downstream remote");
150 return Ok(());
151 }
152 Some(Err(err)) => {
153 return Err(err).std_context("failed to accept streams");
154 }
155 };
156 let auth = self.auth.clone();
157 let shutdown = self.shutdown.clone();
158 let http_client = self.http_client.clone();
159 let metrics = self.metrics.clone();
160 self.tasks.spawn(
161 async move {
164 if let Err(err) = Self::handle_remote_streams(
165 auth,
166 remote_id,
167 send,
168 recv,
169 http_client,
170 metrics,
171 )
172 .await
173 {
174 if shutdown.is_cancelled() {
175 debug!("aborted at shutdown: {err:#}");
176 } else {
177 warn!("failed to handle streams: {err:#}");
178 }
179 }
180 }
181 .instrument(error_span!("stream", id=%stream_id)),
182 );
183 stream_id += 1;
184 }
185 }
186
187 async fn handle_remote_streams(
188 auth: Arc<DynAuthHandler<'static>>,
189 remote_id: EndpointId,
190 mut downstream_send: SendStream,
191 downstream_recv: RecvStream,
192 http_client: reqwest::Client,
193 metrics: Arc<UpstreamMetrics>,
194 ) -> Result<()> {
195 let mut downstream_recv = Prebuffered::new(downstream_recv, HEADER_SECTION_MAX_LENGTH);
196 let (request_len, req) = HttpRequest::peek(&mut downstream_recv).await?;
197 downstream_recv.discard(request_len);
198
199 debug!(?req, "handle request");
200 let req = req
201 .try_into_proxy_request()
202 .context("Received origin-form request but expected proxy request")?;
203
204 let id = req.kind.authority()?;
205 let req_metrics = metrics.get_or_insert(id);
206 req_metrics.bytes_to_origin.inc_by(request_len as u64);
207
208 match auth.authorize(remote_id, &req).await {
209 Ok(()) => {
210 metrics.requests_accepted.inc();
211 req_metrics.requests_accepted.inc();
212 debug!("request is authorized, continue");
213 }
214 Err(reason) => {
215 metrics.requests_denied.inc();
216 req_metrics.requests_denied.inc();
217 debug!(?reason, "request is not authorized, abort");
218 HttpResponse::new(StatusCode::FORBIDDEN)
219 .no_body()
220 .write(&mut downstream_send, true)
221 .await
222 .ok();
223 downstream_send.finish().anyerr()?;
224 return Ok(());
225 }
226 };
227
228 match req.kind {
229 HttpProxyRequestKind::Tunnel { target: authority } => {
230 debug!(%authority, "tunnel request: connecting to origin");
231 match TcpStream::connect(authority.to_addr()).await {
232 Err(err) => {
233 warn!("Failed to connect to origin server: {err:#}");
234 metrics.requests_failed.inc();
235 req_metrics.requests_failed.inc();
236 error_response_and_finish(downstream_send).await?;
237 Ok(())
238 }
239 Ok(tcp_stream) => {
240 debug!(%authority, "connected to origin");
241 HttpResponse::with_reason(StatusCode::OK, "Connection Established")
242 .write(&mut downstream_send, true)
243 .await
244 .context("Failed to write CONNECT response to downstream")?;
245 let (mut origin_recv, mut origin_send) = tcp_stream.into_split();
246
247 let mut downstream_recv = TrackedRead::new(&mut downstream_recv, |d| {
248 req_metrics.bytes_to_origin.inc_by(d);
249 });
250 let mut downstream_send = TrackedWrite::new(&mut downstream_send, |d| {
251 req_metrics.bytes_from_origin.inc_by(d);
252 });
253
254 match forward_bidi(
255 &mut downstream_recv,
256 &mut downstream_send,
257 &mut origin_recv,
258 &mut origin_send,
259 )
260 .await
261 {
262 Ok((to_origin, from_origin)) => {
263 metrics.requests_completed.inc();
264 req_metrics.requests_completed.inc();
265 debug!(to_origin, from_origin, "finish");
266 Ok(())
267 }
268 Err(err) => {
269 metrics.requests_failed.inc();
270 req_metrics.requests_failed.inc();
271 Err(err)
272 }
273 }
274 }
275 }
276 }
277 HttpProxyRequestKind::Absolute { method, target } => {
278 let upgrade_protocol = req
280 .headers
281 .get(http::header::UPGRADE)
282 .and_then(|v| v.to_str().ok())
283 .filter(|proto| {
284 SUPPORTED_UPGRADE_PROTOCOLS
285 .iter()
286 .any(|p| p.eq_ignore_ascii_case(proto))
287 });
288
289 if let Some(protocol) = upgrade_protocol {
290 debug!(%target, %protocol, "upgrade request: connecting to origin");
291 let mut headers = req.headers;
292 filter_hop_by_hop_headers(&mut headers);
293 let authority = Authority::from_absolute_uri(&target)?;
295 let origin_form_uri = absolute_target_to_origin_form(&target)?;
296 let request = HttpRequest {
297 version: Version::HTTP_11,
298 headers,
299 uri: origin_form_uri,
300 method,
301 };
302 match Self::handle_upgrade_request(
303 authority,
304 request,
305 downstream_recv,
306 downstream_send,
307 req_metrics.clone(),
308 )
309 .await
310 {
311 Ok(()) => {
312 metrics.requests_completed.inc();
313 req_metrics.requests_completed.inc();
314 Ok(())
315 }
316 Err(err) => {
317 metrics.requests_failed.inc();
318 req_metrics.requests_failed.inc();
319 Err(err)
320 }
321 }
322 } else {
323 debug!(%target, "origin request: connecting to origin");
324 let body = {
325 let req_metrics = req_metrics.clone();
326 let body = recv_to_stream(downstream_recv);
327 let body = TrackedStream::new(body, move |ev| match ev {
328 StreamEvent::Data(n) => nores(req_metrics.bytes_to_origin.inc_by(n)),
329 _ => {}
330 });
331 reqwest::Body::wrap_stream(body)
332 };
333
334 let mut headers = req.headers;
336 filter_hop_by_hop_headers(&mut headers);
337
338 let mut response = match http_client
340 .request(method, target.to_string())
341 .headers(headers)
342 .body(body)
343 .send()
344 .await
345 {
346 Ok(response) => response,
347 Err(err) => {
348 error_response_and_finish(downstream_send).await?;
349 metrics.requests_failed.inc();
350 req_metrics.requests_failed.inc();
351 return Err(err).anyerr();
352 }
353 };
354 filter_hop_by_hop_headers(response.headers_mut());
355 debug!(?response, "received response from origin");
356 let res = forward_reqwest_response(
357 response,
358 &mut downstream_send,
359 req_metrics.clone(),
360 )
361 .await;
362 match res {
363 Ok(total) => {
364 debug!(response_body_len=%total, "finish");
365 metrics.requests_completed.inc();
366 req_metrics.requests_completed.inc();
367 Ok(())
368 }
369 Err(err) => {
370 metrics.requests_failed.inc();
371 req_metrics.requests_failed.inc();
372 Err(err)
373 }
374 }
375 }
376 }
377 }
378 }
379
380 async fn handle_upgrade_request(
387 authority: Authority,
388 request: HttpRequest,
389 mut downstream_recv: Prebuffered<RecvStream>,
390 mut downstream_send: SendStream,
391 req_metrics: Arc<TargetMetrics>,
392 ) -> Result<()> {
393 let origin = match TcpStream::connect(authority.to_addr()).await {
395 Ok(stream) => stream,
396 Err(err) => {
397 warn!("Failed to connect to origin for upgrade: {err:#}");
398 error_response_and_finish(downstream_send).await?;
399 return Err(err).anyerr();
400 }
401 };
402 let (origin_recv, mut origin_send) = origin.into_split();
403
404 let mut downstream_recv = TrackedRead::new(&mut downstream_recv, |d| {
405 req_metrics.bytes_to_origin.inc_by(d);
406 });
407 let mut downstream_send = TrackedWrite::new(&mut downstream_send, |d| {
408 req_metrics.bytes_from_origin.inc_by(d);
409 });
410
411 request.write(&mut origin_send).await?;
413
414 let mut origin_recv = Prebuffered::new(origin_recv, HEADER_SECTION_MAX_LENGTH);
416 let response = HttpResponse::read(&mut origin_recv).await?;
417 debug!(?response, "upgrade response from origin");
418 response.write(&mut downstream_send, true).await?;
419
420 if response.status != StatusCode::SWITCHING_PROTOCOLS {
421 downstream_send.into_inner().finish().anyerr()?;
422 return Ok(());
423 }
424
425 let (to_origin, from_origin) = forward_bidi(
427 &mut downstream_recv,
428 &mut downstream_send,
429 &mut origin_recv,
430 &mut origin_send,
431 )
432 .await?;
433 debug!(to_origin, from_origin, "upgrade connection finished");
434 Ok(())
435 }
436}
437
438async fn forward_reqwest_response(
439 response: reqwest::Response,
440 send: &mut SendStream,
441 req_metrics: Arc<TargetMetrics>,
442) -> Result<usize> {
443 let mut send = TrackedWrite::new(send, |d| {
444 req_metrics.bytes_from_origin.inc_by(d);
445 });
446 write_response(&response, &mut send).await?;
447 let send = send.into_inner();
448 let mut total = 0;
449 let mut body = response.bytes_stream();
450 while let Some(bytes) = body.next().await {
451 let bytes = bytes.anyerr()?;
452 total += bytes.len();
453 req_metrics.bytes_from_origin.inc_by(bytes.len() as u64);
454 send.write_chunk(bytes).await.anyerr()?;
455 }
456 send.finish().anyerr()?;
457 Ok(total)
458}
459
460async fn error_response_and_finish(mut send: SendStream) -> Result<(), n0_error::AnyError> {
461 HttpResponse::with_reason(StatusCode::BAD_GATEWAY, "Origin Is Unreachable")
462 .no_body()
463 .write(&mut send, true)
464 .await
465 .inspect_err(|err| warn!("Failed to write error response to downstream: {err:#}"))
466 .ok();
467 send.finish().anyerr()?;
468 Ok(())
469}
470
471async fn write_response(
472 res: &reqwest::Response,
473 send: &mut (impl AsyncWrite + Unpin),
474) -> Result<()> {
475 let status_line = format!(
476 "{:?} {} {}\r\n",
477 res.version(),
478 res.status().as_u16(),
479 res.status().canonical_reason().unwrap_or_default()
481 );
482 send.write_all(status_line.as_bytes()).await.anyerr()?;
483
484 for (name, value) in res.headers().iter() {
485 send.write_all(name.as_str().as_bytes()).await.anyerr()?;
486 send.write_all(b": ").await.anyerr()?;
487 send.write_all(value.as_bytes()).await.anyerr()?;
488 send.write_all(b"\r\n").await.anyerr()?;
489 }
490 send.write_all(b"\r\n").await.anyerr()?;
491 Ok(())
492}