use alloc::string::ToString;
use miden_crypto::merkle::{InnerNodeInfo, MerkleError, PartialSmt, SmtLeaf, SmtProof};
use super::{AssetVault, AssetVaultKey};
use crate::Word;
use crate::asset::{Asset, AssetWitness};
use crate::errors::PartialAssetVaultError;
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
#[derive(Clone, Debug, PartialEq, Eq, Default)]
pub struct PartialVault {
partial_smt: PartialSmt,
}
impl PartialVault {
pub fn new(root: Word) -> Self {
PartialVault { partial_smt: PartialSmt::new(root) }
}
pub fn new_full(vault: AssetVault) -> Self {
let partial_smt = PartialSmt::from(vault.asset_tree);
PartialVault { partial_smt }
}
pub fn new_minimal(vault: &AssetVault) -> Self {
PartialVault::new(vault.root())
}
pub fn root(&self) -> Word {
self.partial_smt.root()
}
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
self.partial_smt.inner_nodes()
}
pub fn leaves(&self) -> impl Iterator<Item = &SmtLeaf> {
self.partial_smt.leaves().map(|(_, leaf)| leaf)
}
pub fn open(&self, vault_key: AssetVaultKey) -> Result<AssetWitness, PartialAssetVaultError> {
let smt_proof = self
.partial_smt
.open(&vault_key.into())
.map_err(PartialAssetVaultError::UntrackedAsset)?;
Ok(AssetWitness::new_unchecked(smt_proof))
}
pub fn get(&self, vault_key: AssetVaultKey) -> Result<Option<Asset>, MerkleError> {
self.partial_smt.get_value(&vault_key.into()).map(|word| {
if word.is_empty() {
None
} else {
Some(Asset::try_from(word).expect("partial vault should only track valid assets"))
}
})
}
pub fn add(&mut self, witness: AssetWitness) -> Result<(), PartialAssetVaultError> {
let proof = SmtProof::from(witness);
self.partial_smt
.add_proof(proof)
.map_err(PartialAssetVaultError::FailedToAddProof)
}
fn validate_entries<'a>(
entries: impl IntoIterator<Item = &'a (Word, Word)>,
) -> Result<(), PartialAssetVaultError> {
for (vault_key, asset) in entries {
let asset = Asset::try_from(asset).map_err(|source| {
PartialAssetVaultError::InvalidAssetInSmt { entry: *asset, source }
})?;
if *vault_key != asset.vault_key().into() {
return Err(PartialAssetVaultError::AssetVaultKeyMismatch {
expected: asset.vault_key(),
actual: *vault_key,
});
}
}
Ok(())
}
}
impl TryFrom<PartialSmt> for PartialVault {
type Error = PartialAssetVaultError;
fn try_from(partial_smt: PartialSmt) -> Result<Self, Self::Error> {
Self::validate_entries(partial_smt.entries())?;
Ok(PartialVault { partial_smt })
}
}
impl Serializable for PartialVault {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(&self.partial_smt)
}
}
impl Deserializable for PartialVault {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let partial_smt: PartialSmt = source.read()?;
PartialVault::try_from(partial_smt)
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use miden_crypto::merkle::Smt;
use super::*;
use crate::asset::FungibleAsset;
#[test]
fn partial_vault_ensures_asset_validity() -> anyhow::Result<()> {
let invalid_asset = Word::from([0, 0, 0, 5u32]);
let smt = Smt::with_entries([(invalid_asset, invalid_asset)])?;
let proof = smt.open(&invalid_asset);
let partial_smt = PartialSmt::from_proofs([proof.clone()])?;
let err = PartialVault::try_from(partial_smt).unwrap_err();
assert_matches!(err, PartialAssetVaultError::InvalidAssetInSmt { entry, .. } => {
assert_eq!(entry, invalid_asset);
});
Ok(())
}
#[test]
fn partial_vault_ensures_asset_vault_key_matches() -> anyhow::Result<()> {
let asset = FungibleAsset::mock(500);
let invalid_vault_key = Word::from([0, 1, 2, 3u32]);
let smt = Smt::with_entries([(invalid_vault_key, asset.into())])?;
let proof = smt.open(&invalid_vault_key);
let partial_smt = PartialSmt::from_proofs([proof.clone()])?;
let err = PartialVault::try_from(partial_smt).unwrap_err();
assert_matches!(err, PartialAssetVaultError::AssetVaultKeyMismatch { expected, actual } => {
assert_eq!(actual, invalid_vault_key);
assert_eq!(expected, asset.vault_key());
});
Ok(())
}
}