use super::{DecryptError, RetrieveKeyPayload};
use recipher::key::Iv;
use serde::{Deserialize, Serialize};
use std::{
borrow::Cow,
convert::Infallible,
fmt::{Debug, Display},
};
use thiserror::Error;
use uuid::Uuid;
use zerokms_protocol::Context;
#[derive(Debug, Error)]
#[error("{message}")]
pub struct DecryptableError {
pub message: String,
}
pub trait Decryptable: Debug {
type Error: std::error::Error;
fn keyset_id(&self) -> Option<Uuid>;
fn retrieve_key_payload<'a>(&'a self) -> Result<RetrieveKeyPayload<'a>, Self::Error>;
fn into_encrypted_record(self) -> Result<EncryptedRecord, Self::Error>;
}
#[cfg_attr(test, derive(PartialEq, Eq))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedRecord {
#[serde(with = "serde_bytes")]
pub iv: Iv,
#[serde(with = "serde_bytes")]
pub ciphertext: Vec<u8>,
#[serde(with = "serde_bytes")]
pub tag: Vec<u8>,
pub descriptor: String,
#[serde(default, rename = "dataset_id")]
pub keyset_id: Option<Uuid>,
}
impl Decryptable for EncryptedRecord {
type Error = Infallible;
fn keyset_id(&self) -> Option<Uuid> {
self.keyset_id
}
fn retrieve_key_payload(&self) -> Result<RetrieveKeyPayload<'_>, Self::Error> {
Ok(RetrieveKeyPayload::new(
self.iv,
self.descriptor.as_str(),
&self.tag,
))
}
fn into_encrypted_record(self) -> Result<EncryptedRecord, Self::Error> {
Ok(self)
}
}
impl Display for EncryptedRecord {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let descriptor = if self.descriptor.is_empty() {
"Undescribed".to_string()
} else {
self.descriptor.clone()
};
if let Some(keyset_id) = self.keyset_id {
write!(f, "({descriptor}, keyset_id: {keyset_id})")
} else {
write!(f, "({descriptor})")
}
}
}
impl EncryptedRecord {
#[deprecated(since = "0.12.4", note = "Use to_cbor_bytes or to_mp_bytes instead")]
pub fn to_vec(&self) -> Result<Vec<u8>, DecryptError> {
self.to_cbor_bytes()
}
pub fn to_cbor_bytes(&self) -> Result<Vec<u8>, DecryptError> {
serde_cbor::to_vec(&self).map_err(DecryptError::CborError)
}
pub fn to_mp_bytes(&self) -> Result<Vec<u8>, DecryptError> {
rmp_serde::to_vec(&self).map_err(DecryptError::SerializeError)
}
#[deprecated(since = "0.12.4", note = "Use to_cbor_hex or to_mp_base85 instead")]
pub fn to_hex(&self) -> Result<String, DecryptError> {
self.to_cbor_hex()
}
pub fn to_cbor_hex(&self) -> Result<String, DecryptError> {
self.to_cbor_bytes().map(hex::encode)
}
pub fn to_mp_base85(&self) -> Result<String, DecryptError> {
self.to_mp_bytes().map(|v| base85::encode(&v))
}
#[deprecated(
since = "0.12.4",
note = "Use from_cbor_bytes or from_mp_bytes instead"
)]
pub fn from_slice(bytes: &[u8]) -> Result<Self, DecryptError> {
Self::from_cbor_bytes(bytes)
}
pub fn from_cbor_bytes(bytes: &[u8]) -> Result<Self, DecryptError> {
serde_cbor::from_slice(bytes).map_err(DecryptError::CborError)
}
pub fn from_mp_bytes(bytes: &[u8]) -> Result<Self, DecryptError> {
rmp_serde::from_slice(bytes).map_err(DecryptError::DeserializeError)
}
#[deprecated(since = "0.12.4", note = "Use from_cbor_hex or from_mp_base85 instead")]
pub fn from_hex(hexstr: impl AsRef<[u8]>) -> Result<Self, DecryptError> {
let bytes = hex::decode(hexstr).map_err(DecryptError::HexDecodingError)?;
Self::from_cbor_bytes(&bytes)
}
pub fn from_cbor_hex(hexstr: &str) -> Result<Self, DecryptError> {
let bytes = hex::decode(hexstr).map_err(DecryptError::HexDecodingError)?;
Self::from_cbor_bytes(&bytes)
}
pub fn from_mp_base85(base85str: &str) -> Result<Self, DecryptError> {
let bytes = base85::decode(base85str)?;
Self::from_mp_bytes(&bytes)
}
pub fn with_context<'a>(self, context: Context) -> WithContext<'a> {
WithContext {
context: Cow::Owned(vec![context]),
record: self,
}
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct WithContext<'context, D = EncryptedRecord> {
#[serde(default, rename = "lockContext")]
pub context: Cow<'context, [Context]>,
#[serde(rename = "ciphertext")]
pub record: D,
}
impl<'context, D: Decryptable> Decryptable for WithContext<'context, D> {
type Error = <D as Decryptable>::Error;
fn keyset_id(&self) -> Option<Uuid> {
self.record.keyset_id()
}
fn retrieve_key_payload<'a>(&'a self) -> Result<RetrieveKeyPayload<'a>, Self::Error> {
self.record
.retrieve_key_payload()
.map(|payload| payload.with_context(self.context.clone()))
}
fn into_encrypted_record(self) -> Result<EncryptedRecord, Self::Error> {
self.record.into_encrypted_record()
}
}
pub mod formats {
pub mod mp_base85 {
use super::super::*;
use serde::Deserialize;
pub fn serialize<S>(ciphertext: &EncryptedRecord, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let s = ciphertext
.to_mp_base85()
.map_err(serde::ser::Error::custom)?;
serializer.serialize_str(&s)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<EncryptedRecord, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
EncryptedRecord::from_mp_base85(&s).map_err(serde::de::Error::custom)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(dead_code)]
fn unencoded_bytes_len(record: &EncryptedRecord) -> usize {
let len =
record.iv.len() + record.ciphertext.len() + record.tag.len() + record.descriptor.len();
if record.keyset_id.is_some() {
len + 16
} else {
len
}
}
fn record() -> EncryptedRecord {
EncryptedRecord {
iv: Iv::default(),
ciphertext: vec![1; 32],
tag: vec![1; 16],
descriptor: "users/name".to_string(),
keyset_id: Some(Uuid::new_v4()),
}
}
#[test]
fn test_encrypted_record_to_from_cbor() {
let record = record();
let bytes = record.to_cbor_bytes().unwrap();
let record2 = EncryptedRecord::from_cbor_bytes(&bytes).unwrap();
assert_eq!(record, record2);
}
#[test]
fn test_encrypted_record_to_from_mp() {
let record = record();
let bytes = record.to_mp_bytes().unwrap();
let record2 = EncryptedRecord::from_mp_bytes(&bytes).unwrap();
assert_eq!(record, record2);
}
#[test]
fn test_encrypted_record_cbor_hex() {
let record = record();
let hex = record.to_cbor_hex().unwrap();
let record2 = EncryptedRecord::from_cbor_hex(&hex).unwrap();
assert_eq!(record, record2);
}
#[test]
fn test_encrypted_record_mp_base85() {
let record = record();
let hex = record.to_mp_base85().unwrap();
let record2 = EncryptedRecord::from_mp_base85(&hex).unwrap();
assert_eq!(record, record2);
}
#[test]
fn test_json_serialisation() {
let record = record();
let json: serde_json::Value = serde_json::to_value(record).unwrap();
assert!(json.get("dataset_id").is_some());
assert!(json.get("keyset_id").is_none());
}
}