Skip to main content

rusty_sockslib/
request.rs

1use std::fmt::Display;
2use std::net::{Ipv4Addr, Ipv6Addr};
3
4use crate::helpers::{Helpers, IntoError, Res};
5
6pub struct Request {
7    pub version: u8,
8    pub command: u8,
9    pub reserved: u8,
10    pub address_type: u8,
11    pub port: u16,
12    pub destination: Destination,
13}
14
15pub enum Destination {
16    Ipv4Addr(Ipv4Addr),
17    Ipv6Addr(Ipv6Addr),
18    Domain(String),
19}
20
21impl Display for Destination {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match &self {
24            Self::Ipv4Addr(ipv4) => write!(f, "{}", ipv4),
25            Self::Ipv6Addr(ipv6) => write!(f, "{}", ipv6),
26            Self::Domain(domain) => write!(f, "{}", domain),
27        }
28    }
29}
30
31impl Request {
32    pub fn from_data(data: &[u8]) -> Res<Self> {
33        // VER, CMD, RSV, ATYP.
34        if data.len() < 4 {
35            return "Request too short: need at least the four-byte header.".into_error();
36        }
37
38        let version = data[0];
39        let command = data[1];
40        let reserved = data[2];
41        let address_type = data[3];
42
43        match address_type {
44            0x01 => {
45                // IPv4: four address bytes followed by a two-byte port.
46                if data.len() < 10 {
47                    return "Request too short for an IPv4 address.".into_error();
48                }
49
50                let address = Ipv4Addr::from(Helpers::slice_to_u32(&data[4..8])?);
51                let port = Helpers::bytes_to_port(&data[8..10])?;
52
53                Ok(Request {
54                    version,
55                    command,
56                    reserved,
57                    address_type,
58                    port,
59                    destination: Destination::Ipv4Addr(address),
60                })
61            }
62            0x03 => {
63                // Domain: a length byte, that many name bytes, then a two-byte port.
64                if data.len() < 5 {
65                    return "Request too short for a domain name.".into_error();
66                }
67
68                let name_length = data[4] as usize;
69                let port_start = 5 + name_length;
70
71                if data.len() < port_start + 2 {
72                    return "Request too short for the stated domain length.".into_error();
73                }
74
75                let name = std::str::from_utf8(&data[5..port_start])?.to_owned();
76                let port = Helpers::bytes_to_port(&data[port_start..port_start + 2])?;
77
78                Ok(Request {
79                    version,
80                    command,
81                    reserved,
82                    address_type,
83                    port,
84                    destination: Destination::Domain(name),
85                })
86            }
87            0x04 => {
88                // IPv6: sixteen address bytes followed by a two-byte port.
89                if data.len() < 22 {
90                    return "Request too short for an IPv6 address.".into_error();
91                }
92
93                let address = Ipv6Addr::from(Helpers::slice_to_u128(&data[4..20])?);
94                let port = Helpers::bytes_to_port(&data[20..22])?;
95
96                Ok(Request {
97                    version,
98                    command,
99                    reserved,
100                    address_type,
101                    port,
102                    destination: Destination::Ipv6Addr(address),
103                })
104            }
105            _ => "Unknown request type, or data corrupt.".into_error(),
106        }
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use pretty_assertions::assert_eq;
114    use std::net::{Ipv4Addr, Ipv6Addr};
115
116    #[test]
117    fn parses_ipv4_connect() {
118        // VER, CMD=CONNECT, RSV, ATYP=IPv4, 93.184.216.34, port 443.
119        let data = [0x05, 0x01, 0x00, 0x01, 93, 184, 216, 34, 0x01, 0xBB];
120        let req = Request::from_data(&data).unwrap();
121
122        assert_eq!(req.version, 5);
123        assert_eq!(req.command, 1);
124        assert_eq!(req.address_type, 1);
125        assert_eq!(req.port, 443);
126        match req.destination {
127            Destination::Ipv4Addr(ip) => assert_eq!(ip, Ipv4Addr::new(93, 184, 216, 34)),
128            other => panic!("expected ipv4 destination, got {other}"),
129        }
130    }
131
132    #[test]
133    fn parses_domain_connect() {
134        let domain = b"example.com";
135        let mut data = vec![0x05, 0x01, 0x00, 0x03, domain.len() as u8];
136        data.extend_from_slice(domain);
137        data.extend_from_slice(&[0x00, 0x50]); // port 80
138
139        let req = Request::from_data(&data).unwrap();
140
141        assert_eq!(req.address_type, 3);
142        assert_eq!(req.port, 80);
143        match req.destination {
144            Destination::Domain(name) => assert_eq!(name, "example.com"),
145            other => panic!("expected domain destination, got {other}"),
146        }
147    }
148
149    #[test]
150    fn parses_ipv6_connect() {
151        let ip = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1);
152        let mut data = vec![0x05, 0x01, 0x00, 0x04];
153        data.extend_from_slice(&ip.octets());
154        data.extend_from_slice(&[0x1F, 0x90]); // port 8080
155
156        let req = Request::from_data(&data).unwrap();
157
158        assert_eq!(req.address_type, 4);
159        assert_eq!(req.port, 8080);
160        match req.destination {
161            Destination::Ipv6Addr(parsed) => assert_eq!(parsed, ip),
162            other => panic!("expected ipv6 destination, got {other}"),
163        }
164    }
165
166    #[test]
167    fn rejects_unknown_address_type() {
168        let data = [0x05, 0x01, 0x00, 0x09, 0, 0, 0, 0, 0, 0];
169        assert!(Request::from_data(&data).is_err());
170    }
171
172    #[test]
173    fn rejects_truncated_header() {
174        assert!(Request::from_data(&[0x05, 0x01]).is_err());
175    }
176
177    #[test]
178    fn rejects_truncated_ipv4() {
179        // ATYP IPv4 but only part of the address is present.
180        assert!(Request::from_data(&[0x05, 0x01, 0x00, 0x01, 127, 0, 0]).is_err());
181    }
182
183    #[test]
184    fn rejects_domain_length_overrun() {
185        // Claims a 50-byte domain but provides only a couple of bytes.
186        assert!(Request::from_data(&[0x05, 0x01, 0x00, 0x03, 50, b'a', b'b']).is_err());
187    }
188}