use alloc::string::ToString;
use alloc::vec::Vec;
use miden_core::EMPTY_WORD;
use miden_core::utils::{ByteReader, ByteWriter, Deserializable, Serializable};
use miden_processor::DeserializationError;
use crate::Word;
use crate::block::{BlockNumber, NullifierWitness};
use crate::crypto::merkle::{MutationSet, SMT_DEPTH, Smt};
use crate::errors::NullifierTreeError;
use crate::note::Nullifier;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NullifierTree {
smt: Smt,
}
impl NullifierTree {
pub const DEPTH: u8 = SMT_DEPTH;
pub const UNSPENT_NULLIFIER: Word = EMPTY_WORD;
pub fn new() -> Self {
Self { smt: Smt::new() }
}
pub fn with_entries(
entries: impl IntoIterator<Item = (Nullifier, BlockNumber)>,
) -> Result<Self, NullifierTreeError> {
let leaves = entries.into_iter().map(|(nullifier, block_num)| {
(nullifier.as_word(), Self::block_num_to_leaf_value(block_num))
});
let smt = Smt::with_entries(leaves)
.map_err(NullifierTreeError::DuplicateNullifierBlockNumbers)?;
Ok(Self { smt })
}
pub fn root(&self) -> Word {
self.smt.root()
}
pub fn num_nullifiers(&self) -> usize {
self.smt.num_entries()
}
pub fn entries(&self) -> impl Iterator<Item = (Nullifier, BlockNumber)> {
self.smt.entries().map(|(nullifier, block_num)| {
(Nullifier::from(*nullifier), Self::leaf_value_to_block_num(*block_num))
})
}
pub fn open(&self, nullifier: &Nullifier) -> NullifierWitness {
NullifierWitness::new(self.smt.open(&nullifier.as_word()))
}
pub fn get_block_num(&self, nullifier: &Nullifier) -> Option<BlockNumber> {
let value = self.smt.get_value(&nullifier.as_word());
if value == Self::UNSPENT_NULLIFIER {
return None;
}
Some(Self::leaf_value_to_block_num(value))
}
pub fn compute_mutations<I>(
&self,
nullifiers: impl IntoIterator<Item = (Nullifier, BlockNumber), IntoIter = I>,
) -> Result<NullifierMutationSet, NullifierTreeError>
where
I: Iterator<Item = (Nullifier, BlockNumber)> + Clone,
{
let nullifiers = nullifiers.into_iter();
for (nullifier, _) in nullifiers.clone() {
if self.get_block_num(&nullifier).is_some() {
return Err(NullifierTreeError::NullifierAlreadySpent(nullifier));
}
}
let mutation_set = self
.smt
.compute_mutations(nullifiers.into_iter().map(|(nullifier, block_num)| {
(nullifier.as_word(), Self::block_num_to_leaf_value(block_num))
}))
.map_err(NullifierTreeError::ComputeMutations)?;
Ok(NullifierMutationSet::new(mutation_set))
}
pub fn mark_spent(
&mut self,
nullifier: Nullifier,
block_num: BlockNumber,
) -> Result<(), NullifierTreeError> {
let prev_nullifier_value = self
.smt
.insert(nullifier.as_word(), Self::block_num_to_leaf_value(block_num))
.map_err(NullifierTreeError::MaxLeafEntriesExceeded)?;
if prev_nullifier_value != Self::UNSPENT_NULLIFIER {
Err(NullifierTreeError::NullifierAlreadySpent(nullifier))
} else {
Ok(())
}
}
pub fn apply_mutations(
&mut self,
mutations: NullifierMutationSet,
) -> Result<(), NullifierTreeError> {
self.smt
.apply_mutations(mutations.into_mutation_set())
.map_err(NullifierTreeError::TreeRootConflict)
}
pub(super) fn block_num_to_leaf_value(block: BlockNumber) -> Word {
Word::from([block.as_u32(), 0, 0, 0])
}
fn leaf_value_to_block_num(value: Word) -> BlockNumber {
let block_num: u32 =
value[0].as_int().try_into().expect("invalid block number found in store");
block_num.into()
}
}
impl Default for NullifierTree {
fn default() -> Self {
Self::new()
}
}
impl Serializable for NullifierTree {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.entries().collect::<Vec<_>>().write_into(target);
}
}
impl Deserializable for NullifierTree {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let entries = Vec::<(Nullifier, BlockNumber)>::read_from(source)?;
Self::with_entries(entries)
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct NullifierMutationSet {
mutation_set: MutationSet<{ NullifierTree::DEPTH }, Word, Word>,
}
impl NullifierMutationSet {
fn new(mutation_set: MutationSet<{ NullifierTree::DEPTH }, Word, Word>) -> Self {
Self { mutation_set }
}
pub fn as_mutation_set(&self) -> &MutationSet<{ NullifierTree::DEPTH }, Word, Word> {
&self.mutation_set
}
pub fn into_mutation_set(self) -> MutationSet<{ NullifierTree::DEPTH }, Word, Word> {
self.mutation_set
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use super::NullifierTree;
use crate::block::BlockNumber;
use crate::note::Nullifier;
use crate::{NullifierTreeError, Word};
#[test]
fn leaf_value_encoding() {
let block_num = 123;
let nullifier_value = NullifierTree::block_num_to_leaf_value(block_num.into());
assert_eq!(nullifier_value, Word::from([block_num, 0, 0, 0u32]));
}
#[test]
fn leaf_value_decoding() {
let block_num = 123;
let nullifier_value = Word::from([block_num, 0, 0, 0u32]);
let decoded_block_num = NullifierTree::leaf_value_to_block_num(nullifier_value);
assert_eq!(decoded_block_num, block_num.into());
}
#[test]
fn apply_mutations() {
let nullifier1 = Nullifier::dummy(1);
let nullifier2 = Nullifier::dummy(2);
let nullifier3 = Nullifier::dummy(3);
let block1 = BlockNumber::from(1);
let block2 = BlockNumber::from(2);
let block3 = BlockNumber::from(3);
let mut tree = NullifierTree::with_entries([(nullifier1, block1)]).unwrap();
let mutations = tree
.compute_mutations([(nullifier2, block1), (nullifier3, block3), (nullifier2, block2)])
.unwrap();
tree.apply_mutations(mutations).unwrap();
assert_eq!(tree.num_nullifiers(), 3);
assert_eq!(tree.get_block_num(&nullifier1).unwrap(), block1);
assert_eq!(tree.get_block_num(&nullifier2).unwrap(), block2);
assert_eq!(tree.get_block_num(&nullifier3).unwrap(), block3);
}
#[test]
fn nullifier_already_spent() {
let nullifier1 = Nullifier::dummy(1);
let block1 = BlockNumber::from(1);
let block2 = BlockNumber::from(2);
let mut tree = NullifierTree::with_entries([(nullifier1, block1)]).unwrap();
let err = tree.clone().compute_mutations([(nullifier1, block1)]).unwrap_err();
assert_matches!(err, NullifierTreeError::NullifierAlreadySpent(nullifier) if nullifier == nullifier1);
let err = tree.clone().mark_spent(nullifier1, block1).unwrap_err();
assert_matches!(err, NullifierTreeError::NullifierAlreadySpent(nullifier) if nullifier == nullifier1);
let err = tree.clone().compute_mutations([(nullifier1, block2)]).unwrap_err();
assert_matches!(err, NullifierTreeError::NullifierAlreadySpent(nullifier) if nullifier == nullifier1);
let err = tree.mark_spent(nullifier1, block2).unwrap_err();
assert_matches!(err, NullifierTreeError::NullifierAlreadySpent(nullifier) if nullifier == nullifier1);
}
}