use crate::new_udp_header;
use crate::parse_udp_request;
use crate::read_exact;
use crate::ready;
use crate::util::target_addr::{read_address, TargetAddr};
use crate::util::stream::tcp_connect_with_timeout;
use crate::Socks5Command;
use crate::{consts, AuthenticationMethod, ReplyError, Result, SocksError};
use anyhow::Context;
use std::future::Future;
use std::io;
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::{SocketAddr, ToSocketAddrs as StdToSocketAddrs};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context as AsyncContext, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::UdpSocket;
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs as AsyncToSocketAddrs};
use tokio::try_join;
use tokio_stream::Stream;
#[derive(Clone)]
pub struct Config {
request_timeout: u64,
skip_auth: bool,
dns_resolve: bool,
execute_command: bool,
allow_udp: bool,
auth: Option<Arc<dyn Authentication>>,
}
impl Default for Config {
fn default() -> Self {
Config {
request_timeout: 10,
skip_auth: false,
dns_resolve: true,
execute_command: true,
allow_udp: false,
auth: None,
}
}
}
pub trait Authentication: Send + Sync {
fn authenticate(&self, username: &str, password: &str) -> bool;
}
pub struct SimpleUserPassword {
pub username: String,
pub password: String,
}
impl Authentication for SimpleUserPassword {
fn authenticate(&self, username: &str, password: &str) -> bool {
username == &self.username && password == &self.password
}
}
impl Config {
pub fn set_request_timeout(&mut self, n: u64) -> &mut Self {
self.request_timeout = n;
self
}
pub fn set_skip_auth(&mut self, value: bool) -> &mut Self {
self.skip_auth = value;
self
}
pub fn set_authentication<T: Authentication + 'static>(
&mut self,
authentication: T,
) -> &mut Self {
self.auth = Some(Arc::new(authentication));
self
}
pub fn set_execute_command(&mut self, value: bool) -> &mut Self {
self.execute_command = value;
self
}
pub fn set_dns_resolve(&mut self, value: bool) -> &mut Self {
self.dns_resolve = value;
self
}
pub fn set_udp_support(&mut self, value: bool) -> &mut Self {
self.allow_udp = value;
self
}
}
pub struct Socks5Server {
listener: TcpListener,
config: Arc<Config>,
}
impl Socks5Server {
pub async fn bind<A: AsyncToSocketAddrs>(addr: A) -> io::Result<Socks5Server> {
let listener = TcpListener::bind(&addr).await?;
let config = Arc::new(Config::default());
Ok(Socks5Server { listener, config })
}
pub fn set_config(&mut self, config: Config) {
self.config = Arc::new(config);
}
pub fn incoming(&self) -> Incoming<'_> {
Incoming(self, None)
}
}
pub struct Incoming<'a>(
&'a Socks5Server,
Option<Pin<Box<dyn Future<Output = io::Result<(TcpStream, SocketAddr)>> + Send + Sync + 'a>>>,
);
impl<'a> Stream for Incoming<'a> {
type Item = Result<Socks5Socket<TcpStream>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut AsyncContext<'_>) -> Poll<Option<Self::Item>> {
loop {
if self.1.is_none() {
self.1 = Some(Box::pin(self.0.listener.accept()));
}
if let Some(f) = &mut self.1 {
let (socket, peer_addr) = ready!(f.as_mut().poll(cx))?;
self.1 = None;
let local_addr = socket.local_addr()?;
debug!(
"incoming connection from peer {} @ {}",
&peer_addr, &local_addr
);
let socket = Socks5Socket::new(socket, self.0.config.clone());
return Poll::Ready(Some(Ok(socket)));
}
}
}
}
pub struct Socks5Socket<T: AsyncRead + AsyncWrite + Unpin> {
inner: T,
config: Arc<Config>,
auth: AuthenticationMethod,
target_addr: Option<TargetAddr>,
cmd: Option<Socks5Command>,
reply_ip: Option<IpAddr>,
}
impl<T: AsyncRead + AsyncWrite + Unpin> Socks5Socket<T> {
pub fn new(socket: T, config: Arc<Config>) -> Self {
Socks5Socket {
inner: socket,
config,
auth: AuthenticationMethod::None,
target_addr: None,
cmd: None,
reply_ip: None,
}
}
pub fn set_reply_ip(&mut self, addr: IpAddr) {
self.reply_ip = Some(addr);
}
pub async fn upgrade_to_socks5(mut self) -> Result<Socks5Socket<T>> {
trace!("upgrading to socks5...");
if !self.config.skip_auth {
let methods = self.get_methods().await?;
self.can_accept_method(methods).await?;
if self.config.auth.is_some() {
let credentials = self.authenticate().await?;
self.auth = AuthenticationMethod::Password {
username: credentials.0,
password: credentials.1,
};
}
} else {
debug!("skipping auth");
}
match self.request().await {
Ok(_) => {}
Err(SocksError::ReplyError(e)) => {
self.reply_error(&e).await?;
return Err(e.into()); }
Err(d) => return Err(d),
};
Ok(self)
}
async fn get_methods(&mut self) -> Result<Vec<u8>> {
trace!("Socks5Socket: get_methods()");
let [version, methods_len] =
read_exact!(self.inner, [0u8; 2]).context("Can't read methods")?;
debug!(
"Handshake headers: [version: {version}, methods len: {len}]",
version = version,
len = methods_len,
);
if version != consts::SOCKS5_VERSION {
return Err(SocksError::UnsupportedSocksVersion(version));
}
let methods = read_exact!(self.inner, vec![0u8; methods_len as usize])
.context("Can't get methods.")?;
debug!("methods supported sent by the client: {:?}", &methods);
Ok(methods)
}
async fn can_accept_method(&mut self, client_methods: Vec<u8>) -> Result<()> {
let method_supported;
if self.config.auth.is_some() {
method_supported = consts::SOCKS5_AUTH_METHOD_PASSWORD;
} else {
method_supported = consts::SOCKS5_AUTH_METHOD_NONE;
}
if !client_methods.contains(&method_supported) {
debug!("Don't support this auth method, reply with (0xff)");
self.inner
.write(&[
consts::SOCKS5_VERSION,
consts::SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE,
])
.await
.context("Can't reply with method not acceptable.")?;
return Err(SocksError::AuthMethodUnacceptable(client_methods));
}
debug!(
"Reply with method {} ({})",
AuthenticationMethod::from_u8(method_supported).context("Method not supported")?,
method_supported
);
self.inner
.write(&[consts::SOCKS5_VERSION, method_supported])
.await
.context("Can't reply with method auth-none")?;
Ok(())
}
async fn authenticate(&mut self) -> Result<(String, String)> {
trace!("Socks5Socket: authenticate()");
let [version, user_len] =
read_exact!(self.inner, [0u8; 2]).context("Can't read user len")?;
debug!(
"Auth: [version: {version}, user len: {len}]",
version = version,
len = user_len,
);
if user_len < 1 {
return Err(SocksError::AuthenticationFailed(format!(
"Username malformed ({} chars)",
user_len
)));
}
let username =
read_exact!(self.inner, vec![0u8; user_len as usize]).context("Can't get username.")?;
debug!("username bytes: {:?}", &username);
let [pass_len] = read_exact!(self.inner, [0u8; 1]).context("Can't read pass len")?;
debug!("Auth: [pass len: {len}]", len = pass_len,);
if pass_len < 1 {
return Err(SocksError::AuthenticationFailed(format!(
"Password malformed ({} chars)",
pass_len
)));
}
let password =
read_exact!(self.inner, vec![0u8; pass_len as usize]).context("Can't get password.")?;
debug!("password bytes: {:?}", &password);
let username = String::from_utf8(username).context("Failed to convert username")?;
let password = String::from_utf8(password).context("Failed to convert password")?;
let auth = self.config.auth.as_ref().context("No auth module")?;
if auth.authenticate(&username, &password) {
self.inner
.write(&[1, consts::SOCKS5_REPLY_SUCCEEDED])
.await
.context("Can't reply auth success")?;
} else {
self.inner
.write(&[1, consts::SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE])
.await
.context("Can't reply with auth method not acceptable.")?;
return Err(SocksError::AuthenticationRejected(format!(
"Authentication with username `{}`, rejected.",
username
)));
}
info!("User `{}` logged successfully.", username);
Ok((username, password))
}
async fn request(&mut self) -> Result<()> {
self.read_command().await?;
if self.config.dns_resolve {
self.resolve_dns().await?;
} else {
debug!("Domain won't be resolved because `dns_resolve`'s config has been turned off.")
}
if self.config.execute_command {
self.execute_command().await?;
}
Ok(())
}
async fn reply_error(&mut self, error: &ReplyError) -> Result<()> {
let reply = new_reply(error, "0.0.0.0:0".parse().unwrap());
debug!("reply error to be written: {:?}", &reply);
self.inner
.write(&reply)
.await
.context("Can't write the reply!")?;
self.inner.flush().await.context("Can't flush the reply!")?;
Ok(())
}
async fn read_command(&mut self) -> Result<()> {
let [version, cmd, rsv, address_type] =
read_exact!(self.inner, [0u8; 4]).context("Malformed request")?;
debug!(
"Request: [version: {version}, command: {cmd}, rev: {rsv}, address_type: {address_type}]",
version = version,
cmd = cmd,
rsv = rsv,
address_type = address_type,
);
if version != consts::SOCKS5_VERSION {
return Err(SocksError::UnsupportedSocksVersion(version));
}
match Socks5Command::from_u8(cmd) {
None => return Err(ReplyError::CommandNotSupported.into()),
Some(cmd) => match cmd {
Socks5Command::TCPConnect => {
self.cmd = Some(cmd);
}
Socks5Command::UDPAssociate => {
if !self.config.allow_udp {
return Err(ReplyError::CommandNotSupported.into());
}
self.cmd = Some(cmd);
}
Socks5Command::TCPBind => return Err(ReplyError::CommandNotSupported.into()),
},
}
let target_addr = read_address(&mut self.inner, address_type)
.await
.map_err(|e| {
error!("{:#}", e);
ReplyError::AddressTypeNotSupported
})?;
self.target_addr = Some(target_addr);
debug!("Request target is {}", self.target_addr.as_ref().unwrap());
Ok(())
}
pub async fn resolve_dns(&mut self) -> Result<()> {
trace!("resolving dns");
if let Some(target_addr) = self.target_addr.take() {
self.target_addr = match target_addr {
TargetAddr::Domain(_, _) => Some(target_addr.resolve_dns().await?),
TargetAddr::Ip(_) => Some(target_addr),
};
}
Ok(())
}
async fn execute_command(&mut self) -> Result<()> {
match &self.cmd {
None => Err(ReplyError::CommandNotSupported.into()),
Some(cmd) => match cmd {
Socks5Command::TCPBind => Err(ReplyError::CommandNotSupported.into()),
Socks5Command::TCPConnect => return self.execute_command_connect().await,
Socks5Command::UDPAssociate => {
if self.config.allow_udp {
return self.execute_command_udp_assoc().await;
} else {
Err(ReplyError::CommandNotSupported.into())
}
}
},
}
}
async fn execute_command_connect(&mut self) -> Result<()> {
let addr = self
.target_addr
.as_ref()
.context("target_addr empty")?
.to_socket_addrs()?
.next()
.context("unreachable")?;
let outbound = tcp_connect_with_timeout(addr, self.config.request_timeout).await?;
debug!("Connected to remote destination");
self.inner
.write(&new_reply(
&ReplyError::Succeeded,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0),
))
.await
.context("Can't write successful reply")?;
self.inner.flush().await.context("Can't flush the reply!")?;
debug!("Wrote success");
transfer(&mut self.inner, outbound).await
}
async fn execute_command_udp_assoc(&mut self) -> Result<()> {
let _not_used = self.target_addr.as_ref();
let peer_sock = UdpSocket::bind("[::]:0").await?;
self.inner
.write(&new_reply(
&ReplyError::Succeeded,
SocketAddr::new(
self.reply_ip.context("invalid reply ip")?,
peer_sock.local_addr()?.port(),
),
))
.await
.context("Can't write successful reply")?;
debug!("Wrote success");
transfer_udp(peer_sock).await?;
Ok(())
}
pub fn target_addr(&self) -> Option<&TargetAddr> {
self.target_addr.as_ref()
}
pub fn auth(&self) -> &AuthenticationMethod {
&self.auth
}
}
async fn transfer<I, O>(mut inbound: I, mut outbound: O) -> Result<()>
where
I: AsyncRead + AsyncWrite + Unpin,
O: AsyncRead + AsyncWrite + Unpin,
{
match tokio::io::copy_bidirectional(&mut inbound, &mut outbound).await {
Ok(res) => info!("transfer closed ({}, {})", res.0, res.1),
Err(err) => error!("transfer error: {:?}", err),
};
Ok(())
}
async fn handle_udp_request(inbound: &UdpSocket, outbound: &UdpSocket) -> Result<()> {
let mut buf = vec![0u8; 0x10000];
loop {
let (size, client_addr) = inbound.recv_from(&mut buf).await?;
debug!("Server recieve udp from {}", client_addr);
inbound.connect(client_addr).await?;
let (frag, target_addr, data) = parse_udp_request(&buf[..size]).await?;
if frag != 0 {
debug!("Discard UDP frag packets sliently.");
return Ok(());
}
debug!("Server forward to packet to {}", target_addr);
let mut target_addr = target_addr
.to_socket_addrs()?
.next()
.context("unreachable")?;
target_addr.set_ip(match target_addr.ip() {
std::net::IpAddr::V4(v4) => std::net::IpAddr::V6(v4.to_ipv6_mapped()),
v6 @ std::net::IpAddr::V6(_) => v6,
});
outbound.send_to(data, target_addr).await?;
}
}
async fn handle_udp_response(inbound: &UdpSocket, outbound: &UdpSocket) -> Result<()> {
let mut buf = vec![0u8; 0x10000];
loop {
let (size, remote_addr) = outbound.recv_from(&mut buf).await?;
debug!("Recieve packet from {}", remote_addr);
let mut data = new_udp_header(remote_addr)?;
data.extend_from_slice(&buf[..size]);
inbound.send(&data).await?;
}
}
async fn transfer_udp(inbound: UdpSocket) -> Result<()> {
let outbound = UdpSocket::bind("[::]:0").await?;
let req_fut = handle_udp_request(&inbound, &outbound);
let res_fut = handle_udp_response(&inbound, &outbound);
match try_join!(req_fut, res_fut) {
Ok(_) => {}
Err(error) => return Err(error),
}
Ok(())
}
impl<T> AsyncRead for Socks5Socket<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_read(context, buf)
}
}
impl<T> AsyncWrite for Socks5Socket<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(context, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(context)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(context)
}
}
fn new_reply(error: &ReplyError, sock_addr: SocketAddr) -> Vec<u8> {
let (addr_type, mut ip_oct, mut port) = match sock_addr {
SocketAddr::V4(sock) => (
consts::SOCKS5_ADDR_TYPE_IPV4,
sock.ip().octets().to_vec(),
sock.port().to_be_bytes().to_vec(),
),
SocketAddr::V6(sock) => (
consts::SOCKS5_ADDR_TYPE_IPV6,
sock.ip().octets().to_vec(),
sock.port().to_be_bytes().to_vec(),
),
};
let mut reply = vec![
consts::SOCKS5_VERSION,
error.as_u8(), 0x00, addr_type, ];
reply.append(&mut ip_oct);
reply.append(&mut port);
reply
}
#[cfg(test)]
mod test {
use crate::server::Socks5Server;
use tokio_test::block_on;
#[test]
fn test_bind() {
let f = async {
let _server = Socks5Server::bind("127.0.0.1:1080").await.unwrap();
};
block_on(f);
}
}