Skip to main content

crablock_core/
crypto.rs

1use aes_gcm::{
2    aead::{Aead, KeyInit, Nonce},
3    Aes256Gcm,
4};
5use chacha20poly1305::ChaCha20Poly1305;
6use rand::RngCore;
7use sha2::{Digest, Sha256};
8use zeroize::{Zeroize, ZeroizeOnDrop};
9
10use crate::error::{CrablockError, Result};
11
12pub const KEY_SIZE: usize = 32;
13pub const NONCE_SIZE: usize = 12;
14pub const TAG_SIZE: usize = 16;
15pub const AES_GCM_NONCE_SIZE: usize = 12;
16pub const CHACHA_NONCE_SIZE: usize = 12;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum EncryptionAlgorithm {
21    Aes256Gcm,
22    ChaCha20Poly1305,
23}
24
25impl EncryptionAlgorithm {
26    pub fn nonce_size(&self) -> usize {
27        match self {
28            EncryptionAlgorithm::Aes256Gcm => AES_GCM_NONCE_SIZE,
29            EncryptionAlgorithm::ChaCha20Poly1305 => CHACHA_NONCE_SIZE,
30        }
31    }
32
33    pub fn as_str(&self) -> &'static str {
34        match self {
35            EncryptionAlgorithm::Aes256Gcm => "aes_256_gcm",
36            EncryptionAlgorithm::ChaCha20Poly1305 => "chacha20_poly1305",
37        }
38    }
39}
40
41impl std::str::FromStr for EncryptionAlgorithm {
42    type Err = CrablockError;
43
44    fn from_str(s: &str) -> Result<Self> {
45        match s.to_lowercase().as_str() {
46            "aes_256_gcm" | "aes-256-gcm" | "aes256gcm" => Ok(EncryptionAlgorithm::Aes256Gcm),
47            "chacha20_poly1305" | "chacha20-poly1305" | "chacha20poly1305" => {
48                Ok(EncryptionAlgorithm::ChaCha20Poly1305)
49            }
50            _ => Err(CrablockError::UnsupportedAlgorithm(format!(
51                "Unknown algorithm: {s}"
52            ))),
53        }
54    }
55}
56
57#[derive(Clone, Zeroize, ZeroizeOnDrop)]
58pub struct EncryptionKey {
59    // The key is kept in a fixed-size array so we always know its exact length.
60    pub key: [u8; KEY_SIZE],
61}
62
63impl EncryptionKey {
64    pub fn new(key: [u8; KEY_SIZE]) -> Self {
65        Self { key }
66    }
67
68    pub fn from_hex(hex_str: &str) -> Result<Self> {
69        let bytes = hex::decode(hex_str)
70            .map_err(|e| CrablockError::InvalidKey(format!("Invalid hex: {e}")))?;
71
72        if bytes.len() != KEY_SIZE {
73            return Err(CrablockError::InvalidKey(format!(
74                "Key must be {} bytes, got {}",
75                KEY_SIZE,
76                bytes.len()
77            )));
78        }
79
80        let mut key = [0u8; KEY_SIZE];
81        key.copy_from_slice(&bytes);
82        Ok(Self::new(key))
83    }
84
85    pub fn from_base64(b64_str: &str) -> Result<Self> {
86        use base64::Engine;
87        let bytes = base64::engine::general_purpose::STANDARD
88            .decode(b64_str)
89            .map_err(|e| CrablockError::InvalidKey(format!("Invalid base64: {e}")))?;
90
91        if bytes.len() != KEY_SIZE {
92            return Err(CrablockError::InvalidKey(format!(
93                "Key must be {} bytes, got {}",
94                KEY_SIZE,
95                bytes.len()
96            )));
97        }
98
99        let mut key = [0u8; KEY_SIZE];
100        key.copy_from_slice(&bytes);
101        Ok(Self::new(key))
102    }
103
104    pub fn generate_random() -> Self {
105        // We generate a fresh random key when tests or helper code need one.
106        let mut key = [0u8; KEY_SIZE];
107        rand::thread_rng().fill_bytes(&mut key);
108        Self::new(key)
109    }
110}
111
112pub struct Encryptor {
113    algorithm: EncryptionAlgorithm,
114    key: EncryptionKey,
115    nonce: Vec<u8>,
116}
117
118impl Encryptor {
119    pub fn new(algorithm: EncryptionAlgorithm, key: EncryptionKey) -> Self {
120        // Every encryption call needs a nonce.
121        // We store it on the encryptor so the caller can put it in the manifest later.
122        let nonce_size = algorithm.nonce_size();
123        let mut nonce = vec![0u8; nonce_size];
124        rand::thread_rng().fill_bytes(&mut nonce);
125
126        Self {
127            algorithm,
128            key,
129            nonce,
130        }
131    }
132
133    pub fn with_nonce(mut self, nonce: Vec<u8>) -> Self {
134        self.nonce = nonce;
135        self
136    }
137
138    pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
139        // The algorithm choice only changes the cipher details.
140        // The rest of the app treats the encrypted bytes the same way.
141        let ciphertext = match self.algorithm {
142            EncryptionAlgorithm::Aes256Gcm => {
143                let cipher = Aes256Gcm::new_from_slice(&self.key.key)
144                    .map_err(|e| CrablockError::Crypto(format!("AES key init failed: {e:?}")))?;
145                let nonce = Nonce::<Aes256Gcm>::from_slice(&self.nonce);
146                cipher
147                    .encrypt(nonce, plaintext)
148                    .map_err(|e| CrablockError::Crypto(format!("AES encryption failed: {e:?}")))?
149            }
150            EncryptionAlgorithm::ChaCha20Poly1305 => {
151                use chacha20poly1305::aead::Aead as ChaChaAead;
152                use chacha20poly1305::aead::KeyInit as ChaChaKeyInit;
153                use chacha20poly1305::Nonce as ChaChaNonce;
154
155                let cipher = ChaCha20Poly1305::new_from_slice(&self.key.key)
156                    .map_err(|e| CrablockError::Crypto(format!("ChaCha key init failed: {e:?}")))?;
157                let nonce = ChaChaNonce::from_slice(&self.nonce);
158                cipher.encrypt(nonce, plaintext).map_err(|e| {
159                    CrablockError::Crypto(format!("ChaCha encryption failed: {e:?}"))
160                })?
161            }
162        };
163
164        Ok(ciphertext)
165    }
166
167    pub fn nonce(&self) -> &[u8] {
168        &self.nonce
169    }
170
171    pub fn algorithm(&self) -> EncryptionAlgorithm {
172        self.algorithm
173    }
174}
175
176pub struct Decryptor {
177    algorithm: EncryptionAlgorithm,
178    key: EncryptionKey,
179    nonce: Vec<u8>,
180}
181
182impl Decryptor {
183    pub fn new(algorithm: EncryptionAlgorithm, key: EncryptionKey, nonce: Vec<u8>) -> Self {
184        Self {
185            algorithm,
186            key,
187            nonce,
188        }
189    }
190
191    pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
192        // Decryption mirrors `encrypt` and returns a typed error when the key or nonce is wrong.
193        let plaintext = match self.algorithm {
194            EncryptionAlgorithm::Aes256Gcm => {
195                let cipher = Aes256Gcm::new_from_slice(&self.key.key)
196                    .map_err(|e| CrablockError::Crypto(format!("AES key init failed: {e:?}")))?;
197                let nonce = Nonce::<Aes256Gcm>::from_slice(&self.nonce);
198                cipher.decrypt(nonce, ciphertext).map_err(|e| {
199                    CrablockError::DecryptionFailed(format!("AES decryption failed: {e:?}"))
200                })?
201            }
202            EncryptionAlgorithm::ChaCha20Poly1305 => {
203                use chacha20poly1305::aead::Aead as ChaChaAead;
204                use chacha20poly1305::aead::KeyInit as ChaChaKeyInit;
205                use chacha20poly1305::Nonce as ChaChaNonce;
206
207                let cipher = ChaCha20Poly1305::new_from_slice(&self.key.key)
208                    .map_err(|e| CrablockError::Crypto(format!("ChaCha key init failed: {e:?}")))?;
209                let nonce = ChaChaNonce::from_slice(&self.nonce);
210                cipher.decrypt(nonce, ciphertext).map_err(|e| {
211                    CrablockError::DecryptionFailed(format!("ChaCha decryption failed: {e:?}"))
212                })?
213            }
214        };
215
216        Ok(plaintext)
217    }
218}
219
220pub fn compute_sha256(data: &[u8]) -> String {
221    // We store hashes as lowercase hex strings because they are easy to log and compare.
222    let mut hasher = Sha256::new();
223    hasher.update(data);
224    hex::encode(hasher.finalize())
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_aes_encryption_roundtrip() {
233        let key = EncryptionKey::generate_random();
234        let plaintext = b"Hello, World!";
235
236        let encryptor = Encryptor::new(EncryptionAlgorithm::Aes256Gcm, key.clone());
237        let nonce = encryptor.nonce().to_vec();
238        let ciphertext = encryptor.encrypt(plaintext).unwrap();
239
240        let decryptor = Decryptor::new(EncryptionAlgorithm::Aes256Gcm, key, nonce);
241        let decrypted = decryptor.decrypt(&ciphertext).unwrap();
242
243        assert_eq!(plaintext.as_slice(), decrypted.as_slice());
244    }
245
246    #[test]
247    fn test_chacha_encryption_roundtrip() {
248        let key = EncryptionKey::generate_random();
249        let plaintext = b"Hello, World!";
250
251        let encryptor = Encryptor::new(EncryptionAlgorithm::ChaCha20Poly1305, key.clone());
252        let nonce = encryptor.nonce().to_vec();
253        let ciphertext = encryptor.encrypt(plaintext).unwrap();
254
255        let decryptor = Decryptor::new(EncryptionAlgorithm::ChaCha20Poly1305, key, nonce);
256        let decrypted = decryptor.decrypt(&ciphertext).unwrap();
257
258        assert_eq!(plaintext.as_slice(), decrypted.as_slice());
259    }
260
261    #[test]
262    fn test_wrong_key_fails() {
263        let key1 = EncryptionKey::generate_random();
264        let key2 = EncryptionKey::generate_random();
265        let plaintext = b"Hello, World!";
266
267        let encryptor = Encryptor::new(EncryptionAlgorithm::Aes256Gcm, key1);
268        let nonce = encryptor.nonce().to_vec();
269        let ciphertext = encryptor.encrypt(plaintext).unwrap();
270
271        let decryptor = Decryptor::new(EncryptionAlgorithm::Aes256Gcm, key2, nonce);
272        let result = decryptor.decrypt(&ciphertext);
273
274        assert!(result.is_err());
275    }
276
277    #[test]
278    fn test_sha256() {
279        let data = b"hello";
280        let hash = compute_sha256(data);
281        assert_eq!(hash.len(), 64); // hex encoded 256 bits
282    }
283}