use std::{
default::Default,
fmt,
hash::{Hash, Hasher},
io,
};
use bitvec::prelude::*;
use halo2::pasta::{group::ff::PrimeField, pallas};
use hex::ToHex;
use incrementalmerkletree::{frontier::NonEmptyFrontier, Hashable};
use lazy_static::lazy_static;
use thiserror::Error;
use zcash_primitives::merkle_tree::HashSer;
use super::sinsemilla::*;
use crate::{
serialization::{
serde_helpers, ReadZcashExt, SerializationError, ZcashDeserialize, ZcashSerialize,
},
subtree::{NoteCommitmentSubtreeIndex, TRACKED_SUBTREE_HEIGHT},
};
pub mod legacy;
use legacy::LegacyNoteCommitmentTree;
pub type NoteCommitmentUpdate = pallas::Base;
pub(super) const MERKLE_DEPTH: u8 = 32;
fn merkle_crh_orchard(layer: u8, left: pallas::Base, right: pallas::Base) -> pallas::Base {
let mut s = bitvec![u8, Lsb0;];
let l = MERKLE_DEPTH - 1 - layer;
s.extend_from_bitslice(&BitArray::<_, Lsb0>::from([l, 0])[0..10]);
s.extend_from_bitslice(&BitArray::<_, Lsb0>::from(left.to_repr())[0..255]);
s.extend_from_bitslice(&BitArray::<_, Lsb0>::from(right.to_repr())[0..255]);
match sinsemilla_hash(b"z.cash:Orchard-MerkleCRH", &s) {
Some(h) => h,
None => pallas::Base::zero(),
}
}
lazy_static! {
pub(super) static ref EMPTY_ROOTS: Vec<pallas::Base> = {
let mut v = vec![NoteCommitmentTree::uncommitted()];
for layer in (0..MERKLE_DEPTH).rev()
{
let next = merkle_crh_orchard(layer, v[0], v[0]);
v.insert(0, next);
}
v
};
}
#[derive(Clone, Copy, Default, Eq, Serialize, Deserialize)]
pub struct Root(#[serde(with = "serde_helpers::Base")] pub(crate) pallas::Base);
impl Root {
pub fn bytes_in_display_order(&self) -> [u8; 32] {
self.into()
}
}
impl fmt::Debug for Root {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("Root")
.field(&hex::encode(self.0.to_repr()))
.finish()
}
}
impl From<Root> for [u8; 32] {
fn from(root: Root) -> Self {
root.0.into()
}
}
impl From<&Root> for [u8; 32] {
fn from(root: &Root) -> Self {
(*root).into()
}
}
impl Hash for Root {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.to_repr().hash(state)
}
}
impl PartialEq for Root {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl TryFrom<[u8; 32]> for Root {
type Error = SerializationError;
fn try_from(bytes: [u8; 32]) -> Result<Self, Self::Error> {
let possible_point = pallas::Base::from_repr(bytes);
if possible_point.is_some().into() {
Ok(Self(possible_point.unwrap()))
} else {
Err(SerializationError::Parse(
"Invalid pallas::Base value for Orchard note commitment tree root",
))
}
}
}
impl ZcashSerialize for Root {
fn zcash_serialize<W: io::Write>(&self, mut writer: W) -> Result<(), io::Error> {
writer.write_all(&<[u8; 32]>::from(*self)[..])?;
Ok(())
}
}
impl ZcashDeserialize for Root {
fn zcash_deserialize<R: io::Read>(mut reader: R) -> Result<Self, SerializationError> {
Self::try_from(reader.read_32_bytes()?)
}
}
#[derive(Copy, Clone, Eq, PartialEq, Default)]
pub struct Node(pallas::Base);
impl Node {
pub fn to_repr(&self) -> [u8; 32] {
self.0.to_repr()
}
pub fn bytes_in_display_order(&self) -> [u8; 32] {
self.to_repr()
}
}
impl TryFrom<&[u8]> for Node {
type Error = &'static str;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
<[u8; 32]>::try_from(bytes)
.map_err(|_| "wrong byte slice len")?
.try_into()
}
}
impl TryFrom<[u8; 32]> for Node {
type Error = &'static str;
fn try_from(bytes: [u8; 32]) -> Result<Self, Self::Error> {
Option::<pallas::Base>::from(pallas::Base::from_repr(bytes))
.map(Node)
.ok_or("invalid Pallas field element")
}
}
impl fmt::Display for Node {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(&self.encode_hex::<String>())
}
}
impl fmt::Debug for Node {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("orchard::Node")
.field(&self.encode_hex::<String>())
.finish()
}
}
impl ToHex for &Node {
fn encode_hex<T: FromIterator<char>>(&self) -> T {
self.bytes_in_display_order().encode_hex()
}
fn encode_hex_upper<T: FromIterator<char>>(&self) -> T {
self.bytes_in_display_order().encode_hex_upper()
}
}
impl ToHex for Node {
fn encode_hex<T: FromIterator<char>>(&self) -> T {
(&self).encode_hex()
}
fn encode_hex_upper<T: FromIterator<char>>(&self) -> T {
(&self).encode_hex_upper()
}
}
impl HashSer for Node {
fn read<R: io::Read>(mut reader: R) -> io::Result<Self> {
let mut repr = [0u8; 32];
reader.read_exact(&mut repr)?;
let maybe_node = pallas::Base::from_repr(repr).map(Self);
<Option<_>>::from(maybe_node).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Non-canonical encoding of Pallas base field value.",
)
})
}
fn write<W: io::Write>(&self, mut writer: W) -> io::Result<()> {
writer.write_all(&self.0.to_repr())
}
}
impl Hashable for Node {
fn empty_leaf() -> Self {
Self(NoteCommitmentTree::uncommitted())
}
fn combine(level: incrementalmerkletree::Level, a: &Self, b: &Self) -> Self {
let layer = MERKLE_DEPTH - 1 - u8::from(level);
Self(merkle_crh_orchard(layer, a.0, b.0))
}
fn empty_root(level: incrementalmerkletree::Level) -> Self {
let layer_below = usize::from(MERKLE_DEPTH) - usize::from(level);
Self(EMPTY_ROOTS[layer_below])
}
}
impl From<pallas::Base> for Node {
fn from(x: pallas::Base) -> Self {
Node(x)
}
}
impl serde::Serialize for Node {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.0.to_repr().serialize(serializer)
}
}
impl<'de> serde::Deserialize<'de> for Node {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let bytes = <[u8; 32]>::deserialize(deserializer)?;
Option::<pallas::Base>::from(pallas::Base::from_repr(bytes))
.map(Node)
.ok_or_else(|| serde::de::Error::custom("invalid Pallas field element"))
}
}
#[derive(Error, Copy, Clone, Debug, Eq, PartialEq, Hash)]
#[allow(missing_docs)]
pub enum NoteCommitmentTreeError {
#[error("The note commitment tree is full")]
FullTree,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(into = "LegacyNoteCommitmentTree")]
#[serde(from = "LegacyNoteCommitmentTree")]
pub struct NoteCommitmentTree {
inner: incrementalmerkletree::frontier::Frontier<Node, MERKLE_DEPTH>,
cached_root: std::sync::RwLock<Option<Root>>,
}
impl NoteCommitmentTree {
#[allow(clippy::unwrap_in_result)]
pub fn append(&mut self, cm_x: NoteCommitmentUpdate) -> Result<(), NoteCommitmentTreeError> {
if self.inner.append(cm_x.into()) {
let cached_root = self
.cached_root
.get_mut()
.expect("a thread that previously held exclusive lock access panicked");
*cached_root = None;
Ok(())
} else {
Err(NoteCommitmentTreeError::FullTree)
}
}
fn frontier(&self) -> Option<&NonEmptyFrontier<Node>> {
self.inner.value()
}
pub fn position(&self) -> Option<u64> {
let Some(tree) = self.frontier() else {
return None;
};
Some(tree.position().into())
}
pub fn contains_new_subtree(&self, prev_tree: &Self) -> bool {
let index = self.subtree_index().map_or(-1, |index| i32::from(index.0));
let prev_index = prev_tree
.subtree_index()
.map_or(-1, |index| i32::from(index.0));
let index_difference = index - prev_index;
if index < prev_index {
return false;
}
if index_difference > 1 {
return true;
}
if index == prev_index {
return self.is_complete_subtree();
}
if self.is_complete_subtree() {
return true;
}
if prev_tree.is_complete_subtree() || prev_index == -1 {
return false;
}
true
}
pub fn is_complete_subtree(&self) -> bool {
let Some(tree) = self.frontier() else {
return false;
};
tree.position()
.is_complete_subtree(TRACKED_SUBTREE_HEIGHT.into())
}
#[allow(clippy::unwrap_in_result)]
pub fn subtree_index(&self) -> Option<NoteCommitmentSubtreeIndex> {
let tree = self.frontier()?;
let index = incrementalmerkletree::Address::above_position(
TRACKED_SUBTREE_HEIGHT.into(),
tree.position(),
)
.index()
.try_into()
.expect("fits in u16");
Some(index)
}
#[allow(clippy::unwrap_in_result)]
pub fn remaining_subtree_leaf_nodes(&self) -> usize {
let remaining = match self.frontier() {
Some(tree) => {
let max_position = incrementalmerkletree::Address::above_position(
TRACKED_SUBTREE_HEIGHT.into(),
tree.position(),
)
.max_position();
max_position - tree.position().into()
}
None => {
let subtree_address = incrementalmerkletree::Address::above_position(
TRACKED_SUBTREE_HEIGHT.into(),
0.into(),
);
assert_eq!(
subtree_address.position_range_start(),
0.into(),
"address is not in the first subtree"
);
subtree_address.position_range_end()
}
};
u64::from(remaining).try_into().expect("fits in usize")
}
pub fn completed_subtree_index_and_root(&self) -> Option<(NoteCommitmentSubtreeIndex, Node)> {
if !self.is_complete_subtree() {
return None;
}
let index = self.subtree_index()?;
let root = self.frontier()?.root(Some(TRACKED_SUBTREE_HEIGHT.into()));
Some((index, root))
}
pub fn root(&self) -> Root {
if let Some(root) = self.cached_root() {
return root;
}
let mut write_root = self
.cached_root
.write()
.expect("a thread that previously held exclusive lock access panicked");
let read_root = write_root.as_ref().cloned();
match read_root {
Some(root) => root,
None => {
let root = self.recalculate_root();
*write_root = Some(root);
root
}
}
}
#[allow(clippy::unwrap_in_result)]
pub fn cached_root(&self) -> Option<Root> {
*self
.cached_root
.read()
.expect("a thread that previously held exclusive lock access panicked")
}
pub fn recalculate_root(&self) -> Root {
Root(self.inner.root().0)
}
pub fn hash(&self) -> [u8; 32] {
self.root().into()
}
pub fn uncommitted() -> pallas::Base {
pallas::Base::one().double()
}
pub fn count(&self) -> u64 {
self.inner
.value()
.map_or(0, |x| u64::from(x.position()) + 1)
}
#[cfg(any(test, feature = "proptest-impl"))]
pub fn assert_frontier_eq(&self, other: &Self) {
assert_eq!(self.cached_root(), other.cached_root());
assert_eq!(self.inner, other.inner);
assert_eq!(self.to_rpc_bytes(), other.to_rpc_bytes());
}
pub fn to_rpc_bytes(&self) -> Vec<u8> {
let tree = incrementalmerkletree::frontier::CommitmentTree::from_frontier(&self.inner);
let mut rpc_bytes = vec![];
zcash_primitives::merkle_tree::write_commitment_tree(&tree, &mut rpc_bytes)
.expect("serializable tree");
rpc_bytes
}
}
impl Clone for NoteCommitmentTree {
fn clone(&self) -> Self {
let cached_root = self.cached_root();
Self {
inner: self.inner.clone(),
cached_root: std::sync::RwLock::new(cached_root),
}
}
}
impl Default for NoteCommitmentTree {
fn default() -> Self {
Self {
inner: incrementalmerkletree::frontier::Frontier::empty(),
cached_root: Default::default(),
}
}
}
impl Eq for NoteCommitmentTree {}
impl PartialEq for NoteCommitmentTree {
fn eq(&self, other: &Self) -> bool {
if let (Some(root), Some(other_root)) = (self.cached_root(), other.cached_root()) {
root == other_root
} else {
self.inner == other.inner
}
}
}
impl From<Vec<pallas::Base>> for NoteCommitmentTree {
fn from(values: Vec<pallas::Base>) -> Self {
let mut tree = Self::default();
if values.is_empty() {
return tree;
}
for cm_x in values {
let _ = tree.append(cm_x);
}
tree
}
}