use std::net::SocketAddr;
use std::sync::Arc;
use std::future::Future;
use std::pin::Pin;
use tokio::net::{TcpListener, TcpStream};
use crate::capture::source::{CaptureSource, IncomingConnection};
use crate::capture::original_dst::OriginalDstProvider;
use std::io;
fn is_recoverable_original_dst_error(err: &io::Error) -> bool {
matches!(
err.kind(),
io::ErrorKind::Unsupported
| io::ErrorKind::NotFound
| io::ErrorKind::AddrNotAvailable
| io::ErrorKind::InvalidInput
)
}
pub struct TransparentTcpCaptureSource {
listener: TcpListener,
original_dst_provider: Arc<dyn OriginalDstProvider>,
}
impl TransparentTcpCaptureSource {
pub fn new(listener: TcpListener, original_dst_provider: Arc<dyn OriginalDstProvider>) -> Self {
Self { listener, original_dst_provider }
}
}
impl CaptureSource for TransparentTcpCaptureSource {
type IO = TcpStream;
fn accept(&mut self) -> Pin<Box<dyn Future<Output = crate::error::Result<IncomingConnection<Self::IO>>> + Send + '_>> {
Box::pin(async move {
let (stream, client_addr) = self.listener.accept().await?;
let target_addr = match self.original_dst_provider.get_original_dst(&stream) {
Ok(target_addr) => target_addr,
Err(e) if is_recoverable_original_dst_error(&e) => {
tracing::debug!("Original destination unavailable, falling back: {}", e);
None
}
Err(e) => return Err(e.into()),
};
Ok(IncomingConnection {
stream,
client_addr,
target_addr,
})
})
}
fn listen_addrs(&self) -> Vec<SocketAddr> {
let mut addrs = self
.original_dst_provider
.get_listen_addrs()
.into_iter()
.collect::<std::collections::BTreeSet<_>>();
if let Ok(addr) = self.listener.local_addr() {
addrs.insert(addr);
}
addrs.into_iter().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use std::collections::BTreeSet;
use std::io;
struct MockOriginalDstProvider {
listen_addrs: BTreeSet<SocketAddr>,
dst_result: io::Result<Option<SocketAddr>>,
}
#[async_trait]
impl OriginalDstProvider for MockOriginalDstProvider {
fn get_original_dst(&self, _stream: &TcpStream) -> io::Result<Option<SocketAddr>> {
match &self.dst_result {
Ok(v) => Ok(*v),
Err(e) => Err(io::Error::new(e.kind(), e.to_string())),
}
}
fn get_listen_addrs(&self) -> BTreeSet<SocketAddr> {
self.listen_addrs.clone()
}
}
#[tokio::test]
async fn test_listen_addrs_contains_listener_and_is_deduplicated() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let listener_addr = listener.local_addr().expect("local_addr");
let provider_addrs = BTreeSet::from([listener_addr]);
let provider = Arc::new(MockOriginalDstProvider {
listen_addrs: provider_addrs,
dst_result: Ok(None),
});
let source = TransparentTcpCaptureSource::new(listener, provider);
let addrs = source.listen_addrs();
assert_eq!(addrs.len(), 1);
assert_eq!(addrs[0], listener_addr);
}
#[tokio::test]
async fn test_accept_recovers_when_original_dst_is_unsupported() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let listener_addr = listener.local_addr().expect("local_addr");
let provider = Arc::new(MockOriginalDstProvider {
listen_addrs: BTreeSet::new(),
dst_result: Err(io::Error::new(
io::ErrorKind::Unsupported,
"ipv6 transparent proxy not implemented",
)),
});
let mut source = TransparentTcpCaptureSource::new(listener, provider);
let connect_task = tokio::spawn(async move {
let _ = TcpStream::connect(listener_addr).await;
});
let conn = source.accept().await.expect("accept should recover");
assert!(conn.target_addr.is_none(), "target_addr should fallback to None");
let _ = connect_task.await;
}
#[tokio::test]
async fn test_accept_propagates_non_recoverable_original_dst_error() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let listener_addr = listener.local_addr().expect("local_addr");
let provider = Arc::new(MockOriginalDstProvider {
listen_addrs: BTreeSet::new(),
dst_result: Err(io::Error::new(
io::ErrorKind::PermissionDenied,
"permission denied",
)),
});
let mut source = TransparentTcpCaptureSource::new(listener, provider);
let connect_task = tokio::spawn(async move {
let _ = TcpStream::connect(listener_addr).await;
});
let result = source.accept().await;
assert!(result.is_err(), "non-recoverable errors must propagate");
let _ = connect_task.await;
}
}