1use core::fmt;
4use core::str::FromStr;
5
6use bitcoin::sighash::{self, EcdsaSighashType, NonStandardSighashTypeError, TapSighashType};
7
8use crate::error::write_err;
9use crate::prelude::*;
10
11#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct PsbtSighashType {
18 pub(crate) inner: u32,
19}
20
21impl fmt::Display for PsbtSighashType {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 match self.taproot_hash_ty() {
24 Err(_) => write!(f, "{:#x}", self.inner),
25 Ok(taproot_hash_ty) => fmt::Display::fmt(&taproot_hash_ty, f),
26 }
27 }
28}
29
30impl FromStr for PsbtSighashType {
31 type Err = ParseSighashTypeError;
32
33 #[inline]
34 fn from_str(s: &str) -> Result<Self, Self::Err> {
35 if let Ok(ty) = TapSighashType::from_str(s) {
41 return Ok(ty.into());
42 }
43
44 if let Ok(inner) = u32::from_str_radix(s.trim_start_matches("0x"), 16) {
46 return Ok(PsbtSighashType { inner });
47 }
48
49 Err(ParseSighashTypeError { unrecognized: s.to_owned() })
50 }
51}
52impl From<EcdsaSighashType> for PsbtSighashType {
53 fn from(ecdsa_hash_ty: EcdsaSighashType) -> Self {
54 PsbtSighashType { inner: ecdsa_hash_ty as u32 }
55 }
56}
57
58impl From<TapSighashType> for PsbtSighashType {
59 fn from(taproot_hash_ty: TapSighashType) -> Self {
60 PsbtSighashType { inner: taproot_hash_ty as u32 }
61 }
62}
63
64impl PsbtSighashType {
65 pub fn ecdsa_hash_ty(self) -> Result<EcdsaSighashType, NonStandardSighashTypeError> {
68 EcdsaSighashType::from_standard(self.inner)
69 }
70
71 pub fn taproot_hash_ty(self) -> Result<TapSighashType, InvalidSighashTypeError> {
74 if self.inner > 0xffu32 {
75 return Err(InvalidSighashTypeError::Invalid(self.inner));
76 }
77
78 let ty = TapSighashType::from_consensus_u8(self.inner as u8)?;
79 Ok(ty)
80 }
81
82 pub fn from_u32(n: u32) -> PsbtSighashType { PsbtSighashType { inner: n } }
87
88 pub fn to_u32(self) -> u32 { self.inner }
92}
93
94#[derive(Debug, Clone, PartialEq, Eq)]
98#[non_exhaustive]
99pub struct ParseSighashTypeError {
100 pub unrecognized: String,
102}
103
104impl fmt::Display for ParseSighashTypeError {
105 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
106 write!(f, "unrecognized SIGHASH string '{}'", self.unrecognized)
107 }
108}
109
110#[cfg(feature = "std")]
111impl std::error::Error for ParseSighashTypeError {
112 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
119#[non_exhaustive]
120pub enum InvalidSighashTypeError {
121 Bitcoin(sighash::InvalidSighashTypeError),
123 Invalid(u32),
125}
126
127impl fmt::Display for InvalidSighashTypeError {
128 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
129 use InvalidSighashTypeError::*;
130
131 match *self {
132 Bitcoin(ref e) => write_err!(f, "bitcoin"; e),
133 Invalid(invalid) => write!(f, "invalid sighash type {}", invalid),
134 }
135 }
136}
137
138#[cfg(feature = "std")]
139impl std::error::Error for InvalidSighashTypeError {
140 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
141 use InvalidSighashTypeError::*;
142
143 match *self {
144 Bitcoin(ref e) => Some(e),
145 Invalid(_) => None,
146 }
147 }
148}
149
150impl From<sighash::InvalidSighashTypeError> for InvalidSighashTypeError {
151 fn from(e: sighash::InvalidSighashTypeError) -> Self { Self::Bitcoin(e) }
152}
153
154#[cfg(test)]
155mod tests {
156 use core::str::FromStr;
157
158 use super::*;
159 use crate::sighash_type::InvalidSighashTypeError;
160
161 #[test]
162 fn psbt_sighash_type_ecdsa() {
163 for ecdsa in &[
164 EcdsaSighashType::All,
165 EcdsaSighashType::None,
166 EcdsaSighashType::Single,
167 EcdsaSighashType::AllPlusAnyoneCanPay,
168 EcdsaSighashType::NonePlusAnyoneCanPay,
169 EcdsaSighashType::SinglePlusAnyoneCanPay,
170 ] {
171 let sighash = PsbtSighashType::from(*ecdsa);
172 let s = format!("{}", sighash);
173 let back = PsbtSighashType::from_str(&s).unwrap();
174 assert_eq!(back, sighash);
175 assert_eq!(back.ecdsa_hash_ty().unwrap(), *ecdsa);
176 }
177 }
178
179 #[test]
180 fn psbt_sighash_type_taproot() {
181 for tap in &[
182 TapSighashType::Default,
183 TapSighashType::All,
184 TapSighashType::None,
185 TapSighashType::Single,
186 TapSighashType::AllPlusAnyoneCanPay,
187 TapSighashType::NonePlusAnyoneCanPay,
188 TapSighashType::SinglePlusAnyoneCanPay,
189 ] {
190 let sighash = PsbtSighashType::from(*tap);
191 let s = format!("{}", sighash);
192 let back = PsbtSighashType::from_str(&s).unwrap();
193 assert_eq!(back, sighash);
194 assert_eq!(back.taproot_hash_ty().unwrap(), *tap);
195 }
196 }
197
198 #[test]
199 fn psbt_sighash_type_notstd() {
200 let nonstd = 0xdddddddd;
201 let sighash = PsbtSighashType { inner: nonstd };
202 let s = format!("{}", sighash);
203 let back = PsbtSighashType::from_str(&s).unwrap();
204
205 assert_eq!(back, sighash);
206 assert_eq!(back.taproot_hash_ty(), Err(InvalidSighashTypeError::Invalid(nonstd)));
209 }
210}