cipherstash-client 0.34.1-alpha.1

The official CipherStash SDK
Documentation
use super::{DecryptError, RetrieveKeyPayload};
use recipher::key::Iv;
use serde::{Deserialize, Serialize};
use std::{
    borrow::Cow,
    convert::Infallible,
    fmt::{Debug, Display},
};
use thiserror::Error;
use uuid::Uuid;
use zerokms_protocol::Context;

#[derive(Debug, Error)]
#[error("{message}")]
pub struct DecryptableError {
    pub message: String,
}

pub trait Decryptable: Debug {
    type Error: std::error::Error;

    /// The keyset ID associated with the record, if any.
    fn keyset_id(&self) -> Option<Uuid>;

    /// The payload used to retrieve the key for decryption.
    fn retrieve_key_payload<'a>(&'a self) -> Result<RetrieveKeyPayload<'a>, Self::Error>;

    /// Convert the record into an EncryptedRecord.
    fn into_encrypted_record(self) -> Result<EncryptedRecord, Self::Error>;
}

/// Represents an encrypted record for storage in the database.
/// Implements serialization and deserialization so you can use it with any serde-compatible format
/// however convenience methods are provided for CBOR, MessagePack, Hex, and Base85 encoding.
#[cfg_attr(test, derive(PartialEq, Eq))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedRecord {
    #[serde(with = "serde_bytes")]
    pub iv: Iv,
    #[serde(with = "serde_bytes")]
    pub ciphertext: Vec<u8>,
    #[serde(with = "serde_bytes")]
    pub tag: Vec<u8>,
    pub descriptor: String,
    #[serde(default, rename = "dataset_id")]
    pub keyset_id: Option<Uuid>,
}

impl Decryptable for EncryptedRecord {
    type Error = Infallible;

    fn keyset_id(&self) -> Option<Uuid> {
        self.keyset_id
    }

    fn retrieve_key_payload(&self) -> Result<RetrieveKeyPayload<'_>, Self::Error> {
        Ok(RetrieveKeyPayload::new(
            self.iv,
            self.descriptor.as_str(),
            &self.tag,
        ))
    }

    fn into_encrypted_record(self) -> Result<EncryptedRecord, Self::Error> {
        Ok(self)
    }
}

impl Display for EncryptedRecord {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let descriptor = if self.descriptor.is_empty() {
            "Undescribed".to_string()
        } else {
            self.descriptor.clone()
        };

        if let Some(keyset_id) = self.keyset_id {
            write!(f, "({descriptor}, keyset_id: {keyset_id})")
        } else {
            write!(f, "({descriptor})")
        }
    }
}

impl EncryptedRecord {
    #[deprecated(since = "0.12.4", note = "Use to_cbor_bytes or to_mp_bytes instead")]
    /// Serialize the record to a `Vec<u8>` using CBOR encoding
    pub fn to_vec(&self) -> Result<Vec<u8>, DecryptError> {
        self.to_cbor_bytes()
    }

    /// Serialize the record to a `Vec<u8>` using CBOR encoding
    pub fn to_cbor_bytes(&self) -> Result<Vec<u8>, DecryptError> {
        serde_cbor::to_vec(&self).map_err(DecryptError::CborError)
    }

    /// Serialize the record to a `Vec<u8>` using MessagePack encoding
    pub fn to_mp_bytes(&self) -> Result<Vec<u8>, DecryptError> {
        rmp_serde::to_vec(&self).map_err(DecryptError::SerializeError)
    }

    #[deprecated(since = "0.12.4", note = "Use to_cbor_hex or to_mp_base85 instead")]
    /// Serialize the record using CBOR and convert to a hex-encoded string
    pub fn to_hex(&self) -> Result<String, DecryptError> {
        self.to_cbor_hex()
    }

    /// Serialize the record using CBOR and convert to a hex-encoded string
    pub fn to_cbor_hex(&self) -> Result<String, DecryptError> {
        self.to_cbor_bytes().map(hex::encode)
    }

    /// Serialize the record using MessagePack and convert to a base85-encoded string
    pub fn to_mp_base85(&self) -> Result<String, DecryptError> {
        self.to_mp_bytes().map(|v| base85::encode(&v))
    }

