1use std::{collections::HashMap, marker::PhantomData, net::SocketAddr, sync::Arc};
6
7use anyhow::{Context, anyhow};
8use bytes::{Bytes, BytesMut};
9use futures::join;
10use secrecy::SecretString;
11use tokio::{
12 net::{TcpListener, TcpStream, UdpSocket},
13 select,
14 sync::{
15 Mutex,
16 mpsc::{UnboundedReceiver, UnboundedSender},
17 },
18 task::JoinHandle,
19};
20use tracing::{Instrument, error, info, info_span};
21
22use crate::{
23 base::{ClientHandshakeData, ClientKeyExchangeData, Constant, Res, TunnelDefinition, Void},
24 buffed_stream::{BincodeSplit, BuffedTcpStream},
25 protocol::{BincodeReceive, BincodeSend, Challenge, ClientAuthentication, ClientPreamble, ProtocolMessage},
26 security::{resolve_keypath, resolve_known_hosts, resolve_private_key, resolve_public_key},
27 utils::{generate_challenge, generate_ephemeral_key_pair, generate_shared_secret, handle_tcp_pump, parse_tunnel_definitions, random_string, sign_challenge, validate_signed_challenge},
28};
29
30pub struct ConfigState;
34pub struct ReadyState;
36
37pub struct Instance<S = ConfigState> {
41 tunnel_definitions: Vec<TunnelDefinition>,
42 config: Config,
43 _phantom: PhantomData<S>,
44}
45
46impl Instance<ConfigState> {
47 pub fn prepare<A, B, C>(key_path: A, connect_address: B, tunnel_definitions: &[C], accept_all_hosts: bool, should_encrypt: bool) -> Res<Instance<ReadyState>>
49 where
50 A: Into<Option<String>>,
51 B: Into<String>,
52 C: AsRef<str>,
53 {
54 let tunnel_definitions = parse_tunnel_definitions(tunnel_definitions)?;
55
56 let key_path = resolve_keypath(key_path)?;
57 let private_key = resolve_private_key(&key_path)?;
58 let public_key = resolve_public_key(&key_path)?;
59 let known_hosts = resolve_known_hosts(&key_path);
60
61 let config = Config::new(public_key, private_key, known_hosts, connect_address.into(), accept_all_hosts, should_encrypt)?;
62
63 Ok(Instance {
64 tunnel_definitions,
65 config,
66 _phantom: PhantomData,
67 })
68 }
69}
70
71impl Instance<ReadyState> {
72 pub async fn start(self) -> Void {
76 let tasks = self
79 .tunnel_definitions
80 .into_iter()
81 .map(|tunnel_definition| async {
82 tokio::spawn(test_server_connection(tunnel_definition.clone(), self.config.clone()));
84
85 let tcp = tokio::spawn(run_tcp_server(tunnel_definition.clone(), self.config.clone()));
87 let udp = tokio::spawn(run_udp_server(tunnel_definition, self.config.clone()));
88
89 let (tcp_result, udp_result) = join!(tcp, udp);
90
91 tcp_result?;
92 udp_result?;
93
94 Void::Ok(())
95 })
96 .collect::<Vec<_>>();
97
98 futures::future::join_all(tasks).await;
101
102 Ok(())
103 }
104}
105
106async fn send_preamble<T, R>(stream: &mut T, config: &Config, remote_address: R, exchange_public_key: &[u8], is_udp: bool) -> Res<Challenge>
113where
114 T: BincodeSend,
115 R: AsRef<str>,
116{
117 if exchange_public_key.len() != Constant::EXCHANGE_PUBLIC_KEY_SIZE {
118 return Err(anyhow!(
119 "Invalid exchange public key size: expected {} bytes, got {} bytes",
120 Constant::EXCHANGE_PUBLIC_KEY_SIZE,
121 exchange_public_key.len()
122 ));
123 }
124
125 let challenge = generate_challenge();
126
127 let preamble = ClientPreamble {
128 exchange_public_key,
129 remote: remote_address.as_ref(),
130 challenge: &challenge,
131 should_encrypt: config.should_encrypt,
132 is_udp,
133 };
134
135 stream.push(ProtocolMessage::ClientPreamble(preamble)).await?;
136
137 info!("✅ Sent preamble to server ...");
138
139 Ok(challenge)
140}
141
142async fn handle_challenge<T>(stream: &mut T, config: &Config, client_challenge: &Challenge) -> Res<ClientHandshakeData>
147where
148 T: BincodeSend + BincodeReceive,
149{
150 let guard = stream.pull().await?;
153 let ProtocolMessage::ServerPreamble(server_preamble) = guard.message() else {
154 return Err(anyhow!("Handshake failed: improper message type (expected handshake challenge)"));
155 };
156
157 let result = ClientHandshakeData {
158 server_challenge: server_preamble.challenge.try_into()?,
159 server_exchange_public_key: server_preamble.exchange_public_key.try_into()?,
160 };
161
162 validate_signed_challenge(client_challenge, server_preamble.signature, server_preamble.identity_public_key)?;
165
166 info!("✅ Server's signature validated with public key `{}` ...", server_preamble.identity_public_key);
167
168 if !config.accept_all_hosts && !config.known_hosts.iter().any(|k| k == server_preamble.identity_public_key) {
171 return Err(anyhow!("Server's public key `{}` is not in the known hosts file", server_preamble.identity_public_key));
173 }
174
175 info!("🚧 Signing server challenge ...");
176
177 let client_signature = sign_challenge(server_preamble.challenge, &config.private_key)?;
178 let client_authentication = ClientAuthentication {
179 identity_public_key: &config.public_key,
180 signature: &client_signature,
181 };
182 stream.push(ProtocolMessage::ClientAuthentication(client_authentication)).await?;
183
184 info!("⏳ Awaiting challenge validation ...");
185
186 let guard = stream.pull().await?;
187 let ProtocolMessage::HandshakeCompletion = guard.message().fail_if_error()? else {
188 return Err(anyhow!("Handshake failed: improper message type (expected handshake completion)"));
189 };
190
191 Ok(result)
192}
193
194async fn handle_handshake<T, R>(stream: &mut T, config: &Config, remote_address: R, is_udp: bool) -> Res<ClientKeyExchangeData>
196where
197 T: BincodeSend + BincodeReceive,
198 R: AsRef<str>,
199{
200 let exchange_key_pair = generate_ephemeral_key_pair()?;
202 let exchange_public_key = exchange_key_pair.public_key.as_ref();
203
204 let client_challenge = send_preamble(stream, config, remote_address, exchange_public_key, is_udp).await?;
205 let handshake_data = handle_challenge(stream, config, &client_challenge).await?;
206
207 let ephemeral_data = ClientKeyExchangeData {
210 server_exchange_public_key: handshake_data.server_exchange_public_key,
211 server_challenge: handshake_data.server_challenge,
212 local_exchange_private_key: exchange_key_pair.private_key,
213 local_challenge: client_challenge,
214 };
215
216 info!("✅ Challenge accepted!");
217
218 Ok(ephemeral_data)
219}
220
221async fn server_connect(connect_address: &str) -> Res<TcpStream> {
223 let stream = TcpStream::connect(connect_address).await?;
224 info!("✅ Connected to server `{}` ...", connect_address);
225
226 Ok(stream)
227}
228
229async fn connect(config: &Config, remote_address: &str, is_udp: bool) -> Res<BuffedTcpStream> {
231 let server = server_connect(&config.connect_address).await?;
233 server.set_nodelay(true)?;
234
235 let mut server = BuffedTcpStream::from(server);
236
237 let handshake_data = handle_handshake(&mut server, config, remote_address, is_udp).await.context("Error handling handshake")?;
239
240 info!("✅ Handshake successful: connection established!");
241
242 if config.should_encrypt {
244 let salt_bytes = [handshake_data.server_challenge, handshake_data.local_challenge].concat();
245
246 let shared_secret = generate_shared_secret(handshake_data.local_exchange_private_key, &handshake_data.server_exchange_public_key, &salt_bytes)?;
247
248 server = server.with_encryption(shared_secret);
249 info!("🔒 Encryption applied ...");
250 }
251
252 Ok(server)
253}
254
255async fn run_tcp_server(tunnel_definition: TunnelDefinition, config: Config) {
261 let result: Void = async move {
262 let listener = TcpListener::bind(&tunnel_definition.bind_address).await?;
263
264 info!(
265 "📻 [TCP] Listening on `{}`, and routing through `{}` to `{}` ...",
266 tunnel_definition.bind_address, config.connect_address, tunnel_definition.remote_address
267 );
268
269 loop {
270 let (socket, _) = listener.accept().await?;
271
272 tokio::spawn(handle_tcp(socket, tunnel_definition.remote_address.clone(), config.clone()));
273 }
274 }
275 .await;
276
277 if let Err(err) = result {
278 error!("❌ Error starting TCP server, or accepting a connection (shutting down listener for this bind address): {}", err);
279 }
280}
281
282async fn handle_tcp(local: TcpStream, remote_address: String, config: Config) {
286 let id = random_string(6);
287 let span = info_span!("tcp", id = id);
288
289 let result: Void = async move {
290 let server = connect(&config, &remote_address, false).await?;
293
294 info!("⛽ Pumping data between client and remote ...");
297
298 local.set_nodelay(true)?;
299
300 handle_tcp_pump(local, server).await.context("Error handling pump")?;
301
302 info!("✅ Connection closed.");
303
304 Ok(())
305 }
306 .instrument(span.clone())
307 .await;
308
309 let _guard = span.enter();
311
312 if let Err(err) = result {
313 let chain = err.chain().collect::<Vec<_>>();
314 let full_chain = chain.iter().map(|e| format!("`{}`", e)).collect::<Vec<_>>().join(" => ");
315
316 error!("❌ Error handling the connection: {}.", full_chain);
317 }
318}
319
320async fn run_udp_server(tunnel_definition: TunnelDefinition, config: Config) {
326 let result: Void = async move {
327 let socket = Arc::new(UdpSocket::bind(&tunnel_definition.bind_address).await?);
328
329 info!(
330 "📻 [UDP] Listening on `{}`, and routing through `{}` to `{}` ...",
331 tunnel_definition.bind_address, config.connect_address, tunnel_definition.remote_address
332 );
333
334 let clients = Arc::new(Mutex::new(HashMap::<SocketAddr, UnboundedSender<Bytes>>::new()));
335 let mut buffer = BytesMut::with_capacity(2 * Constant::BUFFER_SIZE);
336
337 loop {
338 buffer.clear();
340 buffer.reserve(Constant::BUFFER_SIZE);
341
342 unsafe { buffer.set_len(Constant::BUFFER_SIZE) };
344 let (read, addr) = socket.recv_from(&mut buffer).await?;
345 unsafe { buffer.set_len(read) };
346
347 let data = buffer.split().freeze();
348
349 if let Some(data_sender) = clients.lock().await.get_mut(&addr) {
352 data_sender.send(data)?;
354 } else {
355 let socket_clone = socket.clone();
357 let config_clone = config.clone();
358
359 let (data_sender, data_receiver) = tokio::sync::mpsc::unbounded_channel();
361 data_sender.send(data)?;
362 clients.lock().await.insert(addr, data_sender);
363
364 let clients_clone = clients.clone();
366 let remote_address = tunnel_definition.remote_address.clone();
367 tokio::spawn(async move {
368 handle_udp(addr, socket_clone, data_receiver, remote_address, config_clone).await;
370
371 clients_clone.lock().await.remove(&addr);
373 });
374 }
375 }
376 }
377 .await;
378
379 if let Err(err) = result {
380 error!("❌ Error starting UDP server, or accepting a connection (shutting down listener for this bind address): {}", err);
381 }
382}
383
384async fn handle_udp(address: SocketAddr, client_socket: Arc<UdpSocket>, mut data_receiver: UnboundedReceiver<Bytes>, remote_address: String, config: Config) {
386 let id = random_string(6);
387 let span = info_span!("udp", id = id);
388
389 let result: Void = async move {
390 let server = connect(&config, &remote_address, true).await?;
393
394 info!("⛽ Pumping data between client and remote ...");
397
398 let client_socket_clone = client_socket.clone();
399 let (mut remote_read, mut remote_write) = server.into_split();
400
401 let pump_up: JoinHandle<Void> = tokio::spawn(async move {
405 while let Some(data) = data_receiver.recv().await {
406 dbg!("client up {}", String::from_utf8_lossy(&data));
407 remote_write.push(ProtocolMessage::UdpData(&data)).await?;
408 }
409
410 Ok(())
411 });
412
413 let pump_down: JoinHandle<Void> = tokio::spawn(async move {
414 loop {
415 let guard = remote_read.pull().await?;
416 let ProtocolMessage::UdpData(data) = guard.message() else {
417 break;
418 };
419
420 client_socket_clone.send_to(data, &address).await?;
421 }
422
423 Ok(())
424 });
425
426 let result = select! {
431 r = pump_up => r?,
432 r = pump_down => r?,
433 };
434
435 result?;
438
439 Ok(())
440 }
441 .instrument(span.clone())
442 .await;
443
444 let _guard = span.enter();
446
447 if let Err(err) = result {
448 let chain = err.chain().collect::<Vec<_>>();
449 let full_chain = chain.iter().map(|e| format!("`{}`", e)).collect::<Vec<_>>().join(" => ");
450
451 error!("❌ Error handling the connection: {}.", full_chain);
452 }
453}
454
455async fn test_server_connection(tunnel_definition: TunnelDefinition, config: Config) -> Void {
459 info!("⏳ Testing server connection ...");
460
461 let mut remote = BuffedTcpStream::from(server_connect(&config.connect_address).await?);
463
464 if let Err(e) = handle_handshake(&mut remote, &config, &tunnel_definition.remote_address, false).await {
466 error!("❌ Test connection failed: {}", e);
467 return Err(e);
468 }
469
470 info!("✅ Test connection successful!");
471
472 Ok(())
473}
474
475#[derive(Clone)]
481pub(crate) struct Config {
482 pub(crate) public_key: String,
483 pub(crate) private_key: SecretString,
484 pub(crate) known_hosts: Vec<String>,
485 pub(crate) connect_address: String,
486 pub(crate) accept_all_hosts: bool,
487 pub(crate) should_encrypt: bool,
488}
489
490impl Config {
491 fn new(public_key: String, private_key: SecretString, known_hosts: Vec<String>, connect_address: String, accept_all_hosts: bool, should_encrypt: bool) -> Res<Self> {
493 Ok(Self {
494 public_key,
495 private_key,
496 connect_address,
497 known_hosts,
498 accept_all_hosts,
499 should_encrypt,
500 })
501 }
502}
503
504#[cfg(test)]
507pub mod tests {
508 use crate::utils::{
509 generate_key_pair,
510 tests::{generate_test_duplex, generate_test_fake_exchange_public_key},
511 };
512
513 use super::*;
514 use pretty_assertions::assert_eq;
515
516 pub(crate) fn generate_test_client_config() -> Config {
517 let key_path = "test/client";
518
519 let public_key = resolve_public_key(key_path).unwrap();
520 let private_key = resolve_private_key(key_path).unwrap();
521 let known_hosts = resolve_known_hosts(key_path);
522
523 Config {
524 public_key,
525 private_key,
526 known_hosts,
527 connect_address: "connect_address".to_string(),
528 accept_all_hosts: false,
529 should_encrypt: false,
530 }
531 }
532
533 #[test]
534 fn test_prepare() {
535 let key_path = "test/client";
536 let connect_address = "connect_address";
537 let tunnel_definitions = ["localhost:5000:example.com:80", "127.0.0.1:6000:api.example.com:443"];
538 let accept_all_hosts = false;
539 let should_encrypt = false;
540
541 let instance = Instance::prepare(key_path.to_owned(), connect_address, &tunnel_definitions, accept_all_hosts, should_encrypt).unwrap();
542
543 assert_eq!(instance.config.connect_address, connect_address);
545 assert_eq!(instance.config.should_encrypt, should_encrypt);
546
547 let expected_public_key = resolve_public_key(key_path).unwrap();
549 assert_eq!(instance.config.public_key, expected_public_key);
550
551 let expected_known_hosts = resolve_known_hosts(key_path);
553 assert_eq!(instance.config.known_hosts, expected_known_hosts);
554
555 assert_eq!(instance.tunnel_definitions.len(), 2);
557 assert_eq!(instance.tunnel_definitions[0].bind_address, "localhost:5000");
558 assert_eq!(instance.tunnel_definitions[0].remote_address, "example.com:80");
559 assert_eq!(instance.tunnel_definitions[1].bind_address, "127.0.0.1:6000");
560 assert_eq!(instance.tunnel_definitions[1].remote_address, "api.example.com:443");
561 }
562
563 #[tokio::test]
564 async fn test_send_preamble() {
565 let (mut client, mut server) = generate_test_duplex();
566 let config = generate_test_client_config();
567 let remote_address = "remote_address:3000";
568 let exchange_public_key = &generate_test_fake_exchange_public_key();
569
570 let client_challenge = send_preamble(&mut client, &config, remote_address, exchange_public_key, false).await.unwrap();
571
572 let guard = server.pull().await.unwrap();
573 match guard.message() {
574 ProtocolMessage::ClientPreamble(preamble) => {
575 assert_eq!(preamble.remote, remote_address);
576 assert_eq!(preamble.exchange_public_key, exchange_public_key);
577 assert_eq!(preamble.challenge, client_challenge);
578 assert_eq!(preamble.should_encrypt, config.should_encrypt);
579 }
580 _ => panic!("Expected ClientPreamble, got different message type"),
581 }
582 }
583
584 #[tokio::test]
585 async fn test_handle_challenge_bad_key() {
586 let (mut client, mut server) = generate_test_duplex();
587 let config = generate_test_client_config();
588 let client_challenge = generate_challenge();
589 let bad_key = generate_key_pair().unwrap().private_key;
590
591 tokio::spawn(async move {
592 let preamble = crate::protocol::ServerPreamble {
594 identity_public_key: &bad_key,
595 signature: &[0u8; 64], challenge: &generate_challenge(),
597 exchange_public_key: &generate_test_fake_exchange_public_key(),
598 };
599
600 server.push(ProtocolMessage::ServerPreamble(preamble)).await.unwrap();
601 });
602
603 let result = handle_challenge(&mut client, &config, &client_challenge).await;
604
605 assert!(result.is_err());
606 assert_eq!(result.unwrap_err().to_string(), "Invalid signature");
607 }
608
609 #[tokio::test]
610 async fn test_handle_challenge_wrong_message_type() {
611 let (mut client, mut server) = generate_test_duplex();
612 let config = generate_test_client_config();
613 let client_challenge = generate_challenge();
614
615 tokio::spawn(async move {
616 server.push(ProtocolMessage::HandshakeCompletion).await.unwrap();
618 });
619
620 let result = handle_challenge(&mut client, &config, &client_challenge).await;
621
622 assert!(result.is_err());
623 assert!(result.unwrap_err().to_string().contains("improper message type"));
624 }
625}