antimatter 2.0.13

antimatter.io Rust library for data control
Documentation
use crate::capsule::common::{CapsuleHeader, HookInfo};
use aes_gcm::{
    aead::{Aead, AeadCore, KeyInit},
    Aes256Gcm, Key, Nonce,
};
use ciborium::from_reader;
use ciborium::ser::into_writer;
use generic_array::GenericArray;
use serde_tuple::{Deserialize_tuple, Serialize_tuple};
use std::io::Cursor;

use crate::capsule::common::*;

#[derive(Serialize_tuple, Deserialize_tuple, Clone)]
pub struct CapsuleBody {
    pub capsule_tags: Vec<CapsuleTag>,
    pub columns: Vec<Column>,
    pub rows: Vec<Vec<DataElement>>,
    pub hook_info: Vec<HookInfo>,
    #[serde(skip)]
    pub open_token: String,
    #[serde(skip)]
    pub disable_read_logging: bool,
}

#[derive(Serialize_tuple, Deserialize_tuple, Clone)]
pub struct Capsule {
    pub header: CapsuleHeader,
    pub body: CapsuleBody,
}

#[derive(Serialize_tuple, Deserialize_tuple)]
pub struct SealedCapsule {
    pub header: CapsuleHeader,
    #[serde(with = "serde_bytes")]
    pub body: Vec<u8>,
}

impl Capsule {
    /// Creates a new `Capsule` object. A capsule is the container that stores
    /// the encrypted data, its encryption key information, who the capsule
    /// belongs to, and other additional information used for correctly reading
    /// back encapsulated data and applying policy correctly.
    ///
    /// # Arguments
    ///
    /// * `encrypted_dek` - The encrypted data encryption key used to decrypt
    ///        this capsule.
    /// * `key_id` - The key encryption key used to decrypt the encrypted data
    ///        key store here.
    /// * `domain_id` - The domain that owns the capsule.
    /// * `capsule_id` - The unique identifier for this capsule.
    /// * `columns` - Identifying information for each data column store in this
    ///        capsule.
    /// * `data_elements` - The tabulated data stored in this capsule.
    /// * `hook_info` - A vector of hooks used in classifying the data stored in
    ///        this capsule.
    ///
    /// # Returns
    ///
    /// A new `Capsule`.
    pub fn new(
        encrypted_dek: Vec<u8>,
        key_id: u64,
        domain_id: String,
        capsule_id: String,
        columns: Vec<Column>,
        mut data_elements: Vec<Vec<DataElement>>,
        hook_info: Vec<HookInfo>,
    ) -> Self {
        // sanity check the tags to ensure they do not fall out of bounds.
        // If they do truncate them to fit.
        for data_element_vec in data_elements.iter_mut() {
            for data_element in data_element_vec.iter_mut() {
                let data_len = data_element.data.len();
                for tag in data_element.tags.iter_mut() {
                    tag.start = if data_len > 0 {
                        tag.start.min(data_len - 1)
                    } else {
                        0
                    };
                    tag.end = if data_len > 0 {
                        tag.end.min(data_len).max(tag.start)
                    } else {
                        0
                    };
                }
            }
        }

        Capsule {
            header: CapsuleHeader {
                encrypted_dek,
                key_id,
                domain_id,
                capsule_id,
                disaster_recovery_token: None,
            },
            body: CapsuleBody {
                capsule_tags: vec![],
                columns,
                rows: data_elements,
                hook_info,
                open_token: "".to_string(),
                disable_read_logging: false,
            },
        }
    }

    /// Seals a `Capsule`, encrypting the data within and preparing it to be
    /// written out.
    ///
    /// # Arguments
    ///
    /// * `dek` - The data encryption key to encrypt the capsule with.
    ///        this capsule.
    /// * `nonce` - The nonce to use with the AEAD.
    ///
    /// # Returns
    ///
    /// A `SealedCapsule` object.
    pub fn seal(
        &mut self,
        dek: Vec<u8>,
        nonce: &[u8; NONCE_SIZE],
    ) -> Result<SealedCapsule, CapsuleError> {
        let mut buffer = Cursor::new(Vec::new());
        into_writer(&self.body, &mut buffer)
            .map_err(|e| CapsuleError::CBOREncodeFailed(format!("{}", e)))?;
        let encoded_capsule = buffer.into_inner();

        // TODO: support other types of encryption
        let ciphertext = encrypt_aes_gcm_256(dek, nonce, encoded_capsule)?;

        Ok(SealedCapsule {
            header: self.header.clone(),
            body: ciphertext,
        })
    }
}