    #[deprecated(
        since = "0.12.4",
        note = "Use from_cbor_bytes or from_mp_bytes instead"
    )]
    /// Deserialize a record from a slice of bytes encoded using CBOR
    pub fn from_slice(bytes: &[u8]) -> Result<Self, DecryptError> {
        Self::from_cbor_bytes(bytes)
    }

    /// Deserialize a record from a slice of bytes encoded using CBOR
    pub fn from_cbor_bytes(bytes: &[u8]) -> Result<Self, DecryptError> {
        serde_cbor::from_slice(bytes).map_err(DecryptError::CborError)
    }

    /// Deserialize a record from a slice of bytes encoded using MessagePack
    pub fn from_mp_bytes(bytes: &[u8]) -> Result<Self, DecryptError> {
        rmp_serde::from_slice(bytes).map_err(DecryptError::DeserializeError)
    }

    #[deprecated(since = "0.12.4", note = "Use from_cbor_hex or from_mp_base85 instead")]
    /// Deserialize a record from a hex-encoded string of bytes encoded using CBOR
    pub fn from_hex(hexstr: impl AsRef<[u8]>) -> Result<Self, DecryptError> {
        let bytes = hex::decode(hexstr).map_err(DecryptError::HexDecodingError)?;

        Self::from_cbor_bytes(&bytes)
    }

    /// Deserialize a record from a hex-encoded string of bytes encoded using CBOR
    pub fn from_cbor_hex(hexstr: &str) -> Result<Self, DecryptError> {
        let bytes = hex::decode(hexstr).map_err(DecryptError::HexDecodingError)?;

        Self::from_cbor_bytes(&bytes)
    }

    /// Deserialize a record from a base85-encoded string of bytes encoded using MessagePack
    pub fn from_mp_base85(base85str: &str) -> Result<Self, DecryptError> {
        let bytes = base85::decode(base85str)?;
        Self::from_mp_bytes(&bytes)
    }

    /// Add a context to the record to prepare it for decryption.
    pub fn with_context<'a>(self, context: Context) -> WithContext<'a> {
        WithContext {
            context: Cow::Owned(vec![context]),
            record: self,
        }
    }
}

#[derive(Debug, Deserialize, Serialize)]
pub struct WithContext<'context, D = EncryptedRecord> {
    #[serde(default, rename = "lockContext")]
    pub context: Cow<'context, [Context]>,
    #[serde(rename = "ciphertext")]
    pub record: D,
}

impl<'context, D: Decryptable> Decryptable for WithContext<'context, D> {
    type Error = <D as Decryptable>::Error;

    fn keyset_id(&self) -> Option<Uuid> {
        self.record.keyset_id()
    }

    fn retrieve_key_payload<'a>(&'a self) -> Result<RetrieveKeyPayload<'a>, Self::Error> {
        self.record
            .retrieve_key_payload()
            .map(|payload| payload.with_context(self.context.clone()))
    }

    fn into_encrypted_record(self) -> Result<EncryptedRecord, Self::Error> {
        self.record.into_encrypted_record()
    }
}

pub mod formats {
    pub mod mp_base85 {
        use super::super::*;
        use serde::Deserialize;

        pub fn serialize<S>(ciphertext: &EncryptedRecord, serializer: S) -> Result<S::Ok, S::Error>
        where
            S: serde::Serializer,
        {
            let s = ciphertext
                .to_mp_base85()
                .map_err(serde::ser::Error::custom)?;
            serializer.serialize_str(&s)
        }

        pub fn deserialize<'de, D>(deserializer: D) -> Result<EncryptedRecord, D::Error>
        where
            D: serde::Deserializer<'de>,
        {
            let s = String::deserialize(deserializer)?;
            EncryptedRecord::from_mp_base85(&s).map_err(serde::de::Error::custom)
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    // Used for verifying encoded lengths
    #[allow(dead_code)]
    fn unencoded_bytes_len(record: &EncryptedRecord) -> usize {
        let len =
            record.iv.len() + record.ciphertext.len() + record.tag.len() + record.descriptor.len();
        if record.keyset_id.is_some() {
            len + 16
        } else {
            len
        }
    }

    fn record() -> EncryptedRecord {
        EncryptedRecord {
            iv: Iv::default(),
            ciphertext: vec![1; 32],
            tag: vec![1; 16],
            descriptor: "users/name".to_string(),
            keyset_id: Some(Uuid::new_v4()),
        }
    }

    #[test]
    fn test_encrypted_record_to_from_cbor() {
        // 90 bytes raw expands to 137 bytes encoded
        let record = record();
        let bytes = record.to_cbor_bytes().unwrap();
        let record2 = EncryptedRecord::from_cbor_bytes(&bytes).unwrap();

        assert_eq!(record, record2);
    }

    #[test]
    fn test_encrypted_record_to_from_mp() {
        // 90 bytes raw expands to 100 bytes encoded
        let record = record();
        let bytes = record.to_mp_bytes().unwrap();
        let record2 = EncryptedRecord::from_mp_bytes(&bytes).unwrap();

        assert_eq!(record, record2);
    }

    #[test]
    fn test_encrypted_record_cbor_hex() {
        // 90 bytes raw expands to 274 bytes encoded! ~200% expansion
        let record = record();
        let hex = record.to_cbor_hex().unwrap();
        let record2 = EncryptedRecord::from_cbor_hex(&hex).unwrap();

        assert_eq!(record, record2);
    }

    #[test]
    fn test_encrypted_record_mp_base85() {
        // 90 bytes raw expands to 125 bytes encoded! ~38% expansion
        let record = record();
        let hex = record.to_mp_base85().unwrap();
        let record2 = EncryptedRecord::from_mp_base85(&hex).unwrap();

        assert_eq!(record, record2);
    }

    #[test]
    fn test_json_serialisation() {
        let record = record();
        let json: serde_json::Value = serde_json::to_value(record).unwrap();

        assert!(json.get("dataset_id").is_some());
        assert!(json.get("keyset_id").is_none());
    }
}