ratrodlib/
serve.rs

1//! This module implements the server-side of the protocol.
2
3use std::marker::PhantomData;
4
5use anyhow::Context;
6use regex::Regex;
7use secrecy::SecretString;
8use tokio::net::{TcpListener, TcpStream, UdpSocket};
9use tracing::{Instrument, error, info, info_span};
10
11use crate::{
12    base::{ExchangeKeyPair, Res, ServerKeyExchangeData, Void},
13    buffed_stream::{BincodeSplit, BuffedTcpStream},
14    protocol::{BincodeReceive, BincodeSend, Challenge, ClientPreamble, ProtocolError, ProtocolMessage, ServerPreamble, Signature},
15    security::{resolve_authorized_keys, resolve_keypath, resolve_private_key, resolve_public_key},
16    utils::{generate_challenge, generate_ephemeral_key_pair, generate_shared_secret, handle_tcp_pump, handle_udp_pump, random_string, sign_challenge, validate_signed_challenge},
17};
18
19// State machine.
20
21/// The server is in the configuration state.
22pub struct ConfigState;
23/// The server is in the ready state.
24pub struct ReadyState;
25
26/// The server instance.
27///
28/// This is the main entry point for the server. It is used to prepare the server, and start it.
29pub struct Instance<S = ConfigState> {
30    config: Config,
31    _phantom: PhantomData<S>,
32}
33
34impl Instance<ConfigState> {
35    /// Prepares the server instance.
36    pub fn prepare<A, B, C>(key_path: A, remote_regex: B, bind_address: C) -> Res<Instance<ReadyState>>
37    where
38        A: Into<Option<String>>,
39        B: AsRef<str>,
40        C: Into<String>,
41    {
42        let remote_regex = Regex::new(remote_regex.as_ref()).context("Invalid regex for remote host.")?;
43
44        let key_path = resolve_keypath(key_path)?;
45        let private_key = resolve_private_key(&key_path)?;
46        let public_key = resolve_public_key(&key_path)?;
47        let authorized_keys = resolve_authorized_keys(&key_path);
48
49        let config = Config::new(public_key, private_key, authorized_keys, bind_address.into(), remote_regex);
50
51        Ok(Instance { config, _phantom: PhantomData })
52    }
53}
54
55impl Instance<ReadyState> {
56    /// Starts the server instance.
57    pub async fn start(self) -> Void {
58        info!("🚀 Starting server on `{}` ...", self.config.bind_address);
59
60        run_tcp_server(self.config.clone()).await?;
61
62        Ok(())
63    }
64}
65
66// Operations.
67
68/// Verifies the preamble from the client.
69///
70/// This is used to ensure that the client is allowed to connect to the specified remote.
71async fn verify_client_preamble<T>(stream: &mut T, config: &Config, preamble: &ClientPreamble<'_>) -> Res<Signature>
72where
73    T: BincodeSend,
74{
75    // Validate the remote is OK.
76
77    if !config.remote_regex.is_match(preamble.remote) {
78        return ProtocolError::InvalidHost(format!("Invalid host from client (supplied `{}`, but need to satisfy `{}`)", preamble.remote, config.remote_regex))
79            .send_and_bail(stream)
80            .await;
81    }
82
83    // Sign the challenge.
84
85    let signature = sign_challenge(preamble.challenge, &config.private_key)?;
86
87    Ok(signature)
88}
89
90/// Sends the server preamble to the client.
91async fn send_server_preamble<T>(stream: &mut T, config: &Config, server_signature: &Signature, server_challenge: &Challenge) -> Res<ExchangeKeyPair>
92where
93    T: BincodeSend,
94{
95    info!("🚧 Sending handshake challenge to client ...");
96
97    let exchange_key_pair = generate_ephemeral_key_pair()?;
98    let exchange_public_key = exchange_key_pair.public_key.as_ref();
99
100    let preamble = ServerPreamble {
101        challenge: server_challenge,
102        exchange_public_key,
103        identity_public_key: &config.public_key,
104        signature: server_signature,
105    };
106
107    stream.push(ProtocolMessage::ServerPreamble(preamble)).await?;
108
109    Ok(exchange_key_pair)
110}
111
112/// Handles the key challenge from the client.
113///
114/// This is used to ensure that the client is allowed to connect to the server.
115/// It also verifies the signature of the challenge from the client, authenticating the client.
116async fn handle_and_validate_key_challenge<T>(stream: &mut T, config: &Config, server_challenge: &Challenge) -> Void
117where
118    T: BincodeSend + BincodeReceive,
119{
120    // Wait for the client to respond.
121
122    let guard = stream.pull().await?;
123    let ProtocolMessage::ClientAuthentication(client_authentication) = guard.message().fail_if_error()? else {
124        return ProtocolError::InvalidKey("Invalid handshake response".into()).send_and_bail(stream).await;
125    };
126
127    // Verify the signature.
128
129    if validate_signed_challenge(server_challenge, client_authentication.signature, client_authentication.identity_public_key).is_err() {
130        return ProtocolError::InvalidKey("Invalid challenge signature from client".into()).send_and_bail(stream).await;
131    }
132
133    // Validate that the key is authorized.
134    if !config.authorized_keys.iter().any(|k| k == client_authentication.identity_public_key) {
135        return ProtocolError::InvalidKey("Unauthorized key from client".into()).send_and_bail(stream).await;
136    }
137
138    info!("✅ Handshake challenge completed!");
139
140    Ok(())
141}
142
143/// Completes the handshake.
144///
145/// This is used to send the server's ephemeral public key to the client
146/// for the key exchange.
147async fn complete_handshake<T>(stream: &mut T) -> Void
148where
149    T: BincodeSend,
150{
151    let completion = ProtocolMessage::HandshakeCompletion;
152
153    stream.push(completion).await?;
154
155    info!("✅ Handshake completed.");
156
157    Ok(())
158}
159
160/// Handles the e2e handshake.
161///
162/// This is used to handle the handshake between the client and server.
163/// It verifies the preamble, handles the key challenge, and completes the handshake.
164async fn handle_handshake<T>(stream: &mut T, config: &Config) -> Res<ServerKeyExchangeData>
165where
166    T: BincodeSplit + BincodeReceive + BincodeSend,
167{
168    let (read, write) = stream.split();
169
170    // Ingest the preamble from the client.
171
172    let guard = read.pull().await?;
173    let ProtocolMessage::ClientPreamble(preamble) = guard.message() else {
174        return ProtocolError::Unknown("Invalid handshake start".into()).send_and_bail(stream).await;
175    };
176
177    // Extract the preamble ta so the borrows into the stream can be dropped.
178
179    let client_exchange_public_key = preamble.exchange_public_key.try_into()?;
180    let client_challenge = preamble.challenge.try_into()?;
181    let requested_remote_address = preamble.remote.into();
182    let requested_should_encrypt = preamble.should_encrypt;
183    let requested_is_udp = preamble.is_udp;
184
185    // Verify the preamble.
186
187    let server_signature = verify_client_preamble(write, config, preamble).await?;
188
189    // Create a challenge.
190
191    let server_challenge = generate_challenge();
192
193    // Send the server preamble.
194
195    let local_exchange_key_pair = send_server_preamble(write, config, &server_signature, &server_challenge).await?;
196
197    // Validate the client's auth response.
198
199    handle_and_validate_key_challenge(stream, config, &server_challenge).await?;
200
201    // Complete the handshake.
202
203    complete_handshake(stream).await?;
204
205    Ok(ServerKeyExchangeData {
206        client_exchange_public_key,
207        client_challenge,
208        local_exchange_private_key: local_exchange_key_pair.private_key,
209        local_challenge: server_challenge,
210        requested_remote_address,
211        requested_should_encrypt,
212        requested_is_udp,
213    })
214}
215
216/// Runs the pump with a TCP-connected remote.
217async fn run_tcp_pump(mut client: BuffedTcpStream, remote_address: &str) -> Void {
218    let Ok(remote) = TcpStream::connect(remote_address).await.context("Error connecting to remote") else {
219        return ProtocolError::RemoteFailed(format!("Failed to connect to remote `{}`", remote_address))
220            .send_and_bail(&mut client)
221            .await;
222    };
223
224    remote.set_nodelay(true)?;
225
226    info!("✅ Connected to remote server `{}`.", remote_address);
227
228    handle_tcp_pump(remote, client).await.context("Error handling TCP pump.")?;
229
230    Ok(())
231}
232
233/// Runs the pump with a UDP-connected remote.
234async fn run_udp_pump(mut client: BuffedTcpStream, remote_address: &str) -> Void {
235    let remote = UdpSocket::bind("127.0.0.1:0").await.context("Error binding UDP socket")?;
236    if remote.connect(remote_address).await.is_err() {
237        return ProtocolError::RemoteFailed(format!("Failed to connect to remote `{}`", remote_address))
238            .send_and_bail(&mut client)
239            .await;
240    }
241
242    info!("✅ Connected to remote server `{}`.", remote_address);
243
244    handle_udp_pump(remote, client).await.context("Error handling UDP pump.")?;
245
246    Ok(())
247}
248
249/// Runs the TCP server.
250///
251/// This is the main entry point for the server. It binds to the specified address, and handles incoming connections.
252async fn run_tcp_server(config: Config) -> Void {
253    let listener = TcpListener::bind(&config.bind_address).await?;
254
255    loop {
256        let (socket, _) = listener.accept().await?;
257
258        tokio::spawn(handle_connection(socket, config.clone()));
259    }
260}
261
262/// Handles the TCP connection.
263///
264/// This is used to handle the TCP connection between the client and server.
265/// It handles the handshake, and pumps data between the client and server.
266async fn handle_connection(client: TcpStream, config: Config) {
267    let id = random_string(6);
268    let span = info_span!("conn", id = id);
269
270    let result: Void = async move {
271        client.set_nodelay(true)?;
272        let peer_addr = client.peer_addr().context("Error getting peer address")?;
273
274        let mut client = BuffedTcpStream::from(client);
275
276        info!("✅ Accepted connection from `{}`.", peer_addr);
277
278        // Handle the handshake.
279
280        let handshake_data = handle_handshake(&mut client, &config).await.context("Error handling handshake")?;
281
282        // Generate and apply the shared secret, if needed.
283        if handshake_data.requested_should_encrypt {
284            let private_key = handshake_data.local_exchange_private_key;
285            let salt_bytes = [handshake_data.local_challenge, handshake_data.client_challenge].concat();
286            let shared_secret = generate_shared_secret(private_key, &handshake_data.client_exchange_public_key, &salt_bytes)?;
287
288            client = client.with_encryption(shared_secret);
289            info!("🔒 Encryption applied ...");
290        }
291
292        // Handle the pump.
293
294        info!("⛽ Pumping data between client and remote ...");
295
296        if handshake_data.requested_is_udp {
297            run_udp_pump(client, &handshake_data.requested_remote_address).await?;
298        } else {
299            run_tcp_pump(client, &handshake_data.requested_remote_address).await?;
300        }
301
302        info!("✅ Connection closed.");
303
304        Ok(())
305    }
306    .instrument(span.clone())
307    .await;
308
309    // Enter the span, so that the error is logged with the span's metadata, if needed.
310    let _guard = span.enter();
311
312    if let Err(err) = result {
313        let chain = err.chain().collect::<Vec<_>>();
314        let full_chain = chain.iter().map(|e| format!("`{}`", e)).collect::<Vec<_>>().join(" => ");
315
316        error!("❌ Error handling connection: {}.", full_chain);
317    }
318}
319
320// Config.
321
322/// The server configuration.
323///
324/// This is used to store the server's configuration.
325#[derive(Clone)]
326pub(crate) struct Config {
327    pub(crate) public_key: String,
328    pub(crate) private_key: SecretString,
329    pub(crate) authorized_keys: Vec<String>,
330    pub(crate) bind_address: String,
331    pub(crate) remote_regex: Regex,
332}
333
334impl Config {
335    /// Creates a new server configuration.
336    fn new(public_key: String, private_key: SecretString, authorized_keys: Vec<String>, bind_address: String, remote_regex: Regex) -> Self {
337        Self {
338            public_key,
339            private_key,
340            authorized_keys,
341            bind_address,
342            remote_regex,
343        }
344    }
345}
346
347// Tests.
348
349#[cfg(test)]
350mod tests {
351    use pretty_assertions::assert_eq;
352
353    use crate::{
354        connect::tests::generate_test_client_config,
355        protocol::ClientAuthentication,
356        utils::{
357            generate_key_pair, sign_challenge,
358            tests::{generate_test_duplex, generate_test_fake_exchange_public_key},
359        },
360    };
361
362    use super::*;
363
364    pub(crate) fn generate_test_server_config() -> Config {
365        let key_path = "test/server";
366
367        let public_key = resolve_public_key(key_path).unwrap();
368        let private_key = resolve_private_key(key_path).unwrap();
369        let authorized_keys = resolve_authorized_keys(key_path);
370        let remote_regex = Regex::new(".*").unwrap();
371
372        Config {
373            public_key,
374            private_key,
375            authorized_keys,
376            bind_address: "bind_address".to_string(),
377            remote_regex,
378        }
379    }
380
381    #[test]
382    fn test_prepare_config() {
383        let instance = Instance::prepare("test/server".to_string(), ".*", "foo").unwrap();
384
385        assert_eq!(instance.config.public_key, "HQYY0BNIhdawY2Jw62DudkUsK2GKj3hGO3qSVBlCinI");
386        assert_eq!(instance.config.remote_regex.as_str(), ".*");
387        assert_eq!(instance.config.bind_address, "foo");
388    }
389
390    #[tokio::test]
391    async fn test_handle_handshake_success() {
392        // Setup test environment
393        let mut config = generate_test_server_config();
394        let (mut client, mut server) = generate_test_duplex();
395        let client_config = generate_test_client_config();
396
397        // Add client's public key to authorized keys
398        config.authorized_keys.push(client_config.public_key.clone());
399
400        // Client sends preamble
401        let client_challenge: Challenge = [8u8; 32];
402        let client_preamble = ClientPreamble {
403            remote: "localhost",
404            challenge: &client_challenge,
405            exchange_public_key: &generate_test_fake_exchange_public_key(),
406            should_encrypt: true,
407            is_udp: false,
408        };
409
410        client.push(ProtocolMessage::ClientPreamble(client_preamble.clone())).await.unwrap();
411
412        // Client prepares to respond to server's challenge
413        let client_handle = tokio::spawn(async move {
414            // Get server preamble
415            let guard = client.pull().await.unwrap();
416            let server_challenge = match guard.message() {
417                ProtocolMessage::ServerPreamble(preamble) => preamble.challenge,
418                _ => panic!("Expected ServerPreamble message, got: {:?}", guard.message()),
419            };
420
421            // Send client authentication
422            let signature = sign_challenge(server_challenge, &client_config.private_key).unwrap();
423            let client_auth = ClientAuthentication {
424                identity_public_key: &client_config.public_key,
425                signature: &signature,
426            };
427
428            client.push(ProtocolMessage::ClientAuthentication(client_auth)).await.unwrap();
429
430            // Verify handshake completion
431            let guard = client.pull().await.unwrap();
432            assert!(matches!(guard.message(), ProtocolMessage::HandshakeCompletion));
433        });
434
435        // Execute handshake on server side
436        let result = handle_handshake(&mut server, &config).await;
437
438        // Wait for client to complete
439        client_handle.await.unwrap();
440
441        // Verify success
442        assert!(result.is_ok());
443        let key_data = result.unwrap();
444
445        // Verify returned data
446        assert_eq!(key_data.client_exchange_public_key, client_preamble.exchange_public_key);
447        assert_eq!(key_data.client_challenge, client_challenge);
448        assert_eq!(key_data.requested_remote_address, "localhost");
449        assert_eq!(key_data.requested_should_encrypt, true);
450    }
451
452    #[tokio::test]
453    async fn test_handle_handshake_invalid_start() {
454        // Setup
455        let config = generate_test_server_config();
456        let (mut client, mut server) = generate_test_duplex();
457
458        // Send invalid initial message
459        client.push(ProtocolMessage::HandshakeCompletion).await.unwrap();
460
461        // Execute handshake
462        let result = handle_handshake(&mut server, &config).await;
463
464        // Verify error
465        assert!(result.is_err());
466
467        // Verify client received error
468        let guard = client.pull().await.unwrap();
469        if let ProtocolMessage::Error(error) = guard.message() {
470            assert_eq!(error, &ProtocolError::Unknown("Invalid handshake start".into()));
471        } else {
472            panic!("Expected error message, got: {:?}", guard.message());
473        }
474    }
475
476    #[tokio::test]
477    async fn test_handle_handshake_invalid_host() {
478        // Setup
479        let mut config = generate_test_server_config();
480        config.remote_regex = Regex::new("^only-this-host$").unwrap();
481        let (mut client, mut server) = generate_test_duplex();
482
483        // Send preamble with non-matching host
484        let client_preamble = ClientPreamble {
485            remote: "different-host",
486            challenge: &[9u8; 32],
487            exchange_public_key: &generate_test_fake_exchange_public_key(),
488            should_encrypt: false,
489            is_udp: false,
490        };
491
492        client.push(ProtocolMessage::ClientPreamble(client_preamble)).await.unwrap();
493
494        // Execute handshake
495        let result = handle_handshake(&mut server, &config).await;
496
497        // Verify error
498        assert!(result.is_err());
499
500        // Verify client received error
501        let guard = client.pull().await.unwrap();
502        if let ProtocolMessage::Error(error) = guard.message() {
503            assert!(matches!(error, ProtocolError::InvalidHost(_)));
504        } else {
505            panic!("Expected error message, got: {:?}", guard.message());
506        }
507    }
508
509    #[tokio::test]
510    async fn test_handle_handshake_unauthorized_key() {
511        // Setup
512        let config = generate_test_server_config();
513        let (mut client, mut server) = generate_test_duplex();
514
515        // Client sends preamble
516        let client_preamble = ClientPreamble {
517            remote: "localhost",
518            challenge: &[10u8; 32],
519            exchange_public_key: &generate_test_fake_exchange_public_key(),
520            should_encrypt: false,
521            is_udp: false,
522        };
523
524        client.push(ProtocolMessage::ClientPreamble(client_preamble)).await.unwrap();
525
526        // Generate unauthorized key pair
527        let unauthorized_key_pair = generate_key_pair().unwrap();
528        let unauthorized_private_key = unauthorized_key_pair.private_key.into();
529
530        // Client responds with unauthorized key
531        let client_handle = tokio::spawn(async move {
532            // Get server preamble
533            let guard = client.pull().await.unwrap();
534            let server_challenge = match guard.message() {
535                ProtocolMessage::ServerPreamble(preamble) => preamble.challenge,
536                _ => panic!("Expected ServerPreamble message, got: {:?}", guard.message()),
537            };
538
539            // Send client authentication with unauthorized key
540            let signature = sign_challenge(server_challenge, &unauthorized_private_key).unwrap();
541            let client_auth = ClientAuthentication {
542                identity_public_key: &unauthorized_key_pair.public_key,
543                signature: &signature,
544            };
545
546            client.push(ProtocolMessage::ClientAuthentication(client_auth)).await.unwrap();
547
548            // Check for error response
549            let guard = client.pull().await.unwrap();
550            assert!(matches!(guard.message(), &ProtocolMessage::Error(_)));
551        });
552
553        // Execute handshake on server side
554        let result = handle_handshake(&mut server, &config).await;
555
556        // Wait for client to complete
557        client_handle.await.unwrap();
558
559        // Verify failure
560        assert!(result.is_err());
561    }
562}