rusty_sockslib/
helpers.rs1use rand::RngExt;
2use rand::distr::Alphanumeric;
3use std::error::Error;
4use std::fmt::Display;
5use std::{fmt::Formatter, net::SocketAddr};
6
7use if_addrs::get_if_addrs;
8use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
9use tokio::net::TcpSocket;
10
11pub enum Cidr {
12 V4(u32, u32),
13 V6(u128, u128),
14}
15
16impl Cidr {
17 pub fn is_trivial(&self) -> bool {
18 match self {
19 Cidr::V4(_, mask) => *mask == 0,
20 Cidr::V6(_, mask) => *mask == 0,
21 }
22 }
23}
24
25pub struct EndpointPair {
26 pub socket: TcpSocket,
27 pub address: SocketAddr,
28}
29
30pub struct Helpers;
31
32impl Helpers {
33 pub fn get_id() -> String {
34 rand::rng().sample_iter(Alphanumeric).take(4).map(char::from).collect::<String>()
35 }
36
37 pub fn bytes_to_port(data: &[u8]) -> Res<u16> {
38 if data.len() != 2 {
39 return "There must be exactly two (2) bytes for a conversion to a port.".into_error();
40 }
41
42 Ok(((data[0] as u16) << 8) + (data[1] as u16))
43 }
44
45 pub fn port_to_bytes(port: u16) -> (u8, u8) {
46 ((port >> 8) as u8, (port & 0xff) as u8)
47 }
48
49 pub fn slice_to_u32(data: &[u8]) -> Res<u32> {
50 if data.len() != 4 {
51 return "There must be exactly four (4) bytes for a conversion to an IPv4.".into_error();
52 }
53
54 Ok(((data[0] as u32) << 24) + ((data[1] as u32) << 16) + ((data[2] as u32) << 8) + (data[3] as u32))
55 }
56
57 pub fn slice_to_u128(data: &[u8]) -> Res<u128> {
58 if data.len() != 16 {
59 return "There must be exactly sixteen (16) bytes for a conversion to an IPv6.".into_error();
60 }
61
62 Ok(((data[0] as u128) << 120)
63 + ((data[1] as u128) << 112)
64 + ((data[2] as u128) << 104)
65 + ((data[3] as u128) << 96)
66 + ((data[4] as u128) << 88)
67 + ((data[5] as u128) << 80)
68 + ((data[6] as u128) << 72)
69 + ((data[7] as u128) << 64)
70 + ((data[8] as u128) << 56)
71 + ((data[9] as u128) << 48)
72 + ((data[10] as u128) << 40)
73 + ((data[11] as u128) << 32)
74 + ((data[12] as u128) << 24)
75 + ((data[13] as u128) << 16)
76 + ((data[14] as u128) << 8)
77 + (data[15] as u128))
78 }
79
80 pub fn get_socks_reply(error: i32) -> u8 {
81 match error {
82 0 => 0x00, 10050 | 10051 => 0x03, 10064 | 11001 | 10065 => 0x04, 10061 => 0x05, 10060 => 0x06, _ => 0x01, }
89 }
90
91 pub fn write_octets(buffer: &mut [u8], octets: &[u8]) {
92 buffer[..octets.len()].clone_from_slice(octets);
93 }
94
95 pub fn get_interface_ip(name: &str) -> Res<IpAddr> {
96 for iface in get_if_addrs()? {
97 if iface.name == name {
98 return Ok(iface.ip());
99 }
100 }
101
102 format!("Could not lookup IP for interface `{}`.", name).into_error()
103 }
104
105 pub fn mask_ipv4(ip: &Ipv4Addr, mask: u32) -> Res<u32> {
106 Ok(Helpers::slice_to_u32(&ip.octets())? & mask)
107 }
108
109 pub fn mask_ipv6(ip: &Ipv6Addr, mask: u128) -> Res<u128> {
110 Ok(Helpers::slice_to_u128(&ip.octets())? & mask)
111 }
112
113 pub fn is_ip_in_cidr(ip_addr: &IpAddr, cidr: &Cidr) -> Res<bool> {
114 match cidr {
115 Cidr::V4(prefix, mask) => match &ip_addr {
116 IpAddr::V4(ip) => Ok(Helpers::mask_ipv4(ip, *mask)? == *prefix),
117 _ => Err(Box::new(GenericError::from("Cannot check IPv6 addresses against IPv4 CIDRs."))),
118 },
119 Cidr::V6(prefix, mask) => match &ip_addr {
120 IpAddr::V6(ip) => Ok(Helpers::mask_ipv6(ip, *mask)? == *prefix),
121 _ => Err(Box::new(GenericError::from("Cannot check IPv4 addresses against IPv6 CIDRs."))),
122 },
123 }
124 }
125
126 pub fn parse_cidr(s: &str) -> Res<Cidr> {
127 let splits = s.split('/').collect::<Vec<&str>>();
128
129 let ip_addr = splits[0].parse::<IpAddr>()?;
130 let num_mask_bits = splits[1].parse::<u32>()?;
131
132 match ip_addr {
133 IpAddr::V4(ip) => {
134 if num_mask_bits > 32 {
135 return Err(Box::new(GenericError::from("An IPv4 CIDR prefix must have a mask bit length less than or equal to 32.")));
136 }
137
138 let mask = !(2u32.overflowing_pow(32 - num_mask_bits).0.overflowing_sub(1).0);
139 let prefix = Helpers::slice_to_u32(&ip.octets())? & mask;
140
141 Ok(Cidr::V4(prefix, mask))
142 }
143 IpAddr::V6(ip) => {
144 if num_mask_bits > 128 {
145 return Err(Box::new(GenericError::from("An IPv4 CIDR prefix must have a mask bit length less than or equal to 128.")));
146 }
147
148 let mask = !(2u128.overflowing_pow(128 - num_mask_bits).0.overflowing_sub(1).0);
149 let prefix = Helpers::slice_to_u128(&ip.octets())? & mask;
150
151 Ok(Cidr::V6(prefix, mask))
152 }
153 }
154 }
155
156 pub fn create_local_socket(local_addr: SocketAddr, mut endpoint_addresses: impl Iterator<Item = SocketAddr>) -> Option<EndpointPair> {
157 let is_endpoint_interface_ipv6 = local_addr.is_ipv6();
158
159 let endpoint_addr = if is_endpoint_interface_ipv6 {
160 endpoint_addresses.find(|a| a.is_ipv6())
161 } else {
162 endpoint_addresses.find(|a| a.is_ipv4())
163 };
164
165 let endpoint_addr = endpoint_addr?;
166
167 let socket = if endpoint_addr.is_ipv4() { TcpSocket::new_v4().ok()? } else { TcpSocket::new_v6().ok()? };
169
170 socket.bind(local_addr).ok()?;
171
172 Some(EndpointPair { socket, address: endpoint_addr })
173 }
174}
175
176pub type Void = Result<(), Box<dyn std::error::Error>>;
177pub type Res<T> = Result<T, Box<dyn std::error::Error>>;
178
179pub trait IntoError<T> {
180 fn into_error(self) -> Res<T>;
181}
182
183impl<T, S> IntoError<T> for S
184where
185 S: AsRef<str> + ToString,
186{
187 fn into_error(self) -> Res<T> {
188 Err(Box::new(GenericError::from(self)))
189 }
190}
191
192#[derive(Debug)]
193pub struct GenericError {
194 message: String,
195}
196
197impl<T> From<T> for GenericError
198where
199 T: AsRef<str> + ToString,
200{
201 fn from(message: T) -> Self {
202 GenericError { message: message.to_string() }
203 }
204}
205
206impl Display for GenericError {
207 fn fmt<'a>(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
208 write!(f, "{}", self.message)
209 }
210}
211
212impl Error for GenericError {
213 fn source(&self) -> Option<&(dyn Error + 'static)> {
214 None
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use pretty_assertions::assert_eq;
222 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
223
224 #[test]
225 fn port_roundtrips_through_bytes() {
226 for port in [0u16, 80, 443, 1080, 65535] {
227 let (hi, lo) = Helpers::port_to_bytes(port);
228 assert_eq!(Helpers::bytes_to_port(&[hi, lo]).unwrap(), port);
229 }
230 }
231
232 #[test]
233 fn bytes_to_port_rejects_wrong_length() {
234 assert!(Helpers::bytes_to_port(&[0]).is_err());
235 assert!(Helpers::bytes_to_port(&[0, 0, 0]).is_err());
236 }
237
238 #[test]
239 fn slice_to_u32_parses_and_validates_length() {
240 assert_eq!(Helpers::slice_to_u32(&[127, 0, 0, 1]).unwrap(), 0x7f00_0001);
241 assert!(Helpers::slice_to_u32(&[0, 0, 0]).is_err());
242 }
243
244 #[test]
245 fn slice_to_u128_parses_and_validates_length() {
246 assert_eq!(Helpers::slice_to_u128(&[0u8; 16]).unwrap(), 0);
247 assert!(Helpers::slice_to_u128(&[0u8; 15]).is_err());
248 }
249
250 #[test]
251 fn parse_cidr_zero_mask_is_trivial() {
252 assert!(Helpers::parse_cidr("0.0.0.0/0").unwrap().is_trivial());
253 assert!(Helpers::parse_cidr("::/0").unwrap().is_trivial());
254 }
255
256 #[test]
257 fn parse_cidr_rejects_oversized_mask() {
258 assert!(Helpers::parse_cidr("10.0.0.0/33").is_err());
259 assert!(Helpers::parse_cidr("::/129").is_err());
260 }
261
262 #[test]
263 fn cidr_membership_v4() {
264 let cidr = Helpers::parse_cidr("10.216.0.0/16").unwrap();
265 let inside = IpAddr::V4(Ipv4Addr::new(10, 216, 5, 5));
266 let outside = IpAddr::V4(Ipv4Addr::new(10, 217, 0, 1));
267 assert!(Helpers::is_ip_in_cidr(&inside, &cidr).unwrap());
268 assert!(!Helpers::is_ip_in_cidr(&outside, &cidr).unwrap());
269 }
270
271 #[test]
272 fn cidr_membership_v6() {
273 let cidr = Helpers::parse_cidr("2001:db8::/32").unwrap();
274 let inside = IpAddr::V6("2001:db8::1".parse::<Ipv6Addr>().unwrap());
275 let outside = IpAddr::V6("2001:dead::1".parse::<Ipv6Addr>().unwrap());
276 assert!(Helpers::is_ip_in_cidr(&inside, &cidr).unwrap());
277 assert!(!Helpers::is_ip_in_cidr(&outside, &cidr).unwrap());
278 }
279
280 #[test]
281 fn cidr_membership_rejects_family_mismatch() {
282 let v4_cidr = Helpers::parse_cidr("10.0.0.0/8").unwrap();
283 let v6_addr = IpAddr::V6(Ipv6Addr::LOCALHOST);
284 assert!(Helpers::is_ip_in_cidr(&v6_addr, &v4_cidr).is_err());
285 }
286
287 #[test]
288 fn socks_reply_maps_known_errors() {
289 assert_eq!(Helpers::get_socks_reply(0), 0x00);
290 assert_eq!(Helpers::get_socks_reply(10061), 0x05); assert_eq!(Helpers::get_socks_reply(123_456), 0x01); }
293
294 #[test]
295 fn get_id_is_four_alphanumerics() {
296 let id = Helpers::get_id();
297 assert_eq!(id.len(), 4);
298 assert!(id.chars().all(|c| c.is_ascii_alphanumeric()));
299 }
300}