use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum TaprootPsbtError {
#[error("invalid key path: {0}")]
InvalidKeyPath(String),
#[error("invalid leaf script: {0}")]
InvalidLeafScript(String),
#[error("invalid control block: {0}")]
InvalidControlBlock(String),
#[error("invalid Taproot signature: {0}")]
InvalidSignature(String),
#[error("invalid internal key: {0}")]
InvalidInternalKey(String),
#[error("invalid Merkle root: {0}")]
InvalidMerkleRoot(String),
#[error("invalid derivation path: {0}")]
InvalidDerivationPath(String),
#[error("input index {0} out of range")]
InputIndexOutOfRange(usize),
#[error("output index {0} out of range")]
OutputIndexOutOfRange(usize),
}
fn validate_hex_length(
hex_str: &str,
expected_bytes: usize,
field: &str,
) -> Result<(), TaprootPsbtError> {
let expected_chars = expected_bytes * 2;
if hex_str.len() != expected_chars {
return Err(TaprootPsbtError::InvalidInternalKey(format!(
"{field}: expected {expected_bytes} bytes ({expected_chars} hex chars), got {} chars",
hex_str.len()
)));
}
if !hex_str.chars().all(|c| c.is_ascii_hexdigit()) {
return Err(TaprootPsbtError::InvalidInternalKey(format!(
"{field}: contains non-hex characters"
)));
}
Ok(())
}
fn validate_hex_with<E, F>(hex_str: &str, expected_bytes: usize, make_err: F) -> Result<(), E>
where
F: FnOnce(String) -> E,
{
let expected_chars = expected_bytes * 2;
if hex_str.len() != expected_chars {
return Err(make_err(format!(
"expected {expected_bytes} bytes ({expected_chars} hex chars), got {} chars",
hex_str.len()
)));
}
if !hex_str.chars().all(|c| c.is_ascii_hexdigit()) {
return Err(make_err("contains non-hex characters".to_string()));
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum LeafVersion {
#[default]
TapScript,
Future(u8),
}
impl LeafVersion {
pub fn to_byte(&self) -> u8 {
match self {
Self::TapScript => 0xC0,
Self::Future(b) => *b,
}
}
pub fn from_byte(b: u8) -> Self {
if b == 0xC0 {
Self::TapScript
} else {
Self::Future(b)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TapLeafScript {
pub leaf_version: LeafVersion,
pub script: Vec<u8>,
pub control_block: Vec<u8>,
}
impl TapLeafScript {
pub fn new(script: Vec<u8>, version: LeafVersion) -> Self {
Self {
leaf_version: version,
script,
control_block: Vec::new(),
}
}
pub fn script_hex(&self) -> String {
self.script
.iter()
.map(|b| format!("{b:02x}"))
.collect::<String>()
}
pub fn control_block_hex(&self) -> String {
self.control_block
.iter()
.map(|b| format!("{b:02x}"))
.collect::<String>()
}
pub fn proof_len(&self) -> usize {
if self.control_block.len() < 33 {
return 0;
}
(self.control_block.len() - 33) / 32
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TapBip32Derivation {
pub x_only_pubkey: String,
pub master_fingerprint: String,
pub derivation_path: String,
pub leaf_hashes: Vec<String>,
}
impl TapBip32Derivation {
pub fn new(x_only_pubkey: String, master_fingerprint: String, derivation_path: String) -> Self {
Self {
x_only_pubkey,
master_fingerprint,
derivation_path,
leaf_hashes: Vec::new(),
}
}
pub fn is_key_path(&self) -> bool {
self.leaf_hashes.is_empty()
}
pub fn validate(&self) -> Result<(), TaprootPsbtError> {
validate_hex_with(&self.x_only_pubkey, 32, |msg| {
TaprootPsbtError::InvalidKeyPath(format!("x_only_pubkey: {msg}"))
})?;
validate_hex_with(&self.master_fingerprint, 4, |msg| {
TaprootPsbtError::InvalidDerivationPath(format!("master_fingerprint: {msg}"))
})?;
if self.derivation_path.is_empty() {
return Err(TaprootPsbtError::InvalidDerivationPath(
"derivation path is empty".to_string(),
));
}
for hash in &self.leaf_hashes {
validate_hex_with(hash, 32, |msg| {
TaprootPsbtError::InvalidKeyPath(format!("leaf_hash: {msg}"))
})?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaprootInputFields {
pub tap_key_sig: Option<String>,
pub tap_script_sigs: Vec<(String, String)>,
pub tap_leaf_scripts: Vec<TapLeafScript>,
pub tap_bip32_derivations: Vec<TapBip32Derivation>,
pub tap_internal_key: Option<String>,
pub tap_merkle_root: Option<String>,
}
impl TaprootInputFields {
pub fn new() -> Self {
Self {
tap_key_sig: None,
tap_script_sigs: Vec::new(),
tap_leaf_scripts: Vec::new(),
tap_bip32_derivations: Vec::new(),
tap_internal_key: None,
tap_merkle_root: None,
}
}
pub fn with_key_sig(mut self, sig: String) -> Result<Self, TaprootPsbtError> {
let is_valid = sig.len() == 128 || sig.len() == 130 ;
if !is_valid || !sig.chars().all(|c| c.is_ascii_hexdigit()) {
return Err(TaprootPsbtError::InvalidSignature(format!(
"tap_key_sig must be 64 or 65 bytes (128 or 130 hex chars), got {} chars",
sig.len()
)));
}
self.tap_key_sig = Some(sig);
Ok(self)
}
pub fn with_internal_key(mut self, key: String) -> Result<Self, TaprootPsbtError> {
validate_hex_length(&key, 32, "tap_internal_key").map_err(|_| {
TaprootPsbtError::InvalidInternalKey(format!(
"tap_internal_key must be 32 bytes (64 hex chars), got {} chars",
key.len()
))
})?;
self.tap_internal_key = Some(key);
Ok(self)
}
pub fn with_merkle_root(mut self, root: String) -> Result<Self, TaprootPsbtError> {
validate_hex_with(&root, 32, TaprootPsbtError::InvalidMerkleRoot)?;
self.tap_merkle_root = Some(root);
Ok(self)
}
pub fn add_leaf_script(&mut self, leaf: TapLeafScript) {
self.tap_leaf_scripts.push(leaf);
}
pub fn add_bip32_derivation(&mut self, deriv: TapBip32Derivation) {
self.tap_bip32_derivations.push(deriv);
}
pub fn is_key_path_signed(&self) -> bool {
self.tap_key_sig.is_some()
}
pub fn is_script_path_signed(&self) -> bool {
!self.tap_script_sigs.is_empty()
}
pub fn is_finalized(&self) -> bool {
self.is_key_path_signed() || self.is_script_path_signed()
}
pub fn validate(&self) -> Result<(), TaprootPsbtError> {
if let Some(ref sig) = self.tap_key_sig {
let valid = (sig.len() == 128 || sig.len() == 130)
&& sig.chars().all(|c| c.is_ascii_hexdigit());
if !valid {
return Err(TaprootPsbtError::InvalidSignature(format!(
"tap_key_sig: expected 64 or 65 bytes, got {} hex chars",
sig.len()
)));
}
}
if let Some(ref key) = self.tap_internal_key {
validate_hex_length(key, 32, "tap_internal_key").map_err(|_| {
TaprootPsbtError::InvalidInternalKey(format!(
"tap_internal_key must be 32 bytes (64 hex chars), got {} chars",
key.len()
))
})?;
}
if let Some(ref root) = self.tap_merkle_root {
validate_hex_with(root, 32, TaprootPsbtError::InvalidMerkleRoot)?;
}
for (cb, sig) in &self.tap_script_sigs {
if cb.len() < 66 || cb.len() % 2 != 0 || !cb.chars().all(|c| c.is_ascii_hexdigit()) {
return Err(TaprootPsbtError::InvalidControlBlock(format!(
"script sig control block is malformed (len {})",
cb.len()
)));
}
let sig_valid = (sig.len() == 128 || sig.len() == 130)
&& sig.chars().all(|c| c.is_ascii_hexdigit());
if !sig_valid {
return Err(TaprootPsbtError::InvalidSignature(format!(
"script-path sig must be 64 or 65 bytes, got {} hex chars",
sig.len()
)));
}
}
for deriv in &self.tap_bip32_derivations {
deriv.validate()?;
}
Ok(())
}
}
impl Default for TaprootInputFields {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaprootOutputFields {
pub tap_internal_key: Option<String>,
pub tap_tree: Vec<TapLeafScript>,
pub tap_bip32_derivations: Vec<TapBip32Derivation>,
}
impl TaprootOutputFields {
pub fn new() -> Self {
Self {
tap_internal_key: None,
tap_tree: Vec::new(),
tap_bip32_derivations: Vec::new(),
}
}
pub fn with_internal_key(mut self, key: String) -> Result<Self, TaprootPsbtError> {
validate_hex_with(&key, 32, |msg| {
TaprootPsbtError::InvalidInternalKey(format!("tap_internal_key: {msg}"))
})?;
self.tap_internal_key = Some(key);
Ok(self)
}
pub fn add_tree_leaf(&mut self, leaf: TapLeafScript) {
self.tap_tree.push(leaf);
}
pub fn is_taproot_output(&self) -> bool {
self.tap_internal_key.is_some()
}
pub fn validate(&self) -> Result<(), TaprootPsbtError> {
if let Some(ref key) = self.tap_internal_key {
validate_hex_with(key, 32, |msg| {
TaprootPsbtError::InvalidInternalKey(format!("tap_internal_key: {msg}"))
})?;
}
for deriv in &self.tap_bip32_derivations {
deriv.validate()?;
}
Ok(())
}
}
impl Default for TaprootOutputFields {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaprootPsbt {
pub base_psbt_hex: String,
pub input_fields: Vec<TaprootInputFields>,
pub output_fields: Vec<TaprootOutputFields>,
}
impl TaprootPsbt {
pub fn new(psbt_hex: String) -> Self {
Self {
base_psbt_hex: psbt_hex,
input_fields: Vec::new(),
output_fields: Vec::new(),
}
}
pub fn input_count(&self) -> usize {
self.input_fields.len()
}
pub fn output_count(&self) -> usize {
self.output_fields.len()
}
pub fn is_complete(&self) -> bool {
if self.input_fields.is_empty() {
return false;
}
self.input_fields.iter().all(|f| f.is_finalized())
}
pub fn get_input_fields(&self, index: usize) -> Option<&TaprootInputFields> {
self.input_fields.get(index)
}
pub fn get_output_fields(&self, index: usize) -> Option<&TaprootOutputFields> {
self.output_fields.get(index)
}
pub fn add_input_taproot_fields(
&mut self,
index: usize,
fields: TaprootInputFields,
) -> Result<(), TaprootPsbtError> {
while self.input_fields.len() <= index {
self.input_fields.push(TaprootInputFields::new());
}
self.input_fields[index] = fields;
Ok(())
}
pub fn add_output_taproot_fields(
&mut self,
index: usize,
fields: TaprootOutputFields,
) -> Result<(), TaprootPsbtError> {
while self.output_fields.len() <= index {
self.output_fields.push(TaprootOutputFields::new());
}
self.output_fields[index] = fields;
Ok(())
}
pub fn validate_all(&self) -> Result<(), TaprootPsbtError> {
for fields in &self.input_fields {
fields.validate()?;
}
for fields in &self.output_fields {
fields.validate()?;
}
Ok(())
}
}
#[derive(Debug)]
pub struct TaprootPsbtBuilder {
psbt_hex: String,
input_fields: Vec<TaprootInputFields>,
output_fields: Vec<TaprootOutputFields>,
}
impl TaprootPsbtBuilder {
pub fn new(psbt_hex: String) -> Self {
Self {
psbt_hex,
input_fields: Vec::new(),
output_fields: Vec::new(),
}
}
#[must_use]
pub fn add_input_fields(mut self, fields: TaprootInputFields) -> Self {
self.input_fields.push(fields);
self
}
#[must_use]
pub fn add_output_fields(mut self, fields: TaprootOutputFields) -> Self {
self.output_fields.push(fields);
self
}
pub fn build(self) -> TaprootPsbt {
TaprootPsbt {
base_psbt_hex: self.psbt_hex,
input_fields: self.input_fields,
output_fields: self.output_fields,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn xonly_key() -> String {
"b2c9c8b6a5d3e1f0b2c9c8b6a5d3e1f0b2c9c8b6a5d3e1f0b2c9c8b6a5d3e1f0".to_string()
}
fn schnorr_sig_64() -> String {
"aa".repeat(64)
}
fn schnorr_sig_65() -> String {
format!("{}01", "bb".repeat(64))
}
fn minimal_control_block() -> String {
"c0".to_string() + &"ab".repeat(32)
}
#[test]
fn test_leaf_version_tapscript() {
let v = LeafVersion::TapScript;
assert_eq!(v.to_byte(), 0xC0);
assert_eq!(LeafVersion::from_byte(0xC0), LeafVersion::TapScript);
}
#[test]
fn test_leaf_version_future() {
let v = LeafVersion::Future(0xC2);
assert_eq!(v.to_byte(), 0xC2);
assert_eq!(LeafVersion::from_byte(0xC2), LeafVersion::Future(0xC2));
}
#[test]
fn test_tap_leaf_script_new() {
let script = vec![0x51, 0x20]; let leaf = TapLeafScript::new(script.clone(), LeafVersion::TapScript);
assert_eq!(leaf.leaf_version, LeafVersion::TapScript);
assert_eq!(leaf.script, script);
assert!(leaf.control_block.is_empty());
assert_eq!(leaf.script_hex(), "5120");
}
#[test]
fn test_tap_leaf_script_proof_len() {
let mut leaf = TapLeafScript::new(vec![], LeafVersion::TapScript);
assert_eq!(leaf.proof_len(), 0);
leaf.control_block = vec![0xc0; 33];
assert_eq!(leaf.proof_len(), 0);
leaf.control_block = vec![0xc0; 65];
assert_eq!(leaf.proof_len(), 1);
leaf.control_block = vec![0xc0; 97];
assert_eq!(leaf.proof_len(), 2);
}
#[test]
fn test_taproot_input_fields_default() {
let fields = TaprootInputFields::default();
assert!(fields.tap_key_sig.is_none());
assert!(fields.tap_internal_key.is_none());
assert!(fields.tap_merkle_root.is_none());
assert!(fields.tap_script_sigs.is_empty());
assert!(fields.tap_leaf_scripts.is_empty());
assert!(fields.tap_bip32_derivations.is_empty());
assert!(!fields.is_finalized());
}
#[test]
fn test_taproot_input_key_path_signed() {
let fields = TaprootInputFields::new()
.with_key_sig(schnorr_sig_64())
.expect("64-byte sig should be accepted");
assert!(fields.is_key_path_signed());
assert!(!fields.is_script_path_signed());
assert!(fields.is_finalized());
assert!(fields.validate().is_ok());
}
#[test]
fn test_taproot_input_key_sig_65_bytes() {
let fields = TaprootInputFields::new()
.with_key_sig(schnorr_sig_65())
.expect("65-byte sig should be accepted");
assert!(fields.is_key_path_signed());
assert!(fields.validate().is_ok());
}
#[test]
fn test_taproot_input_key_sig_invalid_length() {
let bad_sig = "cc".repeat(63);
let result = TaprootInputFields::new().with_key_sig(bad_sig);
assert!(
matches!(result, Err(TaprootPsbtError::InvalidSignature(_))),
"expected InvalidSignature"
);
}
#[test]
fn test_taproot_input_internal_key_valid() {
let fields = TaprootInputFields::new()
.with_internal_key(xonly_key())
.expect("valid 32-byte key");
assert!(fields.tap_internal_key.is_some());
assert!(fields.validate().is_ok());
}
#[test]
fn test_taproot_input_internal_key_invalid() {
let short_key = "aabb".to_string(); let result = TaprootInputFields::new().with_internal_key(short_key);
assert!(
matches!(result, Err(TaprootPsbtError::InvalidInternalKey(_))),
"expected InvalidInternalKey"
);
}
#[test]
fn test_taproot_output_fields() {
let fields = TaprootOutputFields::new()
.with_internal_key(xonly_key())
.expect("valid key");
assert!(fields.is_taproot_output());
assert!(fields.validate().is_ok());
}
#[test]
fn test_taproot_output_fields_default_not_taproot() {
let fields = TaprootOutputFields::default();
assert!(!fields.is_taproot_output());
assert!(fields.validate().is_ok());
}
#[test]
fn test_taproot_psbt_new() {
let psbt = TaprootPsbt::new("deadbeef".to_string());
assert_eq!(psbt.base_psbt_hex, "deadbeef");
assert_eq!(psbt.input_count(), 0);
assert_eq!(psbt.output_count(), 0);
assert!(!psbt.is_complete());
}
#[test]
fn test_taproot_psbt_add_input_fields() {
let mut psbt = TaprootPsbt::new("deadbeef".to_string());
let fields = TaprootInputFields::new()
.with_key_sig(schnorr_sig_64())
.expect("valid sig");
psbt.add_input_taproot_fields(0, fields).unwrap();
assert_eq!(psbt.input_count(), 1);
assert!(psbt.is_complete());
let retrieved = psbt.get_input_fields(0);
assert!(retrieved.is_some());
assert!(retrieved.unwrap().is_key_path_signed());
assert!(psbt.get_input_fields(99).is_none());
}
#[test]
fn test_taproot_psbt_add_output_fields_sparse() {
let mut psbt = TaprootPsbt::new("cafebabe".to_string());
let fields = TaprootOutputFields::new()
.with_internal_key(xonly_key())
.expect("valid key");
psbt.add_output_taproot_fields(2, fields).unwrap();
assert_eq!(psbt.output_count(), 3); assert!(psbt.get_output_fields(2).unwrap().is_taproot_output());
assert!(!psbt.get_output_fields(0).unwrap().is_taproot_output());
}
#[test]
fn test_bip32_derivation() {
let deriv = TapBip32Derivation::new(
xonly_key(),
"deadbeef".to_string(), "m/86'/0'/0'/0/0".to_string(),
);
assert!(deriv.is_key_path());
assert!(deriv.validate().is_ok());
}
#[test]
fn test_bip32_derivation_invalid_fingerprint() {
let deriv = TapBip32Derivation::new(
xonly_key(),
"dead".to_string(), "m/86'/0'/0'/0/0".to_string(),
);
let result = deriv.validate();
assert!(
matches!(result, Err(TaprootPsbtError::InvalidDerivationPath(_))),
"expected InvalidDerivationPath for bad fingerprint, got {result:?}"
);
}
#[test]
fn test_taproot_psbt_builder() {
let input_fields = TaprootInputFields::new()
.with_key_sig(schnorr_sig_64())
.expect("valid sig");
let output_fields = TaprootOutputFields::new()
.with_internal_key(xonly_key())
.expect("valid key");
let psbt = TaprootPsbtBuilder::new("aabbccdd".to_string())
.add_input_fields(input_fields)
.add_output_fields(output_fields)
.build();
assert_eq!(psbt.base_psbt_hex, "aabbccdd");
assert_eq!(psbt.input_count(), 1);
assert_eq!(psbt.output_count(), 1);
assert!(psbt.is_complete());
assert!(psbt.validate_all().is_ok());
}
#[test]
fn test_validate_all_with_script_path_sigs() {
let mut input_fields = TaprootInputFields::new();
input_fields
.tap_script_sigs
.push((minimal_control_block(), schnorr_sig_64()));
let psbt = TaprootPsbtBuilder::new("ff".to_string())
.add_input_fields(input_fields)
.build();
assert!(psbt.is_complete());
assert!(psbt.validate_all().is_ok());
}
#[test]
fn test_merkle_root_validation() {
let valid_root = "ab".repeat(32);
let fields = TaprootInputFields::new()
.with_merkle_root(valid_root)
.expect("valid 32-byte merkle root");
assert!(fields.tap_merkle_root.is_some());
assert!(fields.validate().is_ok());
let result = TaprootInputFields::new().with_merkle_root("deadbeef".to_string()); assert!(
matches!(result, Err(TaprootPsbtError::InvalidMerkleRoot(_))),
"expected InvalidMerkleRoot for short value"
);
}
#[test]
fn test_tap_leaf_script_add_and_retrieve() {
let mut input_fields = TaprootInputFields::new();
let leaf1 = TapLeafScript::new(vec![0x51], LeafVersion::TapScript);
let leaf2 = TapLeafScript::new(vec![0x52], LeafVersion::Future(0xC2));
input_fields.add_leaf_script(leaf1);
input_fields.add_leaf_script(leaf2);
assert_eq!(input_fields.tap_leaf_scripts.len(), 2);
assert_eq!(input_fields.tap_leaf_scripts[0].script_hex(), "51");
assert_eq!(
input_fields.tap_leaf_scripts[1].leaf_version,
LeafVersion::Future(0xC2)
);
}
}