use alloc::string::ToString;
use alloc::vec::Vec;
use crate::block::BlockNumber;
use crate::crypto::merkle::MerkleError;
use crate::crypto::merkle::smt::{MutationSet, SMT_DEPTH, Smt};
use crate::errors::NullifierTreeError;
use crate::note::Nullifier;
use crate::utils::serde::{
ByteReader,
ByteWriter,
Deserializable,
DeserializationError,
Serializable,
};
use crate::{Felt, Word};
mod backend;
pub use backend::NullifierTreeBackend;
mod witness;
pub use witness::NullifierWitness;
mod partial;
pub use partial::PartialNullifierTree;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NullifierTree<Backend = Smt> {
smt: Backend,
}
impl<Backend> Default for NullifierTree<Backend>
where
Backend: Default,
{
fn default() -> Self {
Self { smt: Default::default() }
}
}
impl<Backend> NullifierTree<Backend>
where
Backend: NullifierTreeBackend<Error = MerkleError>,
{
pub const DEPTH: u8 = SMT_DEPTH;
pub fn new_unchecked(backend: Backend) -> Self {
NullifierTree { smt: backend }
}
pub fn root(&self) -> Word {
self.smt.root()
}
pub fn num_nullifiers(&self) -> usize {
self.smt.num_entries()
}
pub fn entries(&self) -> impl Iterator<Item = (Nullifier, BlockNumber)> {
self.smt.entries().map(|(nullifier, value)| {
(
Nullifier::from_raw(nullifier),
NullifierBlock::new(value)
.expect("SMT should only store valid NullifierBlocks")
.into(),
)
})
}
pub fn open(&self, nullifier: &Nullifier) -> NullifierWitness {
NullifierWitness::new(self.smt.open(&nullifier.as_word()))
}
pub fn get_block_num(&self, nullifier: &Nullifier) -> Option<BlockNumber> {
let nullifier_block = self.smt.get_value(&nullifier.as_word());
if nullifier_block.is_unspent() {
return None;
}
Some(nullifier_block.into())
}
pub fn compute_mutations<I>(
&self,
nullifiers: impl IntoIterator<Item = (Nullifier, BlockNumber), IntoIter = I>,
) -> Result<NullifierMutationSet, NullifierTreeError>
where
I: Iterator<Item = (Nullifier, BlockNumber)> + Clone,
{
let nullifiers = nullifiers.into_iter();
for (nullifier, _) in nullifiers.clone() {
if self.get_block_num(&nullifier).is_some() {
return Err(NullifierTreeError::NullifierAlreadySpent(nullifier));
}
}
let mutation_set = self
.smt
.compute_mutations(
nullifiers
.into_iter()
.map(|(nullifier, block_num)| {
(nullifier.as_word(), NullifierBlock::from(block_num).into())
})
.collect::<Vec<_>>(),
)
.map_err(NullifierTreeError::ComputeMutations)?;
Ok(NullifierMutationSet::new(mutation_set))
}
pub fn mark_spent(
&mut self,
nullifier: Nullifier,
block_num: BlockNumber,
) -> Result<(), NullifierTreeError> {
let prev_nullifier_value = self
.smt
.insert(nullifier.as_word(), NullifierBlock::from(block_num))
.map_err(NullifierTreeError::MaxLeafEntriesExceeded)?;
if prev_nullifier_value.is_spent() {
Err(NullifierTreeError::NullifierAlreadySpent(nullifier))
} else {
Ok(())
}
}
pub fn apply_mutations(
&mut self,
mutations: NullifierMutationSet,
) -> Result<(), NullifierTreeError> {
self.smt
.apply_mutations(mutations.into_mutation_set())
.map_err(NullifierTreeError::TreeRootConflict)
}
}
impl Serializable for NullifierTree {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.entries().collect::<Vec<_>>().write_into(target);
}
}
impl Deserializable for NullifierTree {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let entries = Vec::<(Nullifier, BlockNumber)>::read_from(source)?;
Self::with_entries(entries)
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NullifierMutationSet {
mutation_set: MutationSet<SMT_DEPTH, Word, Word>,
}
impl NullifierMutationSet {
fn new(mutation_set: MutationSet<SMT_DEPTH, Word, Word>) -> Self {
Self { mutation_set }
}
pub fn as_mutation_set(&self) -> &MutationSet<SMT_DEPTH, Word, Word> {
&self.mutation_set
}
pub fn into_mutation_set(self) -> MutationSet<SMT_DEPTH, Word, Word> {
self.mutation_set
}
}
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub struct NullifierBlock(BlockNumber);
impl NullifierBlock {
pub const UNSPENT: NullifierBlock = NullifierBlock(BlockNumber::GENESIS);
pub fn new(word: Word) -> Result<Self, NullifierTreeError> {
let block_num = u32::try_from(word[0].as_canonical_u64())
.map(BlockNumber::from)
.map_err(|_| NullifierTreeError::InvalidNullifierBlockNumber(word))?;
if word[1..4].iter().any(|felt| *felt != Felt::ZERO) {
return Err(NullifierTreeError::InvalidNullifierBlockNumber(word));
}
Ok(NullifierBlock(block_num))
}
pub fn is_spent(&self) -> bool {
!self.is_unspent()
}
pub fn is_unspent(&self) -> bool {
self == &Self::UNSPENT
}
}
impl From<BlockNumber> for NullifierBlock {
fn from(block_num: BlockNumber) -> Self {
Self(block_num)
}
}
impl From<NullifierBlock> for BlockNumber {
fn from(value: NullifierBlock) -> BlockNumber {
value.0
}
}
impl From<NullifierBlock> for Word {
fn from(value: NullifierBlock) -> Word {
Word::from([Felt::from(value.0), Felt::ZERO, Felt::ZERO, Felt::ZERO])
}
}
impl TryFrom<Word> for NullifierBlock {
type Error = NullifierTreeError;
fn try_from(value: Word) -> Result<Self, Self::Error> {
Self::new(value)
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use super::NullifierTree;
use crate::Word;
use crate::block::BlockNumber;
use crate::block::nullifier_tree::NullifierBlock;
use crate::errors::NullifierTreeError;
use crate::note::Nullifier;
#[test]
fn leaf_value_encode_decode() {
let block_num = BlockNumber::from(0xffff_ffff_u32);
let nullifier_block = NullifierBlock::from(block_num);
let block_num_recovered = nullifier_block.into();
assert_eq!(block_num, block_num_recovered);
}
#[test]
fn leaf_value_encoding() {
let block_num = BlockNumber::from(123);
let nullifier_value = NullifierBlock::from(block_num);
assert_eq!(
nullifier_value,
NullifierBlock::new(Word::from([block_num.as_u32(), 0, 0, 0u32])).unwrap()
);
}
#[test]
fn leaf_value_decoding() {
let block_num = 123;
let nullifier_value = NullifierBlock::new(Word::from([block_num, 0, 0, 0u32])).unwrap();
let decoded_block_num: BlockNumber = nullifier_value.into();
assert_eq!(decoded_block_num, block_num.into());
}
#[test]
fn apply_mutations() {
let nullifier1 = Nullifier::dummy(1);
let nullifier2 = Nullifier::dummy(2);
let nullifier3 = Nullifier::dummy(3);
let block1 = BlockNumber::from(1);
let block2 = BlockNumber::from(2);
let block3 = BlockNumber::from(3);
let mut tree = NullifierTree::with_entries([(nullifier1, block1)]).unwrap();
let mutations = tree
.compute_mutations([(nullifier2, block1), (nullifier3, block3), (nullifier2, block2)])
.unwrap();
tree.apply_mutations(mutations).unwrap();
assert_eq!(tree.num_nullifiers(), 3);
assert_eq!(tree.get_block_num(&nullifier1).unwrap(), block1);
assert_eq!(tree.get_block_num(&nullifier2).unwrap(), block2);
assert_eq!(tree.get_block_num(&nullifier3).unwrap(), block3);
}
#[test]
fn nullifier_already_spent() {
let nullifier1 = Nullifier::dummy(1);
let block1 = BlockNumber::from(1);
let block2 = BlockNumber::from(2);
let mut tree = NullifierTree::with_entries([(nullifier1, block1)]).unwrap();
let err = tree.clone().compute_mutations([(nullifier1, block1)]).unwrap_err();
assert_matches!(err, NullifierTreeError::NullifierAlreadySpent(nullifier) if nullifier == nullifier1);
let err = tree.clone().mark_spent(nullifier1, block1).unwrap_err();
assert_matches!(err, NullifierTreeError::NullifierAlreadySpent(nullifier) if nullifier == nullifier1);
let err = tree.clone().compute_mutations([(nullifier1, block2)]).unwrap_err();
assert_matches!(err, NullifierTreeError::NullifierAlreadySpent(nullifier) if nullifier == nullifier1);
let err = tree.mark_spent(nullifier1, block2).unwrap_err();
assert_matches!(err, NullifierTreeError::NullifierAlreadySpent(nullifier) if nullifier == nullifier1);
}
#[cfg(feature = "std")]
#[test]
fn large_smt_backend_basic_operations() {
use miden_crypto::merkle::smt::{LargeSmt, MemoryStorage};
let nullifier1 = Nullifier::dummy(1);
let nullifier2 = Nullifier::dummy(2);
let nullifier3 = Nullifier::dummy(3);
let block1 = BlockNumber::from(1);
let block2 = BlockNumber::from(2);
let block3 = BlockNumber::from(3);
let mut tree = NullifierTree::new_unchecked(
LargeSmt::with_entries(
MemoryStorage::default(),
[
(nullifier1.as_word(), NullifierBlock::from(block1).into()),
(nullifier2.as_word(), NullifierBlock::from(block2).into()),
],
)
.unwrap(),
);
assert_eq!(tree.num_nullifiers(), 2);
assert_eq!(tree.get_block_num(&nullifier1).unwrap(), block1);
assert_eq!(tree.get_block_num(&nullifier2).unwrap(), block2);
let _witness1 = tree.open(&nullifier1);
tree.mark_spent(nullifier3, block3).unwrap();
assert_eq!(tree.num_nullifiers(), 3);
assert_eq!(tree.get_block_num(&nullifier3).unwrap(), block3);
}
#[cfg(feature = "std")]
#[test]
fn large_smt_backend_nullifier_already_spent() {
use miden_crypto::merkle::smt::{LargeSmt, MemoryStorage};
let nullifier1 = Nullifier::dummy(1);
let block1 = BlockNumber::from(1);
let block2 = BlockNumber::from(2);
let mut tree = NullifierTree::new_unchecked(
LargeSmt::with_entries(
MemoryStorage::default(),
[(nullifier1.as_word(), NullifierBlock::from(block1).into())],
)
.unwrap(),
);
assert_eq!(tree.get_block_num(&nullifier1).unwrap(), block1);
let err = tree.mark_spent(nullifier1, block2).unwrap_err();
assert_matches!(err, NullifierTreeError::NullifierAlreadySpent(nullifier) if nullifier == nullifier1);
}
#[cfg(feature = "std")]
#[test]
fn large_smt_backend_apply_mutations() {
use miden_crypto::merkle::smt::{LargeSmt, MemoryStorage};
let nullifier1 = Nullifier::dummy(1);
let nullifier2 = Nullifier::dummy(2);
let nullifier3 = Nullifier::dummy(3);
let block1 = BlockNumber::from(1);
let block2 = BlockNumber::from(2);
let block3 = BlockNumber::from(3);
let mut tree = LargeSmt::with_entries(
MemoryStorage::default(),
[(nullifier1.as_word(), NullifierBlock::from(block1).into())],
)
.map(NullifierTree::new_unchecked)
.unwrap();
let mutations =
tree.compute_mutations([(nullifier2, block2), (nullifier3, block3)]).unwrap();
tree.apply_mutations(mutations).unwrap();
assert_eq!(tree.num_nullifiers(), 3);
assert_eq!(tree.get_block_num(&nullifier1).unwrap(), block1);
assert_eq!(tree.get_block_num(&nullifier2).unwrap(), block2);
assert_eq!(tree.get_block_num(&nullifier3).unwrap(), block3);
}
#[cfg(feature = "std")]
#[test]
fn large_smt_backend_same_root_as_regular_smt() {
use miden_crypto::merkle::smt::{LargeSmt, MemoryStorage};
let nullifier1 = Nullifier::dummy(1);
let nullifier2 = Nullifier::dummy(2);
let block1 = BlockNumber::from(1);
let block2 = BlockNumber::from(2);
let large_tree = LargeSmt::with_entries(
MemoryStorage::default(),
[
(nullifier1.as_word(), NullifierBlock::from(block1).into()),
(nullifier2.as_word(), NullifierBlock::from(block2).into()),
],
)
.map(NullifierTree::new_unchecked)
.unwrap();
let regular_tree =
NullifierTree::with_entries([(nullifier1, block1), (nullifier2, block2)]).unwrap();
assert_eq!(large_tree.root(), regular_tree.root());
let large_entries: std::collections::BTreeMap<_, _> = large_tree.entries().collect();
let regular_entries: std::collections::BTreeMap<_, _> = regular_tree.entries().collect();
assert_eq!(large_entries, regular_entries);
}
}