1use std::fmt::Display;
2use std::net::{Ipv4Addr, Ipv6Addr};
3
4use crate::helpers::{Helpers, IntoError, Res};
5
6pub struct Request {
7 pub version: u8,
8 pub command: u8,
9 pub reserved: u8,
10 pub address_type: u8,
11 pub port: u16,
12 pub destination: Destination,
13}
14
15pub enum Destination {
16 Ipv4Addr(Ipv4Addr),
17 Ipv6Addr(Ipv6Addr),
18 Domain(String),
19}
20
21impl Display for Destination {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match &self {
24 Self::Ipv4Addr(ipv4) => write!(f, "{}", ipv4),
25 Self::Ipv6Addr(ipv6) => write!(f, "{}", ipv6),
26 Self::Domain(domain) => write!(f, "{}", domain),
27 }
28 }
29}
30
31impl Request {
32 pub fn from_data(data: &[u8]) -> Res<Self> {
33 if data.len() < 4 {
35 return "Request too short: need at least the four-byte header.".into_error();
36 }
37
38 let version = data[0];
39 let command = data[1];
40 let reserved = data[2];
41 let address_type = data[3];
42
43 match address_type {
44 0x01 => {
45 if data.len() < 10 {
47 return "Request too short for an IPv4 address.".into_error();
48 }
49
50 let address = Ipv4Addr::from(Helpers::slice_to_u32(&data[4..8])?);
51 let port = Helpers::bytes_to_port(&data[8..10])?;
52
53 Ok(Request {
54 version,
55 command,
56 reserved,
57 address_type,
58 port,
59 destination: Destination::Ipv4Addr(address),
60 })
61 }
62 0x03 => {
63 if data.len() < 5 {
65 return "Request too short for a domain name.".into_error();
66 }
67
68 let name_length = data[4] as usize;
69 let port_start = 5 + name_length;
70
71 if data.len() < port_start + 2 {
72 return "Request too short for the stated domain length.".into_error();
73 }
74
75 let name = std::str::from_utf8(&data[5..port_start])?.to_owned();
76 let port = Helpers::bytes_to_port(&data[port_start..port_start + 2])?;
77
78 Ok(Request {
79 version,
80 command,
81 reserved,
82 address_type,
83 port,
84 destination: Destination::Domain(name),
85 })
86 }
87 0x04 => {
88 if data.len() < 22 {
90 return "Request too short for an IPv6 address.".into_error();
91 }
92
93 let address = Ipv6Addr::from(Helpers::slice_to_u128(&data[4..20])?);
94 let port = Helpers::bytes_to_port(&data[20..22])?;
95
96 Ok(Request {
97 version,
98 command,
99 reserved,
100 address_type,
101 port,
102 destination: Destination::Ipv6Addr(address),
103 })
104 }
105 _ => "Unknown request type, or data corrupt.".into_error(),
106 }
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use pretty_assertions::assert_eq;
114 use std::net::{Ipv4Addr, Ipv6Addr};
115
116 #[test]
117 fn parses_ipv4_connect() {
118 let data = [0x05, 0x01, 0x00, 0x01, 93, 184, 216, 34, 0x01, 0xBB];
120 let req = Request::from_data(&data).unwrap();
121
122 assert_eq!(req.version, 5);
123 assert_eq!(req.command, 1);
124 assert_eq!(req.address_type, 1);
125 assert_eq!(req.port, 443);
126 match req.destination {
127 Destination::Ipv4Addr(ip) => assert_eq!(ip, Ipv4Addr::new(93, 184, 216, 34)),
128 other => panic!("expected ipv4 destination, got {other}"),
129 }
130 }
131
132 #[test]
133 fn parses_domain_connect() {
134 let domain = b"example.com";
135 let mut data = vec![0x05, 0x01, 0x00, 0x03, domain.len() as u8];
136 data.extend_from_slice(domain);
137 data.extend_from_slice(&[0x00, 0x50]); let req = Request::from_data(&data).unwrap();
140
141 assert_eq!(req.address_type, 3);
142 assert_eq!(req.port, 80);
143 match req.destination {
144 Destination::Domain(name) => assert_eq!(name, "example.com"),
145 other => panic!("expected domain destination, got {other}"),
146 }
147 }
148
149 #[test]
150 fn parses_ipv6_connect() {
151 let ip = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1);
152 let mut data = vec![0x05, 0x01, 0x00, 0x04];
153 data.extend_from_slice(&ip.octets());
154 data.extend_from_slice(&[0x1F, 0x90]); let req = Request::from_data(&data).unwrap();
157
158 assert_eq!(req.address_type, 4);
159 assert_eq!(req.port, 8080);
160 match req.destination {
161 Destination::Ipv6Addr(parsed) => assert_eq!(parsed, ip),
162 other => panic!("expected ipv6 destination, got {other}"),
163 }
164 }
165
166 #[test]
167 fn rejects_unknown_address_type() {
168 let data = [0x05, 0x01, 0x00, 0x09, 0, 0, 0, 0, 0, 0];
169 assert!(Request::from_data(&data).is_err());
170 }
171
172 #[test]
173 fn rejects_truncated_header() {
174 assert!(Request::from_data(&[0x05, 0x01]).is_err());
175 }
176
177 #[test]
178 fn rejects_truncated_ipv4() {
179 assert!(Request::from_data(&[0x05, 0x01, 0x00, 0x01, 127, 0, 0]).is_err());
181 }
182
183 #[test]
184 fn rejects_domain_length_overrun() {
185 assert!(Request::from_data(&[0x05, 0x01, 0x00, 0x03, 50, b'a', b'b']).is_err());
187 }
188}