1use std::{marker::PhantomData, sync::Arc};
4
5use anyhow::Context;
6use regex::Regex;
7use secrecy::SecretString;
8use tokio::{
9 net::{TcpListener, TcpStream, UdpSocket},
10 select,
11 task::JoinHandle,
12 time::Instant
13};
14use tracing::{Instrument, error, info, info_span};
15
16use crate::{
17 base::{Constant, ExchangeKeyPair, Res, ServerKeyExchangeData, Void},
18 buffed_stream::BuffedTcpStream,
19 protocol::{BincodeReceive, BincodeSend, Challenge, ClientPreamble, ProtocolError, ProtocolMessage, ServerPreamble, Signature},
20 security::{resolve_authorized_keys, resolve_keypath, resolve_private_key, resolve_public_key},
21 utils::{generate_challenge, generate_ephemeral_key_pair, generate_shared_secret, handle_pump, random_string, sign_challenge, validate_signed_challenge},
22};
23
24pub struct ConfigState;
28pub struct ReadyState;
30
31pub struct Instance<S = ConfigState> {
35 config: Config,
36 _phantom: PhantomData<S>,
37}
38
39impl Instance<ConfigState> {
40 pub fn prepare<A, B, C>(key_path: A, remote_regex: B, bind_address: C) -> Res<Instance<ReadyState>>
42 where
43 A: Into<Option<String>>,
44 B: AsRef<str>,
45 C: Into<String>,
46 {
47 let remote_regex = Regex::new(remote_regex.as_ref()).context("Invalid regex for remote host.")?;
48
49 let key_path = resolve_keypath(key_path)?;
50 let private_key = resolve_private_key(&key_path)?;
51 let public_key = resolve_public_key(&key_path)?;
52 let authorized_keys = resolve_authorized_keys(&key_path);
53
54 let config = Config::new(public_key, private_key, authorized_keys, bind_address.into(), remote_regex);
55
56 Ok(Instance { config, _phantom: PhantomData })
57 }
58}
59
60impl Instance<ReadyState> {
61 pub async fn start(self) -> Void {
63 info!("🚀 Starting server on `{}` ...", self.config.bind_address);
64
65 run_tcp_server(self.config.clone()).await?;
66
67 Ok(())
68 }
69}
70
71async fn verify_client_preamble<T>(stream: &mut T, config: &Config, preamble: &ClientPreamble) -> Res<Signature>
77where
78 T: BincodeSend,
79{
80 if !config.remote_regex.is_match(&preamble.remote) {
83 return ProtocolError::InvalidHost(format!("Invalid host from client (supplied `{}`, but need to satisfy `{}`)", preamble.remote, config.remote_regex))
84 .send_and_bail(stream)
85 .await;
86 }
87
88 let signature = sign_challenge(&preamble.challenge, &config.private_key)?;
91
92 Ok(signature)
93}
94
95async fn send_server_preamble<T>(stream: &mut T, config: &Config, server_signature: &Signature, server_challenge: &Challenge) -> Res<ExchangeKeyPair>
96where
97 T: BincodeSend,
98{
99 info!("🚧 Sending handshake challenge to client ...");
100
101 let exchange_key_pair = generate_ephemeral_key_pair()?;
102 let exchange_public_key = exchange_key_pair.public_key.as_ref().try_into()?;
103
104 let preamble = ServerPreamble {
105 challenge: *server_challenge,
106 exchange_public_key,
107 identity_public_key: config.public_key.clone(),
108 signature: server_signature.into(),
109 };
110
111 stream.push(ProtocolMessage::ServerPreamble(preamble)).await?;
112
113 Ok(exchange_key_pair)
114}
115
116async fn handle_and_validate_key_challenge<T>(stream: &mut T, config: &Config, server_challenge: &Challenge) -> Void
121where
122 T: BincodeSend + BincodeReceive,
123{
124 let ProtocolMessage::ClientAuthentication(client_authentication) = stream.pull().await?.fail_if_error()? else {
127 return ProtocolError::InvalidKey("Invalid handshake response".into()).send_and_bail(stream).await;
128 };
129
130 if validate_signed_challenge(server_challenge, &client_authentication.signature.into(), &client_authentication.identity_public_key).is_err() {
133 return ProtocolError::InvalidKey("Invalid challenge signature from client".into()).send_and_bail(stream).await;
134 }
135
136 if !config.authorized_keys.contains(&client_authentication.identity_public_key) {
138 return ProtocolError::InvalidKey("Unauthorized key from client".into()).send_and_bail(stream).await;
139 }
140
141 info!("✅ Handshake challenge completed!");
142
143 Ok(())
144}
145
146async fn complete_handshake<T>(stream: &mut T) -> Void
151where
152 T: BincodeSend,
153{
154 let completion = ProtocolMessage::HandshakeCompletion;
155
156 stream.push(completion).await?;
157
158 info!("✅ Handshake completed.");
159
160 Ok(())
161}
162
163async fn handle_handshake<T>(stream: &mut T, config: &Config) -> Res<ServerKeyExchangeData>
168where
169 T: BincodeReceive + BincodeSend,
170{
171 let ProtocolMessage::ClientPreamble(preamble) = stream.pull().await? else {
174 return ProtocolError::Unknown("Invalid handshake start".into()).send_and_bail(stream).await;
175 };
176
177 let server_signature = verify_client_preamble(stream, config, &preamble).await?;
180
181 let server_challenge = generate_challenge();
184
185 let local_exchange_key_pair = send_server_preamble(stream, config, &server_signature, &server_challenge).await?;
188
189 handle_and_validate_key_challenge(stream, config, &server_challenge).await?;
192
193 complete_handshake(stream).await?;
196
197 Ok(ServerKeyExchangeData {
198 client_exchange_public_key: preamble.exchange_public_key,
199 client_challenge: preamble.challenge,
200 local_exchange_private_key: local_exchange_key_pair.private_key,
201 local_challenge: server_challenge,
202 requested_remote_address: preamble.remote,
203 requested_should_encrypt: preamble.should_encrypt,
204 requested_is_udp: preamble.is_udp,
205 })
206}
207
208async fn run_tcp_pump(mut client: BuffedTcpStream, remote_address: &str) -> Void {
210 let Ok(mut remote) = TcpStream::connect(remote_address).await.context("Error connecting to remote") else {
211 return ProtocolError::RemoteFailed(format!("Failed to connect to remote `{}`", remote_address))
212 .send_and_bail(&mut client)
213 .await;
214 };
215
216 info!("✅ Connected to remote server `{}`.", remote_address);
217
218 handle_pump(&mut client, &mut remote).await.context("Error handling TCP pump.")?;
219
220 Ok(())
221}
222
223async fn run_udp_pump(mut client: BuffedTcpStream, remote_address: &str) -> Void {
225 let remote = UdpSocket::bind("127.0.0.1:0").await.context("Error binding UDP socket")?;
228 if remote.connect(remote_address).await.is_err() {
229 return ProtocolError::RemoteFailed(format!("Failed to connect to remote `{}`", remote_address))
230 .send_and_bail(&mut client)
231 .await;
232 }
233
234 info!("✅ Connected to remote server `{}`.", remote_address);
235
236 let (mut client_read, mut client_write) = client.into_split();
238
239 let remote_up = Arc::new(remote);
241 let remote_down = remote_up.clone();
242
243 let last_activity = Arc::new(tokio::sync::Mutex::new(Instant::now()));
246 let last_activity_up = last_activity.clone();
247 let last_activity_down = last_activity.clone();
248
249 let pump_up: JoinHandle<Void> = tokio::spawn(async move {
250 while let ProtocolMessage::UdpData(data) = client_read.pull().await? {
251 remote_up.send(&data).await?;
252 *last_activity_up.lock().await = Instant::now();
253 }
254
255 Ok(())
256 });
257
258 let pump_down: JoinHandle<Void> = tokio::spawn(async move {
259 let mut buf = [0; Constant::BUFFER_SIZE];
260
261 loop {
262 let size = remote_down.recv(&mut buf).await?;
263 client_write.push(ProtocolMessage::UdpData(buf[..size].to_vec())).await?;
264 *last_activity_down.lock().await = Instant::now();
265 }
266 });
267
268 let timeout: JoinHandle<Void> = tokio::spawn(async move {
269 loop {
270 let last_activity = *last_activity.lock().await;
271
272 if last_activity.elapsed() > Constant::UDP_TIMEOUT {
273 info!("✅ UDP connection timed out (assumed graceful close).");
274 return Ok(());
275 }
276
277 tokio::time::sleep(Constant::UDP_TIMEOUT).await;
278 }
279 });
280
281 let result = select! {
288 r = pump_up => r?,
289 r = pump_down => r?,
290 r = timeout => r?,
291 };
292
293 result?;
296
297 Ok(())
298}
299
300async fn run_tcp_server(config: Config) -> Void {
304 let listener = TcpListener::bind(&config.bind_address).await?;
305
306 loop {
307 let (socket, _) = listener.accept().await?;
308
309 tokio::spawn(handle_connection(socket, config.clone()));
310 }
311}
312
313async fn handle_connection(client: TcpStream, config: Config) {
318 let mut client = BuffedTcpStream::from(client);
319
320 let id = random_string(6);
321 let span = info_span!("conn", id = id);
322
323 let result: Void = async move {
324 let peer_addr = client.as_inner_tcp_write_ref().peer_addr().context("Error getting peer address")?;
325
326 info!("✅ Accepted connection from `{}`.", peer_addr);
327
328 let handshake_data = handle_handshake(&mut client, &config).await.context("Error handling handshake")?;
331
332 if handshake_data.requested_should_encrypt {
334 let private_key = handshake_data.local_exchange_private_key;
335 let salt_bytes = [handshake_data.local_challenge, handshake_data.client_challenge].concat();
336 let shared_secret = generate_shared_secret(private_key, &handshake_data.client_exchange_public_key, &salt_bytes)?;
337
338 client = client.with_encryption(shared_secret);
339 info!("🔒 Encryption applied ...");
340 }
341
342 info!("⛽ Pumping data between client and remote ...");
345
346 if handshake_data.requested_is_udp {
347 run_udp_pump(client, &handshake_data.requested_remote_address).await?;
348 } else {
349 run_tcp_pump(client, &handshake_data.requested_remote_address).await?;
350 }
351
352 info!("✅ Connection closed.");
353
354 Ok(())
355 }
356 .instrument(span.clone())
357 .await;
358
359 let _guard = span.enter();
361
362 if let Err(err) = result {
363 let chain = err.chain().collect::<Vec<_>>();
364 let full_chain = chain.iter().map(|e| format!("`{}`", e)).collect::<Vec<_>>().join(" => ");
365
366 error!("❌ Error handling connection: {}.", full_chain);
367 }
368}
369
370#[derive(Clone)]
376pub(crate) struct Config {
377 pub(crate) public_key: String,
378 pub(crate) private_key: SecretString,
379 pub(crate) authorized_keys: Vec<String>,
380 pub(crate) bind_address: String,
381 pub(crate) remote_regex: Regex,
382}
383
384impl Config {
385 fn new(public_key: String, private_key: SecretString, authorized_keys: Vec<String>, bind_address: String, remote_regex: Regex) -> Self {
387 Self {
388 public_key,
389 private_key,
390 authorized_keys,
391 bind_address,
392 remote_regex,
393 }
394 }
395}
396
397#[cfg(test)]
400mod tests {
401 use pretty_assertions::assert_eq;
402
403 use crate::{
404 connect::tests::generate_test_client_config,
405 protocol::ClientAuthentication,
406 utils::{
407 generate_key_pair, sign_challenge,
408 tests::{generate_test_duplex, generate_test_fake_exchange_public_key},
409 },
410 };
411
412 use super::*;
413
414 pub(crate) fn generate_test_server_config() -> Config {
415 let key_path = "test/server";
416
417 let public_key = resolve_public_key(key_path).unwrap();
418 let private_key = resolve_private_key(key_path).unwrap();
419 let authorized_keys = resolve_authorized_keys(key_path);
420 let remote_regex = Regex::new(".*").unwrap();
421
422 Config {
423 public_key,
424 private_key,
425 authorized_keys,
426 bind_address: "bind_address".to_string(),
427 remote_regex,
428 }
429 }
430
431 #[test]
432 fn test_prepare_config() {
433 let instance = Instance::prepare("test/server".to_string(), ".*", "foo").unwrap();
434
435 assert_eq!(instance.config.public_key, "HQYY0BNIhdawY2Jw62DudkUsK2GKj3hGO3qSVBlCinI");
436 assert_eq!(instance.config.remote_regex.as_str(), ".*");
437 assert_eq!(instance.config.bind_address, "foo");
438 }
439
440 #[tokio::test]
441 async fn test_handle_handshake_success() {
442 let mut config = generate_test_server_config();
444 let (mut client, mut server) = generate_test_duplex();
445 let client_config = generate_test_client_config();
446
447 config.authorized_keys.push(client_config.public_key.clone());
449
450 let client_challenge: Challenge = [8u8; 32];
452 let client_preamble = ClientPreamble {
453 remote: "localhost".to_string(),
454 challenge: client_challenge,
455 exchange_public_key: generate_test_fake_exchange_public_key(),
456 should_encrypt: true,
457 is_udp: false,
458 };
459
460 client.push(ProtocolMessage::ClientPreamble(client_preamble.clone())).await.unwrap();
461
462 let client_handle = tokio::spawn(async move {
464 let message = client.pull().await.unwrap();
466 let server_challenge = match message {
467 ProtocolMessage::ServerPreamble(preamble) => preamble.challenge,
468 _ => panic!("Expected ServerPreamble message, got: {:?}", message),
469 };
470
471 let signature = sign_challenge(&server_challenge, &client_config.private_key).unwrap();
473 let client_auth = ClientAuthentication {
474 identity_public_key: client_config.public_key,
475 signature: signature.into(),
476 };
477
478 client.push(ProtocolMessage::ClientAuthentication(client_auth)).await.unwrap();
479
480 let message = client.pull().await.unwrap();
482 assert!(matches!(message, ProtocolMessage::HandshakeCompletion));
483 });
484
485 let result = handle_handshake(&mut server, &config).await;
487
488 client_handle.await.unwrap();
490
491 assert!(result.is_ok());
493 let key_data = result.unwrap();
494
495 assert_eq!(key_data.client_exchange_public_key, client_preamble.exchange_public_key);
497 assert_eq!(key_data.client_challenge, client_challenge);
498 assert_eq!(key_data.requested_remote_address, "localhost");
499 assert_eq!(key_data.requested_should_encrypt, true);
500 }
501
502 #[tokio::test]
503 async fn test_handle_handshake_invalid_start() {
504 let config = generate_test_server_config();
506 let (mut client, mut server) = generate_test_duplex();
507
508 client.push(ProtocolMessage::HandshakeCompletion).await.unwrap();
510
511 let result = handle_handshake(&mut server, &config).await;
513
514 assert!(result.is_err());
516
517 let message = client.pull().await.unwrap();
519 if let ProtocolMessage::Error(error) = message {
520 assert_eq!(error, ProtocolError::Unknown("Invalid handshake start".into()));
521 } else {
522 panic!("Expected error message, got: {:?}", message);
523 }
524 }
525
526 #[tokio::test]
527 async fn test_handle_handshake_invalid_host() {
528 let mut config = generate_test_server_config();
530 config.remote_regex = Regex::new("^only-this-host$").unwrap();
531 let (mut client, mut server) = generate_test_duplex();
532
533 let client_preamble = ClientPreamble {
535 remote: "different-host".to_string(),
536 challenge: [9u8; 32],
537 exchange_public_key: generate_test_fake_exchange_public_key(),
538 should_encrypt: false,
539 is_udp: false,
540 };
541
542 client.push(ProtocolMessage::ClientPreamble(client_preamble)).await.unwrap();
543
544 let result = handle_handshake(&mut server, &config).await;
546
547 assert!(result.is_err());
549
550 let message = client.pull().await.unwrap();
552 if let ProtocolMessage::Error(error) = message {
553 assert!(matches!(error, ProtocolError::InvalidHost(_)));
554 } else {
555 panic!("Expected error message, got: {:?}", message);
556 }
557 }
558
559 #[tokio::test]
560 async fn test_handle_handshake_unauthorized_key() {
561 let config = generate_test_server_config();
563 let (mut client, mut server) = generate_test_duplex();
564
565 let client_preamble = ClientPreamble {
567 remote: "localhost".to_string(),
568 challenge: [10u8; 32],
569 exchange_public_key: generate_test_fake_exchange_public_key(),
570 should_encrypt: false,
571 is_udp: false,
572 };
573
574 client.push(ProtocolMessage::ClientPreamble(client_preamble)).await.unwrap();
575
576 let unauthorized_key_pair = generate_key_pair().unwrap();
578 let unauthorized_private_key = unauthorized_key_pair.private_key.into();
579
580 let client_handle = tokio::spawn(async move {
582 let message = client.pull().await.unwrap();
584 let server_challenge = match message {
585 ProtocolMessage::ServerPreamble(preamble) => preamble.challenge,
586 _ => panic!("Expected ServerPreamble message, got: {:?}", message),
587 };
588
589 let signature = sign_challenge(&server_challenge, &unauthorized_private_key).unwrap();
591 let client_auth = ClientAuthentication {
592 identity_public_key: unauthorized_key_pair.public_key,
593 signature: signature.into(),
594 };
595
596 client.push(ProtocolMessage::ClientAuthentication(client_auth)).await.unwrap();
597
598 let message = client.pull().await.unwrap();
600 assert!(matches!(message, ProtocolMessage::Error(_)));
601 });
602
603 let result = handle_handshake(&mut server, &config).await;
605
606 client_handle.await.unwrap();
608
609 assert!(result.is_err());
611 }
612}