rusty_sockslib/
helpers.rs1use rand::RngExt;
2use rand::distr::Alphanumeric;
3use std::error::Error;
4use std::fmt::Display;
5use std::{fmt::Formatter, net::SocketAddr};
6
7use pnet::datalink;
8use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
9use tokio::net::TcpSocket;
10
11pub enum Cidr {
12 V4(u32, u32),
13 V6(u128, u128),
14}
15
16impl Cidr {
17 pub fn is_trivial(&self) -> bool {
18 match self {
19 Cidr::V4(_, mask) => *mask == 0,
20 Cidr::V6(_, mask) => *mask == 0,
21 }
22 }
23}
24
25pub struct EndpointPair {
26 pub socket: TcpSocket,
27 pub address: SocketAddr,
28}
29
30pub struct Helpers;
31
32impl Helpers {
33 pub fn get_id() -> String {
34 rand::rng().sample_iter(Alphanumeric).take(4).map(char::from).collect::<String>()
35 }
36
37 pub fn bytes_to_port(data: &[u8]) -> Res<u16> {
38 if data.len() != 2 {
39 return "There must be exactly two (2) bytes for a conversion to a port.".into_error();
40 }
41
42 Ok(((data[0] as u16) << 8) + (data[1] as u16))
43 }
44
45 pub fn port_to_bytes(port: u16) -> (u8, u8) {
46 ((port >> 8) as u8, (port & 0xff) as u8)
47 }
48
49 pub fn slice_to_u32(data: &[u8]) -> Res<u32> {
50 if data.len() != 4 {
51 return "There must be exactly four (4) bytes for a conversion to an IPv4.".into_error();
52 }
53
54 Ok(((data[0] as u32) << 24) + ((data[1] as u32) << 16) + ((data[2] as u32) << 8) + (data[3] as u32))
55 }
56
57 pub fn slice_to_u128(data: &[u8]) -> Res<u128> {
58 if data.len() != 16 {
59 return "There must be exactly sixteen (16) bytes for a conversion to an IPv6.".into_error();
60 }
61
62 Ok(((data[0] as u128) << 120)
63 + ((data[1] as u128) << 112)
64 + ((data[2] as u128) << 104)
65 + ((data[3] as u128) << 96)
66 + ((data[4] as u128) << 88)
67 + ((data[5] as u128) << 80)
68 + ((data[6] as u128) << 72)
69 + ((data[7] as u128) << 64)
70 + ((data[8] as u128) << 56)
71 + ((data[9] as u128) << 48)
72 + ((data[10] as u128) << 40)
73 + ((data[11] as u128) << 32)
74 + ((data[12] as u128) << 24)
75 + ((data[13] as u128) << 16)
76 + ((data[14] as u128) << 8)
77 + (data[15] as u128))
78 }
79
80 pub fn get_socks_reply(error: i32) -> u8 {
81 match error {
82 0 => 0x00, 10050 | 10051 => 0x03, 10064 | 11001 | 10065 => 0x04, 10061 => 0x05, 10060 => 0x06, _ => 0x01, }
89 }
90
91 pub fn write_octets(buffer: &mut [u8], octets: &[u8]) {
92 buffer[..octets.len()].clone_from_slice(octets);
93 }
94
95 pub fn get_interface_ip(name: &str) -> Res<IpAddr> {
96 for iface in datalink::interfaces() {
97 if iface.name == name {
98 if iface.ips.is_empty() {
99 return format!("Found interface `{}`, but could not find an assigned IP for that interface.", name).into_error();
100 }
101
102 return Ok(iface.ips[0].ip());
103 }
104 }
105
106 format!("Could not lookup IP for interface `{}`.", name).into_error()
107 }
108
109 pub fn mask_ipv4(ip: &Ipv4Addr, mask: u32) -> Res<u32> {
110 Ok(Helpers::slice_to_u32(&ip.octets())? & mask)
111 }
112
113 pub fn mask_ipv6(ip: &Ipv6Addr, mask: u128) -> Res<u128> {
114 Ok(Helpers::slice_to_u128(&ip.octets())? & mask)
115 }
116
117 pub fn is_ip_in_cidr(ip_addr: &IpAddr, cidr: &Cidr) -> Res<bool> {
118 match cidr {
119 Cidr::V4(prefix, mask) => match &ip_addr {
120 IpAddr::V4(ip) => Ok(Helpers::mask_ipv4(ip, *mask)? == *prefix),
121 _ => Err(Box::new(GenericError::from("Cannot check IPv6 addresses against IPv4 CIDRs."))),
122 },
123 Cidr::V6(prefix, mask) => match &ip_addr {
124 IpAddr::V6(ip) => Ok(Helpers::mask_ipv6(ip, *mask)? == *prefix),
125 _ => Err(Box::new(GenericError::from("Cannot check IPv4 addresses against IPv6 CIDRs."))),
126 },
127 }
128 }
129
130 pub fn parse_cidr(s: &str) -> Res<Cidr> {
131 let splits = s.split('/').collect::<Vec<&str>>();
132
133 let ip_addr = splits[0].parse::<IpAddr>()?;
134 let num_mask_bits = splits[1].parse::<u32>()?;
135
136 match ip_addr {
137 IpAddr::V4(ip) => {
138 if num_mask_bits > 32 {
139 return Err(Box::new(GenericError::from("An IPv4 CIDR prefix must have a mask bit length less than or equal to 32.")));
140 }
141
142 let mask = !(2u32.overflowing_pow(32 - num_mask_bits).0.overflowing_sub(1).0);
143 let prefix = Helpers::slice_to_u32(&ip.octets())? & mask;
144
145 Ok(Cidr::V4(prefix, mask))
146 }
147 IpAddr::V6(ip) => {
148 if num_mask_bits > 128 {
149 return Err(Box::new(GenericError::from("An IPv4 CIDR prefix must have a mask bit length less than or equal to 128.")));
150 }
151
152 let mask = !(2u128.overflowing_pow(128 - num_mask_bits).0.overflowing_sub(1).0);
153 let prefix = Helpers::slice_to_u128(&ip.octets())? & mask;
154
155 Ok(Cidr::V6(prefix, mask))
156 }
157 }
158 }
159
160 pub fn create_local_socket(local_addr: SocketAddr, mut endpoint_addresses: impl Iterator<Item = SocketAddr>) -> Option<EndpointPair> {
161 let is_endpoint_interface_ipv6 = local_addr.is_ipv6();
162
163 let endpoint_addr = if is_endpoint_interface_ipv6 {
164 endpoint_addresses.find(|a| a.is_ipv6())
165 } else {
166 endpoint_addresses.find(|a| a.is_ipv4())
167 };
168
169 let endpoint_addr = endpoint_addr?;
170
171 let socket = if endpoint_addr.is_ipv4() { TcpSocket::new_v4().ok()? } else { TcpSocket::new_v6().ok()? };
173
174 socket.bind(local_addr).ok()?;
175
176 Some(EndpointPair { socket, address: endpoint_addr })
177 }
178}
179
180pub type Void = Result<(), Box<dyn std::error::Error>>;
181pub type Res<T> = Result<T, Box<dyn std::error::Error>>;
182
183pub trait IntoError<T> {
184 fn into_error(self) -> Res<T>;
185}
186
187impl<T, S> IntoError<T> for S
188where
189 S: AsRef<str> + ToString,
190{
191 fn into_error(self) -> Res<T> {
192 Err(Box::new(GenericError::from(self)))
193 }
194}
195
196#[derive(Debug)]
197pub struct GenericError {
198 message: String,
199}
200
201impl<T> From<T> for GenericError
202where
203 T: AsRef<str> + ToString,
204{
205 fn from(message: T) -> Self {
206 GenericError { message: message.to_string() }
207 }
208}
209
210impl Display for GenericError {
211 fn fmt<'a>(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
212 write!(f, "{}", self.message)
213 }
214}
215
216impl Error for GenericError {
217 fn source(&self) -> Option<&(dyn Error + 'static)> {
218 None
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use pretty_assertions::assert_eq;
226 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
227
228 #[test]
229 fn port_roundtrips_through_bytes() {
230 for port in [0u16, 80, 443, 1080, 65535] {
231 let (hi, lo) = Helpers::port_to_bytes(port);
232 assert_eq!(Helpers::bytes_to_port(&[hi, lo]).unwrap(), port);
233 }
234 }
235
236 #[test]
237 fn bytes_to_port_rejects_wrong_length() {
238 assert!(Helpers::bytes_to_port(&[0]).is_err());
239 assert!(Helpers::bytes_to_port(&[0, 0, 0]).is_err());
240 }
241
242 #[test]
243 fn slice_to_u32_parses_and_validates_length() {
244 assert_eq!(Helpers::slice_to_u32(&[127, 0, 0, 1]).unwrap(), 0x7f00_0001);
245 assert!(Helpers::slice_to_u32(&[0, 0, 0]).is_err());
246 }
247
248 #[test]
249 fn slice_to_u128_parses_and_validates_length() {
250 assert_eq!(Helpers::slice_to_u128(&[0u8; 16]).unwrap(), 0);
251 assert!(Helpers::slice_to_u128(&[0u8; 15]).is_err());
252 }
253
254 #[test]
255 fn parse_cidr_zero_mask_is_trivial() {
256 assert!(Helpers::parse_cidr("0.0.0.0/0").unwrap().is_trivial());
257 assert!(Helpers::parse_cidr("::/0").unwrap().is_trivial());
258 }
259
260 #[test]
261 fn parse_cidr_rejects_oversized_mask() {
262 assert!(Helpers::parse_cidr("10.0.0.0/33").is_err());
263 assert!(Helpers::parse_cidr("::/129").is_err());
264 }
265
266 #[test]
267 fn cidr_membership_v4() {
268 let cidr = Helpers::parse_cidr("10.216.0.0/16").unwrap();
269 let inside = IpAddr::V4(Ipv4Addr::new(10, 216, 5, 5));
270 let outside = IpAddr::V4(Ipv4Addr::new(10, 217, 0, 1));
271 assert!(Helpers::is_ip_in_cidr(&inside, &cidr).unwrap());
272 assert!(!Helpers::is_ip_in_cidr(&outside, &cidr).unwrap());
273 }
274
275 #[test]
276 fn cidr_membership_v6() {
277 let cidr = Helpers::parse_cidr("2001:db8::/32").unwrap();
278 let inside = IpAddr::V6("2001:db8::1".parse::<Ipv6Addr>().unwrap());
279 let outside = IpAddr::V6("2001:dead::1".parse::<Ipv6Addr>().unwrap());
280 assert!(Helpers::is_ip_in_cidr(&inside, &cidr).unwrap());
281 assert!(!Helpers::is_ip_in_cidr(&outside, &cidr).unwrap());
282 }
283
284 #[test]
285 fn cidr_membership_rejects_family_mismatch() {
286 let v4_cidr = Helpers::parse_cidr("10.0.0.0/8").unwrap();
287 let v6_addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
288 assert!(Helpers::is_ip_in_cidr(&v6_addr, &v4_cidr).is_err());
289 }
290
291 #[test]
292 fn socks_reply_maps_known_errors() {
293 assert_eq!(Helpers::get_socks_reply(0), 0x00);
294 assert_eq!(Helpers::get_socks_reply(10061), 0x05); assert_eq!(Helpers::get_socks_reply(123_456), 0x01); }
297
298 #[test]
299 fn get_id_is_four_alphanumerics() {
300 let id = Helpers::get_id();
301 assert_eq!(id.len(), 4);
302 assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
303 }
304}