Skip to main content

rusty_sockslib/
helpers.rs

1use rand::RngExt;
2use rand::distr::Alphanumeric;
3use std::error::Error;
4use std::fmt::Display;
5use std::{fmt::Formatter, net::SocketAddr};
6
7use pnet::datalink;
8use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
9use tokio::net::TcpSocket;
10
11pub enum Cidr {
12    V4(u32, u32),
13    V6(u128, u128),
14}
15
16impl Cidr {
17    pub fn is_trivial(&self) -> bool {
18        match self {
19            Cidr::V4(_, mask) => *mask == 0,
20            Cidr::V6(_, mask) => *mask == 0,
21        }
22    }
23}
24
25pub struct EndpointPair {
26    pub socket: TcpSocket,
27    pub address: SocketAddr,
28}
29
30pub struct Helpers;
31
32impl Helpers {
33    pub fn get_id() -> String {
34        rand::rng().sample_iter(Alphanumeric).take(4).map(char::from).collect::<String>()
35    }
36
37    pub fn bytes_to_port(data: &[u8]) -> Res<u16> {
38        if data.len() != 2 {
39            return "There must be exactly two (2) bytes for a conversion to a port.".into_error();
40        }
41
42        Ok(((data[0] as u16) << 8) + (data[1] as u16))
43    }
44
45    pub fn port_to_bytes(port: u16) -> (u8, u8) {
46        ((port >> 8) as u8, (port & 0xff) as u8)
47    }
48
49    pub fn slice_to_u32(data: &[u8]) -> Res<u32> {
50        if data.len() != 4 {
51            return "There must be exactly four (4) bytes for a conversion to an IPv4.".into_error();
52        }
53
54        Ok(((data[0] as u32) << 24) + ((data[1] as u32) << 16) + ((data[2] as u32) << 8) + (data[3] as u32))
55    }
56
57    pub fn slice_to_u128(data: &[u8]) -> Res<u128> {
58        if data.len() != 16 {
59            return "There must be exactly sixteen (16) bytes for a conversion to an IPv6.".into_error();
60        }
61
62        Ok(((data[0] as u128) << 120)
63            + ((data[1] as u128) << 112)
64            + ((data[2] as u128) << 104)
65            + ((data[3] as u128) << 96)
66            + ((data[4] as u128) << 88)
67            + ((data[5] as u128) << 80)
68            + ((data[6] as u128) << 72)
69            + ((data[7] as u128) << 64)
70            + ((data[8] as u128) << 56)
71            + ((data[9] as u128) << 48)
72            + ((data[10] as u128) << 40)
73            + ((data[11] as u128) << 32)
74            + ((data[12] as u128) << 24)
75            + ((data[13] as u128) << 16)
76            + ((data[14] as u128) << 8)
77            + (data[15] as u128))
78    }
79
80    pub fn get_socks_reply(error: i32) -> u8 {
81        match error {
82            0 => 0x00,                     // succeeded
83            10050 | 10051 => 0x03,         // Network unreachable
84            10064 | 11001 | 10065 => 0x04, // Host unreachable
85            10061 => 0x05,                 // Connection refused
86            10060 => 0x06,                 // TTL expired... [ARoney] Is this right?
87            _ => 0x01,                     // general SOCKS server failure
88        }
89    }
90
91    pub fn write_octets(buffer: &mut [u8], octets: &[u8]) {
92        buffer[..octets.len()].clone_from_slice(octets);
93    }
94
95    pub fn get_interface_ip(name: &str) -> Res<IpAddr> {
96        for iface in datalink::interfaces() {
97            if iface.name == name {
98                if iface.ips.is_empty() {
99                    return format!("Found interface `{}`, but could not find an assigned IP for that interface.", name).into_error();
100                }
101
102                return Ok(iface.ips[0].ip());
103            }
104        }
105
106        format!("Could not lookup IP for interface `{}`.", name).into_error()
107    }
108
109    pub fn mask_ipv4(ip: &Ipv4Addr, mask: u32) -> Res<u32> {
110        Ok(Helpers::slice_to_u32(&ip.octets())? & mask)
111    }
112
113    pub fn mask_ipv6(ip: &Ipv6Addr, mask: u128) -> Res<u128> {
114        Ok(Helpers::slice_to_u128(&ip.octets())? & mask)
115    }
116
117    pub fn is_ip_in_cidr(ip_addr: &IpAddr, cidr: &Cidr) -> Res<bool> {
118        match cidr {
119            Cidr::V4(prefix, mask) => match &ip_addr {
120                IpAddr::V4(ip) => Ok(Helpers::mask_ipv4(ip, *mask)? == *prefix),
121                _ => Err(Box::new(GenericError::from("Cannot check IPv6 addresses against IPv4 CIDRs."))),
122            },
123            Cidr::V6(prefix, mask) => match &ip_addr {
124                IpAddr::V6(ip) => Ok(Helpers::mask_ipv6(ip, *mask)? == *prefix),
125                _ => Err(Box::new(GenericError::from("Cannot check IPv4 addresses against IPv6 CIDRs."))),
126            },
127        }
128    }
129
130    pub fn parse_cidr(s: &str) -> Res<Cidr> {
131        let splits = s.split('/').collect::<Vec<&str>>();
132
133        let ip_addr = splits[0].parse::<IpAddr>()?;
134        let num_mask_bits = splits[1].parse::<u32>()?;
135
136        match ip_addr {
137            IpAddr::V4(ip) => {
138                if num_mask_bits > 32 {
139                    return Err(Box::new(GenericError::from("An IPv4 CIDR prefix must have a mask bit length less than or equal to 32.")));
140                }
141
142                let mask = !(2u32.overflowing_pow(32 - num_mask_bits).0.overflowing_sub(1).0);
143                let prefix = Helpers::slice_to_u32(&ip.octets())? & mask;
144
145                Ok(Cidr::V4(prefix, mask))
146            }
147            IpAddr::V6(ip) => {
148                if num_mask_bits > 128 {
149                    return Err(Box::new(GenericError::from("An IPv4 CIDR prefix must have a mask bit length less than or equal to 128.")));
150                }
151
152                let mask = !(2u128.overflowing_pow(128 - num_mask_bits).0.overflowing_sub(1).0);
153                let prefix = Helpers::slice_to_u128(&ip.octets())? & mask;
154
155                Ok(Cidr::V6(prefix, mask))
156            }
157        }
158    }
159
160    pub fn create_local_socket(local_addr: SocketAddr, mut endpoint_addresses: impl Iterator<Item = SocketAddr>) -> Option<EndpointPair> {
161        let is_endpoint_interface_ipv6 = local_addr.is_ipv6();
162
163        let endpoint_addr = if is_endpoint_interface_ipv6 {
164            endpoint_addresses.find(|a| a.is_ipv6())
165        } else {
166            endpoint_addresses.find(|a| a.is_ipv4())
167        };
168
169        let endpoint_addr = endpoint_addr?;
170
171        // Bind to requested local address.
172        let socket = if endpoint_addr.is_ipv4() { TcpSocket::new_v4().ok()? } else { TcpSocket::new_v6().ok()? };
173
174        socket.bind(local_addr).ok()?;
175
176        Some(EndpointPair { socket, address: endpoint_addr })
177    }
178}
179
180pub type Void = Result<(), Box<dyn std::error::Error>>;
181pub type Res<T> = Result<T, Box<dyn std::error::Error>>;
182
183pub trait IntoError<T> {
184    fn into_error(self) -> Res<T>;
185}
186
187impl<T, S> IntoError<T> for S
188where
189    S: AsRef<str> + ToString,
190{
191    fn into_error(self) -> Res<T> {
192        Err(Box::new(GenericError::from(self)))
193    }
194}
195
196#[derive(Debug)]
197pub struct GenericError {
198    message: String,
199}
200
201impl<T> From<T> for GenericError
202where
203    T: AsRef<str> + ToString,
204{
205    fn from(message: T) -> Self {
206        GenericError { message: message.to_string() }
207    }
208}
209
210impl Display for GenericError {
211    fn fmt<'a>(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
212        write!(f, "{}", self.message)
213    }
214}
215
216impl Error for GenericError {
217    fn source(&self) -> Option<&(dyn Error + 'static)> {
218        None
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use pretty_assertions::assert_eq;
226    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
227
228    #[test]
229    fn port_roundtrips_through_bytes() {
230        for port in [0u16, 80, 443, 1080, 65535] {
231            let (hi, lo) = Helpers::port_to_bytes(port);
232            assert_eq!(Helpers::bytes_to_port(&[hi, lo]).unwrap(), port);
233        }
234    }
235
236    #[test]
237    fn bytes_to_port_rejects_wrong_length() {
238        assert!(Helpers::bytes_to_port(&[0]).is_err());
239        assert!(Helpers::bytes_to_port(&[0, 0, 0]).is_err());
240    }
241
242    #[test]
243    fn slice_to_u32_parses_and_validates_length() {
244        assert_eq!(Helpers::slice_to_u32(&[127, 0, 0, 1]).unwrap(), 0x7f00_0001);
245        assert!(Helpers::slice_to_u32(&[0, 0, 0]).is_err());
246    }
247
248    #[test]
249    fn slice_to_u128_parses_and_validates_length() {
250        assert_eq!(Helpers::slice_to_u128(&[0u8; 16]).unwrap(), 0);
251        assert!(Helpers::slice_to_u128(&[0u8; 15]).is_err());
252    }
253
254    #[test]
255    fn parse_cidr_zero_mask_is_trivial() {
256        assert!(Helpers::parse_cidr("0.0.0.0/0").unwrap().is_trivial());
257        assert!(Helpers::parse_cidr("::/0").unwrap().is_trivial());
258    }
259
260    #[test]
261    fn parse_cidr_rejects_oversized_mask() {
262        assert!(Helpers::parse_cidr("10.0.0.0/33").is_err());
263        assert!(Helpers::parse_cidr("::/129").is_err());
264    }
265
266    #[test]
267    fn cidr_membership_v4() {
268        let cidr = Helpers::parse_cidr("10.216.0.0/16").unwrap();
269        let inside = IpAddr::V4(Ipv4Addr::new(10, 216, 5, 5));
270        let outside = IpAddr::V4(Ipv4Addr::new(10, 217, 0, 1));
271        assert!(Helpers::is_ip_in_cidr(&inside, &cidr).unwrap());
272        assert!(!Helpers::is_ip_in_cidr(&outside, &cidr).unwrap());
273    }
274
275    #[test]
276    fn cidr_membership_v6() {
277        let cidr = Helpers::parse_cidr("2001:db8::/32").unwrap();
278        let inside = IpAddr::V6("2001:db8::1".parse::<Ipv6Addr>().unwrap());
279        let outside = IpAddr::V6("2001:dead::1".parse::<Ipv6Addr>().unwrap());
280        assert!(Helpers::is_ip_in_cidr(&inside, &cidr).unwrap());
281        assert!(!Helpers::is_ip_in_cidr(&outside, &cidr).unwrap());
282    }
283
284    #[test]
285    fn cidr_membership_rejects_family_mismatch() {
286        let v4_cidr = Helpers::parse_cidr("10.0.0.0/8").unwrap();
287        let v6_addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
288        assert!(Helpers::is_ip_in_cidr(&v6_addr, &v4_cidr).is_err());
289    }
290
291    #[test]
292    fn socks_reply_maps_known_errors() {
293        assert_eq!(Helpers::get_socks_reply(0), 0x00);
294        assert_eq!(Helpers::get_socks_reply(10061), 0x05); // connection refused
295        assert_eq!(Helpers::get_socks_reply(123_456), 0x01); // general failure fallback
296    }
297
298    #[test]
299    fn get_id_is_four_alphanumerics() {
300        let id = Helpers::get_id();
301        assert_eq!(id.len(), 4);
302        assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
303    }
304}