use std::borrow::{Borrow, BorrowMut};
use std::cmp::Ordering;
use std::fmt::{self, Debug, Display, Formatter};
use std::io::{Read, Write};
use std::ops::{Deref, Not};
use std::str::FromStr;
use amplify::Wrapper;
use bitcoin::hashes::Hash;
use bitcoin::psbt::TapTree;
use bitcoin::util::taproot::{LeafVersion, TapBranchHash, TapLeafHash, TaprootBuilder};
use bitcoin::Script;
use strict_encoding::{StrictDecode, StrictEncode};
use crate::types::IntoNodeHash;
use crate::{LeafScript, TapNodeHash, TapScript};
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Error, Display)]
#[display(
"invalid taproot tree lexicographic node ordering in branch {dfs_path}, where the hash of the \
left-side child {left_hash} is larger than the hash of the right-side child {right_hash}"
)]
pub struct TaprootTreeError {
pub left_hash: TapNodeHash,
pub right_hash: TapNodeHash,
pub dfs_path: DfsPath,
}
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Error, Display)]
#[display("maximum taproot script tree depth exceeded.")]
pub struct MaxDepthExceeded;
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Error, Display)]
#[display("an attempt to raise subtree above its depth.")]
pub struct RaiseAboveRoot;
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Error, Display)]
#[display("tree contains just a single known root node and can't be split into two parts.")]
pub struct UnsplittableTree;
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Error, Display)]
#[display("taproot script tree is not complete at node {0:?}.")]
pub struct IncompleteTreeError<N>(N)
where
N: Node + Debug;
#[derive(
Clone, PartialOrd, Ord, PartialEq, Eq, Hash, Debug, Display, Error, From
)]
#[display(doc_comments)]
pub enum InstillError {
#[from(MaxDepthExceeded)]
MaxDepthExceeded,
#[from]
DfsTraversal(DfsTraversalError),
}
#[derive(
Clone, PartialOrd, Ord, PartialEq, Eq, Hash, Debug, Display, Error, From
)]
#[display(doc_comments)]
pub enum CutError {
#[from(UnsplittableTree)]
UnsplittableTree,
#[from]
DfsTraversal(DfsTraversalError),
}
#[derive(Clone, PartialOrd, Ord, PartialEq, Eq, Hash, Debug, Display, Error)]
#[display(doc_comments)]
pub enum DfsTraversalError {
PathNotExists(DfsPath),
HiddenNode {
node_hash: TapNodeHash,
failed_path: DfsPath,
path_leftover: DfsPath,
},
LeafNode {
leaf_script: LeafScript,
failed_path: DfsPath,
path_leftover: DfsPath,
},
}
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Display)]
#[derive(StrictEncode, StrictDecode)]
#[strict_encoding(by_order, repr = u8)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub enum DfsOrder {
#[display("dfs-first")]
First,
#[display("dfs-last")]
Last,
}
impl Not for DfsOrder {
type Output = DfsOrder;
fn not(self) -> Self::Output {
match self {
DfsOrder::First => DfsOrder::Last,
DfsOrder::Last => DfsOrder::First,
}
}
}
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Display)]
#[derive(StrictEncode, StrictDecode)]
#[strict_encoding(by_order, repr = u8)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub enum DfsOrdering {
#[display("left-to-right")]
LeftRight,
#[display("right-to-left")]
RightLeft,
}
impl Not for DfsOrdering {
type Output = DfsOrdering;
fn not(self) -> Self::Output {
match self {
DfsOrdering::LeftRight => DfsOrdering::RightLeft,
DfsOrdering::RightLeft => DfsOrdering::LeftRight,
}
}
}
#[derive(
Wrapper, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug, From
)]
#[derive(StrictEncode, StrictDecode)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct DfsPath(Vec<DfsOrder>);
impl AsRef<[DfsOrder]> for DfsPath {
#[inline]
fn as_ref(&self) -> &[DfsOrder] { self.0.as_ref() }
}
impl Borrow<[DfsOrder]> for DfsPath {
#[inline]
fn borrow(&self) -> &[DfsOrder] { self.0.borrow() }
}
impl Display for DfsPath {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
for step in self {
f.write_str(match step {
DfsOrder::First => "0",
DfsOrder::Last => "1",
})?;
}
Ok(())
}
}
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Display, Error)]
#[display("the given DFS path {0} can't be parsed: an unexpected character {1} was found.")]
pub struct DfsPathParseError(pub String, pub char);
impl FromStr for DfsPath {
type Err = DfsPathParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.chars()
.map(|c| match c {
'0' => Ok(DfsOrder::First),
'1' => Ok(DfsOrder::Last),
other => Err(DfsPathParseError(s.to_string(), other)),
})
.collect()
}
}
impl DfsPath {
#[inline]
pub fn new() -> DfsPath { DfsPath(vec![]) }
pub fn with<'path>(iter: impl IntoIterator<Item = &'path DfsOrder>) -> Self {
DfsPath::from_iter(iter)
}
}
impl<'path> IntoIterator for &'path DfsPath {
type Item = DfsOrder;
type IntoIter = core::iter::Cloned<core::slice::Iter<'path, DfsOrder>>;
fn into_iter(self) -> Self::IntoIter { self.0.iter().cloned() }
}
impl IntoIterator for DfsPath {
type Item = DfsOrder;
type IntoIter = std::vec::IntoIter<DfsOrder>;
fn into_iter(self) -> Self::IntoIter { self.0.into_iter() }
}
impl FromIterator<DfsOrder> for DfsPath {
fn from_iter<T: IntoIterator<Item = DfsOrder>>(iter: T) -> Self {
Self::from_inner(iter.into_iter().collect())
}
}
impl<'iter> FromIterator<&'iter DfsOrder> for DfsPath {
fn from_iter<T: IntoIterator<Item = &'iter DfsOrder>>(iter: T) -> Self {
Self::from_inner(iter.into_iter().copied().collect())
}
}
pub trait Branch {
fn subtree_depth(&self) -> Option<u8>;
fn dfs_ordering(&self) -> DfsOrdering;
fn branch_hash(&self) -> TapBranchHash;
}
pub trait Node {
fn is_hidden(&self) -> bool;
fn is_branch(&self) -> bool;
fn is_leaf(&self) -> bool;
fn node_hash(&self) -> TapNodeHash;
fn node_depth(&self) -> u8;
fn subtree_depth(&self) -> Option<u8>;
}
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
#[derive(StrictEncode, StrictDecode)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub struct BranchNode {
left: Box<TreeNode>,
right: Box<TreeNode>,
dfs_ordering: DfsOrdering,
}
impl Branch for BranchNode {
#[inline]
fn subtree_depth(&self) -> Option<u8> {
Some(self.left.subtree_depth()?.max(self.right.subtree_depth()?))
}
fn dfs_ordering(&self) -> DfsOrdering { self.dfs_ordering }
fn branch_hash(&self) -> TapBranchHash {
TapBranchHash::from_node_hashes(
self.as_left_node().node_hash(),
self.as_right_node().node_hash(),
)
}
}
impl BranchNode {
pub(self) fn with(first: TreeNode, last: TreeNode) -> Self {
let hash1 = first.node_hash();
let hash2 = last.node_hash();
if hash1 < hash2 {
BranchNode {
left: Box::new(first),
right: Box::new(last),
dfs_ordering: DfsOrdering::LeftRight,
}
} else {
BranchNode {
left: Box::new(last),
right: Box::new(first),
dfs_ordering: DfsOrdering::RightLeft,
}
}
}
#[inline]
pub fn split(self) -> (TreeNode, TreeNode) { (*self.left, *self.right) }
#[inline]
pub fn split_dfs(self) -> (TreeNode, TreeNode) {
match self.dfs_ordering {
DfsOrdering::LeftRight => (*self.left, *self.right),
DfsOrdering::RightLeft => (*self.right, *self.left),
}
}
#[inline]
pub fn as_left_node(&self) -> &TreeNode { &self.left }
#[inline]
pub fn as_right_node(&self) -> &TreeNode { &self.right }
#[inline]
pub(self) fn as_left_node_mut(&mut self) -> &mut TreeNode { &mut self.left }
#[inline]
pub(self) fn as_right_node_mut(&mut self) -> &mut TreeNode { &mut self.right }
#[inline]
pub fn as_dfs_child_node(&self, direction: DfsOrder) -> &TreeNode {
match direction {
DfsOrder::First => self.as_dfs_first_node(),
DfsOrder::Last => self.as_dfs_last_node(),
}
}
#[inline]
pub fn as_dfs_first_node(&self) -> &TreeNode {
match self.dfs_ordering() {
DfsOrdering::LeftRight => self.as_left_node(),
DfsOrdering::RightLeft => self.as_right_node(),
}
}
#[inline]
pub fn as_dfs_last_node(&self) -> &TreeNode {
match self.dfs_ordering() {
DfsOrdering::LeftRight => self.as_right_node(),
DfsOrdering::RightLeft => self.as_left_node(),
}
}
#[inline]
pub(self) fn as_dfs_first_node_mut(&mut self) -> &mut TreeNode {
match self.dfs_ordering() {
DfsOrdering::LeftRight => self.as_left_node_mut(),
DfsOrdering::RightLeft => self.as_right_node_mut(),
}
}
#[inline]
pub(self) fn as_dfs_last_node_mut(&mut self) -> &mut TreeNode {
match self.dfs_ordering() {
DfsOrdering::LeftRight => self.as_right_node_mut(),
DfsOrdering::RightLeft => self.as_left_node_mut(),
}
}
}
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
#[derive(StrictEncode, StrictDecode)]
#[strict_encoding(by_order, repr = u8)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub enum TreeNode {
Leaf(LeafScript, u8),
Hidden(TapNodeHash, u8),
Branch(BranchNode, u8),
}
impl strict_encoding::StrictEncode for Box<TreeNode> {
fn strict_encode<E: Write>(&self, mut e: E) -> Result<usize, strict_encoding::Error> {
let s = self.as_ref().strict_serialize()?;
e.write_all(&s)?;
Ok(s.len())
}
}
impl strict_encoding::StrictDecode for Box<TreeNode> {
fn strict_decode<D: Read>(d: D) -> Result<Self, strict_encoding::Error> {
TreeNode::strict_decode(d).map(Box::new)
}
}
impl TreeNode {
pub fn with_tap_script(script: TapScript, depth: u8) -> TreeNode {
TreeNode::Leaf(LeafScript::tapscript(script), depth)
}
pub fn with_branch(a: TreeNode, b: TreeNode, depth: u8) -> TreeNode {
TreeNode::Branch(BranchNode::with(a, b), depth)
}
pub fn as_branch(&self) -> Option<&BranchNode> {
match self {
TreeNode::Branch(branch, _) => Some(branch),
_ => None,
}
}
pub(self) fn as_branch_mut(&mut self) -> Option<&mut BranchNode> {
match self {
TreeNode::Branch(branch, _) => Some(branch),
_ => None,
}
}
pub fn as_leaf_script(&self) -> Option<&LeafScript> {
match self {
TreeNode::Leaf(leaf_script, _) => Some(leaf_script),
_ => None,
}
}
#[inline]
pub fn node_at(&self, path: impl AsRef<[DfsOrder]>) -> Result<&TreeNode, DfsTraversalError> {
let mut curr = self;
let mut past_steps = vec![];
let path = path.as_ref();
let mut iter = path.iter();
for step in iter.by_ref() {
past_steps.push(step);
let branch = match curr {
TreeNode::Branch(branch, _) => branch,
TreeNode::Leaf(leaf_script, _) => {
return Err(DfsTraversalError::LeafNode {
leaf_script: leaf_script.clone(),
failed_path: DfsPath::with(past_steps),
path_leftover: iter.collect(),
})
}
TreeNode::Hidden(hash, _) => {
return Err(DfsTraversalError::HiddenNode {
node_hash: *hash,
failed_path: DfsPath::with(past_steps),
path_leftover: iter.collect(),
})
}
};
curr = match step {
DfsOrder::First => branch.as_dfs_first_node(),
DfsOrder::Last => branch.as_dfs_last_node(),
};
}
Ok(curr)
}
#[inline]
pub(self) fn node_mut_at<'path>(
&mut self,
path: impl IntoIterator<Item = &'path DfsOrder>,
) -> Result<&mut TreeNode, DfsTraversalError> {
let mut curr = self;
let mut past_steps = vec![];
let mut iter = path.into_iter();
for step in iter.by_ref() {
past_steps.push(step);
let branch = match curr {
TreeNode::Branch(branch, _) => branch,
TreeNode::Leaf(leaf_script, _) => {
return Err(DfsTraversalError::LeafNode {
leaf_script: leaf_script.clone(),
failed_path: DfsPath::with(past_steps),
path_leftover: iter.collect(),
})
}
TreeNode::Hidden(hash, _) => {
return Err(DfsTraversalError::HiddenNode {
node_hash: *hash,
failed_path: DfsPath::with(past_steps),
path_leftover: iter.collect(),
})
}
};
curr = match step {
DfsOrder::First => branch.as_dfs_first_node_mut(),
DfsOrder::Last => branch.as_dfs_last_node_mut(),
};
}
Ok(curr)
}
pub(self) fn nodes_on_path<'node, 'path>(
&'node self,
path: &'path [DfsOrder],
) -> TreePathIter<'node, 'path> {
TreePathIter {
next_node: Some(self),
full_path: path,
remaining_path: path.iter(),
}
}
pub(self) fn nodes(&self) -> TreeNodeIter { TreeNodeIter::from(self) }
pub(self) fn nodes_mut(&mut self) -> TreeNodeIterMut { TreeNodeIterMut::from(self) }
pub(self) fn lower(&mut self, inc: u8) -> Result<u8, MaxDepthExceeded> {
let old_depth = self.node_depth();
match self {
TreeNode::Leaf(_, depth) | TreeNode::Hidden(_, depth) | TreeNode::Branch(_, depth) => {
*depth = depth.checked_add(inc).ok_or(MaxDepthExceeded)?;
}
}
Ok(old_depth)
}
pub(self) fn raise(&mut self, dec: u8) -> Result<u8, RaiseAboveRoot> {
let old_depth = self.node_depth();
match self {
TreeNode::Leaf(_, depth) | TreeNode::Hidden(_, depth) | TreeNode::Branch(_, depth) => {
*depth = depth.checked_sub(dec).ok_or(RaiseAboveRoot)?;
}
}
Ok(old_depth)
}
pub fn check(&self) -> Result<(), TaprootTreeError> {
for (node, dfs_path) in self.nodes() {
if let Some(branch) = node.as_branch() {
let left_hash = branch.left.node_hash();
let right_hash = branch.right.node_hash();
if left_hash > right_hash {
return Err(TaprootTreeError {
left_hash,
right_hash,
dfs_path,
});
}
}
}
Ok(())
}
}
impl Node for TreeNode {
fn is_hidden(&self) -> bool { matches!(self, TreeNode::Hidden(..)) }
fn is_branch(&self) -> bool { matches!(self, TreeNode::Branch(..)) }
fn is_leaf(&self) -> bool { matches!(self, TreeNode::Leaf(..)) }
fn node_hash(&self) -> TapNodeHash {
match self {
TreeNode::Leaf(leaf_script, _) => leaf_script.tap_leaf_hash().into_node_hash(),
TreeNode::Hidden(hash, _) => *hash,
TreeNode::Branch(branches, _) => branches.branch_hash().into_node_hash(),
}
}
fn node_depth(&self) -> u8 {
match self {
TreeNode::Leaf(_, depth) | TreeNode::Hidden(_, depth) | TreeNode::Branch(_, depth) => {
*depth
}
}
}
fn subtree_depth(&self) -> Option<u8> {
match self {
TreeNode::Leaf(_, _) => Some(1),
TreeNode::Hidden(_, _) => None,
TreeNode::Branch(branch, _) => Some(branch.subtree_depth()? + 1),
}
}
}
impl TryFrom<PartialTreeNode> for TreeNode {
type Error = IncompleteTreeError<PartialTreeNode>;
fn try_from(partial_node: PartialTreeNode) -> Result<Self, Self::Error> {
Ok(match partial_node {
PartialTreeNode::Leaf(leaf_script, depth) => TreeNode::Leaf(leaf_script, depth),
ref node @ PartialTreeNode::Branch(ref branch, depth) => TreeNode::with_branch(
branch
.first
.as_ref()
.ok_or_else(|| IncompleteTreeError(node.clone()))?
.deref()
.clone()
.try_into()?,
branch
.second
.as_ref()
.ok_or_else(|| IncompleteTreeError(node.clone()))?
.deref()
.clone()
.try_into()?,
depth,
),
})
}
}
impl Display for TreeNode {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
for (node, path) in self.nodes() {
match node {
TreeNode::Leaf(leaf_script, depth) => {
writeln!(f, "{} ({}): {}", path, depth, leaf_script)?;
}
TreeNode::Hidden(hash, depth) => writeln!(f, "{} ({}): {}", path, depth, hash)?,
TreeNode::Branch(_, _) => {}
}
}
Ok(())
}
}
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
pub struct PartialBranchNode {
hash: TapBranchHash,
first: Option<Box<PartialTreeNode>>,
second: Option<Box<PartialTreeNode>>,
}
impl Branch for PartialBranchNode {
fn subtree_depth(&self) -> Option<u8> {
Some(
self.first
.as_ref()?
.subtree_depth()?
.max(self.second.as_ref()?.subtree_depth()?),
)
}
fn dfs_ordering(&self) -> DfsOrdering {
match (
self.first
.as_ref()
.map(Box::as_ref)
.and_then(PartialTreeNode::subtree_depth),
self.second
.as_ref()
.map(Box::as_ref)
.and_then(PartialTreeNode::subtree_depth),
) {
(Some(first), Some(second)) => match first.cmp(&second) {
Ordering::Equal => DfsOrdering::LeftRight,
Ordering::Less => DfsOrdering::LeftRight,
Ordering::Greater => DfsOrdering::RightLeft,
},
_ => DfsOrdering::LeftRight,
}
}
fn branch_hash(&self) -> TapBranchHash { self.hash }
}
impl PartialBranchNode {
pub fn with(hash: TapBranchHash) -> Self {
PartialBranchNode {
hash,
first: None,
second: None,
}
}
pub fn push_child(&mut self, child: PartialTreeNode) -> Option<&mut PartialTreeNode> {
let child = Box::new(child);
if let Some(first) = &self.first {
if first.node_hash() == child.node_hash() {
return self.first.as_deref_mut();
}
} else {
self.first = Some(child);
return self.first.as_deref_mut();
}
if let Some(second) = &self.second {
if second.node_hash() == child.node_hash() {
self.second.as_deref_mut()
} else {
None
}
} else {
self.second = Some(child);
self.second.as_deref_mut()
}
}
#[inline]
pub fn node_hash(&self) -> TapNodeHash { TapNodeHash::from_inner(self.hash.into_inner()) }
}
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
pub enum PartialTreeNode {
Leaf(LeafScript, u8),
Branch(PartialBranchNode, u8),
}
impl PartialTreeNode {
pub fn with_leaf(leaf_version: LeafVersion, script: Script, depth: u8) -> PartialTreeNode {
PartialTreeNode::Leaf(LeafScript::with(leaf_version, script.into()), depth)
}
pub fn with_branch(hash: TapBranchHash, depth: u8) -> PartialTreeNode {
PartialTreeNode::Branch(PartialBranchNode::with(hash), depth)
}
pub fn as_branch(&self) -> Option<&PartialBranchNode> {
match self {
PartialTreeNode::Leaf(_, _) => None,
PartialTreeNode::Branch(branch, _) => Some(branch),
}
}
pub fn as_branch_mut(&mut self) -> Option<&mut PartialBranchNode> {
match self {
PartialTreeNode::Leaf(_, _) => None,
PartialTreeNode::Branch(branch, _) => Some(branch),
}
}
}
impl Node for PartialTreeNode {
#[inline]
fn is_hidden(&self) -> bool { false }
fn is_branch(&self) -> bool { matches!(self, PartialTreeNode::Branch(..)) }
fn is_leaf(&self) -> bool { matches!(self, PartialTreeNode::Leaf(..)) }
fn node_hash(&self) -> TapNodeHash {
match self {
PartialTreeNode::Leaf(leaf_script, _) => leaf_script.tap_leaf_hash().into_node_hash(),
PartialTreeNode::Branch(branch, _) => branch.node_hash(),
}
}
fn node_depth(&self) -> u8 {
match self {
PartialTreeNode::Leaf(_, depth) | PartialTreeNode::Branch(_, depth) => *depth,
}
}
fn subtree_depth(&self) -> Option<u8> {
match self {
PartialTreeNode::Leaf(_, _) => Some(0),
PartialTreeNode::Branch(branch, _) => branch.subtree_depth(),
}
}
}
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug, Display)]
#[derive(StrictEncode, StrictDecode)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[display("{root}")]
pub struct TaprootScriptTree {
root: TreeNode,
}
impl AsRef<TreeNode> for TaprootScriptTree {
#[inline]
fn as_ref(&self) -> &TreeNode { &self.root }
}
impl Borrow<TreeNode> for TaprootScriptTree {
#[inline]
fn borrow(&self) -> &TreeNode { &self.root }
}
impl BorrowMut<TreeNode> for TaprootScriptTree {
#[inline]
fn borrow_mut(&mut self) -> &mut TreeNode { &mut self.root }
}
impl TaprootScriptTree {
#[inline]
pub fn with(root: TreeNode) -> Result<TaprootScriptTree, TaprootTreeError> {
root.check()?;
Ok(TaprootScriptTree { root })
}
#[stability::unstable(reason = "not sufficiently tested")]
#[inline]
pub fn with_fixes(root: TreeNode) -> TaprootScriptTree {
let mut tree = TaprootScriptTree { root };
tree.fix();
tree
}
#[inline]
pub fn scripts(&self) -> TreeScriptIter { TreeScriptIter::from(self) }
#[inline]
pub fn nodes(&self) -> TreeNodeIter { TreeNodeIter::from(self) }
#[inline]
pub(self) fn nodes_mut(&mut self) -> TreeNodeIterMut { TreeNodeIterMut::from(self) }
pub fn nodes_on_path<'node, 'path>(
&'node self,
path: &'path [DfsOrder],
) -> TreePathIter<'node, 'path> {
self.root.nodes_on_path(path)
}
#[inline]
pub fn node_at(&self, path: impl AsRef<[DfsOrder]>) -> Result<&TreeNode, DfsTraversalError> {
self.root.node_at(path)
}
#[inline]
pub(self) fn node_mut_at<'path>(
&mut self,
path: impl IntoIterator<Item = &'path DfsOrder>,
) -> Result<&mut TreeNode, DfsTraversalError> {
self.root.node_mut_at(path)
}
fn update_ancestors_ordering(&mut self, path: impl Borrow<[DfsOrder]>) {
let path = path.borrow();
for step in (0..path.len()).rev() {
let ancestor = self
.node_mut_at(&path[..step])
.expect("the path must be checked to be valid");
let branch = if let Some(branch) = ancestor.as_branch_mut() {
branch
} else {
return;
};
if branch.left.node_hash() > branch.right.node_hash() {
branch.dfs_ordering = !branch.dfs_ordering;
let old_left = branch.as_left_node().clone();
let old_right = branch.as_right_node().clone();
let left = branch.as_left_node_mut();
*left = old_right;
let right = branch.as_right_node_mut();
*right = old_left;
}
}
}
#[inline]
pub fn join(
mut self,
other_tree: TaprootScriptTree,
other_dfs_order: DfsOrder,
) -> Result<TaprootScriptTree, MaxDepthExceeded> {
self.instill(other_tree, [], other_dfs_order)
.map_err(|_| MaxDepthExceeded)?;
Ok(self)
}
pub fn split(self) -> Result<(TaprootScriptTree, TaprootScriptTree), UnsplittableTree> {
self.cut([], DfsOrder::First).map_err(|_| UnsplittableTree)
}
pub fn instill(
&mut self,
mut other_tree: TaprootScriptTree,
path: impl AsRef<[DfsOrder]>,
dfs_order: DfsOrder,
) -> Result<DfsPath, InstillError> {
let path = path.as_ref();
let depth: u8 = path.len().try_into().map_err(|_| MaxDepthExceeded)?;
let instill_point = self.node_mut_at(path)?;
for n in instill_point.nodes_mut() {
n.lower(1)?;
}
for n in other_tree.nodes_mut() {
n.lower(depth.checked_add(1).ok_or(MaxDepthExceeded)?)?;
}
let instill_root = other_tree.into_root_node();
let branch = if dfs_order == DfsOrder::First {
BranchNode::with(instill_root, instill_point.clone())
} else {
BranchNode::with(instill_point.clone(), instill_root)
};
*instill_point = TreeNode::Branch(branch, depth);
self.update_ancestors_ordering(path);
let mut path = DfsPath::with(path);
path.push(dfs_order);
Ok(path)
}
pub fn cut(
mut self,
path: impl AsRef<[DfsOrder]>,
dfs_side: DfsOrder,
) -> Result<(TaprootScriptTree, TaprootScriptTree), CutError> {
let path = path.as_ref();
let depth: u8 = path
.len()
.try_into()
.map_err(|_| DfsTraversalError::PathNotExists(path.to_vec().into()))?;
let (mut cut, mut remnant) = match self.node_at(path)? {
TreeNode::Leaf(_, _) | TreeNode::Hidden(_, _) => {
return Err(CutError::UnsplittableTree)
}
TreeNode::Branch(branch, _) if dfs_side == DfsOrder::First => {
branch.clone().split_dfs()
}
TreeNode::Branch(branch, _) => {
let (remnant, cut) = branch.clone().split_dfs();
(cut, remnant)
}
};
for n in cut.nodes_mut() {
n.raise(depth + 1)
.expect("broken taproot tree cut algorithm");
}
for n in remnant.nodes_mut() {
n.raise(1).expect("broken taproot tree cut algorithm");
}
let mut path_iter = path.iter();
if let Some(last_step) = path_iter.next_back() {
let cut_parent = self.node_mut_at(path_iter)?;
let parent_branch_node = cut_parent
.as_branch_mut()
.expect("parent node always a branch node at this point");
let replaced_child = match last_step {
DfsOrder::First => parent_branch_node.as_dfs_first_node_mut(),
DfsOrder::Last => parent_branch_node.as_dfs_last_node_mut(),
};
*replaced_child = remnant;
} else {
self = TaprootScriptTree { root: remnant };
}
let subtree = TaprootScriptTree { root: cut };
self.update_ancestors_ordering(path);
Ok((self, subtree))
}
#[inline]
pub fn as_root_node(&self) -> &TreeNode { &self.root }
#[inline]
pub fn into_root_node(self) -> TreeNode { self.root }
#[inline]
pub fn to_root_node(&self) -> TreeNode { self.root.clone() }
#[inline]
#[stability::unstable(
reason = "current stable API assumes that taproot script trees always have correct \
structure"
)]
pub fn check(&self) -> Result<(), TaprootTreeError> { self.root.check() }
#[stability::unstable(reason = "not sufficiently tested")]
fn fix(&mut self) -> usize {
let mut fix_count = 0usize;
while self.check().is_err() {
let mut path = None;
for (node, p) in self.nodes() {
if node.is_leaf() || node.is_hidden() {
path = Some(p);
break;
}
}
if let Some(path) = path {
self.update_ancestors_ordering(path);
fix_count += 1;
}
}
fix_count
}
}
impl From<TapTree> for TaprootScriptTree {
fn from(tree: TapTree) -> Self {
let mut root: Option<PartialTreeNode> = None;
let mut script_leaves = tree.script_leaves().collect::<Vec<_>>();
script_leaves.reverse();
for leaf in script_leaves {
let merkle_branch = leaf.merkle_branch().as_inner();
let leaf_depth = merkle_branch.len() as u8;
let mut curr_hash =
TapLeafHash::from_script(leaf.script(), leaf.leaf_version()).into_node_hash();
let merkle_branch = merkle_branch
.iter()
.map(|step| {
curr_hash = TapBranchHash::from_node_hashes(*step, curr_hash).into_node_hash();
curr_hash
})
.collect::<Vec<_>>();
let mut hash_iter = merkle_branch.iter().rev();
match (root.is_some(), hash_iter.next()) {
(false, None) => {
root = Some(PartialTreeNode::with_leaf(
leaf.leaf_version(),
leaf.script().clone(),
0,
))
}
(false, Some(hash)) => {
root = Some(PartialTreeNode::with_branch(
TapBranchHash::from_inner(hash.into_inner()),
0,
))
}
(true, None) => unreachable!("broken TapTree structure"),
(true, Some(_)) => {}
}
let mut node = root.as_mut().expect("unreachable");
for (depth, hash) in hash_iter.enumerate() {
match node {
PartialTreeNode::Leaf(..) => unreachable!("broken TapTree structure"),
PartialTreeNode::Branch(branch, _) => {
let child = PartialTreeNode::with_branch(
TapBranchHash::from_inner(hash.into_inner()),
depth as u8 + 1,
);
node = branch.push_child(child).expect("broken TapTree structure");
}
}
}
let leaf =
PartialTreeNode::with_leaf(leaf.leaf_version(), leaf.script().clone(), leaf_depth);
match node {
PartialTreeNode::Leaf(..) => { }
PartialTreeNode::Branch(branch, _) => {
branch.push_child(leaf);
}
}
}
let root = root
.map(TreeNode::try_from)
.transpose()
.ok()
.flatten()
.expect("broken TapTree structure");
TaprootScriptTree { root }
}
}
pub struct TreePathIter<'tree, 'path> {
next_node: Option<&'tree TreeNode>,
full_path: &'path [DfsOrder],
remaining_path: core::slice::Iter<'path, DfsOrder>,
}
impl<'tree, 'path> Iterator for TreePathIter<'tree, 'path> {
type Item = Result<&'tree TreeNode, DfsTraversalError>;
fn next(&mut self) -> Option<Self::Item> {
match (self.next_node, self.remaining_path.next()) {
(Some(curr_node), Some(step)) => {
match curr_node.node_at([*step]) {
Err(err) => return Some(Err(err)),
Ok(next_node) => self.next_node = Some(next_node),
}
Some(Ok(curr_node))
}
(Some(curr_node), None) => {
self.next_node = None;
Some(Ok(curr_node))
}
(None, None) => None,
(None, Some(_)) => Some(Err(DfsTraversalError::PathNotExists(DfsPath::with(
self.full_path,
)))),
}
}
}
pub struct TreeNodeIter<'tree> {
stack: Vec<(&'tree TreeNode, DfsPath)>,
}
impl<'tree, T> From<&'tree T> for TreeNodeIter<'tree>
where
T: Borrow<TreeNode>,
{
fn from(tree: &'tree T) -> Self {
TreeNodeIter {
stack: vec![(tree.borrow(), DfsPath::new())],
}
}
}
impl<'tree> Iterator for TreeNodeIter<'tree> {
type Item = (&'tree TreeNode, DfsPath);
fn next(&mut self) -> Option<Self::Item> {
let (curr, path) = self.stack.pop()?;
if let TreeNode::Branch(branch, _) = curr {
let mut p = path.clone();
p.push(DfsOrder::First);
self.stack.push((branch.as_dfs_first_node(), p.clone()));
p.pop();
p.push(DfsOrder::Last);
self.stack.push((branch.as_dfs_last_node(), p));
}
Some((curr, path))
}
}
struct TreeNodeIterMut<'tree> {
root: &'tree mut TreeNode,
stack: Vec<Vec<DfsOrder>>,
}
impl<'tree, T> From<&'tree mut T> for TreeNodeIterMut<'tree>
where
T: BorrowMut<TreeNode>,
{
fn from(tree: &'tree mut T) -> Self {
TreeNodeIterMut {
root: tree.borrow_mut(),
stack: vec![vec![]],
}
}
}
impl<'tree> Iterator for TreeNodeIterMut<'tree> {
type Item = &'tree mut TreeNode;
fn next(&mut self) -> Option<Self::Item> {
let mut path = self.stack.pop()?;
let mut curr = unsafe { &mut *(self.root as *mut TreeNode) as &'tree mut TreeNode };
for step in &path {
let branch = match curr {
TreeNode::Branch(branch, _) => branch,
_ => unreachable!("iteration algorithm is broken"),
};
curr = match step {
DfsOrder::First => branch.as_dfs_first_node_mut(),
DfsOrder::Last => branch.as_dfs_last_node_mut(),
};
}
if curr.is_branch() {
path.push(DfsOrder::First);
self.stack.push(path.clone());
path.pop();
path.push(DfsOrder::Last);
self.stack.push(path);
}
Some(curr)
}
}
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
enum BranchDirection {
Shallow,
Deep,
}
pub struct TreeScriptIter<'tree> {
path: Vec<(&'tree TreeNode, BranchDirection)>,
}
impl<'tree, T> From<&'tree T> for TreeScriptIter<'tree>
where
T: Borrow<TreeNode>,
{
fn from(tree: &'tree T) -> Self {
TreeScriptIter {
path: vec![(tree.borrow(), BranchDirection::Shallow)],
}
}
}
impl<'tree> Iterator for TreeScriptIter<'tree> {
type Item = (u8, &'tree LeafScript);
fn next(&mut self) -> Option<Self::Item> {
while let Some((node, mut side)) = self.path.pop() {
let mut curr = node;
loop {
match curr {
TreeNode::Leaf(leaf_script, depth) => {
return Some((*depth, leaf_script));
}
TreeNode::Hidden(..) => break,
TreeNode::Branch(branch, _) if side == BranchDirection::Shallow => {
self.path.push((curr, BranchDirection::Deep));
curr = branch.as_dfs_first_node();
side = BranchDirection::Shallow;
continue;
}
TreeNode::Branch(branch, _) => {
curr = branch.as_dfs_last_node();
side = BranchDirection::Shallow;
continue;
}
}
}
}
None
}
}
impl<'tree> IntoIterator for &'tree TaprootScriptTree {
type Item = (u8, &'tree LeafScript);
type IntoIter = TreeScriptIter<'tree>;
#[inline]
fn into_iter(self) -> Self::IntoIter { self.scripts() }
}
impl From<&TaprootScriptTree> for TapTree {
fn from(tree: &TaprootScriptTree) -> Self {
let mut builder = TaprootBuilder::new();
for (depth, leaf_script) in tree.scripts() {
builder = builder
.add_leaf_with_ver(depth, leaf_script.script.to_inner(), leaf_script.version)
.expect("broken TaprootScriptTree");
}
TapTree::try_from(builder).expect("broken TaprootScriptTree")
}
}
impl From<TaprootScriptTree> for TapTree {
#[inline]
fn from(tree: TaprootScriptTree) -> Self { TapTree::from(&tree) }
}
#[cfg(test)]
mod test {
use std::collections::BTreeSet;
use amplify::Wrapper;
use bitcoin::blockdata::opcodes::all;
use bitcoin::hashes::hex::FromHex;
use bitcoin::util::taproot::TaprootBuilder;
use super::*;
fn compose_tree(opcode: u8, depth_map: impl IntoIterator<Item = u8>) -> TapTree {
let mut val = opcode;
let mut builder = TaprootBuilder::new();
for depth in depth_map {
let script = Script::from_hex(&format!("{:02x}", val)).unwrap();
builder = builder.add_leaf(depth, script).unwrap();
let (new_val, _) = val.overflowing_add(1);
val = new_val;
}
TapTree::try_from(builder).unwrap()
}
fn test_tree(opcode: u8, depth_map: impl IntoIterator<Item = u8>) {
let taptree = compose_tree(opcode, depth_map);
let script_tree = TaprootScriptTree::from(taptree.clone());
let scripts = taptree
.script_leaves()
.map(|leaf| {
(
leaf.merkle_branch().as_inner().len() as u8,
leaf.leaf_version(),
leaf.script(),
)
})
.collect::<BTreeSet<_>>();
let scripts_prime = script_tree
.scripts()
.map(|(depth, leaf_script)| (depth, leaf_script.version, leaf_script.script.as_inner()))
.collect::<BTreeSet<_>>();
assert_eq!(scripts, scripts_prime);
let taptree_prime = TapTree::from(&script_tree);
assert_eq!(taptree, taptree_prime);
}
fn test_join_split(depth_map: impl IntoIterator<Item = u8>) {
let taptree = compose_tree(0x51, depth_map);
let script_tree = TaprootScriptTree::from(taptree);
assert!(script_tree.check().is_ok());
let instill_tree: TaprootScriptTree = compose_tree(all::OP_RETURN.to_u8(), [0]).into();
let merged_tree = script_tree
.clone()
.join(instill_tree.clone(), DfsOrder::First)
.unwrap();
assert!(merged_tree.check().is_ok());
let _ = TapTree::from(&merged_tree);
assert_ne!(merged_tree, script_tree);
let order = merged_tree.root.as_branch().unwrap().dfs_ordering;
match (
merged_tree.node_at([DfsOrder::First]).unwrap(),
merged_tree.node_at([DfsOrder::Last]).unwrap(),
order,
) {
(TreeNode::Leaf(leaf_script, 1), _, DfsOrdering::LeftRight)
| (TreeNode::Leaf(leaf_script, 1), _, DfsOrdering::RightLeft)
if leaf_script.script[0] == all::OP_RETURN.to_u8() =>
{
}
(_, TreeNode::Leaf(leaf_script, 1), ordering)
if leaf_script.script[0] == all::OP_RETURN.to_u8() =>
{
panic!(
"instilled tree with script `{:?}` has incorrect DFS ordering {:?}",
leaf_script.script, ordering
)
}
(TreeNode::Leaf(_, x), _, _) => {
panic!("broken mergged tree depth of first branches: {}", x);
}
_ => panic!("instilled tree is not present as first branch of the merged tree"),
}
let (script_tree_prime, instill_tree_prime) = merged_tree.split().unwrap();
assert!(script_tree_prime.check().is_ok());
assert!(instill_tree_prime.check().is_ok());
assert_eq!(instill_tree, instill_tree_prime);
assert_eq!(script_tree, script_tree_prime);
}
fn test_instill_cut(
depth_map1: impl IntoIterator<Item = u8>,
depth_map2: impl IntoIterator<Item = u8>,
path: &str,
) {
let path = DfsPath::from_str(path).unwrap();
let taptree = compose_tree(0x51, depth_map1);
let script_tree = TaprootScriptTree::from(taptree);
assert!(script_tree.check().is_ok());
let instill_tree: TaprootScriptTree = compose_tree(50, depth_map2).into();
assert!(instill_tree.check().is_ok());
let mut merged_tree = script_tree.clone();
merged_tree
.instill(instill_tree.clone(), &path, DfsOrder::First)
.unwrap();
assert!(merged_tree.check().is_ok());
let _ = TapTree::from(&merged_tree);
assert_ne!(merged_tree, script_tree);
let (script_tree_prime, instill_tree_prime) =
merged_tree.cut(path, DfsOrder::First).unwrap();
assert!(script_tree_prime.check().is_ok());
assert!(instill_tree_prime.check().is_ok());
assert_eq!(instill_tree, instill_tree_prime);
assert_eq!(script_tree, script_tree_prime);
}
fn testsuite_tree_structures(opcode: u8) {
test_tree(opcode, [0]);
test_tree(opcode, [1, 1]);
test_tree(opcode, [1, 2, 2]);
test_tree(opcode, [2, 2, 2, 2]);
test_tree(opcode, [1, 2, 3, 3]);
test_tree(opcode, [1, 3, 3, 3, 3]);
test_tree(opcode, [2, 2, 2, 3, 3]);
test_tree(opcode, [2, 2, 3, 3, 3, 3]);
test_tree(opcode, [2, 3, 3, 3, 3, 3, 3]);
test_tree(opcode, [3, 3, 3, 3, 3, 3, 3, 3]);
}
#[test]
fn taptree_parsing() {
testsuite_tree_structures(0x51);
testsuite_tree_structures(51);
testsuite_tree_structures(0);
testsuite_tree_structures(0x80);
}
#[test]
fn taptree_edge_ops() {
let taptree = compose_tree(0x51, [0]);
let script_tree = TaprootScriptTree::from(taptree);
assert!(script_tree.check().is_ok());
assert_eq!(
script_tree.clone().cut([], DfsOrder::First).unwrap_err(),
CutError::UnsplittableTree
);
assert_eq!(
script_tree.cut([], DfsOrder::Last).unwrap_err(),
CutError::UnsplittableTree
);
}
#[test]
fn taptree_join_split() {
test_join_split([0]);
test_join_split([1, 1]);
test_join_split([1, 2, 2]);
test_join_split([2, 2, 2, 2]);
test_join_split([1, 2, 3, 3]);
test_join_split([1, 3, 3, 3, 3]);
test_join_split([2, 2, 2, 3, 3]);
test_join_split([2, 2, 3, 3, 3, 3]);
test_join_split([2, 3, 3, 3, 3, 3, 3]);
test_join_split([3, 3, 3, 3, 3, 3, 3, 3]);
}
#[test]
fn taptree_instill_cut() {
test_instill_cut([2, 2, 2, 3, 3], [0], "");
test_instill_cut([2, 2, 2, 3, 3], [0], "0");
test_instill_cut([2, 2, 2, 3, 3], [0], "1");
test_instill_cut([2, 2, 2, 3, 3], [0], "00");
test_instill_cut([2, 2, 2, 3, 3], [0], "01");
test_instill_cut([2, 2, 2, 3, 3], [0], "10");
test_instill_cut([2, 2, 2, 3, 3], [0], "11");
test_instill_cut([2, 2, 2, 3, 3], [0], "110");
test_instill_cut([2, 2, 2, 3, 3], [0], "111");
test_instill_cut([2, 2, 2, 3, 3], [1, 2, 3, 3], "");
test_instill_cut([2, 2, 2, 3, 3], [1, 2, 3, 3], "0");
test_instill_cut([2, 2, 2, 3, 3], [1, 2, 3, 3], "1");
test_instill_cut([2, 2, 2, 3, 3], [1, 2, 3, 3], "00");
test_instill_cut([2, 2, 2, 3, 3], [1, 2, 3, 3], "01");
test_instill_cut([2, 2, 2, 3, 3], [1, 2, 3, 3], "10");
test_instill_cut([2, 2, 2, 3, 3], [1, 2, 3, 3], "11");
test_instill_cut([2, 2, 2, 3, 3], [1, 2, 3, 3], "110");
test_instill_cut([2, 2, 2, 3, 3], [1, 2, 3, 3], "111");
}
#[test]
fn instill_path_proof() {
let path = DfsPath::from_str("00101").unwrap();
let taptree = compose_tree(0x51, [3, 5, 5, 4, 3, 3, 2, 3, 4, 5, 6, 8, 8, 7]);
let script_tree = TaprootScriptTree::from(taptree);
assert!(script_tree.check().is_ok());
let instill_tree: TaprootScriptTree = compose_tree(50, [2, 2, 2, 3, 3]).into();
assert!(instill_tree.check().is_ok());
let mut merged_tree = script_tree;
let instill_path = merged_tree
.instill(instill_tree, &path, DfsOrder::First)
.unwrap();
assert!(merged_tree.check().is_ok());
#[derive(PartialEq, Eq, Debug)]
enum PartnerNode {
Script(String),
Hash(TapNodeHash),
}
let path_partners = merged_tree
.nodes_on_path(&instill_path)
.zip(&instill_path)
.map(|(node, step)| {
let branch = node.unwrap().as_branch().unwrap();
match branch.as_dfs_child_node(!step) {
TreeNode::Leaf(script, _) => {
PartnerNode::Script(script.script.as_inner().to_string())
}
TreeNode::Hidden(node, _) => PartnerNode::Hash(*node),
TreeNode::Branch(node, _) => {
PartnerNode::Hash(node.branch_hash().into_node_hash())
}
}
})
.collect::<Vec<_>>();
assert_eq!(path_partners, vec![
PartnerNode::Hash(
"e1cc80c5229fa380040f65495b5a7adf102ec6b1bfe51b5c3dbda04ee258529f"
.parse()
.unwrap()
),
PartnerNode::Hash(
"ddad73a07b9a7725185f19d6772b02bd4b3a5525d05afde705c186cdcf588c37"
.parse()
.unwrap()
),
PartnerNode::Script(s!("Script(OP_PUSHNUM_1)")),
PartnerNode::Script(s!("Script(OP_PUSHNUM_4)")),
PartnerNode::Script(s!("Script(OP_PUSHNUM_2)")),
PartnerNode::Script(s!("Script(OP_PUSHNUM_3)")),
]);
}
#[test]
fn tapscripttree_roudtrip() {
let taptree = compose_tree(0x51, [3, 5, 5, 4, 3, 3, 2, 3, 4, 5, 6, 8, 8, 7]);
let script_tree = TaprootScriptTree::from(taptree.clone());
let taptree_roundtrip = TapTree::from(script_tree);
assert_eq!(taptree, taptree_roundtrip);
}
#[test]
fn tapscripttree_taptree_eq() {
let taptree = compose_tree(0x51, [3, 5, 5, 4, 3, 3, 2, 3, 4, 5, 6, 8, 8, 7]);
let script_tree = TaprootScriptTree::from(taptree.clone());
assert!(script_tree.check().is_ok());
let mut script_leaves = taptree.script_leaves().collect::<Vec<_>>();
script_leaves.reverse();
for (leaf, (_, leaf_script)) in script_leaves.iter().zip(script_tree.scripts()) {
assert_eq!(leaf.script(), leaf_script.script.as_inner());
}
}
#[test]
fn tapscripttree_dfs() {
let depth_map = [3, 5, 5, 4, 3, 3, 2, 3, 4, 5, 6, 8, 8, 7];
let mut val = 0x51;
let taptree = compose_tree(val, depth_map);
let script_tree = TaprootScriptTree::from(taptree);
assert!(script_tree.check().is_ok());
for (depth, leaf_script) in script_tree.scripts() {
let script = Script::from_hex(&format!("{:02x}", val)).unwrap();
assert_eq!(depth, depth_map[(val - 0x51) as usize]);
assert_eq!(script, leaf_script.script.to_inner());
let (new_val, _) = val.overflowing_add(1);
val = new_val;
}
}
}