1use core::cmp::{self, max};
4use core::str::FromStr;
5use core::{fmt, hash};
6
7use bellscoin::taproot::{
8 LeafVersion, TaprootBuilder, TaprootSpendInfo, TAPROOT_CONTROL_BASE_SIZE,
9 TAPROOT_CONTROL_MAX_NODE_COUNT, TAPROOT_CONTROL_NODE_SIZE,
10};
11use bellscoin::{opcodes, secp256k1, Address, Network, ScriptBuf};
12use sync::Arc;
13
14use super::checksum::{self, verify_checksum};
15use crate::expression::{self, FromTree};
16use crate::miniscript::Miniscript;
17use crate::policy::semantic::Policy;
18use crate::policy::Liftable;
19use crate::prelude::*;
20use crate::util::{varint_len, witness_size};
21use crate::{
22 errstr, Error, ForEachKey, MiniscriptKey, Satisfier, ScriptContext, Tap, ToPublicKey,
23 TranslateErr, TranslatePk, Translator,
24};
25
26#[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
30pub enum TapTree<Pk: MiniscriptKey> {
31 Tree(Arc<TapTree<Pk>>, Arc<TapTree<Pk>>),
33 Leaf(Arc<Miniscript<Pk, Tap>>),
38}
39
40pub struct Tr<Pk: MiniscriptKey> {
42 internal_key: Pk,
44 tree: Option<TapTree<Pk>>,
46 spend_info: Mutex<Option<Arc<TaprootSpendInfo>>>,
54}
55
56impl<Pk: MiniscriptKey> Clone for Tr<Pk> {
57 fn clone(&self) -> Self {
58 Self {
62 internal_key: self.internal_key.clone(),
63 tree: self.tree.clone(),
64 spend_info: Mutex::new(
65 self.spend_info
66 .lock()
67 .expect("Lock poisoned")
68 .as_ref()
69 .map(Arc::clone),
70 ),
71 }
72 }
73}
74
75impl<Pk: MiniscriptKey> PartialEq for Tr<Pk> {
76 fn eq(&self, other: &Self) -> bool {
77 self.internal_key == other.internal_key && self.tree == other.tree
78 }
79}
80
81impl<Pk: MiniscriptKey> Eq for Tr<Pk> {}
82
83impl<Pk: MiniscriptKey> PartialOrd for Tr<Pk> {
84 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
85 match self.internal_key.partial_cmp(&other.internal_key) {
86 Some(cmp::Ordering::Equal) => {}
87 ord => return ord,
88 }
89 self.tree.partial_cmp(&other.tree)
90 }
91}
92
93impl<Pk: MiniscriptKey> Ord for Tr<Pk> {
94 fn cmp(&self, other: &Self) -> cmp::Ordering {
95 match self.internal_key.cmp(&other.internal_key) {
96 cmp::Ordering::Equal => {}
97 ord => return ord,
98 }
99 self.tree.cmp(&other.tree)
100 }
101}
102
103impl<Pk: MiniscriptKey> hash::Hash for Tr<Pk> {
104 fn hash<H: hash::Hasher>(&self, state: &mut H) {
105 self.internal_key.hash(state);
106 self.tree.hash(state);
107 }
108}
109
110impl<Pk: MiniscriptKey> TapTree<Pk> {
111 fn taptree_height(&self) -> usize {
115 match *self {
116 TapTree::Tree(ref left_tree, ref right_tree) => {
117 1 + max(left_tree.taptree_height(), right_tree.taptree_height())
118 }
119 TapTree::Leaf(..) => 0,
120 }
121 }
122
123 pub fn iter(&self) -> TapTreeIter<Pk> {
126 TapTreeIter {
127 stack: vec![(0, self)],
128 }
129 }
130
131 fn translate_helper<T, Q, E>(&self, t: &mut T) -> Result<TapTree<Q>, TranslateErr<E>>
133 where
134 T: Translator<Pk, Q, E>,
135 Q: MiniscriptKey,
136 {
137 let frag = match self {
138 TapTree::Tree(l, r) => TapTree::Tree(
139 Arc::new(l.translate_helper(t)?),
140 Arc::new(r.translate_helper(t)?),
141 ),
142 TapTree::Leaf(ms) => TapTree::Leaf(Arc::new(ms.translate_pk(t)?)),
143 };
144 Ok(frag)
145 }
146}
147
148impl<Pk: MiniscriptKey> fmt::Display for TapTree<Pk> {
149 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150 match self {
151 TapTree::Tree(ref left, ref right) => write!(f, "{{{},{}}}", *left, *right),
152 TapTree::Leaf(ref script) => write!(f, "{}", *script),
153 }
154 }
155}
156
157impl<Pk: MiniscriptKey> fmt::Debug for TapTree<Pk> {
158 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159 match self {
160 TapTree::Tree(ref left, ref right) => write!(f, "{{{:?},{:?}}}", *left, *right),
161 TapTree::Leaf(ref script) => write!(f, "{:?}", *script),
162 }
163 }
164}
165
166impl<Pk: MiniscriptKey> Tr<Pk> {
167 pub fn new(internal_key: Pk, tree: Option<TapTree<Pk>>) -> Result<Self, Error> {
169 Tap::check_pk(&internal_key)?;
170 let nodes = tree.as_ref().map(|t| t.taptree_height()).unwrap_or(0);
171
172 if nodes <= TAPROOT_CONTROL_MAX_NODE_COUNT {
173 Ok(Self {
174 internal_key,
175 tree,
176 spend_info: Mutex::new(None),
177 })
178 } else {
179 Err(Error::MaxRecursiveDepthExceeded)
180 }
181 }
182
183 pub fn internal_key(&self) -> &Pk {
185 &self.internal_key
186 }
187
188 pub fn taptree(&self) -> &Option<TapTree<Pk>> {
190 &self.tree
191 }
192
193 pub fn iter_scripts(&self) -> TapTreeIter<Pk> {
196 match self.tree {
197 Some(ref t) => t.iter(),
198 None => TapTreeIter { stack: vec![] },
199 }
200 }
201
202 pub fn spend_info(&self) -> Arc<TaprootSpendInfo>
208 where
209 Pk: ToPublicKey,
210 {
211 let read_lock = self.spend_info.lock().expect("Lock poisoned");
214 if let Some(ref spend_info) = *read_lock {
215 return Arc::clone(spend_info);
216 }
217 drop(read_lock);
218
219 let secp = secp256k1::Secp256k1::verification_only();
222 let data = if self.tree.is_none() {
224 TaprootSpendInfo::new_key_spend(&secp, self.internal_key.to_x_only_pubkey(), None)
225 } else {
226 let mut builder = TaprootBuilder::new();
227 for (depth, ms) in self.iter_scripts() {
228 let script = ms.encode();
229 builder = builder
230 .add_leaf(depth, script)
231 .expect("Computing spend data on a valid Tree should always succeed");
232 }
233 match builder.finalize(&secp, self.internal_key.to_x_only_pubkey()) {
235 Ok(data) => data,
236 Err(_) => unreachable!("We know the builder can be finalized"),
237 }
238 };
239 let spend_info = Arc::new(data);
240 *self.spend_info.lock().expect("Lock poisoned") = Some(Arc::clone(&spend_info));
241 spend_info
242 }
243
244 pub fn sanity_check(&self) -> Result<(), Error> {
246 for (_depth, ms) in self.iter_scripts() {
247 ms.sanity_check()?;
248 }
249 Ok(())
250 }
251
252 pub fn max_weight_to_satisfy(&self) -> Result<usize, Error> {
261 let tree = match self.taptree() {
262 None => {
263 let item_sig_size = 1 + 65;
266 let stack_varint_diff = varint_len(1) - varint_len(0);
268
269 return Ok(stack_varint_diff + item_sig_size);
270 }
271 Some(tree) => tree,
273 };
274
275 tree.iter()
276 .filter_map(|(depth, ms)| {
277 let script_size = ms.script_size();
278 let max_sat_elems = ms.max_satisfaction_witness_elements().ok()?;
279 let max_sat_size = ms.max_satisfaction_size().ok()?;
280 let control_block_size = control_block_len(depth);
281
282 let stack_varint_diff = varint_len(max_sat_elems + 1) - varint_len(0);
284
285 Some(
286 stack_varint_diff +
287 max_sat_size +
289 varint_len(script_size) +
291 script_size +
292 varint_len(control_block_size) +
294 control_block_size,
295 )
296 })
297 .max()
298 .ok_or(Error::ImpossibleSatisfaction)
299 }
300
301 #[deprecated(note = "use max_weight_to_satisfy instead")]
311 pub fn max_satisfaction_weight(&self) -> Result<usize, Error> {
312 let tree = match self.taptree() {
313 None => return Ok(4 + 1 + 1 + 65),
316 Some(tree) => tree,
318 };
319
320 tree.iter()
321 .filter_map(|(depth, ms)| {
322 let script_size = ms.script_size();
323 let max_sat_elems = ms.max_satisfaction_witness_elements().ok()?;
324 let max_sat_size = ms.max_satisfaction_size().ok()?;
325 let control_block_size = control_block_len(depth);
326 Some(
327 4 +
329 varint_len(max_sat_elems + 2) +
331 max_sat_size +
333 varint_len(script_size) +
335 script_size +
336 varint_len(control_block_size) +
338 control_block_size,
339 )
340 })
341 .max()
342 .ok_or(Error::ImpossibleSatisfaction)
343 }
344}
345
346impl<Pk: MiniscriptKey + ToPublicKey> Tr<Pk> {
347 pub fn script_pubkey(&self) -> ScriptBuf {
349 let output_key = self.spend_info().output_key();
350 let builder = bellscoin::blockdata::script::Builder::new();
351 builder
352 .push_opcode(opcodes::all::OP_PUSHNUM_1)
353 .push_slice(&output_key.serialize())
354 .into_script()
355 }
356
357 pub fn address(&self, network: Network) -> Address {
359 let spend_info = self.spend_info();
360 Address::p2tr_tweaked(spend_info.output_key(), network)
361 }
362
363 pub fn get_satisfaction<S>(&self, satisfier: S) -> Result<(Vec<Vec<u8>>, ScriptBuf), Error>
367 where
368 S: Satisfier<Pk>,
369 {
370 best_tap_spend(self, satisfier, false )
371 }
372
373 pub fn get_satisfaction_mall<S>(&self, satisfier: S) -> Result<(Vec<Vec<u8>>, ScriptBuf), Error>
377 where
378 S: Satisfier<Pk>,
379 {
380 best_tap_spend(self, satisfier, true )
381 }
382}
383
384#[derive(Debug, Clone)]
397pub struct TapTreeIter<'a, Pk: MiniscriptKey> {
398 stack: Vec<(u8, &'a TapTree<Pk>)>,
399}
400
401impl<'a, Pk> Iterator for TapTreeIter<'a, Pk>
402where
403 Pk: MiniscriptKey + 'a,
404{
405 type Item = (u8, &'a Miniscript<Pk, Tap>);
406
407 fn next(&mut self) -> Option<Self::Item> {
408 while !self.stack.is_empty() {
409 let (depth, last) = self.stack.pop().expect("Size checked above");
410 match *last {
411 TapTree::Tree(ref l, ref r) => {
412 self.stack.push((depth + 1, r));
413 self.stack.push((depth + 1, l));
414 }
415 TapTree::Leaf(ref ms) => return Some((depth, ms)),
416 }
417 }
418 None
419 }
420}
421
422#[rustfmt::skip]
423impl_block_str!(
424 Tr<Pk>,
425 fn parse_tr_script_spend(tree: &expression::Tree,) -> Result<TapTree<Pk>, Error> {
427 match tree {
428 expression::Tree { name, args } if !name.is_empty() && args.is_empty() => {
429 let script = Miniscript::<Pk, Tap>::from_str(name)?;
430 Ok(TapTree::Leaf(Arc::new(script)))
431 }
432 expression::Tree { name, args } if name.is_empty() && args.len() == 2 => {
433 let left = Self::parse_tr_script_spend(&args[0])?;
434 let right = Self::parse_tr_script_spend(&args[1])?;
435 Ok(TapTree::Tree(Arc::new(left), Arc::new(right)))
436 }
437 _ => Err(Error::Unexpected(
438 "unknown format for script spending paths while parsing taproot descriptor"
439 .to_string(),
440 )),
441 }
442 }
443);
444
445impl_from_tree!(
446 Tr<Pk>,
447 fn from_tree(top: &expression::Tree) -> Result<Self, Error> {
448 if top.name == "tr" {
449 match top.args.len() {
450 1 => {
451 let key = &top.args[0];
452 if !key.args.is_empty() {
453 return Err(Error::Unexpected(format!(
454 "#{} script associated with `key-path` while parsing taproot descriptor",
455 key.args.len()
456 )));
457 }
458 Tr::new(expression::terminal(key, Pk::from_str)?, None)
459 }
460 2 => {
461 let key = &top.args[0];
462 if !key.args.is_empty() {
463 return Err(Error::Unexpected(format!(
464 "#{} script associated with `key-path` while parsing taproot descriptor",
465 key.args.len()
466 )));
467 }
468 let tree = &top.args[1];
469 let ret = Self::parse_tr_script_spend(tree)?;
470 Tr::new(expression::terminal(key, Pk::from_str)?, Some(ret))
471 }
472 _ => Err(Error::Unexpected(format!(
473 "{}[#{} args] while parsing taproot descriptor",
474 top.name,
475 top.args.len()
476 ))),
477 }
478 } else {
479 Err(Error::Unexpected(format!(
480 "{}[#{} args] while parsing taproot descriptor",
481 top.name,
482 top.args.len()
483 )))
484 }
485 }
486);
487
488impl_from_str!(
489 Tr<Pk>,
490 type Err = Error;,
491 fn from_str(s: &str) -> Result<Self, Self::Err> {
492 let desc_str = verify_checksum(s)?;
493 let top = parse_tr_tree(desc_str)?;
494 Self::from_tree(&top)
495 }
496);
497
498impl<Pk: MiniscriptKey> fmt::Debug for Tr<Pk> {
499 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
500 match self.tree {
501 Some(ref s) => write!(f, "tr({:?},{:?})", self.internal_key, s),
502 None => write!(f, "tr({:?})", self.internal_key),
503 }
504 }
505}
506
507impl<Pk: MiniscriptKey> fmt::Display for Tr<Pk> {
508 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
509 use fmt::Write;
510 let mut wrapped_f = checksum::Formatter::new(f);
511 let key = &self.internal_key;
512 match self.tree {
513 Some(ref s) => write!(wrapped_f, "tr({},{})", key, s)?,
514 None => write!(wrapped_f, "tr({})", key)?,
515 }
516 wrapped_f.write_checksum_if_not_alt()
517 }
518}
519
520fn parse_tr_tree(s: &str) -> Result<expression::Tree, Error> {
522 for ch in s.bytes() {
523 if !ch.is_ascii() {
524 return Err(Error::Unprintable(ch));
525 }
526 }
527
528 if s.len() > 3 && &s[..3] == "tr(" && s.as_bytes()[s.len() - 1] == b')' {
529 let rest = &s[3..s.len() - 1];
530 if !rest.contains(',') {
531 let internal_key = expression::Tree {
532 name: rest,
533 args: vec![],
534 };
535 return Ok(expression::Tree {
536 name: "tr",
537 args: vec![internal_key],
538 });
539 }
540 let (key, script) = split_once(rest, ',')
542 .ok_or_else(|| Error::BadDescriptor("invalid taproot descriptor".to_string()))?;
543
544 let internal_key = expression::Tree {
545 name: key,
546 args: vec![],
547 };
548 if script.is_empty() {
549 return Ok(expression::Tree {
550 name: "tr",
551 args: vec![internal_key],
552 });
553 }
554 let (tree, rest) = expression::Tree::from_slice_delim(script, 1, '{')?;
555 if rest.is_empty() {
556 Ok(expression::Tree {
557 name: "tr",
558 args: vec![internal_key, tree],
559 })
560 } else {
561 Err(errstr(rest))
562 }
563 } else {
564 Err(Error::Unexpected("invalid taproot descriptor".to_string()))
565 }
566}
567
568fn split_once(inp: &str, delim: char) -> Option<(&str, &str)> {
569 if inp.is_empty() {
570 None
571 } else {
572 let mut found = inp.len();
573 for (idx, ch) in inp.chars().enumerate() {
574 if ch == delim {
575 found = idx;
576 break;
577 }
578 }
579 if found >= inp.len() - 1 {
581 Some((inp, ""))
582 } else {
583 Some((&inp[..found], &inp[found + 1..]))
584 }
585 }
586}
587
588impl<Pk: MiniscriptKey> Liftable<Pk> for TapTree<Pk> {
589 fn lift(&self) -> Result<Policy<Pk>, Error> {
590 fn lift_helper<Pk: MiniscriptKey>(s: &TapTree<Pk>) -> Result<Policy<Pk>, Error> {
591 match s {
592 TapTree::Tree(ref l, ref r) => {
593 Ok(Policy::Threshold(1, vec![lift_helper(l)?, lift_helper(r)?]))
594 }
595 TapTree::Leaf(ref leaf) => leaf.lift(),
596 }
597 }
598
599 let pol = lift_helper(self)?;
600 Ok(pol.normalized())
601 }
602}
603
604impl<Pk: MiniscriptKey> Liftable<Pk> for Tr<Pk> {
605 fn lift(&self) -> Result<Policy<Pk>, Error> {
606 match &self.tree {
607 Some(root) => Ok(Policy::Threshold(
608 1,
609 vec![Policy::Key(self.internal_key.clone()), root.lift()?],
610 )),
611 None => Ok(Policy::Key(self.internal_key.clone())),
612 }
613 }
614}
615
616impl<Pk: MiniscriptKey> ForEachKey<Pk> for Tr<Pk> {
617 fn for_each_key<'a, F: FnMut(&'a Pk) -> bool>(&'a self, mut pred: F) -> bool {
618 let script_keys_res = self
619 .iter_scripts()
620 .all(|(_d, ms)| ms.for_each_key(&mut pred));
621 script_keys_res && pred(&self.internal_key)
622 }
623}
624
625impl<P, Q> TranslatePk<P, Q> for Tr<P>
626where
627 P: MiniscriptKey,
628 Q: MiniscriptKey,
629{
630 type Output = Tr<Q>;
631
632 fn translate_pk<T, E>(&self, translate: &mut T) -> Result<Self::Output, TranslateErr<E>>
633 where
634 T: Translator<P, Q, E>,
635 {
636 let tree = match &self.tree {
637 Some(tree) => Some(tree.translate_helper(translate)?),
638 None => None,
639 };
640 let translate_desc = Tr::new(translate.pk(&self.internal_key)?, tree)
641 .map_err(|e| TranslateErr::OuterError(e))?;
642 Ok(translate_desc)
643 }
644}
645
646fn control_block_len(depth: u8) -> usize {
648 TAPROOT_CONTROL_BASE_SIZE + (depth as usize) * TAPROOT_CONTROL_NODE_SIZE
649}
650
651fn best_tap_spend<Pk, S>(
654 desc: &Tr<Pk>,
655 satisfier: S,
656 allow_mall: bool,
657) -> Result<(Vec<Vec<u8>>, ScriptBuf), Error>
658where
659 Pk: ToPublicKey,
660 S: Satisfier<Pk>,
661{
662 let spend_info = desc.spend_info();
663 if let Some(sig) = satisfier.lookup_tap_key_spend_sig() {
665 Ok((vec![sig.to_vec()], ScriptBuf::new()))
666 } else {
667 let (mut min_wit, mut min_wit_len) = (None, None);
670 for (depth, ms) in desc.iter_scripts() {
671 let mut wit = if allow_mall {
672 match ms.satisfy_malleable(&satisfier) {
673 Ok(wit) => wit,
674 Err(..) => continue, }
676 } else {
677 match ms.satisfy(&satisfier) {
678 Ok(wit) => wit,
679 Err(..) => continue, }
681 };
682 let wit_size = witness_size(&wit)
686 + control_block_len(depth)
687 + ms.script_size()
688 + varint_len(ms.script_size());
689 if min_wit_len.is_some() && Some(wit_size) > min_wit_len {
690 continue;
691 } else {
692 let leaf_script = (ms.encode(), LeafVersion::TapScript);
693 let control_block = spend_info
694 .control_block(&leaf_script)
695 .expect("Control block must exist in script map for every known leaf");
696 wit.push(leaf_script.0.into_bytes()); wit.push(control_block.serialize());
700 min_wit = Some(wit);
702 min_wit_len = Some(wit_size);
703 }
704 }
705 match min_wit {
706 Some(wit) => Ok((wit, ScriptBuf::new())),
707 None => Err(Error::CouldNotSatisfy), }
709 }
710}
711
712#[cfg(test)]
713mod tests {
714 use super::*;
715 use crate::ForEachKey;
716
717 #[test]
718 fn test_for_each() {
719 let desc = "tr(acc0, {
720 multi_a(3, acc10, acc11, acc12), {
721 and_v(
722 v:multi_a(2, acc10, acc11, acc12),
723 after(10)
724 ),
725 and_v(
726 v:multi_a(1, acc10, acc11, ac12),
727 after(100)
728 )
729 }
730 })";
731 let desc = desc.replace(&[' ', '\n'][..], "");
732 let tr = Tr::<String>::from_str(&desc).unwrap();
733 assert!(!tr.for_each_key(|k| k.starts_with("acc")));
735 }
736}