crablock-core 0.1.2

Core library for crablock - encryption, package format, and common utilities
Documentation
use aes_gcm::{
    aead::{Aead, KeyInit, Nonce},
    Aes256Gcm,
};
use chacha20poly1305::ChaCha20Poly1305;
use rand::RngCore;
use sha2::{Digest, Sha256};
use zeroize::{Zeroize, ZeroizeOnDrop};

use crate::error::{CrablockError, Result};

pub const KEY_SIZE: usize = 32;
pub const NONCE_SIZE: usize = 12;
pub const TAG_SIZE: usize = 16;
pub const AES_GCM_NONCE_SIZE: usize = 12;
pub const CHACHA_NONCE_SIZE: usize = 12;

#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EncryptionAlgorithm {
    Aes256Gcm,
    ChaCha20Poly1305,
}

impl EncryptionAlgorithm {
    pub fn nonce_size(&self) -> usize {
        match self {
            EncryptionAlgorithm::Aes256Gcm => AES_GCM_NONCE_SIZE,
            EncryptionAlgorithm::ChaCha20Poly1305 => CHACHA_NONCE_SIZE,
        }
    }

    pub fn as_str(&self) -> &'static str {
        match self {
            EncryptionAlgorithm::Aes256Gcm => "aes_256_gcm",
            EncryptionAlgorithm::ChaCha20Poly1305 => "chacha20_poly1305",
        }
    }
}

impl std::str::FromStr for EncryptionAlgorithm {
    type Err = CrablockError;

    fn from_str(s: &str) -> Result<Self> {
        match s.to_lowercase().as_str() {
            "aes_256_gcm" | "aes-256-gcm" | "aes256gcm" => Ok(EncryptionAlgorithm::Aes256Gcm),
            "chacha20_poly1305" | "chacha20-poly1305" | "chacha20poly1305" => {
                Ok(EncryptionAlgorithm::ChaCha20Poly1305)
            }
            _ => Err(CrablockError::UnsupportedAlgorithm(format!(
                "Unknown algorithm: {s}"
            ))),
        }
    }
}

#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct EncryptionKey {
    // The key is kept in a fixed-size array so we always know its exact length.
    pub key: [u8; KEY_SIZE],
}

impl EncryptionKey {
    pub fn new(key: [u8; KEY_SIZE]) -> Self {
        Self { key }
    }

    pub fn from_hex(hex_str: &str) -> Result<Self> {
        let bytes = hex::decode(hex_str)
            .map_err(|e| CrablockError::InvalidKey(format!("Invalid hex: {e}")))?;

        if bytes.len() != KEY_SIZE {
            return Err(CrablockError::InvalidKey(format!(
                "Key must be {} bytes, got {}",
                KEY_SIZE,
                bytes.len()
            )));
        }

        let mut key = [0u8; KEY_SIZE];
        key.copy_from_slice(&bytes);
        Ok(Self::new(key))
    }

    pub fn from_base64(b64_str: &str) -> Result<Self> {
        use base64::Engine;
        let bytes = base64::engine::general_purpose::STANDARD
            .decode(b64_str)
            .map_err(|e| CrablockError::InvalidKey(format!("Invalid base64: {e}")))?;

        if bytes.len() != KEY_SIZE {
            return Err(CrablockError::InvalidKey(format!(
                "Key must be {} bytes, got {}",
                KEY_SIZE,
                bytes.len()
            )));
        }

        let mut key = [0u8; KEY_SIZE];
        key.copy_from_slice(&bytes);
        Ok(Self::new(key))
    }

    pub fn generate_random() -> Self {
        // We generate a fresh random key when tests or helper code need one.
        let mut key = [0u8; KEY_SIZE];
        rand::thread_rng().fill_bytes(&mut key);
        Self::new(key)
    }
}

pub struct Encryptor {
    algorithm: EncryptionAlgorithm,
    key: EncryptionKey,
    nonce: Vec<u8>,
}

impl Encryptor {
    pub fn new(algorithm: EncryptionAlgorithm, key: EncryptionKey) -> Self {
        // Every encryption call needs a nonce.
        // We store it on the encryptor so the caller can put it in the manifest later.
        let nonce_size = algorithm.nonce_size();
        let mut nonce = vec![0u8; nonce_size];
        rand::thread_rng().fill_bytes(&mut nonce);

        Self {
            algorithm,
            key,
            nonce,
        }
    }

    pub fn with_nonce(mut self, nonce: Vec<u8>) -> Self {
        self.nonce = nonce;
        self
    }

    pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
        // The algorithm choice only changes the cipher details.
        // The rest of the app treats the encrypted bytes the same way.
        let ciphertext = match self.algorithm {
            EncryptionAlgorithm::Aes256Gcm => {
                let cipher = Aes256Gcm::new_from_slice(&self.key.key)
                    .map_err(|e| CrablockError::Crypto(format!("AES key init failed: {e:?}")))?;
                let nonce = Nonce::<Aes256Gcm>::from_slice(&self.nonce);
                cipher
                    .encrypt(nonce, plaintext)
                    .map_err(|e| CrablockError::Crypto(format!("AES encryption failed: {e:?}")))?
            }
            EncryptionAlgorithm::ChaCha20Poly1305 => {
                use chacha20poly1305::aead::Aead as ChaChaAead;
                use chacha20poly1305::aead::KeyInit as ChaChaKeyInit;
                use chacha20poly1305::Nonce as ChaChaNonce;

                let cipher = ChaCha20Poly1305::new_from_slice(&self.key.key)
                    .map_err(|e| CrablockError::Crypto(format!("ChaCha key init failed: {e:?}")))?;
                let nonce = ChaChaNonce::from_slice(&self.nonce);
                cipher.encrypt(nonce, plaintext).map_err(|e| {
                    CrablockError::Crypto(format!("ChaCha encryption failed: {e:?}"))
                })?
            }
        };

        Ok(ciphertext)
    }

    pub fn nonce(&self) -> &[u8] {
        &self.nonce
    }

    pub fn algorithm(&self) -> EncryptionAlgorithm {
        self.algorithm
    }
}

pub struct Decryptor {
    algorithm: EncryptionAlgorithm,
    key: EncryptionKey,
    nonce: Vec<u8>,
}

impl Decryptor {
    pub fn new(algorithm: EncryptionAlgorithm, key: EncryptionKey, nonce: Vec<u8>) -> Self {
        Self {
            algorithm,
            key,
            nonce,
        }
    }

    pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
        // Decryption mirrors `encrypt` and returns a typed error when the key or nonce is wrong.
        let plaintext = match self.algorithm {
            EncryptionAlgorithm::Aes256Gcm => {
                let cipher = Aes256Gcm::new_from_slice(&self.key.key)
                    .map_err(|e| CrablockError::Crypto(format!("AES key init failed: {e:?}")))?;
                let nonce = Nonce::<Aes256Gcm>::from_slice(&self.nonce);
                cipher.decrypt(nonce, ciphertext).map_err(|e| {
                    CrablockError::DecryptionFailed(format!("AES decryption failed: {e:?}"))
                })?
            }
            EncryptionAlgorithm::ChaCha20Poly1305 => {
                use chacha20poly1305::aead::Aead as ChaChaAead;
                use chacha20poly1305::aead::KeyInit as ChaChaKeyInit;
                use chacha20poly1305::Nonce as ChaChaNonce;

                let cipher = ChaCha20Poly1305::new_from_slice(&self.key.key)
                    .map_err(|e| CrablockError::Crypto(format!("ChaCha key init failed: {e:?}")))?;
                let nonce = ChaChaNonce::from_slice(&self.nonce);
                cipher.decrypt(nonce, ciphertext).map_err(|e| {
                    CrablockError::DecryptionFailed(format!("ChaCha decryption failed: {e:?}"))
                })?
            }
        };

        Ok(plaintext)
    }
}

pub fn compute_sha256(data: &[u8]) -> String {
    // We store hashes as lowercase hex strings because they are easy to log and compare.
    let mut hasher = Sha256::new();
    hasher.update(data);
    hex::encode(hasher.finalize())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_aes_encryption_roundtrip() {
        let key = EncryptionKey::generate_random();
        let plaintext = b"Hello, World!";

        let encryptor = Encryptor::new(EncryptionAlgorithm::Aes256Gcm, key.clone());
        let nonce = encryptor.nonce().to_vec();
        let ciphertext = encryptor.encrypt(plaintext).unwrap();

        let decryptor = Decryptor::new(EncryptionAlgorithm::Aes256Gcm, key, nonce);
        let decrypted = decryptor.decrypt(&ciphertext).unwrap();

        assert_eq!(plaintext.as_slice(), decrypted.as_slice());
    }

    #[test]
    fn test_chacha_encryption_roundtrip() {
        let key = EncryptionKey::generate_random();
        let plaintext = b"Hello, World!";

        let encryptor = Encryptor::new(EncryptionAlgorithm::ChaCha20Poly1305, key.clone());
        let nonce = encryptor.nonce().to_vec();
        let ciphertext = encryptor.encrypt(plaintext).unwrap();

        let decryptor = Decryptor::new(EncryptionAlgorithm::ChaCha20Poly1305, key, nonce);
        let decrypted = decryptor.decrypt(&ciphertext).unwrap();

        assert_eq!(plaintext.as_slice(), decrypted.as_slice());
    }

    #[test]
    fn test_wrong_key_fails() {
        let key1 = EncryptionKey::generate_random();
        let key2 = EncryptionKey::generate_random();
        let plaintext = b"Hello, World!";

        let encryptor = Encryptor::new(EncryptionAlgorithm::Aes256Gcm, key1);
        let nonce = encryptor.nonce().to_vec();
        let ciphertext = encryptor.encrypt(plaintext).unwrap();

        let decryptor = Decryptor::new(EncryptionAlgorithm::Aes256Gcm, key2, nonce);
        let result = decryptor.decrypt(&ciphertext);

        assert!(result.is_err());
    }

    #[test]
    fn test_sha256() {
        let data = b"hello";
        let hash = compute_sha256(data);
        assert_eq!(hash.len(), 64); // hex encoded 256 bits
    }
}