modality_utils/
encrypted_text.rs1use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
2use ring::{aead, pbkdf2, rand, hkdf};
3use ring::rand::SecureRandom;
4use std::num::NonZeroU32;
5
6pub struct EncryptedText;
7
8impl EncryptedText {
9 fn derive_key(password: &str, salt: &[u8]) -> Result<[u8; 32], &'static str> {
10 const ITERATIONS: u32 = 100_000;
11 let iterations = NonZeroU32::new(ITERATIONS).unwrap();
12 let mut pbkdf2_output = [0u8; 32];
13
14 pbkdf2::derive(
15 pbkdf2::PBKDF2_HMAC_SHA256,
16 iterations,
17 salt,
18 password.as_bytes(),
19 &mut pbkdf2_output,
20 );
21
22 let hkdf_salt = hkdf::Salt::new(hkdf::HKDF_SHA256, salt);
24 let prk = hkdf_salt.extract(&pbkdf2_output);
25 let mut final_key = [0u8; 32];
26
27 prk.expand(&[b"aes-256-gcm"], hkdf::HKDF_SHA256)
29 .map_err(|_| "HKDF expand failed")?
30 .fill(&mut final_key)
31 .map_err(|_| "HKDF fill failed")?;
32
33
34 Ok(final_key)
35 }
36
37 pub fn encrypt(text: &str, password: &str) -> Result<String, &'static str> {
38 let rng = rand::SystemRandom::new();
39 let mut salt = vec![0u8; 16];
40 rng.fill(&mut salt).map_err(|_| "Failed to generate salt")?;
41
42 let key_bytes = Self::derive_key(password, &salt)?;
43
44 let unbound_key = aead::UnboundKey::new(&aead::AES_256_GCM, &key_bytes)
45 .map_err(|_| "Failed to create key")?;
46 let key = aead::LessSafeKey::new(unbound_key);
47
48 let mut nonce_bytes = vec![0u8; 12];
49 rng.fill(&mut nonce_bytes).map_err(|_| "Failed to generate nonce")?;
50 let nonce = aead::Nonce::try_assume_unique_for_key(&nonce_bytes)
51 .map_err(|_| "Invalid nonce")?;
52
53 let mut in_out = text.as_bytes().to_vec();
54 key.seal_in_place_append_tag(nonce, aead::Aad::empty(), &mut in_out)
55 .map_err(|_| "Encryption failed")?;
56
57 let mut combined = Vec::new();
58 combined.extend_from_slice(&salt);
59 combined.extend_from_slice(&nonce_bytes);
60 combined.extend_from_slice(&in_out);
61
62 Ok(BASE64.encode(combined))
63 }
64
65 pub fn decrypt(encrypted_base64: &str, password: &str) -> Result<String, &'static str> {
66 let combined = BASE64.decode(encrypted_base64)
67 .map_err(|_| "Invalid base64 data")?;
68
69 if combined.len() < 28 {
70 return Err("Data too short");
71 }
72
73 let salt = &combined[0..16];
74 let nonce_bytes = &combined[16..28];
75 let ciphertext = &combined[28..];
76
77 let key_bytes = Self::derive_key(password, salt)?;
78
79 let unbound_key = aead::UnboundKey::new(&aead::AES_256_GCM, &key_bytes)
80 .map_err(|_| "Failed to create key")?;
81 let key = aead::LessSafeKey::new(unbound_key);
82
83 let nonce = aead::Nonce::try_assume_unique_for_key(nonce_bytes)
84 .map_err(|_| "Invalid nonce")?;
85
86 let mut decrypted = ciphertext.to_vec();
87 let decrypted_len = key.open_in_place(nonce, aead::Aad::empty(), &mut decrypted)
88 .map_err(|_| "Decryption failed - invalid password or corrupted data")?
89 .len();
90
91 decrypted.truncate(decrypted_len);
93
94 String::from_utf8(decrypted)
95 .map_err(|_| "Invalid UTF-8 in decrypted data")
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102
103 #[test]
104 fn test_encryption_decryption() {
105 let password = "MySecretPassword123!";
106 let text = "Hello, Web Crypto with password-based encryption!";
107
108 let encrypted = EncryptedText::encrypt(text, password).unwrap();
109 let decrypted = EncryptedText::decrypt(&encrypted, password).unwrap();
110 assert_eq!(decrypted, text);
111
112 let result = EncryptedText::decrypt(&encrypted, "WrongPassword");
113 assert!(result.is_err());
114 }
115
116 #[test]
117 fn test_corrupted_data() {
118 let password = "MySecretPassword123!";
119 let text = "Test message";
120
121 let encrypted = EncryptedText::encrypt(text, password).unwrap();
122
123 let corrupted_start = format!("A{}", &encrypted[1..]);
125 assert!(EncryptedText::decrypt(&corrupted_start, password).is_err());
126
127 let mid = encrypted.len() / 2;
128 let corrupted_middle = format!("{}A{}", &encrypted[..mid], &encrypted[mid+1..]);
129 assert!(EncryptedText::decrypt(&corrupted_middle, password).is_err());
130
131 let corrupted_end = format!("{}AAAA", &encrypted[..encrypted.len()-4]);
132 assert!(EncryptedText::decrypt(&corrupted_end, password).is_err());
133 }
134
135 #[test]
136 fn test_various_lengths() {
137 let password = "test123";
138 let texts = vec![
139 "",
140 "a",
141 "hello",
142 "This is a longer test message with spaces and !@#$ symbols",
143 "🦀 Rust with Unicode 🔐"
144 ];
145
146 for text in texts {
147 let encrypted = EncryptedText::encrypt(text, password).unwrap();
148 let decrypted = EncryptedText::decrypt(&encrypted, password).unwrap();
149 assert_eq!(text, decrypted);
150 }
151 }
152
153 #[test]
154 fn test_known_string() {
155 const KNOWN_PASSWORD: &str = "test_password_123";
156 const KNOWN_MESSAGE: &str = "Hello, Cross-Platform Encryption!";
157 const KNOWN_ENCRYPTED: &str = "1G73otj9BTJ5i3djZyuemijZnGkMb8XawInJVUqLqiNTIRPrBrs8MxL0y+cJWTcxGcxkS7H+/BltKwxqS0dd5TYTN81cOWaHmO7SJR0=";
158
159 let decrypted = EncryptedText::decrypt(KNOWN_ENCRYPTED, KNOWN_PASSWORD).unwrap();
161 assert_eq!(decrypted, KNOWN_MESSAGE);
162
163 let encrypted = EncryptedText::encrypt(KNOWN_MESSAGE, KNOWN_PASSWORD).unwrap();
165 let decrypted = EncryptedText::decrypt(&encrypted, KNOWN_PASSWORD).unwrap();
166 assert_eq!(decrypted, KNOWN_MESSAGE);
167 }
168}