Skip to main content

relay_core_lib/capture/
transparent_tcp.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3use std::future::Future;
4use std::pin::Pin;
5use tokio::net::{TcpListener, TcpStream};
6use crate::capture::source::{CaptureSource, IncomingConnection};
7use crate::capture::original_dst::OriginalDstProvider;
8use std::io;
9
10fn is_recoverable_original_dst_error(err: &io::Error) -> bool {
11    matches!(
12        err.kind(),
13        io::ErrorKind::Unsupported
14            | io::ErrorKind::NotFound
15            | io::ErrorKind::AddrNotAvailable
16            | io::ErrorKind::InvalidInput
17    )
18}
19
20pub struct TransparentTcpCaptureSource {
21    listener: TcpListener,
22    original_dst_provider: Arc<dyn OriginalDstProvider>,
23}
24
25impl TransparentTcpCaptureSource {
26    pub fn new(listener: TcpListener, original_dst_provider: Arc<dyn OriginalDstProvider>) -> Self {
27        Self { listener, original_dst_provider }
28    }
29}
30
31impl CaptureSource for TransparentTcpCaptureSource {
32    type IO = TcpStream;
33
34    fn accept(&mut self) -> Pin<Box<dyn Future<Output = crate::error::Result<IncomingConnection<Self::IO>>> + Send + '_>> {
35        Box::pin(async move {
36            let (stream, client_addr) = self.listener.accept().await?;
37            
38            // Try to get original destination.
39            // Some platforms/modes may return recoverable errors (e.g. unsupported IPv6),
40            // in which case we gracefully degrade to "unknown original dst".
41            let target_addr = match self.original_dst_provider.get_original_dst(&stream) {
42                Ok(target_addr) => target_addr,
43                Err(e) if is_recoverable_original_dst_error(&e) => {
44                    tracing::debug!("Original destination unavailable, falling back: {}", e);
45                    None
46                }
47                Err(e) => return Err(e.into()),
48            };
49            
50            Ok(IncomingConnection {
51                stream,
52                client_addr,
53                target_addr,
54            })
55        })
56    }
57
58    fn listen_addrs(&self) -> Vec<SocketAddr> {
59        // Return listen addresses from OriginalDstProvider (which knows iptables/nftables targets)
60        // or just the listener's local address.
61        // Usually transparent proxy listens on 0.0.0.0, but we want to avoid loops to *that* address.
62        // But loop detection logic also checks local IPs.
63        
64        let mut addrs = self
65            .original_dst_provider
66            .get_listen_addrs()
67            .into_iter()
68            .collect::<std::collections::BTreeSet<_>>();
69        if let Ok(addr) = self.listener.local_addr() {
70            addrs.insert(addr);
71        }
72        addrs.into_iter().collect()
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use async_trait::async_trait;
80    use std::collections::BTreeSet;
81    use std::io;
82
83    struct MockOriginalDstProvider {
84        listen_addrs: BTreeSet<SocketAddr>,
85        dst_result: io::Result<Option<SocketAddr>>,
86    }
87
88    #[async_trait]
89    impl OriginalDstProvider for MockOriginalDstProvider {
90        fn get_original_dst(&self, _stream: &TcpStream) -> io::Result<Option<SocketAddr>> {
91            match &self.dst_result {
92                Ok(v) => Ok(*v),
93                Err(e) => Err(io::Error::new(e.kind(), e.to_string())),
94            }
95        }
96
97        fn get_listen_addrs(&self) -> BTreeSet<SocketAddr> {
98            self.listen_addrs.clone()
99        }
100    }
101
102    #[tokio::test]
103    async fn test_listen_addrs_contains_listener_and_is_deduplicated() {
104        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
105        let listener_addr = listener.local_addr().expect("local_addr");
106        let provider_addrs = BTreeSet::from([listener_addr]);
107        let provider = Arc::new(MockOriginalDstProvider {
108            listen_addrs: provider_addrs,
109            dst_result: Ok(None),
110        });
111
112        let source = TransparentTcpCaptureSource::new(listener, provider);
113        let addrs = source.listen_addrs();
114
115        assert_eq!(addrs.len(), 1);
116        assert_eq!(addrs[0], listener_addr);
117    }
118
119    #[tokio::test]
120    async fn test_accept_recovers_when_original_dst_is_unsupported() {
121        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
122        let listener_addr = listener.local_addr().expect("local_addr");
123        let provider = Arc::new(MockOriginalDstProvider {
124            listen_addrs: BTreeSet::new(),
125            dst_result: Err(io::Error::new(
126                io::ErrorKind::Unsupported,
127                "ipv6 transparent proxy not implemented",
128            )),
129        });
130        let mut source = TransparentTcpCaptureSource::new(listener, provider);
131
132        let connect_task = tokio::spawn(async move {
133            let _ = TcpStream::connect(listener_addr).await;
134        });
135
136        let conn = source.accept().await.expect("accept should recover");
137        assert!(conn.target_addr.is_none(), "target_addr should fallback to None");
138        let _ = connect_task.await;
139    }
140
141    #[tokio::test]
142    async fn test_accept_propagates_non_recoverable_original_dst_error() {
143        let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
144        let listener_addr = listener.local_addr().expect("local_addr");
145        let provider = Arc::new(MockOriginalDstProvider {
146            listen_addrs: BTreeSet::new(),
147            dst_result: Err(io::Error::new(
148                io::ErrorKind::PermissionDenied,
149                "permission denied",
150            )),
151        });
152        let mut source = TransparentTcpCaptureSource::new(listener, provider);
153
154        let connect_task = tokio::spawn(async move {
155            let _ = TcpStream::connect(listener_addr).await;
156        });
157
158        let result = source.accept().await;
159        assert!(result.is_err(), "non-recoverable errors must propagate");
160        let _ = connect_task.await;
161    }
162}