1use crate::{
4 EncryptionError, EncryptionKey, EncryptionNonce, Hash, KeyPair, PublicKey, SignatureBytes,
5 SigningError, decrypt, encrypt, generate_key, generate_nonce, hash, verify,
6};
7use std::fs::File;
8use std::io::{self, Read, Write};
9use std::path::Path;
10use thiserror::Error;
11
12#[derive(Debug, Error)]
14pub enum UtilError {
15 #[error("IO error: {0}")]
16 Io(#[from] io::Error),
17
18 #[error("Encryption error: {0}")]
19 Encryption(#[from] EncryptionError),
20
21 #[error("Signing error: {0}")]
22 Signing(#[from] SigningError),
23
24 #[error("Invalid file format: {0}")]
25 InvalidFormat(String),
26}
27
28pub type UtilResult<T> = Result<T, UtilError>;
30
31#[derive(Debug, Clone)]
33pub struct EncryptedMessage {
34 pub ciphertext: Vec<u8>,
36 pub nonce: EncryptionNonce,
38 pub plaintext_hash: Option<Hash>,
40}
41
42impl EncryptedMessage {
43 pub fn new(ciphertext: Vec<u8>, nonce: EncryptionNonce) -> Self {
45 Self {
46 ciphertext,
47 nonce,
48 plaintext_hash: None,
49 }
50 }
51
52 pub fn with_hash(ciphertext: Vec<u8>, nonce: EncryptionNonce, plaintext_hash: Hash) -> Self {
54 Self {
55 ciphertext,
56 nonce,
57 plaintext_hash: Some(plaintext_hash),
58 }
59 }
60
61 pub fn total_size(&self) -> usize {
63 self.ciphertext.len() + 12 + if self.plaintext_hash.is_some() { 32 } else { 0 }
64 }
65
66 pub fn to_bytes(&self) -> Vec<u8> {
68 let mut bytes = Vec::with_capacity(self.total_size());
69 bytes.extend_from_slice(&self.nonce);
70 bytes.extend_from_slice(&self.ciphertext);
71 if let Some(hash) = &self.plaintext_hash {
72 bytes.extend_from_slice(hash);
73 }
74 bytes
75 }
76
77 pub fn from_bytes(bytes: &[u8], with_hash: bool) -> UtilResult<Self> {
79 if bytes.len() < 12 {
80 return Err(UtilError::InvalidFormat(
81 "Too short for encrypted message".to_string(),
82 ));
83 }
84
85 let mut nonce = [0u8; 12];
86 nonce.copy_from_slice(&bytes[0..12]);
87
88 let ciphertext_end = if with_hash {
89 if bytes.len() < 44 {
90 return Err(UtilError::InvalidFormat(
91 "Too short for encrypted message with hash".to_string(),
92 ));
93 }
94 bytes.len() - 32
95 } else {
96 bytes.len()
97 };
98
99 let ciphertext = bytes[12..ciphertext_end].to_vec();
100
101 let plaintext_hash = if with_hash {
102 let mut hash = [0u8; 32];
103 hash.copy_from_slice(&bytes[ciphertext_end..]);
104 Some(hash)
105 } else {
106 None
107 };
108
109 Ok(Self {
110 ciphertext,
111 nonce,
112 plaintext_hash,
113 })
114 }
115}
116
117#[derive(Debug, Clone)]
119pub struct SignedMessage {
120 pub message: Vec<u8>,
122 pub signature: SignatureBytes,
124 pub public_key: PublicKey,
126}
127
128impl SignedMessage {
129 pub fn new(message: Vec<u8>, signature: SignatureBytes, public_key: PublicKey) -> Self {
131 Self {
132 message,
133 signature,
134 public_key,
135 }
136 }
137
138 pub fn sign(message: Vec<u8>, keypair: &KeyPair) -> Self {
140 let signature = keypair.sign(&message);
141 let public_key = keypair.public_key();
142 Self::new(message, signature, public_key)
143 }
144
145 pub fn verify(&self) -> Result<(), SigningError> {
147 verify(&self.public_key, &self.message, &self.signature)
148 }
149
150 pub fn to_bytes(&self) -> Vec<u8> {
152 let mut bytes = Vec::with_capacity(self.message.len() + 64 + 32);
153 bytes.extend_from_slice(&(self.message.len() as u32).to_le_bytes());
154 bytes.extend_from_slice(&self.message);
155 bytes.extend_from_slice(&self.signature);
156 bytes.extend_from_slice(&self.public_key);
157 bytes
158 }
159
160 pub fn from_bytes(bytes: &[u8]) -> UtilResult<Self> {
162 if bytes.len() < 100 {
163 return Err(UtilError::InvalidFormat(
165 "Too short for signed message".to_string(),
166 ));
167 }
168
169 let msg_len = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
170 if bytes.len() < 4 + msg_len + 64 + 32 {
171 return Err(UtilError::InvalidFormat(
172 "Invalid signed message length".to_string(),
173 ));
174 }
175
176 let message = bytes[4..4 + msg_len].to_vec();
177
178 let mut signature = [0u8; 64];
179 signature.copy_from_slice(&bytes[4 + msg_len..4 + msg_len + 64]);
180
181 let mut public_key = [0u8; 32];
182 public_key.copy_from_slice(&bytes[4 + msg_len + 64..4 + msg_len + 96]);
183
184 Ok(Self {
185 message,
186 signature,
187 public_key,
188 })
189 }
190}
191
192#[derive(Debug, Clone)]
194pub struct EncryptedAndSigned {
195 pub encrypted: EncryptedMessage,
197 pub signature: SignatureBytes,
199 pub signer_public_key: PublicKey,
201}
202
203impl EncryptedAndSigned {
204 pub fn create(
206 plaintext: &[u8],
207 encryption_key: &EncryptionKey,
208 signing_keypair: &KeyPair,
209 ) -> UtilResult<Self> {
210 let nonce = generate_nonce();
212 let ciphertext = encrypt(plaintext, encryption_key, &nonce)?;
213
214 let plaintext_hash = hash(plaintext);
216
217 let encrypted = EncryptedMessage::with_hash(ciphertext, nonce, plaintext_hash);
219
220 let signature = signing_keypair.sign(&encrypted.ciphertext);
222 let signer_public_key = signing_keypair.public_key();
223
224 Ok(Self {
225 encrypted,
226 signature,
227 signer_public_key,
228 })
229 }
230
231 pub fn verify_and_decrypt(&self, decryption_key: &EncryptionKey) -> UtilResult<Vec<u8>> {
233 verify(
235 &self.signer_public_key,
236 &self.encrypted.ciphertext,
237 &self.signature,
238 )?;
239
240 let plaintext = decrypt(
242 &self.encrypted.ciphertext,
243 decryption_key,
244 &self.encrypted.nonce,
245 )?;
246
247 if let Some(expected_hash) = &self.encrypted.plaintext_hash {
249 let actual_hash = hash(&plaintext);
250 if &actual_hash != expected_hash {
251 return Err(UtilError::InvalidFormat(
252 "Plaintext hash mismatch".to_string(),
253 ));
254 }
255 }
256
257 Ok(plaintext)
258 }
259}
260
261pub fn encrypt_file(
263 input_path: impl AsRef<Path>,
264 output_path: impl AsRef<Path>,
265 key: &EncryptionKey,
266) -> UtilResult<EncryptionNonce> {
267 let mut file = File::open(input_path)?;
268 let mut plaintext = Vec::new();
269 file.read_to_end(&mut plaintext)?;
270
271 let nonce = generate_nonce();
272 let ciphertext = encrypt(&plaintext, key, &nonce)?;
273
274 let mut output = File::create(output_path)?;
275 output.write_all(&nonce)?;
276 output.write_all(&ciphertext)?;
277
278 Ok(nonce)
279}
280
281pub fn decrypt_file(
283 input_path: impl AsRef<Path>,
284 output_path: impl AsRef<Path>,
285 key: &EncryptionKey,
286) -> UtilResult<()> {
287 let mut file = File::open(input_path)?;
288 let mut data = Vec::new();
289 file.read_to_end(&mut data)?;
290
291 if data.len() < 12 {
292 return Err(UtilError::InvalidFormat(
293 "File too short to contain nonce".to_string(),
294 ));
295 }
296
297 let mut nonce = [0u8; 12];
298 nonce.copy_from_slice(&data[0..12]);
299
300 let ciphertext = &data[12..];
301 let plaintext = decrypt(ciphertext, key, &nonce)?;
302
303 let mut output = File::create(output_path)?;
304 output.write_all(&plaintext)?;
305
306 Ok(())
307}
308
309pub fn generate_and_save_key(path: impl AsRef<Path>) -> UtilResult<EncryptionKey> {
311 let key = generate_key();
312 let mut file = File::create(path)?;
313 file.write_all(&key)?;
314 Ok(key)
315}
316
317pub fn load_key(path: impl AsRef<Path>) -> UtilResult<EncryptionKey> {
319 let mut file = File::open(path)?;
320 let mut key = [0u8; 32];
321 file.read_exact(&mut key)?;
322 Ok(key)
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use std::io::Write;
329 use tempfile::NamedTempFile;
330
331 #[test]
332 fn test_encrypted_message_roundtrip() {
333 let key = generate_key();
334 let nonce = generate_nonce();
335 let plaintext = b"Hello, World!";
336
337 let ciphertext = encrypt(plaintext, &key, &nonce).unwrap();
338 let msg = EncryptedMessage::new(ciphertext, nonce);
339
340 let bytes = msg.to_bytes();
341 let restored = EncryptedMessage::from_bytes(&bytes, false).unwrap();
342
343 assert_eq!(msg.nonce, restored.nonce);
344 assert_eq!(msg.ciphertext, restored.ciphertext);
345 }
346
347 #[test]
348 fn test_encrypted_message_with_hash() {
349 let key = generate_key();
350 let nonce = generate_nonce();
351 let plaintext = b"Hello, World!";
352 let plaintext_hash = hash(plaintext);
353
354 let ciphertext = encrypt(plaintext, &key, &nonce).unwrap();
355 let msg = EncryptedMessage::with_hash(ciphertext, nonce, plaintext_hash);
356
357 let bytes = msg.to_bytes();
358 let restored = EncryptedMessage::from_bytes(&bytes, true).unwrap();
359
360 assert_eq!(msg.plaintext_hash, restored.plaintext_hash);
361 }
362
363 #[test]
364 fn test_signed_message_roundtrip() {
365 let keypair = KeyPair::generate();
366 let message = b"Test message".to_vec();
367
368 let signed = SignedMessage::sign(message.clone(), &keypair);
369 assert!(signed.verify().is_ok());
370
371 let bytes = signed.to_bytes();
372 let restored = SignedMessage::from_bytes(&bytes).unwrap();
373
374 assert_eq!(signed.message, restored.message);
375 assert_eq!(signed.signature, restored.signature);
376 assert_eq!(signed.public_key, restored.public_key);
377 assert!(restored.verify().is_ok());
378 }
379
380 #[test]
381 fn test_encrypted_and_signed() {
382 let encryption_key = generate_key();
383 let signing_keypair = KeyPair::generate();
384 let plaintext = b"Secure message";
385
386 let encrypted_signed =
387 EncryptedAndSigned::create(plaintext, &encryption_key, &signing_keypair).unwrap();
388
389 let decrypted = encrypted_signed
390 .verify_and_decrypt(&encryption_key)
391 .unwrap();
392
393 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
394 }
395
396 #[test]
397 fn test_file_encryption() {
398 let key = generate_key();
399
400 let mut input_file = NamedTempFile::new().unwrap();
402 let output_file = NamedTempFile::new().unwrap();
403 let decrypted_file = NamedTempFile::new().unwrap();
404
405 let plaintext = b"This is a test file content";
406 input_file.write_all(plaintext).unwrap();
407 input_file.flush().unwrap();
408
409 encrypt_file(input_file.path(), output_file.path(), &key).unwrap();
411
412 decrypt_file(output_file.path(), decrypted_file.path(), &key).unwrap();
414
415 let mut decrypted = Vec::new();
417 File::open(decrypted_file.path())
418 .unwrap()
419 .read_to_end(&mut decrypted)
420 .unwrap();
421
422 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
423 }
424
425 #[test]
426 fn test_key_save_load() {
427 let key_file = NamedTempFile::new().unwrap();
428
429 let original_key = generate_and_save_key(key_file.path()).unwrap();
430 let loaded_key = load_key(key_file.path()).unwrap();
431
432 assert_eq!(original_key, loaded_key);
433 }
434}