use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use bytes::Bytes;
use smoltcp::iface::{SocketHandle, SocketSet};
use smoltcp::socket::tcp;
use smoltcp::wire::IpListenEndpoint;
use tokio::sync::mpsc;
const TCP_RX_BUF_SIZE: usize = 65536;
const TCP_TX_BUF_SIZE: usize = 65536;
const DEFAULT_MAX_CONNECTIONS: usize = 256;
const CHANNEL_CAPACITY: usize = 32;
const RELAY_BUF_SIZE: usize = 16384;
pub struct ConnectionTracker {
connections: HashMap<SocketHandle, Connection>,
connection_keys: HashSet<(SocketAddr, SocketAddr)>,
max_connections: usize,
}
const DEFERRED_CLOSE_LIMIT: u16 = 64;
struct Connection {
src: SocketAddr,
dst: SocketAddr,
to_proxy: mpsc::Sender<Bytes>,
from_proxy: mpsc::Receiver<Bytes>,
proxy_channels: Option<ProxyChannels>,
proxy_spawned: bool,
write_buf: Option<(Bytes, usize)>,
read_buf: Option<Bytes>,
close_attempts: u16,
}
struct ProxyChannels {
from_smoltcp: mpsc::Receiver<Bytes>,
to_smoltcp: mpsc::Sender<Bytes>,
}
pub struct NewConnection {
pub dst: SocketAddr,
pub from_smoltcp: mpsc::Receiver<Bytes>,
pub to_smoltcp: mpsc::Sender<Bytes>,
}
impl ConnectionTracker {
pub fn new(max_connections: Option<usize>) -> Self {
Self {
connections: HashMap::new(),
connection_keys: HashSet::new(),
max_connections: max_connections.unwrap_or(DEFAULT_MAX_CONNECTIONS),
}
}
pub fn has_socket_for(&self, src: &SocketAddr, dst: &SocketAddr) -> bool {
self.connection_keys.contains(&(*src, *dst))
}
pub fn create_tcp_socket(
&mut self,
src: SocketAddr,
dst: SocketAddr,
sockets: &mut SocketSet<'_>,
) -> bool {
if self.connections.len() >= self.max_connections {
return false;
}
let rx_buf = tcp::SocketBuffer::new(vec![0u8; TCP_RX_BUF_SIZE]);
let tx_buf = tcp::SocketBuffer::new(vec![0u8; TCP_TX_BUF_SIZE]);
let mut socket = tcp::Socket::new(rx_buf, tx_buf);
let listen_endpoint = IpListenEndpoint {
addr: Some(dst.ip().into()),
port: dst.port(),
};
if socket.listen(listen_endpoint).is_err() {
return false;
}
let handle = sockets.add(socket);
let (to_proxy_tx, to_proxy_rx) = mpsc::channel(CHANNEL_CAPACITY);
let (from_proxy_tx, from_proxy_rx) = mpsc::channel(CHANNEL_CAPACITY);
self.connection_keys.insert((src, dst));
self.connections.insert(
handle,
Connection {
src,
dst,
to_proxy: to_proxy_tx,
from_proxy: from_proxy_rx,
proxy_channels: Some(ProxyChannels {
from_smoltcp: to_proxy_rx,
to_smoltcp: from_proxy_tx,
}),
proxy_spawned: false,
write_buf: None,
read_buf: None,
close_attempts: 0,
},
);
true
}
pub fn relay_data(&mut self, sockets: &mut SocketSet<'_>) {
let mut relay_buf = [0u8; RELAY_BUF_SIZE];
for (&handle, conn) in &mut self.connections {
if !conn.proxy_spawned {
continue;
}
let socket = sockets.get_mut::<tcp::Socket>(handle);
if conn.to_proxy.is_closed() {
write_proxy_data(socket, conn);
if conn.write_buf.is_none() {
socket.close();
} else {
conn.close_attempts += 1;
if conn.close_attempts >= DEFERRED_CLOSE_LIMIT {
socket.abort();
}
}
continue;
}
if let Some(pending) = conn.read_buf.take()
&& let Err(e) = conn.to_proxy.try_send(pending)
{
conn.read_buf = Some(e.into_inner());
}
if conn.read_buf.is_none() {
while socket.can_recv() {
match socket.recv_slice(&mut relay_buf) {
Ok(n) if n > 0 => {
let data = Bytes::copy_from_slice(&relay_buf[..n]);
if let Err(e) = conn.to_proxy.try_send(data) {
conn.read_buf = Some(e.into_inner());
break;
}
}
_ => break,
}
}
}
write_proxy_data(socket, conn);
}
}
pub fn take_new_connections(&mut self, sockets: &mut SocketSet<'_>) -> Vec<NewConnection> {
let mut new = Vec::new();
for (&handle, conn) in &mut self.connections {
if conn.proxy_spawned {
continue;
}
let socket = sockets.get::<tcp::Socket>(handle);
if socket.state() == tcp::State::Established {
conn.proxy_spawned = true;
if let Some(channels) = conn.proxy_channels.take() {
new.push(NewConnection {
dst: conn.dst,
from_smoltcp: channels.from_smoltcp,
to_smoltcp: channels.to_smoltcp,
});
}
}
}
new
}
pub fn cleanup_closed(&mut self, sockets: &mut SocketSet<'_>) {
let keys = &mut self.connection_keys;
self.connections.retain(|&handle, conn| {
let socket = sockets.get::<tcp::Socket>(handle);
if matches!(socket.state(), tcp::State::Closed) {
keys.remove(&(conn.src, conn.dst));
sockets.remove(handle);
false
} else {
true
}
});
}
}
fn write_proxy_data(socket: &mut tcp::Socket<'_>, conn: &mut Connection) {
if let Some((data, offset)) = &mut conn.write_buf {
if socket.can_send() {
match socket.send_slice(&data[*offset..]) {
Ok(written) => {
*offset += written;
if *offset >= data.len() {
conn.write_buf = None;
}
}
Err(_) => return,
}
} else {
return;
}
}
while conn.write_buf.is_none() {
match conn.from_proxy.try_recv() {
Ok(data) => {
if socket.can_send() {
match socket.send_slice(&data) {
Ok(written) if written < data.len() => {
conn.write_buf = Some((data, written));
}
Err(_) => {
conn.write_buf = Some((data, 0));
}
_ => {}
}
} else {
conn.write_buf = Some((data, 0));
}
}
Err(_) => break,
}
}
}