Skip to main content

rusty_sockslib/
helpers.rs

1use rand::RngExt;
2use rand::distr::Alphanumeric;
3use std::error::Error;
4use std::fmt::Display;
5use std::{fmt::Formatter, net::SocketAddr};
6
7use if_addrs::get_if_addrs;
8use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
9use tokio::net::TcpSocket;
10
11pub enum Cidr {
12    V4(u32, u32),
13    V6(u128, u128),
14}
15
16impl Cidr {
17    pub fn is_trivial(&self) -> bool {
18        match self {
19            Cidr::V4(_, mask) => *mask == 0,
20            Cidr::V6(_, mask) => *mask == 0,
21        }
22    }
23}
24
25pub struct EndpointPair {
26    pub socket: TcpSocket,
27    pub address: SocketAddr,
28}
29
30pub struct Helpers;
31
32impl Helpers {
33    pub fn get_id() -> String {
34        rand::rng().sample_iter(Alphanumeric).take(4).map(char::from).collect::<String>()
35    }
36
37    pub fn bytes_to_port(data: &[u8]) -> Res<u16> {
38        if data.len() != 2 {
39            return "There must be exactly two (2) bytes for a conversion to a port.".into_error();
40        }
41
42        Ok(((data[0] as u16) << 8) + (data[1] as u16))
43    }
44
45    pub fn port_to_bytes(port: u16) -> (u8, u8) {
46        ((port >> 8) as u8, (port & 0xff) as u8)
47    }
48
49    pub fn slice_to_u32(data: &[u8]) -> Res<u32> {
50        if data.len() != 4 {
51            return "There must be exactly four (4) bytes for a conversion to an IPv4.".into_error();
52        }
53
54        Ok(((data[0] as u32) << 24) + ((data[1] as u32) << 16) + ((data[2] as u32) << 8) + (data[3] as u32))
55    }
56
57    pub fn slice_to_u128(data: &[u8]) -> Res<u128> {
58        if data.len() != 16 {
59            return "There must be exactly sixteen (16) bytes for a conversion to an IPv6.".into_error();
60        }
61
62        Ok(((data[0] as u128) << 120)
63            + ((data[1] as u128) << 112)
64            + ((data[2] as u128) << 104)
65            + ((data[3] as u128) << 96)
66            + ((data[4] as u128) << 88)
67            + ((data[5] as u128) << 80)
68            + ((data[6] as u128) << 72)
69            + ((data[7] as u128) << 64)
70            + ((data[8] as u128) << 56)
71            + ((data[9] as u128) << 48)
72            + ((data[10] as u128) << 40)
73            + ((data[11] as u128) << 32)
74            + ((data[12] as u128) << 24)
75            + ((data[13] as u128) << 16)
76            + ((data[14] as u128) << 8)
77            + (data[15] as u128))
78    }
79
80    pub fn get_socks_reply(error: i32) -> u8 {
81        match error {
82            0 => 0x00,                     // succeeded
83            10050 | 10051 => 0x03,         // Network unreachable
84            10064 | 11001 | 10065 => 0x04, // Host unreachable
85            10061 => 0x05,                 // Connection refused
86            10060 => 0x06,                 // TTL expired... [ARoney] Is this right?
87            _ => 0x01,                     // general SOCKS server failure
88        }
89    }
90
91    pub fn write_octets(buffer: &mut [u8], octets: &[u8]) {
92        buffer[..octets.len()].clone_from_slice(octets);
93    }
94
95    pub fn get_interface_ip(name: &str) -> Res<IpAddr> {
96        for iface in get_if_addrs()? {
97            if iface.name == name {
98                return Ok(iface.ip());
99            }
100        }
101
102        format!("Could not lookup IP for interface `{}`.", name).into_error()
103    }
104
105    pub fn mask_ipv4(ip: &Ipv4Addr, mask: u32) -> Res<u32> {
106        Ok(Helpers::slice_to_u32(&ip.octets())? & mask)
107    }
108
109    pub fn mask_ipv6(ip: &Ipv6Addr, mask: u128) -> Res<u128> {
110        Ok(Helpers::slice_to_u128(&ip.octets())? & mask)
111    }
112
113    pub fn is_ip_in_cidr(ip_addr: &IpAddr, cidr: &Cidr) -> Res<bool> {
114        match cidr {
115            Cidr::V4(prefix, mask) => match &ip_addr {
116                IpAddr::V4(ip) => Ok(Helpers::mask_ipv4(ip, *mask)? == *prefix),
117                _ => Err(Box::new(GenericError::from("Cannot check IPv6 addresses against IPv4 CIDRs."))),
118            },
119            Cidr::V6(prefix, mask) => match &ip_addr {
120                IpAddr::V6(ip) => Ok(Helpers::mask_ipv6(ip, *mask)? == *prefix),
121                _ => Err(Box::new(GenericError::from("Cannot check IPv4 addresses against IPv6 CIDRs."))),
122            },
123        }
124    }
125
126    pub fn parse_cidr(s: &str) -> Res<Cidr> {
127        let splits = s.split('/').collect::<Vec<&str>>();
128
129        let ip_addr = splits[0].parse::<IpAddr>()?;
130        let num_mask_bits = splits[1].parse::<u32>()?;
131
132        match ip_addr {
133            IpAddr::V4(ip) => {
134                if num_mask_bits > 32 {
135                    return Err(Box::new(GenericError::from("An IPv4 CIDR prefix must have a mask bit length less than or equal to 32.")));
136                }
137
138                let mask = !(2u32.overflowing_pow(32 - num_mask_bits).0.overflowing_sub(1).0);
139                let prefix = Helpers::slice_to_u32(&ip.octets())? & mask;
140
141                Ok(Cidr::V4(prefix, mask))
142            }
143            IpAddr::V6(ip) => {
144                if num_mask_bits > 128 {
145                    return Err(Box::new(GenericError::from("An IPv4 CIDR prefix must have a mask bit length less than or equal to 128.")));
146                }
147
148                let mask = !(2u128.overflowing_pow(128 - num_mask_bits).0.overflowing_sub(1).0);
149                let prefix = Helpers::slice_to_u128(&ip.octets())? & mask;
150
151                Ok(Cidr::V6(prefix, mask))
152            }
153        }
154    }
155
156    pub fn create_local_socket(local_addr: SocketAddr, mut endpoint_addresses: impl Iterator<Item = SocketAddr>) -> Option<EndpointPair> {
157        let is_endpoint_interface_ipv6 = local_addr.is_ipv6();
158
159        let endpoint_addr = if is_endpoint_interface_ipv6 {
160            endpoint_addresses.find(|a| a.is_ipv6())
161        } else {
162            endpoint_addresses.find(|a| a.is_ipv4())
163        };
164
165        let endpoint_addr = endpoint_addr?;
166
167        // Bind to requested local address.
168        let socket = if endpoint_addr.is_ipv4() { TcpSocket::new_v4().ok()? } else { TcpSocket::new_v6().ok()? };
169
170        socket.bind(local_addr).ok()?;
171
172        Some(EndpointPair { socket, address: endpoint_addr })
173    }
174}
175
176pub type Void = Result<(), Box<dyn std::error::Error>>;
177pub type Res<T> = Result<T, Box<dyn std::error::Error>>;
178
179pub trait IntoError<T> {
180    fn into_error(self) -> Res<T>;
181}
182
183impl<T, S> IntoError<T> for S
184where
185    S: AsRef<str> + ToString,
186{
187    fn into_error(self) -> Res<T> {
188        Err(Box::new(GenericError::from(self)))
189    }
190}
191
192#[derive(Debug)]
193pub struct GenericError {
194    message: String,
195}
196
197impl<T> From<T> for GenericError
198where
199    T: AsRef<str> + ToString,
200{
201    fn from(message: T) -> Self {
202        GenericError { message: message.to_string() }
203    }
204}
205
206impl Display for GenericError {
207    fn fmt<'a>(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
208        write!(f, "{}", self.message)
209    }
210}
211
212impl Error for GenericError {
213    fn source(&self) -> Option<&(dyn Error + 'static)> {
214        None
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use pretty_assertions::assert_eq;
222    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
223
224    #[test]
225    fn port_roundtrips_through_bytes() {
226        for port in [0u16, 80, 443, 1080, 65535] {
227            let (hi, lo) = Helpers::port_to_bytes(port);
228            assert_eq!(Helpers::bytes_to_port(&[hi, lo]).unwrap(), port);
229        }
230    }
231
232    #[test]
233    fn bytes_to_port_rejects_wrong_length() {
234        assert!(Helpers::bytes_to_port(&[0]).is_err());
235        assert!(Helpers::bytes_to_port(&[0, 0, 0]).is_err());
236    }
237
238    #[test]
239    fn slice_to_u32_parses_and_validates_length() {
240        assert_eq!(Helpers::slice_to_u32(&[127, 0, 0, 1]).unwrap(), 0x7f00_0001);
241        assert!(Helpers::slice_to_u32(&[0, 0, 0]).is_err());
242    }
243
244    #[test]
245    fn slice_to_u128_parses_and_validates_length() {
246        assert_eq!(Helpers::slice_to_u128(&[0u8; 16]).unwrap(), 0);
247        assert!(Helpers::slice_to_u128(&[0u8; 15]).is_err());
248    }
249
250    #[test]
251    fn parse_cidr_zero_mask_is_trivial() {
252        assert!(Helpers::parse_cidr("0.0.0.0/0").unwrap().is_trivial());
253        assert!(Helpers::parse_cidr("::/0").unwrap().is_trivial());
254    }
255
256    #[test]
257    fn parse_cidr_rejects_oversized_mask() {
258        assert!(Helpers::parse_cidr("10.0.0.0/33").is_err());
259        assert!(Helpers::parse_cidr("::/129").is_err());
260    }
261
262    #[test]
263    fn cidr_membership_v4() {
264        let cidr = Helpers::parse_cidr("10.216.0.0/16").unwrap();
265        let inside = IpAddr::V4(Ipv4Addr::new(10, 216, 5, 5));
266        let outside = IpAddr::V4(Ipv4Addr::new(10, 217, 0, 1));
267        assert!(Helpers::is_ip_in_cidr(&inside, &cidr).unwrap());
268        assert!(!Helpers::is_ip_in_cidr(&outside, &cidr).unwrap());
269    }
270
271    #[test]
272    fn cidr_membership_v6() {
273        let cidr = Helpers::parse_cidr("2001:db8::/32").unwrap();
274        let inside = IpAddr::V6("2001:db8::1".parse::<Ipv6Addr>().unwrap());
275        let outside = IpAddr::V6("2001:dead::1".parse::<Ipv6Addr>().unwrap());
276        assert!(Helpers::is_ip_in_cidr(&inside, &cidr).unwrap());
277        assert!(!Helpers::is_ip_in_cidr(&outside, &cidr).unwrap());
278    }
279
280    #[test]
281    fn cidr_membership_rejects_family_mismatch() {
282        let v4_cidr = Helpers::parse_cidr("10.0.0.0/8").unwrap();
283        let v6_addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
284        assert!(Helpers::is_ip_in_cidr(&v6_addr, &v4_cidr).is_err());
285    }
286
287    #[test]
288    fn socks_reply_maps_known_errors() {
289        assert_eq!(Helpers::get_socks_reply(0), 0x00);
290        assert_eq!(Helpers::get_socks_reply(10061), 0x05); // connection refused
291        assert_eq!(Helpers::get_socks_reply(123_456), 0x01); // general failure fallback
292    }
293
294    #[test]
295    fn get_id_is_four_alphanumerics() {
296        let id = Helpers::get_id();
297        assert_eq!(id.len(), 4);
298        assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
299    }
300}