1use std::marker::PhantomData;
4
5use anyhow::Context;
6use regex::Regex;
7use secrecy::SecretString;
8use tokio::net::{TcpListener, TcpStream, UdpSocket};
9use tracing::{Instrument, error, info, info_span};
10
11use crate::{
12 base::{ExchangeKeyPair, Res, ServerKeyExchangeData, Void},
13 buffed_stream::{BincodeSplit, BuffedTcpStream},
14 protocol::{BincodeReceive, BincodeSend, Challenge, ClientPreamble, ProtocolError, ProtocolMessage, ServerPreamble, Signature},
15 security::{resolve_authorized_keys, resolve_keypath, resolve_private_key, resolve_public_key},
16 utils::{generate_challenge, generate_ephemeral_key_pair, generate_shared_secret, handle_tcp_pump, handle_udp_pump, random_string, sign_challenge, validate_signed_challenge},
17};
18
19pub struct ConfigState;
23pub struct ReadyState;
25
26pub struct Instance<S = ConfigState> {
30 config: Config,
31 _phantom: PhantomData<S>,
32}
33
34impl Instance<ConfigState> {
35 pub fn prepare<A, B, C>(key_path: A, remote_regex: B, bind_address: C) -> Res<Instance<ReadyState>>
37 where
38 A: Into<Option<String>>,
39 B: AsRef<str>,
40 C: Into<String>,
41 {
42 let remote_regex = Regex::new(remote_regex.as_ref()).context("Invalid regex for remote host.")?;
43
44 let key_path = resolve_keypath(key_path)?;
45 let private_key = resolve_private_key(&key_path)?;
46 let public_key = resolve_public_key(&key_path)?;
47 let authorized_keys = resolve_authorized_keys(&key_path);
48
49 let config = Config::new(public_key, private_key, authorized_keys, bind_address.into(), remote_regex);
50
51 Ok(Instance { config, _phantom: PhantomData })
52 }
53}
54
55impl Instance<ReadyState> {
56 pub async fn start(self) -> Void {
58 info!("🚀 Starting server on `{}` ...", self.config.bind_address);
59
60 run_tcp_server(self.config.clone()).await?;
61
62 Ok(())
63 }
64}
65
66async fn verify_client_preamble<T>(stream: &mut T, config: &Config, preamble: &ClientPreamble<'_>) -> Res<Signature>
72where
73 T: BincodeSend,
74{
75 if !config.remote_regex.is_match(preamble.remote) {
78 return ProtocolError::InvalidHost(&format!("Invalid host from client (supplied `{}`, but need to satisfy `{}`)", preamble.remote, config.remote_regex))
79 .send_and_bail(stream)
80 .await;
81 }
82
83 let signature = sign_challenge(preamble.challenge, &config.private_key)?;
86
87 Ok(signature)
88}
89
90async fn send_server_preamble<T>(stream: &mut T, config: &Config, server_signature: &Signature, server_challenge: &Challenge) -> Res<ExchangeKeyPair>
92where
93 T: BincodeSend,
94{
95 info!("🚧 Sending handshake challenge to client ...");
96
97 let exchange_key_pair = generate_ephemeral_key_pair()?;
98 let exchange_public_key = exchange_key_pair.public_key.as_ref();
99
100 let preamble = ServerPreamble {
101 challenge: server_challenge,
102 exchange_public_key,
103 identity_public_key: &config.public_key,
104 signature: server_signature,
105 };
106
107 stream.push(ProtocolMessage::ServerPreamble(preamble)).await?;
108
109 Ok(exchange_key_pair)
110}
111
112async fn handle_and_validate_key_challenge<T>(stream: &mut T, config: &Config, server_challenge: &Challenge) -> Void
117where
118 T: BincodeSend + BincodeReceive,
119{
120 let guard = stream.pull().await?;
123 let ProtocolMessage::ClientAuthentication(client_authentication) = guard.message().fail_if_error()? else {
124 return ProtocolError::InvalidKey("Invalid handshake response").send_and_bail(stream).await;
125 };
126
127 if validate_signed_challenge(server_challenge, client_authentication.signature, client_authentication.identity_public_key).is_err() {
130 return ProtocolError::InvalidKey("Invalid challenge signature from client").send_and_bail(stream).await;
131 }
132
133 if !config.authorized_keys.iter().any(|k| k == client_authentication.identity_public_key) {
135 return ProtocolError::InvalidKey("Unauthorized key from client").send_and_bail(stream).await;
136 }
137
138 info!("✅ Handshake challenge completed!");
139
140 Ok(())
141}
142
143async fn complete_handshake<T>(stream: &mut T) -> Void
148where
149 T: BincodeSend,
150{
151 let completion = ProtocolMessage::HandshakeCompletion;
152
153 stream.push(completion).await?;
154
155 info!("✅ Handshake completed.");
156
157 Ok(())
158}
159
160async fn handle_handshake<T>(stream: &mut T, config: &Config) -> Res<ServerKeyExchangeData>
165where
166 T: BincodeSplit + BincodeReceive + BincodeSend,
167{
168 let (read, write) = stream.split();
169
170 let guard = read.pull().await?;
173 let ProtocolMessage::ClientPreamble(preamble) = guard.message() else {
174 return ProtocolError::Unknown("Invalid handshake start").send_and_bail(stream).await;
175 };
176
177 let client_exchange_public_key = preamble.exchange_public_key.try_into()?;
180 let client_challenge = preamble.challenge.try_into()?;
181 let requested_remote_address = preamble.remote.into();
182 let requested_should_encrypt = preamble.should_encrypt;
183 let requested_is_udp = preamble.is_udp;
184
185 let server_signature = verify_client_preamble(write, config, preamble).await?;
188
189 let server_challenge = generate_challenge();
192
193 let local_exchange_key_pair = send_server_preamble(write, config, &server_signature, &server_challenge).await?;
196
197 handle_and_validate_key_challenge(stream, config, &server_challenge).await?;
200
201 complete_handshake(stream).await?;
204
205 Ok(ServerKeyExchangeData {
206 client_exchange_public_key,
207 client_challenge,
208 local_exchange_private_key: local_exchange_key_pair.private_key,
209 local_challenge: server_challenge,
210 requested_remote_address,
211 requested_should_encrypt,
212 requested_is_udp,
213 })
214}
215
216async fn run_tcp_pump(mut client: BuffedTcpStream, remote_address: &str) -> Void {
218 let Ok(remote) = TcpStream::connect(remote_address).await.context("Error connecting to remote") else {
219 return ProtocolError::RemoteFailed(&format!("Failed to connect to remote `{}`", remote_address))
220 .send_and_bail(&mut client)
221 .await;
222 };
223
224 remote.set_nodelay(true)?;
225
226 info!("✅ Connected to remote server `{}`.", remote_address);
227
228 handle_tcp_pump(remote, client).await.context("Error handling TCP pump.")?;
229
230 Ok(())
231}
232
233async fn run_udp_pump(mut client: BuffedTcpStream, remote_address: &str) -> Void {
235 let remote = UdpSocket::bind("127.0.0.1:0").await.context("Error binding UDP socket")?;
236 if remote.connect(remote_address).await.is_err() {
237 return ProtocolError::RemoteFailed(&format!("Failed to connect to remote `{}`", remote_address))
238 .send_and_bail(&mut client)
239 .await;
240 }
241
242 info!("✅ Connected to remote server `{}`.", remote_address);
243
244 handle_udp_pump(remote, client).await.context("Error handling UDP pump.")?;
245
246 Ok(())
247}
248
249async fn run_tcp_server(config: Config) -> Void {
253 let listener = TcpListener::bind(&config.bind_address).await?;
254
255 loop {
256 let (socket, _) = listener.accept().await?;
257
258 tokio::spawn(handle_connection(socket, config.clone()));
259 }
260}
261
262async fn handle_connection(client: TcpStream, config: Config) {
267 let id = random_string(6);
268 let span = info_span!("conn", id = id);
269
270 let result: Void = async move {
271 client.set_nodelay(true)?;
272 let peer_addr = client.peer_addr().context("Error getting peer address")?;
273
274 let mut client = BuffedTcpStream::from(client);
275
276 info!("✅ Accepted connection from `{}`.", peer_addr);
277
278 let handshake_data = handle_handshake(&mut client, &config).await.context("Error handling handshake")?;
281
282 if handshake_data.requested_should_encrypt {
284 let private_key = handshake_data.local_exchange_private_key;
285 let salt_bytes = [handshake_data.local_challenge, handshake_data.client_challenge].concat();
286 let shared_secret = generate_shared_secret(private_key, &handshake_data.client_exchange_public_key, &salt_bytes)?;
287
288 client = client.with_encryption(shared_secret);
289 info!("🔒 Encryption applied ...");
290 }
291
292 info!("⛽ Pumping data between client and remote ...");
295
296 if handshake_data.requested_is_udp {
297 run_udp_pump(client, &handshake_data.requested_remote_address).await?;
298 } else {
299 run_tcp_pump(client, &handshake_data.requested_remote_address).await?;
300 }
301
302 info!("✅ Connection closed.");
303
304 Ok(())
305 }
306 .instrument(span.clone())
307 .await;
308
309 let _guard = span.enter();
311
312 if let Err(err) = result {
313 let chain = err.chain().collect::<Vec<_>>();
314 let full_chain = chain.iter().map(|e| format!("`{}`", e)).collect::<Vec<_>>().join(" => ");
315
316 error!("❌ Error handling connection: {}.", full_chain);
317 }
318}
319
320#[derive(Clone)]
326pub(crate) struct Config {
327 pub(crate) public_key: String,
328 pub(crate) private_key: SecretString,
329 pub(crate) authorized_keys: Vec<String>,
330 pub(crate) bind_address: String,
331 pub(crate) remote_regex: Regex,
332}
333
334impl Config {
335 fn new(public_key: String, private_key: SecretString, authorized_keys: Vec<String>, bind_address: String, remote_regex: Regex) -> Self {
337 Self {
338 public_key,
339 private_key,
340 authorized_keys,
341 bind_address,
342 remote_regex,
343 }
344 }
345}
346
347#[cfg(test)]
350mod tests {
351 use pretty_assertions::assert_eq;
352
353 use crate::{
354 connect::tests::generate_test_client_config,
355 protocol::ClientAuthentication,
356 utils::{
357 generate_key_pair, sign_challenge,
358 tests::{generate_test_duplex, generate_test_fake_exchange_public_key},
359 },
360 };
361
362 use super::*;
363
364 pub(crate) fn generate_test_server_config() -> Config {
365 let key_path = "test/server";
366
367 let public_key = resolve_public_key(key_path).unwrap();
368 let private_key = resolve_private_key(key_path).unwrap();
369 let authorized_keys = resolve_authorized_keys(key_path);
370 let remote_regex = Regex::new(".*").unwrap();
371
372 Config {
373 public_key,
374 private_key,
375 authorized_keys,
376 bind_address: "bind_address".to_string(),
377 remote_regex,
378 }
379 }
380
381 #[test]
382 fn test_prepare_config() {
383 let instance = Instance::prepare("test/server".to_string(), ".*", "foo").unwrap();
384
385 assert_eq!(instance.config.public_key, "HQYY0BNIhdawY2Jw62DudkUsK2GKj3hGO3qSVBlCinI");
386 assert_eq!(instance.config.remote_regex.as_str(), ".*");
387 assert_eq!(instance.config.bind_address, "foo");
388 }
389
390 #[tokio::test]
391 async fn test_handle_handshake_success() {
392 let mut config = generate_test_server_config();
394 let (mut client, mut server) = generate_test_duplex();
395 let client_config = generate_test_client_config();
396
397 config.authorized_keys.push(client_config.public_key.clone());
399
400 let client_challenge: Challenge = [8u8; 32];
402 let client_preamble = ClientPreamble {
403 remote: "localhost",
404 challenge: &client_challenge,
405 exchange_public_key: &generate_test_fake_exchange_public_key(),
406 should_encrypt: true,
407 is_udp: false,
408 };
409
410 client.push(ProtocolMessage::ClientPreamble(client_preamble.clone())).await.unwrap();
411
412 let client_handle = tokio::spawn(async move {
414 let guard = client.pull().await.unwrap();
416 let server_challenge = match guard.message() {
417 ProtocolMessage::ServerPreamble(preamble) => preamble.challenge,
418 _ => panic!("Expected ServerPreamble message, got: {:?}", guard.message()),
419 };
420
421 let signature = sign_challenge(server_challenge, &client_config.private_key).unwrap();
423 let client_auth = ClientAuthentication {
424 identity_public_key: &client_config.public_key,
425 signature: &signature,
426 };
427
428 client.push(ProtocolMessage::ClientAuthentication(client_auth)).await.unwrap();
429
430 let guard = client.pull().await.unwrap();
432 assert!(matches!(guard.message(), ProtocolMessage::HandshakeCompletion));
433 });
434
435 let result = handle_handshake(&mut server, &config).await;
437
438 client_handle.await.unwrap();
440
441 assert!(result.is_ok());
443 let key_data = result.unwrap();
444
445 assert_eq!(key_data.client_exchange_public_key, client_preamble.exchange_public_key);
447 assert_eq!(key_data.client_challenge, client_challenge);
448 assert_eq!(key_data.requested_remote_address, "localhost");
449 assert_eq!(key_data.requested_should_encrypt, true);
450 }
451
452 #[tokio::test]
453 async fn test_handle_handshake_invalid_start() {
454 let config = generate_test_server_config();
456 let (mut client, mut server) = generate_test_duplex();
457
458 client.push(ProtocolMessage::HandshakeCompletion).await.unwrap();
460
461 let result = handle_handshake(&mut server, &config).await;
463
464 assert!(result.is_err());
466
467 let guard = client.pull().await.unwrap();
469 if let ProtocolMessage::Error(error) = guard.message() {
470 assert_eq!(error, &ProtocolError::Unknown("Invalid handshake start"));
471 } else {
472 panic!("Expected error message, got: {:?}", guard.message());
473 }
474 }
475
476 #[tokio::test]
477 async fn test_handle_handshake_invalid_host() {
478 let mut config = generate_test_server_config();
480 config.remote_regex = Regex::new("^only-this-host$").unwrap();
481 let (mut client, mut server) = generate_test_duplex();
482
483 let client_preamble = ClientPreamble {
485 remote: "different-host",
486 challenge: &[9u8; 32],
487 exchange_public_key: &generate_test_fake_exchange_public_key(),
488 should_encrypt: false,
489 is_udp: false,
490 };
491
492 client.push(ProtocolMessage::ClientPreamble(client_preamble)).await.unwrap();
493
494 let result = handle_handshake(&mut server, &config).await;
496
497 assert!(result.is_err());
499
500 let guard = client.pull().await.unwrap();
502 if let ProtocolMessage::Error(error) = guard.message() {
503 assert!(matches!(error, ProtocolError::InvalidHost(_)));
504 } else {
505 panic!("Expected error message, got: {:?}", guard.message());
506 }
507 }
508
509 #[tokio::test]
510 async fn test_handle_handshake_unauthorized_key() {
511 let config = generate_test_server_config();
513 let (mut client, mut server) = generate_test_duplex();
514
515 let client_preamble = ClientPreamble {
517 remote: "localhost",
518 challenge: &[10u8; 32],
519 exchange_public_key: &generate_test_fake_exchange_public_key(),
520 should_encrypt: false,
521 is_udp: false,
522 };
523
524 client.push(ProtocolMessage::ClientPreamble(client_preamble)).await.unwrap();
525
526 let unauthorized_key_pair = generate_key_pair().unwrap();
528 let unauthorized_private_key = unauthorized_key_pair.private_key.into();
529
530 let client_handle = tokio::spawn(async move {
532 let guard = client.pull().await.unwrap();
534 let server_challenge = match guard.message() {
535 ProtocolMessage::ServerPreamble(preamble) => preamble.challenge,
536 _ => panic!("Expected ServerPreamble message, got: {:?}", guard.message()),
537 };
538
539 let signature = sign_challenge(server_challenge, &unauthorized_private_key).unwrap();
541 let client_auth = ClientAuthentication {
542 identity_public_key: &unauthorized_key_pair.public_key,
543 signature: &signature,
544 };
545
546 client.push(ProtocolMessage::ClientAuthentication(client_auth)).await.unwrap();
547
548 let guard = client.pull().await.unwrap();
550 assert!(matches!(guard.message(), &ProtocolMessage::Error(_)));
551 });
552
553 let result = handle_handshake(&mut server, &config).await;
555
556 client_handle.await.unwrap();
558
559 assert!(result.is_err());
561 }
562}