modality_utils/
encrypted_text.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use ring::{aead, pbkdf2, rand, hkdf};
use ring::rand::SecureRandom;
use std::num::NonZeroU32;

#[derive(Debug)]
pub struct EncryptedData {
    salt: Vec<u8>,
    nonce: Vec<u8>,
    ciphertext: Vec<u8>,
}

pub struct EncryptedText;

impl EncryptedText {
    fn derive_key(password: &str, salt: &[u8]) -> Result<[u8; 32], &'static str> {
        const ITERATIONS: u32 = 100_000;
        let iterations = NonZeroU32::new(ITERATIONS).unwrap();
        let mut pbkdf2_output = [0u8; 32];

        pbkdf2::derive(
            pbkdf2::PBKDF2_HMAC_SHA256,
            iterations,
            salt,
            password.as_bytes(),
            &mut pbkdf2_output,
        );

        // Then HKDF
        let hkdf_salt = hkdf::Salt::new(hkdf::HKDF_SHA256, salt);
        let prk = hkdf_salt.extract(&pbkdf2_output);
        let mut final_key = [0u8; 32];
        
        // Use expand and fill with a fixed length
        prk.expand(&[b"aes-256-gcm"], hkdf::HKDF_SHA256)
            .map_err(|_| "HKDF expand failed")?
            .fill(&mut final_key)
            .map_err(|_| "HKDF fill failed")?;


        Ok(final_key)
    }

    pub fn encrypt(text: &str, password: &str) -> Result<String, &'static str> {
        let rng = rand::SystemRandom::new();
        let mut salt = vec![0u8; 16];
        rng.fill(&mut salt).map_err(|_| "Failed to generate salt")?;

        let key_bytes = Self::derive_key(password, &salt)?;

        let unbound_key = aead::UnboundKey::new(&aead::AES_256_GCM, &key_bytes)
            .map_err(|_| "Failed to create key")?;
        let key = aead::LessSafeKey::new(unbound_key);

        let mut nonce_bytes = vec![0u8; 12];
        rng.fill(&mut nonce_bytes).map_err(|_| "Failed to generate nonce")?;
        let nonce = aead::Nonce::try_assume_unique_for_key(&nonce_bytes)
            .map_err(|_| "Invalid nonce")?;

        let mut in_out = text.as_bytes().to_vec();
        key.seal_in_place_append_tag(nonce, aead::Aad::empty(), &mut in_out)
            .map_err(|_| "Encryption failed")?;

        let mut combined = Vec::new();
        combined.extend_from_slice(&salt);
        combined.extend_from_slice(&nonce_bytes);
        combined.extend_from_slice(&in_out);

        Ok(BASE64.encode(combined))
    }

    pub fn decrypt(encrypted_base64: &str, password: &str) -> Result<String, &'static str> {
        let combined = BASE64.decode(encrypted_base64)
            .map_err(|_| "Invalid base64 data")?;

        if combined.len() < 28 {
            return Err("Data too short");
        }

        let salt = &combined[0..16];
        let nonce_bytes = &combined[16..28];
        let ciphertext = &combined[28..];

        let key_bytes = Self::derive_key(password, salt)?;

        let unbound_key = aead::UnboundKey::new(&aead::AES_256_GCM, &key_bytes)
            .map_err(|_| "Failed to create key")?;
        let key = aead::LessSafeKey::new(unbound_key);

        let nonce = aead::Nonce::try_assume_unique_for_key(nonce_bytes)
            .map_err(|_| "Invalid nonce")?;

        let mut decrypted = ciphertext.to_vec();
        let decrypted_len = key.open_in_place(nonce, aead::Aad::empty(), &mut decrypted)
            .map_err(|_| "Decryption failed - invalid password or corrupted data")?
            .len();

        // Truncate to the actual decrypted length (excluding auth tag)
        decrypted.truncate(decrypted_len);

        String::from_utf8(decrypted)
            .map_err(|_| "Invalid UTF-8 in decrypted data")
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_encryption_decryption() {
        let password = "MySecretPassword123!";
        let text = "Hello, Web Crypto with password-based encryption!";

        let encrypted = EncryptedText::encrypt(text, password).unwrap();
        let decrypted = EncryptedText::decrypt(&encrypted, password).unwrap();
        assert_eq!(decrypted, text);

        let result = EncryptedText::decrypt(&encrypted, "WrongPassword");
        assert!(result.is_err());
    }

    #[test]
    fn test_corrupted_data() {
        let password = "MySecretPassword123!";
        let text = "Test message";

        let encrypted = EncryptedText::encrypt(text, password).unwrap();
        
        // Corrupt different parts of the data
        let corrupted_start = format!("A{}", &encrypted[1..]);
        assert!(EncryptedText::decrypt(&corrupted_start, password).is_err());

        let mid = encrypted.len() / 2;
        let corrupted_middle = format!("{}A{}", &encrypted[..mid], &encrypted[mid+1..]);
        assert!(EncryptedText::decrypt(&corrupted_middle, password).is_err());

        let corrupted_end = format!("{}AAAA", &encrypted[..encrypted.len()-4]);
        assert!(EncryptedText::decrypt(&corrupted_end, password).is_err());
    }

    #[test]
    fn test_various_lengths() {
        let password = "test123";
        let texts = vec![
            "",
            "a",
            "hello",
            "This is a longer test message with spaces and !@#$ symbols",
            "🦀 Rust with Unicode 🔐"
        ];

        for text in texts {
            let encrypted = EncryptedText::encrypt(text, password).unwrap();
            let decrypted = EncryptedText::decrypt(&encrypted, password).unwrap();
            assert_eq!(text, decrypted);
        }
    }
    
    #[test]
    fn test_known_string() {
        const KNOWN_PASSWORD: &str = "test_password_123";
        const KNOWN_MESSAGE: &str = "Hello, Cross-Platform Encryption!";
        const KNOWN_ENCRYPTED: &str = "1G73otj9BTJ5i3djZyuemijZnGkMb8XawInJVUqLqiNTIRPrBrs8MxL0y+cJWTcxGcxkS7H+/BltKwxqS0dd5TYTN81cOWaHmO7SJR0=";
    
        // Test decryption of known string
        let decrypted = EncryptedText::decrypt(KNOWN_ENCRYPTED, KNOWN_PASSWORD).unwrap();
        assert_eq!(decrypted, KNOWN_MESSAGE);

        // Test that we can also encrypt and decrypt our own message
        let encrypted = EncryptedText::encrypt(KNOWN_MESSAGE, KNOWN_PASSWORD).unwrap();
        let decrypted = EncryptedText::decrypt(&encrypted, KNOWN_PASSWORD).unwrap();
        assert_eq!(decrypted, KNOWN_MESSAGE);
    }
}