#![deny(clippy::all)]
#![deny(clippy::unwrap_used)]
#![deny(clippy::expect_used)]
#![deny(clippy::std_instead_of_core)]
#![deny(clippy::std_instead_of_alloc)]
#![deny(clippy::alloc_instead_of_core)]
#![warn(missing_docs)]
extern crate alloc;
use alloc::vec;
use core::time::Duration;
use std::{
io,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, TcpStream, ToSocketAddrs},
};
#[cfg(feature = "client")]
pub use v4::client::Socks4Stream;
#[cfg(feature = "client")]
pub use v5::client::Socks5Stream;
#[cfg(feature = "bind")]
pub use v4::bind::Socks4Listener;
#[cfg(feature = "bind")]
pub use v5::bind::Socks5Listener;
#[cfg(feature = "udp")]
pub use v5::udp::Socks5Datagram;
pub use error::{is_io_socks2_error, unwrap_io_to_socks2_error, Error};
mod error;
mod ext_bytes;
#[cfg(feature = "udp")]
mod ext_io;
#[cfg(any(feature = "client", feature = "bind"))]
mod v4;
#[cfg(any(feature = "client", feature = "bind", feature = "udp"))]
mod v5;
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum TargetAddr {
Ip(SocketAddr),
Domain(String, u16),
}
impl core::fmt::Display for TargetAddr {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Ip(addr) => write!(f, "{addr}"),
Self::Domain(domain, port) => write!(f, "{domain}:{port}"),
}
}
}
impl ToSocketAddrs for TargetAddr {
type Iter = Iter;
fn to_socket_addrs(&self) -> io::Result<Iter> {
let inner = match *self {
Self::Ip(addr) => IterInner::Ip(Some(addr)),
Self::Domain(ref domain, port) => {
let it = (&**domain, port).to_socket_addrs()?;
IterInner::Domain(it)
}
};
Ok(Iter(inner))
}
}
enum IterInner {
Ip(Option<SocketAddr>),
Domain(vec::IntoIter<SocketAddr>),
}
pub struct Iter(IterInner);
impl Iterator for Iter {
type Item = SocketAddr;
fn next(&mut self) -> Option<SocketAddr> {
match self.0 {
IterInner::Ip(ref mut addr) => addr.take(),
IterInner::Domain(ref mut it) => it.next(),
}
}
}
pub trait ToTargetAddr {
fn to_target_addr(&self) -> io::Result<TargetAddr>;
}
impl ToTargetAddr for TargetAddr {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
Ok(self.clone())
}
}
impl ToTargetAddr for SocketAddr {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
Ok(TargetAddr::Ip(*self))
}
}
impl ToTargetAddr for SocketAddrV4 {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
SocketAddr::V4(*self).to_target_addr()
}
}
impl ToTargetAddr for SocketAddrV6 {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
SocketAddr::V6(*self).to_target_addr()
}
}
impl ToTargetAddr for (Ipv4Addr, u16) {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
SocketAddrV4::new(self.0, self.1).to_target_addr()
}
}
impl ToTargetAddr for (Ipv6Addr, u16) {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
SocketAddrV6::new(self.0, self.1, 0, 0).to_target_addr()
}
}
impl ToTargetAddr for (&str, u16) {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
if let Ok(addr) = self.0.parse::<Ipv4Addr>() {
return (addr, self.1).to_target_addr();
}
if let Ok(addr) = self.0.parse::<Ipv6Addr>() {
return (addr, self.1).to_target_addr();
}
Ok(TargetAddr::Domain(self.0.to_owned(), self.1))
}
}
impl ToTargetAddr for &str {
fn to_target_addr(&self) -> io::Result<TargetAddr> {
if let Ok(addr) = self.parse::<SocketAddrV4>() {
return addr.to_target_addr();
}
if let Ok(addr) = self.parse::<SocketAddrV6>() {
return addr.to_target_addr();
}
let mut parts_iter = self.rsplitn(2, ':');
let Some(port_str) = parts_iter.next() else {
return Err(Error::InvalidSocksAddress {
addr: (*self).to_string(),
}
.into_io());
};
let Some(host) = parts_iter.next() else {
return Err(Error::InvalidSocksAddress {
addr: (*self).to_string(),
}
.into_io());
};
let Some(port): Option<u16> = port_str.parse().ok() else {
return Err(Error::InvalidPortValue {
addr: (*self).to_string(),
port: port_str.to_string(),
}
.into_io());
};
(host, port).to_target_addr()
}
}
fn tcp_stream_connect<T>(proxy: T, connect_timeout: Option<Duration>) -> io::Result<TcpStream>
where
T: ToSocketAddrs,
{
match connect_timeout {
None => TcpStream::connect(proxy),
Some(t) => {
let mut addrs = proxy.to_socket_addrs()?;
let mut last_err = None;
for addr in &mut addrs {
match TcpStream::connect_timeout(&addr, t) {
Ok(t) => return Ok(t),
Err(err) => last_err = Some(err),
}
}
Err(last_err.unwrap_or_else(|| Error::NoResolveSocketAddrs {}.into_io()))
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod test {
use super::*;
#[test]
fn domains_to_target_addr() {
assert_eq!(
"localhost:80".to_target_addr().unwrap(),
TargetAddr::Domain("localhost".to_owned(), 80)
);
assert_eq!(
unwrap_io_to_socks2_error(&"localhost:".to_target_addr().unwrap_err()),
Some(&Error::InvalidPortValue {
addr: String::new(),
port: String::new()
})
);
assert_eq!(
"github.com:443".to_target_addr().unwrap(),
TargetAddr::Domain("github.com".to_owned(), 443)
);
}
}