1use tokio::io::AsyncWriteExt;
2use tokio::net::TcpStream;
3use tokio::{io::AsyncReadExt, task::JoinHandle};
4
5use phf::{Map, phf_map};
6use std::iter::IntoIterator;
7use std::net::{IpAddr, SocketAddr};
8use std::str::FromStr;
9use tracing::{debug, error, info, warn};
10
11use crate::handshake::Handshake;
12use crate::helpers::{Helpers, IntoError, Res, Void};
13use crate::request::{Destination, Request};
14use crate::buffer_pool::Buffer;
16use crate::copy_pump::CopyPump;
17
18pub struct Connection {
19 id: String,
20 client_socket: TcpStream,
21 endpoint_interface: String,
22 buffer: Buffer,
23 read_timeout: u64,
24}
25
26impl Connection {
27 pub fn from(client_socket: TcpStream, endpoint_interface: String, buffer: Buffer, read_timeout: u64) -> Self {
28 Connection {
29 id: Helpers::get_id(),
30 client_socket,
31 endpoint_interface,
32 buffer,
33 read_timeout,
34 }
35 }
36
37 pub fn handle(self) -> JoinHandle<()> {
40 debug!("[{}] Start.", self.id);
41
42 tokio::spawn(async move {
44 match self.handle_task().await {
45 Ok(_) => {}
46 Err(e) => {
47 error!("{}", e);
48 }
49 }
50 })
51 }
52
53 async fn handle_task(mut self) -> Void {
54 let buffer = &mut self.buffer.get().await[..];
56
57 let handshake = Connection::perform_handshake(&mut self.client_socket, buffer).await?;
60 let methods_string = handshake.methods.into_iter().map(|m| m.to_string()).collect::<Vec<String>>().join(",");
61
62 debug!("[{}] Handshake:", self.id);
63 debug!("[{}] Version: {}", self.id, handshake.version);
64 debug!("[{}] Num Methods: {}", self.id, handshake.num_methods);
65 debug!("[{}] Methods: {}", self.id, methods_string);
66
67 let request = Connection::perform_request_negotiation(&mut self.client_socket, buffer).await?;
70 let destination = match &request.destination {
71 Destination::Ipv4Addr(ipv4) => ipv4.to_string(),
72 Destination::Ipv6Addr(ipv6) => ipv6.to_string(),
73 Destination::Domain(s) => s.to_owned(),
74 };
75
76 debug!("[{}] Request:", self.id);
77 debug!("[{}] Version: {}", self.id, request.version);
78 debug!("[{}] Command: {}", self.id, COMMANDS[&request.command]);
79 debug!("[{}] Address Type: {}", self.id, ADDRESS_TYPES[&request.address_type]);
80 debug!("[{}] Destination: {}", self.id, destination);
81 debug!("[{}] Port: {}", self.id, request.port);
82
83 let endpoint_socket = match request.command {
86 0x01 => Connection::establish_connect_request(&mut self.client_socket, &self.endpoint_interface, &request, buffer).await?,
87 0x02 => return "BIND requests not supported.".into_error(),
88 0x03 => return "UDP ASSOCIATE requests not supported.".into_error(),
89 _ => return "Unknown command type.".into_error()
90 };
91
92 let client_peer_addr = self.client_socket.peer_addr()?;
95 let client_local_addr = self.client_socket.local_addr()?;
96 let endpoint_local_addr = endpoint_socket.local_addr()?;
97 let endpoint_peer_addr = endpoint_socket.peer_addr()?;
98
99 info!("[{}] {} => {} => {} => {}", self.id, client_peer_addr, client_local_addr, endpoint_local_addr, endpoint_peer_addr);
100
101 match CopyPump::from(self.client_socket, endpoint_socket, self.read_timeout).start().await {
105 Ok(_) => {}
106 Err(e) => {
107 warn!("[{}] The pump ended with an error. {}", self.id, e);
108 }
109 }
110
111 debug!("[{}] End.", self.id);
112
113 Ok(())
114 }
115
116 async fn perform_handshake(client_socket: &mut TcpStream, buffer: &mut [u8]) -> Res<Handshake> {
117 let read = client_socket.read(buffer).await?;
118
119 if read == 0 {
120 return "Read 0 bytes during handshake.".into_error();
121 }
122
123 let handshake = Handshake::from_data(buffer)?;
124
125 if handshake.version != 5 {
126 return "Bad SOCKS version.".into_error();
127 }
128
129 buffer[0] = 0x05; buffer[1] = 0x00; client_socket.write_all(&buffer[..2]).await?;
135 client_socket.flush().await?;
136
137 Ok(handshake)
138 }
139
140 async fn perform_request_negotiation(client_socket: &mut TcpStream, buffer: &mut [u8]) -> Res<Request> {
141 let read = client_socket.read(buffer).await?;
142
143 if read == 0 {
144 return "Read 0 bytes during connection negotiation.".into_error();
145 }
146
147 let request = Request::from_data(buffer)?;
148
149 Ok(request)
150 }
151
152 async fn establish_connect_request(client_socket: &mut TcpStream, endpoint_interface: &str, request: &Request, buffer: &mut [u8]) -> Res<TcpStream> {
153 let mut reply = 0u8;
154
155 let local_addr = SocketAddr::from_str(&format!("{}:{}", endpoint_interface, 0))?;
157
158 let string_to_connect = format!("{}:{}", request.destination, request.port);
160 let endpoint_addr_iterator = tokio::net::lookup_host(&string_to_connect).await;
161
162 let endpoint_socket = match endpoint_addr_iterator {
164 Ok(endpoint_addresses) => {
165 match Helpers::create_local_socket(local_addr, endpoint_addresses) {
167 Some(ep) => {
168 let socket = ep.socket;
169 let endpoint_addr = ep.address;
170
171 match socket.connect(endpoint_addr).await {
173 Ok(s) => Some(s),
174 Err(e) => {
175 warn!("Could not connect to `{}` (`{}`).", string_to_connect, endpoint_addr);
176
177 reply = match e.raw_os_error() {
178 Some(i) => Helpers::get_socks_reply(i),
179 _ => 5u8, };
181
182 None
183 }
184 }
185 }
186 None => {
187 warn!(
188 "Could not create local socket (`{}`) to `{}`. This likely means that we could not find a suitable address type for the endpoint that matches the endpoint interface type (i.e., IPv6/IPv4 mismatch).",
189 local_addr, string_to_connect
190 );
191
192 reply = 5u8; None
195 }
196 }
197 }
198 Err(e) => {
199 warn!("Could not compute an endpoint address for `{}`.", string_to_connect);
200
201 reply = match e.raw_os_error() {
202 Some(i) => Helpers::get_socks_reply(i),
203 _ => 8u8, };
205
206 None
207 }
208 };
209
210 let local_ip = local_addr.ip();
212 let (port_high, port_low) = Helpers::port_to_bytes(local_addr.port());
213
214 buffer[0] = 0x05; buffer[1] = reply;
218 buffer[2] = 0x0; let reply_length = match local_ip {
221 IpAddr::V4(ipv4) => {
222 let octets = ipv4.octets();
223
224 buffer[3] = 0x01; buffer[4] = octets[0];
226 buffer[5] = octets[1];
227 buffer[6] = octets[2];
228 buffer[7] = octets[3];
229 Helpers::write_octets(&mut buffer[4..8], &octets);
230
231 buffer[8] = port_high;
232 buffer[9] = port_low;
233
234 10
235 }
236 IpAddr::V6(ipv6) => {
237 let octets = ipv6.octets();
238
239 buffer[3] = 0x04; Helpers::write_octets(&mut buffer[4..20], &octets);
241
242 buffer[20] = port_high;
243 buffer[21] = port_low;
244
245 22
246 }
247 };
248
249 client_socket.write_all(&buffer[0..reply_length]).await?;
252 client_socket.flush().await?;
253
254 if reply != 0 {
257 return format!("The connection to `{}` failed gracefully with `{}`.", string_to_connect, ERRORS[&reply]).into_error();
258 }
259
260 Ok(endpoint_socket.unwrap())
262 }
263}
264
265static COMMANDS: Map<u8, &'static str> = phf_map! {
266 1u8 => "Connect",
267 2u8 => "Bind",
268 3u8 => "UDP Associate",
269};
270
271static ADDRESS_TYPES: Map<u8, &'static str> = phf_map! {
272 1u8 => "Ipv4",
273 3u8 => "Domain",
274 4u8 => "Ipv6",
275};
276
277static ERRORS: Map<u8, &'static str> = phf_map! {
278 0u8 => "Succeeded",
279 1u8 => "General SOCKS Server Failure",
280 3u8 => "Network Unreachable",
281 4u8 => "Host Unreachable",
282 5u8 => "Connection Refused",
283 6u8 => "TTL Expired",
284 8u8 => "Address type not supported"
285};