Skip to main content

psbt_v2/
sighash_type.rs

1// SPDX-License-Identifier: CC0-1.0
2
3use core::fmt;
4use core::str::FromStr;
5
6use bitcoin::sighash::{self, EcdsaSighashType, NonStandardSighashTypeError, TapSighashType};
7
8use crate::error::write_err;
9use crate::prelude::*;
10
11/// A Signature hash type for the corresponding input. As of taproot upgrade, the signature hash
12/// type can be either [`EcdsaSighashType`] or [`TapSighashType`] but it is not possible to know
13/// directly which signature hash type the user is dealing with. Therefore, the user is responsible
14/// for converting to/from [`PsbtSighashType`] from/to the desired signature hash type they need.
15#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct PsbtSighashType {
18    pub(crate) inner: u32,
19}
20
21impl fmt::Display for PsbtSighashType {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        match self.taproot_hash_ty() {
24            Err(_) => write!(f, "{:#x}", self.inner),
25            Ok(taproot_hash_ty) => fmt::Display::fmt(&taproot_hash_ty, f),
26        }
27    }
28}
29
30impl FromStr for PsbtSighashType {
31    type Err = ParseSighashTypeError;
32
33    #[inline]
34    fn from_str(s: &str) -> Result<Self, Self::Err> {
35        // We accept strings of form: "SIGHASH_ALL" etc.
36        //
37        // NB: some of Taproot sighash types are non-standard for pre-taproot
38        // inputs. We also do not support SIGHASH_RESERVED in verbatim form
39        // ("0xFF" string should be used instead).
40        if let Ok(ty) = TapSighashType::from_str(s) {
41            return Ok(ty.into());
42        }
43
44        // We accept non-standard sighash values.
45        if let Ok(inner) = u32::from_str_radix(s.trim_start_matches("0x"), 16) {
46            return Ok(PsbtSighashType { inner });
47        }
48
49        Err(ParseSighashTypeError { unrecognized: s.to_owned() })
50    }
51}
52impl From<EcdsaSighashType> for PsbtSighashType {
53    fn from(ecdsa_hash_ty: EcdsaSighashType) -> Self {
54        PsbtSighashType { inner: ecdsa_hash_ty as u32 }
55    }
56}
57
58impl From<TapSighashType> for PsbtSighashType {
59    fn from(taproot_hash_ty: TapSighashType) -> Self {
60        PsbtSighashType { inner: taproot_hash_ty as u32 }
61    }
62}
63
64impl PsbtSighashType {
65    /// Returns the [`EcdsaSighashType`] if the [`PsbtSighashType`] can be
66    /// converted to one.
67    pub fn ecdsa_hash_ty(self) -> Result<EcdsaSighashType, NonStandardSighashTypeError> {
68        EcdsaSighashType::from_standard(self.inner)
69    }
70
71    /// Returns the [`TapSighashType`] if the [`PsbtSighashType`] can be
72    /// converted to one.
73    pub fn taproot_hash_ty(self) -> Result<TapSighashType, InvalidSighashTypeError> {
74        if self.inner > 0xffu32 {
75            return Err(InvalidSighashTypeError::Invalid(self.inner));
76        }
77
78        let ty = TapSighashType::from_consensus_u8(self.inner as u8)?;
79        Ok(ty)
80    }
81
82    /// Creates a [`PsbtSighashType`] from a raw `u32`.
83    ///
84    /// Allows construction of a non-standard or non-valid sighash flag
85    /// ([`EcdsaSighashType`], [`TapSighashType`] respectively).
86    pub fn from_u32(n: u32) -> PsbtSighashType { PsbtSighashType { inner: n } }
87
88    /// Converts [`PsbtSighashType`] to a raw `u32` sighash flag.
89    ///
90    /// No guarantees are made as to the standardness or validity of the returned value.
91    pub fn to_u32(self) -> u32 { self.inner }
92}
93
94/// Error returned for failure during parsing one of the sighash types.
95///
96/// This is currently returned for unrecognized sighash strings.
97#[derive(Debug, Clone, PartialEq, Eq)]
98#[non_exhaustive]
99pub struct ParseSighashTypeError {
100    /// The unrecognized string we attempted to parse.
101    pub unrecognized: String,
102}
103
104impl fmt::Display for ParseSighashTypeError {
105    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
106        write!(f, "unrecognized SIGHASH string '{}'", self.unrecognized)
107    }
108}
109
110#[cfg(feature = "std")]
111impl std::error::Error for ParseSighashTypeError {
112    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None }
113}
114
115// TODO: Remove this error after issue resolves.
116// https://github.com/rust-bitcoin/rust-bitcoin/issues/2423
117/// Integer is not a consensus valid sighash type.
118#[derive(Debug, Clone, PartialEq, Eq)]
119#[non_exhaustive]
120pub enum InvalidSighashTypeError {
121    /// The real invalid sighash type error.
122    Bitcoin(sighash::InvalidSighashTypeError),
123    /// Hack required because of non_exhaustive on the real error.
124    Invalid(u32),
125}
126
127impl fmt::Display for InvalidSighashTypeError {
128    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
129        use InvalidSighashTypeError::*;
130
131        match *self {
132            Bitcoin(ref e) => write_err!(f, "bitcoin"; e),
133            Invalid(invalid) => write!(f, "invalid sighash type {}", invalid),
134        }
135    }
136}
137
138#[cfg(feature = "std")]
139impl std::error::Error for InvalidSighashTypeError {
140    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
141        use InvalidSighashTypeError::*;
142
143        match *self {
144            Bitcoin(ref e) => Some(e),
145            Invalid(_) => None,
146        }
147    }
148}
149
150impl From<sighash::InvalidSighashTypeError> for InvalidSighashTypeError {
151    fn from(e: sighash::InvalidSighashTypeError) -> Self { Self::Bitcoin(e) }
152}
153
154#[cfg(test)]
155mod tests {
156    use core::str::FromStr;
157
158    use super::*;
159    use crate::sighash_type::InvalidSighashTypeError;
160
161    #[test]
162    fn psbt_sighash_type_ecdsa() {
163        for ecdsa in &[
164            EcdsaSighashType::All,
165            EcdsaSighashType::None,
166            EcdsaSighashType::Single,
167            EcdsaSighashType::AllPlusAnyoneCanPay,
168            EcdsaSighashType::NonePlusAnyoneCanPay,
169            EcdsaSighashType::SinglePlusAnyoneCanPay,
170        ] {
171            let sighash = PsbtSighashType::from(*ecdsa);
172            let s = format!("{}", sighash);
173            let back = PsbtSighashType::from_str(&s).unwrap();
174            assert_eq!(back, sighash);
175            assert_eq!(back.ecdsa_hash_ty().unwrap(), *ecdsa);
176        }
177    }
178
179    #[test]
180    fn psbt_sighash_type_taproot() {
181        for tap in &[
182            TapSighashType::Default,
183            TapSighashType::All,
184            TapSighashType::None,
185            TapSighashType::Single,
186            TapSighashType::AllPlusAnyoneCanPay,
187            TapSighashType::NonePlusAnyoneCanPay,
188            TapSighashType::SinglePlusAnyoneCanPay,
189        ] {
190            let sighash = PsbtSighashType::from(*tap);
191            let s = format!("{}", sighash);
192            let back = PsbtSighashType::from_str(&s).unwrap();
193            assert_eq!(back, sighash);
194            assert_eq!(back.taproot_hash_ty().unwrap(), *tap);
195        }
196    }
197
198    #[test]
199    fn psbt_sighash_type_notstd() {
200        let nonstd = 0xdddddddd;
201        let sighash = PsbtSighashType { inner: nonstd };
202        let s = format!("{}", sighash);
203        let back = PsbtSighashType::from_str(&s).unwrap();
204
205        assert_eq!(back, sighash);
206        // TODO: Add this assertion once we remove InvalidSighashTypeError
207        // assert_eq!(back.ecdsa_hash_ty(), Err(NonStandardSighashTypeError(nonstd)));
208        assert_eq!(back.taproot_hash_ty(), Err(InvalidSighashTypeError::Invalid(nonstd)));
209    }
210}