ratrodlib/
utils.rs

1//! Utility functions for the application.
2//!
3//! This module provides various utility functions for generating keys, encrypting/decrypting data, and handling tunnels.
4//! It also includes functions for parsing tunnel definitions and handling bidirectional data transfer.
5
6use std::sync::Arc;
7
8use anyhow::{Context, anyhow};
9use base64::Engine;
10use bytes::BytesMut;
11use futures::future::Either;
12use rand::{Rng, distr::Alphanumeric};
13use ring::{
14    aead::{Aad, LessSafeKey, Nonce, UnboundKey},
15    agreement::{EphemeralPrivateKey, agree_ephemeral},
16    hkdf::Salt,
17    rand::{SecureRandom, SystemRandom},
18    signature::{Ed25519KeyPair, KeyPair},
19};
20use secrecy::{ExposeSecret, SecretString};
21use tokio::{
22    io::{AsyncReadExt, AsyncWriteExt},
23    net::{TcpStream, UdpSocket},
24    select,
25    sync::Mutex,
26    task::JoinHandle,
27    time::Instant,
28};
29use tracing::{debug, info};
30
31use crate::{
32    base::{Base64KeyPair, Constant, EncryptedData, ExchangeKeyPair, Res, SharedSecret, SharedSecretNonce, SharedSecretShape, TunnelDefinition, Void},
33    buffed_stream::{BincodeSplit, BuffedTcpStream},
34    protocol::{BincodeReceive, BincodeSend, Challenge, ProtocolMessage, Signature},
35};
36
37/// Generates a random alphanumeric string of the specified length.
38///
39/// This is used for creating unique identifiers, such as connection IDs.
40pub fn random_string(len: usize) -> String {
41    rand::rng().sample_iter(&Alphanumeric).take(len).map(char::from).collect()
42}
43
44/// Generates a random alphanumeric string of the specified length.
45pub fn generate_key_pair() -> Res<Base64KeyPair> {
46    let rng = SystemRandom::new();
47    // Generate Ed25519 key pair in PKCS#8 format
48    let pkcs8 = Ed25519KeyPair::generate_pkcs8(&rng).context("Unable to generate key pair")?;
49
50    let key_pair = Ed25519KeyPair::from_pkcs8(pkcs8.as_ref()).context("Failed to create key pair")?;
51
52    let public = Constant::BASE64_ENGINE.encode(key_pair.public_key().as_ref());
53    let private = Constant::BASE64_ENGINE.encode(pkcs8.as_ref());
54
55    Ok(Base64KeyPair { public_key: public, private_key: private })
56}
57
58/// Generates a key pair from a given private key.
59pub fn generate_key_pair_from_key(private_key: &str) -> Res<Base64KeyPair> {
60    let key_bytes = Constant::BASE64_ENGINE.decode(private_key).context("Could not decode seed")?;
61
62    let key_pair = Ed25519KeyPair::from_pkcs8(&key_bytes).context("Failed to create key pair")?;
63
64    let public = Constant::BASE64_ENGINE.encode(key_pair.public_key().as_ref());
65
66    Ok(Base64KeyPair {
67        public_key: public,
68        private_key: private_key.to_string(),
69    })
70}
71
72/// Generates a random challenge for a peer to sign.
73pub fn generate_challenge() -> Challenge {
74    let rng = SystemRandom::new();
75    let mut challenge = Challenge::default();
76    rng.fill(&mut challenge).expect("Failed to generate challenge");
77    challenge
78}
79
80/// Signs a challenge using the provided private key.
81pub fn sign_challenge(challenge: &[u8], private_key: &SecretString) -> Res<Signature> {
82    if challenge.len() != Constant::CHALLENGE_SIZE {
83        return Err(anyhow!("Invalid challenge length"));
84    }
85
86    debug!("Challenge: `{:?}`", challenge);
87
88    let private_key = Constant::BASE64_ENGINE.decode(private_key.expose_secret()).context("Could not decode private key")?;
89    debug!("Signing challenge with private key: {:?}", &private_key);
90
91    let key_pair = Ed25519KeyPair::from_pkcs8(&private_key).map_err(|_| anyhow!("Invalid private key"))?;
92    debug!("Key pair: {:?}", key_pair);
93
94    let signature = key_pair.sign(challenge).as_ref()[..Constant::SIGNATURE_SIZE]
95        .try_into()
96        .map_err(|_| anyhow!("Invalid signature length"))?;
97    debug!("Signature: {:?}", &signature);
98
99    Ok(signature)
100}
101
102/// Validates a signed challenge using the provided public key.
103pub fn validate_signed_challenge(challenge: &[u8], signature: &[u8], public_key: &str) -> Void {
104    if challenge.len() != Constant::CHALLENGE_SIZE {
105        return Err(anyhow!("Invalid challenge length"));
106    }
107
108    if signature.len() != Constant::SIGNATURE_SIZE {
109        return Err(anyhow!("Invalid signature length"));
110    }
111
112    let public_key = Constant::BASE64_ENGINE.decode(public_key).context("Could not decode public key")?;
113
114    let unparsed_public_key = ring::signature::UnparsedPublicKey::new(Constant::SIGNATURE, public_key);
115
116    unparsed_public_key.verify(challenge, signature).context("Invalid signature")?;
117
118    Ok(())
119}
120
121/// Generates an ephemeral key pair for key exchange.
122pub fn generate_ephemeral_key_pair() -> Res<ExchangeKeyPair> {
123    let rng = SystemRandom::new();
124
125    let my_private_key = EphemeralPrivateKey::generate(Constant::AGREEMENT, &rng)?;
126
127    let public_key = my_private_key.compute_public_key()?;
128
129    Ok(ExchangeKeyPair { public_key, private_key: my_private_key })
130}
131
132/// Derives a shared secret for encrypting and decrypting data.
133pub fn generate_shared_secret(private_key: EphemeralPrivateKey, peer_public_key: &[u8], salt_bytes: &[u8]) -> Res<SharedSecret> {
134    if peer_public_key.len() != Constant::EXCHANGE_PUBLIC_KEY_SIZE {
135        return Err(anyhow!("Invalid public key length"));
136    }
137
138    let unparsed_peer_public_key = ring::agreement::UnparsedPublicKey::new(Constant::AGREEMENT, peer_public_key);
139
140    let shared_secret = agree_ephemeral(private_key, &unparsed_peer_public_key, |shared_secret| generate_chacha_key(shared_secret, salt_bytes))??;
141    Ok(shared_secret)
142}
143
144/// Generates a ChaCha20 key from the shared secret and salt bytes.
145fn generate_chacha_key(private_key: &[u8], salt_bytes: &[u8]) -> Res<SharedSecret> {
146    let salt = Salt::new(Constant::KDF, salt_bytes);
147    let info = &[salt_bytes];
148
149    let prk = salt.extract(private_key);
150    let okm = prk.expand(info, Constant::KDF)?;
151
152    let mut key = SharedSecretShape::default();
153    okm.fill(&mut key)?;
154
155    Ok(SharedSecret::init_with(|| key))
156}
157
158/// Encrypts the given plaintext using the shared secret.
159pub fn encrypt(shared_secret: &SharedSecret, plaintext: &[u8]) -> Res<EncryptedData> {
160    let mut in_out = BytesMut::from(plaintext);
161
162    let nonce = encrypt_into(shared_secret, &mut in_out)?;
163
164    Ok(EncryptedData { nonce, data: in_out.to_vec() })
165}
166
167/// Encrypts the data in place and appends the tag to the end of the buffer.
168/// The nonce is generated randomly.
169///
170/// This method updates the `in_out` length.
171pub fn encrypt_into(shared_secret: &SharedSecret, in_out: &mut BytesMut) -> Res<SharedSecretNonce> {
172    let rng = SystemRandom::new();
173    let mut nonce_bytes = [0u8; Constant::SHARED_SECRET_NONCE_SIZE];
174    rng.fill(&mut nonce_bytes).context("Could not fill nonce for encryption")?;
175
176    let unbound_key = UnboundKey::new(Constant::AEAD, shared_secret.expose_secret()).context("Could not generate unbound key for encryption")?;
177    let sealing_key = LessSafeKey::new(unbound_key);
178    let nonce = Nonce::assume_unique_for_key(nonce_bytes);
179
180    sealing_key.seal_in_place_append_tag(nonce, Aad::empty(), in_out).context("Could not seal in place during encryption")?;
181
182    Ok(nonce_bytes)
183}
184
185/// Decrypts the given ciphertext using the shared secret.
186pub fn decrypt(shared_secret: &SharedSecret, nonce_bytes: &SharedSecretNonce, ciphertext: &[u8]) -> Res<Vec<u8>> {
187    let mut in_out = BytesMut::from(ciphertext);
188
189    decrypt_in_place(shared_secret, nonce_bytes, &mut in_out)?;
190
191    Ok(in_out.to_vec())
192}
193
194/// Decrypts the data in place.
195///
196/// This method updates the `in_out` length.
197pub fn decrypt_in_place(shared_secret: &SharedSecret, nonce_bytes: &SharedSecretNonce, in_out: &mut BytesMut) -> Void {
198    let unbound_key = UnboundKey::new(Constant::AEAD, shared_secret.expose_secret()).context("Could not generate unbound key for decryption")?;
199    let opening_key = LessSafeKey::new(unbound_key);
200    let nonce = Nonce::assume_unique_for_key(*nonce_bytes);
201
202    let length = opening_key.open_in_place(nonce, Aad::empty(), in_out).context("Could not open in place for decryption")?.len();
203
204    // SAFETY: Decryption in place reduces the length, so we can safely set the length
205    // to the length of the decrypted data.
206    unsafe {
207        in_out.set_len(length);
208    }
209
210    Ok(())
211}
212
213/// Parses the tunnel definition from the given input string.
214///
215/// Input is of the form:
216/// - `local_port:destination_host:destination_port`
217/// - `local_port:destination_port`
218/// - `local_port`
219pub fn parse_tunnel_definition(tunnel: &str) -> Res<TunnelDefinition> {
220    let parts: Vec<&str> = tunnel.split(':').collect();
221
222    match parts.len() {
223        4 => {
224            let bind_address = format!("{}:{}", parts[0], parts[1]);
225            let host_address = format!("{}:{}", parts[2], parts[3]);
226
227            Ok(TunnelDefinition {
228                bind_address,
229                remote_address: host_address,
230            })
231        }
232        3 => {
233            let bind_address = format!("127.0.0.1:{}", parts[0]);
234            let host_address = format!("{}:{}", parts[1], parts[2]);
235
236            Ok(TunnelDefinition {
237                bind_address,
238                remote_address: host_address,
239            })
240        }
241        2 => {
242            let bind_address = format!("127.0.0.1:{}", parts[0]);
243            let host_address = format!("127.0.0.1:{}", parts[1]);
244
245            Ok(TunnelDefinition {
246                bind_address,
247                remote_address: host_address,
248            })
249        }
250        1 => {
251            let bind_address = format!("127.0.0.1:{}", parts[0]);
252            let host_address = format!("127.0.0.1:{}", parts[0]);
253
254            Ok(TunnelDefinition {
255                bind_address,
256                remote_address: host_address,
257            })
258        }
259        _ => Err(anyhow!("Invalid tunnel definition format")),
260    }
261}
262
263/// Parses a list of tunnel definitions from the given input strings.
264pub fn parse_tunnel_definitions<T>(tunnels: &[T]) -> Res<Vec<TunnelDefinition>>
265where
266    T: AsRef<str>,
267{
268    tunnels.iter().map(|tunnel| parse_tunnel_definition(tunnel.as_ref())).collect()
269}
270
271/// Handles bidirectional data transfer between two streams.
272pub async fn handle_tcp_pump(a: TcpStream, b: BuffedTcpStream) -> Res<(u64, u64)> {
273    let (mut read_a, mut write_a) = a.into_split();
274    let (mut read_b, mut write_b) = b.into_split();
275
276    let a_to_b: JoinHandle<Res<u64>> = tokio::spawn(async move {
277        let buf = &mut [0u8; Constant::BUFFER_SIZE];
278        let mut count = 0;
279        loop {
280            let n = read_a.read(buf).await?;
281
282            if n == 0 {
283                break;
284            }
285
286            write_b.push(ProtocolMessage::Data(&buf[..n])).await?;
287
288            count += n as u64;
289        }
290
291        Ok(count)
292    });
293
294    let b_to_a: JoinHandle<Res<u64>> = tokio::spawn(async move {
295        let mut count = 0;
296        loop {
297            let guard = read_b.pull().await?;
298            let data = match guard.message() {
299                ProtocolMessage::Data(data) => data,
300                ProtocolMessage::Shutdown => break,
301                _ => return Err(anyhow!("Failed to read data in pump (wrong type)")),
302            };
303
304            if data.is_empty() {
305                break;
306            }
307
308            write_a.write_all(data).await?;
309            write_a.flush().await?;
310
311            count += data.len() as u64;
312        }
313
314        Ok(count)
315    });
316
317    let result = futures::future::select(a_to_b, b_to_a).await;
318
319    match result {
320        Either::Left((a_to_b, other)) => {
321            let right = a_to_b??;
322            let left = other.await??;
323
324            info!("📊 {} ⮀ {}", left, right);
325
326            Ok((left, right))
327        }
328        Either::Right((b_to_a, other)) => {
329            let right = b_to_a??;
330            let left = other.await??;
331
332            info!("📊 {} ⮀ {}", left, right);
333
334            Ok((left, right))
335        }
336    }
337}
338
339/// Handles bidirectional data transfer between a UDP socket and a TCP stream.
340pub async fn handle_udp_pump(a: UdpSocket, b: BuffedTcpStream) -> Void {
341    // Split the client connection into a read and write half.
342    let (mut b_read, mut b_write) = b.into_split();
343
344    // Split the remote connection into a read and write half (just requires `Arc`ing, since the UDP send / receive does not require `&mut`).
345    let a_up = Arc::new(a);
346    let a_down = a_up.clone();
347
348    // Run the pumps.
349
350    let last_activity = Arc::new(Mutex::new(Instant::now()));
351    let last_activity_up = last_activity.clone();
352    let last_activity_down = last_activity.clone();
353
354    let pump_up: JoinHandle<Void> = tokio::spawn(async move {
355        loop {
356            let guard = b_read.pull().await?;
357            let ProtocolMessage::UdpData(data) = guard.message() else {
358                break;
359            };
360
361            a_up.send(data).await?;
362            *last_activity_up.lock().await = Instant::now();
363        }
364
365        Ok(())
366    });
367
368    let pump_down: JoinHandle<Void> = tokio::spawn(async move {
369        let mut buf = [0; Constant::BUFFER_SIZE];
370
371        loop {
372            let size = a_down.recv(&mut buf).await?;
373            b_write.push(ProtocolMessage::UdpData(&buf[..size])).await?;
374            *last_activity_down.lock().await = Instant::now();
375        }
376    });
377
378    let timeout: JoinHandle<Void> = tokio::spawn(async move {
379        loop {
380            let last_activity = *last_activity.lock().await;
381
382            if last_activity.elapsed() > Constant::TIMEOUT {
383                info!("✅ UDP connection timed out (assumed graceful close).");
384                return Ok(());
385            }
386
387            tokio::time::sleep(Constant::TIMEOUT).await;
388        }
389    });
390
391    // Wait for the pumps to finish.  This employs a "last activity" type timeout, and uses `select!` to break
392    // out of the loop if any of the pumps finish.  In general, the UDP side to the remote will not close,
393    // but the client may break the pipe, so we need to handle that.
394    // The `select!` macro will return the first result that completes,
395    // and the `timeout` will return if the last activity is too long ago.
396
397    let result = select! {
398        r = pump_up => r?,
399        r = pump_down => r?,
400        r = timeout => r?,
401    };
402
403    // Check for errors.
404
405    result?;
406
407    Ok(())
408}
409
410#[cfg(test)]
411pub mod tests {
412
413    use crate::{
414        buffed_stream::{BuffedDuplexStream, BuffedStream},
415        protocol::ExchangePublicKey,
416    };
417
418    use super::*;
419    use pretty_assertions::assert_eq;
420
421    pub fn generate_test_duplex() -> (BuffedDuplexStream, BuffedDuplexStream) {
422        let (a, b) = tokio::io::duplex(Constant::BUFFER_SIZE);
423        (BuffedStream::from(a), BuffedStream::from(b))
424    }
425
426    pub fn generate_test_duplex_with_encryption() -> (BuffedDuplexStream, BuffedDuplexStream) {
427        let (a, b) = tokio::io::duplex(Constant::BUFFER_SIZE);
428        let secret_box = generate_test_shared_secret();
429        let shared_secret = secret_box.expose_secret();
430
431        (
432            BuffedStream::from(a).with_encryption(SharedSecret::init_with(|| *shared_secret)),
433            BuffedStream::from(b).with_encryption(SharedSecret::init_with(|| *shared_secret)),
434        )
435    }
436
437    pub fn generate_test_ephemeral_key_pair() -> ExchangeKeyPair {
438        generate_ephemeral_key_pair().unwrap()
439    }
440
441    pub fn generate_test_shared_secret() -> SharedSecret {
442        let ephemeral_key_pair = generate_test_ephemeral_key_pair();
443        let challenge = generate_challenge();
444
445        generate_shared_secret(ephemeral_key_pair.private_key, ephemeral_key_pair.public_key.as_ref(), &challenge).unwrap()
446    }
447
448    pub fn generate_test_fake_exchange_public_key() -> ExchangePublicKey {
449        b"this needs to be exactly 32 byte".as_ref().try_into().unwrap()
450    }
451
452    #[test]
453    fn test_generate_key_pair() {
454        let key_pair = generate_key_pair().unwrap();
455        assert_eq!(key_pair.public_key.len(), 43);
456        assert_eq!(key_pair.private_key.len(), 111);
457    }
458
459    #[test]
460    fn test_generate_key_pair_from_key() {
461        let key_pair = generate_key_pair().unwrap();
462        let new_key_pair = generate_key_pair_from_key(&key_pair.private_key).unwrap();
463        assert_eq!(new_key_pair.public_key, key_pair.public_key);
464        assert_eq!(new_key_pair.private_key, key_pair.private_key);
465    }
466
467    #[test]
468    fn test_ed25519() {
469        let key_pair = generate_key_pair().unwrap();
470
471        let challenge = generate_challenge();
472        let signature = sign_challenge(&challenge, &key_pair.private_key.into()).unwrap();
473
474        validate_signed_challenge(&challenge, &signature, &key_pair.public_key).unwrap();
475    }
476
477    #[test]
478    fn test_ephemeral_key_exchange() {
479        let ephemeral_key_pair_1 = generate_ephemeral_key_pair().unwrap();
480        let ephemeral_key_pair_2 = generate_ephemeral_key_pair().unwrap();
481        let challenge = generate_challenge();
482
483        let shared_secret_1 = generate_shared_secret(ephemeral_key_pair_1.private_key, ephemeral_key_pair_2.public_key.as_ref(), &challenge).unwrap();
484        let shared_secret_2 = generate_shared_secret(ephemeral_key_pair_2.private_key, ephemeral_key_pair_1.public_key.as_ref(), &challenge).unwrap();
485
486        assert_eq!(shared_secret_1.expose_secret().len(), Constant::SHARED_SECRET_SIZE);
487        assert_eq!(shared_secret_1.expose_secret(), shared_secret_2.expose_secret());
488    }
489
490    #[test]
491    fn test_encrypt_decrypt() {
492        let shared_secret = generate_test_shared_secret();
493
494        let plaintext = b"Hello, world!";
495        let encrypted_data = encrypt(&shared_secret, plaintext).unwrap();
496        let decrypted_data = decrypt(&shared_secret, &encrypted_data.nonce, &encrypted_data.data).unwrap();
497
498        assert_eq!(decrypted_data, plaintext);
499    }
500
501    #[test]
502    fn test_parse_tunnel_definition() {
503        let input = "a:b:c:d";
504        let result = parse_tunnel_definition(input).unwrap();
505        assert_eq!(result.bind_address, "a:b");
506        assert_eq!(result.remote_address, "c:d");
507
508        let input = "a:b:c";
509        let result = parse_tunnel_definition(input).unwrap();
510        assert_eq!(result.bind_address, "127.0.0.1:a");
511        assert_eq!(result.remote_address, "b:c");
512
513        let input = "a:b";
514        let result = parse_tunnel_definition(input).unwrap();
515        assert_eq!(result.bind_address, "127.0.0.1:a");
516        assert_eq!(result.remote_address, "127.0.0.1:b");
517
518        let input = "a";
519        let result = parse_tunnel_definition(input).unwrap();
520        assert_eq!(result.bind_address, "127.0.0.1:a");
521        assert_eq!(result.remote_address, "127.0.0.1:a");
522    }
523
524    #[test]
525    fn test_bad_tunnel_definition() {
526        let input = "a:b:c:d:e";
527        assert!(parse_tunnel_definition(input).is_err());
528
529        let input = "a:b:c:d:e:f";
530        assert!(parse_tunnel_definition(input).is_err());
531    }
532}