1use std::{collections::HashMap, marker::PhantomData, net::SocketAddr, sync::Arc};
6
7use anyhow::Context;
8use futures::join;
9use secrecy::SecretString;
10use tokio::{
11 net::{TcpListener, TcpStream, UdpSocket}, select, sync::{
12 mpsc::{UnboundedReceiver, UnboundedSender}, Mutex
13 }, task::JoinHandle
14};
15use tracing::{Instrument, error, info, info_span};
16
17use crate::{
18 base::{ClientHandshakeData, ClientKeyExchangeData, Constant, Err, Res, TunnelDefinition, Void},
19 buffed_stream::BuffedTcpStream,
20 protocol::{BincodeReceive, BincodeSend, Challenge, ClientAuthentication, ClientPreamble, ExchangePublicKey, ProtocolMessage},
21 security::{resolve_keypath, resolve_known_hosts, resolve_private_key, resolve_public_key},
22 utils::{generate_challenge, generate_ephemeral_key_pair, generate_shared_secret, handle_pump, parse_tunnel_definitions, random_string, sign_challenge, validate_signed_challenge},
23};
24
25pub struct ConfigState;
29pub struct ReadyState;
31
32pub struct Instance<S = ConfigState> {
36 tunnel_definitions: Vec<TunnelDefinition>,
37 config: Config,
38 _phantom: PhantomData<S>,
39}
40
41impl Instance<ConfigState> {
42 pub fn prepare<A, B, C>(key_path: A, connect_address: B, tunnel_definitions: &[C], accept_all_hosts: bool, should_encrypt: bool) -> Res<Instance<ReadyState>>
44 where
45 A: Into<Option<String>>,
46 B: Into<String>,
47 C: AsRef<str>,
48 {
49 let tunnel_definitions = parse_tunnel_definitions(tunnel_definitions)?;
50
51 let key_path = resolve_keypath(key_path)?;
52 let private_key = resolve_private_key(&key_path)?;
53 let public_key = resolve_public_key(&key_path)?;
54 let known_hosts = resolve_known_hosts(&key_path);
55
56 let config = Config::new(public_key, private_key, known_hosts, connect_address.into(), accept_all_hosts, should_encrypt)?;
57
58 Ok(Instance {
59 tunnel_definitions,
60 config,
61 _phantom: PhantomData,
62 })
63 }
64}
65
66impl Instance<ReadyState> {
67 pub async fn start(self) -> Void {
71 let tasks = self
74 .tunnel_definitions
75 .into_iter()
76 .map(|tunnel_definition| async {
77 tokio::spawn(test_server_connection(tunnel_definition.clone(), self.config.clone()));
79
80 let tcp = tokio::spawn(run_tcp_server(tunnel_definition.clone(), self.config.clone()));
82 let udp = tokio::spawn(run_udp_server(tunnel_definition, self.config.clone()));
83
84 let (tcp_result, udp_result) = join!(tcp, udp);
85
86 tcp_result?;
87 udp_result?;
88
89 Void::Ok(())
90 })
91 .collect::<Vec<_>>();
92
93 futures::future::join_all(tasks).await;
96
97 Ok(())
98 }
99}
100
101async fn send_preamble<T, R>(stream: &mut T, config: &Config, remote_address: R, exchange_public_key: ExchangePublicKey, is_udp: bool) -> Res<Challenge>
108where
109 T: BincodeSend,
110 R: Into<String>,
111{
112 let challenge = generate_challenge();
113
114 let preamble = ClientPreamble {
115 exchange_public_key,
116 remote: remote_address.into(),
117 challenge,
118 should_encrypt: config.should_encrypt,
119 is_udp,
120 };
121
122 stream.push(ProtocolMessage::ClientPreamble(preamble)).await?;
123
124 info!("✅ Sent preamble to server ...");
125
126 Ok(challenge)
127}
128
129async fn handle_challenge<T>(stream: &mut T, config: &Config, client_challenge: &Challenge) -> Res<ClientHandshakeData>
134where
135 T: BincodeSend + BincodeReceive,
136{
137 let ProtocolMessage::ServerPreamble(server_preamble) = stream.pull().await? else {
140 return Err(Err::msg("Handshake failed: improper message type (expected handshake challenge)"));
141 };
142
143 validate_signed_challenge(client_challenge, &server_preamble.signature.into(), &server_preamble.identity_public_key)?;
146
147 info!("✅ Server's signature validated with public key `{}` ...", server_preamble.identity_public_key);
148
149 if !config.accept_all_hosts && !config.known_hosts.contains(&server_preamble.identity_public_key) {
152 return Err(Err::msg(format!("Server's public key `{}` is not in the known hosts file", server_preamble.identity_public_key)));
154 }
155
156 info!("🚧 Signing server challenge ...");
157
158 let client_signature = sign_challenge(&server_preamble.challenge, &config.private_key)?;
159 let client_authentication = ClientAuthentication {
160 identity_public_key: config.public_key.clone(),
161 signature: client_signature.into(),
162 };
163 stream.push(ProtocolMessage::ClientAuthentication(client_authentication)).await?;
164
165 info!("⏳ Awaiting challenge validation ...");
166
167 let ProtocolMessage::HandshakeCompletion = stream.pull().await?.fail_if_error()? else {
168 return Err(Err::msg("Handshake failed: improper message type (expected handshake completion)"));
169 };
170
171 Ok(ClientHandshakeData {
172 server_challenge: server_preamble.challenge,
173 server_exchange_public_key: server_preamble.exchange_public_key,
174 })
175}
176
177async fn handle_handshake<T, R>(stream: &mut T, config: &Config, remote_address: R, is_udp: bool) -> Res<ClientKeyExchangeData>
179where
180 T: BincodeSend + BincodeReceive,
181 R: Into<String>,
182{
183 let exchange_key_pair = generate_ephemeral_key_pair()?;
185 let exchange_public_key = exchange_key_pair.public_key.as_ref().try_into().map_err(|_| Err::msg("Could not convert peer public key to array"))?;
186
187 let client_challenge = send_preamble(stream, config, remote_address, exchange_public_key, is_udp).await?;
188 let handshake_data = handle_challenge(stream, config, &client_challenge).await?;
189
190 let ephemeral_data = ClientKeyExchangeData {
193 server_exchange_public_key: handshake_data.server_exchange_public_key,
194 server_challenge: handshake_data.server_challenge,
195 local_exchange_private_key: exchange_key_pair.private_key,
196 local_challenge: client_challenge,
197 };
198
199 info!("✅ Challenge accepted!");
200
201 Ok(ephemeral_data)
202}
203
204async fn server_connect(connect_address: &str) -> Res<TcpStream> {
206 let stream = TcpStream::connect(connect_address).await?;
207 info!("✅ Connected to server `{}` ...", connect_address);
208
209 Ok(stream)
210}
211
212async fn connect(config: &Config, remote_address: &str, is_udp: bool) -> Res<BuffedTcpStream> {
214 let mut server = BuffedTcpStream::from(server_connect(&config.connect_address).await?);
216
217 let handshake_data = handle_handshake(&mut server, config, remote_address, is_udp).await.context("Error handling handshake")?;
219
220 info!("✅ Handshake successful: connection established!");
221
222 if config.should_encrypt {
224 let salt_bytes = [handshake_data.server_challenge, handshake_data.local_challenge].concat();
225
226 let shared_secret = generate_shared_secret(handshake_data.local_exchange_private_key, &handshake_data.server_exchange_public_key, &salt_bytes)?;
227
228 server = server.with_encryption(shared_secret);
229 info!("🔒 Encryption applied ...");
230 }
231
232 Ok(server)
233}
234
235async fn run_tcp_server(tunnel_definition: TunnelDefinition, config: Config) {
241 let result: Void = async move {
242 let listener = TcpListener::bind(&tunnel_definition.bind_address).await?;
243
244 info!(
245 "📻 [TCP] Listening on `{}`, and routing through `{}` to `{}` ...",
246 tunnel_definition.bind_address, config.connect_address, tunnel_definition.remote_address
247 );
248
249 loop {
250 let (socket, _) = listener.accept().await?;
251
252 tokio::spawn(handle_tcp(socket, tunnel_definition.remote_address.clone(), config.clone()));
253 }
254 }
255 .await;
256
257 if let Err(err) = result {
258 error!("❌ Error starting TCP server, or accepting a connection (shutting down listener for this bind address): {}", err);
259 }
260}
261
262async fn handle_tcp(mut local: TcpStream, remote_address: String, config: Config) {
266 let id = random_string(6);
267 let span = info_span!("tcp", id = id);
268
269 let result: Void = async move {
270 let mut server = connect(&config, &remote_address, false).await?;
273
274 info!("⛽ Pumping data between client and remote ...");
277
278 handle_pump(&mut local, &mut server).await.context("Error handling pump")?;
279
280 info!("✅ Connection closed.");
281
282 Ok(())
283 }
284 .instrument(span.clone())
285 .await;
286
287 let _guard = span.enter();
289
290 if let Err(err) = result {
291 let chain = err.chain().collect::<Vec<_>>();
292 let full_chain = chain.iter().map(|e| format!("`{}`", e)).collect::<Vec<_>>().join(" => ");
293
294 error!("❌ Error handling the connection: {}.", full_chain);
295 }
296}
297
298async fn run_udp_server(tunnel_definition: TunnelDefinition, config: Config) {
304 let result: Void = async move {
305 let socket = Arc::new(UdpSocket::bind(&tunnel_definition.bind_address).await?);
306
307 info!(
308 "📻 [UDP] Listening on `{}`, and routing through `{}` to `{}` ...",
309 tunnel_definition.bind_address, config.connect_address, tunnel_definition.remote_address
310 );
311
312 let clients = Arc::new(Mutex::new(HashMap::<SocketAddr, UnboundedSender<Vec<u8>>>::new()));
313
314 loop {
315 let mut buf = vec![0; Constant::BUFFER_SIZE];
319 let (read, addr) = socket.recv_from(&mut buf).await?;
320 buf.truncate(read);
321
322 if let Some(data_sender) = clients.lock().await.get_mut(&addr) {
325 data_sender.send(buf)?;
327 } else {
328 let socket_clone = socket.clone();
330 let config_clone = config.clone();
331
332 let (data_sender, data_receiver) = tokio::sync::mpsc::unbounded_channel();
334 data_sender.send(buf)?;
335 clients.lock().await.insert(addr, data_sender);
336
337 let clients_clone = clients.clone();
339 let remote_address = tunnel_definition.remote_address.clone();
340 tokio::spawn(async move {
341 handle_udp(addr, socket_clone, data_receiver, remote_address, config_clone).await;
343
344 clients_clone.lock().await.remove(&addr);
346 });
347 }
348 }
349 }
350 .await;
351
352 if let Err(err) = result {
353 error!("❌ Error starting UDP server, or accepting a connection (shutting down listener for this bind address): {}", err);
354 }
355}
356
357async fn handle_udp(address: SocketAddr, client_socket: Arc<UdpSocket>, mut data_receiver: UnboundedReceiver<Vec<u8>>, remote_address: String, config: Config) {
359 let id = random_string(6);
360 let span = info_span!("tcp", id = id);
361
362 let result: Void = async move {
363 let server = connect(&config, &remote_address, true).await?;
366
367 info!("⛽ Pumping data between client and remote ...");
370
371 let client_socket_clone = client_socket.clone();
372 let (mut remote_read, mut remote_write) = server.into_split();
373
374 let pump_up: JoinHandle<Void> = tokio::spawn(async move {
377 while let Some(data) = data_receiver.recv().await {
378 remote_write.push(ProtocolMessage::UdpData(data.to_vec())).await?;
379 }
380
381 Ok(())
382 });
383
384 let pump_down: JoinHandle<Void> = tokio::spawn(async move {
385 while let ProtocolMessage::UdpData(data) = remote_read.pull().await? {
386 client_socket_clone.send_to(&data, &address).await?;
387 }
388
389 Ok(())
390 });
391
392 let result = select! {
397 r = pump_up => r?,
398 r = pump_down => r?,
399 };
400
401 result?;
404
405 Ok(())
406 }
407 .instrument(span.clone())
408 .await;
409
410 let _guard = span.enter();
412
413 if let Err(err) = result {
414 let chain = err.chain().collect::<Vec<_>>();
415 let full_chain = chain.iter().map(|e| format!("`{}`", e)).collect::<Vec<_>>().join(" => ");
416
417 error!("❌ Error handling the connection: {}.", full_chain);
418 }
419}
420
421async fn test_server_connection(tunnel_definition: TunnelDefinition, config: Config) -> Void {
425 info!("⏳ Testing server connection ...");
426
427 let mut remote = BuffedTcpStream::from(server_connect(&config.connect_address).await?);
429
430 if let Err(e) = handle_handshake(&mut remote, &config, &tunnel_definition.remote_address, false).await {
432 error!("❌ Test connection failed: {}", e);
433 return Err(e);
434 }
435
436 info!("✅ Test connection successful!");
437
438 Ok(())
439}
440
441#[derive(Clone)]
447pub(crate) struct Config {
448 pub(crate) public_key: String,
449 pub(crate) private_key: SecretString,
450 pub(crate) known_hosts: Vec<String>,
451 pub(crate) connect_address: String,
452 pub(crate) accept_all_hosts: bool,
453 pub(crate) should_encrypt: bool,
454}
455
456impl Config {
457 fn new(public_key: String, private_key: SecretString, known_hosts: Vec<String>, connect_address: String, accept_all_hosts: bool, should_encrypt: bool) -> Res<Self> {
459 Ok(Self {
460 public_key,
461 private_key,
462 connect_address,
463 known_hosts,
464 accept_all_hosts,
465 should_encrypt,
466 })
467 }
468}
469
470#[cfg(test)]
473pub mod tests {
474 use crate::utils::{
475 generate_key_pair,
476 tests::{generate_test_duplex, generate_test_fake_exchange_public_key},
477 };
478
479 use super::*;
480 use pretty_assertions::assert_eq;
481
482 pub(crate) fn generate_test_client_config() -> Config {
483 let key_path = "test/client";
484
485 let public_key = resolve_public_key(key_path).unwrap();
486 let private_key = resolve_private_key(key_path).unwrap();
487 let known_hosts = resolve_known_hosts(key_path);
488
489 Config {
490 public_key,
491 private_key,
492 known_hosts,
493 connect_address: "connect_address".to_string(),
494 accept_all_hosts: false,
495 should_encrypt: false,
496 }
497 }
498
499 #[test]
500 fn test_prepare() {
501 let key_path = "test/client";
502 let connect_address = "connect_address";
503 let tunnel_definitions = ["localhost:5000:example.com:80", "127.0.0.1:6000:api.example.com:443"];
504 let accrpt_all_hosts = false;
505 let should_encrypt = false;
506
507 let instance = Instance::prepare(key_path.to_owned(), connect_address, &tunnel_definitions, accrpt_all_hosts, should_encrypt).unwrap();
508
509 assert_eq!(instance.config.connect_address, connect_address);
511 assert_eq!(instance.config.should_encrypt, should_encrypt);
512
513 let expected_public_key = resolve_public_key(key_path).unwrap();
515 assert_eq!(instance.config.public_key, expected_public_key);
516
517 let expected_known_hosts = resolve_known_hosts(key_path);
519 assert_eq!(instance.config.known_hosts, expected_known_hosts);
520
521 assert_eq!(instance.tunnel_definitions.len(), 2);
523 assert_eq!(instance.tunnel_definitions[0].bind_address, "localhost:5000");
524 assert_eq!(instance.tunnel_definitions[0].remote_address, "example.com:80");
525 assert_eq!(instance.tunnel_definitions[1].bind_address, "127.0.0.1:6000");
526 assert_eq!(instance.tunnel_definitions[1].remote_address, "api.example.com:443");
527 }
528
529 #[tokio::test]
530 async fn test_send_preamble() {
531 let (mut client, mut server) = generate_test_duplex();
532 let config = generate_test_client_config();
533 let remote_address = "remote_address:3000";
534 let exchange_public_key = generate_test_fake_exchange_public_key();
535
536 let client_challenge = send_preamble(&mut client, &config, remote_address, exchange_public_key, false).await.unwrap();
537
538 let received = server.pull().await.unwrap();
539
540 match received {
541 ProtocolMessage::ClientPreamble(preamble) => {
542 assert_eq!(preamble.remote, remote_address);
543 assert_eq!(preamble.exchange_public_key, exchange_public_key);
544 assert_eq!(preamble.challenge, client_challenge);
545 assert_eq!(preamble.should_encrypt, config.should_encrypt);
546 }
547 _ => panic!("Expected ClientPreamble, got different message type"),
548 }
549 }
550
551 #[tokio::test]
552 async fn test_handle_challenge_bad_key() {
553 let (mut client, mut server) = generate_test_duplex();
554 let config = generate_test_client_config();
555 let client_challenge = generate_challenge();
556 let bad_key = generate_key_pair().unwrap().private_key;
557
558 tokio::spawn(async move {
559 let preamble = crate::protocol::ServerPreamble {
561 identity_public_key: bad_key,
562 signature: [0u8; 64].into(), challenge: generate_challenge(),
564 exchange_public_key: generate_test_fake_exchange_public_key(),
565 };
566
567 server.push(ProtocolMessage::ServerPreamble(preamble)).await.unwrap();
568 });
569
570 let result = handle_challenge(&mut client, &config, &client_challenge).await;
571
572 assert!(result.is_err());
573 assert_eq!(result.unwrap_err().to_string(), "Invalid signature");
574 }
575
576 #[tokio::test]
577 async fn test_handle_challenge_wrong_message_type() {
578 let (mut client, mut server) = generate_test_duplex();
579 let config = generate_test_client_config();
580 let client_challenge = generate_challenge();
581
582 tokio::spawn(async move {
583 server.push(ProtocolMessage::HandshakeCompletion).await.unwrap();
585 });
586
587 let result = handle_challenge(&mut client, &config, &client_challenge).await;
588
589 assert!(result.is_err());
590 assert!(result.unwrap_err().to_string().contains("improper message type"));
591 }
592}