ratrodlib/
utils.rs

1//! Utility functions for the application.
2//!
3//! This module provides various utility functions for generating keys, encrypting/decrypting data, and handling tunnels.
4//! It also includes functions for parsing tunnel definitions and handling bidirectional data transfer.
5
6use anyhow::Context;
7use base64::Engine;
8use rand::{Rng, distr::Alphanumeric};
9use ring::{
10    aead::{Aad, LessSafeKey, Nonce, UnboundKey},
11    agreement::{EphemeralPrivateKey, agree_ephemeral},
12    hkdf::Salt,
13    rand::{SecureRandom, SystemRandom},
14    signature::{Ed25519KeyPair, KeyPair},
15};
16use secrecy::{ExposeSecret, SecretString};
17use tokio::io::{AsyncRead, AsyncWrite};
18use tracing::{debug, info};
19
20use crate::{
21    base::{Base64KeyPair, Constant, EncryptedData, Err, ExchangeKeyPair, Res, SharedSecret, SharedSecretNonce, SharedSecretShape, TunnelDefinition, Void},
22    protocol::{Challenge, ExchangePublicKey, Signature},
23};
24
25/// Generates a random alphanumeric string of the specified length.
26///
27/// This is used for creating unique identifiers, such as connection IDs.
28pub fn random_string(len: usize) -> String {
29    rand::rng().sample_iter(&Alphanumeric).take(len).map(char::from).collect()
30}
31
32pub fn generate_key_pair() -> Res<Base64KeyPair> {
33    let rng = SystemRandom::new();
34    // Generate Ed25519 key pair in PKCS#8 format
35    let pkcs8 = Ed25519KeyPair::generate_pkcs8(&rng).context("Unable to generate key pair")?;
36
37    let key_pair = Ed25519KeyPair::from_pkcs8(pkcs8.as_ref()).context("Failed to create key pair")?;
38
39    let public = Constant::BASE64_ENGINE.encode(key_pair.public_key().as_ref());
40    let private = Constant::BASE64_ENGINE.encode(pkcs8.as_ref());
41
42    Ok(Base64KeyPair { public_key: public, private_key: private })
43}
44
45pub fn generate_key_pair_from_key(private_key: &str) -> Res<Base64KeyPair> {
46    let key_bytes = Constant::BASE64_ENGINE.decode(private_key).context("Could not decode seed")?;
47
48    let key_pair = Ed25519KeyPair::from_pkcs8(&key_bytes).context("Failed to create key pair")?;
49
50    let public = Constant::BASE64_ENGINE.encode(key_pair.public_key().as_ref());
51
52    Ok(Base64KeyPair {
53        public_key: public,
54        private_key: private_key.to_string(),
55    })
56}
57
58pub fn generate_challenge() -> Challenge {
59    let rng = SystemRandom::new();
60    let mut challenge = Challenge::default();
61    rng.fill(&mut challenge).expect("Failed to generate challenge");
62    challenge
63}
64
65pub fn sign_challenge(challenge: &Challenge, private_key: &SecretString) -> Res<Signature> {
66    debug!("Challenge: `{:?}`", challenge);
67
68    let private_key = Constant::BASE64_ENGINE.decode(private_key.expose_secret()).context("Could not decode private key")?;
69    debug!("Signing challenge with private key: {:?}", &private_key);
70
71    let key_pair = Ed25519KeyPair::from_pkcs8(&private_key).map_err(|_| Err::msg("Invalid private key"))?;
72    debug!("Key pair: {:?}", key_pair);
73
74    let signature = key_pair.sign(challenge).as_ref()[..Constant::SIGNATURE_SIZE]
75        .try_into()
76        .map_err(|_| Err::msg("Invalid signature length"))?;
77    debug!("Signature: {:?}", &signature);
78
79    Ok(signature)
80}
81
82pub fn validate_signed_challenge(challenge: &Challenge, signature: &Signature, public_key: &str) -> Void {
83    let public_key = Constant::BASE64_ENGINE.decode(public_key).context("Could not decode public key")?;
84
85    let unparsed_public_key = ring::signature::UnparsedPublicKey::new(Constant::SIGNATURE, public_key);
86
87    unparsed_public_key.verify(challenge, signature).context("Invalid signature")?;
88
89    Ok(())
90}
91
92pub fn generate_ephemeral_key_pair() -> Res<ExchangeKeyPair> {
93    let rng = SystemRandom::new();
94
95    let my_private_key = EphemeralPrivateKey::generate(Constant::AGREEMENT, &rng)?;
96
97    let public_key = my_private_key.compute_public_key()?;
98
99    Ok(ExchangeKeyPair { public_key, private_key: my_private_key })
100}
101
102pub fn generate_shared_secret(private_key: EphemeralPrivateKey, peer_public_key: &ExchangePublicKey, salt_bytes: &[u8]) -> Res<SharedSecret> {
103    let unparsed_peer_public_key = ring::agreement::UnparsedPublicKey::new(Constant::AGREEMENT, peer_public_key);
104
105    let shared_secret = agree_ephemeral(private_key, &unparsed_peer_public_key, |shared_secret| generate_chacha_key(shared_secret, salt_bytes))??;
106    Ok(shared_secret)
107}
108
109fn generate_chacha_key(private_key: &[u8], salt_bytes: &[u8]) -> Res<SharedSecret> {
110    let salt = Salt::new(Constant::KDF, salt_bytes);
111    let info = &[salt_bytes];
112
113    let prk = salt.extract(private_key);
114    let okm = prk.expand(info, Constant::KDF)?;
115
116    let mut key = SharedSecretShape::default();
117    okm.fill(&mut key)?;
118
119    Ok(SharedSecret::init_with(|| key))
120}
121
122pub fn encrypt(shared_secret: &SharedSecret, plaintext: &[u8]) -> Res<EncryptedData> {
123    let rng = SystemRandom::new();
124    let mut nonce_bytes = [0u8; Constant::SHARED_SECRET_NONCE_SIZE];
125    rng.fill(&mut nonce_bytes).context("Could not fill nonce for encryption")?;
126
127    let unbound_key = UnboundKey::new(Constant::AEAD, shared_secret.expose_secret()).context("Could not generate unbound key for encryption")?;
128    let sealing_key = LessSafeKey::new(unbound_key);
129    let nonce = Nonce::assume_unique_for_key(nonce_bytes);
130
131    let mut in_out = plaintext.to_vec();
132    in_out.reserve_exact(Constant::AEAD.tag_len());
133
134    sealing_key
135        .seal_in_place_append_tag(nonce, Aad::empty(), &mut in_out)
136        .context("Could not seal in place during encryption")?;
137
138    Ok(EncryptedData { nonce: nonce_bytes, data: in_out })
139}
140
141pub fn decrypt(shared_secret: &SharedSecret, ciphertext: &[u8], nonce_bytes: &SharedSecretNonce) -> Res<Vec<u8>> {
142    let unbound_key = UnboundKey::new(Constant::AEAD, shared_secret.expose_secret()).context("Could not generate unbound key for decryption")?;
143    let opening_key = LessSafeKey::new(unbound_key);
144    let nonce = Nonce::assume_unique_for_key(*nonce_bytes);
145
146    let mut in_out = ciphertext.to_vec();
147    let plaintext = opening_key.open_in_place(nonce, Aad::empty(), &mut in_out).context("Could not open in place for decryption")?;
148
149    Ok(plaintext.to_vec())
150}
151
152/// Parses the tunnel definition from the given input string.
153///
154/// Input is of the form:
155/// - `local_port:destination_host:destination_port`
156/// - `local_port:destination_port`
157/// - `local_port`
158pub fn parse_tunnel_definition(tunnel: &str) -> Res<TunnelDefinition> {
159    let parts: Vec<&str> = tunnel.split(':').collect();
160
161    match parts.len() {
162        4 => {
163            let bind_address = format!("{}:{}", parts[0], parts[1]);
164            let host_address = format!("{}:{}", parts[2], parts[3]);
165
166            Ok(TunnelDefinition {
167                bind_address,
168                remote_address: host_address,
169            })
170        }
171        3 => {
172            let bind_address = format!("127.0.0.1:{}", parts[0]);
173            let host_address = format!("{}:{}", parts[1], parts[2]);
174
175            Ok(TunnelDefinition {
176                bind_address,
177                remote_address: host_address,
178            })
179        }
180        2 => {
181            let bind_address = format!("127.0.0.1:{}", parts[0]);
182            let host_address = format!("127.0.0.1:{}", parts[1]);
183
184            Ok(TunnelDefinition {
185                bind_address,
186                remote_address: host_address,
187            })
188        }
189        1 => {
190            let bind_address = format!("127.0.0.1:{}", parts[0]);
191            let host_address = format!("127.0.0.1:{}", parts[0]);
192
193            Ok(TunnelDefinition {
194                bind_address,
195                remote_address: host_address,
196            })
197        }
198        _ => Err(Err::msg("Invalid tunnel definition format")),
199    }
200}
201
202pub fn parse_tunnel_definitions<T>(tunnels: &[T]) -> Res<Vec<TunnelDefinition>>
203where
204    T: AsRef<str>,
205{
206    tunnels.iter().map(|tunnel| parse_tunnel_definition(tunnel.as_ref())).collect()
207}
208
209pub async fn handle_pump<A, B>(a: &mut A, b: &mut B) -> Res<(u64, u64)>
210where
211    A: AsyncRead + AsyncWrite + Unpin,
212    B: AsyncRead + AsyncWrite + Unpin,
213{
214    let result = tokio::io::copy_bidirectional_with_sizes(a, b, Constant::BUFFER_SIZE, Constant::BUFFER_SIZE).await?;
215
216    info!("⬅️ {} bytes ➡️ {} bytes", result.1, result.0);
217
218    Ok(result)
219}
220
221#[cfg(test)]
222pub mod tests {
223    use tokio::io::{AsyncReadExt, AsyncWriteExt};
224
225    use crate::buffed_stream::{BuffedDuplexStream, BuffedStream};
226
227    use super::*;
228    use pretty_assertions::assert_eq;
229
230    pub fn generate_test_duplex() -> (BuffedDuplexStream, BuffedDuplexStream) {
231        let (a, b) = tokio::io::duplex(Constant::BUFFER_SIZE);
232        (BuffedStream::from(a), BuffedStream::from(b))
233    }
234
235    pub fn generate_test_duplex_with_encryption() -> (BuffedDuplexStream, BuffedDuplexStream) {
236        let (a, b) = tokio::io::duplex(Constant::BUFFER_SIZE);
237        let secret_box = generate_test_shared_secret();
238        let shared_secret = secret_box.expose_secret();
239
240        (
241            BuffedStream::from(a).with_encryption(SharedSecret::init_with(|| *shared_secret)),
242            BuffedStream::from(b).with_encryption(SharedSecret::init_with(|| *shared_secret)),
243        )
244    }
245
246    pub fn generate_test_ephemeral_key_pair() -> ExchangeKeyPair {
247        generate_ephemeral_key_pair().unwrap()
248    }
249
250    pub fn generate_test_shared_secret() -> SharedSecret {
251        let ephemeral_key_pair = generate_test_ephemeral_key_pair();
252        let challenge = generate_challenge();
253
254        generate_shared_secret(ephemeral_key_pair.private_key, ephemeral_key_pair.public_key.as_ref().try_into().unwrap(), &challenge).unwrap()
255    }
256
257    pub fn generate_test_fake_exchange_public_key() -> ExchangePublicKey {
258        b"this needs to be exactly 32 byte".as_ref().try_into().unwrap()
259    }
260
261    #[test]
262    fn test_generate_key_pair() {
263        let key_pair = generate_key_pair().unwrap();
264        assert_eq!(key_pair.public_key.len(), 43);
265        assert_eq!(key_pair.private_key.len(), 111);
266    }
267
268    #[test]
269    fn test_generate_key_pair_from_key() {
270        let key_pair = generate_key_pair().unwrap();
271        let new_key_pair = generate_key_pair_from_key(&key_pair.private_key).unwrap();
272        assert_eq!(new_key_pair.public_key, key_pair.public_key);
273        assert_eq!(new_key_pair.private_key, key_pair.private_key);
274    }
275
276    #[test]
277    fn test_ed25519() {
278        let key_pair = generate_key_pair().unwrap();
279
280        let challenge = generate_challenge();
281        let signature = sign_challenge(&challenge, &key_pair.private_key.into()).unwrap();
282
283        validate_signed_challenge(&challenge, &signature, &key_pair.public_key).unwrap();
284    }
285
286    #[test]
287    fn test_ephemeral_key_exchange() {
288        let ephemeral_key_pair_1 = generate_ephemeral_key_pair().unwrap();
289        let ephemeral_key_pair_2 = generate_ephemeral_key_pair().unwrap();
290        let challenge = generate_challenge();
291
292        let shared_secret_1 = generate_shared_secret(ephemeral_key_pair_1.private_key, ephemeral_key_pair_2.public_key.as_ref().try_into().unwrap(), &challenge).unwrap();
293        let shared_secret_2 = generate_shared_secret(ephemeral_key_pair_2.private_key, ephemeral_key_pair_1.public_key.as_ref().try_into().unwrap(), &challenge).unwrap();
294
295        assert_eq!(shared_secret_1.expose_secret().len(), Constant::SHARED_SECRET_SIZE);
296        assert_eq!(shared_secret_1.expose_secret(), shared_secret_2.expose_secret());
297    }
298
299    #[test]
300    fn test_encrypt_decrypt() {
301        let shared_secret = generate_test_shared_secret();
302
303        let plaintext = b"Hello, world!";
304        let encrypted_data = encrypt(&shared_secret, plaintext).unwrap();
305        let decrypted_data = decrypt(&shared_secret, &encrypted_data.data, &encrypted_data.nonce).unwrap();
306
307        assert_eq!(decrypted_data, plaintext);
308    }
309
310    #[test]
311    fn test_parse_tunnel_definition() {
312        let input = "a:b:c:d";
313        let result = parse_tunnel_definition(input).unwrap();
314        assert_eq!(result.bind_address, "a:b");
315        assert_eq!(result.remote_address, "c:d");
316
317        let input = "a:b:c";
318        let result = parse_tunnel_definition(input).unwrap();
319        assert_eq!(result.bind_address, "127.0.0.1:a");
320        assert_eq!(result.remote_address, "b:c");
321
322        let input = "a:b";
323        let result = parse_tunnel_definition(input).unwrap();
324        assert_eq!(result.bind_address, "127.0.0.1:a");
325        assert_eq!(result.remote_address, "127.0.0.1:b");
326
327        let input = "a";
328        let result = parse_tunnel_definition(input).unwrap();
329        assert_eq!(result.bind_address, "127.0.0.1:a");
330        assert_eq!(result.remote_address, "127.0.0.1:a");
331    }
332
333    #[test]
334    fn test_bad_tunnel_definition() {
335        let input = "a:b:c:d:e";
336        assert!(parse_tunnel_definition(input).is_err());
337
338        let input = "a:b:c:d:e:f";
339        assert!(parse_tunnel_definition(input).is_err());
340    }
341
342    #[tokio::test]
343    async fn test_handle_pump() {
344        let (mut client, mut server1) = generate_test_duplex();
345        let (mut server2, mut remote) = generate_test_duplex();
346
347        client.write_all(b"Hello, remote!").await.unwrap();
348        client.shutdown().await.unwrap();
349        remote.write_all(b"Hello, client!!").await.unwrap();
350        remote.shutdown().await.unwrap();
351
352        let (up, down) = handle_pump(&mut server1, &mut server2).await.unwrap();
353
354        assert_eq!(up, 14);
355        assert_eq!(down, 15);
356
357        let mut client_received = vec![];
358        client.read_to_end(&mut client_received).await.unwrap();
359        assert_eq!(client_received, b"Hello, client!!");
360
361        let mut remote_received = vec![];
362        remote.read_to_end(&mut remote_received).await.unwrap();
363        assert_eq!(remote_received, b"Hello, remote!");
364    }
365
366    #[tokio::test]
367    async fn test_handle_pump_with_encryption() {
368        let (mut client, mut server1) = generate_test_duplex_with_encryption();
369        let (mut server2, mut remote) = generate_test_duplex_with_encryption();
370
371        client.write_all(b"Hello, remote!").await.unwrap();
372        client.shutdown().await.unwrap();
373        remote.write_all(b"Hello, client!!").await.unwrap();
374        remote.shutdown().await.unwrap();
375
376        let (up, down) = handle_pump(&mut server1, &mut server2).await.unwrap();
377
378        assert_eq!(up, 14);
379        assert_eq!(down, 15);
380    }
381}