ratrodlib/
serve.rs

1//! This module implements the server-side of the protocol.
2
3use std::{marker::PhantomData, sync::Arc};
4
5use anyhow::Context;
6use regex::Regex;
7use secrecy::SecretString;
8use tokio::{
9    net::{TcpListener, TcpStream, UdpSocket}, select, task::JoinHandle, time::Instant
10};
11use tracing::{Instrument, error, info, info_span};
12
13use crate::{
14    base::{Constant, ExchangeKeyPair, Res, ServerKeyExchangeData, Void},
15    buffed_stream::BuffedTcpStream,
16    protocol::{BincodeReceive, BincodeSend, Challenge, ClientPreamble, ProtocolError, ProtocolMessage, ServerPreamble, Signature},
17    security::{resolve_authorized_keys, resolve_keypath, resolve_private_key, resolve_public_key},
18    utils::{generate_challenge, generate_ephemeral_key_pair, generate_shared_secret, handle_pump, random_string, sign_challenge, validate_signed_challenge},
19};
20
21// State machine.
22
23/// The server is in the configuration state.
24pub struct ConfigState;
25/// The server is in the ready state.
26pub struct ReadyState;
27
28/// The server instance.
29///
30/// This is the main entry point for the server. It is used to prepare the server, and start it.
31pub struct Instance<S = ConfigState> {
32    config: Config,
33    _phantom: PhantomData<S>,
34}
35
36impl Instance<ConfigState> {
37    /// Prepares the server instance.
38    pub fn prepare<A, B, C>(key_path: A, remote_regex: B, bind_address: C) -> Res<Instance<ReadyState>>
39    where
40        A: Into<Option<String>>,
41        B: AsRef<str>,
42        C: Into<String>,
43    {
44        let remote_regex = Regex::new(remote_regex.as_ref()).context("Invalid regex for remote host.")?;
45
46        let key_path = resolve_keypath(key_path)?;
47        let private_key = resolve_private_key(&key_path)?;
48        let public_key = resolve_public_key(&key_path)?;
49        let authorized_keys = resolve_authorized_keys(&key_path);
50
51        let config = Config::new(public_key, private_key, authorized_keys, bind_address.into(), remote_regex);
52
53        Ok(Instance { config, _phantom: PhantomData })
54    }
55}
56
57impl Instance<ReadyState> {
58    /// Starts the server instance.
59    pub async fn start(self) -> Void {
60        info!("🚀 Starting server on `{}` ...", self.config.bind_address);
61
62        run_tcp_server(self.config.clone()).await?;
63
64        Ok(())
65    }
66}
67
68// Operations.
69
70/// Verifies the preamble from the client.
71///
72/// This is used to ensure that the client is allowed to connect to the specified remote.
73async fn verify_client_preamble<T>(stream: &mut T, config: &Config, preamble: &ClientPreamble) -> Res<Signature>
74where
75    T: BincodeSend,
76{
77    // Validate the remote is OK.
78
79    if !config.remote_regex.is_match(&preamble.remote) {
80        return ProtocolError::InvalidHost(format!("Invalid host from client (supplied `{}`, but need to satisfy `{}`)", preamble.remote, config.remote_regex))
81            .send_and_bail(stream)
82            .await;
83    }
84
85    // Sign the challenge.
86
87    let signature = sign_challenge(&preamble.challenge, &config.private_key)?;
88
89    Ok(signature)
90}
91
92async fn send_server_preamble<T>(stream: &mut T, config: &Config, server_signature: &Signature, server_challenge: &Challenge) -> Res<ExchangeKeyPair>
93where
94    T: BincodeSend,
95{
96    info!("🚧 Sending handshake challenge to client ...");
97
98    let exchange_key_pair = generate_ephemeral_key_pair()?;
99    let exchange_public_key = exchange_key_pair.public_key.as_ref().try_into()?;
100
101    let preamble = ServerPreamble {
102        challenge: *server_challenge,
103        exchange_public_key,
104        identity_public_key: config.public_key.clone(),
105        signature: server_signature.into(),
106    };
107
108    stream.push(ProtocolMessage::ServerPreamble(preamble)).await?;
109
110    Ok(exchange_key_pair)
111}
112
113/// Handles the key challenge from the client.
114///
115/// This is used to ensure that the client is allowed to connect to the server.
116/// It also verifies the signature of the challenge from the client, authenticating the client.
117async fn handle_and_validate_key_challenge<T>(stream: &mut T, config: &Config, server_challenge: &Challenge) -> Void
118where
119    T: BincodeSend + BincodeReceive,
120{
121    // Wait for the client to respond.
122
123    let ProtocolMessage::ClientAuthentication(client_authentication) = stream.pull().await?.fail_if_error()? else {
124        return ProtocolError::InvalidKey("Invalid handshake response".into()).send_and_bail(stream).await;
125    };
126
127    // Verify the signature.
128
129    if validate_signed_challenge(server_challenge, &client_authentication.signature.into(), &client_authentication.identity_public_key).is_err() {
130        return ProtocolError::InvalidKey("Invalid challenge signature from client".into()).send_and_bail(stream).await;
131    }
132
133    // Validate that the key is authorized.
134    if !config.authorized_keys.contains(&client_authentication.identity_public_key) {
135        return ProtocolError::InvalidKey("Unauthorized key from client".into()).send_and_bail(stream).await;
136    }
137
138    info!("✅ Handshake challenge completed!");
139
140    Ok(())
141}
142
143/// Completes the handshake.
144///
145/// This is used to send the server's ephemeral public key to the client
146/// for the key exchange.
147async fn complete_handshake<T>(stream: &mut T) -> Void
148where
149    T: BincodeSend,
150{
151    let completion = ProtocolMessage::HandshakeCompletion;
152
153    stream.push(completion).await?;
154
155    info!("✅ Handshake completed.");
156
157    Ok(())
158}
159
160/// Handles the e2e handshake.
161///
162/// This is used to handle the handshake between the client and server.
163/// It verifies the preamble, handles the key challenge, and completes the handshake.
164async fn handle_handshake<T>(stream: &mut T, config: &Config) -> Res<ServerKeyExchangeData>
165where
166    T: BincodeReceive + BincodeSend,
167{
168    // Ingest the preamble from the client.
169
170    let ProtocolMessage::ClientPreamble(preamble) = stream.pull().await? else {
171        return ProtocolError::Unknown("Invalid handshake start".into()).send_and_bail(stream).await;
172    };
173
174    // Verify the preamble.
175
176    let server_signature = verify_client_preamble(stream, config, &preamble).await?;
177
178    // Create a challenge.
179
180    let server_challenge = generate_challenge();
181
182    // Send the server preamble.
183
184    let local_exchange_key_pair = send_server_preamble(stream, config, &server_signature, &server_challenge).await?;
185
186    // Validate the client's auth response.
187
188    handle_and_validate_key_challenge(stream, config, &server_challenge).await?;
189
190    // Complete the handshake.
191
192    complete_handshake(stream).await?;
193
194    Ok(ServerKeyExchangeData {
195        client_exchange_public_key: preamble.exchange_public_key,
196        client_challenge: preamble.challenge,
197        local_exchange_private_key: local_exchange_key_pair.private_key,
198        local_challenge: server_challenge,
199        requested_remote_address: preamble.remote,
200        requested_should_encrypt: preamble.should_encrypt,
201        requested_is_udp: preamble.is_udp,
202    })
203}
204
205/// Runs the pump with a TCP-connected remote.
206async fn run_tcp_pump(mut client: BuffedTcpStream, remote_address: &str) -> Void {
207    let Ok(mut remote) = TcpStream::connect(remote_address).await.context("Error connecting to remote") else {
208        return ProtocolError::RemoteFailed(format!("Failed to connect to remote `{}`", remote_address))
209            .send_and_bail(&mut client)
210            .await;
211    };
212
213    info!("✅ Connected to remote server `{}`.", remote_address);
214
215    handle_pump(&mut client, &mut remote).await.context("Error handling TCP pump.")?;
216
217    Ok(())
218}
219
220/// Runs the pump with a UDP-connected remote.
221async fn run_udp_pump(mut client: BuffedTcpStream, remote_address: &str) -> Void {
222    // Attempt to connect to the remote address (should only fail in weird circumstances).
223
224    let remote = UdpSocket::bind("127.0.0.1:0").await.context("Error binding UDP socket")?;
225    if remote.connect(remote_address).await.is_err() {
226        return ProtocolError::RemoteFailed(format!("Failed to connect to remote `{}`", remote_address))
227            .send_and_bail(&mut client)
228            .await;
229    }
230
231    info!("✅ Connected to remote server `{}`.", remote_address);
232
233    // Split the client connection into a read and write half.
234    let (mut client_read, mut client_write) = client.into_split();
235
236    // Split the remote connection into a read and write half (just requires `Arc`ing, since the UDP send / receive does not require `&mut`).
237    let remote_up = Arc::new(remote);
238    let remote_down = remote_up.clone();
239
240    // Run the pumps.
241    
242    let last_activity = Arc::new(tokio::sync::Mutex::new(Instant::now()));
243    let last_activity_up = last_activity.clone();
244    let last_activity_down = last_activity.clone();
245
246    let pump_up: JoinHandle<Void> = tokio::spawn(async move {
247        while let ProtocolMessage::UdpData(data) = client_read.pull().await? {
248            remote_up.send(&data).await?;
249            *last_activity_up.lock().await = Instant::now();
250        }
251
252        Ok(())
253    });
254
255    let pump_down: JoinHandle<Void> = tokio::spawn(async move {
256        let mut buf = [0; Constant::BUFFER_SIZE];
257
258        loop {
259            let size = remote_down.recv(&mut buf).await?;
260            client_write.push(ProtocolMessage::UdpData(buf[..size].to_vec())).await?;
261            *last_activity_down.lock().await = Instant::now();
262        }
263    });
264
265    let timeout: JoinHandle<Void> = tokio::spawn(async move {
266        loop {
267            let last_activity = *last_activity.lock().await;
268
269            if last_activity.elapsed() > Constant::UDP_TIMEOUT {
270                info!("✅ UDP connection timed out (assumed graceful close).");
271                return Ok(());
272            }
273
274            tokio::time::sleep(Constant::UDP_TIMEOUT).await;
275        }
276    });
277
278    // Wait for the pumps to finish.  This employs a "last activity" type timeout, and uses `select!` to break
279    // out of the loop if any of the pumps finish.  In general, the UDP side to the remote will not close,
280    // but the client may break the pipe, so we need to handle that.
281    // The `select!` macro will return the first result that completes,
282    // and the `timeout` will return if the last activity is too long ago.
283
284    let result = select! {
285        r = pump_up => r?,
286        r = pump_down => r?,
287        r = timeout => r?,
288    };
289
290    // Check for errors.
291
292    result?;
293
294    Ok(())
295}
296
297/// Runs the TCP server.
298///
299/// This is the main entry point for the server. It binds to the specified address, and handles incoming connections.
300async fn run_tcp_server(config: Config) -> Void {
301    let listener = TcpListener::bind(&config.bind_address).await?;
302
303    loop {
304        let (socket, _) = listener.accept().await?;
305
306        tokio::spawn(handle_connection(socket, config.clone()));
307    }
308}
309
310/// Handles the TCP connection.
311///
312/// This is used to handle the TCP connection between the client and server.
313/// It handles the handshake, and pumps data between the client and server.
314async fn handle_connection(client: TcpStream, config: Config) {
315    let mut client = BuffedTcpStream::from(client);
316
317    let id = random_string(6);
318    let span = info_span!("conn", id = id);
319
320    let result: Void = async move {
321        let peer_addr = client.as_inner_tcp_write_ref().peer_addr().context("Error getting peer address")?;
322
323        info!("✅ Accepted connection from `{}`.", peer_addr);
324
325        // Handle the handshake.
326
327        let handshake_data = handle_handshake(&mut client, &config).await.context("Error handling handshake")?;
328
329        // Generate and apply the shared secret, if needed.
330        if handshake_data.requested_should_encrypt {
331            let private_key = handshake_data.local_exchange_private_key;
332            let salt_bytes = [handshake_data.local_challenge, handshake_data.client_challenge].concat();
333            let shared_secret = generate_shared_secret(private_key, &handshake_data.client_exchange_public_key, &salt_bytes)?;
334
335            client = client.with_encryption(shared_secret);
336            info!("🔒 Encryption applied ...");
337        }
338
339        // Handle the pump.
340
341        info!("⛽ Pumping data between client and remote ...");
342
343        if handshake_data.requested_is_udp {
344            run_udp_pump(client, &handshake_data.requested_remote_address).await?;
345        } else {
346            run_tcp_pump(client, &handshake_data.requested_remote_address).await?;
347        }
348
349        info!("✅ Connection closed.");
350
351        Ok(())
352    }
353    .instrument(span.clone())
354    .await;
355
356    // Enter the span, so that the error is logged with the span's metadata, if needed.
357    let _guard = span.enter();
358
359    if let Err(err) = result {
360        let chain = err.chain().collect::<Vec<_>>();
361        let full_chain = chain.iter().map(|e| format!("`{}`", e)).collect::<Vec<_>>().join(" => ");
362
363        error!("❌ Error handling connection: {}.", full_chain);
364    }
365}
366
367// Config.
368
369/// The server configuration.
370///
371/// This is used to store the server's configuration.
372#[derive(Clone)]
373pub(crate) struct Config {
374    pub(crate) public_key: String,
375    pub(crate) private_key: SecretString,
376    pub(crate) authorized_keys: Vec<String>,
377    pub(crate) bind_address: String,
378    pub(crate) remote_regex: Regex,
379}
380
381impl Config {
382    /// Creates a new server configuration.
383    fn new(public_key: String, private_key: SecretString, authorized_keys: Vec<String>, bind_address: String, remote_regex: Regex) -> Self {
384        Self {
385            public_key,
386            private_key,
387            authorized_keys,
388            bind_address,
389            remote_regex,
390        }
391    }
392}
393
394// Tests.
395
396#[cfg(test)]
397mod tests {
398    use pretty_assertions::assert_eq;
399
400    use crate::{
401        connect::tests::generate_test_client_config,
402        protocol::ClientAuthentication,
403        utils::{
404            generate_key_pair, sign_challenge,
405            tests::{generate_test_duplex, generate_test_fake_exchange_public_key},
406        },
407    };
408
409    use super::*;
410
411    pub(crate) fn generate_test_server_config() -> Config {
412        let key_path = "test/server";
413
414        let public_key = resolve_public_key(key_path).unwrap();
415        let private_key = resolve_private_key(key_path).unwrap();
416        let authorized_keys = resolve_authorized_keys(key_path);
417        let remote_regex = Regex::new(".*").unwrap();
418
419        Config {
420            public_key,
421            private_key,
422            authorized_keys,
423            bind_address: "bind_address".to_string(),
424            remote_regex,
425        }
426    }
427
428    #[test]
429    fn test_prepare_config() {
430        let instance = Instance::prepare("test/server".to_string(), ".*", "foo").unwrap();
431
432        assert_eq!(instance.config.public_key, "HQYY0BNIhdawY2Jw62DudkUsK2GKj3hGO3qSVBlCinI");
433        assert_eq!(instance.config.remote_regex.as_str(), ".*");
434        assert_eq!(instance.config.bind_address, "foo");
435    }
436
437    #[tokio::test]
438    async fn test_handle_handshake_success() {
439        // Setup test environment
440        let mut config = generate_test_server_config();
441        let (mut client, mut server) = generate_test_duplex();
442        let client_config = generate_test_client_config();
443
444        // Add client's public key to authorized keys
445        config.authorized_keys.push(client_config.public_key.clone());
446
447        // Client sends preamble
448        let client_challenge: Challenge = [8u8; 32];
449        let client_preamble = ClientPreamble {
450            remote: "localhost".to_string(),
451            challenge: client_challenge,
452            exchange_public_key: generate_test_fake_exchange_public_key(),
453            should_encrypt: true,
454            is_udp: false,
455        };
456
457        client.push(ProtocolMessage::ClientPreamble(client_preamble.clone())).await.unwrap();
458
459        // Client prepares to respond to server's challenge
460        let client_handle = tokio::spawn(async move {
461            // Get server preamble
462            let message = client.pull().await.unwrap();
463            let server_challenge = match message {
464                ProtocolMessage::ServerPreamble(preamble) => preamble.challenge,
465                _ => panic!("Expected ServerPreamble message, got: {:?}", message),
466            };
467
468            // Send client authentication
469            let signature = sign_challenge(&server_challenge, &client_config.private_key).unwrap();
470            let client_auth = ClientAuthentication {
471                identity_public_key: client_config.public_key,
472                signature: signature.into(),
473            };
474
475            client.push(ProtocolMessage::ClientAuthentication(client_auth)).await.unwrap();
476
477            // Verify handshake completion
478            let message = client.pull().await.unwrap();
479            assert!(matches!(message, ProtocolMessage::HandshakeCompletion));
480        });
481
482        // Execute handshake on server side
483        let result = handle_handshake(&mut server, &config).await;
484
485        // Wait for client to complete
486        client_handle.await.unwrap();
487
488        // Verify success
489        assert!(result.is_ok());
490        let key_data = result.unwrap();
491
492        // Verify returned data
493        assert_eq!(key_data.client_exchange_public_key, client_preamble.exchange_public_key);
494        assert_eq!(key_data.client_challenge, client_challenge);
495        assert_eq!(key_data.requested_remote_address, "localhost");
496        assert_eq!(key_data.requested_should_encrypt, true);
497    }
498
499    #[tokio::test]
500    async fn test_handle_handshake_invalid_start() {
501        // Setup
502        let config = generate_test_server_config();
503        let (mut client, mut server) = generate_test_duplex();
504
505        // Send invalid initial message
506        client.push(ProtocolMessage::HandshakeCompletion).await.unwrap();
507
508        // Execute handshake
509        let result = handle_handshake(&mut server, &config).await;
510
511        // Verify error
512        assert!(result.is_err());
513
514        // Verify client received error
515        let message = client.pull().await.unwrap();
516        if let ProtocolMessage::Error(error) = message {
517            assert_eq!(error, ProtocolError::Unknown("Invalid handshake start".into()));
518        } else {
519            panic!("Expected error message, got: {:?}", message);
520        }
521    }
522
523    #[tokio::test]
524    async fn test_handle_handshake_invalid_host() {
525        // Setup
526        let mut config = generate_test_server_config();
527        config.remote_regex = Regex::new("^only-this-host$").unwrap();
528        let (mut client, mut server) = generate_test_duplex();
529
530        // Send preamble with non-matching host
531        let client_preamble = ClientPreamble {
532            remote: "different-host".to_string(),
533            challenge: [9u8; 32],
534            exchange_public_key: generate_test_fake_exchange_public_key(),
535            should_encrypt: false,
536            is_udp: false,
537        };
538
539        client.push(ProtocolMessage::ClientPreamble(client_preamble)).await.unwrap();
540
541        // Execute handshake
542        let result = handle_handshake(&mut server, &config).await;
543
544        // Verify error
545        assert!(result.is_err());
546
547        // Verify client received error
548        let message = client.pull().await.unwrap();
549        if let ProtocolMessage::Error(error) = message {
550            assert!(matches!(error, ProtocolError::InvalidHost(_)));
551        } else {
552            panic!("Expected error message, got: {:?}", message);
553        }
554    }
555
556    #[tokio::test]
557    async fn test_handle_handshake_unauthorized_key() {
558        // Setup
559        let config = generate_test_server_config();
560        let (mut client, mut server) = generate_test_duplex();
561
562        // Client sends preamble
563        let client_preamble = ClientPreamble {
564            remote: "localhost".to_string(),
565            challenge: [10u8; 32],
566            exchange_public_key: generate_test_fake_exchange_public_key(),
567            should_encrypt: false,
568            is_udp: false,
569        };
570
571        client.push(ProtocolMessage::ClientPreamble(client_preamble)).await.unwrap();
572
573        // Generate unauthorized key pair
574        let unauthorized_key_pair = generate_key_pair().unwrap();
575        let unauthorized_private_key = unauthorized_key_pair.private_key.into();
576
577        // Client responds with unauthorized key
578        let client_handle = tokio::spawn(async move {
579            // Get server preamble
580            let message = client.pull().await.unwrap();
581            let server_challenge = match message {
582                ProtocolMessage::ServerPreamble(preamble) => preamble.challenge,
583                _ => panic!("Expected ServerPreamble message, got: {:?}", message),
584            };
585
586            // Send client authentication with unauthorized key
587            let signature = sign_challenge(&server_challenge, &unauthorized_private_key).unwrap();
588            let client_auth = ClientAuthentication {
589                identity_public_key: unauthorized_key_pair.public_key,
590                signature: signature.into(),
591            };
592
593            client.push(ProtocolMessage::ClientAuthentication(client_auth)).await.unwrap();
594
595            // Check for error response
596            let message = client.pull().await.unwrap();
597            assert!(matches!(message, ProtocolMessage::Error(_)));
598        });
599
600        // Execute handshake on server side
601        let result = handle_handshake(&mut server, &config).await;
602
603        // Wait for client to complete
604        client_handle.await.unwrap();
605
606        // Verify failure
607        assert!(result.is_err());
608    }
609}