Skip to main content

rusty_sockslib/
connection.rs

1use tokio::io::AsyncWriteExt;
2use tokio::net::TcpStream;
3use tokio::{io::AsyncReadExt, task::JoinHandle};
4
5use phf::{Map, phf_map};
6use std::iter::IntoIterator;
7use std::net::{IpAddr, SocketAddr};
8use std::str::FromStr;
9use tracing::{debug, error, info, warn};
10
11use crate::handshake::Handshake;
12use crate::helpers::{Helpers, IntoError, Res, Void};
13use crate::request::{Destination, Request};
14//use crate::custom_pump::CustomPump;
15use crate::buffer_pool::Buffer;
16use crate::copy_pump::CopyPump;
17
18pub struct Connection {
19    id: String,
20    client_socket: TcpStream,
21    endpoint_interface: String,
22    buffer: Buffer,
23    read_timeout: u64,
24}
25
26impl Connection {
27    pub fn from(client_socket: TcpStream, endpoint_interface: String, buffer: Buffer, read_timeout: u64) -> Self {
28        Connection {
29            id: Helpers::get_id(),
30            client_socket,
31            endpoint_interface,
32            buffer,
33            read_timeout,
34        }
35    }
36
37    // `self` Connection is moved when the handle method is called, and ownership is given
38    // fully to the thread, so `this` Connection will drop when the spawned thread ends.
39    pub fn handle(self) -> JoinHandle<()> {
40        debug!("[{}] Start.", self.id);
41
42        // Move self into the spawned thread, as well.
43        tokio::spawn(async move {
44            match self.handle_task().await {
45                Ok(_) => {}
46                Err(e) => {
47                    error!("{}", e);
48                }
49            }
50        })
51    }
52
53    async fn handle_task(mut self) -> Void {
54        // Get a &mut slice from the leased buffer.
55        let buffer = &mut self.buffer.get().await[..];
56
57        // Complete handshake.
58
59        let handshake = Connection::perform_handshake(&mut self.client_socket, buffer).await?;
60        let methods_string = handshake.methods.into_iter().map(|m| m.to_string()).collect::<Vec<String>>().join(",");
61
62        debug!("[{}]   Handshake:", self.id);
63        debug!("[{}]     Version: {}", self.id, handshake.version);
64        debug!("[{}]     Num Methods: {}", self.id, handshake.num_methods);
65        debug!("[{}]     Methods: {}", self.id, methods_string);
66
67        // Get request from client.
68
69        let request = Connection::perform_request_negotiation(&mut self.client_socket, buffer).await?;
70        let destination = match &request.destination {
71            Destination::Ipv4Addr(ipv4) => ipv4.to_string(),
72            Destination::Ipv6Addr(ipv6) => ipv6.to_string(),
73            Destination::Domain(s) => s.to_owned(),
74        };
75
76        debug!("[{}]   Request:", self.id);
77        debug!("[{}]     Version: {}", self.id, request.version);
78        debug!("[{}]     Command: {}", self.id, COMMANDS[&request.command]);
79        debug!("[{}]     Address Type: {}", self.id, ADDRESS_TYPES[&request.address_type]);
80        debug!("[{}]     Destination: {}", self.id, destination);
81        debug!("[{}]     Port: {}", self.id, request.port);
82
83        // Perform requested action.
84
85        let endpoint_socket = match request.command {
86            0x01 /* CONNECT */ => Connection::establish_connect_request(&mut self.client_socket, &self.endpoint_interface, &request, buffer).await?,
87            0x02 /* BIND */ => return "BIND requests not supported.".into_error(),
88            0x03 /* UDP ASSOCIATE */ => return "UDP ASSOCIATE requests not supported.".into_error(),
89            _ => return "Unknown command type.".into_error()
90        };
91
92        // Print the data path.
93
94        let client_peer_addr = self.client_socket.peer_addr()?;
95        let client_local_addr = self.client_socket.local_addr()?;
96        let endpoint_local_addr = endpoint_socket.local_addr()?;
97        let endpoint_peer_addr = endpoint_socket.peer_addr()?;
98
99        info!("[{}] {} => {} => {} => {}", self.id, client_peer_addr, client_local_addr, endpoint_local_addr, endpoint_peer_addr);
100
101        // Run the pump (all errors in pumps are emitted as log messages and should not disrupt the execution flow).
102
103        //CustomPump::from(&self.id, self.client_socket, endpoint_socket, buffer, self.read_timeout).start().await;
104        match CopyPump::from(self.client_socket, endpoint_socket, self.read_timeout).start().await {
105            Ok(_) => {}
106            Err(e) => {
107                warn!("[{}] The pump ended with an error.  {}", self.id, e);
108            }
109        }
110
111        debug!("[{}] End.", self.id);
112
113        Ok(())
114    }
115
116    async fn perform_handshake(client_socket: &mut TcpStream, buffer: &mut [u8]) -> Res<Handshake> {
117        let read = client_socket.read(buffer).await?;
118
119        if read == 0 {
120            return "Read 0 bytes during handshake.".into_error();
121        }
122
123        let handshake = Handshake::from_data(buffer)?;
124
125        if handshake.version != 5 {
126            return "Bad SOCKS version.".into_error();
127        }
128
129        // Reuse the buffer since we are borrowing it anyway.
130
131        buffer[0] = 0x05; // VERSION.
132        buffer[1] = 0x00; // NO AUTH.
133
134        client_socket.write_all(&buffer[..2]).await?;
135        client_socket.flush().await?;
136
137        Ok(handshake)
138    }
139
140    async fn perform_request_negotiation(client_socket: &mut TcpStream, buffer: &mut [u8]) -> Res<Request> {
141        let read = client_socket.read(buffer).await?;
142
143        if read == 0 {
144            return "Read 0 bytes during connection negotiation.".into_error();
145        }
146
147        let request = Request::from_data(buffer)?;
148
149        Ok(request)
150    }
151
152    async fn establish_connect_request(client_socket: &mut TcpStream, endpoint_interface: &str, request: &Request, buffer: &mut [u8]) -> Res<TcpStream> {
153        let mut reply = 0u8;
154
155        // Get requested local interface.
156        let local_addr = SocketAddr::from_str(&format!("{}:{}", endpoint_interface, 0))?;
157
158        // Get endpoint address.
159        let string_to_connect = format!("{}:{}", request.destination, request.port);
160        let endpoint_addr_iterator = tokio::net::lookup_host(&string_to_connect).await;
161
162        // Compute valid endpoint addresses, and connect to endpoint.
163        let endpoint_socket = match endpoint_addr_iterator {
164            Ok(endpoint_addresses) => {
165                // Try to create a local socket that can connect to the endpoint.
166                match Helpers::create_local_socket(local_addr, endpoint_addresses) {
167                    Some(ep) => {
168                        let socket = ep.socket;
169                        let endpoint_addr = ep.address;
170
171                        // Connect to endpoint.
172                        match socket.connect(endpoint_addr).await {
173                            Ok(s) => Some(s),
174                            Err(e) => {
175                                warn!("Could not connect to `{}` (`{}`).", string_to_connect, endpoint_addr);
176
177                                reply = match e.raw_os_error() {
178                                    Some(i) => Helpers::get_socks_reply(i),
179                                    _ => 5u8, // Connection refused?.
180                                };
181
182                                None
183                            }
184                        }
185                    }
186                    None => {
187                        warn!(
188                            "Could not create local socket (`{}`) to `{}`. This likely means that we could not find a suitable address type for the endpoint that matches the endpoint interface type (i.e., IPv6/IPv4 mismatch).",
189                            local_addr, string_to_connect
190                        );
191
192                        reply = 5u8; // Connection refused?.
193
194                        None
195                    }
196                }
197            }
198            Err(e) => {
199                warn!("Could not compute an endpoint address for `{}`.", string_to_connect);
200
201                reply = match e.raw_os_error() {
202                    Some(i) => Helpers::get_socks_reply(i),
203                    _ => 8u8, // Address type not supported.
204                };
205
206                None
207            }
208        };
209
210        // Get the local IP and port.
211        let local_ip = local_addr.ip();
212        let (port_high, port_low) = Helpers::port_to_bytes(local_addr.port());
213
214        // Prepare reply.
215
216        buffer[0] = 0x05; // VERSION.
217        buffer[1] = reply;
218        buffer[2] = 0x0; // RESERVED.
219
220        let reply_length = match local_ip {
221            IpAddr::V4(ipv4) => {
222                let octets = ipv4.octets();
223
224                buffer[3] = 0x01; // ADDRESS TYPE (IPv4).
225                buffer[4] = octets[0];
226                buffer[5] = octets[1];
227                buffer[6] = octets[2];
228                buffer[7] = octets[3];
229                Helpers::write_octets(&mut buffer[4..8], &octets);
230
231                buffer[8] = port_high;
232                buffer[9] = port_low;
233
234                10
235            }
236            IpAddr::V6(ipv6) => {
237                let octets = ipv6.octets();
238
239                buffer[3] = 0x04; // ADDRESS TYPE (IPv6).
240                Helpers::write_octets(&mut buffer[4..20], &octets);
241
242                buffer[20] = port_high;
243                buffer[21] = port_low;
244
245                22
246            }
247        };
248
249        // Send a response to the client, even if there is a failure.
250
251        client_socket.write_all(&buffer[0..reply_length]).await?;
252        client_socket.flush().await?;
253
254        // In a failure scenario, ensure the SOCKS process does not continue.
255
256        if reply != 0 {
257            return format!("The connection to `{}` failed gracefully with `{}`.", string_to_connect, ERRORS[&reply]).into_error();
258        }
259
260        // This should only be `None` if there is an error, which aborts above.
261        Ok(endpoint_socket.unwrap())
262    }
263}
264
265static COMMANDS: Map<u8, &'static str> = phf_map! {
266    1u8 => "Connect",
267    2u8 => "Bind",
268    3u8 => "UDP Associate",
269};
270
271static ADDRESS_TYPES: Map<u8, &'static str> = phf_map! {
272    1u8 => "Ipv4",
273    3u8 => "Domain",
274    4u8 => "Ipv6",
275};
276
277static ERRORS: Map<u8, &'static str> = phf_map! {
278    0u8 => "Succeeded",
279    1u8 => "General SOCKS Server Failure",
280    3u8 => "Network Unreachable",
281    4u8 => "Host Unreachable",
282    5u8 => "Connection Refused",
283    6u8 => "TTL Expired",
284    8u8 => "Address type not supported"
285};