1use anyhow::Context;
7use base64::Engine;
8use rand::{Rng, distr::Alphanumeric};
9use ring::{
10 aead::{Aad, LessSafeKey, Nonce, UnboundKey},
11 agreement::{EphemeralPrivateKey, agree_ephemeral},
12 hkdf::Salt,
13 rand::{SecureRandom, SystemRandom},
14 signature::{Ed25519KeyPair, KeyPair},
15};
16use secrecy::{ExposeSecret, SecretString};
17use tokio::io::{AsyncRead, AsyncWrite};
18use tracing::{debug, info};
19
20use crate::{
21 base::{Base64KeyPair, Constant, EncryptedData, Err, ExchangeKeyPair, Res, SharedSecret, SharedSecretNonce, SharedSecretShape, TunnelDefinition, Void},
22 protocol::{Challenge, ExchangePublicKey, Signature},
23};
24
25pub fn random_string(len: usize) -> String {
26 rand::rng().sample_iter(&Alphanumeric).take(len).map(char::from).collect()
27}
28
29pub fn generate_key_pair() -> Res<Base64KeyPair> {
30 let rng = SystemRandom::new();
31 let pkcs8 = Ed25519KeyPair::generate_pkcs8(&rng).context("Unable to generate key pair")?;
32
33 let key_pair = Ed25519KeyPair::from_pkcs8(pkcs8.as_ref()).context("Failed to create key pair")?;
34
35 let public = Constant::BASE64_ENGINE.encode(key_pair.public_key().as_ref());
36 let private = Constant::BASE64_ENGINE.encode(pkcs8.as_ref());
37
38 Ok(Base64KeyPair { public_key: public, private_key: private })
39}
40
41pub fn generate_key_pair_from_key(private_key: &str) -> Res<Base64KeyPair> {
42 let key_bytes = Constant::BASE64_ENGINE.decode(private_key).context("Could not decode seed")?;
43
44 let key_pair = Ed25519KeyPair::from_pkcs8(&key_bytes).context("Failed to create key pair")?;
45
46 let public = Constant::BASE64_ENGINE.encode(key_pair.public_key().as_ref());
47
48 Ok(Base64KeyPair {
49 public_key: public,
50 private_key: private_key.to_string(),
51 })
52}
53
54pub fn generate_challenge() -> Challenge {
55 let rng = SystemRandom::new();
56 let mut challenge = Challenge::default();
57 rng.fill(&mut challenge).expect("Failed to generate challenge");
58 challenge
59}
60
61pub fn sign_challenge(challenge: &Challenge, private_key: &SecretString) -> Res<Signature> {
62 debug!("Challenge: `{:?}`", challenge);
63
64 let private_key = Constant::BASE64_ENGINE.decode(private_key.expose_secret()).context("Could not decode private key")?;
65 debug!("Signing challenge with private key: {:?}", &private_key);
66
67 let key_pair = Ed25519KeyPair::from_pkcs8(&private_key).map_err(|_| Err::msg("Invalid private key"))?;
68 debug!("Key pair: {:?}", key_pair);
69
70 let signature = key_pair.sign(challenge).as_ref()[..Constant::SIGNATURE_SIZE]
71 .try_into()
72 .map_err(|_| Err::msg("Invalid signature length"))?;
73 debug!("Signature: {:?}", &signature);
74
75 Ok(signature)
76}
77
78pub fn validate_signed_challenge(challenge: &Challenge, signature: &Signature, public_key: &str) -> Void {
79 let public_key = Constant::BASE64_ENGINE.decode(public_key).context("Could not decode public key")?;
80
81 let unparsed_public_key = ring::signature::UnparsedPublicKey::new(Constant::SIGNATURE, public_key);
82
83 unparsed_public_key.verify(challenge, signature).context("Invalid signature")?;
84
85 Ok(())
86}
87
88pub fn generate_ephemeral_key_pair() -> Res<ExchangeKeyPair> {
89 let rng = SystemRandom::new();
90
91 let my_private_key = EphemeralPrivateKey::generate(Constant::AGREEMENT, &rng)?;
92
93 let public_key = my_private_key.compute_public_key()?;
94
95 Ok(ExchangeKeyPair { public_key, private_key: my_private_key })
96}
97
98pub fn generate_shared_secret(private_key: EphemeralPrivateKey, peer_public_key: &ExchangePublicKey, salt_bytes: &[u8]) -> Res<SharedSecret> {
99 let unparsed_peer_public_key = ring::agreement::UnparsedPublicKey::new(Constant::AGREEMENT, peer_public_key);
100
101 let shared_secret = agree_ephemeral(private_key, &unparsed_peer_public_key, |shared_secret| generate_chacha_key(shared_secret, salt_bytes))??;
102 Ok(shared_secret)
103}
104
105fn generate_chacha_key(private_key: &[u8], salt_bytes: &[u8]) -> Res<SharedSecret> {
106 let salt = Salt::new(Constant::KDF, salt_bytes);
107 let info = &[salt_bytes];
108
109 let prk = salt.extract(private_key);
110 let okm = prk.expand(info, Constant::KDF)?;
111
112 let mut key = SharedSecretShape::default();
113 okm.fill(&mut key)?;
114
115 Ok(SharedSecret::init_with(|| key))
116}
117
118pub fn encrypt(shared_secret: &SharedSecret, plaintext: &[u8]) -> Res<EncryptedData> {
119 let rng = SystemRandom::new();
120 let mut nonce_bytes = [0u8; Constant::SHARED_SECRET_NONCE_SIZE];
121 rng.fill(&mut nonce_bytes).context("Could not fill nonce for encryption")?;
122
123 let unbound_key = UnboundKey::new(Constant::AEAD, shared_secret.expose_secret()).context("Could not generate unbound key for encryption")?;
124 let sealing_key = LessSafeKey::new(unbound_key);
125 let nonce = Nonce::assume_unique_for_key(nonce_bytes);
126
127 let mut in_out = plaintext.to_vec();
128 in_out.reserve_exact(Constant::AEAD.tag_len());
129
130 sealing_key
131 .seal_in_place_append_tag(nonce, Aad::empty(), &mut in_out)
132 .context("Could not seal in place during encryption")?;
133
134 Ok(EncryptedData { nonce: nonce_bytes, data: in_out })
135}
136
137pub fn decrypt(shared_secret: &SharedSecret, ciphertext: &[u8], nonce_bytes: &SharedSecretNonce) -> Res<Vec<u8>> {
138 let unbound_key = UnboundKey::new(Constant::AEAD, shared_secret.expose_secret()).context("Could not generate unbound key for decryption")?;
139 let opening_key = LessSafeKey::new(unbound_key);
140 let nonce = Nonce::assume_unique_for_key(*nonce_bytes);
141
142 let mut in_out = ciphertext.to_vec();
143 let plaintext = opening_key.open_in_place(nonce, Aad::empty(), &mut in_out).context("Could not open in place for decryption")?;
144
145 Ok(plaintext.to_vec())
146}
147
148pub fn parse_tunnel_definition(tunnel: &str) -> Res<TunnelDefinition> {
155 let parts: Vec<&str> = tunnel.split(':').collect();
156
157 match parts.len() {
158 4 => {
159 let bind_address = format!("{}:{}", parts[0], parts[1]);
160 let host_address = format!("{}:{}", parts[2], parts[3]);
161
162 Ok(TunnelDefinition {
163 bind_address,
164 remote_address: host_address,
165 })
166 }
167 3 => {
168 let bind_address = format!("127.0.0.1:{}", parts[0]);
169 let host_address = format!("{}:{}", parts[1], parts[2]);
170
171 Ok(TunnelDefinition {
172 bind_address,
173 remote_address: host_address,
174 })
175 }
176 2 => {
177 let bind_address = format!("127.0.0.1:{}", parts[0]);
178 let host_address = format!("127.0.0.1:{}", parts[1]);
179
180 Ok(TunnelDefinition {
181 bind_address,
182 remote_address: host_address,
183 })
184 }
185 1 => {
186 let bind_address = format!("127.0.0.1:{}", parts[0]);
187 let host_address = format!("127.0.0.1:{}", parts[0]);
188
189 Ok(TunnelDefinition {
190 bind_address,
191 remote_address: host_address,
192 })
193 }
194 _ => Err(Err::msg("Invalid tunnel definition format")),
195 }
196}
197
198pub fn parse_tunnel_definitions<T>(tunnels: &[T]) -> Res<Vec<TunnelDefinition>>
199where
200 T: AsRef<str>,
201{
202 tunnels.iter().map(|tunnel| parse_tunnel_definition(tunnel.as_ref())).collect()
203}
204
205pub async fn handle_pump<A, B>(a: &mut A, b: &mut B) -> Res<(u64, u64)>
206where
207 A: AsyncRead + AsyncWrite + Unpin,
208 B: AsyncRead + AsyncWrite + Unpin,
209{
210 let result = tokio::io::copy_bidirectional_with_sizes(a, b, Constant::BUFFER_SIZE, Constant::BUFFER_SIZE).await?;
211
212 info!("⬅️ {} bytes ➡️ {} bytes", result.1, result.0);
213
214 Ok(result)
215}
216
217#[cfg(test)]
218pub mod tests {
219 use tokio::io::{AsyncReadExt, AsyncWriteExt};
220
221 use crate::buffed_stream::{BuffedDuplexStream, BuffedStream};
222
223 use super::*;
224 use pretty_assertions::assert_eq;
225
226 pub fn generate_test_duplex() -> (BuffedDuplexStream, BuffedDuplexStream) {
227 let (a, b) = tokio::io::duplex(Constant::BUFFER_SIZE);
228 (BuffedStream::from(a), BuffedStream::from(b))
229 }
230
231 pub fn generate_test_duplex_with_encryption() -> (BuffedDuplexStream, BuffedDuplexStream) {
232 let (a, b) = tokio::io::duplex(Constant::BUFFER_SIZE);
233 let secret_box = generate_test_shared_secret();
234 let shared_secret = secret_box.expose_secret();
235
236 (
237 BuffedStream::from(a).with_encryption(SharedSecret::init_with(|| *shared_secret)),
238 BuffedStream::from(b).with_encryption(SharedSecret::init_with(|| *shared_secret)),
239 )
240 }
241
242 pub fn generate_test_ephemeral_key_pair() -> ExchangeKeyPair {
243 generate_ephemeral_key_pair().unwrap()
244 }
245
246 pub fn generate_test_shared_secret() -> SharedSecret {
247 let ephemeral_key_pair = generate_test_ephemeral_key_pair();
248 let challenge = generate_challenge();
249
250 generate_shared_secret(ephemeral_key_pair.private_key, ephemeral_key_pair.public_key.as_ref().try_into().unwrap(), &challenge).unwrap()
251 }
252
253 pub fn generate_test_fake_exchange_public_key() -> ExchangePublicKey {
254 b"this needs to be exactly 32 byte".as_ref().try_into().unwrap()
255 }
256
257 #[test]
258 fn test_generate_key_pair() {
259 let key_pair = generate_key_pair().unwrap();
260 assert_eq!(key_pair.public_key.len(), 43);
261 assert_eq!(key_pair.private_key.len(), 111);
262 }
263
264 #[test]
265 fn test_generate_key_pair_from_key() {
266 let key_pair = generate_key_pair().unwrap();
267 let new_key_pair = generate_key_pair_from_key(&key_pair.private_key).unwrap();
268 assert_eq!(new_key_pair.public_key, key_pair.public_key);
269 assert_eq!(new_key_pair.private_key, key_pair.private_key);
270 }
271
272 #[test]
273 fn test_ed25519() {
274 let key_pair = generate_key_pair().unwrap();
275
276 let challenge = generate_challenge();
277 let signature = sign_challenge(&challenge, &key_pair.private_key.into()).unwrap();
278
279 validate_signed_challenge(&challenge, &signature, &key_pair.public_key).unwrap();
280 }
281
282 #[test]
283 fn test_ephemeral_key_exchange() {
284 let ephemeral_key_pair_1 = generate_ephemeral_key_pair().unwrap();
285 let ephemeral_key_pair_2 = generate_ephemeral_key_pair().unwrap();
286 let challenge = generate_challenge();
287
288 let shared_secret_1 = generate_shared_secret(ephemeral_key_pair_1.private_key, ephemeral_key_pair_2.public_key.as_ref().try_into().unwrap(), &challenge).unwrap();
289 let shared_secret_2 = generate_shared_secret(ephemeral_key_pair_2.private_key, ephemeral_key_pair_1.public_key.as_ref().try_into().unwrap(), &challenge).unwrap();
290
291 assert_eq!(shared_secret_1.expose_secret().len(), Constant::SHARED_SECRET_SIZE);
292 assert_eq!(shared_secret_1.expose_secret(), shared_secret_2.expose_secret());
293 }
294
295 #[test]
296 fn test_encrypt_decrypt() {
297 let shared_secret = generate_test_shared_secret();
298
299 let plaintext = b"Hello, world!";
300 let encrypted_data = encrypt(&shared_secret, plaintext).unwrap();
301 let decrypted_data = decrypt(&shared_secret, &encrypted_data.data, &encrypted_data.nonce).unwrap();
302
303 assert_eq!(decrypted_data, plaintext);
304 }
305
306 #[test]
307 fn test_parse_tunnel_definition() {
308 let input = "a:b:c:d";
309 let result = parse_tunnel_definition(input).unwrap();
310 assert_eq!(result.bind_address, "a:b");
311 assert_eq!(result.remote_address, "c:d");
312
313 let input = "a:b:c";
314 let result = parse_tunnel_definition(input).unwrap();
315 assert_eq!(result.bind_address, "127.0.0.1:a");
316 assert_eq!(result.remote_address, "b:c");
317
318 let input = "a:b";
319 let result = parse_tunnel_definition(input).unwrap();
320 assert_eq!(result.bind_address, "127.0.0.1:a");
321 assert_eq!(result.remote_address, "127.0.0.1:b");
322
323 let input = "a";
324 let result = parse_tunnel_definition(input).unwrap();
325 assert_eq!(result.bind_address, "127.0.0.1:a");
326 assert_eq!(result.remote_address, "127.0.0.1:a");
327 }
328
329 #[test]
330 fn test_bad_tunnel_definition() {
331 let input = "a:b:c:d:e";
332 assert!(parse_tunnel_definition(input).is_err());
333
334 let input = "a:b:c:d:e:f";
335 assert!(parse_tunnel_definition(input).is_err());
336 }
337
338 #[tokio::test]
339 async fn test_handle_pump() {
340 let (mut client, mut server1) = generate_test_duplex();
341 let (mut server2, mut remote) = generate_test_duplex();
342
343 client.write_all(b"Hello, remote!").await.unwrap();
344 client.shutdown().await.unwrap();
345 remote.write_all(b"Hello, client!!").await.unwrap();
346 remote.shutdown().await.unwrap();
347
348 let (up, down) = handle_pump(&mut server1, &mut server2).await.unwrap();
349
350 assert_eq!(up, 14);
351 assert_eq!(down, 15);
352
353 let mut client_received = vec![];
354 client.read_to_end(&mut client_received).await.unwrap();
355 assert_eq!(client_received, b"Hello, client!!");
356
357 let mut remote_received = vec![];
358 remote.read_to_end(&mut remote_received).await.unwrap();
359 assert_eq!(remote_received, b"Hello, remote!");
360 }
361
362 #[tokio::test]
363 async fn test_handle_pump_with_encryption() {
364 let (mut client, mut server1) = generate_test_duplex_with_encryption();
365 let (mut server2, mut remote) = generate_test_duplex_with_encryption();
366
367 client.write_all(b"Hello, remote!").await.unwrap();
368 client.shutdown().await.unwrap();
369 remote.write_all(b"Hello, client!!").await.unwrap();
370 remote.shutdown().await.unwrap();
371
372 let (up, down) = handle_pump(&mut server1, &mut server2).await.unwrap();
373
374 assert_eq!(up, 14);
375 assert_eq!(down, 15);
376 }
377}