ratrodlib/
connect.rs

1//! This module contains the code for the client-side of the tunnel.
2//!
3//! It includes the state machine, operations, and configuration.
4
5use std::{collections::HashMap, marker::PhantomData, net::SocketAddr, sync::Arc};
6
7use anyhow::Context;
8use futures::join;
9use secrecy::SecretString;
10use tokio::{
11    net::{TcpListener, TcpStream, UdpSocket}, select, sync::{
12        mpsc::{UnboundedReceiver, UnboundedSender}, Mutex
13    }, task::JoinHandle
14};
15use tracing::{Instrument, error, info, info_span};
16
17use crate::{
18    base::{ClientHandshakeData, ClientKeyExchangeData, Constant, Err, Res, TunnelDefinition, Void},
19    buffed_stream::BuffedTcpStream,
20    protocol::{BincodeReceive, BincodeSend, Challenge, ClientAuthentication, ClientPreamble, ExchangePublicKey, ProtocolMessage},
21    security::{resolve_keypath, resolve_known_hosts, resolve_private_key, resolve_public_key},
22    utils::{generate_challenge, generate_ephemeral_key_pair, generate_shared_secret, handle_pump, parse_tunnel_definitions, random_string, sign_challenge, validate_signed_challenge},
23};
24
25// State machine.
26
27/// The client is in the configuration state.
28pub struct ConfigState;
29/// The client is in the ready state.
30pub struct ReadyState;
31
32/// The client instance.
33///
34/// This is the main entry point for the client. It is used to connect, configure, and start the client.
35pub struct Instance<S = ConfigState> {
36    tunnel_definitions: Vec<TunnelDefinition>,
37    config: Config,
38    _phantom: PhantomData<S>,
39}
40
41impl Instance<ConfigState> {
42    /// Prepares the client instance.
43    pub fn prepare<A, B, C>(key_path: A, connect_address: B, tunnel_definitions: &[C], accept_all_hosts: bool, should_encrypt: bool) -> Res<Instance<ReadyState>>
44    where
45        A: Into<Option<String>>,
46        B: Into<String>,
47        C: AsRef<str>,
48    {
49        let tunnel_definitions = parse_tunnel_definitions(tunnel_definitions)?;
50
51        let key_path = resolve_keypath(key_path)?;
52        let private_key = resolve_private_key(&key_path)?;
53        let public_key = resolve_public_key(&key_path)?;
54        let known_hosts = resolve_known_hosts(&key_path);
55
56        let config = Config::new(public_key, private_key, known_hosts, connect_address.into(), accept_all_hosts, should_encrypt)?;
57
58        Ok(Instance {
59            tunnel_definitions,
60            config,
61            _phantom: PhantomData,
62        })
63    }
64}
65
66impl Instance<ReadyState> {
67    /// Starts the client instance.
68    ///
69    /// This is the main entry point for the client. It is used to connect, configure, and start the client
70    pub async fn start(self) -> Void {
71        // Finally, start the server(s) (one per tunnel definition).
72
73        let tasks = self
74            .tunnel_definitions
75            .into_iter()
76            .map(|tunnel_definition| async {
77                // Schedule a test connection.
78                tokio::spawn(test_server_connection(tunnel_definition.clone(), self.config.clone()));
79
80                // Start the servers.
81                let tcp = tokio::spawn(run_tcp_server(tunnel_definition.clone(), self.config.clone()));
82                let udp = tokio::spawn(run_udp_server(tunnel_definition, self.config.clone()));
83
84                let (tcp_result, udp_result) = join!(tcp, udp);
85
86                tcp_result?;
87                udp_result?;
88
89                Void::Ok(())
90            })
91            .collect::<Vec<_>>();
92
93        // Basically, only crash if _all_ of the servers fail to start.  Otherwise, the user can use the error logs to see that some of the
94        // servers failed to start.  As a result, we _do not_ log an error, since the user can see the errors in the logs.
95        futures::future::join_all(tasks).await;
96
97        Ok(())
98    }
99}
100
101// Operations.
102
103/// Sends the preamble to the server.
104///
105/// This is the first message sent to the server. It contains the remote address and the peer public key
106/// for the future key exchange.
107async fn send_preamble<T, R>(stream: &mut T, config: &Config, remote_address: R, exchange_public_key: ExchangePublicKey, is_udp: bool) -> Res<Challenge>
108where
109    T: BincodeSend,
110    R: Into<String>,
111{
112    let challenge = generate_challenge();
113
114    let preamble = ClientPreamble {
115        exchange_public_key,
116        remote: remote_address.into(),
117        challenge,
118        should_encrypt: config.should_encrypt,
119        is_udp,
120    };
121
122    stream.push(ProtocolMessage::ClientPreamble(preamble)).await?;
123
124    info!("✅ Sent preamble to server ...");
125
126    Ok(challenge)
127}
128
129/// Handles the challenge from the server.
130///
131/// This is the second message sent to the server. It receives the challenge,
132/// signs it, and sends the signature back to the server.
133async fn handle_challenge<T>(stream: &mut T, config: &Config, client_challenge: &Challenge) -> Res<ClientHandshakeData>
134where
135    T: BincodeSend + BincodeReceive,
136{
137    // Wait for the server's preamble.
138
139    let ProtocolMessage::ServerPreamble(server_preamble) = stream.pull().await? else {
140        return Err(Err::msg("Handshake failed: improper message type (expected handshake challenge)"));
141    };
142
143    // Validate the server's signature.
144
145    validate_signed_challenge(client_challenge, &server_preamble.signature.into(), &server_preamble.identity_public_key)?;
146
147    info!("✅ Server's signature validated with public key `{}` ...", server_preamble.identity_public_key);
148
149    // Ensure that the server is in the `known_hosts` file.
150
151    if !config.accept_all_hosts && !config.known_hosts.contains(&server_preamble.identity_public_key) {
152        // Client doesn't really need to tell the server about failures, so will error and break the pipe.
153        return Err(Err::msg(format!("Server's public key `{}` is not in the known hosts file", server_preamble.identity_public_key)));
154    }
155
156    info!("🚧 Signing server challenge ...");
157
158    let client_signature = sign_challenge(&server_preamble.challenge, &config.private_key)?;
159    let client_authentication = ClientAuthentication {
160        identity_public_key: config.public_key.clone(),
161        signature: client_signature.into(),
162    };
163    stream.push(ProtocolMessage::ClientAuthentication(client_authentication)).await?;
164
165    info!("⏳ Awaiting challenge validation ...");
166
167    let ProtocolMessage::HandshakeCompletion = stream.pull().await?.fail_if_error()? else {
168        return Err(Err::msg("Handshake failed: improper message type (expected handshake completion)"));
169    };
170
171    Ok(ClientHandshakeData {
172        server_challenge: server_preamble.challenge,
173        server_exchange_public_key: server_preamble.exchange_public_key,
174    })
175}
176
177/// Handles the handshake with the server.
178async fn handle_handshake<T, R>(stream: &mut T, config: &Config, remote_address: R, is_udp: bool) -> Res<ClientKeyExchangeData>
179where
180    T: BincodeSend + BincodeReceive,
181    R: Into<String>,
182{
183    // If we want to request encryption, we need to generate an ephemeral key pair, and send the public key to the server.
184    let exchange_key_pair = generate_ephemeral_key_pair()?;
185    let exchange_public_key = exchange_key_pair.public_key.as_ref().try_into().map_err(|_| Err::msg("Could not convert peer public key to array"))?;
186
187    let client_challenge = send_preamble(stream, config, remote_address, exchange_public_key, is_udp).await?;
188    let handshake_data = handle_challenge(stream, config, &client_challenge).await?;
189
190    // Compute the ephemeral data.
191
192    let ephemeral_data = ClientKeyExchangeData {
193        server_exchange_public_key: handshake_data.server_exchange_public_key,
194        server_challenge: handshake_data.server_challenge,
195        local_exchange_private_key: exchange_key_pair.private_key,
196        local_challenge: client_challenge,
197    };
198
199    info!("✅ Challenge accepted!");
200
201    Ok(ephemeral_data)
202}
203
204/// Connects to the requested remote.
205async fn server_connect(connect_address: &str) -> Res<TcpStream> {
206    let stream = TcpStream::connect(connect_address).await?;
207    info!("✅ Connected to server `{}` ...", connect_address);
208
209    Ok(stream)
210}
211
212/// Establishes the e2e connection with server.
213async fn connect(config: &Config, remote_address: &str, is_udp: bool) -> Res<BuffedTcpStream> {
214    // Connect to the server.
215    let mut server = BuffedTcpStream::from(server_connect(&config.connect_address).await?);
216
217    // Handle the handshake.
218    let handshake_data = handle_handshake(&mut server, config, remote_address, is_udp).await.context("Error handling handshake")?;
219
220    info!("✅ Handshake successful: connection established!");
221
222    // Generate and apply the shared secret, if needed.
223    if config.should_encrypt {
224        let salt_bytes = [handshake_data.server_challenge, handshake_data.local_challenge].concat();
225
226        let shared_secret = generate_shared_secret(handshake_data.local_exchange_private_key, &handshake_data.server_exchange_public_key, &salt_bytes)?;
227
228        server = server.with_encryption(shared_secret);
229        info!("🔒 Encryption applied ...");
230    }
231
232    Ok(server)
233}
234
235// TCP connection.
236
237/// Runs the TCP server.
238///
239/// This is the main entry point for the server. It is used to accept connections and handle them.
240async fn run_tcp_server(tunnel_definition: TunnelDefinition, config: Config) {
241    let result: Void = async move {
242        let listener = TcpListener::bind(&tunnel_definition.bind_address).await?;
243
244        info!(
245            "📻 [TCP] Listening on `{}`, and routing through `{}` to `{}` ...",
246            tunnel_definition.bind_address, config.connect_address, tunnel_definition.remote_address
247        );
248
249        loop {
250            let (socket, _) = listener.accept().await?;
251
252            tokio::spawn(handle_tcp(socket, tunnel_definition.remote_address.clone(), config.clone()));
253        }
254    }
255    .await;
256
257    if let Err(err) = result {
258        error!("❌ Error starting TCP server, or accepting a connection (shutting down listener for this bind address): {}", err);
259    }
260}
261
262/// Handles the TCP connection.
263///
264/// This is the main entry point for the connection. It is used to handle the handshake and pump data between the client and server.
265async fn handle_tcp(mut local: TcpStream, remote_address: String, config: Config) {
266    let id = random_string(6);
267    let span = info_span!("tcp", id = id);
268
269    let result: Void = async move {
270        // Connect.
271
272        let mut server = connect(&config, &remote_address, false).await?;
273
274        // Handle the TCP pump.
275
276        info!("⛽ Pumping data between client and remote ...");
277
278        handle_pump(&mut local, &mut server).await.context("Error handling pump")?;
279
280        info!("✅ Connection closed.");
281
282        Ok(())
283    }
284    .instrument(span.clone())
285    .await;
286
287    // Enter the span, so that the error is logged with the span's metadata, if needed.
288    let _guard = span.enter();
289
290    if let Err(err) = result {
291        let chain = err.chain().collect::<Vec<_>>();
292        let full_chain = chain.iter().map(|e| format!("`{}`", e)).collect::<Vec<_>>().join(" => ");
293
294        error!("❌ Error handling the connection: {}.", full_chain);
295    }
296}
297
298// UDP connection.
299
300/// Runs the UDP server.
301///
302/// This is the main entry point for the server. It is used to accept connections and handle them.
303async fn run_udp_server(tunnel_definition: TunnelDefinition, config: Config) {
304    let result: Void = async move {
305        let socket = Arc::new(UdpSocket::bind(&tunnel_definition.bind_address).await?);
306
307        info!(
308            "📻 [UDP] Listening on `{}`, and routing through `{}` to `{}` ...",
309            tunnel_definition.bind_address, config.connect_address, tunnel_definition.remote_address
310        );
311
312        let clients = Arc::new(Mutex::new(HashMap::<SocketAddr, UnboundedSender<Vec<u8>>>::new()));
313
314        loop {
315            // Receive a datagram.
316
317            // TODO: _technically_, this could be up to 65,507 bytes, but that would be a bit silly, so 8 KB should be fine (since most systems use the MTU of 1500).
318            let mut buf = vec![0; Constant::BUFFER_SIZE];
319            let (read, addr) = socket.recv_from(&mut buf).await?;
320            buf.truncate(read);
321
322            // Handle the packet.
323
324            if let Some(data_sender) = clients.lock().await.get_mut(&addr) {
325                // In the case where we already have a connection, we should push the message into the channel.
326                data_sender.send(buf)?;
327            } else {
328                // In this case, we need to create a new connection.
329                let socket_clone = socket.clone();
330                let config_clone = config.clone();
331
332                // Create a new channel for the client.
333                let (data_sender, data_receiver) = tokio::sync::mpsc::unbounded_channel();
334                data_sender.send(buf)?;
335                clients.lock().await.insert(addr, data_sender);
336
337                // Spawn a new task to handle the connection.
338                let clients_clone = clients.clone();
339                let remote_address = tunnel_definition.remote_address.clone();
340                tokio::spawn(async move {
341                    // Handle the connection.
342                    handle_udp(addr, socket_clone, data_receiver, remote_address, config_clone).await;
343
344                    // Remove the client from the list of clients.
345                    clients_clone.lock().await.remove(&addr);
346                });
347            }
348        }
349    }
350    .await;
351
352    if let Err(err) = result {
353        error!("❌ Error starting UDP server, or accepting a connection (shutting down listener for this bind address): {}", err);
354    }
355}
356
357/// Handles a new UDP connection.
358async fn handle_udp(address: SocketAddr, client_socket: Arc<UdpSocket>, mut data_receiver: UnboundedReceiver<Vec<u8>>, remote_address: String, config: Config) {
359    let id = random_string(6);
360    let span = info_span!("tcp", id = id);
361
362    let result: Void = async move {
363        // Connect.
364
365        let server = connect(&config, &remote_address, true).await?;
366
367        // Handle the UDP pump.
368
369        info!("⛽ Pumping data between client and remote ...");
370
371        let client_socket_clone = client_socket.clone();
372        let (mut remote_read, mut remote_write) = server.into_split();
373
374        // TODO: Figure out logic below to get them to disconnect?
375
376        let pump_up: JoinHandle<Void> = tokio::spawn(async move {
377            while let Some(data) = data_receiver.recv().await {
378                remote_write.push(ProtocolMessage::UdpData(data.to_vec())).await?;
379            }
380
381            Ok(())
382        });
383
384        let pump_down: JoinHandle<Void> = tokio::spawn(async move {
385            while let ProtocolMessage::UdpData(data) = remote_read.pull().await? {
386                client_socket_clone.send_to(&data, &address).await?;
387            }
388
389            Ok(())
390        });
391
392        // Wait for either side to finish (server handles the connection closing when it has not detected activity on the pump).
393        // Essentially, we are waiting for either side to finish, or to time out.  The server will handle the timeout, which will close the
394        // TCP side, which will then close the UDP side (and then the client is removed from the client list).
395
396        let result = select! {
397            r = pump_up => r?,
398            r = pump_down => r?,
399        };
400
401        // Check for errors.
402
403        result?;
404
405        Ok(())
406    }
407    .instrument(span.clone())
408    .await;
409
410    // Enter the span, so that the error is logged with the span's metadata, if needed.
411    let _guard = span.enter();
412
413    if let Err(err) = result {
414        let chain = err.chain().collect::<Vec<_>>();
415        let full_chain = chain.iter().map(|e| format!("`{}`", e)).collect::<Vec<_>>().join(" => ");
416
417        error!("❌ Error handling the connection: {}.", full_chain);
418    }
419}
420
421// Client connection tests.
422
423/// Tests the server connection by performing a handshake.
424async fn test_server_connection(tunnel_definition: TunnelDefinition, config: Config) -> Void {
425    info!("⏳ Testing server connection ...");
426
427    // Connect to the server.
428    let mut remote = BuffedTcpStream::from(server_connect(&config.connect_address).await?);
429
430    // Handle the handshake.
431    if let Err(e) = handle_handshake(&mut remote, &config, &tunnel_definition.remote_address, false).await {
432        error!("❌ Test connection failed: {}", e);
433        return Err(e);
434    }
435
436    info!("✅ Test connection successful!");
437
438    Ok(())
439}
440
441// Config.
442
443/// The configuration for the client.
444///
445/// This is used to store the private key, the connect address, and whether or not to encrypt the connection.
446#[derive(Clone)]
447pub(crate) struct Config {
448    pub(crate) public_key: String,
449    pub(crate) private_key: SecretString,
450    pub(crate) known_hosts: Vec<String>,
451    pub(crate) connect_address: String,
452    pub(crate) accept_all_hosts: bool,
453    pub(crate) should_encrypt: bool,
454}
455
456impl Config {
457    /// Creates a new configuration.
458    fn new(public_key: String, private_key: SecretString, known_hosts: Vec<String>, connect_address: String, accept_all_hosts: bool, should_encrypt: bool) -> Res<Self> {
459        Ok(Self {
460            public_key,
461            private_key,
462            connect_address,
463            known_hosts,
464            accept_all_hosts,
465            should_encrypt,
466        })
467    }
468}
469
470// Tests.
471
472#[cfg(test)]
473pub mod tests {
474    use crate::utils::{
475        generate_key_pair,
476        tests::{generate_test_duplex, generate_test_fake_exchange_public_key},
477    };
478
479    use super::*;
480    use pretty_assertions::assert_eq;
481
482    pub(crate) fn generate_test_client_config() -> Config {
483        let key_path = "test/client";
484
485        let public_key = resolve_public_key(key_path).unwrap();
486        let private_key = resolve_private_key(key_path).unwrap();
487        let known_hosts = resolve_known_hosts(key_path);
488
489        Config {
490            public_key,
491            private_key,
492            known_hosts,
493            connect_address: "connect_address".to_string(),
494            accept_all_hosts: false,
495            should_encrypt: false,
496        }
497    }
498
499    #[test]
500    fn test_prepare() {
501        let key_path = "test/client";
502        let connect_address = "connect_address";
503        let tunnel_definitions = ["localhost:5000:example.com:80", "127.0.0.1:6000:api.example.com:443"];
504        let accrpt_all_hosts = false;
505        let should_encrypt = false;
506
507        let instance = Instance::prepare(key_path.to_owned(), connect_address, &tunnel_definitions, accrpt_all_hosts, should_encrypt).unwrap();
508
509        // Verify config
510        assert_eq!(instance.config.connect_address, connect_address);
511        assert_eq!(instance.config.should_encrypt, should_encrypt);
512
513        // Verify the public key was loaded correctly
514        let expected_public_key = resolve_public_key(key_path).unwrap();
515        assert_eq!(instance.config.public_key, expected_public_key);
516
517        // Verify known hosts were loaded correctly
518        let expected_known_hosts = resolve_known_hosts(key_path);
519        assert_eq!(instance.config.known_hosts, expected_known_hosts);
520
521        // Verify tunnel definitions
522        assert_eq!(instance.tunnel_definitions.len(), 2);
523        assert_eq!(instance.tunnel_definitions[0].bind_address, "localhost:5000");
524        assert_eq!(instance.tunnel_definitions[0].remote_address, "example.com:80");
525        assert_eq!(instance.tunnel_definitions[1].bind_address, "127.0.0.1:6000");
526        assert_eq!(instance.tunnel_definitions[1].remote_address, "api.example.com:443");
527    }
528
529    #[tokio::test]
530    async fn test_send_preamble() {
531        let (mut client, mut server) = generate_test_duplex();
532        let config = generate_test_client_config();
533        let remote_address = "remote_address:3000";
534        let exchange_public_key = generate_test_fake_exchange_public_key();
535
536        let client_challenge = send_preamble(&mut client, &config, remote_address, exchange_public_key, false).await.unwrap();
537
538        let received = server.pull().await.unwrap();
539
540        match received {
541            ProtocolMessage::ClientPreamble(preamble) => {
542                assert_eq!(preamble.remote, remote_address);
543                assert_eq!(preamble.exchange_public_key, exchange_public_key);
544                assert_eq!(preamble.challenge, client_challenge);
545                assert_eq!(preamble.should_encrypt, config.should_encrypt);
546            }
547            _ => panic!("Expected ClientPreamble, got different message type"),
548        }
549    }
550
551    #[tokio::test]
552    async fn test_handle_challenge_bad_key() {
553        let (mut client, mut server) = generate_test_duplex();
554        let config = generate_test_client_config();
555        let client_challenge = generate_challenge();
556        let bad_key = generate_key_pair().unwrap().private_key;
557
558        tokio::spawn(async move {
559            // Create and send ServerPreamble with unknown key
560            let preamble = crate::protocol::ServerPreamble {
561                identity_public_key: bad_key,
562                signature: [0u8; 64].into(), // Mock signature
563                challenge: generate_challenge(),
564                exchange_public_key: generate_test_fake_exchange_public_key(),
565            };
566
567            server.push(ProtocolMessage::ServerPreamble(preamble)).await.unwrap();
568        });
569
570        let result = handle_challenge(&mut client, &config, &client_challenge).await;
571
572        assert!(result.is_err());
573        assert_eq!(result.unwrap_err().to_string(), "Invalid signature");
574    }
575
576    #[tokio::test]
577    async fn test_handle_challenge_wrong_message_type() {
578        let (mut client, mut server) = generate_test_duplex();
579        let config = generate_test_client_config();
580        let client_challenge = generate_challenge();
581
582        tokio::spawn(async move {
583            // Send wrong message type
584            server.push(ProtocolMessage::HandshakeCompletion).await.unwrap();
585        });
586
587        let result = handle_challenge(&mut client, &config, &client_challenge).await;
588
589        assert!(result.is_err());
590        assert!(result.unwrap_err().to_string().contains("improper message type"));
591    }
592}