Skip to main content

relay_core_lib/capture/
transparent_tcp.rs

1use crate::capture::original_dst::OriginalDstProvider;
2use crate::capture::source::{CaptureSource, IncomingConnection};
3use std::future::Future;
4use std::io;
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::sync::Arc;
8use tokio::net::{TcpListener, TcpStream};
9
10fn is_recoverable_original_dst_error(err: &io::Error) -> bool {
11    matches!(
12        err.kind(),
13        io::ErrorKind::Unsupported
14            | io::ErrorKind::NotFound
15            | io::ErrorKind::AddrNotAvailable
16            | io::ErrorKind::InvalidInput
17    )
18}
19
20pub struct TransparentTcpCaptureSource {
21    listener: TcpListener,
22    original_dst_provider: Arc<dyn OriginalDstProvider>,
23}
24
25impl TransparentTcpCaptureSource {
26    pub fn new(listener: TcpListener, original_dst_provider: Arc<dyn OriginalDstProvider>) -> Self {
27        Self {
28            listener,
29            original_dst_provider,
30        }
31    }
32}
33
34impl CaptureSource for TransparentTcpCaptureSource {
35    type IO = TcpStream;
36
37    fn accept(
38        &mut self,
39    ) -> Pin<Box<dyn Future<Output = crate::error::Result<IncomingConnection<Self::IO>>> + Send + '_>>
40    {
41        Box::pin(async move {
42            let (stream, client_addr) = self.listener.accept().await?;
43
44            // Try to get original destination.
45            // Some platforms/modes may return recoverable errors (e.g. unsupported IPv6),
46            // in which case we gracefully degrade to "unknown original dst".
47            let target_addr = match self.original_dst_provider.get_original_dst(&stream) {
48                Ok(target_addr) => target_addr,
49                Err(e) if is_recoverable_original_dst_error(&e) => {
50                    tracing::debug!("Original destination unavailable, falling back: {}", e);
51                    None
52                }
53                Err(e) => return Err(e.into()),
54            };
55
56            Ok(IncomingConnection {
57                stream,
58                client_addr,
59                target_addr,
60            })
61        })
62    }
63
64    fn listen_addrs(&self) -> Vec<SocketAddr> {
65        // Return listen addresses from OriginalDstProvider (which knows iptables/nftables targets)
66        // or just the listener's local address.
67        // Usually transparent proxy listens on 0.0.0.0, but we want to avoid loops to *that* address.
68        // But loop detection logic also checks local IPs.
69
70        let mut addrs = self
71            .original_dst_provider
72            .get_listen_addrs()
73            .into_iter()
74            .collect::<std::collections::BTreeSet<_>>();
75        if let Ok(addr) = self.listener.local_addr() {
76            addrs.insert(addr);
77        }
78        addrs.into_iter().collect()
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use async_trait::async_trait;
86    use std::collections::BTreeSet;
87    use std::io;
88
89    struct MockOriginalDstProvider {
90        listen_addrs: BTreeSet<SocketAddr>,
91        dst_result: io::Result<Option<SocketAddr>>,
92    }
93
94    #[async_trait]
95    impl OriginalDstProvider for MockOriginalDstProvider {
96        fn get_original_dst(&self, _stream: &TcpStream) -> io::Result<Option<SocketAddr>> {
97            match &self.dst_result {
98                Ok(v) => Ok(*v),
99                Err(e) => Err(io::Error::new(e.kind(), e.to_string())),
100            }
101        }
102
103        fn get_listen_addrs(&self) -> BTreeSet<SocketAddr> {
104            self.listen_addrs.clone()
105        }
106    }
107
108    #[tokio::test]
109    async fn test_listen_addrs_contains_listener_and_is_deduplicated() {
110        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
111        let listener_addr = listener.local_addr().expect("local_addr");
112        let provider_addrs = BTreeSet::from([listener_addr]);
113        let provider = Arc::new(MockOriginalDstProvider {
114            listen_addrs: provider_addrs,
115            dst_result: Ok(None),
116        });
117
118        let source = TransparentTcpCaptureSource::new(listener, provider);
119        let addrs = source.listen_addrs();
120
121        assert_eq!(addrs.len(), 1);
122        assert_eq!(addrs[0], listener_addr);
123    }
124
125    #[tokio::test]
126    async fn test_accept_recovers_when_original_dst_is_unsupported() {
127        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
128        let listener_addr = listener.local_addr().expect("local_addr");
129        let provider = Arc::new(MockOriginalDstProvider {
130            listen_addrs: BTreeSet::new(),
131            dst_result: Err(io::Error::new(
132                io::ErrorKind::Unsupported,
133                "ipv6 transparent proxy not implemented",
134            )),
135        });
136        let mut source = TransparentTcpCaptureSource::new(listener, provider);
137
138        let connect_task = tokio::spawn(async move {
139            let _ = TcpStream::connect(listener_addr).await;
140        });
141
142        let conn = source.accept().await.expect("accept should recover");
143        assert!(
144            conn.target_addr.is_none(),
145            "target_addr should fallback to None"
146        );
147        let _ = connect_task.await;
148    }
149
150    #[tokio::test]
151    async fn test_accept_propagates_non_recoverable_original_dst_error() {
152        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
153        let listener_addr = listener.local_addr().expect("local_addr");
154        let provider = Arc::new(MockOriginalDstProvider {
155            listen_addrs: BTreeSet::new(),
156            dst_result: Err(io::Error::new(
157                io::ErrorKind::PermissionDenied,
158                "permission denied",
159            )),
160        });
161        let mut source = TransparentTcpCaptureSource::new(listener, provider);
162
163        let connect_task = tokio::spawn(async move {
164            let _ = TcpStream::connect(listener_addr).await;
165        });
166
167        let result = source.accept().await;
168        assert!(result.is_err(), "non-recoverable errors must propagate");
169        let _ = connect_task.await;
170    }
171}