ratrodlib/
connect.rs

1//! This module contains the code for the client-side of the tunnel.
2//!
3//! It includes the state machine, operations, and configuration.
4
5use std::{collections::HashMap, marker::PhantomData, net::SocketAddr, sync::Arc};
6
7use anyhow::{Context, anyhow};
8use bytes::{Bytes, BytesMut};
9use futures::join;
10use secrecy::SecretString;
11use tokio::{
12    net::{TcpListener, TcpStream, UdpSocket},
13    select,
14    sync::{
15        Mutex,
16        mpsc::{UnboundedReceiver, UnboundedSender},
17    },
18    task::JoinHandle,
19};
20use tracing::{Instrument, error, info, info_span};
21
22use crate::{
23    base::{ClientHandshakeData, ClientKeyExchangeData, Constant, Res, TunnelDefinition, Void},
24    buffed_stream::{BincodeSplit, BuffedTcpStream},
25    protocol::{BincodeReceive, BincodeSend, Challenge, ClientAuthentication, ClientPreamble, ProtocolMessage},
26    security::{resolve_keypath, resolve_known_hosts, resolve_private_key, resolve_public_key},
27    utils::{generate_challenge, generate_ephemeral_key_pair, generate_shared_secret, handle_tcp_pump, parse_tunnel_definitions, random_string, sign_challenge, validate_signed_challenge},
28};
29
30// State machine.
31
32/// The client is in the configuration state.
33pub struct ConfigState;
34/// The client is in the ready state.
35pub struct ReadyState;
36
37/// The client instance.
38///
39/// This is the main entry point for the client. It is used to connect, configure, and start the client.
40pub struct Instance<S = ConfigState> {
41    tunnel_definitions: Vec<TunnelDefinition>,
42    config: Config,
43    _phantom: PhantomData<S>,
44}
45
46impl Instance<ConfigState> {
47    /// Prepares the client instance.
48    pub fn prepare<A, B, C>(key_path: A, connect_address: B, tunnel_definitions: &[C], accept_all_hosts: bool, should_encrypt: bool) -> Res<Instance<ReadyState>>
49    where
50        A: Into<Option<String>>,
51        B: Into<String>,
52        C: AsRef<str>,
53    {
54        let tunnel_definitions = parse_tunnel_definitions(tunnel_definitions)?;
55
56        let key_path = resolve_keypath(key_path)?;
57        let private_key = resolve_private_key(&key_path)?;
58        let public_key = resolve_public_key(&key_path)?;
59        let known_hosts = resolve_known_hosts(&key_path);
60
61        let config = Config::new(public_key, private_key, known_hosts, connect_address.into(), accept_all_hosts, should_encrypt)?;
62
63        Ok(Instance {
64            tunnel_definitions,
65            config,
66            _phantom: PhantomData,
67        })
68    }
69}
70
71impl Instance<ReadyState> {
72    /// Starts the client instance.
73    ///
74    /// This is the main entry point for the client. It is used to connect, configure, and start the client
75    pub async fn start(self) -> Void {
76        // Finally, start the server(s) (one per tunnel definition).
77
78        let tasks = self
79            .tunnel_definitions
80            .into_iter()
81            .map(|tunnel_definition| async {
82                // Schedule a test connection.
83                tokio::spawn(test_server_connection(tunnel_definition.clone(), self.config.clone()));
84
85                // Start the servers.
86                let tcp = tokio::spawn(run_tcp_server(tunnel_definition.clone(), self.config.clone()));
87                let udp = tokio::spawn(run_udp_server(tunnel_definition, self.config.clone()));
88
89                let (tcp_result, udp_result) = join!(tcp, udp);
90
91                tcp_result?;
92                udp_result?;
93
94                Void::Ok(())
95            })
96            .collect::<Vec<_>>();
97
98        // Only exit if _all_ of the servers fail to start.  Otherwise, the user can use the error logs to see that some of the
99        // servers failed to start.  As a result, we _do not_ return an error, since the user can see the errors in the logs.
100        futures::future::join_all(tasks).await;
101
102        Ok(())
103    }
104}
105
106// Operations.
107
108/// Sends the preamble to the server.
109///
110/// This is the first message sent to the server. It contains the remote address and the peer public key
111/// for the future key exchange.
112async fn send_preamble<T, R>(stream: &mut T, config: &Config, remote_address: R, exchange_public_key: &[u8], is_udp: bool) -> Res<Challenge>
113where
114    T: BincodeSend,
115    R: AsRef<str>,
116{
117    if exchange_public_key.len() != Constant::EXCHANGE_PUBLIC_KEY_SIZE {
118        return Err(anyhow!(
119            "Invalid exchange public key size: expected {} bytes, got {} bytes",
120            Constant::EXCHANGE_PUBLIC_KEY_SIZE,
121            exchange_public_key.len()
122        ));
123    }
124
125    let challenge = generate_challenge();
126
127    let preamble = ClientPreamble {
128        exchange_public_key,
129        remote: remote_address.as_ref(),
130        challenge: &challenge,
131        should_encrypt: config.should_encrypt,
132        is_udp,
133    };
134
135    stream.push(ProtocolMessage::ClientPreamble(preamble)).await?;
136
137    info!("✅ Sent preamble to server ...");
138
139    Ok(challenge)
140}
141
142/// Handles the challenge from the server.
143///
144/// This is the second message sent to the server. It receives the challenge,
145/// signs it, and sends the signature back to the server.
146async fn handle_challenge<T>(stream: &mut T, config: &Config, client_challenge: &Challenge) -> Res<ClientHandshakeData>
147where
148    T: BincodeSend + BincodeReceive,
149{
150    // Wait for the server's preamble.
151
152    let guard = stream.pull().await?;
153    let ProtocolMessage::ServerPreamble(server_preamble) = guard.message() else {
154        return Err(anyhow!("Handshake failed: improper message type (expected handshake challenge)"));
155    };
156
157    let result = ClientHandshakeData {
158        server_challenge: server_preamble.challenge.try_into()?,
159        server_exchange_public_key: server_preamble.exchange_public_key.try_into()?,
160    };
161
162    // Validate the server's signature.
163
164    validate_signed_challenge(client_challenge, server_preamble.signature, server_preamble.identity_public_key)?;
165
166    info!("✅ Server's signature validated with public key `{}` ...", server_preamble.identity_public_key);
167
168    // Ensure that the server is in the `known_hosts` file.
169
170    if !config.accept_all_hosts && !config.known_hosts.iter().any(|k| k == server_preamble.identity_public_key) {
171        // Client doesn't really need to tell the server about failures, so will error and break the pipe.
172        return Err(anyhow!("Server's public key `{}` is not in the known hosts file", server_preamble.identity_public_key));
173    }
174
175    info!("🚧 Signing server challenge ...");
176
177    let client_signature = sign_challenge(server_preamble.challenge, &config.private_key)?;
178    let client_authentication = ClientAuthentication {
179        identity_public_key: &config.public_key,
180        signature: &client_signature,
181    };
182    stream.push(ProtocolMessage::ClientAuthentication(client_authentication)).await?;
183
184    info!("⏳ Awaiting challenge validation ...");
185
186    let guard = stream.pull().await?;
187    let ProtocolMessage::HandshakeCompletion = guard.message().fail_if_error()? else {
188        return Err(anyhow!("Handshake failed: improper message type (expected handshake completion)"));
189    };
190
191    Ok(result)
192}
193
194/// Handles the handshake with the server.
195async fn handle_handshake<T, R>(stream: &mut T, config: &Config, remote_address: R, is_udp: bool) -> Res<ClientKeyExchangeData>
196where
197    T: BincodeSend + BincodeReceive,
198    R: AsRef<str>,
199{
200    // If we want to request encryption, we need to generate an ephemeral key pair, and send the public key to the server.
201    let exchange_key_pair = generate_ephemeral_key_pair()?;
202    let exchange_public_key = exchange_key_pair.public_key.as_ref();
203
204    let client_challenge = send_preamble(stream, config, remote_address, exchange_public_key, is_udp).await?;
205    let handshake_data = handle_challenge(stream, config, &client_challenge).await?;
206
207    // Compute the ephemeral data.
208
209    let ephemeral_data = ClientKeyExchangeData {
210        server_exchange_public_key: handshake_data.server_exchange_public_key,
211        server_challenge: handshake_data.server_challenge,
212        local_exchange_private_key: exchange_key_pair.private_key,
213        local_challenge: client_challenge,
214    };
215
216    info!("✅ Challenge accepted!");
217
218    Ok(ephemeral_data)
219}
220
221/// Connects to the requested remote.
222async fn server_connect(connect_address: &str) -> Res<TcpStream> {
223    let stream = TcpStream::connect(connect_address).await?;
224    info!("✅ Connected to server `{}` ...", connect_address);
225
226    Ok(stream)
227}
228
229/// Establishes the e2e connection with server.
230async fn connect(config: &Config, remote_address: &str, is_udp: bool) -> Res<BuffedTcpStream> {
231    // Connect to the server.
232    let server = server_connect(&config.connect_address).await?;
233    server.set_nodelay(true)?;
234
235    let mut server = BuffedTcpStream::from(server);
236
237    // Handle the handshake.
238    let handshake_data = handle_handshake(&mut server, config, remote_address, is_udp).await.context("Error handling handshake")?;
239
240    info!("✅ Handshake successful: connection established!");
241
242    // Generate and apply the shared secret, if needed.
243    if config.should_encrypt {
244        let salt_bytes = [handshake_data.server_challenge, handshake_data.local_challenge].concat();
245
246        let shared_secret = generate_shared_secret(handshake_data.local_exchange_private_key, &handshake_data.server_exchange_public_key, &salt_bytes)?;
247
248        server = server.with_encryption(shared_secret);
249        info!("🔒 Encryption applied ...");
250    }
251
252    Ok(server)
253}
254
255// TCP connection.
256
257/// Runs the TCP server.
258///
259/// This is the main entry point for the server. It is used to accept connections and handle them.
260async fn run_tcp_server(tunnel_definition: TunnelDefinition, config: Config) {
261    let result: Void = async move {
262        let listener = TcpListener::bind(&tunnel_definition.bind_address).await?;
263
264        info!(
265            "📻 [TCP] Listening on `{}`, and routing through `{}` to `{}` ...",
266            tunnel_definition.bind_address, config.connect_address, tunnel_definition.remote_address
267        );
268
269        loop {
270            let (socket, _) = listener.accept().await?;
271
272            tokio::spawn(handle_tcp(socket, tunnel_definition.remote_address.clone(), config.clone()));
273        }
274    }
275    .await;
276
277    if let Err(err) = result {
278        error!("❌ Error starting TCP server, or accepting a connection (shutting down listener for this bind address): {}", err);
279    }
280}
281
282/// Handles the TCP connection.
283///
284/// This is the main entry point for the connection. It is used to handle the handshake and pump data between the client and server.
285async fn handle_tcp(local: TcpStream, remote_address: String, config: Config) {
286    let id = random_string(6);
287    let span = info_span!("tcp", id = id);
288
289    let result: Void = async move {
290        // Connect.
291
292        let server = connect(&config, &remote_address, false).await?;
293
294        // Handle the TCP pump.
295
296        info!("⛽ Pumping data between client and remote ...");
297
298        local.set_nodelay(true)?;
299
300        handle_tcp_pump(local, server).await.context("Error handling pump")?;
301
302        info!("✅ Connection closed.");
303
304        Ok(())
305    }
306    .instrument(span.clone())
307    .await;
308
309    // Enter the span, so that the error is logged with the span's metadata, if needed.
310    let _guard = span.enter();
311
312    if let Err(err) = result {
313        let chain = err.chain().collect::<Vec<_>>();
314        let full_chain = chain.iter().map(|e| format!("`{}`", e)).collect::<Vec<_>>().join(" => ");
315
316        error!("❌ Error handling the connection: {}.", full_chain);
317    }
318}
319
320// UDP connection.
321
322/// Runs the UDP server.
323///
324/// This is the main entry point for the server. It is used to accept connections and handle them.
325async fn run_udp_server(tunnel_definition: TunnelDefinition, config: Config) {
326    let result: Void = async move {
327        let socket = Arc::new(UdpSocket::bind(&tunnel_definition.bind_address).await?);
328
329        info!(
330            "📻 [UDP] Listening on `{}`, and routing through `{}` to `{}` ...",
331            tunnel_definition.bind_address, config.connect_address, tunnel_definition.remote_address
332        );
333
334        let clients = Arc::new(Mutex::new(HashMap::<SocketAddr, UnboundedSender<Bytes>>::new()));
335        let mut buffer = BytesMut::with_capacity(2 * Constant::BUFFER_SIZE);
336
337        loop {
338            // Clear and reclaim the bufer.
339            buffer.clear();
340            buffer.reserve(Constant::BUFFER_SIZE);
341
342            // Receive a datagram.
343            unsafe { buffer.set_len(Constant::BUFFER_SIZE) };
344            let (read, addr) = socket.recv_from(&mut buffer).await?;
345            unsafe { buffer.set_len(read) };
346
347            let data = buffer.split().freeze();
348
349            // Handle the packet.
350
351            if let Some(data_sender) = clients.lock().await.get_mut(&addr) {
352                // In the case where we already have a connection, we should push the message into the channel.
353                data_sender.send(data)?;
354            } else {
355                // In this case, we need to create a new connection.
356                let socket_clone = socket.clone();
357                let config_clone = config.clone();
358
359                // Create a new channel for the client.
360                let (data_sender, data_receiver) = tokio::sync::mpsc::unbounded_channel();
361                data_sender.send(data)?;
362                clients.lock().await.insert(addr, data_sender);
363
364                // Spawn a new task to handle the connection.
365                let clients_clone = clients.clone();
366                let remote_address = tunnel_definition.remote_address.clone();
367                tokio::spawn(async move {
368                    // Handle the connection.
369                    handle_udp(addr, socket_clone, data_receiver, remote_address, config_clone).await;
370
371                    // Remove the client from the list of clients.
372                    clients_clone.lock().await.remove(&addr);
373                });
374            }
375        }
376    }
377    .await;
378
379    if let Err(err) = result {
380        error!("❌ Error starting UDP server, or accepting a connection (shutting down listener for this bind address): {}", err);
381    }
382}
383
384/// Handles a new UDP connection.
385async fn handle_udp(address: SocketAddr, client_socket: Arc<UdpSocket>, mut data_receiver: UnboundedReceiver<Bytes>, remote_address: String, config: Config) {
386    let id = random_string(6);
387    let span = info_span!("udp", id = id);
388
389    let result: Void = async move {
390        // Connect.
391
392        let server = connect(&config, &remote_address, true).await?;
393
394        // Handle the UDP pump.
395
396        info!("⛽ Pumping data between client and remote ...");
397
398        let client_socket_clone = client_socket.clone();
399        let (mut remote_read, mut remote_write) = server.into_split();
400
401        // Connection will be closed automatically when either client side disconnects or
402        // when the server detects inactivity timeout. No explicit disconnect logic needed here.
403
404        let pump_up: JoinHandle<Void> = tokio::spawn(async move {
405            while let Some(data) = data_receiver.recv().await {
406                dbg!("client up {}", String::from_utf8_lossy(&data));
407                remote_write.push(ProtocolMessage::UdpData(&data)).await?;
408            }
409
410            Ok(())
411        });
412
413        let pump_down: JoinHandle<Void> = tokio::spawn(async move {
414            loop {
415                let guard = remote_read.pull().await?;
416                let ProtocolMessage::UdpData(data) = guard.message() else {
417                    break;
418                };
419
420                client_socket_clone.send_to(data, &address).await?;
421            }
422
423            Ok(())
424        });
425
426        // Wait for either side to finish (server handles the connection closing when it has not detected activity on the pump).
427        // Essentially, we are waiting for either side to finish, or to time out.  The server will handle the timeout, which will close the
428        // TCP side, which will then close the UDP side (and then the client is removed from the client list).
429
430        let result = select! {
431            r = pump_up => r?,
432            r = pump_down => r?,
433        };
434
435        // Check for errors.
436
437        result?;
438
439        Ok(())
440    }
441    .instrument(span.clone())
442    .await;
443
444    // Enter the span, so that the error is logged with the span's metadata, if needed.
445    let _guard = span.enter();
446
447    if let Err(err) = result {
448        let chain = err.chain().collect::<Vec<_>>();
449        let full_chain = chain.iter().map(|e| format!("`{}`", e)).collect::<Vec<_>>().join(" => ");
450
451        error!("❌ Error handling the connection: {}.", full_chain);
452    }
453}
454
455// Client connection tests.
456
457/// Tests the server connection by performing a handshake.
458async fn test_server_connection(tunnel_definition: TunnelDefinition, config: Config) -> Void {
459    info!("⏳ Testing server connection ...");
460
461    // Connect to the server.
462    let mut remote = BuffedTcpStream::from(server_connect(&config.connect_address).await?);
463
464    // Handle the handshake.
465    if let Err(e) = handle_handshake(&mut remote, &config, &tunnel_definition.remote_address, false).await {
466        error!("❌ Test connection failed: {}", e);
467        return Err(e);
468    }
469
470    info!("✅ Test connection successful!");
471
472    Ok(())
473}
474
475// Config.
476
477/// The configuration for the client.
478///
479/// This is used to store the private key, the connect address, and whether or not to encrypt the connection.
480#[derive(Clone)]
481pub(crate) struct Config {
482    pub(crate) public_key: String,
483    pub(crate) private_key: SecretString,
484    pub(crate) known_hosts: Vec<String>,
485    pub(crate) connect_address: String,
486    pub(crate) accept_all_hosts: bool,
487    pub(crate) should_encrypt: bool,
488}
489
490impl Config {
491    /// Creates a new configuration.
492    fn new(public_key: String, private_key: SecretString, known_hosts: Vec<String>, connect_address: String, accept_all_hosts: bool, should_encrypt: bool) -> Res<Self> {
493        Ok(Self {
494            public_key,
495            private_key,
496            connect_address,
497            known_hosts,
498            accept_all_hosts,
499            should_encrypt,
500        })
501    }
502}
503
504// Tests.
505
506#[cfg(test)]
507pub mod tests {
508    use crate::utils::{
509        generate_key_pair,
510        tests::{generate_test_duplex, generate_test_fake_exchange_public_key},
511    };
512
513    use super::*;
514    use pretty_assertions::assert_eq;
515
516    pub(crate) fn generate_test_client_config() -> Config {
517        let key_path = "test/client";
518
519        let public_key = resolve_public_key(key_path).unwrap();
520        let private_key = resolve_private_key(key_path).unwrap();
521        let known_hosts = resolve_known_hosts(key_path);
522
523        Config {
524            public_key,
525            private_key,
526            known_hosts,
527            connect_address: "connect_address".to_string(),
528            accept_all_hosts: false,
529            should_encrypt: false,
530        }
531    }
532
533    #[test]
534    fn test_prepare() {
535        let key_path = "test/client";
536        let connect_address = "connect_address";
537        let tunnel_definitions = ["localhost:5000:example.com:80", "127.0.0.1:6000:api.example.com:443"];
538        let accept_all_hosts = false;
539        let should_encrypt = false;
540
541        let instance = Instance::prepare(key_path.to_owned(), connect_address, &tunnel_definitions, accept_all_hosts, should_encrypt).unwrap();
542
543        // Verify config
544        assert_eq!(instance.config.connect_address, connect_address);
545        assert_eq!(instance.config.should_encrypt, should_encrypt);
546
547        // Verify the public key was loaded correctly
548        let expected_public_key = resolve_public_key(key_path).unwrap();
549        assert_eq!(instance.config.public_key, expected_public_key);
550
551        // Verify known hosts were loaded correctly
552        let expected_known_hosts = resolve_known_hosts(key_path);
553        assert_eq!(instance.config.known_hosts, expected_known_hosts);
554
555        // Verify tunnel definitions
556        assert_eq!(instance.tunnel_definitions.len(), 2);
557        assert_eq!(instance.tunnel_definitions[0].bind_address, "localhost:5000");
558        assert_eq!(instance.tunnel_definitions[0].remote_address, "example.com:80");
559        assert_eq!(instance.tunnel_definitions[1].bind_address, "127.0.0.1:6000");
560        assert_eq!(instance.tunnel_definitions[1].remote_address, "api.example.com:443");
561    }
562
563    #[tokio::test]
564    async fn test_send_preamble() {
565        let (mut client, mut server) = generate_test_duplex();
566        let config = generate_test_client_config();
567        let remote_address = "remote_address:3000";
568        let exchange_public_key = &generate_test_fake_exchange_public_key();
569
570        let client_challenge = send_preamble(&mut client, &config, remote_address, exchange_public_key, false).await.unwrap();
571
572        let guard = server.pull().await.unwrap();
573        match guard.message() {
574            ProtocolMessage::ClientPreamble(preamble) => {
575                assert_eq!(preamble.remote, remote_address);
576                assert_eq!(preamble.exchange_public_key, exchange_public_key);
577                assert_eq!(preamble.challenge, client_challenge);
578                assert_eq!(preamble.should_encrypt, config.should_encrypt);
579            }
580            _ => panic!("Expected ClientPreamble, got different message type"),
581        }
582    }
583
584    #[tokio::test]
585    async fn test_handle_challenge_bad_key() {
586        let (mut client, mut server) = generate_test_duplex();
587        let config = generate_test_client_config();
588        let client_challenge = generate_challenge();
589        let bad_key = generate_key_pair().unwrap().private_key;
590
591        tokio::spawn(async move {
592            // Create and send ServerPreamble with unknown key
593            let preamble = crate::protocol::ServerPreamble {
594                identity_public_key: &bad_key,
595                signature: &[0u8; 64], // Mock signature
596                challenge: &generate_challenge(),
597                exchange_public_key: &generate_test_fake_exchange_public_key(),
598            };
599
600            server.push(ProtocolMessage::ServerPreamble(preamble)).await.unwrap();
601        });
602
603        let result = handle_challenge(&mut client, &config, &client_challenge).await;
604
605        assert!(result.is_err());
606        assert_eq!(result.unwrap_err().to_string(), "Invalid signature");
607    }
608
609    #[tokio::test]
610    async fn test_handle_challenge_wrong_message_type() {
611        let (mut client, mut server) = generate_test_duplex();
612        let config = generate_test_client_config();
613        let client_challenge = generate_challenge();
614
615        tokio::spawn(async move {
616            // Send wrong message type
617            server.push(ProtocolMessage::HandshakeCompletion).await.unwrap();
618        });
619
620        let result = handle_challenge(&mut client, &config, &client_challenge).await;
621
622        assert!(result.is_err());
623        assert!(result.unwrap_err().to_string().contains("improper message type"));
624    }
625}