impl SealedCapsule {
    /// Given the correct data encryption key, this decrypts a `SealedCapsule`
    /// and returns a `Capsule`.
    ///
    /// # Arguments:
    ///
    /// * `dek` - The data encrpytion key.
    /// * `open_token` - The open token from used by `Session` to open this
    ///        capsule.
    /// * `disabled_read_logging` - Disabling read logging reduces network
    ///        traffic to the logging servers, thus improving performance.
    ///
    /// # Returns:
    ///
    /// A decrypted `Capsule`.
    pub fn unseal(
        &self,
        dek: Vec<u8>,
        open_token: String,
        disable_read_logging: bool,
    ) -> Result<Capsule, CapsuleError> {
        let encoded = decrypt_aes_gcm_256(dek, &self.body)
            .map_err(|e| CapsuleError::DecryptionFailure(format!("{}", e)))?;
        let cursor = Cursor::new(encoded);
        let mut body: CapsuleBody =
            from_reader(cursor).map_err(|e| CapsuleError::CBORDecodeFailed(format!("{}", e)))?;
        body.open_token = open_token;
        body.disable_read_logging = disable_read_logging;

        Ok(Capsule {
            header: self.header.clone(),
            body,
        })
    }

    /// A helper function to get the size of a `SealedCapsule`.
    ///
    /// # Returns:
    ///
    /// The `SealedCapsules` size in bytes.
    pub fn len(&self) -> usize {
        self.body.len()
    }

    pub fn is_empty(&self) -> bool {
        self.body.is_empty()
    }
}

fn decrypt_aes_gcm_256(key: Vec<u8>, ciphertext: &[u8]) -> Result<Vec<u8>, aes_gcm::Error> {
    if ciphertext.len() < NONCE_SIZE {
        return Err(aes_gcm::Error);
    }

    let (nonce_bytes, ciphertext) = ciphertext.split_at(NONCE_SIZE);
    let nonce = Nonce::from_slice(nonce_bytes);
    let key = Key::<Aes256Gcm>::from_slice(&key);
    let cipher = Aes256Gcm::new(key);

    cipher.decrypt(nonce, ciphertext.as_ref())
}

pub fn encrypt_aes_gcm_256(
    key: Vec<u8>,
    nonce: &[u8; NONCE_SIZE],
    plaintext: Vec<u8>,
) -> Result<Vec<u8>, CapsuleError> {
    if plaintext.len() >= 64 * 1024 * 1024 * 1024 - 256 {
        return Err(CapsuleError::EncryptionFailure(
            "Plaintext exceeds maximum encryption size".to_string(),
        ));
    }
    if key.len() != 32 {
        return Err(CapsuleError::EncryptionFailure(
            "Invalid key length for AES-256 GCM".to_string(),
        ));
    }

    let nonce_ga: &GenericArray<u8, <Aes256Gcm as AeadCore>::NonceSize> =
        GenericArray::from_slice(nonce);
    let key = Key::<Aes256Gcm>::from_slice(&key);
    let cipher = Aes256Gcm::new(key);
    let mut ciphertext = cipher
        .encrypt(nonce_ga, plaintext.as_ref())
        .map_err(|e| CapsuleError::EncryptionFailure(format!("failed to encrypt: {}", e)))?;

    // Append the nonce at the beginning of the ciphertext
    let mut result = nonce.to_vec();
    result.append(&mut ciphertext);

    Ok(result)
}

#[cfg(test)]
pub mod tests {
    use super::*;
    use crate::capsule::common::BASE58_CHARSET;
    use rand::{distributions::Alphanumeric, Rng};

