use core::fmt;
use bitcoin::{FeeRate, Transaction, Txid};
use crate::error::{write_err, FeeError};
use crate::v2::{DetermineLockTimeError, Psbt};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Extractor(Psbt);
impl Extractor {
pub fn new(psbt: Psbt) -> Result<Self, ExtractError> {
if psbt.inputs.iter().any(|input| !input.is_finalized()) {
return Err(ExtractError::PsbtNotFinalized);
}
let _ = psbt.determine_lock_time()?;
Ok(Self(psbt))
}
pub fn id(&self) -> Txid {
self.0.id().expect("Extractor guarantees lock time can be determined")
}
}
impl Extractor {
pub const DEFAULT_MAX_FEE_RATE: FeeRate = FeeRate::from_sat_per_vb_unchecked(25_000);
pub fn extract_tx(&self) -> Result<Transaction, ExtractTxFeeRateError> {
self.internal_extract_tx_with_fee_rate_limit(Self::DEFAULT_MAX_FEE_RATE)
}
pub fn extract_tx_fee_rate_limit(&self) -> Result<Transaction, ExtractTxFeeRateError> {
self.internal_extract_tx_with_fee_rate_limit(Self::DEFAULT_MAX_FEE_RATE)
}
pub fn extract_tx_with_fee_rate_limit(
&self,
max_fee_rate: FeeRate,
) -> Result<Transaction, ExtractTxFeeRateError> {
self.internal_extract_tx_with_fee_rate_limit(max_fee_rate)
}
#[allow(clippy::result_large_err)]
pub fn extract_tx_unchecked_fee_rate(&self) -> Result<Transaction, ExtractTxError> {
self.internal_extract_tx()
}
#[inline]
fn internal_extract_tx_with_fee_rate_limit(
&self,
max_fee_rate: FeeRate,
) -> Result<Transaction, ExtractTxFeeRateError> {
let fee = self.0.fee()?;
let tx = self.internal_extract_tx()?;
let fee_rate =
FeeRate::from_sat_per_kwu(fee.to_sat().saturating_mul(1000) / tx.weight().to_wu());
if fee_rate > max_fee_rate {
return Err(ExtractTxFeeRateError::FeeTooHigh { fee: fee_rate, max: max_fee_rate });
}
Ok(tx)
}
#[inline]
#[allow(clippy::result_large_err)]
fn internal_extract_tx(&self) -> Result<Transaction, ExtractTxError> {
if !self.0.is_finalized() {
return Err(ExtractTxError::Unfinalized);
}
let lock_time = self.0.determine_lock_time()?;
let tx = Transaction {
version: self.0.global.tx_version,
lock_time,
input: self.0.inputs.iter().map(|input| input.signed_tx_in()).collect(),
output: self.0.outputs.iter().map(|ouput| ouput.tx_out()).collect(),
};
Ok(tx)
}
}
#[derive(Debug)]
pub enum ExtractError {
PsbtNotFinalized,
DetermineLockTime(DetermineLockTimeError),
}
impl fmt::Display for ExtractError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use ExtractError::*;
match *self {
PsbtNotFinalized => write!(f, "attempted to extract tx from an unfinalized PSBT"),
DetermineLockTime(ref e) =>
write_err!(f, "extractor must be able to determine the lock time"; e),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for ExtractError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
use ExtractError::*;
match *self {
DetermineLockTime(ref e) => Some(e),
PsbtNotFinalized => None,
}
}
}
impl From<DetermineLockTimeError> for ExtractError {
fn from(e: DetermineLockTimeError) -> Self { Self::DetermineLockTime(e) }
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ExtractTxFeeRateError {
Fee(FeeError),
FeeTooHigh {
fee: FeeRate,
max: FeeRate,
},
ExtractTx(ExtractTxError),
}
impl fmt::Display for ExtractTxFeeRateError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use ExtractTxFeeRateError::*;
match *self {
Fee(ref e) => write_err!(f, "fee calculation"; e),
FeeTooHigh { fee, max } => write!(f, "fee {} is greater than max {}", fee, max),
ExtractTx(ref e) => write_err!(f, "extract"; e),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for ExtractTxFeeRateError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
use ExtractTxFeeRateError::*;
match *self {
Fee(ref e) => Some(e),
ExtractTx(ref e) => Some(e),
FeeTooHigh { .. } => None,
}
}
}
impl From<FeeError> for ExtractTxFeeRateError {
fn from(e: FeeError) -> Self { Self::Fee(e) }
}
impl From<ExtractTxError> for ExtractTxFeeRateError {
fn from(e: ExtractTxError) -> Self { Self::ExtractTx(e) }
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ExtractTxError {
Unfinalized,
DetermineLockTime(DetermineLockTimeError),
}
impl fmt::Display for ExtractTxError {
fn fmt(&self, _f: &mut fmt::Formatter<'_>) -> fmt::Result { todo!() }
}
#[cfg(feature = "std")]
impl std::error::Error for ExtractTxError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { todo!() }
}
impl From<DetermineLockTimeError> for ExtractTxError {
fn from(e: DetermineLockTimeError) -> Self { Self::DetermineLockTime(e) }
}