use super::secret::Secret;
use aead::{generic_array::typenum::Unsigned, KeySizeUser, Payload};
use chacha20poly1305::{
aead::{Aead, AeadCore, KeyInit, OsRng},
Key, XChaCha20Poly1305,
};
use cid::Cid;
use co_primitives::{from_cbor, to_cbor, Block, KnownMultiCodec, MultiCodec, MultiCodecError};
use derive_more::From;
use multihash_codetable::{Code, MultihashDigest};
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use std::{cmp::min, collections::BTreeMap, fmt::Debug, mem::take};
pub const BLOCK_KEY_DERIVATION: &str = "co 2023-10-24T10:25:23Z block key derivation v1";
pub const BLOCK_DERIVATION: &str = "co 2023-10-26T14:31:38Z block derivation v1";
pub const BLOCK_MULTICODEC: u64 = KnownMultiCodec::CoEncryptedBlock as u64;
pub type Nonce = Vec<u8>;
pub type Salt = Vec<u8>;
pub type CipherU8 = u8;
#[derive(Debug, thiserror::Error)]
pub enum AlgorithmError {
#[error("Generic Cipher Error")]
Cipher,
#[error("Invalid arguments specified")]
InvalidArguments(#[source] anyhow::Error),
#[error("Generic decoding error")]
Decoding,
#[error("Generic encoding error")]
Encoding,
#[error("Size is to large")]
Size,
}
impl From<aead::Error> for AlgorithmError {
fn from(_: aead::Error) -> Self {
AlgorithmError::Cipher
}
}
impl From<MultiCodecError> for AlgorithmError {
fn from(value: MultiCodecError) -> Self {
AlgorithmError::InvalidArguments(value.into())
}
}
#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr, PartialEq)]
#[repr(u8)]
#[derive(Default)]
pub enum Algorithm {
#[default]
XChaCha20Poly1305 = 1,
}
impl Algorithm {
pub fn key_size(&self) -> usize {
match self {
Algorithm::XChaCha20Poly1305 => XChaCha20Poly1305::key_size(),
}
}
pub fn nonce_size(&self) -> usize {
match self {
Algorithm::XChaCha20Poly1305 => <XChaCha20Poly1305 as AeadCore>::NonceSize::USIZE,
}
}
pub fn tag_size(&self) -> usize {
match self {
Algorithm::XChaCha20Poly1305 => <XChaCha20Poly1305 as AeadCore>::TagSize::USIZE,
}
}
pub fn generate_serect(&self) -> Secret {
match self {
Algorithm::XChaCha20Poly1305 => Secret::new(XChaCha20Poly1305::generate_key(&mut OsRng).to_vec()),
}
}
pub fn generate_nonce(&self) -> Nonce {
match self {
Algorithm::XChaCha20Poly1305 => XChaCha20Poly1305::generate_nonce(&mut OsRng).to_vec(),
}
}
pub fn encrypt(
&self,
secret: &Secret,
nonce: &Nonce,
plaintext: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, AlgorithmError> {
if self.nonce_size() != nonce.len() {
return Err(AlgorithmError::InvalidArguments(anyhow::anyhow!("nonce size")));
}
if self.key_size() != secret.divulge().len() {
return Err(AlgorithmError::InvalidArguments(anyhow::anyhow!("key size")));
}
match self {
Algorithm::XChaCha20Poly1305 => {
let cipher = XChaCha20Poly1305::new(Key::from_slice(secret.divulge()));
let payload = Payload { msg: plaintext, aad };
cipher
.encrypt(aead::Nonce::<XChaCha20Poly1305>::from_slice(nonce.as_slice()), payload)
.map_err(|e| e.into())
},
}
}
pub fn decrypt(
&self,
secret: &Secret,
nonce: &Nonce,
ciphertext: &[CipherU8],
aad: &[u8],
) -> Result<Vec<u8>, AlgorithmError> {
if self.nonce_size() != nonce.len() {
return Err(AlgorithmError::InvalidArguments(anyhow::anyhow!("nonce size")));
}
if self.key_size() != secret.divulge().len() {
return Err(AlgorithmError::InvalidArguments(anyhow::anyhow!("key size")));
}
match self {
Algorithm::XChaCha20Poly1305 => {
let cipher = XChaCha20Poly1305::new(Key::from_slice(secret.divulge()));
let payload = Payload { msg: ciphertext, aad };
cipher
.decrypt(aead::Nonce::<XChaCha20Poly1305>::from_slice(nonce.as_slice()), payload)
.map_err(|e| e.into())
},
}
}
}
#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr, PartialEq)]
#[repr(u8)]
pub enum EncryptionVersion {
V1 = 1,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedBlock {
#[serde(rename = "h")]
pub header: Header,
#[serde(rename = "d")]
pub payload: EncryptedData,
}
impl EncryptedBlock {
pub fn encrypt(
algorithm: Algorithm,
secret: &Secret,
block: impl Into<BlockPayload>,
) -> Result<EncryptedBlock, AlgorithmError> {
let block_secret = algorithm.generate_serect();
Self::encrypt_with_block_secret(algorithm, secret, &block_secret, block)
}
pub fn encrypt_with_block_secret(
algorithm: Algorithm,
secret: &Secret,
block_secret: &Secret,
block: impl Into<BlockPayload>,
) -> Result<EncryptedBlock, AlgorithmError> {
let block: BlockPayload = block.into();
let data_secret = block_secret.derive_serect(BLOCK_DERIVATION);
let key_slot = KeySlot::new(algorithm, secret, block_secret)?;
let header = Header::new(algorithm, vec![key_slot]);
let aad = header.aad();
let data = block.to_bytes().map_err(|_e| AlgorithmError::Encoding)?;
Ok(Self {
payload: header
.algorithm
.encrypt(&data_secret, &header.nonce, data.as_slice(), aad.as_slice())?
.into(),
header,
})
}
pub fn block(&self, secret: &Secret) -> Result<BlockPayload, AlgorithmError> {
let block_secret = self
.header
.block_secret(secret)
.ok_or(AlgorithmError::InvalidArguments(anyhow::anyhow!("key")))?;
let aad = self.header.aad();
let data = self
.payload
.inline()
.ok_or(AlgorithmError::InvalidArguments(anyhow::anyhow!("Expected inline data")))?;
let data_plain = self.decrypt_data(&block_secret, data, &aad)?;
from_cbor(&data_plain).map_err(|err| AlgorithmError::InvalidArguments(err.into()))
}
fn decrypt_data(&self, block_secret: &Secret, data: &[u8], aad: &[u8]) -> Result<Vec<u8>, AlgorithmError> {
let data_secret = block_secret.derive_serect(BLOCK_DERIVATION);
let data = self.header.algorithm.decrypt(&data_secret, &self.header.nonce, data, aad)?;
Ok(data)
}
pub fn is_valid(&self) -> bool {
self.header.is_valid()
}
}
impl TryInto<Block> for EncryptedBlock {
type Error = AlgorithmError;
fn try_into(self) -> Result<Block, Self::Error> {
let encrypted_data = to_cbor(&self).map_err(|_| AlgorithmError::Encoding)?;
let mh = Code::Blake3_256.digest(&encrypted_data);
let cid = Cid::new_v1(KnownMultiCodec::CoEncryptedBlock.into(), mh);
Ok(Block::new_unchecked(cid, encrypted_data))
}
}
impl TryFrom<Block> for EncryptedBlock {
type Error = AlgorithmError;
fn try_from(value: Block) -> Result<Self, Self::Error> {
MultiCodec::with_codec(KnownMultiCodec::CoEncryptedBlock, value.cid())?;
let block: EncryptedBlock = from_cbor(value.data()).map_err(|_| AlgorithmError::Decoding)?;
if !block.is_valid() {
return Err(AlgorithmError::Decoding);
}
Ok(block)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, From)]
#[serde(untagged)]
pub enum EncryptedData {
#[from]
#[serde(with = "serde_bytes")]
Inline(Vec<CipherU8>),
#[from]
Block(Vec<Cid>),
}
impl EncryptedData {
pub fn inline(&self) -> Option<&[u8]> {
match self {
Self::Inline(data) => Some(data),
_ => None,
}
}
pub fn blocks(&self) -> Option<&[Cid]> {
match self {
Self::Block(data) => Some(data),
_ => None,
}
}
pub fn fit_into_blocks(&mut self, max_block_size: usize, inline_offset: Option<usize>) -> Vec<Block> {
let mut data = match self {
Self::Inline(data) => {
if max_block_size >= data.len() + inline_offset.unwrap_or(0) {
return vec![];
} else {
take(data)
}
},
Self::Block(_) => {
return vec![];
},
};
let mut extra_blocks = Vec::new();
while !data.is_empty() {
let rest = data.split_off(min(data.len(), max_block_size));
extra_blocks.push(Block::new_data(KnownMultiCodec::Raw, data));
data = rest;
}
*self = Self::Block(extra_blocks.iter().map(|block| *block.cid()).collect());
extra_blocks
}
pub fn try_inline_blocks(&mut self, blocks: impl IntoIterator<Item = (Cid, Vec<u8>)>) -> Result<(), ()> {
match self {
Self::Inline(_) => Ok(()),
Self::Block(cids) => {
let mut blocks: BTreeMap<Cid, Vec<u8>> = blocks.into_iter().collect();
if !cids.iter().all(|cid| blocks.contains_key(cid)) {
return Err(());
}
let mut inline = Vec::new();
for cid in cids {
if let Some(mut block) = blocks.remove(cid) {
inline.append(&mut block);
} else {
return Err(());
}
}
*self = Self::Inline(inline);
Ok(())
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BlockPayload {
#[serde(rename = "c")]
pub cid: Cid,
#[serde(rename = "r", default, skip_serializing_if = "BTreeMap::is_empty")]
pub references: BTreeMap<Cid, Cid>,
#[serde(with = "serde_bytes", rename = "d")]
pub data: Vec<u8>,
}
impl BlockPayload {
pub fn cid(&self) -> &Cid {
&self.cid
}
pub fn to_bytes(&self) -> Result<Vec<u8>, anyhow::Error> {
Ok(to_cbor(self)?)
}
}
impl From<Block> for BlockPayload {
fn from(value: Block) -> Self {
let (cid, data) = value.into_inner();
Self { cid, data, references: Default::default() }
}
}
impl From<BlockPayload> for Block {
fn from(value: BlockPayload) -> Self {
Block::new_unchecked(value.cid, value.data)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Header {
#[serde(rename = "v")]
pub version: EncryptionVersion,
#[serde(rename = "a")]
pub algorithm: Algorithm,
#[serde(rename = "k")]
pub key_slots: Vec<KeySlot>,
#[serde(rename = "n", with = "serde_bytes")]
pub nonce: Nonce,
}
impl Header {
pub fn new(algorithm: Algorithm, key_slots: Vec<KeySlot>) -> Self {
Self { version: EncryptionVersion::V1, algorithm, nonce: algorithm.generate_nonce(), key_slots }
}
pub fn is_valid(&self) -> bool {
self.version == EncryptionVersion::V1
&& self.nonce.len() == self.algorithm.nonce_size()
&& self.key_slots.iter().all(KeySlot::is_valid)
}
pub fn aad(&self) -> Vec<u8> {
let mut result = Vec::with_capacity(1 + 1 + self.nonce.len());
result.extend([self.version as u8, self.algorithm as u8].iter());
result.extend(self.nonce.iter());
result
}
pub fn block_secret(&self, secret: &Secret) -> Option<Secret> {
self.key_slots
.iter()
.map(|key_slot| key_slot.block_secret(secret))
.filter_map(|r| r.ok())
.next()
}
pub fn encoded_size(algorithm: Algorithm) -> usize {
let field_size = 1;
let cbor_size = 1;
cbor_size
+ 1 + field_size + cbor_size
+ 1 + field_size + cbor_size
+ KeySlot::encoded_size(algorithm) + field_size + cbor_size + cbor_size
+ algorithm.nonce_size() + field_size + cbor_size + cbor_size + cbor_size
}
}
#[derive(Debug, Clone, Copy, Serialize_repr, Deserialize_repr, PartialEq)]
#[repr(u8)]
pub enum KeySlotVersion {
V1 = 1,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct KeySlot {
#[serde(rename = "v")]
pub version: KeySlotVersion,
#[serde(rename = "a")]
pub algorithm: Algorithm,
#[serde(rename = "k", with = "serde_bytes")]
pub key: Vec<CipherU8>,
#[serde(rename = "s", with = "serde_bytes")]
pub salt: Salt,
#[serde(rename = "n", with = "serde_bytes")]
pub nonce: Nonce,
}
impl KeySlot {
pub fn encoded_size(algorithm: Algorithm) -> usize {
let tag_size = algorithm.tag_size();
let field_size = 1;
let cbor_size = 1;
cbor_size
+ 1 + field_size + cbor_size
+ 1 + field_size + cbor_size
+ algorithm.key_size() + field_size + tag_size + cbor_size + cbor_size + cbor_size
+ algorithm.nonce_size() + field_size + cbor_size + cbor_size + cbor_size
+ algorithm.nonce_size() + field_size + cbor_size + cbor_size + cbor_size
}
pub fn new(algorithm: Algorithm, secret: &Secret, block_secret: &Secret) -> Result<Self, AlgorithmError> {
let salt = algorithm.generate_nonce(); let secret_derived = secret.derive_serect_with_salt(BLOCK_KEY_DERIVATION, &salt);
let nonce = algorithm.generate_nonce();
let block_secret_encrypted = algorithm.encrypt(&secret_derived, &nonce, block_secret.divulge(), b"")?;
Ok(Self { version: KeySlotVersion::V1, algorithm, key: block_secret_encrypted, nonce, salt })
}
pub fn is_valid(&self) -> bool {
self.version == KeySlotVersion::V1
&& self.key.len() == self.algorithm.key_size() + self.algorithm.tag_size()
&& self.nonce.len() == self.algorithm.nonce_size()
}
pub fn block_secret(&self, secret: &Secret) -> Result<Secret, AlgorithmError> {
let secret_derived = secret.derive_serect_with_salt(BLOCK_KEY_DERIVATION, &self.salt);
let block_secret = self.algorithm.decrypt(&secret_derived, &self.nonce, self.key.as_slice(), b"")?;
Ok(Secret::new(block_secret))
}
}
#[cfg(test)]
mod tests {
use super::{Algorithm, EncryptedBlock, Header, KeySlot};
use crate::crypto::{block::EncryptedData, secret::Secret};
use cid::Cid;
use co_primitives::{from_cbor, to_cbor, Block, BlockSerializer, DefaultParams, KnownMultiCodec, StoreParams};
use std::iter::repeat_n;
#[test]
fn algorithm_key_size() {
assert_eq!(Algorithm::XChaCha20Poly1305.key_size(), 32);
}
#[test]
fn algorithm_nonce_size() {
assert_eq!(Algorithm::XChaCha20Poly1305.nonce_size(), 24);
}
#[test]
fn is_valid() {
let secret = Secret::new(repeat_n(0u8, Algorithm::default().key_size()).collect());
let block_secret = Secret::new(repeat_n(1u8, Algorithm::default().key_size()).collect());
let key_slot = KeySlot::new(Algorithm::default(), &secret, &block_secret).unwrap();
let header = Header::new(Algorithm::default(), vec![key_slot]);
assert!(header.is_valid());
}
#[test]
fn serialize_header() {
let secret = Secret::new(repeat_n(0u8, Algorithm::default().key_size()).collect());
let block_secret = Secret::new(repeat_n(1u8, Algorithm::default().key_size()).collect());
let key_slot = KeySlot::new(Algorithm::default(), &secret, &block_secret).unwrap();
let header = Header::new(Algorithm::default(), vec![key_slot]);
let bytes = to_cbor(&header).unwrap();
assert_eq!(bytes.len(), 153);
let header_deserialized: Header = from_cbor(bytes.as_slice()).unwrap();
assert_eq!(header_deserialized, header);
assert!(header.is_valid());
}
#[test]
fn key_slot_encoded_size() {
let secret = Secret::new(repeat_n(0u8, Algorithm::default().key_size()).collect());
let block_secret = Secret::new(repeat_n(1u8, Algorithm::default().key_size()).collect());
let key_slot = KeySlot::new(Algorithm::default(), &secret, &block_secret).unwrap();
let bytes = to_cbor(&key_slot).unwrap();
assert_eq!(bytes.len(), KeySlot::encoded_size(Algorithm::default()));
}
#[test]
fn header_encoded_size() {
let secret = Secret::new(repeat_n(0u8, Algorithm::default().key_size()).collect());
let block_secret = Secret::new(repeat_n(1u8, Algorithm::default().key_size()).collect());
let key_slot = KeySlot::new(Algorithm::default(), &secret, &block_secret).unwrap();
let header = Header::new(Algorithm::default(), vec![key_slot]);
let bytes = to_cbor(&header).unwrap();
assert_eq!(bytes.len(), Header::encoded_size(Algorithm::default()));
}
#[test]
fn encrypt_block_roundtrip() {
let secret = Secret::new(repeat_n(0u8, Algorithm::default().key_size()).collect());
let block = BlockSerializer::default().serialize(&"Hello World!").unwrap();
let encrypted_block = EncryptedBlock::encrypt(Algorithm::default(), &secret, block.clone()).unwrap();
assert_ne!(encrypted_block.payload.inline().unwrap(), block.data());
let encrypted_block_bytes = to_cbor(&encrypted_block).unwrap();
assert_eq!(encrypted_block_bytes.len(), 236);
let encrypted_block_deserialized: EncryptedBlock = from_cbor(&encrypted_block_bytes).unwrap();
let decrypted_block = encrypted_block_deserialized.block(&secret).unwrap();
assert_eq!(decrypted_block.cid(), block.cid());
assert_eq!(&decrypted_block.data, block.data());
}
#[test]
fn test_fit_to_blocks() {
let secret = Secret::new(repeat_n(0u8, Algorithm::default().key_size()).collect());
let data: Vec<u8> = repeat_n(0u8, DefaultParams::MAX_BLOCK_SIZE).collect();
let block = Block::new_data(KnownMultiCodec::Raw, data);
let mut encrypted_block = EncryptedBlock::encrypt(Algorithm::default(), &secret, block.clone()).unwrap();
let encrypted_extra_blocks = encrypted_block
.payload
.fit_into_blocks(DefaultParams::MAX_BLOCK_SIZE, Some(Header::encoded_size(Algorithm::default())));
assert!(match &encrypted_block.payload {
EncryptedData::Block(blocks) =>
blocks == &encrypted_extra_blocks.iter().map(|b| *b.cid()).collect::<Vec<Cid>>(),
_ => false,
});
encrypted_block
.payload
.try_inline_blocks(encrypted_extra_blocks.into_iter().map(|v| v.into_inner()))
.unwrap();
let decrypted_block = encrypted_block.block(&secret).unwrap();
assert_eq!(decrypted_block.cid(), block.cid());
assert_eq!(&decrypted_block.data, block.data());
}
}