use crate::error::CryptError;
use aes_kw::KekAes256;
use anyhow::{Result, anyhow};
use blake3::traits::digest::Digest;
use chacha20poly1305::{
AeadInPlace, Tag, XChaCha20Poly1305, XNonce,
aead::{Aead, AeadCore, KeyInit},
};
use rand_core::CryptoRngCore;
use serde::{Deserialize, Serialize};
use skip_ratchet::Ratchet;
use std::fmt::Debug;
use wnfs_common::utils;
pub(crate) const NONCE_SIZE: usize = 24;
pub(crate) const AUTHENTICATION_TAG_SIZE: usize = 16;
pub const KEY_BYTE_SIZE: usize = 32;
pub(crate) const REVISION_SEGMENT_DSI: &str = "wnfs/1.0/revision segment derivation from ratchet";
pub(crate) const HIDING_SEGMENT_DSI: &str = "wnfs/1.0/hiding segment derivation from content key";
pub(crate) const BLOCK_SEGMENT_DSI: &str = "wnfs/1.0/segment derivation for file block";
pub(crate) const TEMPORAL_KEY_DSI: &str = "wnfs/1.0/temporal derivation from ratchet";
pub(crate) const SNAPSHOT_KEY_DSI: &str = "wnfs/1.0/snapshot key derivation from temporal";
#[derive(PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct SnapshotKey(#[serde(with = "serde_byte_array")] pub(crate) [u8; KEY_BYTE_SIZE]);
#[derive(PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct TemporalKey(#[serde(with = "serde_byte_array")] pub(crate) [u8; KEY_BYTE_SIZE]);
impl TemporalKey {
pub fn new(ratchet: &Ratchet) -> Self {
Self(ratchet.derive_key(TEMPORAL_KEY_DSI).finalize().into())
}
pub fn derive_snapshot_key(&self) -> SnapshotKey {
SnapshotKey(blake3::derive_key(SNAPSHOT_KEY_DSI, &self.0))
}
pub fn key_wrap_encrypt(&self, cleartext: &[u8]) -> Result<Vec<u8>> {
Ok(KekAes256::from(self.0)
.wrap_with_padding_vec(cleartext)
.map_err(|e| CryptError::UnableToEncrypt(anyhow!(e)))?)
}
pub fn key_wrap_decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
Ok(KekAes256::from(self.0)
.unwrap_with_padding_vec(ciphertext)
.map_err(|e| CryptError::UnableToEncrypt(anyhow!(e)))?)
}
pub fn as_bytes(&self) -> &[u8; KEY_BYTE_SIZE] {
&self.0
}
}
impl SnapshotKey {
pub fn new(rng: &mut impl CryptoRngCore) -> Self {
Self(utils::get_random_bytes(rng))
}
pub fn encrypt(&self, data: &[u8], rng: &mut impl CryptoRngCore) -> Result<Vec<u8>> {
let nonce = Self::generate_nonce(rng);
let key = self.0.into();
let cipher_text = XChaCha20Poly1305::new(&key)
.encrypt(&nonce, data)
.map_err(|e| CryptError::UnableToEncrypt(anyhow!(e)))?;
Ok([nonce.to_vec(), cipher_text].concat())
}
pub(crate) fn generate_nonce(rng: &mut impl CryptoRngCore) -> XNonce {
XChaCha20Poly1305::generate_nonce(rng)
}
pub(crate) fn encrypt_in_place(&self, nonce: &XNonce, buffer: &mut [u8]) -> Result<Tag> {
let key = self.0.into();
let tag = XChaCha20Poly1305::new(&key)
.encrypt_in_place_detached(nonce, &[], buffer)
.map_err(|e| CryptError::UnableToEncrypt(anyhow!(e)))?;
Ok(tag)
}
pub fn decrypt(&self, cipher_text: &[u8]) -> Result<Vec<u8>> {
let (nonce_bytes, data) = cipher_text.split_at(NONCE_SIZE);
let key = self.0.into();
let nonce = XNonce::from_slice(nonce_bytes);
Ok(XChaCha20Poly1305::new(&key)
.decrypt(nonce, data)
.map_err(|e| CryptError::UnableToDecrypt(anyhow!(e)))?)
}
#[allow(dead_code)] pub(crate) fn decrypt_in_place(
&self,
nonce: &XNonce,
tag: &Tag,
buffer: &mut [u8],
) -> Result<()> {
let key = self.0.into();
XChaCha20Poly1305::new(&key)
.decrypt_in_place_detached(nonce, &[], buffer, tag)
.map_err(|e| CryptError::UnableToDecrypt(anyhow!(e)))?;
Ok(())
}
pub fn as_bytes(&self) -> &[u8; KEY_BYTE_SIZE] {
&self.0
}
}
impl Debug for SnapshotKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("SnapshotKey")
.field(&hex::encode(&self.0[..8]))
.finish()
}
}
impl Debug for TemporalKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("TemporalKey")
.field(&hex::encode(&self.0[..8]))
.finish()
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::{prelude::any, prop_assert_eq, prop_assert_ne};
use rand_chacha::ChaCha12Rng;
use rand_core::SeedableRng;
use test_strategy::proptest;
#[proptest(cases = 100)]
fn snapshot_key_can_encrypt_and_decrypt_data(
#[strategy(any::<Vec<u8>>())] data: Vec<u8>,
#[strategy(any::<[u8; KEY_BYTE_SIZE]>())] rng_seed: [u8; KEY_BYTE_SIZE],
key_bytes: [u8; KEY_BYTE_SIZE],
) {
let key = SnapshotKey(key_bytes);
let rng = &mut ChaCha12Rng::from_seed(rng_seed);
let encrypted = key.encrypt(&data, rng).unwrap();
let decrypted = key.decrypt(&encrypted).unwrap();
if data.len() >= 16 {
let cipher_part = &encrypted[NONCE_SIZE..NONCE_SIZE + data.len()];
prop_assert_ne!(cipher_part, &decrypted);
}
prop_assert_eq!(&decrypted, &data);
}
#[proptest(cases = 100)]
fn snapshot_key_can_encrypt_and_decrypt_data_in_place(
data: Vec<u8>,
key_bytes: [u8; KEY_BYTE_SIZE],
nonce: [u8; NONCE_SIZE],
) {
let mut buffer = data.clone();
let nonce = XNonce::from_slice(&nonce);
let key = SnapshotKey(key_bytes);
let tag = key.encrypt_in_place(nonce, &mut buffer).unwrap();
if buffer.len() >= 16 {
prop_assert_ne!(&buffer, &data);
}
key.decrypt_in_place(nonce, &tag, &mut buffer).unwrap();
prop_assert_eq!(&buffer, &data);
}
}