use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use serde::{Deserialize, Serialize};
use umbral_pre::{
decrypt_original, encrypt, serde_bytes, Capsule, EncryptionError, PublicKey, SecretKey,
Signature, Signer, VerifiedKeyFrag,
};
use crate::address::Address;
use crate::hrac::HRAC;
use crate::key_frag::{DecryptionError, EncryptedKeyFrag};
use crate::versioning::{
messagepack_deserialize, messagepack_serialize, ProtocolObject, ProtocolObjectInner,
};
use crate::RevocationOrder;
#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
pub struct TreasureMap {
pub threshold: u8,
pub hrac: HRAC,
pub destinations: BTreeMap<Address, EncryptedKeyFrag>,
pub policy_encrypting_key: PublicKey,
pub publisher_verifying_key: PublicKey,
}
impl TreasureMap {
pub fn new(
signer: &Signer,
hrac: &HRAC,
policy_encrypting_key: &PublicKey,
assigned_kfrags: impl IntoIterator<Item = (Address, (PublicKey, VerifiedKeyFrag))>,
threshold: u8,
) -> Self {
assert!(threshold != 0, "threshold must be non-zero");
let mut destinations = BTreeMap::new();
for (ursula_address, (ursula_encrypting_key, verified_kfrag)) in assigned_kfrags.into_iter()
{
let encrypted_kfrag =
EncryptedKeyFrag::new(signer, &ursula_encrypting_key, hrac, verified_kfrag);
if destinations
.insert(ursula_address, encrypted_kfrag)
.is_some()
{
panic!(
"{}",
format!("Repeating address in assigned_kfrags: {:?}", ursula_address)
)
};
}
assert!(
destinations.len() >= threshold as usize,
"threshold cannot be larger than the total number of shares"
);
Self {
threshold,
hrac: *hrac,
destinations,
policy_encrypting_key: *policy_encrypting_key,
publisher_verifying_key: signer.verifying_key(),
}
}
pub fn encrypt(&self, signer: &Signer, recipient_key: &PublicKey) -> EncryptedTreasureMap {
EncryptedTreasureMap::new(signer, recipient_key, self)
}
pub fn make_revocation_orders(&self, signer: &Signer) -> Vec<RevocationOrder> {
self.destinations
.iter()
.map(|(address, ekfrag)| RevocationOrder::new(signer, address, ekfrag))
.collect()
}
}
impl<'a> ProtocolObjectInner<'a> for TreasureMap {
fn brand() -> [u8; 4] {
*b"TMap"
}
fn version() -> (u16, u16) {
(2, 0)
}
fn unversioned_to_bytes(&self) -> Box<[u8]> {
messagepack_serialize(&self)
}
fn unversioned_from_bytes(minor_version: u16, bytes: &[u8]) -> Option<Result<Self, String>> {
if minor_version == 0 {
Some(messagepack_deserialize(bytes))
} else {
None
}
}
}
impl<'a> ProtocolObject<'a> for TreasureMap {}
#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
struct AuthorizedTreasureMap {
signature: Signature,
treasure_map: TreasureMap,
}
impl AuthorizedTreasureMap {
fn message_to_sign(recipient_key: &PublicKey, treasure_map: &TreasureMap) -> Vec<u8> {
let mut message = recipient_key.to_compressed_bytes().to_vec();
message.extend(treasure_map.to_bytes().iter());
message
}
fn new(signer: &Signer, recipient_key: &PublicKey, treasure_map: &TreasureMap) -> Self {
let message = Self::message_to_sign(recipient_key, treasure_map);
let signature = signer.sign(&message);
Self {
signature,
treasure_map: treasure_map.clone(),
}
}
fn verify(
self,
recipient_key: &PublicKey,
publisher_verifying_key: &PublicKey,
) -> Option<TreasureMap> {
let message = Self::message_to_sign(recipient_key, &self.treasure_map);
if !self.signature.verify(publisher_verifying_key, &message) {
return None;
}
Some(self.treasure_map)
}
}
impl<'a> ProtocolObjectInner<'a> for AuthorizedTreasureMap {
fn brand() -> [u8; 4] {
*b"AMap"
}
fn version() -> (u16, u16) {
(2, 0)
}
fn unversioned_to_bytes(&self) -> Box<[u8]> {
messagepack_serialize(&self)
}
fn unversioned_from_bytes(minor_version: u16, bytes: &[u8]) -> Option<Result<Self, String>> {
if minor_version == 0 {
Some(messagepack_deserialize(bytes))
} else {
None
}
}
}
impl<'a> ProtocolObject<'a> for AuthorizedTreasureMap {}
#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)]
pub struct EncryptedTreasureMap {
capsule: Capsule,
#[serde(with = "serde_bytes::as_base64")]
ciphertext: Box<[u8]>,
}
impl EncryptedTreasureMap {
fn new(signer: &Signer, recipient_key: &PublicKey, treasure_map: &TreasureMap) -> Self {
let authorized_tmap = AuthorizedTreasureMap::new(signer, recipient_key, treasure_map);
let (capsule, ciphertext) = match encrypt(recipient_key, &authorized_tmap.to_bytes()) {
Ok(result) => result,
Err(err) => match err {
EncryptionError::PlaintextTooLarge => panic!("encryption failed - out of memory?"),
},
};
Self {
capsule,
ciphertext,
}
}
pub fn decrypt(
&self,
sk: &SecretKey,
publisher_verifying_key: &PublicKey,
) -> Result<TreasureMap, DecryptionError> {
let auth_tmap_bytes = decrypt_original(sk, &self.capsule, &self.ciphertext)
.map_err(DecryptionError::DecryptionFailed)?;
let auth_tmap = AuthorizedTreasureMap::from_bytes(&auth_tmap_bytes)
.map_err(DecryptionError::DeserializationFailed)?;
auth_tmap
.verify(&sk.public_key(), publisher_verifying_key)
.ok_or(DecryptionError::VerificationFailed)
}
}
impl<'a> ProtocolObjectInner<'a> for EncryptedTreasureMap {
fn brand() -> [u8; 4] {
*b"EMap"
}
fn version() -> (u16, u16) {
(2, 0)
}
fn unversioned_to_bytes(&self) -> Box<[u8]> {
messagepack_serialize(&self)
}
fn unversioned_from_bytes(minor_version: u16, bytes: &[u8]) -> Option<Result<Self, String>> {
if minor_version == 0 {
Some(messagepack_deserialize(bytes))
} else {
None
}
}
}
impl<'a> ProtocolObject<'a> for EncryptedTreasureMap {}