1use anyhow::{anyhow, Result};
8use base64::engine::general_purpose::URL_SAFE_NO_PAD;
9use base64::Engine;
10use chacha20poly1305::aead::{Aead, KeyInit};
11use chacha20poly1305::{XChaCha20Poly1305, XNonce};
12use rand::RngCore;
13use zeroize::{Zeroize, ZeroizeOnDrop};
14
15pub const KEY_SIZE: usize = 32;
17
18pub const NONCE_SIZE: usize = 24;
20
21#[derive(Clone, Zeroize, ZeroizeOnDrop)]
26pub struct TunnelKey {
27 key: [u8; KEY_SIZE],
28}
29
30impl TunnelKey {
31 pub fn generate() -> Self {
33 let mut key = [0u8; KEY_SIZE];
34 rand::rngs::OsRng.fill_bytes(&mut key);
35 Self { key }
36 }
37
38 pub fn from_bytes(bytes: [u8; KEY_SIZE]) -> Self {
40 Self { key: bytes }
41 }
42
43 pub fn from_base64(encoded: &str) -> Result<Self> {
45 let bytes = URL_SAFE_NO_PAD
46 .decode(encoded)
47 .map_err(|e| anyhow!("Invalid base64: {}", e))?;
48
49 if bytes.len() != KEY_SIZE {
50 return Err(anyhow!(
51 "Invalid key length: expected {}, got {}",
52 KEY_SIZE,
53 bytes.len()
54 ));
55 }
56
57 let mut key = [0u8; KEY_SIZE];
58 key.copy_from_slice(&bytes);
59 Ok(Self { key })
60 }
61
62 pub fn to_base64(&self) -> String {
64 URL_SAFE_NO_PAD.encode(self.key)
65 }
66
67 pub fn as_bytes(&self) -> &[u8; KEY_SIZE] {
69 &self.key
70 }
71}
72
73impl std::fmt::Debug for TunnelKey {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 f.debug_struct("TunnelKey")
76 .field("key", &"[REDACTED]")
77 .finish()
78 }
79}
80
81pub struct TunnelCrypto {
89 cipher: XChaCha20Poly1305,
90}
91
92impl TunnelCrypto {
93 pub fn new(key: &TunnelKey) -> Self {
95 let cipher = XChaCha20Poly1305::new(key.as_bytes().into());
96 Self { cipher }
97 }
98
99 pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
103 let nonce = Self::generate_nonce();
104 let ciphertext = self
105 .cipher
106 .encrypt(XNonce::from_slice(&nonce), plaintext)
107 .map_err(|_| anyhow!("Encryption failed"))?;
108
109 let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
111 result.extend_from_slice(&nonce);
112 result.extend(ciphertext);
113 Ok(result)
114 }
115
116 pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
118 if data.len() < NONCE_SIZE + 16 {
120 return Err(anyhow!(
121 "Ciphertext too short: need at least {} bytes, got {}",
122 NONCE_SIZE + 16,
123 data.len()
124 ));
125 }
126
127 let (nonce, ciphertext) = data.split_at(NONCE_SIZE);
128 self.cipher
129 .decrypt(XNonce::from_slice(nonce), ciphertext)
130 .map_err(|_| anyhow!("Decryption failed: authentication tag mismatch"))
131 }
132
133 fn generate_nonce() -> [u8; NONCE_SIZE] {
135 let mut nonce = [0u8; NONCE_SIZE];
136 rand::rngs::OsRng.fill_bytes(&mut nonce);
137 nonce
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 #[test]
146 fn test_key_generation() {
147 let key1 = TunnelKey::generate();
148 let key2 = TunnelKey::generate();
149
150 assert_ne!(key1.as_bytes(), key2.as_bytes());
152 }
153
154 #[test]
155 fn test_key_base64_roundtrip() {
156 let key = TunnelKey::generate();
157 let encoded = key.to_base64();
158 let decoded = TunnelKey::from_base64(&encoded).unwrap();
159
160 assert_eq!(key.as_bytes(), decoded.as_bytes());
161 }
162
163 #[test]
164 fn test_key_base64_length() {
165 let key = TunnelKey::generate();
166 let encoded = key.to_base64();
167
168 assert_eq!(encoded.len(), 43);
170 }
171
172 #[test]
173 fn test_encrypt_decrypt_roundtrip() {
174 let key = TunnelKey::generate();
175 let crypto = TunnelCrypto::new(&key);
176
177 let plaintext = b"Hello, World! This is a test message.";
178 let ciphertext = crypto.encrypt(plaintext).unwrap();
179 let decrypted = crypto.decrypt(&ciphertext).unwrap();
180
181 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
182 }
183
184 #[test]
185 fn test_encrypt_produces_different_ciphertext() {
186 let key = TunnelKey::generate();
187 let crypto = TunnelCrypto::new(&key);
188
189 let plaintext = b"Same message";
190 let ciphertext1 = crypto.encrypt(plaintext).unwrap();
191 let ciphertext2 = crypto.encrypt(plaintext).unwrap();
192
193 assert_ne!(ciphertext1, ciphertext2);
195 }
196
197 #[test]
198 fn test_decrypt_wrong_key_fails() {
199 let key1 = TunnelKey::generate();
200 let key2 = TunnelKey::generate();
201 let crypto1 = TunnelCrypto::new(&key1);
202 let crypto2 = TunnelCrypto::new(&key2);
203
204 let plaintext = b"Secret message";
205 let ciphertext = crypto1.encrypt(plaintext).unwrap();
206
207 let result = crypto2.decrypt(&ciphertext);
209 assert!(result.is_err());
210 }
211
212 #[test]
213 fn test_decrypt_tampered_data_fails() {
214 let key = TunnelKey::generate();
215 let crypto = TunnelCrypto::new(&key);
216
217 let plaintext = b"Original message";
218 let mut ciphertext = crypto.encrypt(plaintext).unwrap();
219
220 if let Some(byte) = ciphertext.get_mut(NONCE_SIZE + 5) {
222 *byte ^= 0xFF;
223 }
224
225 let result = crypto.decrypt(&ciphertext);
227 assert!(result.is_err());
228 }
229
230 #[test]
231 fn test_decrypt_too_short_fails() {
232 let key = TunnelKey::generate();
233 let crypto = TunnelCrypto::new(&key);
234
235 let short_data = vec![0u8; 30];
237 let result = crypto.decrypt(&short_data);
238 assert!(result.is_err());
239 }
240
241 #[test]
242 fn test_empty_plaintext() {
243 let key = TunnelKey::generate();
244 let crypto = TunnelCrypto::new(&key);
245
246 let plaintext = b"";
247 let ciphertext = crypto.encrypt(plaintext).unwrap();
248 let decrypted = crypto.decrypt(&ciphertext).unwrap();
249
250 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
251 }
252
253 #[test]
254 fn test_large_plaintext() {
255 let key = TunnelKey::generate();
256 let crypto = TunnelCrypto::new(&key);
257
258 let plaintext = vec![0xAB; 1024 * 1024];
260 let ciphertext = crypto.encrypt(&plaintext).unwrap();
261 let decrypted = crypto.decrypt(&ciphertext).unwrap();
262
263 assert_eq!(plaintext, decrypted);
264 }
265}