1use crate::accept::Accept;
39use std::{
40 fmt,
41 future::Future,
42 io,
43 net::{IpAddr, SocketAddr},
44 pin::Pin,
45 task::{Context, Poll},
46 time::Duration,
47};
48
49use http::HeaderValue;
50use http::Request;
51use ppp::{v1, v2, HeaderResult};
52use tokio::{
53 io::{AsyncRead, AsyncReadExt, AsyncWrite},
54 time::timeout,
55};
56use tower_service::Service;
57
58pub(crate) mod future;
59use self::future::ProxyProtocolAcceptorFuture;
60
61const V1_PREFIX_LEN: usize = 5;
63const V1_MAX_LENGTH: usize = 107;
65const V1_TERMINATOR: &[u8] = b"\r\n";
67const V2_PREFIX_LEN: usize = 12;
69const V2_MINIMUM_LEN: usize = 16;
71const V2_LENGTH_INDEX: usize = 14;
73const READ_BUFFER_LEN: usize = 512;
75
76pub(crate) async fn read_proxy_header<I>(
77 mut stream: I,
78) -> Result<(I, Option<SocketAddr>), io::Error>
79where
80 I: AsyncRead + Unpin,
81{
82 let mut buffer = [0; READ_BUFFER_LEN];
84 let mut dynamic_buffer = None;
86
87 stream.read_exact(&mut buffer[..V1_PREFIX_LEN]).await?;
89
90 if &buffer[..V1_PREFIX_LEN] == v1::PROTOCOL_PREFIX.as_bytes() {
91 read_v1_header(&mut stream, &mut buffer).await?;
92 } else {
93 stream
94 .read_exact(&mut buffer[V1_PREFIX_LEN..V2_MINIMUM_LEN])
95 .await?;
96 if &buffer[..V2_PREFIX_LEN] == v2::PROTOCOL_PREFIX {
97 dynamic_buffer = read_v2_header(&mut stream, &mut buffer).await?;
98 } else {
99 return Err(io::Error::new(
100 io::ErrorKind::InvalidData,
101 "No valid Proxy Protocol header detected",
102 ));
103 }
104 }
105
106 let buffer = dynamic_buffer.as_deref().unwrap_or(&buffer[..]);
108
109 let header = HeaderResult::parse(buffer);
111 match header {
112 HeaderResult::V1(Ok(header)) => {
113 let client_address = match header.addresses {
114 v1::Addresses::Tcp4(ip) => {
115 SocketAddr::new(IpAddr::V4(ip.source_address), ip.source_port)
116 }
117 v1::Addresses::Tcp6(ip) => {
118 SocketAddr::new(IpAddr::V6(ip.source_address), ip.source_port)
119 }
120 v1::Addresses::Unknown => {
121 return Ok((stream, None));
123 }
124 };
125
126 Ok((stream, Some(client_address)))
127 }
128 HeaderResult::V2(Ok(header)) => {
129 let client_address = match header.addresses {
130 v2::Addresses::IPv4(ip) => {
131 SocketAddr::new(IpAddr::V4(ip.source_address), ip.source_port)
132 }
133 v2::Addresses::IPv6(ip) => {
134 SocketAddr::new(IpAddr::V6(ip.source_address), ip.source_port)
135 }
136 v2::Addresses::Unix(unix) => {
137 return Err(io::Error::new(
138 io::ErrorKind::InvalidData,
139 format!(
140 "Unix socket addresses are not supported. Addresses: {:?}",
141 unix
142 ),
143 ));
144 }
145 v2::Addresses::Unspecified => {
146 return Ok((stream, None));
148 }
149 };
150
151 Ok((stream, Some(client_address)))
152 }
153 HeaderResult::V1(Err(_error)) => Err(io::Error::new(
154 io::ErrorKind::InvalidData,
155 "No valid V1 Proxy Protocol header received",
156 )),
157 HeaderResult::V2(Err(_error)) => Err(io::Error::new(
158 io::ErrorKind::InvalidData,
159 "No valid V2 Proxy Protocol header received",
160 )),
161 }
162}
163
164async fn read_v2_header<I>(
165 mut stream: I,
166 buffer: &mut [u8; READ_BUFFER_LEN],
167) -> Result<Option<Vec<u8>>, io::Error>
168where
169 I: AsyncRead + Unpin,
170{
171 let length =
172 u16::from_be_bytes([buffer[V2_LENGTH_INDEX], buffer[V2_LENGTH_INDEX + 1]]) as usize;
173 let full_length = V2_MINIMUM_LEN + length;
174
175 if full_length > READ_BUFFER_LEN {
177 let mut dynamic_buffer = Vec::with_capacity(full_length);
178 dynamic_buffer.extend_from_slice(&buffer[..V2_MINIMUM_LEN]);
179
180 stream
182 .read_exact(&mut dynamic_buffer[V2_MINIMUM_LEN..full_length])
183 .await?;
184
185 Ok(Some(dynamic_buffer))
186 } else {
187 stream
189 .read_exact(&mut buffer[V2_MINIMUM_LEN..full_length])
190 .await?;
191
192 Ok(None)
193 }
194}
195
196async fn read_v1_header<I>(
197 mut stream: I,
198 buffer: &mut [u8; READ_BUFFER_LEN],
199) -> Result<(), io::Error>
200where
201 I: AsyncRead + Unpin,
202{
203 let mut end_found = false;
205 for i in V1_PREFIX_LEN..V1_MAX_LENGTH {
206 buffer[i] = stream.read_u8().await?;
207
208 if [buffer[i - 1], buffer[i]] == V1_TERMINATOR {
209 end_found = true;
210 break;
211 }
212 }
213 if !end_found {
214 return Err(io::Error::new(
215 io::ErrorKind::InvalidData,
216 "No valid Proxy Protocol header detected",
217 ));
218 }
219
220 Ok(())
221}
222
223#[derive(Debug, Clone)]
226pub struct ForwardClientIp<S> {
227 inner: S,
228 client_address: Option<SocketAddr>,
229}
230
231impl<B, S> Service<Request<B>> for ForwardClientIp<S>
232where
233 S: Service<Request<B>>,
234{
235 type Response = S::Response;
236 type Error = S::Error;
237 type Future = S::Future;
238
239 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
240 self.inner.poll_ready(cx)
241 }
242
243 fn call(&mut self, mut req: Request<B>) -> Self::Future {
244 let mut forwarded_string = match self.client_address {
246 Some(socket_addr) => match socket_addr {
247 SocketAddr::V4(addr) => {
248 format!("for={}:{}", addr.ip(), addr.port())
249 }
250 SocketAddr::V6(addr) => {
251 format!("for=\"[{}]:{}\"", addr.ip(), addr.port())
252 }
253 },
254 None => "for=unknown".to_string(),
255 };
256
257 if let Some(existing_value) = req.headers_mut().get("Forwarded") {
258 forwarded_string = format!(
259 "{}, {}",
260 existing_value.to_str().unwrap_or(""),
261 forwarded_string
262 );
263 }
264
265 if let Ok(header_value) = HeaderValue::from_str(&forwarded_string) {
266 req.headers_mut().insert("Forwarded", header_value);
267 }
268
269 self.inner.call(req)
270 }
271}
272
273#[derive(Clone)]
275pub struct ProxyProtocolAcceptor<A> {
276 inner: A,
277 parsing_timeout: Duration,
278}
279
280impl<A> ProxyProtocolAcceptor<A> {
281 pub fn new(inner: A) -> Self {
284 #[cfg(not(test))]
285 let parsing_timeout = Duration::from_secs(5);
286
287 #[cfg(test)]
289 let parsing_timeout = Duration::from_secs(1);
290
291 Self {
292 inner,
293 parsing_timeout,
294 }
295 }
296
297 pub fn parsing_timeout(mut self, val: Duration) -> Self {
299 self.parsing_timeout = val;
300 self
301 }
302}
303
304impl<A> ProxyProtocolAcceptor<A> {
305 pub fn acceptor<Acceptor>(self, acceptor: Acceptor) -> ProxyProtocolAcceptor<Acceptor> {
307 ProxyProtocolAcceptor {
308 inner: acceptor,
309 parsing_timeout: self.parsing_timeout,
310 }
311 }
312}
313
314impl<A, I, S> Accept<I, S> for ProxyProtocolAcceptor<A>
315where
316 A: Accept<I, S> + Clone,
317 A::Stream: AsyncRead + AsyncWrite + Unpin,
318 I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
319{
320 type Stream = A::Stream;
321 type Service = ForwardClientIp<A::Service>;
322 type Future = ProxyProtocolAcceptorFuture<
323 Pin<Box<dyn Future<Output = Result<(I, Option<SocketAddr>), io::Error>> + Send>>,
324 A,
325 I,
326 S,
327 >;
328
329 fn accept(&self, stream: I, service: S) -> Self::Future {
330 let future = Box::pin(read_proxy_header(stream));
331
332 ProxyProtocolAcceptorFuture::new(
333 timeout(self.parsing_timeout, future),
334 self.inner.clone(),
335 service,
336 )
337 }
338}
339
340impl<A> fmt::Debug for ProxyProtocolAcceptor<A> {
341 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342 f.debug_struct("ProxyProtocolAcceptor").finish()
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 #[cfg(feature = "tls-openssl")]
349 use crate::tls_openssl::{
350 self,
351 tests::{dns_name as openssl_dns_name, tls_connector as openssl_connector},
352 OpenSSLConfig,
353 };
354 #[cfg(feature = "tls-rustls")]
355 use crate::tls_rustls::{
356 self,
357 tests::{dns_name as rustls_dns_name, tls_connector as rustls_connector},
358 RustlsConfig,
359 };
360 use crate::{handle::Handle, server::Server};
361 use axum::http::Response;
362 use axum::{routing::get, Router};
363 use bytes::Bytes;
364 use http::{response, Request};
365 use hyper::{
366 client::conn::{handshake, SendRequest},
367 Body,
368 };
369 use ppp::v2::{Builder, Command, Protocol, Type, Version};
370 use std::{io, net::SocketAddr, time::Duration};
371 use tokio::io::{AsyncReadExt, AsyncWriteExt};
372 use tokio::{
373 net::{TcpListener, TcpStream},
374 task::JoinHandle,
375 time::timeout,
376 };
377 use tower::{Service, ServiceExt};
378
379 #[tokio::test]
380 async fn start_and_request() {
381 let (_handle, _server_task, server_addr) = start_server(true).await;
382
383 let addr = start_proxy(server_addr, ProxyVersion::V2)
384 .await
385 .expect("Failed to start proxy");
386
387 let (mut client, _conn, _client_addr) = connect(addr).await;
388
389 let (_parts, body) = send_empty_request(&mut client).await;
390
391 assert_eq!(body.as_ref(), b"Hello, world!");
392 }
393
394 #[tokio::test]
395 async fn server_receives_client_address() {
396 let (_handle, _server_task, server_addr) = start_server(true).await;
397
398 let addr = start_proxy(server_addr, ProxyVersion::V2)
399 .await
400 .expect("Failed to start proxy");
401
402 let (mut client, _conn, client_addr) = connect(addr).await;
403
404 let (parts, body) = send_empty_request(&mut client).await;
405
406 let forwarded_header = parts
408 .headers
409 .get("Forwarded")
410 .expect("No Forwarded header present")
411 .to_str()
412 .expect("Failed to convert Forwarded header to str");
413
414 assert!(forwarded_header.contains(&format!("for={}", client_addr)));
415 assert_eq!(body.as_ref(), b"Hello, world!");
416 }
417
418 #[tokio::test]
419 async fn server_receives_client_address_v1() {
420 let (_handle, _server_task, server_addr) = start_server(true).await;
421
422 let addr = start_proxy(server_addr, ProxyVersion::V1)
423 .await
424 .expect("Failed to start proxy");
425
426 let (mut client, _conn, client_addr) = connect(addr).await;
427
428 let (parts, body) = send_empty_request(&mut client).await;
429
430 let forwarded_header = parts
432 .headers
433 .get("Forwarded")
434 .expect("No Forwarded header present")
435 .to_str()
436 .expect("Failed to convert Forwarded header to str");
437
438 assert!(forwarded_header.contains(&format!("for={}", client_addr)));
439 assert_eq!(body.as_ref(), b"Hello, world!");
440 }
441
442 #[cfg(feature = "tls-rustls")]
443 #[tokio::test]
444 async fn rustls_server_receives_client_address() {
445 let (_handle, _server_task, server_addr) = start_rustls_server().await;
446
447 let addr = start_proxy(server_addr, ProxyVersion::V2)
448 .await
449 .expect("Failed to start proxy");
450
451 let (mut client, _conn, client_addr) = rustls_connect(addr).await;
452
453 let (parts, body) = send_empty_request(&mut client).await;
454
455 let forwarded_header = parts
457 .headers
458 .get("Forwarded")
459 .expect("No Forwarded header present")
460 .to_str()
461 .expect("Failed to convert Forwarded header to str");
462
463 assert!(forwarded_header.contains(&format!("for={}", client_addr)));
464 assert_eq!(body.as_ref(), b"Hello, world!");
465 }
466
467 #[cfg(feature = "tls-openssl")]
468 #[tokio::test]
469 async fn openssl_server_receives_client_address() {
470 let (_handle, _server_task, server_addr) = start_openssl_server().await;
471
472 let addr = start_proxy(server_addr, ProxyVersion::V2)
473 .await
474 .expect("Failed to start proxy");
475
476 let (mut client, _conn, client_addr) = openssl_connect(addr).await;
477
478 let (parts, body) = send_empty_request(&mut client).await;
479
480 let forwarded_header = parts
482 .headers
483 .get("Forwarded")
484 .expect("No Forwarded header present")
485 .to_str()
486 .expect("Failed to convert Forwarded header to str");
487
488 assert!(forwarded_header.contains(&format!("for={}", client_addr)));
489 assert_eq!(body.as_ref(), b"Hello, world!");
490 }
491
492 #[tokio::test]
493 async fn not_parsing_when_header_present_fails() {
494 let (_handle, _server_task, server_addr) = start_server(false).await;
496
497 let addr = start_proxy(server_addr, ProxyVersion::V2)
499 .await
500 .expect("Failed to start proxy");
501
502 let (mut client, _conn, _client_addr) = connect(addr).await;
504
505 match client
507 .ready()
508 .await
509 .unwrap()
510 .call(Request::new(Body::empty()))
511 .await
512 {
513 Ok(_o) => {
515 }
518 Err(e) => {
519 if e.is_incomplete_message() {
520 } else {
521 panic!("Received unexpected error");
522 }
523 }
524 }
525 }
526
527 #[tokio::test]
528 async fn parsing_when_header_not_present_fails() {
529 let (_handle, _server_task, server_addr) = start_server(true).await;
530
531 let addr = start_proxy(server_addr, ProxyVersion::None)
532 .await
533 .expect("Failed to start proxy");
534
535 let (mut client, _conn, _client_addr) = connect(addr).await;
536
537 match client
538 .ready()
539 .await
540 .unwrap()
541 .call(Request::new(Body::empty()))
542 .await
543 {
544 Ok(_) => panic!("Should have failed"),
545 Err(e) => {
546 if e.is_incomplete_message() {
547 } else {
548 panic!("Received unexpected error");
549 }
550 }
551 }
552 }
553
554 async fn forward_ip_handler(req: Request<Body>) -> Response<Body> {
555 let mut response = Response::new(Body::from("Hello, world!"));
556
557 if let Some(header_value) = req.headers().get("Forwarded") {
558 response
559 .headers_mut()
560 .insert("Forwarded", header_value.clone());
561 }
562
563 response
564 }
565
566 async fn start_server(
567 parse_proxy_header: bool,
568 ) -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
569 let handle = Handle::new();
570
571 let server_handle = handle.clone();
572 let server_task = tokio::spawn(async move {
573 let app = Router::new().route("/", get(forward_ip_handler));
574
575 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
576
577 if parse_proxy_header {
578 Server::bind(addr)
579 .handle(server_handle)
580 .enable_proxy_protocol(None)
581 .serve(app.into_make_service())
582 .await
583 } else {
584 Server::bind(addr)
585 .handle(server_handle)
586 .serve(app.into_make_service())
587 .await
588 }
589 });
590
591 let addr = handle.listening().await.unwrap();
592
593 (handle, server_task, addr)
594 }
595
596 #[cfg(feature = "tls-rustls")]
597 async fn start_rustls_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
598 let handle = Handle::new();
599
600 let server_handle = handle.clone();
601 let server_task = tokio::spawn(async move {
602 let app = Router::new().route("/", get(forward_ip_handler));
603
604 let config = RustlsConfig::from_pem_file(
605 "examples/self-signed-certs/cert.pem",
606 "examples/self-signed-certs/key.pem",
607 )
608 .await?;
609
610 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
611
612 tls_rustls::bind_rustls(addr, config)
613 .handle(server_handle)
614 .enable_proxy_protocol(None)
615 .serve(app.into_make_service())
616 .await
617 });
618
619 let addr = handle.listening().await.unwrap();
620
621 (handle, server_task, addr)
622 }
623
624 #[cfg(feature = "tls-openssl")]
625 async fn start_openssl_server() -> (Handle, JoinHandle<io::Result<()>>, SocketAddr) {
626 let handle = Handle::new();
627
628 let server_handle = handle.clone();
629 let server_task = tokio::spawn(async move {
630 let app = Router::new().route("/", get(forward_ip_handler));
631
632 let config = OpenSSLConfig::from_pem_file(
633 "examples/self-signed-certs/cert.pem",
634 "examples/self-signed-certs/key.pem",
635 )
636 .unwrap();
637
638 let addr = SocketAddr::from(([127, 0, 0, 1], 0));
639
640 tls_openssl::bind_openssl(addr, config)
641 .handle(server_handle)
642 .enable_proxy_protocol(None)
643 .serve(app.into_make_service())
644 .await
645 });
646
647 let addr = handle.listening().await.unwrap();
648
649 (handle, server_task, addr)
650 }
651
652 #[derive(Debug, Clone, Copy)]
653 enum ProxyVersion {
654 V1,
655 V2,
656 None,
657 }
658
659 async fn start_proxy(
660 server_address: SocketAddr,
661 proxy_version: ProxyVersion,
662 ) -> Result<SocketAddr, Box<dyn std::error::Error>> {
663 let proxy_address = SocketAddr::from(([127, 0, 0, 1], 0));
664 let listener = TcpListener::bind(proxy_address).await?;
665 let proxy_address = listener.local_addr()?;
666
667 let _proxy_task = tokio::spawn(async move {
668 loop {
669 match listener.accept().await {
670 Ok((client_stream, _)) => {
671 tokio::spawn(async move {
672 if let Err(e) =
673 handle_conn(client_stream, server_address, proxy_version).await
674 {
675 println!("Error handling connection: {:?}", e);
676 }
677 });
678 }
679 Err(e) => println!("Failed to accept a connection: {:?}", e),
680 }
681 }
682 });
683
684 Ok(proxy_address)
685 }
686
687 async fn handle_conn(
688 mut client_stream: TcpStream,
689 server_address: SocketAddr,
690 proxy_version: ProxyVersion,
691 ) -> io::Result<()> {
692 let client_address = client_stream.peer_addr()?; let mut server_stream = TcpStream::connect(server_address).await?;
694 let server_address = server_stream.peer_addr()?; let (mut client_read, mut client_write) = client_stream.split();
697 let (mut server_read, mut server_write) = server_stream.split();
698
699 send_proxy_header(
700 &mut server_write,
701 client_address,
702 server_address,
703 proxy_version,
704 )
705 .await?;
706
707 let duration = Duration::from_secs(1);
708 let client_to_server = async {
709 match timeout(duration, transfer(&mut client_read, &mut server_write)).await {
710 Ok(result) => result,
711 Err(_) => Err(io::Error::new(
712 io::ErrorKind::TimedOut,
713 "Client to Server transfer timed out",
714 )),
715 }
716 };
717
718 let server_to_client = async {
719 match timeout(duration, transfer(&mut server_read, &mut client_write)).await {
720 Ok(result) => result,
721 Err(_) => Err(io::Error::new(
722 io::ErrorKind::TimedOut,
723 "Server to Client transfer timed out",
724 )),
725 }
726 };
727
728 let _ = tokio::try_join!(client_to_server, server_to_client);
729
730 Ok(())
731 }
732
733 async fn transfer(
734 read_stream: &mut (impl AsyncReadExt + Unpin),
735 write_stream: &mut (impl AsyncWriteExt + Unpin),
736 ) -> io::Result<()> {
737 let mut buf = [0; 4096];
738 loop {
739 let n = read_stream.read(&mut buf).await?;
740 if n == 0 {
741 break; }
743 write_stream.write_all(&buf[..n]).await?;
744 }
745 Ok(())
746 }
747
748 async fn send_proxy_header(
749 write_stream: &mut (impl AsyncWriteExt + Unpin),
750 client_address: SocketAddr,
751 server_address: SocketAddr,
752 proxy_version: ProxyVersion,
753 ) -> io::Result<()> {
754 match proxy_version {
755 ProxyVersion::V1 => {
756 let header = ppp::v1::Addresses::from((client_address, server_address)).to_string();
757
758 for byte in header.as_bytes() {
759 write_stream.write_all(&[*byte]).await?;
760 }
761 }
762 ProxyVersion::V2 => {
763 let mut header = Builder::with_addresses(
764 Version::Two | Command::Proxy,
766 Protocol::Stream,
767 (client_address, server_address),
768 )
769 .write_tlv(Type::NoOp, b"Hello, World!")?
770 .build()?;
771
772 for byte in header.drain(..) {
773 write_stream.write_all(&[byte]).await?;
774 }
775 }
776 ProxyVersion::None => {}
777 }
778
779 Ok(())
780 }
781
782 async fn connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>, SocketAddr) {
783 let stream = TcpStream::connect(addr).await.unwrap();
784 let client_addr = stream.local_addr().unwrap();
785
786 let (send_request, connection) = handshake(stream).await.unwrap();
787
788 let task = tokio::spawn(async move {
789 let _ = connection.await;
790 });
791
792 (send_request, task, client_addr)
793 }
794
795 #[cfg(feature = "tls-rustls")]
796 async fn rustls_connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>, SocketAddr) {
797 let stream = TcpStream::connect(addr).await.unwrap();
798 let client_addr = stream.local_addr().unwrap();
799 let tls_stream = rustls_connector()
800 .connect(rustls_dns_name(), stream)
801 .await
802 .unwrap();
803
804 let (send_request, connection) = handshake(tls_stream).await.unwrap();
805
806 let task = tokio::spawn(async move {
807 let _ = connection.await;
808 });
809
810 (send_request, task, client_addr)
811 }
812
813 #[cfg(feature = "tls-openssl")]
814 async fn openssl_connect(addr: SocketAddr) -> (SendRequest<Body>, JoinHandle<()>, SocketAddr) {
815 let stream = TcpStream::connect(addr).await.unwrap();
816 let client_addr = stream.local_addr().unwrap();
817 let tls_stream = openssl_connector(openssl_dns_name(), stream).await;
818
819 let (send_request, connection) = handshake(tls_stream).await.unwrap();
820
821 let task = tokio::spawn(async move {
822 let _ = connection.await;
823 });
824
825 (send_request, task, client_addr)
826 }
827
828 async fn send_empty_request(client: &mut SendRequest<Body>) -> (response::Parts, Bytes) {
829 let (parts, body) = client
830 .ready()
831 .await
832 .unwrap()
833 .call(Request::new(Body::empty()))
834 .await
835 .unwrap()
836 .into_parts();
837 let body = hyper::body::to_bytes(body).await.unwrap();
838
839 (parts, body)
840 }
841}