use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum PsbtV2Error {
#[error("invalid PSBT version {0}; expected 2")]
InvalidVersion(u32),
#[error("missing required field: {0}")]
MissingRequiredField(String),
#[error("invalid locktime: {0}")]
InvalidLocktime(String),
#[error("invalid sequence number: {0}")]
InvalidSequence(u32),
#[error("input index {0} out of range")]
InputIndexOutOfRange(usize),
#[error("output index {0} out of range")]
OutputIndexOutOfRange(usize),
#[error("serialization error: {0}")]
SerializationError(String),
#[error("modifiability violation: {0}")]
ModifiabilityViolation(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct TxModifiable(pub u8);
impl TxModifiable {
pub const INPUTS_OUTPUTS_MODIFIABLE: Self = Self(0x03);
pub const INPUTS_MODIFIABLE: Self = Self(0x01);
pub const OUTPUTS_MODIFIABLE: Self = Self(0x02);
pub const HAS_SIGHASH_SINGLE: Self = Self(0x04);
pub const NONE: Self = Self(0x00);
#[inline]
pub fn inputs_modifiable(&self) -> bool {
self.0 & 0x01 != 0
}
#[inline]
pub fn outputs_modifiable(&self) -> bool {
self.0 & 0x02 != 0
}
#[inline]
pub fn has_sighash_single(&self) -> bool {
self.0 & 0x04 != 0
}
#[inline]
pub fn is_sealed(&self) -> bool {
self.0 == 0
}
}
impl Default for TxModifiable {
fn default() -> Self {
Self::NONE
}
}
impl std::ops::BitOr for TxModifiable {
type Output = Self;
fn bitor(self, rhs: Self) -> Self {
Self(self.0 | rhs.0)
}
}
impl std::ops::BitOrAssign for TxModifiable {
fn bitor_assign(&mut self, rhs: Self) {
self.0 |= rhs.0;
}
}
fn validate_hex_length(
hex_str: &str,
expected_bytes: usize,
field: &str,
) -> Result<(), PsbtV2Error> {
let expected_chars = expected_bytes * 2;
if hex_str.len() != expected_chars {
return Err(PsbtV2Error::MissingRequiredField(format!(
"{field}: expected {expected_bytes} bytes ({expected_chars} hex chars), got {} chars",
hex_str.len()
)));
}
if !hex_str.chars().all(|c| c.is_ascii_hexdigit()) {
return Err(PsbtV2Error::MissingRequiredField(format!(
"{field}: contains non-hex characters"
)));
}
Ok(())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PsbtV2Input {
pub previous_txid: String,
pub output_index: u32,
pub sequence: Option<u32>,
pub required_time_locktime: Option<u32>,
pub required_height_locktime: Option<u32>,
pub witness_utxo: Option<String>,
pub final_script_sig: Option<String>,
pub final_script_witness: Option<Vec<String>>,
pub sighash_type: Option<u32>,
pub partial_sigs: Vec<(String, String)>,
}
impl PsbtV2Input {
pub fn new(previous_txid: String, output_index: u32) -> Self {
Self {
previous_txid,
output_index,
sequence: None,
required_time_locktime: None,
required_height_locktime: None,
witness_utxo: None,
final_script_sig: None,
final_script_witness: None,
sighash_type: None,
partial_sigs: Vec::new(),
}
}
#[must_use]
pub fn with_sequence(mut self, seq: u32) -> Self {
self.sequence = Some(seq);
self
}
#[must_use]
pub fn with_time_locktime(mut self, locktime: u32) -> Self {
self.required_time_locktime = Some(locktime);
self
}
#[must_use]
pub fn with_height_locktime(mut self, locktime: u32) -> Self {
self.required_height_locktime = Some(locktime);
self
}
pub fn is_finalized(&self) -> bool {
self.final_script_sig.is_some() || self.final_script_witness.is_some()
}
pub fn validate(&self) -> Result<(), PsbtV2Error> {
validate_hex_length(&self.previous_txid, 32, "previous_txid")?;
if self.required_time_locktime.is_some() && self.required_height_locktime.is_some() {
return Err(PsbtV2Error::InvalidLocktime(
"cannot set both time-based and height-based locktimes on the same input"
.to_string(),
));
}
if let Some(t) = self.required_time_locktime {
if t < 500_000_000 {
return Err(PsbtV2Error::InvalidLocktime(format!(
"required_time_locktime {t} is below the UNIX-timestamp range (500_000_000)"
)));
}
}
if let Some(h) = self.required_height_locktime {
if h >= 500_000_000 {
return Err(PsbtV2Error::InvalidLocktime(format!(
"required_height_locktime {h} is not in the block-height range (< 500_000_000)"
)));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PsbtV2Output {
pub amount: u64,
pub script_pubkey: String,
pub redeem_script: Option<String>,
pub witness_script: Option<String>,
pub bip32_derivations: Vec<(String, String)>,
}
impl PsbtV2Output {
pub fn new(amount: u64, script_pubkey: String) -> Self {
Self {
amount,
script_pubkey,
redeem_script: None,
witness_script: None,
bip32_derivations: Vec::new(),
}
}
pub fn p2wpkh(amount: u64, pubkey_hex: &str) -> Self {
let keyhash_hex = sha256_truncated_20_hex(pubkey_hex);
let script_pubkey = format!("0014{keyhash_hex}");
Self::new(amount, script_pubkey)
}
pub fn validate(&self) -> Result<(), PsbtV2Error> {
if self.script_pubkey.is_empty() {
return Err(PsbtV2Error::MissingRequiredField(
"script_pubkey".to_string(),
));
}
if self.script_pubkey.len() % 2 != 0 {
return Err(PsbtV2Error::MissingRequiredField(
"script_pubkey: odd hex length".to_string(),
));
}
Ok(())
}
}
fn sha256_truncated_20_hex(input_hex: &str) -> String {
use bitcoin::hashes::{Hash, sha256};
let hash = sha256::Hash::hash(input_hex.as_bytes());
let bytes = hash.to_byte_array();
bytes[..20]
.iter()
.map(|b| format!("{b:02x}"))
.collect::<String>()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PsbtV2Summary {
pub version: u32,
pub input_count: usize,
pub output_count: usize,
pub total_output_value: u64,
pub is_complete: bool,
pub is_sealed: bool,
pub effective_locktime: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PsbtV2 {
pub version: u32,
pub tx_version: u32,
pub fallback_locktime: Option<u32>,
pub modifiable: TxModifiable,
pub inputs: Vec<PsbtV2Input>,
pub outputs: Vec<PsbtV2Output>,
pub unknown_globals: HashMap<String, String>,
}
impl PsbtV2 {
pub fn new(tx_version: u32) -> Self {
Self {
version: 2,
tx_version,
fallback_locktime: None,
modifiable: TxModifiable::INPUTS_MODIFIABLE | TxModifiable::OUTPUTS_MODIFIABLE,
inputs: Vec::new(),
outputs: Vec::new(),
unknown_globals: HashMap::new(),
}
}
pub fn add_input(&mut self, input: PsbtV2Input) -> Result<(), PsbtV2Error> {
if !self.modifiable.inputs_modifiable() {
return Err(PsbtV2Error::ModifiabilityViolation(
"PSBT inputs are not modifiable".to_string(),
));
}
self.inputs.push(input);
Ok(())
}
pub fn add_output(&mut self, output: PsbtV2Output) -> Result<(), PsbtV2Error> {
if !self.modifiable.outputs_modifiable() {
return Err(PsbtV2Error::ModifiabilityViolation(
"PSBT outputs are not modifiable".to_string(),
));
}
self.outputs.push(output);
Ok(())
}
pub fn seal(&mut self) {
self.modifiable = TxModifiable::NONE;
}
pub fn is_complete(&self) -> bool {
if self.inputs.is_empty() {
return false;
}
self.inputs.iter().all(|i| i.is_finalized())
}
pub fn input_count(&self) -> usize {
self.inputs.len()
}
pub fn output_count(&self) -> usize {
self.outputs.len()
}
pub fn total_output_value(&self) -> u64 {
self.outputs.iter().map(|o| o.amount).sum()
}
pub fn validate(&self) -> Result<(), PsbtV2Error> {
if self.version != 2 {
return Err(PsbtV2Error::InvalidVersion(self.version));
}
for input in &self.inputs {
input.validate()?;
}
for output in &self.outputs {
output.validate()?;
}
let has_time = self
.inputs
.iter()
.any(|i| i.required_time_locktime.is_some());
let has_height = self
.inputs
.iter()
.any(|i| i.required_height_locktime.is_some());
if has_time && has_height {
return Err(PsbtV2Error::InvalidLocktime(
"inputs mix time-based and height-based locktime requirements".to_string(),
));
}
Ok(())
}
pub fn effective_locktime(&self) -> u32 {
let max_time: Option<u32> = self
.inputs
.iter()
.filter_map(|i| i.required_time_locktime)
.max();
if let Some(t) = max_time {
return t;
}
let max_height: Option<u32> = self
.inputs
.iter()
.filter_map(|i| i.required_height_locktime)
.max();
if let Some(h) = max_height {
return h;
}
self.fallback_locktime.unwrap_or(0)
}
pub fn to_summary(&self) -> PsbtV2Summary {
PsbtV2Summary {
version: self.version,
input_count: self.input_count(),
output_count: self.output_count(),
total_output_value: self.total_output_value(),
is_complete: self.is_complete(),
is_sealed: self.modifiable.is_sealed(),
effective_locktime: self.effective_locktime(),
}
}
pub fn serialize_to_map(&self) -> HashMap<String, serde_json::Value> {
let mut map = HashMap::new();
map.insert(
"PSBT_GLOBAL_VERSION".to_string(),
serde_json::Value::from(self.version),
);
map.insert(
"PSBT_GLOBAL_TX_VERSION".to_string(),
serde_json::Value::from(self.tx_version),
);
if let Some(lt) = self.fallback_locktime {
map.insert(
"PSBT_GLOBAL_FALLBACK_LOCKTIME".to_string(),
serde_json::Value::from(lt),
);
}
map.insert(
"PSBT_GLOBAL_INPUT_COUNT".to_string(),
serde_json::Value::from(self.inputs.len()),
);
map.insert(
"PSBT_GLOBAL_OUTPUT_COUNT".to_string(),
serde_json::Value::from(self.outputs.len()),
);
map.insert(
"PSBT_GLOBAL_TX_MODIFIABLE".to_string(),
serde_json::Value::from(self.modifiable.0),
);
map.insert(
"inputs".to_string(),
serde_json::to_value(&self.inputs).unwrap_or(serde_json::Value::Null),
);
map.insert(
"outputs".to_string(),
serde_json::to_value(&self.outputs).unwrap_or(serde_json::Value::Null),
);
map
}
}
#[derive(Debug)]
pub struct PsbtV2Builder {
tx_version: u32,
fallback_locktime: Option<u32>,
modifiable: TxModifiable,
inputs: Vec<PsbtV2Input>,
outputs: Vec<PsbtV2Output>,
}
impl Default for PsbtV2Builder {
fn default() -> Self {
Self::new()
}
}
impl PsbtV2Builder {
pub fn new() -> Self {
Self {
tx_version: 2,
fallback_locktime: None,
modifiable: TxModifiable::INPUTS_MODIFIABLE | TxModifiable::OUTPUTS_MODIFIABLE,
inputs: Vec::new(),
outputs: Vec::new(),
}
}
#[must_use]
pub fn tx_version(mut self, v: u32) -> Self {
self.tx_version = v;
self
}
#[must_use]
pub fn fallback_locktime(mut self, lt: u32) -> Self {
self.fallback_locktime = Some(lt);
self
}
#[must_use]
pub fn modifiable(mut self, m: TxModifiable) -> Self {
self.modifiable = m;
self
}
#[must_use]
pub fn add_input(mut self, input: PsbtV2Input) -> Self {
self.inputs.push(input);
self
}
#[must_use]
pub fn add_output(mut self, output: PsbtV2Output) -> Self {
self.outputs.push(output);
self
}
pub fn build(self) -> Result<PsbtV2, PsbtV2Error> {
let psbt = PsbtV2 {
version: 2,
tx_version: self.tx_version,
fallback_locktime: self.fallback_locktime,
modifiable: self.modifiable,
inputs: self.inputs,
outputs: self.outputs,
unknown_globals: HashMap::new(),
};
psbt.validate()?;
Ok(psbt)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_txid() -> String {
"a".repeat(64)
}
fn minimal_input() -> PsbtV2Input {
PsbtV2Input::new(dummy_txid(), 0)
}
fn minimal_output() -> PsbtV2Output {
PsbtV2Output::new(
100_000,
"0014aabbccddeeff00112233445566778899aabb".to_string(),
)
}
#[test]
fn test_tx_modifiable_flags() {
let m = TxModifiable::INPUTS_MODIFIABLE | TxModifiable::OUTPUTS_MODIFIABLE;
assert!(m.inputs_modifiable());
assert!(m.outputs_modifiable());
assert!(!m.has_sighash_single());
assert!(!m.is_sealed());
}
#[test]
fn test_tx_modifiable_sealed() {
let m = TxModifiable::NONE;
assert!(!m.inputs_modifiable());
assert!(!m.outputs_modifiable());
assert!(m.is_sealed());
}
#[test]
fn test_tx_modifiable_bitor() {
let m = TxModifiable::INPUTS_MODIFIABLE | TxModifiable::HAS_SIGHASH_SINGLE;
assert!(m.inputs_modifiable());
assert!(!m.outputs_modifiable());
assert!(m.has_sighash_single());
}
#[test]
fn test_psbt_v2_new() {
let psbt = PsbtV2::new(2);
assert_eq!(psbt.version, 2);
assert_eq!(psbt.tx_version, 2);
assert!(psbt.modifiable.inputs_modifiable());
assert!(psbt.modifiable.outputs_modifiable());
assert_eq!(psbt.input_count(), 0);
assert_eq!(psbt.output_count(), 0);
}
#[test]
fn test_add_input_sealed_fails() {
let mut psbt = PsbtV2::new(2);
psbt.seal();
let result = psbt.add_input(minimal_input());
assert!(matches!(
result,
Err(PsbtV2Error::ModifiabilityViolation(_))
));
}
#[test]
fn test_add_output_sealed_fails() {
let mut psbt = PsbtV2::new(2);
psbt.seal();
let result = psbt.add_output(minimal_output());
assert!(matches!(
result,
Err(PsbtV2Error::ModifiabilityViolation(_))
));
}
#[test]
fn test_is_complete_no_inputs() {
let psbt = PsbtV2::new(2);
assert!(!psbt.is_complete());
}
#[test]
fn test_is_complete_with_finalized_input() {
let mut psbt = PsbtV2::new(2);
let mut input = minimal_input();
input.final_script_sig = Some("deadbeef".to_string());
psbt.add_input(input).unwrap();
assert!(psbt.is_complete());
}
#[test]
fn test_total_output_value() {
let mut psbt = PsbtV2::new(2);
psbt.add_output(PsbtV2Output::new(50_000, "0014aa".to_string()))
.unwrap();
psbt.add_output(PsbtV2Output::new(75_000, "0014bb".to_string()))
.unwrap();
assert_eq!(psbt.total_output_value(), 125_000);
}
#[test]
fn test_builder_basic() {
let psbt = PsbtV2Builder::new()
.tx_version(2)
.fallback_locktime(0)
.add_input(minimal_input())
.add_output(minimal_output())
.build()
.expect("builder should produce a valid PSBT");
assert_eq!(psbt.version, 2);
assert_eq!(psbt.input_count(), 1);
assert_eq!(psbt.output_count(), 1);
assert_eq!(psbt.fallback_locktime, Some(0));
}
#[test]
fn test_effective_locktime_from_inputs() {
let mut psbt = PsbtV2Builder::new()
.fallback_locktime(100)
.add_input(PsbtV2Input::new(dummy_txid(), 0).with_height_locktime(800_000))
.add_input(PsbtV2Input::new(dummy_txid(), 1).with_height_locktime(850_000))
.build()
.expect("valid PSBT");
assert_eq!(psbt.effective_locktime(), 850_000);
psbt.inputs.clear();
assert_eq!(psbt.effective_locktime(), 100);
}
#[test]
fn test_effective_locktime_time_based() {
let psbt = PsbtV2Builder::new()
.add_input(PsbtV2Input::new(dummy_txid(), 0).with_time_locktime(1_700_000_000))
.add_input(PsbtV2Input::new(dummy_txid(), 1).with_time_locktime(1_800_000_000))
.build()
.expect("valid PSBT");
assert_eq!(psbt.effective_locktime(), 1_800_000_000);
}
#[test]
fn test_input_validation_missing_txid() {
let short_txid = "aabb".to_string(); let input = PsbtV2Input::new(short_txid, 0);
let result = input.validate();
assert!(
matches!(result, Err(PsbtV2Error::MissingRequiredField(_))),
"expected MissingRequiredField, got {result:?}"
);
}
#[test]
fn test_input_validation_mixed_locktimes() {
let input = PsbtV2Input::new(dummy_txid(), 0)
.with_time_locktime(1_700_000_000)
.with_height_locktime(800_000);
let result = input.validate();
assert!(
matches!(result, Err(PsbtV2Error::InvalidLocktime(_))),
"expected InvalidLocktime, got {result:?}"
);
}
#[test]
fn test_psbt_v2_summary() {
let mut psbt = PsbtV2::new(2);
let mut input = minimal_input();
input.final_script_sig = Some("cafebabe".to_string());
psbt.add_input(input).unwrap();
psbt.add_output(minimal_output()).unwrap();
psbt.seal();
let summary = psbt.to_summary();
assert_eq!(summary.version, 2);
assert_eq!(summary.input_count, 1);
assert_eq!(summary.output_count, 1);
assert_eq!(summary.total_output_value, 100_000);
assert!(summary.is_complete);
assert!(summary.is_sealed);
}
#[test]
fn test_serialize_to_map() {
let psbt = PsbtV2::new(2);
let map = psbt.serialize_to_map();
assert!(map.contains_key("PSBT_GLOBAL_VERSION"));
assert!(map.contains_key("PSBT_GLOBAL_TX_VERSION"));
assert!(map.contains_key("PSBT_GLOBAL_INPUT_COUNT"));
assert!(map.contains_key("PSBT_GLOBAL_OUTPUT_COUNT"));
assert!(map.contains_key("PSBT_GLOBAL_TX_MODIFIABLE"));
}
#[test]
fn test_p2wpkh_output_convenience() {
let pubkey_hex = "02c6047f9441ed7d6d3045406e95c07cd85c778e4b8cef3ca7abac09b95c709ee5";
let output = PsbtV2Output::p2wpkh(50_000, pubkey_hex);
assert_eq!(output.amount, 50_000);
assert!(
output.script_pubkey.starts_with("0014"),
"P2WPKH script should start with 0014"
);
assert_eq!(output.script_pubkey.len(), 44); }
#[test]
fn test_validate_mixed_locktime_across_inputs_fails() {
let mut psbt = PsbtV2 {
version: 2,
tx_version: 2,
fallback_locktime: None,
modifiable: TxModifiable::NONE,
inputs: vec![
PsbtV2Input::new(dummy_txid(), 0).with_time_locktime(1_700_000_000),
PsbtV2Input::new(dummy_txid(), 1).with_height_locktime(800_000),
],
outputs: Vec::new(),
unknown_globals: HashMap::new(),
};
assert!(psbt.inputs[0].validate().is_ok());
assert!(psbt.inputs[1].validate().is_ok());
let result = psbt.validate();
assert!(
matches!(result, Err(PsbtV2Error::InvalidLocktime(_))),
"expected cross-input locktime error, got {result:?}"
);
psbt.inputs[1].required_height_locktime = None;
assert!(psbt.validate().is_ok());
}
}