    pub fn generate_random_base58_string(len: usize) -> String {
        let charset = BASE58_CHARSET.as_bytes();
        let mut rng = rand::thread_rng();

        (0..len)
            .map(|_| {
                let idx = rng.gen_range(0..charset.len());
                charset[idx] as char
            })
            .collect()
    }

    // Helper function to generate a random string for testing
    pub fn generate_random_string(len: usize) -> String {
        rand::thread_rng()
            .sample_iter(&Alphanumeric)
            .take(len)
            .map(char::from)
            .collect()
    }

    pub fn generate_random_vec(length: usize) -> Vec<u8> {
        let mut rng = rand::thread_rng();
        let mut vec = vec![0u8; length];
        rng.fill(&mut vec[..]);
        vec
    }

    pub fn generate_random_nonce() -> [u8; NONCE_SIZE] {
        let mut rng = rand::thread_rng();
        let mut result: [u8; NONCE_SIZE] = [0; NONCE_SIZE];
        rng.fill(&mut result);
        result
    }

    // Helper function to create a test Capsule
    fn create_test_capsule() -> Capsule {
        let encrypted_dek = "abcdef".as_bytes().to_vec();
        let key_id = 1;
        let domain_id = format!("dm-{}", generate_random_base58_string(11));
        let capsule_id = format!("ca-{}", generate_random_base58_string(22));
        let column_defs = vec![];
        let data_elements = vec![];

        Capsule::new(
            encrypted_dek,
            key_id,
            domain_id,
            capsule_id,
            column_defs,
            data_elements,
            vec![],
        )
    }

    #[test]
    fn create_test_capsule_adjusted_spantags() {
        let mut tags = Vec::new();
        tags.push(SpanTag {
            tag: CapsuleTag {
                name: "Sample too long".to_string(),
                tag_type: TagType::Unary,
                value: "".to_string(),
                source: "manual".to_string(),
                hook_version: (1, 0, 0),
            },
            start: 2,
            end: 20,
        });
        tags.push(SpanTag {
            tag: CapsuleTag {
                name: "Sample out of bounds".to_string(),
                tag_type: TagType::Unary,
                value: "".to_string(),
                source: "manual".to_string(),
                hook_version: (1, 0, 0),
            },
            start: 15,
            end: 20,
        });

        let data_elements = vec![vec![DataElement {
            data: vec![0, 1, 2, 3, 4], // Data of length 5
            tags: tags,
        }]];

        let capsule = Capsule::new(
            vec![],
            0,
            "domain_id".to_string(),
            "capsule_id".to_string(),
            vec![],
            data_elements,
            vec![],
        );

        let adjusted_0 = &capsule.body.rows[0][0].tags[0];
        let adjusted_1 = &capsule.body.rows[0][0].tags[1];
        assert_eq!(adjusted_0.start, 2);
        assert_eq!(adjusted_0.end, 5);
        assert_eq!(adjusted_1.start, 4);
        assert_eq!(adjusted_1.end, 5);
    }

    pub fn create_sealed_capsule() -> SealedCapsule {
        create_test_capsule()
            .seal(vec![0; 32], &generate_random_nonce())
            .unwrap()
    }

    // Test sealing and then unsealing a Capsule
    #[test]
    fn test_seal_unseal_capsule() {
        let mut capsule = create_test_capsule();
        let dek = vec![0; 32];

        // Seal the capsule
        let sealed_capsule = capsule
            .seal(dek.clone(), &generate_random_nonce())
            .expect("Sealing failed");
        assert!(!sealed_capsule.body.is_empty());

        // Unseal the capsule
        let open_token = "token".to_string();
        let unsealed_capsule = sealed_capsule
            .unseal(dek, open_token.clone(), false)
            .expect("Unsealing failed");
        assert_eq!(unsealed_capsule.header.domain_id, capsule.header.domain_id);
        assert_eq!(
            unsealed_capsule.header.capsule_id,
            capsule.header.capsule_id
        );
        assert_eq!(unsealed_capsule.body.open_token, open_token);
    }
}