use core::fmt;
use core::str::FromStr;
use bitcoin::sighash::{self, EcdsaSighashType, NonStandardSighashTypeError, TapSighashType};
use crate::error::write_err;
use crate::prelude::*;
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct PsbtSighashType {
pub(crate) inner: u32,
}
impl fmt::Display for PsbtSighashType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.taproot_hash_ty() {
Err(_) => write!(f, "{:#x}", self.inner),
Ok(taproot_hash_ty) => fmt::Display::fmt(&taproot_hash_ty, f),
}
}
}
impl FromStr for PsbtSighashType {
type Err = ParseSighashTypeError;
#[inline]
fn from_str(s: &str) -> Result<Self, Self::Err> {
if let Ok(ty) = TapSighashType::from_str(s) {
return Ok(ty.into());
}
if let Ok(inner) = u32::from_str_radix(s.trim_start_matches("0x"), 16) {
return Ok(PsbtSighashType { inner });
}
Err(ParseSighashTypeError { unrecognized: s.to_owned() })
}
}
impl From<EcdsaSighashType> for PsbtSighashType {
fn from(ecdsa_hash_ty: EcdsaSighashType) -> Self {
PsbtSighashType { inner: ecdsa_hash_ty as u32 }
}
}
impl From<TapSighashType> for PsbtSighashType {
fn from(taproot_hash_ty: TapSighashType) -> Self {
PsbtSighashType { inner: taproot_hash_ty as u32 }
}
}
impl PsbtSighashType {
pub fn ecdsa_hash_ty(self) -> Result<EcdsaSighashType, NonStandardSighashTypeError> {
EcdsaSighashType::from_standard(self.inner)
}
pub fn taproot_hash_ty(self) -> Result<TapSighashType, InvalidSighashTypeError> {
if self.inner > 0xffu32 {
return Err(InvalidSighashTypeError::Invalid(self.inner));
}
let ty = TapSighashType::from_consensus_u8(self.inner as u8)?;
Ok(ty)
}
pub fn from_u32(n: u32) -> PsbtSighashType { PsbtSighashType { inner: n } }
pub fn to_u32(self) -> u32 { self.inner }
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct ParseSighashTypeError {
pub unrecognized: String,
}
impl fmt::Display for ParseSighashTypeError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "unrecognized SIGHASH string '{}'", self.unrecognized)
}
}
#[cfg(feature = "std")]
impl std::error::Error for ParseSighashTypeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None }
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum InvalidSighashTypeError {
Bitcoin(sighash::InvalidSighashTypeError),
Invalid(u32),
}
impl fmt::Display for InvalidSighashTypeError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use InvalidSighashTypeError::*;
match *self {
Bitcoin(ref e) => write_err!(f, "bitcoin"; e),
Invalid(invalid) => write!(f, "invalid sighash type {}", invalid),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for InvalidSighashTypeError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
use InvalidSighashTypeError::*;
match *self {
Bitcoin(ref e) => Some(e),
Invalid(_) => None,
}
}
}
impl From<sighash::InvalidSighashTypeError> for InvalidSighashTypeError {
fn from(e: sighash::InvalidSighashTypeError) -> Self { Self::Bitcoin(e) }
}
#[cfg(test)]
mod tests {
use core::str::FromStr;
use super::*;
use crate::sighash_type::InvalidSighashTypeError;
#[test]
fn psbt_sighash_type_ecdsa() {
for ecdsa in &[
EcdsaSighashType::All,
EcdsaSighashType::None,
EcdsaSighashType::Single,
EcdsaSighashType::AllPlusAnyoneCanPay,
EcdsaSighashType::NonePlusAnyoneCanPay,
EcdsaSighashType::SinglePlusAnyoneCanPay,
] {
let sighash = PsbtSighashType::from(*ecdsa);
let s = format!("{}", sighash);
let back = PsbtSighashType::from_str(&s).unwrap();
assert_eq!(back, sighash);
assert_eq!(back.ecdsa_hash_ty().unwrap(), *ecdsa);
}
}
#[test]
fn psbt_sighash_type_taproot() {
for tap in &[
TapSighashType::Default,
TapSighashType::All,
TapSighashType::None,
TapSighashType::Single,
TapSighashType::AllPlusAnyoneCanPay,
TapSighashType::NonePlusAnyoneCanPay,
TapSighashType::SinglePlusAnyoneCanPay,
] {
let sighash = PsbtSighashType::from(*tap);
let s = format!("{}", sighash);
let back = PsbtSighashType::from_str(&s).unwrap();
assert_eq!(back, sighash);
assert_eq!(back.taproot_hash_ty().unwrap(), *tap);
}
}
#[test]
fn psbt_sighash_type_notstd() {
let nonstd = 0xdddddddd;
let sighash = PsbtSighashType { inner: nonstd };
let s = format!("{}", sighash);
let back = PsbtSighashType::from_str(&s).unwrap();
assert_eq!(back, sighash);
assert_eq!(back.taproot_hash_ty(), Err(InvalidSighashTypeError::Invalid(nonstd)));
}
}