use crate::error::TaprootError;
use crate::tagged_hash::{TapLeafHash, TapNodeHash};
pub const TAPSCRIPT_LEAF_VERSION: u8 = 0xc0;
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
pub struct LeafVersion(pub u8);
impl LeafVersion {
pub const TAPSCRIPT: Self = Self(TAPSCRIPT_LEAF_VERSION);
pub fn new(version: u8) -> Result<Self, TaprootError> {
if version & 0x01 != 0 {
return Err(TaprootError::InvalidLeafVersion(version));
}
Ok(Self(version))
}
pub fn to_u8(self) -> u8 {
self.0
}
}
impl Default for LeafVersion {
fn default() -> Self {
Self::TAPSCRIPT
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct TapLeaf {
pub version: LeafVersion,
pub script: Vec<u8>,
}
impl TapLeaf {
pub fn new(script: Vec<u8>) -> Self {
Self {
version: LeafVersion::TAPSCRIPT,
script,
}
}
pub fn with_version(version: LeafVersion, script: Vec<u8>) -> Self {
Self { version, script }
}
pub fn hash(&self) -> TapLeafHash {
TapLeafHash::from_script(self.version.0, &self.script)
}
}
#[derive(Clone, Debug)]
pub enum TapNode {
Leaf(TapLeaf),
Branch(Box<TapNode>, Box<TapNode>),
}
impl TapNode {
pub fn hash(&self) -> TapNodeHash {
match self {
TapNode::Leaf(leaf) => TapNodeHash::from_leaf(leaf.hash()),
TapNode::Branch(left, right) => {
TapNodeHash::from_children(&left.hash(), &right.hash())
}
}
}
pub fn is_leaf(&self) -> bool {
matches!(self, TapNode::Leaf(_))
}
pub fn as_leaf(&self) -> Option<&TapLeaf> {
match self {
TapNode::Leaf(leaf) => Some(leaf),
TapNode::Branch(_, _) => None,
}
}
}
#[derive(Clone, Debug)]
pub struct TapTree {
root: TapNode,
}
impl TapTree {
pub fn from_node(root: TapNode) -> Self {
Self { root }
}
pub fn single_leaf(script: Vec<u8>) -> Self {
Self {
root: TapNode::Leaf(TapLeaf::new(script)),
}
}
pub fn root_hash(&self) -> TapNodeHash {
self.root.hash()
}
pub fn root(&self) -> &TapNode {
&self.root
}
pub fn merkle_path(&self, target_leaf: &TapLeaf) -> Option<Vec<TapNodeHash>> {
let target_hash = target_leaf.hash();
self.find_path(&self.root, &TapNodeHash::from_leaf(target_hash))
}
fn find_path(&self, node: &TapNode, target: &TapNodeHash) -> Option<Vec<TapNodeHash>> {
match node {
TapNode::Leaf(leaf) => {
if TapNodeHash::from_leaf(leaf.hash()) == *target {
Some(Vec::new())
} else {
None
}
}
TapNode::Branch(left, right) => {
if let Some(mut path) = self.find_path(left, target) {
path.push(right.hash());
return Some(path);
}
if let Some(mut path) = self.find_path(right, target) {
path.push(left.hash());
return Some(path);
}
None
}
}
}
pub fn leaves(&self) -> Vec<&TapLeaf> {
let mut leaves = Vec::new();
self.collect_leaves(&self.root, &mut leaves);
leaves
}
fn collect_leaves<'a>(&'a self, node: &'a TapNode, leaves: &mut Vec<&'a TapLeaf>) {
match node {
TapNode::Leaf(leaf) => leaves.push(leaf),
TapNode::Branch(left, right) => {
self.collect_leaves(left, leaves);
self.collect_leaves(right, leaves);
}
}
}
}
#[derive(Default)]
pub struct TapTreeBuilder {
leaves: Vec<(TapLeaf, u8)>, }
impl TapTreeBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn add_leaf(mut self, depth: u8, script: Vec<u8>) -> Self {
self.leaves.push((TapLeaf::new(script), depth));
self
}
pub fn add_leaf_with_version(
mut self,
depth: u8,
version: LeafVersion,
script: Vec<u8>,
) -> Self {
self.leaves.push((TapLeaf::with_version(version, script), depth));
self
}
pub fn build(self) -> Result<TapTree, TaprootError> {
if self.leaves.is_empty() {
return Err(TaprootError::EmptyTree);
}
if self.leaves.len() == 1 {
return Ok(TapTree::single_leaf(self.leaves[0].0.script.clone()));
}
let mut leaves = self.leaves;
leaves.sort_by(|a, b| b.1.cmp(&a.1));
let mut nodes: Vec<(TapNode, u8)> = leaves
.into_iter()
.map(|(leaf, depth)| (TapNode::Leaf(leaf), depth))
.collect();
while nodes.len() > 1 {
let mut i = 0;
while i < nodes.len() - 1 {
if nodes[i].1 == nodes[i + 1].1 {
let (right, _) = nodes.remove(i + 1);
let (left, depth) = nodes.remove(i);
let branch = TapNode::Branch(Box::new(left), Box::new(right));
nodes.insert(i, (branch, depth.saturating_sub(1)));
} else {
i += 1;
}
}
if nodes.len() > 1 && nodes.iter().all(|(_, d)| *d == nodes[0].1) {
break;
}
}
if nodes.len() != 1 {
return Err(TaprootError::TreeError(
"Could not build balanced tree".into(),
));
}
Ok(TapTree::from_node(nodes.remove(0).0))
}
}
pub fn two_leaf_tree(script1: Vec<u8>, script2: Vec<u8>) -> TapTree {
let left = TapNode::Leaf(TapLeaf::new(script1));
let right = TapNode::Leaf(TapLeaf::new(script2));
TapTree::from_node(TapNode::Branch(Box::new(left), Box::new(right)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_leaf_hash() {
let leaf = TapLeaf::new(vec![0x51]); let hash = leaf.hash();
let hash2 = leaf.hash();
assert_eq!(hash, hash2);
}
#[test]
fn test_single_leaf_tree() {
let tree = TapTree::single_leaf(vec![0x51]);
let leaves = tree.leaves();
assert_eq!(leaves.len(), 1);
}
#[test]
fn test_two_leaf_tree() {
let tree = two_leaf_tree(vec![0x51], vec![0x52]);
let leaves = tree.leaves();
assert_eq!(leaves.len(), 2);
}
#[test]
fn test_merkle_path() {
let script1 = vec![0x51];
let script2 = vec![0x52];
let tree = two_leaf_tree(script1.clone(), script2.clone());
let leaf1 = TapLeaf::new(script1);
let path = tree.merkle_path(&leaf1).unwrap();
assert_eq!(path.len(), 1);
}
#[test]
fn test_builder_single_leaf() {
let tree = TapTreeBuilder::new()
.add_leaf(0, vec![0x51])
.build()
.unwrap();
assert_eq!(tree.leaves().len(), 1);
}
#[test]
fn test_builder_two_leaves() {
let tree = TapTreeBuilder::new()
.add_leaf(1, vec![0x51])
.add_leaf(1, vec![0x52])
.build()
.unwrap();
assert_eq!(tree.leaves().len(), 2);
}
#[test]
fn test_leaf_version() {
assert!(LeafVersion::new(0xc0).is_ok());
assert!(LeafVersion::new(0xc2).is_ok());
assert!(LeafVersion::new(0xc1).is_err()); }
#[test]
fn test_branch_hash_deterministic() {
let tree1 = two_leaf_tree(vec![0x51], vec![0x52]);
let tree2 = two_leaf_tree(vec![0x51], vec![0x52]);
assert_eq!(tree1.root_hash(), tree2.root_hash());
}
}