use core::convert::TryFrom;
use bytes::{Buf, BytesMut};
use super::addr::field_port;
use super::{field, Cmd, Decoder, Encodable, Encoder, Error, HasAddr, Rep, Result, SocksAddr, Ver};
#[derive(Debug, PartialEq, Clone)]
pub struct Packet<T: AsRef<[u8]>>(HasAddr<T>);
impl<T: AsRef<[u8]>> Packet<T> {
pub fn new_unchecked(buffer: T) -> Packet<T> {
Packet(HasAddr::new_unchecked(field::ATYP, buffer))
}
pub fn new_checked(buffer: T) -> Result<Packet<T>> {
let packet = Self::new_unchecked(buffer);
packet.check_len()?;
Ok(packet)
}
#[inline]
fn buffer_ref(&self) -> &[u8] {
self.0.buffer.as_ref()
}
#[inline]
pub fn check_len(&self) -> Result<()> {
self.0.check_addr_len()?;
if self.buffer_ref().len() > self.total_len() {
Err(Error::Malformed)
} else {
Ok(())
}
}
#[inline]
pub fn total_len(&self) -> usize {
self.0.len_to_port()
}
#[inline]
pub fn version(&self) -> u8 {
let data = self.buffer_ref();
data[field::VER]
}
#[inline]
pub fn cmd_or_rep(&self) -> u8 {
let data = self.buffer_ref();
data[field::CMD_OR_REP]
}
#[inline]
pub fn atyp(&self) -> u8 {
self.0.atyp()
}
#[inline]
pub fn port(&self) -> u16 {
self.0.port()
}
pub fn take_buffer(self) -> T {
self.0.take_buffer()
}
}
impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> {
#[inline]
pub fn addr(&self) -> &'a [u8] {
self.0.addr()
}
#[inline]
pub fn socks_addr(&self) -> &'a [u8] {
self.0.socks_addr()
}
}
impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
#[inline]
fn buffer_mut(&mut self) -> &mut [u8] {
self.0.buffer.as_mut()
}
#[inline]
pub fn set_version(&mut self, value: u8) {
let data = self.buffer_mut();
data[field::VER] = value;
}
#[inline]
pub fn set_cmd_or_rep(&mut self, value: u8) {
let data = self.buffer_mut();
data[field::CMD_OR_REP] = value;
}
#[inline]
pub fn set_atyp(&mut self, value: u8) {
self.0.set_atyp(value)
}
#[inline]
pub fn set_addr(&mut self, value: &[u8]) {
self.0.set_addr(value)
}
#[inline]
pub fn set_port(&mut self, value: u16) {
self.0.set_port(value)
}
#[inline]
pub fn set_socks_addr(&mut self, value: &[u8]) {
self.0.set_socks_addr(value)
}
#[inline]
pub fn addr_mut(&mut self) -> &mut [u8] {
self.0.addr_mut()
}
#[inline]
pub fn socks_addr_mut(&mut self) -> &mut [u8] {
self.0.socks_addr_mut()
}
}
impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> {
#[inline]
fn as_ref(&self) -> &[u8] {
self.buffer_ref()
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct CmdRepr {
pub ver: Ver,
pub cmd: Cmd,
pub addr: SocksAddr,
}
impl CmdRepr {
pub fn parse<T: AsRef<[u8]> + ?Sized>(packet: &Packet<&T>) -> Result<CmdRepr> {
packet.check_len()?;
if packet.version() != Ver::SOCKS5 as u8 {
return Err(Error::Malformed);
}
if packet.as_ref()[field::RSV] != 0 {
return Err(Error::Malformed);
}
Ok(CmdRepr {
ver: Ver::SOCKS5,
cmd: Cmd::try_from(packet.cmd_or_rep())?,
addr: SocksAddr::try_from(packet.socks_addr())?,
})
}
pub fn buffer_len(&self) -> usize {
let addr_len = self.addr.addr_len();
field_port(field::ADDR_PORT.start, addr_len).end
}
pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, packet: &mut Packet<T>) {
packet.set_version(Ver::SOCKS5.into());
packet.set_cmd_or_rep(self.cmd as u8);
packet.set_socks_addr(&self.addr.to_vec());
}
}
impl Decoder<CmdRepr> for CmdRepr {
fn decode(src: &mut BytesMut) -> Result<Option<Self>> {
let pkt = Packet::new_unchecked(src.as_ref());
match CmdRepr::parse(&pkt) {
Ok(repr) => {
src.advance(repr.buffer_len());
Ok(Some(repr))
}
Err(Error::Truncated) => Ok(None),
Err(err) => Err(err),
}
}
}
impl Encodable for CmdRepr {
fn encode_into(&self, dst: &mut BytesMut) {
if dst.len() < self.buffer_len() {
dst.resize(self.buffer_len(), 0);
}
let mut pkt = Packet::new_unchecked(dst);
self.emit(&mut pkt);
}
}
impl Encoder<CmdRepr> for CmdRepr {
fn encode(item: &CmdRepr, dst: &mut BytesMut) {
item.encode_into(dst);
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct RepRepr {
pub ver: Ver,
pub rep: Rep,
pub addr: SocksAddr,
}
impl RepRepr {
pub fn parse<T: AsRef<[u8]> + ?Sized>(packet: &Packet<&T>) -> Result<RepRepr> {
packet.check_len()?;
if packet.version() != Ver::SOCKS5 as u8 {
return Err(Error::Malformed);
}
if packet.as_ref()[field::RSV] != 0 {
return Err(Error::Malformed);
}
Ok(RepRepr {
ver: Ver::SOCKS5,
rep: Rep::try_from(packet.cmd_or_rep())?,
addr: SocksAddr::try_from(packet.socks_addr())?,
})
}
pub fn buffer_len(&self) -> usize {
let addr_len = self.addr.addr_len();
field_port(field::ADDR_PORT.start, addr_len).end
}
pub fn emit<T: AsRef<[u8]> + AsMut<[u8]>>(&self, packet: &mut Packet<T>) {
packet.set_version(Ver::SOCKS5.into());
packet.set_cmd_or_rep(self.rep as u8);
packet.set_socks_addr(&self.addr.to_vec());
}
}
impl Decoder<RepRepr> for RepRepr {
fn decode(src: &mut BytesMut) -> Result<Option<Self>> {
let pkt = Packet::new_unchecked(src.as_ref());
match RepRepr::parse(&pkt) {
Ok(repr) => {
src.advance(repr.buffer_len());
Ok(Some(repr))
}
Err(Error::Truncated) => Ok(None),
Err(err) => Err(err),
}
}
}
impl Encodable for RepRepr {
fn encode_into(&self, dst: &mut BytesMut) {
if dst.len() < self.buffer_len() {
dst.resize(self.buffer_len(), 0);
}
let mut pkt = Packet::new_unchecked(dst);
self.emit(&mut pkt);
}
}
impl Encoder<RepRepr> for RepRepr {
fn encode(item: &RepRepr, dst: &mut BytesMut) {
item.encode_into(dst);
}
}
#[cfg(test)]
mod tests {
use bytes::BytesMut;
#[cfg(any(feature = "proto-ipv4", feature = "proto-ipv6"))]
use smolsocket::SocketAddr;
#[cfg(feature = "proto-ipv4")]
use smoltcp::wire::Ipv4Address;
#[cfg(feature = "proto-ipv6")]
use smoltcp::wire::Ipv6Address;
use crate::Atyp;
use super::*;
#[cfg(feature = "proto-ipv4")]
#[test]
fn test_cmd_invalid_len() {
let mut truncated_bytes = vec![0x00 as u8; 4];
let mut truncated = Packet::new_unchecked(&mut truncated_bytes);
truncated.set_version(Ver::SOCKS5 as u8);
truncated.set_atyp(Atyp::V4 as u8);
assert_eq!(truncated.check_len(), Err(Error::Truncated));
let mut truncated_bytes_mut = BytesMut::new();
truncated_bytes_mut.extend(truncated_bytes);
assert_eq!(CmdRepr::decode(&mut truncated_bytes_mut), Ok(None));
let mut truncated_bytes = vec![0x00 as u8; 5];
let mut truncated = Packet::new_unchecked(&mut truncated_bytes);
truncated.set_version(Ver::SOCKS5 as u8);
truncated.set_atyp(Atyp::V4 as u8);
assert_eq!(truncated.check_len(), Err(Error::Truncated));
assert_eq!(truncated.total_len(), 10);
let mut malformed_bytes = vec![0x00 as u8; truncated.total_len() + 1];
let mut malformed = Packet::new_unchecked(&mut malformed_bytes);
malformed.set_version(Ver::SOCKS5 as u8);
malformed.set_atyp(Atyp::V4 as u8);
assert_eq!(malformed.check_len(), Err(Error::Malformed));
let mut malformed_bytes_mut = BytesMut::new();
malformed_bytes_mut.extend(malformed_bytes);
assert_eq!(
CmdRepr::decode(&mut malformed_bytes_mut),
Err(Error::Malformed)
);
}
#[cfg(feature = "proto-ipv4")]
#[test]
fn test_cmd_connect_ip4() {
let socket_addr = SocketAddr::new_ip4_port(127, 0, 0, 1, 80);
let socks_addr = SocksAddr::SocketAddr(socket_addr);
let repr = CmdRepr {
ver: Ver::SOCKS5,
cmd: Cmd::Connect,
addr: socks_addr.clone(),
};
assert_eq!(repr.buffer_len(), 10);
let mut bytes = vec![0x00 as u8; repr.buffer_len()];
let mut pkt = Packet::new_unchecked(&mut bytes);
assert_eq!(pkt.atyp(), 0);
pkt.set_atyp(Atyp::V4 as u8);
assert_eq!(pkt.atyp(), Atyp::V4 as u8);
assert_eq!(&pkt.addr_mut(), &Ipv4Address::new(0, 0, 0, 0).as_bytes());
pkt.set_addr(Ipv4Address::new(192, 168, 0, 1).as_bytes());
assert_eq!(
&pkt.addr_mut(),
&Ipv4Address::new(192, 168, 0, 1).as_bytes()
);
assert_eq!(pkt.port(), 0);
pkt.set_port(8080);
assert_eq!(pkt.port(), 8080);
repr.emit(&mut pkt);
assert_eq!(pkt.socks_addr_mut(), socks_addr.to_vec().as_slice());
let pkt_to_parse = Packet::new_checked(pkt.as_ref()).expect("should be valid");
assert_eq!(
pkt_to_parse.addr(),
Ipv4Address::new(127, 0, 0, 1).as_bytes()
);
let parsed = CmdRepr::parse(&pkt_to_parse).expect("should parse");
assert_eq!(parsed, repr);
assert_eq!(parsed.addr.atyp(), Atyp::V4);
if let SocksAddr::SocketAddr(SocketAddr::V4(socket_addr)) = parsed.addr {
assert!(socket_addr.addr.is_loopback());
}
let mut bytes_mut = BytesMut::new();
CmdRepr::encode(&repr, &mut bytes_mut);
let decoded = CmdRepr::decode(&mut bytes_mut);
assert_eq!(decoded, Ok(Some(repr)));
}
#[cfg(feature = "proto-ipv6")]
#[test]
fn test_cmd_connect_ip6() {
let socket_addr = SocketAddr::new_ip6_port(0, 0, 0, 0, 0, 0, 0, 1, 80);
let socks_addr = SocksAddr::SocketAddr(socket_addr);
let repr = CmdRepr {
ver: Ver::SOCKS5,
cmd: Cmd::Connect,
addr: socks_addr.clone(),
};
assert_eq!(repr.buffer_len(), 22);
let mut bytes = vec![0x00 as u8; repr.buffer_len()];
let mut pkt = Packet::new_unchecked(&mut bytes);
assert_eq!(pkt.atyp(), 0);
pkt.set_atyp(Atyp::V6 as u8);
assert_eq!(pkt.atyp(), Atyp::V6 as u8);
assert_eq!(
&pkt.addr_mut(),
&Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 0).as_bytes()
);
pkt.set_addr(Ipv6Address::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1).as_bytes());
assert_eq!(
&pkt.addr_mut(),
&Ipv6Address::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1).as_bytes()
);
assert_eq!(pkt.port(), 0);
pkt.set_port(8080);
assert_eq!(pkt.port(), 8080);
repr.emit(&mut pkt);
assert_eq!(pkt.socks_addr_mut(), socks_addr.to_vec().as_slice());
let pkt_to_parse = Packet::new_checked(pkt.as_ref()).expect("should be valid");
assert_eq!(
pkt_to_parse.addr(),
Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1).as_bytes()
);
let parsed = CmdRepr::parse(&pkt_to_parse).expect("should parse");
assert_eq!(parsed, repr);
assert_eq!(parsed.addr.atyp(), Atyp::V6);
if let SocksAddr::SocketAddr(SocketAddr::V6(socket_addr)) = parsed.addr {
assert!(socket_addr.addr.is_loopback());
}
let mut bytes_mut = BytesMut::new();
CmdRepr::encode(&repr, &mut bytes_mut);
let decoded = CmdRepr::decode(&mut bytes_mut);
assert_eq!(decoded, Ok(Some(repr)));
}
#[test]
fn test_cmd_connect_domain() {
let socks_addr = SocksAddr::DomainPort("google.com".to_string(), 443);
let repr = CmdRepr {
ver: Ver::SOCKS5,
cmd: Cmd::Connect,
addr: socks_addr.clone(),
};
assert_eq!(repr.buffer_len(), 17);
let mut bytes = vec![0x00 as u8; repr.buffer_len()];
let mut pkt = Packet::new_unchecked(&mut bytes);
assert_eq!(pkt.atyp(), 0);
pkt.set_atyp(Atyp::Domain as u8);
assert_eq!(pkt.atyp(), Atyp::Domain as u8);
assert_eq!(pkt.addr_mut()[0], 0);
pkt.addr_mut()[0] = 10;
assert_eq!(&pkt.addr_mut()[1..], b"\0\0\0\0\0\0\0\0\0\0");
pkt.set_addr(b" ");
assert_eq!(&pkt.addr_mut()[1..], b" ");
assert_eq!(pkt.port(), 0);
pkt.set_port(8080);
assert_eq!(pkt.port(), 8080);
repr.emit(&mut pkt);
assert_eq!(pkt.socks_addr_mut(), socks_addr.to_vec().as_slice());
let pkt_to_parse = Packet::new_checked(pkt.as_ref()).expect("should be valid");
assert_eq!(pkt_to_parse.addr()[0], 10);
assert_eq!(&pkt_to_parse.addr()[1..], b"google.com");
let parsed = CmdRepr::parse(&pkt_to_parse).expect("should parse");
assert_eq!(parsed, repr);
assert_eq!(parsed.addr.atyp(), Atyp::Domain);
if let SocksAddr::DomainPort(domain, port) = parsed.addr {
assert_eq!(domain, "google.com".to_string());
assert_eq!(port, 443);
}
let mut bytes_mut = BytesMut::new();
CmdRepr::encode(&repr, &mut bytes_mut);
let decoded = CmdRepr::decode(&mut bytes_mut);
assert_eq!(decoded, Ok(Some(repr)));
}
#[cfg(feature = "proto-ipv4")]
#[test]
fn test_rep_invalid_len() {
let mut truncated_bytes = vec![0x00 as u8; 4];
let mut truncated = Packet::new_unchecked(&mut truncated_bytes);
truncated.set_version(Ver::SOCKS5 as u8);
truncated.set_atyp(Atyp::V4 as u8);
assert_eq!(truncated.check_len(), Err(Error::Truncated));
let mut truncated_bytes_mut = BytesMut::new();
truncated_bytes_mut.extend(truncated_bytes);
assert_eq!(RepRepr::decode(&mut truncated_bytes_mut), Ok(None));
let mut truncated_bytes = vec![0x00 as u8; 5];
let mut truncated = Packet::new_unchecked(&mut truncated_bytes);
truncated.set_version(Ver::SOCKS5 as u8);
truncated.set_atyp(Atyp::V4 as u8);
assert_eq!(truncated.check_len(), Err(Error::Truncated));
assert_eq!(truncated.total_len(), 10);
let mut malformed_bytes = vec![0x00 as u8; truncated.total_len() + 1];
let mut malformed = Packet::new_unchecked(&mut malformed_bytes);
malformed.set_version(Ver::SOCKS5 as u8);
malformed.set_atyp(Atyp::V4 as u8);
assert_eq!(malformed.check_len(), Err(Error::Malformed));
let mut malformed_bytes_mut = BytesMut::new();
malformed_bytes_mut.extend(malformed_bytes);
assert_eq!(
RepRepr::decode(&mut malformed_bytes_mut),
Err(Error::Malformed)
);
}
#[cfg(feature = "proto-ipv4")]
#[test]
fn test_rep_success_ip4() {
let socket_addr = SocketAddr::new_ip4_port(127, 0, 0, 1, 80);
let socks_addr = SocksAddr::SocketAddr(socket_addr);
let repr = RepRepr {
ver: Ver::SOCKS5,
rep: Rep::Success,
addr: socks_addr.clone(),
};
assert_eq!(repr.buffer_len(), 10);
let mut bytes = vec![0x00 as u8; repr.buffer_len()];
let mut pkt = Packet::new_unchecked(&mut bytes);
assert_eq!(pkt.atyp(), 0);
pkt.set_atyp(Atyp::V4 as u8);
assert_eq!(pkt.atyp(), Atyp::V4 as u8);
assert_eq!(&pkt.addr_mut(), &Ipv4Address::new(0, 0, 0, 0).as_bytes());
pkt.set_addr(Ipv4Address::new(192, 168, 0, 1).as_bytes());
assert_eq!(
&pkt.addr_mut(),
&Ipv4Address::new(192, 168, 0, 1).as_bytes()
);
assert_eq!(pkt.port(), 0);
pkt.set_port(8080);
assert_eq!(pkt.port(), 8080);
repr.emit(&mut pkt);
assert_eq!(pkt.socks_addr_mut(), socks_addr.to_vec().as_slice());
let pkt_to_parse = Packet::new_checked(pkt.as_ref()).expect("should be valid");
assert_eq!(
pkt_to_parse.addr(),
Ipv4Address::new(127, 0, 0, 1).as_bytes()
);
let parsed = RepRepr::parse(&pkt_to_parse).expect("should parse");
assert_eq!(parsed, repr);
assert_eq!(parsed.addr.atyp(), Atyp::V4);
if let SocksAddr::SocketAddr(SocketAddr::V4(socket_addr)) = parsed.addr {
assert!(socket_addr.addr.is_loopback());
}
let mut bytes_mut = BytesMut::new();
RepRepr::encode(&repr, &mut bytes_mut);
let decoded = RepRepr::decode(&mut bytes_mut);
assert_eq!(decoded, Ok(Some(repr)));
}
#[cfg(feature = "proto-ipv6")]
#[test]
fn test_rep_success_ip6() {
let socket_addr = SocketAddr::new_ip6_port(0, 0, 0, 0, 0, 0, 0, 1, 80);
let socks_addr = SocksAddr::SocketAddr(socket_addr);
let repr = RepRepr {
ver: Ver::SOCKS5,
rep: Rep::Success,
addr: socks_addr.clone(),
};
assert_eq!(repr.buffer_len(), 22);
let mut bytes = vec![0x00 as u8; repr.buffer_len()];
let mut pkt = Packet::new_unchecked(&mut bytes);
assert_eq!(pkt.atyp(), 0);
pkt.set_atyp(Atyp::V6 as u8);
assert_eq!(pkt.atyp(), Atyp::V6 as u8);
assert_eq!(
&pkt.addr_mut(),
&Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 0).as_bytes()
);
pkt.set_addr(Ipv6Address::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1).as_bytes());
assert_eq!(
&pkt.addr_mut(),
&Ipv6Address::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1).as_bytes()
);
assert_eq!(pkt.port(), 0);
pkt.set_port(8080);
assert_eq!(pkt.port(), 8080);
repr.emit(&mut pkt);
assert_eq!(pkt.socks_addr_mut(), socks_addr.to_vec().as_slice());
let pkt_to_parse = Packet::new_checked(pkt.as_ref()).expect("should be valid");
assert_eq!(
pkt_to_parse.addr(),
Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 1).as_bytes()
);
let parsed = RepRepr::parse(&pkt_to_parse).expect("should parse");
assert_eq!(parsed, repr);
assert_eq!(parsed.addr.atyp(), Atyp::V6);
if let SocksAddr::SocketAddr(SocketAddr::V6(socket_addr)) = parsed.addr {
assert!(socket_addr.addr.is_loopback());
}
let mut bytes_mut = BytesMut::new();
RepRepr::encode(&repr, &mut bytes_mut);
let decoded = RepRepr::decode(&mut bytes_mut);
assert_eq!(decoded, Ok(Some(repr)));
}
#[test]
fn test_rep_success_domain() {
let socks_addr = SocksAddr::DomainPort("google.com".to_string(), 443);
let repr = RepRepr {
ver: Ver::SOCKS5,
rep: Rep::Success,
addr: socks_addr.clone(),
};
assert_eq!(repr.buffer_len(), 17);
let mut bytes = vec![0x00 as u8; repr.buffer_len()];
let mut pkt = Packet::new_unchecked(&mut bytes);
assert_eq!(pkt.atyp(), 0);
pkt.set_atyp(Atyp::Domain as u8);
assert_eq!(pkt.atyp(), Atyp::Domain as u8);
assert_eq!(pkt.addr_mut()[0], 0);
pkt.addr_mut()[0] = 10;
assert_eq!(&pkt.addr_mut()[1..], b"\0\0\0\0\0\0\0\0\0\0");
pkt.set_addr(b" ");
assert_eq!(&pkt.addr_mut()[1..], b" ");
assert_eq!(pkt.port(), 0);
pkt.set_port(8080);
assert_eq!(pkt.port(), 8080);
repr.emit(&mut pkt);
assert_eq!(pkt.socks_addr_mut(), socks_addr.to_vec().as_slice());
let pkt_to_parse = Packet::new_checked(pkt.as_ref()).expect("should be valid");
assert_eq!(pkt_to_parse.addr()[0], 10);
assert_eq!(&pkt_to_parse.addr()[1..], b"google.com");
let parsed = RepRepr::parse(&pkt_to_parse).expect("should parse");
assert_eq!(parsed, repr);
assert_eq!(parsed.addr.atyp(), Atyp::Domain);
if let SocksAddr::DomainPort(domain, port) = parsed.addr {
assert_eq!(domain, "google.com".to_string());
assert_eq!(port, 443);
}
let mut bytes_mut = BytesMut::new();
RepRepr::encode(&repr, &mut bytes_mut);
let decoded = RepRepr::decode(&mut bytes_mut);
assert_eq!(decoded, Ok(Some(repr)));
}
}