relay_core_lib/capture/
transparent_tcp.rs1use crate::capture::original_dst::OriginalDstProvider;
2use crate::capture::source::{CaptureSource, IncomingConnection};
3use std::future::Future;
4use std::io;
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::sync::Arc;
8use tokio::net::{TcpListener, TcpStream};
9
10fn is_recoverable_original_dst_error(err: &io::Error) -> bool {
11 matches!(
12 err.kind(),
13 io::ErrorKind::Unsupported
14 | io::ErrorKind::NotFound
15 | io::ErrorKind::AddrNotAvailable
16 | io::ErrorKind::InvalidInput
17 )
18}
19
20pub struct TransparentTcpCaptureSource {
21 listener: TcpListener,
22 original_dst_provider: Arc<dyn OriginalDstProvider>,
23}
24
25impl TransparentTcpCaptureSource {
26 pub fn new(listener: TcpListener, original_dst_provider: Arc<dyn OriginalDstProvider>) -> Self {
27 Self {
28 listener,
29 original_dst_provider,
30 }
31 }
32}
33
34impl CaptureSource for TransparentTcpCaptureSource {
35 type IO = TcpStream;
36
37 fn accept(
38 &mut self,
39 ) -> Pin<Box<dyn Future<Output = crate::error::Result<IncomingConnection<Self::IO>>> + Send + '_>>
40 {
41 Box::pin(async move {
42 let (stream, client_addr) = self.listener.accept().await?;
43
44 let target_addr = match self.original_dst_provider.get_original_dst(&stream) {
48 Ok(target_addr) => target_addr,
49 Err(e) if is_recoverable_original_dst_error(&e) => {
50 tracing::debug!("Original destination unavailable, falling back: {}", e);
51 None
52 }
53 Err(e) => return Err(e.into()),
54 };
55
56 Ok(IncomingConnection {
57 stream,
58 client_addr,
59 target_addr,
60 })
61 })
62 }
63
64 fn listen_addrs(&self) -> Vec<SocketAddr> {
65 let mut addrs = self
71 .original_dst_provider
72 .get_listen_addrs()
73 .into_iter()
74 .collect::<std::collections::BTreeSet<_>>();
75 if let Ok(addr) = self.listener.local_addr() {
76 addrs.insert(addr);
77 }
78 addrs.into_iter().collect()
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use async_trait::async_trait;
86 use std::collections::BTreeSet;
87 use std::io;
88
89 struct MockOriginalDstProvider {
90 listen_addrs: BTreeSet<SocketAddr>,
91 dst_result: io::Result<Option<SocketAddr>>,
92 }
93
94 #[async_trait]
95 impl OriginalDstProvider for MockOriginalDstProvider {
96 fn get_original_dst(&self, _stream: &TcpStream) -> io::Result<Option<SocketAddr>> {
97 match &self.dst_result {
98 Ok(v) => Ok(*v),
99 Err(e) => Err(io::Error::new(e.kind(), e.to_string())),
100 }
101 }
102
103 fn get_listen_addrs(&self) -> BTreeSet<SocketAddr> {
104 self.listen_addrs.clone()
105 }
106 }
107
108 #[tokio::test]
109 async fn test_listen_addrs_contains_listener_and_is_deduplicated() {
110 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
111 let listener_addr = listener.local_addr().expect("local_addr");
112 let provider_addrs = BTreeSet::from([listener_addr]);
113 let provider = Arc::new(MockOriginalDstProvider {
114 listen_addrs: provider_addrs,
115 dst_result: Ok(None),
116 });
117
118 let source = TransparentTcpCaptureSource::new(listener, provider);
119 let addrs = source.listen_addrs();
120
121 assert_eq!(addrs.len(), 1);
122 assert_eq!(addrs[0], listener_addr);
123 }
124
125 #[tokio::test]
126 async fn test_accept_recovers_when_original_dst_is_unsupported() {
127 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
128 let listener_addr = listener.local_addr().expect("local_addr");
129 let provider = Arc::new(MockOriginalDstProvider {
130 listen_addrs: BTreeSet::new(),
131 dst_result: Err(io::Error::new(
132 io::ErrorKind::Unsupported,
133 "ipv6 transparent proxy not implemented",
134 )),
135 });
136 let mut source = TransparentTcpCaptureSource::new(listener, provider);
137
138 let connect_task = tokio::spawn(async move {
139 let _ = TcpStream::connect(listener_addr).await;
140 });
141
142 let conn = source.accept().await.expect("accept should recover");
143 assert!(
144 conn.target_addr.is_none(),
145 "target_addr should fallback to None"
146 );
147 let _ = connect_task.await;
148 }
149
150 #[tokio::test]
151 async fn test_accept_propagates_non_recoverable_original_dst_error() {
152 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
153 let listener_addr = listener.local_addr().expect("local_addr");
154 let provider = Arc::new(MockOriginalDstProvider {
155 listen_addrs: BTreeSet::new(),
156 dst_result: Err(io::Error::new(
157 io::ErrorKind::PermissionDenied,
158 "permission denied",
159 )),
160 });
161 let mut source = TransparentTcpCaptureSource::new(listener, provider);
162
163 let connect_task = tokio::spawn(async move {
164 let _ = TcpStream::connect(listener_addr).await;
165 });
166
167 let result = source.accept().await;
168 assert!(result.is_err(), "non-recoverable errors must propagate");
169 let _ = connect_task.await;
170 }
171}