1use std::{marker::PhantomData, sync::Arc};
4
5use anyhow::Context;
6use regex::Regex;
7use secrecy::SecretString;
8use tokio::{
9 net::{TcpListener, TcpStream, UdpSocket}, select, task::JoinHandle, time::Instant
10};
11use tracing::{Instrument, error, info, info_span};
12
13use crate::{
14 base::{Constant, ExchangeKeyPair, Res, ServerKeyExchangeData, Void},
15 buffed_stream::BuffedTcpStream,
16 protocol::{BincodeReceive, BincodeSend, Challenge, ClientPreamble, ProtocolError, ProtocolMessage, ServerPreamble, Signature},
17 security::{resolve_authorized_keys, resolve_keypath, resolve_private_key, resolve_public_key},
18 utils::{generate_challenge, generate_ephemeral_key_pair, generate_shared_secret, handle_pump, random_string, sign_challenge, validate_signed_challenge},
19};
20
21pub struct ConfigState;
25pub struct ReadyState;
27
28pub struct Instance<S = ConfigState> {
32 config: Config,
33 _phantom: PhantomData<S>,
34}
35
36impl Instance<ConfigState> {
37 pub fn prepare<A, B, C>(key_path: A, remote_regex: B, bind_address: C) -> Res<Instance<ReadyState>>
39 where
40 A: Into<Option<String>>,
41 B: AsRef<str>,
42 C: Into<String>,
43 {
44 let remote_regex = Regex::new(remote_regex.as_ref()).context("Invalid regex for remote host.")?;
45
46 let key_path = resolve_keypath(key_path)?;
47 let private_key = resolve_private_key(&key_path)?;
48 let public_key = resolve_public_key(&key_path)?;
49 let authorized_keys = resolve_authorized_keys(&key_path);
50
51 let config = Config::new(public_key, private_key, authorized_keys, bind_address.into(), remote_regex);
52
53 Ok(Instance { config, _phantom: PhantomData })
54 }
55}
56
57impl Instance<ReadyState> {
58 pub async fn start(self) -> Void {
60 info!("🚀 Starting server on `{}` ...", self.config.bind_address);
61
62 run_tcp_server(self.config.clone()).await?;
63
64 Ok(())
65 }
66}
67
68async fn verify_client_preamble<T>(stream: &mut T, config: &Config, preamble: &ClientPreamble) -> Res<Signature>
74where
75 T: BincodeSend,
76{
77 if !config.remote_regex.is_match(&preamble.remote) {
80 return ProtocolError::InvalidHost(format!("Invalid host from client (supplied `{}`, but need to satisfy `{}`)", preamble.remote, config.remote_regex))
81 .send_and_bail(stream)
82 .await;
83 }
84
85 let signature = sign_challenge(&preamble.challenge, &config.private_key)?;
88
89 Ok(signature)
90}
91
92async fn send_server_preamble<T>(stream: &mut T, config: &Config, server_signature: &Signature, server_challenge: &Challenge) -> Res<ExchangeKeyPair>
93where
94 T: BincodeSend,
95{
96 info!("🚧 Sending handshake challenge to client ...");
97
98 let exchange_key_pair = generate_ephemeral_key_pair()?;
99 let exchange_public_key = exchange_key_pair.public_key.as_ref().try_into()?;
100
101 let preamble = ServerPreamble {
102 challenge: *server_challenge,
103 exchange_public_key,
104 identity_public_key: config.public_key.clone(),
105 signature: server_signature.into(),
106 };
107
108 stream.push(ProtocolMessage::ServerPreamble(preamble)).await?;
109
110 Ok(exchange_key_pair)
111}
112
113async fn handle_and_validate_key_challenge<T>(stream: &mut T, config: &Config, server_challenge: &Challenge) -> Void
118where
119 T: BincodeSend + BincodeReceive,
120{
121 let ProtocolMessage::ClientAuthentication(client_authentication) = stream.pull().await?.fail_if_error()? else {
124 return ProtocolError::InvalidKey("Invalid handshake response".into()).send_and_bail(stream).await;
125 };
126
127 if validate_signed_challenge(server_challenge, &client_authentication.signature.into(), &client_authentication.identity_public_key).is_err() {
130 return ProtocolError::InvalidKey("Invalid challenge signature from client".into()).send_and_bail(stream).await;
131 }
132
133 if !config.authorized_keys.contains(&client_authentication.identity_public_key) {
135 return ProtocolError::InvalidKey("Unauthorized key from client".into()).send_and_bail(stream).await;
136 }
137
138 info!("✅ Handshake challenge completed!");
139
140 Ok(())
141}
142
143async fn complete_handshake<T>(stream: &mut T) -> Void
148where
149 T: BincodeSend,
150{
151 let completion = ProtocolMessage::HandshakeCompletion;
152
153 stream.push(completion).await?;
154
155 info!("✅ Handshake completed.");
156
157 Ok(())
158}
159
160async fn handle_handshake<T>(stream: &mut T, config: &Config) -> Res<ServerKeyExchangeData>
165where
166 T: BincodeReceive + BincodeSend,
167{
168 let ProtocolMessage::ClientPreamble(preamble) = stream.pull().await? else {
171 return ProtocolError::Unknown("Invalid handshake start".into()).send_and_bail(stream).await;
172 };
173
174 let server_signature = verify_client_preamble(stream, config, &preamble).await?;
177
178 let server_challenge = generate_challenge();
181
182 let local_exchange_key_pair = send_server_preamble(stream, config, &server_signature, &server_challenge).await?;
185
186 handle_and_validate_key_challenge(stream, config, &server_challenge).await?;
189
190 complete_handshake(stream).await?;
193
194 Ok(ServerKeyExchangeData {
195 client_exchange_public_key: preamble.exchange_public_key,
196 client_challenge: preamble.challenge,
197 local_exchange_private_key: local_exchange_key_pair.private_key,
198 local_challenge: server_challenge,
199 requested_remote_address: preamble.remote,
200 requested_should_encrypt: preamble.should_encrypt,
201 requested_is_udp: preamble.is_udp,
202 })
203}
204
205async fn run_tcp_pump(mut client: BuffedTcpStream, remote_address: &str) -> Void {
207 let Ok(mut remote) = TcpStream::connect(remote_address).await.context("Error connecting to remote") else {
208 return ProtocolError::RemoteFailed(format!("Failed to connect to remote `{}`", remote_address))
209 .send_and_bail(&mut client)
210 .await;
211 };
212
213 info!("✅ Connected to remote server `{}`.", remote_address);
214
215 handle_pump(&mut client, &mut remote).await.context("Error handling TCP pump.")?;
216
217 Ok(())
218}
219
220async fn run_udp_pump(mut client: BuffedTcpStream, remote_address: &str) -> Void {
222 let remote = UdpSocket::bind("127.0.0.1:0").await.context("Error binding UDP socket")?;
225 if remote.connect(remote_address).await.is_err() {
226 return ProtocolError::RemoteFailed(format!("Failed to connect to remote `{}`", remote_address))
227 .send_and_bail(&mut client)
228 .await;
229 }
230
231 info!("✅ Connected to remote server `{}`.", remote_address);
232
233 let (mut client_read, mut client_write) = client.into_split();
235
236 let remote_up = Arc::new(remote);
238 let remote_down = remote_up.clone();
239
240 let last_activity = Arc::new(tokio::sync::Mutex::new(Instant::now()));
243 let last_activity_up = last_activity.clone();
244 let last_activity_down = last_activity.clone();
245
246 let pump_up: JoinHandle<Void> = tokio::spawn(async move {
247 while let ProtocolMessage::UdpData(data) = client_read.pull().await? {
248 remote_up.send(&data).await?;
249 *last_activity_up.lock().await = Instant::now();
250 }
251
252 Ok(())
253 });
254
255 let pump_down: JoinHandle<Void> = tokio::spawn(async move {
256 let mut buf = [0; Constant::BUFFER_SIZE];
257
258 loop {
259 let size = remote_down.recv(&mut buf).await?;
260 client_write.push(ProtocolMessage::UdpData(buf[..size].to_vec())).await?;
261 *last_activity_down.lock().await = Instant::now();
262 }
263 });
264
265 let timeout: JoinHandle<Void> = tokio::spawn(async move {
266 loop {
267 let last_activity = *last_activity.lock().await;
268
269 if last_activity.elapsed() > Constant::UDP_TIMEOUT {
270 info!("✅ UDP connection timed out (assumed graceful close).");
271 return Ok(());
272 }
273
274 tokio::time::sleep(Constant::UDP_TIMEOUT).await;
275 }
276 });
277
278 let result = select! {
285 r = pump_up => r?,
286 r = pump_down => r?,
287 r = timeout => r?,
288 };
289
290 result?;
293
294 Ok(())
295}
296
297async fn run_tcp_server(config: Config) -> Void {
301 let listener = TcpListener::bind(&config.bind_address).await?;
302
303 loop {
304 let (socket, _) = listener.accept().await?;
305
306 tokio::spawn(handle_connection(socket, config.clone()));
307 }
308}
309
310async fn handle_connection(client: TcpStream, config: Config) {
315 let mut client = BuffedTcpStream::from(client);
316
317 let id = random_string(6);
318 let span = info_span!("conn", id = id);
319
320 let result: Void = async move {
321 let peer_addr = client.as_inner_tcp_write_ref().peer_addr().context("Error getting peer address")?;
322
323 info!("✅ Accepted connection from `{}`.", peer_addr);
324
325 let handshake_data = handle_handshake(&mut client, &config).await.context("Error handling handshake")?;
328
329 if handshake_data.requested_should_encrypt {
331 let private_key = handshake_data.local_exchange_private_key;
332 let salt_bytes = [handshake_data.local_challenge, handshake_data.client_challenge].concat();
333 let shared_secret = generate_shared_secret(private_key, &handshake_data.client_exchange_public_key, &salt_bytes)?;
334
335 client = client.with_encryption(shared_secret);
336 info!("🔒 Encryption applied ...");
337 }
338
339 info!("⛽ Pumping data between client and remote ...");
342
343 if handshake_data.requested_is_udp {
344 run_udp_pump(client, &handshake_data.requested_remote_address).await?;
345 } else {
346 run_tcp_pump(client, &handshake_data.requested_remote_address).await?;
347 }
348
349 info!("✅ Connection closed.");
350
351 Ok(())
352 }
353 .instrument(span.clone())
354 .await;
355
356 let _guard = span.enter();
358
359 if let Err(err) = result {
360 let chain = err.chain().collect::<Vec<_>>();
361 let full_chain = chain.iter().map(|e| format!("`{}`", e)).collect::<Vec<_>>().join(" => ");
362
363 error!("❌ Error handling connection: {}.", full_chain);
364 }
365}
366
367#[derive(Clone)]
373pub(crate) struct Config {
374 pub(crate) public_key: String,
375 pub(crate) private_key: SecretString,
376 pub(crate) authorized_keys: Vec<String>,
377 pub(crate) bind_address: String,
378 pub(crate) remote_regex: Regex,
379}
380
381impl Config {
382 fn new(public_key: String, private_key: SecretString, authorized_keys: Vec<String>, bind_address: String, remote_regex: Regex) -> Self {
384 Self {
385 public_key,
386 private_key,
387 authorized_keys,
388 bind_address,
389 remote_regex,
390 }
391 }
392}
393
394#[cfg(test)]
397mod tests {
398 use pretty_assertions::assert_eq;
399
400 use crate::{
401 connect::tests::generate_test_client_config,
402 protocol::ClientAuthentication,
403 utils::{
404 generate_key_pair, sign_challenge,
405 tests::{generate_test_duplex, generate_test_fake_exchange_public_key},
406 },
407 };
408
409 use super::*;
410
411 pub(crate) fn generate_test_server_config() -> Config {
412 let key_path = "test/server";
413
414 let public_key = resolve_public_key(key_path).unwrap();
415 let private_key = resolve_private_key(key_path).unwrap();
416 let authorized_keys = resolve_authorized_keys(key_path);
417 let remote_regex = Regex::new(".*").unwrap();
418
419 Config {
420 public_key,
421 private_key,
422 authorized_keys,
423 bind_address: "bind_address".to_string(),
424 remote_regex,
425 }
426 }
427
428 #[test]
429 fn test_prepare_config() {
430 let instance = Instance::prepare("test/server".to_string(), ".*", "foo").unwrap();
431
432 assert_eq!(instance.config.public_key, "HQYY0BNIhdawY2Jw62DudkUsK2GKj3hGO3qSVBlCinI");
433 assert_eq!(instance.config.remote_regex.as_str(), ".*");
434 assert_eq!(instance.config.bind_address, "foo");
435 }
436
437 #[tokio::test]
438 async fn test_handle_handshake_success() {
439 let mut config = generate_test_server_config();
441 let (mut client, mut server) = generate_test_duplex();
442 let client_config = generate_test_client_config();
443
444 config.authorized_keys.push(client_config.public_key.clone());
446
447 let client_challenge: Challenge = [8u8; 32];
449 let client_preamble = ClientPreamble {
450 remote: "localhost".to_string(),
451 challenge: client_challenge,
452 exchange_public_key: generate_test_fake_exchange_public_key(),
453 should_encrypt: true,
454 is_udp: false,
455 };
456
457 client.push(ProtocolMessage::ClientPreamble(client_preamble.clone())).await.unwrap();
458
459 let client_handle = tokio::spawn(async move {
461 let message = client.pull().await.unwrap();
463 let server_challenge = match message {
464 ProtocolMessage::ServerPreamble(preamble) => preamble.challenge,
465 _ => panic!("Expected ServerPreamble message, got: {:?}", message),
466 };
467
468 let signature = sign_challenge(&server_challenge, &client_config.private_key).unwrap();
470 let client_auth = ClientAuthentication {
471 identity_public_key: client_config.public_key,
472 signature: signature.into(),
473 };
474
475 client.push(ProtocolMessage::ClientAuthentication(client_auth)).await.unwrap();
476
477 let message = client.pull().await.unwrap();
479 assert!(matches!(message, ProtocolMessage::HandshakeCompletion));
480 });
481
482 let result = handle_handshake(&mut server, &config).await;
484
485 client_handle.await.unwrap();
487
488 assert!(result.is_ok());
490 let key_data = result.unwrap();
491
492 assert_eq!(key_data.client_exchange_public_key, client_preamble.exchange_public_key);
494 assert_eq!(key_data.client_challenge, client_challenge);
495 assert_eq!(key_data.requested_remote_address, "localhost");
496 assert_eq!(key_data.requested_should_encrypt, true);
497 }
498
499 #[tokio::test]
500 async fn test_handle_handshake_invalid_start() {
501 let config = generate_test_server_config();
503 let (mut client, mut server) = generate_test_duplex();
504
505 client.push(ProtocolMessage::HandshakeCompletion).await.unwrap();
507
508 let result = handle_handshake(&mut server, &config).await;
510
511 assert!(result.is_err());
513
514 let message = client.pull().await.unwrap();
516 if let ProtocolMessage::Error(error) = message {
517 assert_eq!(error, ProtocolError::Unknown("Invalid handshake start".into()));
518 } else {
519 panic!("Expected error message, got: {:?}", message);
520 }
521 }
522
523 #[tokio::test]
524 async fn test_handle_handshake_invalid_host() {
525 let mut config = generate_test_server_config();
527 config.remote_regex = Regex::new("^only-this-host$").unwrap();
528 let (mut client, mut server) = generate_test_duplex();
529
530 let client_preamble = ClientPreamble {
532 remote: "different-host".to_string(),
533 challenge: [9u8; 32],
534 exchange_public_key: generate_test_fake_exchange_public_key(),
535 should_encrypt: false,
536 is_udp: false,
537 };
538
539 client.push(ProtocolMessage::ClientPreamble(client_preamble)).await.unwrap();
540
541 let result = handle_handshake(&mut server, &config).await;
543
544 assert!(result.is_err());
546
547 let message = client.pull().await.unwrap();
549 if let ProtocolMessage::Error(error) = message {
550 assert!(matches!(error, ProtocolError::InvalidHost(_)));
551 } else {
552 panic!("Expected error message, got: {:?}", message);
553 }
554 }
555
556 #[tokio::test]
557 async fn test_handle_handshake_unauthorized_key() {
558 let config = generate_test_server_config();
560 let (mut client, mut server) = generate_test_duplex();
561
562 let client_preamble = ClientPreamble {
564 remote: "localhost".to_string(),
565 challenge: [10u8; 32],
566 exchange_public_key: generate_test_fake_exchange_public_key(),
567 should_encrypt: false,
568 is_udp: false,
569 };
570
571 client.push(ProtocolMessage::ClientPreamble(client_preamble)).await.unwrap();
572
573 let unauthorized_key_pair = generate_key_pair().unwrap();
575 let unauthorized_private_key = unauthorized_key_pair.private_key.into();
576
577 let client_handle = tokio::spawn(async move {
579 let message = client.pull().await.unwrap();
581 let server_challenge = match message {
582 ProtocolMessage::ServerPreamble(preamble) => preamble.challenge,
583 _ => panic!("Expected ServerPreamble message, got: {:?}", message),
584 };
585
586 let signature = sign_challenge(&server_challenge, &unauthorized_private_key).unwrap();
588 let client_auth = ClientAuthentication {
589 identity_public_key: unauthorized_key_pair.public_key,
590 signature: signature.into(),
591 };
592
593 client.push(ProtocolMessage::ClientAuthentication(client_auth)).await.unwrap();
594
595 let message = client.pull().await.unwrap();
597 assert!(matches!(message, ProtocolMessage::Error(_)));
598 });
599
600 let result = handle_handshake(&mut server, &config).await;
602
603 client_handle.await.unwrap();
605
606 assert!(result.is_err());
608 }
609}