Skip to main content

koibumi_socks_core/
lib.rs

1//! This crate is a core module of a minimal SOCKS5 client library.
2//!
3//! Intended to use with a local Tor SOCKS5 proxy.
4
5// See RFC 1928 SOCKS Protocol Version 5
6
7#![warn(missing_docs)]
8
9use std::{
10    fmt,
11    net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6},
12    num::ParseIntError,
13    str::FromStr,
14};
15
16/// This type represents a domain name used by SOCKS5.
17///
18/// The maximum length is 255 bytes.
19#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
20pub struct DomainName {
21    bytes: Vec<u8>,
22}
23
24impl AsRef<[u8]> for DomainName {
25    fn as_ref(&self) -> &[u8] {
26        self.bytes.as_ref()
27    }
28}
29
30/// An error which can be returned when parsing a domain name.
31///
32/// This error is used as the error type for the [`DomainName::new()`] method
33/// and the `FromStr` implementation for [`DomainName`].
34///
35/// [`DomainName::new()`]: struct.DomainName.html#method.new
36/// [`DomainName`]: struct.DomainName.html
37#[derive(Clone, PartialEq, Eq, Debug)]
38pub enum ParseDomainNameError {
39    /// The input was too long to construct a domain name for SOCKS5.
40    /// The maximum length allowed and the actual length of the input
41    /// are returned as payloads of this variant.
42    TooLong {
43        /// The maximum length.
44        max: usize,
45        /// The actual length.
46        len: usize,
47    },
48}
49
50impl fmt::Display for ParseDomainNameError {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        match self {
53            Self::TooLong { max, len } => write!(f, "length must be <={}, but {}", max, len),
54        }
55    }
56}
57
58impl std::error::Error for ParseDomainNameError {}
59
60impl DomainName {
61    const MAX_LEN: usize = 0xff;
62
63    /// Constructs a domain name from a byte string.
64    ///
65    /// The byte length is checked.
66    pub fn new(bytes: Vec<u8>) -> std::result::Result<Self, ParseDomainNameError> {
67        if bytes.len() > Self::MAX_LEN {
68            return Err(ParseDomainNameError::TooLong {
69                max: Self::MAX_LEN,
70                len: bytes.len(),
71            });
72        }
73        Ok(Self { bytes })
74    }
75}
76
77impl fmt::Display for DomainName {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        String::from_utf8_lossy(self.bytes.as_ref()).fmt(f)
80    }
81}
82
83impl FromStr for DomainName {
84    type Err = ParseDomainNameError;
85
86    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
87        Ok(Self::new(s.as_bytes().to_vec())?)
88    }
89}
90
91/// This type represents IPv4, IPv6 address or a domain name used by SOCKS5.
92///
93/// The inner IP address types are defined in the `std::net` module.
94/// On the other hand, the inner domain name type is defined in this own crate.
95#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
96pub enum Addr {
97    /// An IPv4 address.
98    Ipv4(Ipv4Addr),
99
100    /// A domain name.
101    DomainName(DomainName),
102
103    /// An IPv6 address.
104    Ipv6(Ipv6Addr),
105}
106
107impl fmt::Display for Addr {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        match self {
110            Addr::Ipv4(addr) => addr.fmt(f),
111            Addr::DomainName(addr) => addr.fmt(f),
112            Addr::Ipv6(addr) => addr.fmt(f),
113        }
114    }
115}
116
117/// An error which can be returned when parsing a SOCKS address
118/// or a SOCKS socket address.
119///
120/// This error is used as the error type for the `FromStr` implementation
121/// for [`Addr`].
122/// This error is also used as a component of the error type
123/// [`ParseSocketAddrError`].
124///
125/// [`Addr`]: enum.Addr.html
126/// [`ParseSocketAddrError`]: enum.ParseSocketAddrError.html
127#[derive(Clone, PartialEq, Eq, Debug)]
128pub enum ParseAddrError {
129    /// Failed when parsing the input as a domain name.
130    /// The actual error caught is returned as a payload of this variant.
131    DomainName(ParseDomainNameError),
132}
133
134impl fmt::Display for ParseAddrError {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        match self {
137            Self::DomainName(err) => err.fmt(f),
138        }
139    }
140}
141
142impl From<ParseDomainNameError> for ParseAddrError {
143    fn from(err: ParseDomainNameError) -> Self {
144        Self::DomainName(err)
145    }
146}
147
148impl std::error::Error for ParseAddrError {}
149
150impl FromStr for Addr {
151    type Err = ParseAddrError;
152
153    fn from_str(s: &str) -> Result<Self, Self::Err> {
154        if let Ok(addr) = s.parse::<Ipv4Addr>() {
155            return Ok(Self::Ipv4(addr));
156        }
157        if let Ok(addr) = s.parse::<Ipv6Addr>() {
158            return Ok(Self::Ipv6(addr));
159        }
160        Ok(Self::DomainName(DomainName::new(s.as_bytes().to_vec())?))
161    }
162}
163
164type Port = u16;
165
166/// This type represents a socket address which uses domain name,
167/// that is, a domain name with a port.
168///
169/// For restrictions on a domain name, see documents for
170/// [`DomainName`](struct.DomainName.html).
171#[derive(Clone, PartialEq, Eq, Hash, Debug)]
172pub struct SocketDomainName {
173    domain_name: DomainName,
174    port: Port,
175}
176
177impl SocketDomainName {
178    /// Constructs a socket address from a domain name with a port.
179    pub fn new(domain_name: DomainName, port: Port) -> Self {
180        Self { domain_name, port }
181    }
182
183    /// Returns the domain name part of the socket address.
184    pub fn domain_name(&self) -> &DomainName {
185        &self.domain_name
186    }
187
188    /// Returns the port part of the socket address.
189    pub fn port(&self) -> Port {
190        self.port
191    }
192}
193
194impl fmt::Display for SocketDomainName {
195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196        write!(f, "{}:{}", self.domain_name, self.port)
197    }
198}
199
200/// This type represents a socket address used by SOCKS5.
201///
202/// The inner IP socket address types are defined in the `std::net` module.
203/// On the other hand, the inner domain name socket address type is
204/// defined in this own crate.
205#[derive(Clone, PartialEq, Eq, Hash, Debug)]
206pub enum SocketAddr {
207    /// A socket address using an IPv4 address.
208    Ipv4(SocketAddrV4),
209
210    /// A socket address using a domain name.
211    DomainName(SocketDomainName),
212
213    /// A socket address using an IPv6 address.
214    Ipv6(SocketAddrV6),
215}
216
217impl SocketAddr {
218    /// Constructs a socket address from an address and a port.
219    pub fn new(addr: Addr, port: Port) -> Self {
220        match addr {
221            Addr::Ipv4(addr) => Self::Ipv4(SocketAddrV4::new(addr, port)),
222            Addr::DomainName(addr) => Self::DomainName(SocketDomainName::new(addr, port)),
223            Addr::Ipv6(addr) => Self::Ipv6(SocketAddrV6::new(addr, port, 0, 0)),
224        }
225    }
226}
227
228impl fmt::Display for SocketAddr {
229    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230        match self {
231            SocketAddr::Ipv4(addr) => addr.fmt(f),
232            SocketAddr::DomainName(addr) => addr.fmt(f),
233            SocketAddr::Ipv6(addr) => addr.fmt(f),
234        }
235    }
236}
237
238/// An error which can be returned
239/// when parsing a SOCKS socket address.
240///
241/// This error is used as the error type for the `FromStr` implementation
242/// for [`SocketAddr`].
243///
244/// [`SocketAddr`]: enum.SocketAddr.html
245#[derive(Clone, PartialEq, Eq, Debug)]
246pub enum ParseSocketAddrError {
247    /// The input did not have any port number.
248    PortNotFound,
249
250    /// An error was caught when parsing a port number.
251    /// The actual error caught is returned
252    /// as a payload of this variant.
253    InvalidPort(ParseIntError),
254
255    /// An error was caught when parsing
256    /// a extended Bitmessage network address.
257    /// The actual error caught is returned
258    /// as a payload of this variant.
259    InvalidAddr(ParseAddrError),
260}
261
262impl fmt::Display for ParseSocketAddrError {
263    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264        match self {
265            Self::PortNotFound => write!(f, "port not found"),
266            Self::InvalidPort(err) => err.fmt(f),
267            Self::InvalidAddr(err) => err.fmt(f),
268        }
269    }
270}
271
272impl std::error::Error for ParseSocketAddrError {}
273
274impl From<ParseIntError> for ParseSocketAddrError {
275    fn from(err: ParseIntError) -> Self {
276        Self::InvalidPort(err)
277    }
278}
279
280impl From<ParseAddrError> for ParseSocketAddrError {
281    fn from(err: ParseAddrError) -> Self {
282        Self::InvalidAddr(err)
283    }
284}
285
286impl FromStr for SocketAddr {
287    type Err = ParseSocketAddrError;
288
289    fn from_str(s: &str) -> Result<Self, Self::Err> {
290        let colon = s.rfind(':');
291        if colon.is_none() {
292            return Err(Self::Err::PortNotFound);
293        }
294        let colon = colon.unwrap();
295
296        let mut addr_part = &s[..colon];
297        if addr_part.starts_with('[') && addr_part.ends_with(']') {
298            addr_part = &addr_part[1..addr_part.len() - 1];
299        }
300        let port_part = &s[colon + 1..];
301
302        let port = port_part.parse::<Port>()?;
303        let addr = addr_part.parse()?;
304        Ok(Self::new(addr, port))
305    }
306}
307
308#[test]
309fn test_socket_addr_ext_from_str() {
310    let test: SocketAddr = "www.example.net:8080".parse().unwrap();
311    let domain = SocketDomainName::new(DomainName::new(b"www.example.net".to_vec()).unwrap(), 8080);
312    let expected = SocketAddr::DomainName(domain);
313    assert_eq!(test, expected);
314}