use crate::read_exact;
use crate::util::target_addr::{read_address, TargetAddr};
use crate::{consts, AuthenticationMethod, ReplyError, Result, SocksError};
use anyhow::Context;
use async_std::{
future,
net::{TcpListener, TcpStream, ToSocketAddrs as AsyncToSocketAddrs},
sync::Arc,
task::{Context as AsyncContext, Poll},
};
use futures::{
future::{Either, Future},
stream::Stream,
AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt,
};
use std::io;
use std::net::ToSocketAddrs as StdToSocketAddrs;
use std::pin::Pin;
#[derive(Clone)]
pub struct Config {
request_timeout: u64,
skip_auth: bool,
dns_resolve: bool,
execute_command: bool,
auth: Option<Arc<dyn Authentication>>,
}
impl Default for Config {
fn default() -> Self {
Config {
request_timeout: 10,
skip_auth: false,
dns_resolve: true,
execute_command: true,
auth: None,
}
}
}
pub trait Authentication: Send + Sync {
fn authenticate(&self, username: &str, password: &str) -> bool;
}
pub struct SimpleUserPassword {
pub username: String,
pub password: String,
}
impl Authentication for SimpleUserPassword {
fn authenticate(&self, username: &str, password: &str) -> bool {
username == &self.username && password == &self.password
}
}
impl Config {
pub fn set_request_timeout(&mut self, n: u64) -> &mut Self {
self.request_timeout = n;
self
}
pub fn set_skip_auth(&mut self, value: bool) -> &mut Self {
self.skip_auth = value;
self
}
pub fn set_authentication<T: Authentication + 'static>(
&mut self,
authentication: T,
) -> &mut Self {
self.auth = Some(Arc::new(authentication));
self
}
pub fn set_execute_command(&mut self, value: bool) -> &mut Self {
self.execute_command = value;
self
}
pub fn set_dns_resolve(&mut self, value: bool) -> &mut Self {
self.dns_resolve = value;
self
}
}
pub struct Socks5Server {
listener: TcpListener,
config: Arc<Config>,
}
impl Socks5Server {
pub async fn bind<A: AsyncToSocketAddrs>(addr: A) -> io::Result<Socks5Server> {
let listener = TcpListener::bind(&addr).await?;
let config = Arc::new(Config::default());
Ok(Socks5Server { listener, config })
}
pub fn set_config(&mut self, config: Config) {
self.config = Arc::new(config);
}
pub fn incoming(&self) -> Incoming<'_> {
Incoming(self)
}
}
pub struct Incoming<'a>(&'a Socks5Server);
impl<'a> Stream for Incoming<'a> {
type Item = Result<Socks5Socket<TcpStream>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut AsyncContext<'_>) -> Poll<Option<Self::Item>> {
let fut = self.0.listener.accept();
futures::pin_mut!(fut);
let (socket, peer_addr) = futures::ready!(fut.poll(cx))?;
let local_addr = socket.local_addr()?;
debug!(
"incoming connection from peer {} @ {}",
&peer_addr, &local_addr
);
let socket = Socks5Socket::new(socket, self.0.config.clone());
Poll::Ready(Some(Ok(socket)))
}
}
pub struct Socks5Socket<T: AsyncRead + AsyncWrite + Unpin> {
inner: T,
config: Arc<Config>,
auth: AuthenticationMethod,
target_addr: Option<TargetAddr>,
}
impl<T: AsyncRead + AsyncWrite + Unpin> Socks5Socket<T> {
pub fn new(socket: T, config: Arc<Config>) -> Self {
Socks5Socket {
inner: socket,
config,
auth: AuthenticationMethod::None,
target_addr: None,
}
}
pub async fn upgrade_to_socks5(mut self) -> Result<Socks5Socket<T>> {
trace!("upgrading to socks5...");
if self.config.skip_auth == false {
let methods = self.get_methods().await?;
self.can_accept_method(methods).await?;
if self.config.auth.is_some() {
let credentials = self.authenticate().await?;
self.auth = AuthenticationMethod::Password {
username: credentials.0,
password: credentials.1,
};
}
} else {
debug!("skipping auth");
}
match self.request().await {
Ok(_) => {}
Err(SocksError::ReplyError(e)) => {
self.reply(&e).await?;
Err(e)? }
Err(d) => return Err(d),
};
Ok(self)
}
async fn get_methods(&mut self) -> Result<Vec<u8>> {
trace!("Socks5Socket: get_methods()");
let [version, methods_len] =
read_exact!(self.inner, [0u8; 2]).context("Can't read methods")?;
debug!(
"Handshake headers: [version: {version}, methods len: {len}]",
version = version,
len = methods_len,
);
if version != consts::SOCKS5_VERSION {
return Err(SocksError::UnsupportedSocksVersion(version));
}
let methods = read_exact!(self.inner, vec![0u8; methods_len as usize])
.context("Can't get methods.")?;
debug!("methods supported sent by the client: {:?}", &methods);
Ok(methods)
}
async fn can_accept_method(&mut self, client_methods: Vec<u8>) -> Result<()> {
let method_supported;
if self.config.auth.is_some() {
method_supported = consts::SOCKS5_AUTH_METHOD_PASSWORD;
} else {
method_supported = consts::SOCKS5_AUTH_METHOD_NONE;
}
if !client_methods.contains(&method_supported) {
debug!("Don't support this auth method, reply with (0xff)");
self.inner
.write(&[
consts::SOCKS5_VERSION,
consts::SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE,
])
.await
.context("Can't reply with method not acceptable.")?;
return Err(SocksError::AuthMethodUnacceptable(client_methods));
}
debug!(
"Reply with method {} ({})",
AuthenticationMethod::from_u8(method_supported).context("Method not supported")?,
method_supported
);
self.inner
.write(&[consts::SOCKS5_VERSION, method_supported])
.await
.context("Can't reply with method auth-none")?;
Ok(())
}
async fn authenticate(&mut self) -> Result<(String, String)> {
trace!("Socks5Socket: authenticate()");
let [version, user_len] =
read_exact!(self.inner, [0u8; 2]).context("Can't read user len")?;
debug!(
"Auth: [version: {version}, user len: {len}]",
version = version,
len = user_len,
);
if user_len < 1 {
return Err(SocksError::AuthenticationFailed(format!(
"Username malformed ({} chars)",
user_len
)));
}
let username =
read_exact!(self.inner, vec![0u8; user_len as usize]).context("Can't get username.")?;
debug!("username bytes: {:?}", &username);
let [pass_len] = read_exact!(self.inner, [0u8; 1]).context("Can't read pass len")?;
debug!("Auth: [pass len: {len}]", len = pass_len,);
if pass_len < 1 {
return Err(SocksError::AuthenticationFailed(format!(
"Password malformed ({} chars)",
pass_len
)));
}
let password =
read_exact!(self.inner, vec![0u8; pass_len as usize]).context("Can't get password.")?;
debug!("password bytes: {:?}", &password);
let username = String::from_utf8(username).context("Failed to convert username")?;
let password = String::from_utf8(password).context("Failed to convert password")?;
let auth = self.config.auth.as_ref().context("No auth module")?;
if auth.authenticate(&username, &password) {
self.inner
.write(&[1, consts::SOCKS5_REPLY_SUCCEEDED])
.await
.context("Can't reply auth success")?;
} else {
self.inner
.write(&[1, consts::SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE])
.await
.context("Can't reply with auth method not acceptable.")?;
return Err(SocksError::AuthenticationRejected(format!(
"Authentication with username `{}`, rejected.",
username
)));
}
info!("User `{}` logged successfully.", username);
Ok((username, password))
}
async fn request(&mut self) -> Result<()> {
self.read_command().await?;
if self.config.dns_resolve {
self.resolve_dns().await?;
} else {
debug!("Domain won't be resolved because `dns_resolve`'s config has been turned off.")
}
if self.config.execute_command {
self.execute_command().await?;
}
Ok(())
}
async fn reply(&mut self, error: &ReplyError) -> Result<()> {
let reply = &[
consts::SOCKS5_VERSION,
error.as_u8(), 0x00, 1, 127, 0,
0,
1,
0, 0,
];
debug!("reply error to be written: {:?}", &reply);
self.inner
.write(reply)
.await
.context("Can't write the reply!")?;
Ok(())
}
async fn read_command(&mut self) -> Result<()> {
let [version, cmd, rsv, address_type] =
read_exact!(self.inner, [0u8; 4]).context("Malformed request")?;
debug!(
"Request: [version: {version}, command: {cmd}, rev: {rsv}, address_type: {address_type}]",
version = version,
cmd = cmd,
rsv = rsv,
address_type = address_type,
);
if version != consts::SOCKS5_VERSION {
return Err(SocksError::UnsupportedSocksVersion(version));
}
if cmd != consts::SOCKS5_CMD_TCP_CONNECT {
return Err(ReplyError::CommandNotSupported)?;
}
let target_addr = read_address(&mut self.inner, address_type)
.await
.map_err(|e| {
error!("{:#}", e);
ReplyError::AddressTypeNotSupported
})?;
self.target_addr = Some(target_addr);
debug!("Request target is {}", self.target_addr.as_ref().unwrap());
Ok(())
}
pub async fn resolve_dns(&mut self) -> Result<()> {
trace!("resolving dns");
if let Some(target_addr) = self.target_addr.take() {
self.target_addr = match target_addr {
TargetAddr::Domain(_, _) => Some(target_addr.resolve_dns().await?),
TargetAddr::Ip(_) => Some(target_addr),
};
}
Ok(())
}
async fn execute_command(&mut self) -> Result<()> {
let addr = self
.target_addr
.as_ref()
.context("target_addr empty")?
.to_socket_addrs()?
.next()
.context("unreachable")?;
let outbound = match future::timeout(
std::time::Duration::from_secs(self.config.request_timeout),
TcpStream::connect(addr),
)
.await
{
Ok(e) => match e {
Ok(o) => o,
Err(e) => match e.kind() {
io::ErrorKind::ConnectionRefused => Err(ReplyError::ConnectionRefused)?,
io::ErrorKind::ConnectionAborted => Err(ReplyError::ConnectionNotAllowed)?,
io::ErrorKind::ConnectionReset => Err(ReplyError::ConnectionNotAllowed)?,
io::ErrorKind::NotConnected => Err(ReplyError::NetworkUnreachable)?,
_ => Err(e)?, },
},
Err(_) => Err(ReplyError::TtlExpired)?,
};
debug!("Connected to remote destination");
self.inner
.write_all(&[
consts::SOCKS5_VERSION,
consts::SOCKS5_REPLY_SUCCEEDED,
0x00, 1, 127, 0,
0,
1,
0, 0,
])
.await
.context("Can't write successful reply")?;
trace!("Wrote success");
transfer(&mut self.inner, outbound).await
}
pub fn target_addr(&self) -> Option<&TargetAddr> {
self.target_addr.as_ref()
}
pub fn auth(&self) -> &AuthenticationMethod {
&self.auth
}
}
async fn transfer<I, O>(mut inbound: I, outbound: O) -> Result<()>
where
I: AsyncRead + AsyncWrite + Unpin,
O: AsyncRead + AsyncWrite + Unpin,
{
let (mut ri, mut wi) = futures::io::AsyncReadExt::split(&mut inbound);
let (mut ro, mut wo) = futures::io::AsyncReadExt::split(outbound);
let inbound_to_outbound = futures::io::copy(&mut ri, &mut wo);
let outbound_to_inbound = futures::io::copy(&mut ro, &mut wi);
match futures::future::select(inbound_to_outbound, outbound_to_inbound).await {
Either::Left((Ok(data), _)) => {
info!("local closed -> remote target ({} bytes consumed)", data)
}
Either::Left((Err(err), _)) => {
error!("local closed -> remote target with error {:?}", err,)
}
Either::Right((Ok(data), _)) => {
info!("local <- remote target closed ({} bytes consumed)", data)
}
Either::Right((Err(err), _)) => {
error!("local <- remote target closed with error {:?}", err,)
}
};
Ok(())
}
impl<T> AsyncRead for Socks5Socket<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_read(context, buf)
}
}
impl<T> AsyncWrite for Socks5Socket<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(context, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(context)
}
fn poll_close(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_close(context)
}
}
#[cfg(test)]
mod test {
use crate::socks5::server::Socks5Server;
#[test]
fn test_bind() {
async {
let server = Socks5Server::bind("127.0.0.1:1080").await.unwrap();
};
}
}