use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use bytes::Bytes;
use socks_lib::io::{self, AsyncRead, AsyncWrite};
use socks_lib::net::{TcpListener, TcpStream, UdpSocket};
use socks_lib::v5::server::auth::NoAuthentication;
use socks_lib::v5::server::{Config, Handler, Server};
use socks_lib::v5::{Address, Request, Response, Stream, UdpPacket};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tokio::time::{MissedTickBehavior, interval};
#[tokio::main]
async fn main() {
let listener = TcpListener::bind("0.0.0.0:1081").await.unwrap();
println!(
"SOCKS server listening on {}",
listener.local_addr().unwrap()
);
let config = Config::new(NoAuthentication, CommandHandler);
Server::run(listener, config.into(), async {
tokio::signal::ctrl_c().await.unwrap();
})
.await
.unwrap();
}
pub struct CommandHandler;
impl Handler for CommandHandler {
async fn handle<T>(&self, stream: &mut Stream<T>, request: Request) -> io::Result<()>
where
T: AsyncRead + AsyncWrite + Unpin + Send + Sync,
{
println!("Request: {request:?}");
match request {
Request::Connect(addr) => {
stream.write_response_unspecified().await?;
let mut target = TcpStream::connect(addr.to_string()).await?;
let copy = io::copy_bidirectional(stream, &mut target).await?;
println!(
"[TCP] {} -> {} | Sent: {}, Received: {}",
stream.peer_addr(),
addr,
copy.0,
copy.1
);
}
Request::Associate(_) => {
let server_ip = stream.local_addr().ip();
let inbound = Arc::new(UdpSocket::bind("0.0.0.0:0").await?);
let udp_port = inbound.local_addr().unwrap().port();
let reply_addr = SocketAddr::new(server_ip, udp_port);
let bind_addr = Address::from(reply_addr);
stream
.write_response(&Response::Success(&bind_addr))
.await?;
println!(
"[UDP] Association created for {}. Client should send UDP to {}.",
stream.peer_addr(),
bind_addr
);
udp_session_run(inbound, Duration::from_secs(180)).await?;
println!("[UDP] Association for {} ended.", stream.peer_addr());
}
_ => {
stream.write_response_unsupported().await?;
}
}
Ok(())
}
}
struct OutboundEntry {
socket: Arc<UdpSocket>,
last_active: Instant,
recv_task: JoinHandle<()>,
}
async fn udp_session_run(inbound: Arc<UdpSocket>, idle_timeout: Duration) -> io::Result<()> {
let mut buf = vec![0u8; 65535];
let (n, client_addr) = tokio::time::timeout(idle_timeout, inbound.recv_from(&mut buf))
.await
.map_err(|_| {
io::Error::new(
io::ErrorKind::TimedOut,
"Client did not send any UDP packet",
)
})??;
let nat: Arc<Mutex<HashMap<String, OutboundEntry>>> = Arc::new(Mutex::new(HashMap::new()));
let nat_for_sweep = nat.clone();
let sweep_handle = tokio::spawn(async move {
let mut tick = interval(Duration::from_secs(30)); tick.set_missed_tick_behavior(MissedTickBehavior::Delay);
loop {
tick.tick().await;
let mut map = nat_for_sweep.lock().await;
let dead_keys: Vec<String> = map
.iter()
.filter(|(_, v)| v.last_active.elapsed() > idle_timeout)
.map(|(k, v)| {
v.recv_task.abort();
k.clone()
})
.collect();
for k in dead_keys {
if map.remove(&k).is_some() {
println!("[UDP] NAT entry for {} timed out and was removed.", k);
}
}
}
});
handle_client_packet(&inbound, client_addr, &nat, &buf[..n]).await?;
let main_loop_result = loop {
match tokio::time::timeout(idle_timeout, inbound.recv_from(&mut buf)).await {
Ok(Ok((n, src))) => {
if src != client_addr {
continue;
}
if let Err(e) = handle_client_packet(&inbound, client_addr, &nat, &buf[..n]).await {
eprintln!("[UDP] Error handling client packet: {}", e);
break Err(e);
}
}
Ok(Err(e)) => {
break Err(e);
}
Err(_) => {
break Ok(());
}
}
};
sweep_handle.abort();
let mut guard = nat.lock().await;
for (_, entry) in guard.drain() {
entry.recv_task.abort();
}
main_loop_result
}
async fn handle_client_packet(
inbound: &Arc<UdpSocket>,
client_addr: SocketAddr,
nat: &Arc<Mutex<HashMap<String, OutboundEntry>>>,
raw: &[u8],
) -> io::Result<()> {
let pkt = UdpPacket::from_bytes(&mut &raw[..]).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to parse UDP packet: {e}"),
)
})?;
let target_sock_addr: SocketAddr = match pkt.address {
Address::IPv4(v4) => SocketAddr::V4(v4),
Address::IPv6(v6) => SocketAddr::V6(v6),
Address::Domain(ref domain, port) => {
let full_addr = format!("{}:{}", domain.format_as_str(), port);
tokio::net::lookup_host(full_addr)
.await?
.next()
.ok_or_else(|| {
io::Error::new(io::ErrorKind::NotFound, "Domain resolution failed")
})?
}
};
let target_key = target_sock_addr.to_string();
let mut map = nat.lock().await;
if !map.contains_key(&target_key) {
let outbound = Arc::new(UdpSocket::bind("0.0.0.0:0").await?);
let recv_task = {
let outbound_rx = outbound.clone();
let inbound_tx = inbound.clone();
let original_target_key = target_key.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 65535];
loop {
match outbound_rx.recv_from(&mut buf).await {
Ok((n, src_addr)) => {
let data = Bytes::copy_from_slice(&buf[..n]);
let response_packet = UdpPacket::un_frag(Address::from(src_addr), data);
if let Err(e) = inbound_tx
.send_to(&response_packet.to_bytes(), client_addr)
.await
{
eprintln!(
"[UDP] Failed to send reply to client for target {}: {}",
original_target_key, e
);
break;
}
}
Err(e) => {
eprintln!(
"[UDP] Error on outbound socket for target {}: {}",
original_target_key, e
);
break;
}
}
}
})
};
println!("[UDP] New NAT entry created for target: {}", target_key);
let entry = OutboundEntry {
socket: outbound,
last_active: Instant::now(),
recv_task,
};
map.insert(target_key.clone(), entry);
}
if let Some(entry) = map.get_mut(&target_key) {
entry.last_active = Instant::now();
let data = pkt.data;
entry.socket.send_to(&data, target_sock_addr).await?;
} else {
return Err(io::Error::new(
io::ErrorKind::Other,
"Failed to get or create NAT entry",
));
}
Ok(())
}