ratrodlib/
serve.rs

1//! This module implements the server-side of the protocol.
2
3use std::{marker::PhantomData, sync::Arc};
4
5use anyhow::Context;
6use regex::Regex;
7use secrecy::SecretString;
8use tokio::{
9    net::{TcpListener, TcpStream, UdpSocket}, 
10    select, 
11    task::JoinHandle, 
12    time::Instant
13};
14use tracing::{Instrument, error, info, info_span};
15
16use crate::{
17    base::{Constant, ExchangeKeyPair, Res, ServerKeyExchangeData, Void},
18    buffed_stream::BuffedTcpStream,
19    protocol::{BincodeReceive, BincodeSend, Challenge, ClientPreamble, ProtocolError, ProtocolMessage, ServerPreamble, Signature},
20    security::{resolve_authorized_keys, resolve_keypath, resolve_private_key, resolve_public_key},
21    utils::{generate_challenge, generate_ephemeral_key_pair, generate_shared_secret, handle_pump, random_string, sign_challenge, validate_signed_challenge},
22};
23
24// State machine.
25
26/// The server is in the configuration state.
27pub struct ConfigState;
28/// The server is in the ready state.
29pub struct ReadyState;
30
31/// The server instance.
32///
33/// This is the main entry point for the server. It is used to prepare the server, and start it.
34pub struct Instance<S = ConfigState> {
35    config: Config,
36    _phantom: PhantomData<S>,
37}
38
39impl Instance<ConfigState> {
40    /// Prepares the server instance.
41    pub fn prepare<A, B, C>(key_path: A, remote_regex: B, bind_address: C) -> Res<Instance<ReadyState>>
42    where
43        A: Into<Option<String>>,
44        B: AsRef<str>,
45        C: Into<String>,
46    {
47        let remote_regex = Regex::new(remote_regex.as_ref()).context("Invalid regex for remote host.")?;
48
49        let key_path = resolve_keypath(key_path)?;
50        let private_key = resolve_private_key(&key_path)?;
51        let public_key = resolve_public_key(&key_path)?;
52        let authorized_keys = resolve_authorized_keys(&key_path);
53
54        let config = Config::new(public_key, private_key, authorized_keys, bind_address.into(), remote_regex);
55
56        Ok(Instance { config, _phantom: PhantomData })
57    }
58}
59
60impl Instance<ReadyState> {
61    /// Starts the server instance.
62    pub async fn start(self) -> Void {
63        info!("🚀 Starting server on `{}` ...", self.config.bind_address);
64
65        run_tcp_server(self.config.clone()).await?;
66
67        Ok(())
68    }
69}
70
71// Operations.
72
73/// Verifies the preamble from the client.
74///
75/// This is used to ensure that the client is allowed to connect to the specified remote.
76async fn verify_client_preamble<T>(stream: &mut T, config: &Config, preamble: &ClientPreamble) -> Res<Signature>
77where
78    T: BincodeSend,
79{
80    // Validate the remote is OK.
81
82    if !config.remote_regex.is_match(&preamble.remote) {
83        return ProtocolError::InvalidHost(format!("Invalid host from client (supplied `{}`, but need to satisfy `{}`)", preamble.remote, config.remote_regex))
84            .send_and_bail(stream)
85            .await;
86    }
87
88    // Sign the challenge.
89
90    let signature = sign_challenge(&preamble.challenge, &config.private_key)?;
91
92    Ok(signature)
93}
94
95async fn send_server_preamble<T>(stream: &mut T, config: &Config, server_signature: &Signature, server_challenge: &Challenge) -> Res<ExchangeKeyPair>
96where
97    T: BincodeSend,
98{
99    info!("🚧 Sending handshake challenge to client ...");
100
101    let exchange_key_pair = generate_ephemeral_key_pair()?;
102    let exchange_public_key = exchange_key_pair.public_key.as_ref().try_into()?;
103
104    let preamble = ServerPreamble {
105        challenge: *server_challenge,
106        exchange_public_key,
107        identity_public_key: config.public_key.clone(),
108        signature: server_signature.into(),
109    };
110
111    stream.push(ProtocolMessage::ServerPreamble(preamble)).await?;
112
113    Ok(exchange_key_pair)
114}
115
116/// Handles the key challenge from the client.
117///
118/// This is used to ensure that the client is allowed to connect to the server.
119/// It also verifies the signature of the challenge from the client, authenticating the client.
120async fn handle_and_validate_key_challenge<T>(stream: &mut T, config: &Config, server_challenge: &Challenge) -> Void
121where
122    T: BincodeSend + BincodeReceive,
123{
124    // Wait for the client to respond.
125
126    let ProtocolMessage::ClientAuthentication(client_authentication) = stream.pull().await?.fail_if_error()? else {
127        return ProtocolError::InvalidKey("Invalid handshake response".into()).send_and_bail(stream).await;
128    };
129
130    // Verify the signature.
131
132    if validate_signed_challenge(server_challenge, &client_authentication.signature.into(), &client_authentication.identity_public_key).is_err() {
133        return ProtocolError::InvalidKey("Invalid challenge signature from client".into()).send_and_bail(stream).await;
134    }
135
136    // Validate that the key is authorized.
137    if !config.authorized_keys.contains(&client_authentication.identity_public_key) {
138        return ProtocolError::InvalidKey("Unauthorized key from client".into()).send_and_bail(stream).await;
139    }
140
141    info!("✅ Handshake challenge completed!");
142
143    Ok(())
144}
145
146/// Completes the handshake.
147///
148/// This is used to send the server's ephemeral public key to the client
149/// for the key exchange.
150async fn complete_handshake<T>(stream: &mut T) -> Void
151where
152    T: BincodeSend,
153{
154    let completion = ProtocolMessage::HandshakeCompletion;
155
156    stream.push(completion).await?;
157
158    info!("✅ Handshake completed.");
159
160    Ok(())
161}
162
163/// Handles the e2e handshake.
164///
165/// This is used to handle the handshake between the client and server.
166/// It verifies the preamble, handles the key challenge, and completes the handshake.
167async fn handle_handshake<T>(stream: &mut T, config: &Config) -> Res<ServerKeyExchangeData>
168where
169    T: BincodeReceive + BincodeSend,
170{
171    // Ingest the preamble from the client.
172
173    let ProtocolMessage::ClientPreamble(preamble) = stream.pull().await? else {
174        return ProtocolError::Unknown("Invalid handshake start".into()).send_and_bail(stream).await;
175    };
176
177    // Verify the preamble.
178
179    let server_signature = verify_client_preamble(stream, config, &preamble).await?;
180
181    // Create a challenge.
182
183    let server_challenge = generate_challenge();
184
185    // Send the server preamble.
186
187    let local_exchange_key_pair = send_server_preamble(stream, config, &server_signature, &server_challenge).await?;
188
189    // Validate the client's auth response.
190
191    handle_and_validate_key_challenge(stream, config, &server_challenge).await?;
192
193    // Complete the handshake.
194
195    complete_handshake(stream).await?;
196
197    Ok(ServerKeyExchangeData {
198        client_exchange_public_key: preamble.exchange_public_key,
199        client_challenge: preamble.challenge,
200        local_exchange_private_key: local_exchange_key_pair.private_key,
201        local_challenge: server_challenge,
202        requested_remote_address: preamble.remote,
203        requested_should_encrypt: preamble.should_encrypt,
204        requested_is_udp: preamble.is_udp,
205    })
206}
207
208/// Runs the pump with a TCP-connected remote.
209async fn run_tcp_pump(mut client: BuffedTcpStream, remote_address: &str) -> Void {
210    let Ok(mut remote) = TcpStream::connect(remote_address).await.context("Error connecting to remote") else {
211        return ProtocolError::RemoteFailed(format!("Failed to connect to remote `{}`", remote_address))
212            .send_and_bail(&mut client)
213            .await;
214    };
215
216    info!("✅ Connected to remote server `{}`.", remote_address);
217
218    handle_pump(&mut client, &mut remote).await.context("Error handling TCP pump.")?;
219
220    Ok(())
221}
222
223/// Runs the pump with a UDP-connected remote.
224async fn run_udp_pump(mut client: BuffedTcpStream, remote_address: &str) -> Void {
225    // Attempt to connect to the remote address (should only fail in weird circumstances).
226
227    let remote = UdpSocket::bind("127.0.0.1:0").await.context("Error binding UDP socket")?;
228    if remote.connect(remote_address).await.is_err() {
229        return ProtocolError::RemoteFailed(format!("Failed to connect to remote `{}`", remote_address))
230            .send_and_bail(&mut client)
231            .await;
232    }
233
234    info!("✅ Connected to remote server `{}`.", remote_address);
235
236    // Split the client connection into a read and write half.
237    let (mut client_read, mut client_write) = client.into_split();
238
239    // Split the remote connection into a read and write half (just requires `Arc`ing, since the UDP send / receive does not require `&mut`).
240    let remote_up = Arc::new(remote);
241    let remote_down = remote_up.clone();
242
243    // Run the pumps.
244    
245    let last_activity = Arc::new(tokio::sync::Mutex::new(Instant::now()));
246    let last_activity_up = last_activity.clone();
247    let last_activity_down = last_activity.clone();
248
249    let pump_up: JoinHandle<Void> = tokio::spawn(async move {
250        while let ProtocolMessage::UdpData(data) = client_read.pull().await? {
251            remote_up.send(&data).await?;
252            *last_activity_up.lock().await = Instant::now();
253        }
254
255        Ok(())
256    });
257
258    let pump_down: JoinHandle<Void> = tokio::spawn(async move {
259        let mut buf = [0; Constant::BUFFER_SIZE];
260
261        loop {
262            let size = remote_down.recv(&mut buf).await?;
263            client_write.push(ProtocolMessage::UdpData(buf[..size].to_vec())).await?;
264            *last_activity_down.lock().await = Instant::now();
265        }
266    });
267
268    let timeout: JoinHandle<Void> = tokio::spawn(async move {
269        loop {
270            let last_activity = *last_activity.lock().await;
271
272            if last_activity.elapsed() > Constant::UDP_TIMEOUT {
273                info!("✅ UDP connection timed out (assumed graceful close).");
274                return Ok(());
275            }
276
277            tokio::time::sleep(Constant::UDP_TIMEOUT).await;
278        }
279    });
280
281    // Wait for the pumps to finish.  This employs a "last activity" type timeout, and uses `select!` to break
282    // out of the loop if any of the pumps finish.  In general, the UDP side to the remote will not close,
283    // but the client may break the pipe, so we need to handle that.
284    // The `select!` macro will return the first result that completes,
285    // and the `timeout` will return if the last activity is too long ago.
286
287    let result = select! {
288        r = pump_up => r?,
289        r = pump_down => r?,
290        r = timeout => r?,
291    };
292
293    // Check for errors.
294
295    result?;
296
297    Ok(())
298}
299
300/// Runs the TCP server.
301///
302/// This is the main entry point for the server. It binds to the specified address, and handles incoming connections.
303async fn run_tcp_server(config: Config) -> Void {
304    let listener = TcpListener::bind(&config.bind_address).await?;
305
306    loop {
307        let (socket, _) = listener.accept().await?;
308
309        tokio::spawn(handle_connection(socket, config.clone()));
310    }
311}
312
313/// Handles the TCP connection.
314///
315/// This is used to handle the TCP connection between the client and server.
316/// It handles the handshake, and pumps data between the client and server.
317async fn handle_connection(client: TcpStream, config: Config) {
318    let mut client = BuffedTcpStream::from(client);
319
320    let id = random_string(6);
321    let span = info_span!("conn", id = id);
322
323    let result: Void = async move {
324        let peer_addr = client.as_inner_tcp_write_ref().peer_addr().context("Error getting peer address")?;
325
326        info!("✅ Accepted connection from `{}`.", peer_addr);
327
328        // Handle the handshake.
329
330        let handshake_data = handle_handshake(&mut client, &config).await.context("Error handling handshake")?;
331
332        // Generate and apply the shared secret, if needed.
333        if handshake_data.requested_should_encrypt {
334            let private_key = handshake_data.local_exchange_private_key;
335            let salt_bytes = [handshake_data.local_challenge, handshake_data.client_challenge].concat();
336            let shared_secret = generate_shared_secret(private_key, &handshake_data.client_exchange_public_key, &salt_bytes)?;
337
338            client = client.with_encryption(shared_secret);
339            info!("🔒 Encryption applied ...");
340        }
341
342        // Handle the pump.
343
344        info!("⛽ Pumping data between client and remote ...");
345
346        if handshake_data.requested_is_udp {
347            run_udp_pump(client, &handshake_data.requested_remote_address).await?;
348        } else {
349            run_tcp_pump(client, &handshake_data.requested_remote_address).await?;
350        }
351
352        info!("✅ Connection closed.");
353
354        Ok(())
355    }
356    .instrument(span.clone())
357    .await;
358
359    // Enter the span, so that the error is logged with the span's metadata, if needed.
360    let _guard = span.enter();
361
362    if let Err(err) = result {
363        let chain = err.chain().collect::<Vec<_>>();
364        let full_chain = chain.iter().map(|e| format!("`{}`", e)).collect::<Vec<_>>().join(" => ");
365
366        error!("❌ Error handling connection: {}.", full_chain);
367    }
368}
369
370// Config.
371
372/// The server configuration.
373///
374/// This is used to store the server's configuration.
375#[derive(Clone)]
376pub(crate) struct Config {
377    pub(crate) public_key: String,
378    pub(crate) private_key: SecretString,
379    pub(crate) authorized_keys: Vec<String>,
380    pub(crate) bind_address: String,
381    pub(crate) remote_regex: Regex,
382}
383
384impl Config {
385    /// Creates a new server configuration.
386    fn new(public_key: String, private_key: SecretString, authorized_keys: Vec<String>, bind_address: String, remote_regex: Regex) -> Self {
387        Self {
388            public_key,
389            private_key,
390            authorized_keys,
391            bind_address,
392            remote_regex,
393        }
394    }
395}
396
397// Tests.
398
399#[cfg(test)]
400mod tests {
401    use pretty_assertions::assert_eq;
402
403    use crate::{
404        connect::tests::generate_test_client_config,
405        protocol::ClientAuthentication,
406        utils::{
407            generate_key_pair, sign_challenge,
408            tests::{generate_test_duplex, generate_test_fake_exchange_public_key},
409        },
410    };
411
412    use super::*;
413
414    pub(crate) fn generate_test_server_config() -> Config {
415        let key_path = "test/server";
416
417        let public_key = resolve_public_key(key_path).unwrap();
418        let private_key = resolve_private_key(key_path).unwrap();
419        let authorized_keys = resolve_authorized_keys(key_path);
420        let remote_regex = Regex::new(".*").unwrap();
421
422        Config {
423            public_key,
424            private_key,
425            authorized_keys,
426            bind_address: "bind_address".to_string(),
427            remote_regex,
428        }
429    }
430
431    #[test]
432    fn test_prepare_config() {
433        let instance = Instance::prepare("test/server".to_string(), ".*", "foo").unwrap();
434
435        assert_eq!(instance.config.public_key, "HQYY0BNIhdawY2Jw62DudkUsK2GKj3hGO3qSVBlCinI");
436        assert_eq!(instance.config.remote_regex.as_str(), ".*");
437        assert_eq!(instance.config.bind_address, "foo");
438    }
439
440    #[tokio::test]
441    async fn test_handle_handshake_success() {
442        // Setup test environment
443        let mut config = generate_test_server_config();
444        let (mut client, mut server) = generate_test_duplex();
445        let client_config = generate_test_client_config();
446
447        // Add client's public key to authorized keys
448        config.authorized_keys.push(client_config.public_key.clone());
449
450        // Client sends preamble
451        let client_challenge: Challenge = [8u8; 32];
452        let client_preamble = ClientPreamble {
453            remote: "localhost".to_string(),
454            challenge: client_challenge,
455            exchange_public_key: generate_test_fake_exchange_public_key(),
456            should_encrypt: true,
457            is_udp: false,
458        };
459
460        client.push(ProtocolMessage::ClientPreamble(client_preamble.clone())).await.unwrap();
461
462        // Client prepares to respond to server's challenge
463        let client_handle = tokio::spawn(async move {
464            // Get server preamble
465            let message = client.pull().await.unwrap();
466            let server_challenge = match message {
467                ProtocolMessage::ServerPreamble(preamble) => preamble.challenge,
468                _ => panic!("Expected ServerPreamble message, got: {:?}", message),
469            };
470
471            // Send client authentication
472            let signature = sign_challenge(&server_challenge, &client_config.private_key).unwrap();
473            let client_auth = ClientAuthentication {
474                identity_public_key: client_config.public_key,
475                signature: signature.into(),
476            };
477
478            client.push(ProtocolMessage::ClientAuthentication(client_auth)).await.unwrap();
479
480            // Verify handshake completion
481            let message = client.pull().await.unwrap();
482            assert!(matches!(message, ProtocolMessage::HandshakeCompletion));
483        });
484
485        // Execute handshake on server side
486        let result = handle_handshake(&mut server, &config).await;
487
488        // Wait for client to complete
489        client_handle.await.unwrap();
490
491        // Verify success
492        assert!(result.is_ok());
493        let key_data = result.unwrap();
494
495        // Verify returned data
496        assert_eq!(key_data.client_exchange_public_key, client_preamble.exchange_public_key);
497        assert_eq!(key_data.client_challenge, client_challenge);
498        assert_eq!(key_data.requested_remote_address, "localhost");
499        assert_eq!(key_data.requested_should_encrypt, true);
500    }
501
502    #[tokio::test]
503    async fn test_handle_handshake_invalid_start() {
504        // Setup
505        let config = generate_test_server_config();
506        let (mut client, mut server) = generate_test_duplex();
507
508        // Send invalid initial message
509        client.push(ProtocolMessage::HandshakeCompletion).await.unwrap();
510
511        // Execute handshake
512        let result = handle_handshake(&mut server, &config).await;
513
514        // Verify error
515        assert!(result.is_err());
516
517        // Verify client received error
518        let message = client.pull().await.unwrap();
519        if let ProtocolMessage::Error(error) = message {
520            assert_eq!(error, ProtocolError::Unknown("Invalid handshake start".into()));
521        } else {
522            panic!("Expected error message, got: {:?}", message);
523        }
524    }
525
526    #[tokio::test]
527    async fn test_handle_handshake_invalid_host() {
528        // Setup
529        let mut config = generate_test_server_config();
530        config.remote_regex = Regex::new("^only-this-host$").unwrap();
531        let (mut client, mut server) = generate_test_duplex();
532
533        // Send preamble with non-matching host
534        let client_preamble = ClientPreamble {
535            remote: "different-host".to_string(),
536            challenge: [9u8; 32],
537            exchange_public_key: generate_test_fake_exchange_public_key(),
538            should_encrypt: false,
539            is_udp: false,
540        };
541
542        client.push(ProtocolMessage::ClientPreamble(client_preamble)).await.unwrap();
543
544        // Execute handshake
545        let result = handle_handshake(&mut server, &config).await;
546
547        // Verify error
548        assert!(result.is_err());
549
550        // Verify client received error
551        let message = client.pull().await.unwrap();
552        if let ProtocolMessage::Error(error) = message {
553            assert!(matches!(error, ProtocolError::InvalidHost(_)));
554        } else {
555            panic!("Expected error message, got: {:?}", message);
556        }
557    }
558
559    #[tokio::test]
560    async fn test_handle_handshake_unauthorized_key() {
561        // Setup
562        let config = generate_test_server_config();
563        let (mut client, mut server) = generate_test_duplex();
564
565        // Client sends preamble
566        let client_preamble = ClientPreamble {
567            remote: "localhost".to_string(),
568            challenge: [10u8; 32],
569            exchange_public_key: generate_test_fake_exchange_public_key(),
570            should_encrypt: false,
571            is_udp: false,
572        };
573
574        client.push(ProtocolMessage::ClientPreamble(client_preamble)).await.unwrap();
575
576        // Generate unauthorized key pair
577        let unauthorized_key_pair = generate_key_pair().unwrap();
578        let unauthorized_private_key = unauthorized_key_pair.private_key.into();
579
580        // Client responds with unauthorized key
581        let client_handle = tokio::spawn(async move {
582            // Get server preamble
583            let message = client.pull().await.unwrap();
584            let server_challenge = match message {
585                ProtocolMessage::ServerPreamble(preamble) => preamble.challenge,
586                _ => panic!("Expected ServerPreamble message, got: {:?}", message),
587            };
588
589            // Send client authentication with unauthorized key
590            let signature = sign_challenge(&server_challenge, &unauthorized_private_key).unwrap();
591            let client_auth = ClientAuthentication {
592                identity_public_key: unauthorized_key_pair.public_key,
593                signature: signature.into(),
594            };
595
596            client.push(ProtocolMessage::ClientAuthentication(client_auth)).await.unwrap();
597
598            // Check for error response
599            let message = client.pull().await.unwrap();
600            assert!(matches!(message, ProtocolMessage::Error(_)));
601        });
602
603        // Execute handshake on server side
604        let result = handle_handshake(&mut server, &config).await;
605
606        // Wait for client to complete
607        client_handle.await.unwrap();
608
609        // Verify failure
610        assert!(result.is_err());
611    }
612}