mostro_core/
crypto.rs

1// In a new file like src/crypto.rs or src/utils/crypto.rs
2use crate::prelude::*;
3use argon2::{
4    password_hash::{rand_core::OsRng, Salt, SaltString},
5    Algorithm, Argon2, Params, PasswordHasher, Version,
6};
7use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine as _};
8use chacha20poly1305::{
9    aead::{Aead, KeyInit},
10    AeadCore, ChaCha20Poly1305, Key,
11};
12use secrecy::*;
13use std::collections::{HashMap, VecDeque};
14use std::sync::{LazyLock, RwLock};
15use zeroize::Zeroize;
16
17// 🔐 Cache: global static or pass it explicitly
18static KEY_CACHE: LazyLock<RwLock<SecretBox<SimpleCache>>> =
19    LazyLock::new(|| RwLock::new(SecretBox::new(Box::new(SimpleCache::new()))));
20
21// Constants for the crypto utils
22// Derived key length with Argon2
23const DERIVED_KEY_LENGTH: usize = 32;
24// Salt size and nonce size
25const SALT_SIZE: usize = 16;
26const NONCE_SIZE: usize = 12;
27// ----- SIMPLE FIXED-SIZE CACHE -----
28const MAX_CACHE_SIZE: usize = 50;
29
30// blake3 hash for cache key
31type CacheKey = blake3::Hash; // 256-bit
32
33struct SimpleCache {
34    map: HashMap<CacheKey, [u8; 32]>,
35    order: VecDeque<CacheKey>,
36}
37
38impl SimpleCache {
39    fn new() -> Self {
40        Self {
41            map: HashMap::new(),
42            order: VecDeque::new(),
43        }
44    }
45
46    fn get(&mut self, key: CacheKey) -> Option<[u8; 32]> {
47        if let Some(value) = self.map.get(&key) {
48            self.order.retain(|&k| k != key);
49            self.order.push_back(key);
50            Some(*value)
51        } else {
52            None
53        }
54    }
55
56    fn put(&mut self, key: CacheKey, value: [u8; 32]) {
57        if !self.map.contains_key(&key) && self.map.len() >= MAX_CACHE_SIZE {
58            if let Some(oldest_key) = self.order.pop_front() {
59                self.map.remove(&oldest_key);
60            }
61        }
62        self.order.retain(|&k| k != key);
63        self.order.push_back(key);
64        self.map.insert(key, value);
65    }
66}
67
68// Implementation of zeroize required by secretbox
69impl Zeroize for SimpleCache {
70    fn zeroize(&mut self) {
71        for value in self.map.values_mut() {
72            value.zeroize();
73        }
74        self.map.clear();
75        self.order.clear();
76    }
77}
78
79// On drop, zeroize the cache
80impl Drop for SimpleCache {
81    fn drop(&mut self) {
82        self.zeroize();
83    }
84}
85
86// make blake3 hash for cache key from password and salt
87fn make_cache_key(password: &str, salt: &[u8]) -> CacheKey {
88    blake3::hash([password.as_bytes(), salt].concat().as_slice())
89}
90
91pub struct CryptoUtils;
92
93impl CryptoUtils {
94    /// Derive a key from password and salt with Argon2
95    pub fn derive_key(password: &str, salt: &SaltString) -> Result<Vec<u8>, ServiceError> {
96        // Common key derivation logic
97        let params = Params::new(
98            Params::DEFAULT_M_COST,
99            Params::DEFAULT_T_COST,
100            Params::DEFAULT_P_COST * 2,
101            Some(Params::DEFAULT_OUTPUT_LEN),
102        )
103        .map_err(|_| ServiceError::EncryptionError("Error creating params".to_string()))?;
104
105        let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
106        let password_hash = argon2
107            .hash_password(password.as_bytes(), salt)
108            .map_err(|_| ServiceError::EncryptionError("Error hashing password".to_string()))?;
109
110        let key = password_hash
111            .hash
112            .ok_or_else(|| ServiceError::EncryptionError("Error getting hash".to_string()))?;
113        let key_bytes = key.as_bytes();
114        if key_bytes.len() != DERIVED_KEY_LENGTH {
115            return Err(ServiceError::EncryptionError(
116                "Key length is not 32 bytes".to_string(),
117            ));
118        }
119        Ok(key_bytes.to_vec())
120    }
121
122    /// Encrypt data with the provided key and return a base64 encoded string to store in the database
123    pub fn encrypt(data: &[u8], key: &[u8], salt: &[u8]) -> Result<String, ServiceError> {
124        // Encryption logic
125        // Check key length
126        if key.len() != DERIVED_KEY_LENGTH {
127            return Err(ServiceError::EncryptionError(
128                "Key length is not 32 bytes".to_string(),
129            ));
130        }
131        // Check salt length
132        if salt.len() != SALT_SIZE {
133            return Err(ServiceError::EncryptionError(
134                "Salt length is not 16 bytes".to_string(),
135            ));
136        }
137        // Create cipher
138        let cipher = ChaCha20Poly1305::new(Key::from_slice(key));
139        // Generate nonce
140        let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng); // 96-bits; unique per message
141
142        // Encrypt data
143        let ciphertext = cipher
144            .encrypt(&nonce, data)
145            .map_err(|e| ServiceError::EncryptionError(e.to_string()))?;
146
147        // Combine nonce and ciphertext
148        let mut encrypted = Vec::with_capacity(NONCE_SIZE + SALT_SIZE + ciphertext.len());
149        encrypted.extend_from_slice(&nonce);
150        encrypted.extend_from_slice(salt);
151        encrypted.extend_from_slice(&ciphertext);
152
153        // --- Encoding to String ---
154        // Encode the binary ciphertext into a Base64 String
155        let ciphertext_base64 = BASE64_STANDARD.encode(&encrypted);
156        Ok(ciphertext_base64)
157    }
158
159    /// Decrypt data with the provided key
160    /// In case of cached values return the cached value to speed up search
161    fn decrypt(data: Vec<u8>, password: &str) -> Result<Vec<u8>, ServiceError> {
162        // Split the encrypted data into nonce and data
163        let (nonce, data) = data.split_at(NONCE_SIZE);
164
165        let nonce: [u8; NONCE_SIZE] = nonce
166            .try_into()
167            .map_err(|e| ServiceError::DecryptionError(format!("Error converting nonce: {}", e)))?;
168
169        let (salt, ciphertext) = data.split_at(SALT_SIZE);
170
171        // Enecode salt from base64 to bytes
172        let salt = SaltString::encode_b64(salt)
173            .map_err(|e| ServiceError::DecryptionError(format!("Error decoding salt: {}", e)))?;
174
175        // get hash value from salt and password
176        let cache_key = make_cache_key(password, salt.as_str().as_bytes());
177
178        let mut cache = KEY_CACHE
179            .write()
180            .map_err(|_| ServiceError::DecryptionError("Error in key cache".to_string()))?;
181        // Check if the key is already in the cache
182        // If the key is in the cache, use it
183        let key_bytes = if let Some(cached_key) = cache.expose_secret_mut().get(cache_key) {
184            cached_key
185        } else {
186            // Key not cached, derive it
187            let key_bytes = CryptoUtils::derive_key(password, &salt)
188                .map_err(|_| ServiceError::DecryptionError("Error deriving key".to_string()))?;
189            let mut key_array = [0u8; 32];
190            key_array.copy_from_slice(&key_bytes);
191            cache.expose_secret_mut().put(cache_key, key_array);
192            key_array
193        };
194
195        // Create cipher
196        let cipher = ChaCha20Poly1305::new(Key::from_slice(&key_bytes));
197
198        // Decrypt the data
199        let decrypted = cipher
200            .decrypt(&nonce.into(), ciphertext)
201            .map_err(|e| ServiceError::DecryptionError(e.to_string()))?;
202
203        Ok(decrypted)
204    }
205
206    /// Decrypt an identity key from the database
207    pub fn decrypt_data(
208        data: String,
209        password: Option<&SecretString>,
210    ) -> Result<String, ServiceError> {
211        // If password is not provided, return data as it is
212        let password = match password {
213            Some(password) => password,
214            None => return Ok(data),
215        };
216        // Decode the encrypted data from base64 to bytes
217        let encrypted_bytes = BASE64_STANDARD.decode(&data).map_err(|_| {
218            ServiceError::DecryptionError("Error decoding encrypted data".to_string())
219        })?;
220
221        // Validate input length before processing
222        if encrypted_bytes.len() < NONCE_SIZE + SALT_SIZE {
223            return Err(ServiceError::DecryptionError(
224                "Invalid encrypted data: too short for nonce and salt".to_string(),
225            ));
226        }
227
228        // Extract key bytes, salt and ciphered text
229        let decrypted_data = CryptoUtils::decrypt(encrypted_bytes, password.expose_secret())?;
230
231        // Convert the decrypted data to a string and return it
232        String::from_utf8(decrypted_data).map_err(|_| {
233            ServiceError::DecryptionError("Error converting encrypted data to string".to_string())
234        })
235    }
236
237    /// Encrypt a string to save it in the database
238    ///
239    /// # Parameters
240    /// * `idkey` - The string data to be encrypted
241    /// * `password` - Optional password used for encryption. If None, returns the data unencrypted
242    /// * `fixed_salt` - Optional fixed salt for encryption. If None, generates a random salt.
243    ///   This parameter is primarily used for unit testing to ensure consistent encryption results.
244    ///
245    /// # Returns
246    /// Returns a Result containing either:
247    /// * Ok(String) - The encrypted data encoded in base64
248    /// * Err(ServiceError) - If encryption fails
249    pub fn store_encrypted(
250        idkey: &str,
251        password: Option<&SecretString>,
252        fixed_salt: Option<SaltString>,
253    ) -> Result<String, ServiceError> {
254        // If password is not provided, return data as it is
255        let password = match password {
256            Some(password) => password,
257            None => return Ok(idkey.to_string()),
258        };
259
260        // Salt generation
261        let salt = match fixed_salt {
262            Some(salt) => salt,
263            None => SaltString::generate(&mut OsRng),
264        };
265
266        // Buffer to decode salt
267        let buf = &mut [0u8; Salt::RECOMMENDED_LENGTH];
268        // Decode salt from base64 to bytes
269        let salt_decoded = salt
270            .decode_b64(buf)
271            .map_err(|e| ServiceError::EncryptionError(format!("Error decoding salt: {}", e)))?;
272
273        // Derive key as bytes
274        let key_bytes = CryptoUtils::derive_key(password.expose_secret(), &salt)
275            .map_err(|e| ServiceError::EncryptionError(format!("Error deriving key: {}", e)))?;
276
277        // Encrypt data and return base64 encoded string
278        let ciphertext_base64 = CryptoUtils::encrypt(idkey.as_bytes(), &key_bytes, salt_decoded)
279            .map_err(|e| ServiceError::EncryptionError(format!("Error encrypting data: {}", e)))?;
280
281        Ok(ciphertext_base64)
282    }
283}