use core::convert::TryFrom;
use core::fmt;
use bitcoin::bip32::{ChildNumber, DerivationPath, Fingerprint, KeySource, Xpub};
use bitcoin::consensus::{encode as consensus, Decodable};
use bitcoin::locktime::absolute;
#[cfg(feature = "silent-payments")]
use bitcoin::CompressedPublicKey;
use bitcoin::{bip32, transaction, VarInt};
use crate::consts::{
PSBT_GLOBAL_FALLBACK_LOCKTIME, PSBT_GLOBAL_INPUT_COUNT, PSBT_GLOBAL_OUTPUT_COUNT,
PSBT_GLOBAL_PROPRIETARY, PSBT_GLOBAL_TX_MODIFIABLE, PSBT_GLOBAL_TX_VERSION,
PSBT_GLOBAL_UNSIGNED_TX, PSBT_GLOBAL_VERSION, PSBT_GLOBAL_XPUB,
};
#[cfg(feature = "silent-payments")]
use crate::consts::{PSBT_GLOBAL_SP_DLEQ, PSBT_GLOBAL_SP_ECDH_SHARE};
use crate::error::{write_err, InconsistentKeySourcesError};
use crate::io::{Cursor, Read};
use crate::prelude::*;
use crate::serialize::Serialize;
#[cfg(feature = "silent-payments")]
use crate::v2::dleq::DleqProof;
use crate::v2::map::Map;
use crate::version::Version;
use crate::{consts, raw, serialize, V2};
const INPUTS_MODIFIABLE: u8 = 0x01 << 0;
const OUTPUTS_MODIFIABLE: u8 = 0x01 << 1;
const SIGHASH_SINGLE: u8 = 0x01 << 2;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Global {
pub version: Version,
pub tx_version: transaction::Version,
pub fallback_lock_time: Option<absolute::LockTime>,
pub tx_modifiable_flags: u8,
pub input_count: usize,
pub output_count: usize,
pub xpubs: BTreeMap<Xpub, KeySource>,
#[cfg(feature = "silent-payments")]
#[cfg_attr(feature = "serde", serde(with = "crate::serde_utils::btreemap_as_seq"))]
pub sp_ecdh_shares: BTreeMap<CompressedPublicKey, CompressedPublicKey>,
#[cfg(feature = "silent-payments")]
#[cfg_attr(feature = "serde", serde(with = "crate::serde_utils::btreemap_as_seq"))]
pub sp_dleq_proofs: BTreeMap<CompressedPublicKey, DleqProof>,
#[cfg_attr(feature = "serde", serde(with = "crate::serde_utils::btreemap_as_seq_byte_values"))]
pub proprietaries: BTreeMap<raw::ProprietaryKey, Vec<u8>>,
#[cfg_attr(feature = "serde", serde(with = "crate::serde_utils::btreemap_as_seq_byte_values"))]
pub unknowns: BTreeMap<raw::Key, Vec<u8>>,
}
impl Global {
fn new() -> Self {
Global {
version: V2,
tx_version: transaction::Version::TWO,
fallback_lock_time: None,
tx_modifiable_flags: 0x00,
input_count: 0,
output_count: 0,
xpubs: Default::default(),
#[cfg(feature = "silent-payments")]
sp_ecdh_shares: Default::default(),
#[cfg(feature = "silent-payments")]
sp_dleq_proofs: Default::default(),
proprietaries: Default::default(),
unknowns: Default::default(),
}
}
pub(crate) fn set_inputs_modifiable_flag(&mut self) {
self.tx_modifiable_flags |= INPUTS_MODIFIABLE;
}
pub(crate) fn set_outputs_modifiable_flag(&mut self) {
self.tx_modifiable_flags |= OUTPUTS_MODIFIABLE;
}
#[allow(dead_code)]
pub(crate) fn set_sighash_single_flag(&mut self) { self.tx_modifiable_flags |= SIGHASH_SINGLE; }
pub(crate) fn clear_inputs_modifiable_flag(&mut self) {
self.tx_modifiable_flags &= !INPUTS_MODIFIABLE;
}
pub(crate) fn clear_outputs_modifiable_flag(&mut self) {
self.tx_modifiable_flags &= !OUTPUTS_MODIFIABLE;
}
#[allow(dead_code)]
pub(crate) fn clear_sighash_single_flag(&mut self) {
self.tx_modifiable_flags &= !SIGHASH_SINGLE;
}
pub(crate) fn is_inputs_modifiable(&self) -> bool {
self.tx_modifiable_flags & INPUTS_MODIFIABLE > 0
}
pub(crate) fn is_outputs_modifiable(&self) -> bool {
self.tx_modifiable_flags & OUTPUTS_MODIFIABLE > 0
}
#[allow(dead_code)]
pub(crate) fn has_sighash_single(&self) -> bool {
self.tx_modifiable_flags & SIGHASH_SINGLE > 0
}
pub(crate) fn decode<R: Read + ?Sized>(r: &mut R) -> Result<Self, DecodeError> {
let mut version: Option<Version> = None;
let mut tx_version: Option<transaction::Version> = None;
let mut fallback_lock_time: Option<absolute::LockTime> = None;
let mut tx_modifiable_flags: Option<u8> = None;
let mut input_count: Option<u64> = None;
let mut output_count: Option<u64> = None;
let mut xpubs: BTreeMap<Xpub, (Fingerprint, DerivationPath)> = Default::default();
#[cfg(feature = "silent-payments")]
let mut sp_ecdh_shares: BTreeMap<CompressedPublicKey, CompressedPublicKey> =
Default::default();
#[cfg(feature = "silent-payments")]
let mut sp_dleq_proofs: BTreeMap<CompressedPublicKey, DleqProof> = Default::default();
let mut proprietaries: BTreeMap<raw::ProprietaryKey, Vec<u8>> = Default::default();
let mut unknowns: BTreeMap<raw::Key, Vec<u8>> = Default::default();
let mut insert_pair = |pair: raw::Pair| {
match pair.key.type_value {
PSBT_GLOBAL_VERSION =>
if pair.key.key.is_empty() {
if version.is_none() {
let vlen: usize = pair.value.len();
let mut decoder = Cursor::new(pair.value);
if vlen != 4 {
return Err::<(), InsertPairError>(
InsertPairError::ValueWrongLength(vlen, 4),
);
}
let ver = Decodable::consensus_decode(&mut decoder)?;
if ver != 2 {
return Err(InsertPairError::WrongVersion(ver));
}
version = Some(Version::try_from(ver).expect("valid, this is 2"));
} else {
return Err(InsertPairError::DuplicateKey(pair.key));
}
} else {
return Err(InsertPairError::InvalidKeyDataNotEmpty(pair.key));
},
PSBT_GLOBAL_TX_VERSION => {
if pair.key.key.is_empty() {
if tx_version.is_none() {
let vlen: usize = pair.value.len();
let mut decoder = Cursor::new(pair.value);
if vlen != 4 {
return Err(InsertPairError::ValueWrongLength(vlen, 4));
}
tx_version = Some(Decodable::consensus_decode(&mut decoder)?);
} else {
return Err(InsertPairError::DuplicateKey(pair.key));
}
} else {
return Err(InsertPairError::InvalidKeyDataNotEmpty(pair.key));
}
}
PSBT_GLOBAL_FALLBACK_LOCKTIME =>
if pair.key.key.is_empty() {
if fallback_lock_time.is_none() {
let vlen: usize = pair.value.len();
if vlen != 4 {
return Err(InsertPairError::ValueWrongLength(vlen, 4));
}
let mut decoder = Cursor::new(pair.value);
fallback_lock_time = Some(Decodable::consensus_decode(&mut decoder)?);
} else {
return Err(InsertPairError::DuplicateKey(pair.key));
}
} else {
return Err(InsertPairError::InvalidKeyDataNotEmpty(pair.key));
},
PSBT_GLOBAL_INPUT_COUNT => {
if pair.key.key.is_empty() {
if output_count.is_none() {
let mut decoder = Cursor::new(pair.value);
let count: VarInt = Decodable::consensus_decode(&mut decoder)?;
input_count = Some(count.0);
} else {
return Err(InsertPairError::DuplicateKey(pair.key));
}
} else {
return Err(InsertPairError::InvalidKeyDataNotEmpty(pair.key));
}
}
PSBT_GLOBAL_OUTPUT_COUNT => {
if pair.key.key.is_empty() {
if output_count.is_none() {
let mut decoder = Cursor::new(pair.value);
let count: VarInt = Decodable::consensus_decode(&mut decoder)?;
output_count = Some(count.0);
} else {
return Err(InsertPairError::DuplicateKey(pair.key));
}
} else {
return Err(InsertPairError::InvalidKeyDataNotEmpty(pair.key));
}
}
PSBT_GLOBAL_TX_MODIFIABLE =>
if pair.key.key.is_empty() {
if tx_modifiable_flags.is_none() {
let vlen: usize = pair.value.len();
if vlen != 1 {
return Err(InsertPairError::ValueWrongLength(vlen, 1));
}
let mut decoder = Cursor::new(pair.value);
tx_modifiable_flags = Some(Decodable::consensus_decode(&mut decoder)?);
} else {
return Err(InsertPairError::DuplicateKey(pair.key));
}
} else {
return Err(InsertPairError::InvalidKeyDataNotEmpty(pair.key));
},
PSBT_GLOBAL_XPUB => {
if !pair.key.key.is_empty() {
let xpub = Xpub::decode(&pair.key.key)?;
if pair.value.is_empty() {
return Err(InsertPairError::InvalidKeyDataNotEmpty(pair.key));
}
if pair.value.len() < 4 {
return Err(InsertPairError::XpubInvalidFingerprint);
}
if pair.value.len() % 4 != 0 {
return Err(InsertPairError::XpubInvalidPath(pair.value.len()));
}
let child_count = pair.value.len() / 4 - 1;
let mut decoder = Cursor::new(pair.value);
let mut fingerprint = [0u8; 4];
decoder
.read_exact(&mut fingerprint[..])
.expect("in-memory readers don't err");
let mut path = Vec::<ChildNumber>::with_capacity(child_count);
while let Ok(index) = u32::consensus_decode(&mut decoder) {
path.push(ChildNumber::from(index))
}
let derivation = DerivationPath::from(path);
if let Some(key_source) =
xpubs.insert(xpub, (Fingerprint::from(fingerprint), derivation))
{
return Err(InsertPairError::DuplicateXpub(key_source));
}
} else {
return Err(InsertPairError::InvalidKeyDataEmpty(pair.key));
}
}
PSBT_GLOBAL_PROPRIETARY =>
if !pair.key.key.is_empty() {
match proprietaries.entry(
raw::ProprietaryKey::try_from(pair.key.clone())
.map_err(|_| InsertPairError::InvalidProprietaryKey)?,
) {
btree_map::Entry::Vacant(empty_key) => {
empty_key.insert(pair.value);
}
btree_map::Entry::Occupied(_) =>
return Err(InsertPairError::DuplicateKey(pair.key)),
}
} else {
return Err(InsertPairError::InvalidKeyDataEmpty(pair.key));
},
#[cfg(feature = "silent-payments")]
PSBT_GLOBAL_SP_ECDH_SHARE => {
v2_impl_psbt_insert_sp_pair!(
sp_ecdh_shares,
pair.key,
pair.value,
compressed_pubkey
);
}
#[cfg(feature = "silent-payments")]
PSBT_GLOBAL_SP_DLEQ => {
v2_impl_psbt_insert_sp_pair!(sp_dleq_proofs, pair.key, pair.value, dleq_proof);
}
v if v == PSBT_GLOBAL_UNSIGNED_TX =>
return Err(InsertPairError::ExcludedKey { key_type_value: v }),
_ => match unknowns.entry(pair.key) {
btree_map::Entry::Vacant(empty_key) => {
empty_key.insert(pair.value);
}
btree_map::Entry::Occupied(k) => {
return Err(InsertPairError::DuplicateKey(k.key().clone()));
}
},
}
Ok(())
};
loop {
match raw::Pair::decode(r) {
Ok(pair) => insert_pair(pair)?,
Err(serialize::Error::NoMorePairs) => break,
Err(e) => return Err(DecodeError::DeserPair(e)),
}
}
let version = version.ok_or(DecodeError::MissingVersion)?;
let tx_version = tx_version.ok_or(DecodeError::MissingTxVersion)?;
let tx_modifiable_flags = tx_modifiable_flags.unwrap_or(0_u8);
let input_count = usize::try_from(input_count.ok_or(DecodeError::MissingInputCount)?)
.map_err(|_| DecodeError::InputCountOverflow(input_count.expect("is some")))?;
let output_count = usize::try_from(output_count.ok_or(DecodeError::MissingOutputCount)?)
.map_err(|_| DecodeError::OutputCountOverflow(output_count.expect("is some")))?;
#[cfg(feature = "silent-payments")]
{
let has_ecdh = !sp_ecdh_shares.is_empty();
let has_dleq = !sp_dleq_proofs.is_empty();
if has_ecdh != has_dleq {
return Err(DecodeError::FieldMismatch);
}
}
Ok(Global {
tx_version,
fallback_lock_time,
input_count,
output_count,
tx_modifiable_flags,
version,
#[cfg(feature = "silent-payments")]
sp_ecdh_shares,
#[cfg(feature = "silent-payments")]
sp_dleq_proofs,
xpubs,
proprietaries,
unknowns,
})
}
pub fn combine(&mut self, other: Self) -> Result<(), CombineError> {
if self.version != other.version {
return Err(CombineError::VersionMismatch { this: self.version, that: other.version });
}
if self.tx_version != other.tx_version {
return Err(CombineError::TxVersionMismatch {
this: self.tx_version,
that: other.tx_version,
});
}
for (xpub, (fingerprint1, derivation1)) in other.xpubs {
match self.xpubs.entry(xpub) {
btree_map::Entry::Vacant(entry) => {
entry.insert((fingerprint1, derivation1));
}
btree_map::Entry::Occupied(mut entry) => {
let (fingerprint2, derivation2) = entry.get().clone();
if (derivation1 == derivation2 && fingerprint1 == fingerprint2)
|| (derivation1.len() < derivation2.len()
&& derivation1[..]
== derivation2[derivation2.len() - derivation1.len()..])
{
continue;
} else if derivation2[..]
== derivation1[derivation1.len() - derivation2.len()..]
{
entry.insert((fingerprint1, derivation1));
continue;
}
return Err(InconsistentKeySourcesError(xpub).into());
}
}
}
#[cfg(feature = "silent-payments")]
v2_combine_map!(sp_ecdh_shares, self, other);
#[cfg(feature = "silent-payments")]
v2_combine_map!(sp_dleq_proofs, self, other);
v2_combine_map!(proprietaries, self, other);
v2_combine_map!(unknowns, self, other);
Ok(())
}
}
impl Default for Global {
fn default() -> Self { Self::new() }
}
impl Map for Global {
fn get_pairs(&self) -> Vec<raw::Pair> {
let mut rv: Vec<raw::Pair> = Default::default();
rv.push(raw::Pair {
key: raw::Key { type_value: PSBT_GLOBAL_VERSION, key: vec![] },
value: self.version.serialize(),
});
rv.push(raw::Pair {
key: raw::Key { type_value: PSBT_GLOBAL_TX_VERSION, key: vec![] },
value: self.tx_version.serialize(),
});
v2_impl_psbt_get_pair! {
rv.push(self.fallback_lock_time, PSBT_GLOBAL_FALLBACK_LOCKTIME)
}
rv.push(raw::Pair {
key: raw::Key { type_value: PSBT_GLOBAL_INPUT_COUNT, key: vec![] },
value: VarInt::from(self.input_count).serialize(),
});
rv.push(raw::Pair {
key: raw::Key { type_value: PSBT_GLOBAL_OUTPUT_COUNT, key: vec![] },
value: VarInt::from(self.output_count).serialize(),
});
rv.push(raw::Pair {
key: raw::Key { type_value: PSBT_GLOBAL_TX_MODIFIABLE, key: vec![] },
value: vec![self.tx_modifiable_flags],
});
for (xpub, (fingerprint, derivation)) in &self.xpubs {
rv.push(raw::Pair {
key: raw::Key { type_value: PSBT_GLOBAL_XPUB, key: xpub.encode().to_vec() },
value: {
let mut ret = Vec::with_capacity(4 + derivation.len() * 4);
ret.extend(fingerprint.as_bytes());
derivation.into_iter().for_each(|n| ret.extend(&u32::from(*n).to_le_bytes()));
ret
},
});
}
#[cfg(feature = "silent-payments")]
for (scan_key, ecdh_share) in &self.sp_ecdh_shares {
rv.push(raw::Pair {
key: raw::Key {
type_value: PSBT_GLOBAL_SP_ECDH_SHARE,
key: scan_key.to_bytes().to_vec(),
},
value: ecdh_share.to_bytes().to_vec(),
});
}
#[cfg(feature = "silent-payments")]
for (scan_key, dleq_proof) in &self.sp_dleq_proofs {
rv.push(raw::Pair {
key: raw::Key {
type_value: PSBT_GLOBAL_SP_DLEQ,
key: scan_key.to_bytes().to_vec(),
},
value: dleq_proof.as_bytes().to_vec(),
});
}
for (key, value) in self.proprietaries.iter() {
rv.push(raw::Pair { key: key.to_key(), value: value.clone() });
}
for (key, value) in self.unknowns.iter() {
rv.push(raw::Pair { key: key.clone(), value: value.clone() });
}
rv
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum DecodeError {
InsertPair(InsertPairError),
DeserPair(serialize::Error),
MissingVersion,
MissingTxVersion,
MissingInputCount,
InputCountOverflow(u64),
MissingOutputCount,
OutputCountOverflow(u64),
FieldMismatch,
}
impl fmt::Display for DecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use DecodeError::*;
match *self {
InsertPair(ref e) => write_err!(f, "error inserting a pair"; e),
DeserPair(ref e) => write_err!(f, "error deserializing a pair"; e),
MissingVersion => write!(f, "serialized PSBT is missing the version number"),
MissingTxVersion => {
write!(f, "serialized PSBT is missing the transaction version number")
}
MissingInputCount => write!(f, "serialized PSBT is missing the input count"),
InputCountOverflow(count) => {
write!(f, "input count overflows word size for current architecture: {}", count)
}
MissingOutputCount => write!(f, "serialized PSBT is missing the output count"),
OutputCountOverflow(count) => {
write!(f, "output count overflows word size for current architecture: {}", count)
}
FieldMismatch => {
write!(f, "ECDH shares and DLEQ proofs must both be present or both absent")
}
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for DecodeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
use DecodeError::*;
match *self {
InsertPair(ref e) => Some(e),
DeserPair(ref e) => Some(e),
MissingVersion
| MissingTxVersion
| MissingInputCount
| InputCountOverflow(_)
| MissingOutputCount
| OutputCountOverflow(_)
| FieldMismatch => None,
}
}
}
impl From<InsertPairError> for DecodeError {
fn from(e: InsertPairError) -> Self { Self::InsertPair(e) }
}
#[derive(Debug)]
pub enum InsertPairError {
DuplicateKey(raw::Key),
InvalidKeyDataEmpty(raw::Key),
InvalidKeyDataNotEmpty(raw::Key),
Deser(serialize::Error),
Consensus(consensus::Error),
ValueWrongLength(usize, usize),
WrongVersion(u32),
XpubInvalidFingerprint,
XpubInvalidPath(usize),
Bip32(bip32::Error),
DuplicateXpub(KeySource),
InvalidProprietaryKey,
ExcludedKey {
key_type_value: u8,
},
KeyWrongLength(usize, usize),
}
impl fmt::Display for InsertPairError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use InsertPairError::*;
match *self {
DuplicateKey(ref key) => write!(f, "duplicate key: {}", key),
InvalidKeyDataEmpty(ref key) => write!(f, "key should contain data: {}", key),
InvalidKeyDataNotEmpty(ref key) => write!(f, "key should not contain data: {}", key),
Deser(ref e) => write_err!(f, "error deserializing raw value"; e),
Consensus(ref e) => write_err!(f, "error consensus deserializing type"; e),
ValueWrongLength(got, want) => {
write!(f, "value (keyvalue pair) wrong length (got, want) {} {}", got, want)
}
WrongVersion(v) => {
write!(f, "PSBT_GLOBAL_VERSION: PSBT v2 expects the version to be 2, found: {}", v)
}
XpubInvalidFingerprint => {
write!(f, "PSBT_GLOBAL_XPUB: derivation path must be a list of 32 byte varints")
}
XpubInvalidPath(len) => write!(
f,
"PSBT_GLOBAL_XPUB: derivation path must be a list of 32 byte varints: {}",
len
),
Bip32(ref e) => write_err!(f, "PSBT_GLOBAL_XPUB: Failed to decode a BIP-32 type"; e),
DuplicateXpub((fingerprint, ref derivation_path)) => write!(
f,
"PSBT_GLOBAL_XPUB: xpubs must be unique ({}, {})",
fingerprint, derivation_path
),
InvalidProprietaryKey => write!(f, "PSBT_GLOBAL_PROPRIETARY: Invalid proprietary key"),
ExcludedKey { key_type_value } => write!(
f,
"found a keypair type that is explicitly excluded: {}",
consts::psbt_global_key_type_value_to_str(key_type_value)
),
KeyWrongLength(got, expected) => {
write!(f, "key wrong length (got: {}, expected: {})", got, expected)
}
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for InsertPairError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
use InsertPairError::*;
match *self {
Deser(ref e) => Some(e),
Consensus(ref e) => Some(e),
Bip32(ref e) => Some(e),
DuplicateKey(_)
| InvalidKeyDataEmpty(_)
| InvalidKeyDataNotEmpty(_)
| ValueWrongLength(..)
| WrongVersion(_)
| XpubInvalidFingerprint
| XpubInvalidPath(_)
| DuplicateXpub(_)
| InvalidProprietaryKey
| ExcludedKey { .. }
| KeyWrongLength(..) => None,
}
}
}
impl From<serialize::Error> for InsertPairError {
fn from(e: serialize::Error) -> Self { Self::Deser(e) }
}
impl From<consensus::Error> for InsertPairError {
fn from(e: consensus::Error) -> Self { Self::Consensus(e) }
}
impl From<bip32::Error> for InsertPairError {
fn from(e: bip32::Error) -> Self { Self::Bip32(e) }
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum CombineError {
VersionMismatch {
this: Version,
that: Version,
},
TxVersionMismatch {
this: transaction::Version,
that: transaction::Version,
},
InconsistentKeySources(InconsistentKeySourcesError),
}
impl fmt::Display for CombineError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use CombineError::*;
match *self {
VersionMismatch { ref this, ref that } => {
write!(f, "combine two PSBTs with different versions: {:?} {:?}", this, that)
}
TxVersionMismatch { ref this, ref that } => {
write!(f, "combine two PSBTs with different tx versions: {:?} {:?}", this, that)
}
InconsistentKeySources(ref e) => {
write_err!(f, "combine with inconsistent key sources"; e)
}
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for CombineError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
use CombineError::*;
match *self {
InconsistentKeySources(ref e) => Some(e),
VersionMismatch { .. } | TxVersionMismatch { .. } => None,
}
}
}
impl From<InconsistentKeySourcesError> for CombineError {
fn from(e: InconsistentKeySourcesError) -> Self { Self::InconsistentKeySources(e) }
}