modality_utils/
encrypted_text.rs

1use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
2use ring::{aead, pbkdf2, rand, hkdf};
3use ring::rand::SecureRandom;
4use std::num::NonZeroU32;
5
6pub struct EncryptedText;
7
8impl EncryptedText {
9    fn derive_key(password: &str, salt: &[u8]) -> Result<[u8; 32], &'static str> {
10        const ITERATIONS: u32 = 100_000;
11        let iterations = NonZeroU32::new(ITERATIONS).unwrap();
12        let mut pbkdf2_output = [0u8; 32];
13
14        pbkdf2::derive(
15            pbkdf2::PBKDF2_HMAC_SHA256,
16            iterations,
17            salt,
18            password.as_bytes(),
19            &mut pbkdf2_output,
20        );
21
22        // Then HKDF
23        let hkdf_salt = hkdf::Salt::new(hkdf::HKDF_SHA256, salt);
24        let prk = hkdf_salt.extract(&pbkdf2_output);
25        let mut final_key = [0u8; 32];
26        
27        // Use expand and fill with a fixed length
28        prk.expand(&[b"aes-256-gcm"], hkdf::HKDF_SHA256)
29            .map_err(|_| "HKDF expand failed")?
30            .fill(&mut final_key)
31            .map_err(|_| "HKDF fill failed")?;
32
33
34        Ok(final_key)
35    }
36
37    pub fn encrypt(text: &str, password: &str) -> Result<String, &'static str> {
38        let rng = rand::SystemRandom::new();
39        let mut salt = vec![0u8; 16];
40        rng.fill(&mut salt).map_err(|_| "Failed to generate salt")?;
41
42        let key_bytes = Self::derive_key(password, &salt)?;
43
44        let unbound_key = aead::UnboundKey::new(&aead::AES_256_GCM, &key_bytes)
45            .map_err(|_| "Failed to create key")?;
46        let key = aead::LessSafeKey::new(unbound_key);
47
48        let mut nonce_bytes = vec![0u8; 12];
49        rng.fill(&mut nonce_bytes).map_err(|_| "Failed to generate nonce")?;
50        let nonce = aead::Nonce::try_assume_unique_for_key(&nonce_bytes)
51            .map_err(|_| "Invalid nonce")?;
52
53        let mut in_out = text.as_bytes().to_vec();
54        key.seal_in_place_append_tag(nonce, aead::Aad::empty(), &mut in_out)
55            .map_err(|_| "Encryption failed")?;
56
57        let mut combined = Vec::new();
58        combined.extend_from_slice(&salt);
59        combined.extend_from_slice(&nonce_bytes);
60        combined.extend_from_slice(&in_out);
61
62        Ok(BASE64.encode(combined))
63    }
64
65    pub fn decrypt(encrypted_base64: &str, password: &str) -> Result<String, &'static str> {
66        let combined = BASE64.decode(encrypted_base64)
67            .map_err(|_| "Invalid base64 data")?;
68
69        if combined.len() < 28 {
70            return Err("Data too short");
71        }
72
73        let salt = &combined[0..16];
74        let nonce_bytes = &combined[16..28];
75        let ciphertext = &combined[28..];
76
77        let key_bytes = Self::derive_key(password, salt)?;
78
79        let unbound_key = aead::UnboundKey::new(&aead::AES_256_GCM, &key_bytes)
80            .map_err(|_| "Failed to create key")?;
81        let key = aead::LessSafeKey::new(unbound_key);
82
83        let nonce = aead::Nonce::try_assume_unique_for_key(nonce_bytes)
84            .map_err(|_| "Invalid nonce")?;
85
86        let mut decrypted = ciphertext.to_vec();
87        let decrypted_len = key.open_in_place(nonce, aead::Aad::empty(), &mut decrypted)
88            .map_err(|_| "Decryption failed - invalid password or corrupted data")?
89            .len();
90
91        // Truncate to the actual decrypted length (excluding auth tag)
92        decrypted.truncate(decrypted_len);
93
94        String::from_utf8(decrypted)
95            .map_err(|_| "Invalid UTF-8 in decrypted data")
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102
103    #[test]
104    fn test_encryption_decryption() {
105        let password = "MySecretPassword123!";
106        let text = "Hello, Web Crypto with password-based encryption!";
107
108        let encrypted = EncryptedText::encrypt(text, password).unwrap();
109        let decrypted = EncryptedText::decrypt(&encrypted, password).unwrap();
110        assert_eq!(decrypted, text);
111
112        let result = EncryptedText::decrypt(&encrypted, "WrongPassword");
113        assert!(result.is_err());
114    }
115
116    #[test]
117    fn test_corrupted_data() {
118        let password = "MySecretPassword123!";
119        let text = "Test message";
120
121        let encrypted = EncryptedText::encrypt(text, password).unwrap();
122        
123        // Corrupt different parts of the data
124        let corrupted_start = format!("A{}", &encrypted[1..]);
125        assert!(EncryptedText::decrypt(&corrupted_start, password).is_err());
126
127        let mid = encrypted.len() / 2;
128        let corrupted_middle = format!("{}A{}", &encrypted[..mid], &encrypted[mid+1..]);
129        assert!(EncryptedText::decrypt(&corrupted_middle, password).is_err());
130
131        let corrupted_end = format!("{}AAAA", &encrypted[..encrypted.len()-4]);
132        assert!(EncryptedText::decrypt(&corrupted_end, password).is_err());
133    }
134
135    #[test]
136    fn test_various_lengths() {
137        let password = "test123";
138        let texts = vec![
139            "",
140            "a",
141            "hello",
142            "This is a longer test message with spaces and !@#$ symbols",
143            "🦀 Rust with Unicode 🔐"
144        ];
145
146        for text in texts {
147            let encrypted = EncryptedText::encrypt(text, password).unwrap();
148            let decrypted = EncryptedText::decrypt(&encrypted, password).unwrap();
149            assert_eq!(text, decrypted);
150        }
151    }
152    
153    #[test]
154    fn test_known_string() {
155        const KNOWN_PASSWORD: &str = "test_password_123";
156        const KNOWN_MESSAGE: &str = "Hello, Cross-Platform Encryption!";
157        const KNOWN_ENCRYPTED: &str = "1G73otj9BTJ5i3djZyuemijZnGkMb8XawInJVUqLqiNTIRPrBrs8MxL0y+cJWTcxGcxkS7H+/BltKwxqS0dd5TYTN81cOWaHmO7SJR0=";
158    
159        // Test decryption of known string
160        let decrypted = EncryptedText::decrypt(KNOWN_ENCRYPTED, KNOWN_PASSWORD).unwrap();
161        assert_eq!(decrypted, KNOWN_MESSAGE);
162
163        // Test that we can also encrypt and decrypt our own message
164        let encrypted = EncryptedText::encrypt(KNOWN_MESSAGE, KNOWN_PASSWORD).unwrap();
165        let decrypted = EncryptedText::decrypt(&encrypted, KNOWN_PASSWORD).unwrap();
166        assert_eq!(decrypted, KNOWN_MESSAGE);
167    }
168}