#![allow(unused_assignments)]
use chrono::Utc;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use zeroize::{Zeroize, Zeroizing};
use crate::{
modules::crypto::{AEADCipher, AES, AESKeySize, BlockCipher, CipherMode},
shamir::ShamirSecret,
utils::BHashSet,
};
#[derive(Debug, Error)]
pub enum SealBoxError {
#[error("SealBox is sealed")]
Sealed,
#[error("SealBox is not sealed")]
NotSealed,
#[error("SealBox is unsealing")]
Unsealing,
#[error("Decryption failed")]
DecryptionFailed,
#[error("Unsealing failed: insufficient or invalid shares")]
UnsealFailed,
#[error("Unsealing failed: deprecated share")]
UnsealKeyDeprecated,
#[error("Encryption failed")]
EncryptionFailed,
#[error("Shamir secret split failed")]
ShamirSecretSplitFailed,
#[error("Shamir secret combine failed")]
ShamirSecretCombineFailed,
}
#[derive(Default, Serialize, Deserialize, Zeroize)]
#[zeroize(drop)]
pub struct SealBox<T> {
sealed_data: Vec<u8>,
nonce: [u8; 16],
aad: [u8; 13],
tag: [u8; 16],
threshold: u8,
total_shares: u8,
#[serde(skip)]
shares: Option<Vec<Vec<u8>>>,
#[serde(skip)]
key: Option<[u8; 32]>,
#[serde(skip)]
#[zeroize(skip)]
value: Option<T>,
#[zeroize(skip)]
deprecated_shares: BHashSet,
}
impl<T> SealBox<T>
where
T: Serialize + for<'de> Deserialize<'de>,
{
pub fn new(data: T, threshold: u8, total_shares: u8) -> Result<Self, SealBoxError> {
if threshold < 2 || total_shares < threshold {
return Err(SealBoxError::ShamirSecretSplitFailed);
}
let serialized = serde_json::to_vec(&data).unwrap();
let now_ms = Utc::now()
.timestamp_millis()
.to_string()
.as_bytes()
.to_vec();
let mut aes_encrypter = AES::new(
true,
Some(AESKeySize::AES256),
Some(CipherMode::GCM),
None,
None,
)
.map_err(|_| SealBoxError::EncryptionFailed)?;
aes_encrypter
.set_aad(now_ms.clone())
.map_err(|_| SealBoxError::EncryptionFailed)?;
let encrypted = aes_encrypter
.encrypt(&serialized)
.map_err(|_| SealBoxError::EncryptionFailed)?;
let mut tag: [u8; 16] = [0; 16];
tag[..16].copy_from_slice(
&aes_encrypter
.get_tag()
.map_err(|_| SealBoxError::EncryptionFailed)?,
);
let mut key: [u8; 32] = [0; 32];
key[..32].copy_from_slice(&aes_encrypter.get_key_iv().0);
let mut nonce: [u8; 16] = [0; 16];
nonce[..16].copy_from_slice(&aes_encrypter.get_key_iv().1);
let mut aad: [u8; 13] = [0; 13];
aad[..13].copy_from_slice(&now_ms);
Ok(Self {
sealed_data: encrypted,
nonce,
aad,
tag,
threshold,
total_shares,
shares: None,
key: Some(key),
value: Some(data),
deprecated_shares: BHashSet::default(),
})
}
pub fn generate_shares(&self) -> Result<Zeroizing<Vec<Vec<u8>>>, SealBoxError> {
if !self.is_unsealed() {
return Err(SealBoxError::Sealed);
}
let key = self.key.as_ref().ok_or(SealBoxError::Sealed)?;
let shares = ShamirSecret::split(key, self.total_shares, self.threshold)
.map_err(|_| SealBoxError::ShamirSecretSplitFailed)?;
Ok(shares)
}
fn do_unseal(&mut self, unseal_key: &[u8]) -> Result<(), SealBoxError> {
if self.is_unsealed() {
return Err(SealBoxError::NotSealed);
}
if self.deprecated_shares.contains(unseal_key) {
return Err(SealBoxError::UnsealKeyDeprecated);
}
let Some(shares) = self.shares.as_mut() else {
self.shares = Some(vec![unseal_key.to_vec()]);
return Err(SealBoxError::Unsealing);
};
if shares.len() < self.threshold as usize {
shares.push(unseal_key.to_vec());
}
if shares.len() < self.threshold as usize {
return Err(SealBoxError::Unsealing);
}
let key = ShamirSecret::combine(shares.clone()).ok_or(SealBoxError::UnsealFailed)?;
let mut aes_decrypter = AES::new(
false,
Some(AESKeySize::AES256),
Some(CipherMode::GCM),
Some(key.to_vec()),
Some(self.nonce.to_vec()),
)
.map_err(|_| SealBoxError::DecryptionFailed)?;
aes_decrypter
.set_aad(self.aad.to_vec())
.map_err(|_| SealBoxError::DecryptionFailed)?;
aes_decrypter
.set_tag(self.tag.to_vec())
.map_err(|_| SealBoxError::DecryptionFailed)?;
let decrypted = aes_decrypter
.decrypt(&self.sealed_data)
.map_err(|_| SealBoxError::DecryptionFailed)?;
let value: T =
serde_json::from_slice(&decrypted).map_err(|_| SealBoxError::DecryptionFailed)?;
let key: [u8; 32] = key.try_into().map_err(|_| SealBoxError::UnsealFailed)?;
self.key = Some(key);
self.value = Some(value);
Ok(())
}
pub fn unseal(&mut self, unseal_key: &[u8]) -> Result<(), SealBoxError> {
let ret = self.do_unseal(unseal_key);
match ret {
Err(SealBoxError::Unsealing) => {}
_ => self.shares = None,
}
ret
}
pub fn unseal_once(&mut self, unseal_key: &[u8]) -> Result<(), SealBoxError> {
let ret = self.do_unseal(unseal_key);
if ret.is_ok()
&& let Some(shares) = self.shares.as_ref()
{
for share in shares.iter() {
self.deprecated_shares.insert(share);
}
}
match ret {
Err(SealBoxError::Unsealing) => {}
_ => self.shares = None,
}
ret
}
pub fn seal(&mut self) {
self.shares = None;
self.key = None;
self.value = None;
}
pub fn is_unsealed(&self) -> bool {
self.key.is_some() && self.value.is_some()
}
pub fn get(&self) -> Result<&T, SealBoxError> {
self.value.as_ref().ok_or(SealBoxError::Sealed)
}
pub fn get_mut(&mut self) -> Result<&mut T, SealBoxError> {
self.value.as_mut().ok_or(SealBoxError::Sealed)
}
}