relay_core_lib/capture/
transparent_tcp.rs1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::future::Future;
4use std::pin::Pin;
5use tokio::net::{TcpListener, TcpStream};
6use crate::capture::source::{CaptureSource, IncomingConnection};
7use crate::capture::original_dst::OriginalDstProvider;
8use std::io;
9
10fn is_recoverable_original_dst_error(err: &io::Error) -> bool {
11 matches!(
12 err.kind(),
13 io::ErrorKind::Unsupported
14 | io::ErrorKind::NotFound
15 | io::ErrorKind::AddrNotAvailable
16 | io::ErrorKind::InvalidInput
17 )
18}
19
20pub struct TransparentTcpCaptureSource {
21 listener: TcpListener,
22 original_dst_provider: Arc<dyn OriginalDstProvider>,
23}
24
25impl TransparentTcpCaptureSource {
26 pub fn new(listener: TcpListener, original_dst_provider: Arc<dyn OriginalDstProvider>) -> Self {
27 Self { listener, original_dst_provider }
28 }
29}
30
31impl CaptureSource for TransparentTcpCaptureSource {
32 type IO = TcpStream;
33
34 fn accept(&mut self) -> Pin<Box<dyn Future<Output = crate::error::Result<IncomingConnection<Self::IO>>> + Send + '_>> {
35 Box::pin(async move {
36 let (stream, client_addr) = self.listener.accept().await?;
37
38 let target_addr = match self.original_dst_provider.get_original_dst(&stream) {
42 Ok(target_addr) => target_addr,
43 Err(e) if is_recoverable_original_dst_error(&e) => {
44 tracing::debug!("Original destination unavailable, falling back: {}", e);
45 None
46 }
47 Err(e) => return Err(e.into()),
48 };
49
50 Ok(IncomingConnection {
51 stream,
52 client_addr,
53 target_addr,
54 })
55 })
56 }
57
58 fn listen_addrs(&self) -> Vec<SocketAddr> {
59 let mut addrs = self
65 .original_dst_provider
66 .get_listen_addrs()
67 .into_iter()
68 .collect::<std::collections::BTreeSet<_>>();
69 if let Ok(addr) = self.listener.local_addr() {
70 addrs.insert(addr);
71 }
72 addrs.into_iter().collect()
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use async_trait::async_trait;
80 use std::collections::BTreeSet;
81 use std::io;
82
83 struct MockOriginalDstProvider {
84 listen_addrs: BTreeSet<SocketAddr>,
85 dst_result: io::Result<Option<SocketAddr>>,
86 }
87
88 #[async_trait]
89 impl OriginalDstProvider for MockOriginalDstProvider {
90 fn get_original_dst(&self, _stream: &TcpStream) -> io::Result<Option<SocketAddr>> {
91 match &self.dst_result {
92 Ok(v) => Ok(*v),
93 Err(e) => Err(io::Error::new(e.kind(), e.to_string())),
94 }
95 }
96
97 fn get_listen_addrs(&self) -> BTreeSet<SocketAddr> {
98 self.listen_addrs.clone()
99 }
100 }
101
102 #[tokio::test]
103 async fn test_listen_addrs_contains_listener_and_is_deduplicated() {
104 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
105 let listener_addr = listener.local_addr().expect("local_addr");
106 let provider_addrs = BTreeSet::from([listener_addr]);
107 let provider = Arc::new(MockOriginalDstProvider {
108 listen_addrs: provider_addrs,
109 dst_result: Ok(None),
110 });
111
112 let source = TransparentTcpCaptureSource::new(listener, provider);
113 let addrs = source.listen_addrs();
114
115 assert_eq!(addrs.len(), 1);
116 assert_eq!(addrs[0], listener_addr);
117 }
118
119 #[tokio::test]
120 async fn test_accept_recovers_when_original_dst_is_unsupported() {
121 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
122 let listener_addr = listener.local_addr().expect("local_addr");
123 let provider = Arc::new(MockOriginalDstProvider {
124 listen_addrs: BTreeSet::new(),
125 dst_result: Err(io::Error::new(
126 io::ErrorKind::Unsupported,
127 "ipv6 transparent proxy not implemented",
128 )),
129 });
130 let mut source = TransparentTcpCaptureSource::new(listener, provider);
131
132 let connect_task = tokio::spawn(async move {
133 let _ = TcpStream::connect(listener_addr).await;
134 });
135
136 let conn = source.accept().await.expect("accept should recover");
137 assert!(conn.target_addr.is_none(), "target_addr should fallback to None");
138 let _ = connect_task.await;
139 }
140
141 #[tokio::test]
142 async fn test_accept_propagates_non_recoverable_original_dst_error() {
143 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
144 let listener_addr = listener.local_addr().expect("local_addr");
145 let provider = Arc::new(MockOriginalDstProvider {
146 listen_addrs: BTreeSet::new(),
147 dst_result: Err(io::Error::new(
148 io::ErrorKind::PermissionDenied,
149 "permission denied",
150 )),
151 });
152 let mut source = TransparentTcpCaptureSource::new(listener, provider);
153
154 let connect_task = tokio::spawn(async move {
155 let _ = TcpStream::connect(listener_addr).await;
156 });
157
158 let result = source.accept().await;
159 assert!(result.is_err(), "non-recoverable errors must propagate");
160 let _ = connect_task.await;
161 }
162}