mod error;
use bitcoin::OutPoint;
pub(crate) use error::InternalRequestError;
pub use error::RequestError;
use super::*;
pub use crate::receive::common::{WantsFeeRange, WantsInputs, WantsOutputs};
use crate::uri::PjParam;
use crate::{IntoUrl, OutputSubstitution, PjParseError, Version};
const SUPPORTED_VERSIONS: &[Version] = &[Version::One];
pub trait Headers {
fn get_header(&self, key: &str) -> Option<&str>;
}
pub fn build_v1_pj_uri<'a>(
address: &bitcoin::Address,
endpoint: impl IntoUrl,
output_substitution: OutputSubstitution,
) -> Result<crate::uri::PjUri<'a>, PjParseError> {
let pj_param = PjParam::parse(endpoint)?;
let extras = crate::uri::PayjoinExtras { pj_param, output_substitution };
Ok(bitcoin_uri::Uri::with_extras(address.clone(), extras))
}
impl UncheckedOriginalPayload {
pub fn from_request(body: &[u8], query: &str, headers: impl Headers) -> Result<Self, Error> {
let validated_body = validate_body(headers, body).map_err(ProtocolError::V1)?;
let base64 = std::str::from_utf8(validated_body).map_err(InternalPayloadError::Utf8)?;
let (psbt, params) = crate::receive::parse_payload(base64, query, SUPPORTED_VERSIONS)
.map_err(ProtocolError::OriginalPayload)?;
Ok(Self { original: OriginalPayload { psbt, params } })
}
}
#[derive(Debug, Clone)]
pub struct UncheckedOriginalPayload {
original: OriginalPayload,
}
impl UncheckedOriginalPayload {
pub fn check_broadcast_suitability(
self,
min_fee_rate: Option<FeeRate>,
can_broadcast: impl Fn(&bitcoin::Transaction) -> Result<bool, ImplementationError>,
) -> Result<MaybeInputsOwned, Error> {
self.original.check_broadcast_suitability(min_fee_rate, can_broadcast)?;
Ok(MaybeInputsOwned { original: self.original })
}
pub fn assume_interactive_receiver(self) -> MaybeInputsOwned {
MaybeInputsOwned { original: self.original }
}
}
#[derive(Debug, Clone)]
pub struct MaybeInputsOwned {
pub(crate) original: OriginalPayload,
}
impl MaybeInputsOwned {
pub fn extract_tx_to_schedule_broadcast(&self) -> bitcoin::Transaction {
self.original.psbt.clone().extract_tx_unchecked_fee_rate()
}
pub fn check_inputs_not_owned(
self,
is_owned: &mut impl FnMut(&Script) -> Result<bool, ImplementationError>,
) -> Result<MaybeInputsSeen, Error> {
self.original.check_inputs_not_owned(is_owned)?;
Ok(MaybeInputsSeen { original: self.original })
}
}
#[derive(Debug, Clone)]
pub struct MaybeInputsSeen {
original: OriginalPayload,
}
impl MaybeInputsSeen {
pub fn check_no_inputs_seen_before(
self,
is_known: &mut impl FnMut(&OutPoint) -> Result<bool, ImplementationError>,
) -> Result<OutputsUnknown, Error> {
self.original.check_no_inputs_seen_before(is_known)?;
Ok(OutputsUnknown { original: self.original })
}
}
#[derive(Debug, Clone)]
pub struct OutputsUnknown {
original: OriginalPayload,
}
impl OutputsUnknown {
#[cfg_attr(not(feature = "v1"), allow(dead_code))]
pub fn identify_receiver_outputs(
self,
is_receiver_output: &mut impl FnMut(&Script) -> Result<bool, ImplementationError>,
) -> Result<WantsOutputs, Error> {
self.original.identify_receiver_outputs(is_receiver_output)
}
}
fn validate_body(headers: impl Headers, body: &[u8]) -> Result<&[u8], RequestError> {
let content_type = headers
.get_header("content-type")
.ok_or(InternalRequestError::MissingHeader("Content-Type"))?;
if !content_type.starts_with("text/plain") {
return Err(InternalRequestError::InvalidContentType(content_type.to_owned()).into());
}
let content_length = headers
.get_header("content-length")
.ok_or(InternalRequestError::MissingHeader("Content-Length"))?
.parse::<usize>()
.map_err(InternalRequestError::InvalidContentLength)?;
if body.len() != content_length {
return Err(InternalRequestError::ContentLengthMismatch {
expected: content_length,
actual: body.len(),
}
.into());
}
Ok(body)
}
impl crate::receive::common::WantsFeeRange {
pub fn apply_fee_range(
self,
min_fee_rate: Option<FeeRate>,
max_effective_fee_rate: Option<FeeRate>,
) -> Result<ProvisionalProposal, Error> {
let psbt_context =
self.calculate_psbt_context_with_fee_range(min_fee_rate, max_effective_fee_rate)?;
Ok(ProvisionalProposal { psbt_context })
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ProvisionalProposal {
psbt_context: PsbtContext,
}
impl ProvisionalProposal {
pub fn finalize_proposal(
self,
wallet_process_psbt: impl Fn(&Psbt) -> Result<Psbt, ImplementationError>,
) -> Result<PayjoinProposal, Error> {
let finalized_psbt = self
.psbt_context
.finalize_proposal(wallet_process_psbt)
.map_err(|e| Error::Implementation(ImplementationError::new(e)))?;
Ok(PayjoinProposal { payjoin_psbt: finalized_psbt })
}
pub fn psbt_to_sign(&self) -> Psbt { self.psbt_context.payjoin_psbt.clone() }
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct PayjoinProposal {
payjoin_psbt: Psbt,
}
impl PayjoinProposal {
pub fn utxos_to_be_locked(&self) -> impl '_ + Iterator<Item = &bitcoin::OutPoint> {
self.payjoin_psbt.unsigned_tx.input.iter().map(|input| &input.previous_output)
}
pub fn psbt(&self) -> &Psbt { &self.payjoin_psbt }
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use bitcoin::absolute::{LockTime, Time};
use bitcoin::{Address, Amount, Network, Transaction};
use payjoin_test_utils::{
ORIGINAL_PSBT, PARSED_ORIGINAL_PSBT, PARSED_PAYJOIN_PROPOSAL, QUERY_PARAMS,
};
use super::*;
use crate::Version;
#[derive(Debug, Clone)]
struct MockHeaders {
length: String,
}
impl MockHeaders {
fn new(length: u64) -> MockHeaders { MockHeaders { length: length.to_string() } }
}
impl Headers for MockHeaders {
fn get_header(&self, key: &str) -> Option<&str> {
match key {
"content-length" => Some(&self.length),
"content-type" => Some("text/plain"),
_ => None,
}
}
}
#[test]
fn test_parse_body() {
let body = ORIGINAL_PSBT.as_bytes().to_vec();
let headers = MockHeaders::new((body.len() + 1) as u64);
let validated_request = validate_body(headers.clone(), body.as_slice());
assert!(validated_request.is_err());
match validated_request {
Ok(_) => panic!("Expected error, got success"),
Err(error) => {
assert_eq!(
error.to_string(),
RequestError::from(InternalRequestError::ContentLengthMismatch {
expected: body.len() + 1,
actual: body.len(),
})
.to_string()
);
}
}
}
#[test]
fn test_from_request() -> Result<(), Box<dyn std::error::Error>> {
let body = ORIGINAL_PSBT.as_bytes();
let headers = MockHeaders::new(body.len() as u64);
let validated_request = validate_body(headers.clone(), body);
assert!(validated_request.is_ok());
let proposal = UncheckedOriginalPayload::from_request(body, QUERY_PARAMS, headers)?;
let witness_utxo = proposal.original.psbt.inputs[0]
.witness_utxo
.as_ref()
.expect("witness_utxo should be present");
let address =
Address::from_script(&witness_utxo.script_pubkey, bitcoin::params::Params::MAINNET)?;
assert_eq!(address.address_type(), Some(AddressType::P2sh));
assert_eq!(proposal.original.params.v, Version::One);
assert_eq!(
proposal.original.params.additional_fee_contribution,
Some((Amount::from_sat(182), 0))
);
Ok(())
}
fn unchecked_proposal_from_test_vector() -> UncheckedOriginalPayload {
let pairs = url::form_urlencoded::parse(QUERY_PARAMS.as_bytes());
let params = Params::from_query_pairs(pairs, &[Version::One])
.expect("Could not parse params from query pairs");
UncheckedOriginalPayload {
original: OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params },
}
}
fn maybe_inputs_owned_from_test_vector() -> MaybeInputsOwned {
let pairs = url::form_urlencoded::parse(QUERY_PARAMS.as_bytes());
let params = Params::from_query_pairs(pairs, &[Version::One])
.expect("Could not parse params from query pairs");
MaybeInputsOwned {
original: OriginalPayload { psbt: PARSED_ORIGINAL_PSBT.clone(), params },
}
}
fn wants_outputs_from_test_vector(proposal: UncheckedOriginalPayload) -> WantsOutputs {
proposal
.assume_interactive_receiver()
.check_inputs_not_owned(&mut |_| Ok(false))
.expect("No inputs should be owned")
.check_no_inputs_seen_before(&mut |_| Ok(false))
.expect("No inputs should be seen before")
.identify_receiver_outputs(&mut |script| {
let network = Network::Bitcoin;
Ok(Address::from_script(script, network).unwrap()
== Address::from_str("3CZZi7aWFugaCdUCS15dgrUUViupmB8bVM")
.unwrap()
.require_network(network)
.unwrap())
})
.expect("Receiver output should be identified")
}
fn provisional_proposal_from_test_vector(
proposal: UncheckedOriginalPayload,
) -> ProvisionalProposal {
wants_outputs_from_test_vector(proposal)
.commit_outputs()
.commit_inputs()
.apply_fee_range(None, None)
.expect("Contributed inputs should allow for valid fee contributions")
}
#[test]
fn test_mutable_receiver_state_closures() {
let mut call_count = 0;
let maybe_inputs_owned = maybe_inputs_owned_from_test_vector();
fn mock_callback(call_count: &mut usize, ret: bool) -> Result<bool, ImplementationError> {
*call_count += 1;
Ok(ret)
}
let maybe_inputs_seen = maybe_inputs_owned
.check_inputs_not_owned(&mut |_| mock_callback(&mut call_count, false));
assert_eq!(call_count, 1);
let outputs_unknown = maybe_inputs_seen
.map_err(|_| "Check inputs owned closure failed".to_string())
.expect("Next receiver state should be accessible")
.check_no_inputs_seen_before(&mut |_| mock_callback(&mut call_count, false));
assert_eq!(call_count, 2);
let _wants_outputs = outputs_unknown
.map_err(|_| "Check no inputs seen closure failed".to_string())
.expect("Next receiver state should be accessible")
.identify_receiver_outputs(&mut |_| mock_callback(&mut call_count, true));
assert_eq!(call_count, 4);
}
#[test]
fn is_output_substitution_disabled() {
let mut proposal = unchecked_proposal_from_test_vector();
let payjoin = wants_outputs_from_test_vector(proposal.clone());
assert_eq!(payjoin.output_substitution(), OutputSubstitution::Enabled);
proposal.original.params.output_substitution = OutputSubstitution::Disabled;
let payjoin = wants_outputs_from_test_vector(proposal);
assert_eq!(payjoin.output_substitution(), OutputSubstitution::Disabled);
}
#[test]
fn unchecked_proposal_min_fee() {
let proposal = unchecked_proposal_from_test_vector();
let min_fee_rate =
proposal.original.psbt_fee_rate().expect("Feerate calculation should not fail");
let _ = proposal
.clone()
.check_broadcast_suitability(Some(min_fee_rate), |_| Ok(true))
.expect("Broadcast suitability check with appropriate min_fee_rate should succeed");
assert_eq!(proposal.original.psbt_fee_rate().unwrap(), min_fee_rate);
let min_fee_rate = FeeRate::MAX;
let proposal_below_min_fee = proposal
.clone()
.check_broadcast_suitability(Some(min_fee_rate), |_| Ok(true))
.expect_err("Broadcast suitability with min_fee_rate below minimum should fail");
match proposal_below_min_fee {
Error::Protocol(ProtocolError::OriginalPayload(PayloadError(
InternalPayloadError::PsbtBelowFeeRate(original_fee_rate, min_fee_rate_param),
))) => {
assert_eq!(original_fee_rate, proposal.original.psbt_fee_rate().unwrap());
assert_eq!(min_fee_rate_param, min_fee_rate);
}
_ => panic!("Expected PsbtBelowFeeRate error, got: {proposal_below_min_fee:?}"),
}
}
#[test]
fn test_finalize_proposal_invalid_payjoin_proposal() {
let proposal = unchecked_proposal_from_test_vector();
let provisional = provisional_proposal_from_test_vector(proposal);
let empty_tx = Transaction {
version: bitcoin::transaction::Version::TWO,
lock_time: LockTime::Seconds(Time::MIN),
input: vec![],
output: vec![],
};
let other_psbt = Psbt::from_unsigned_tx(empty_tx).expect("Valid unsigned tx");
let err = provisional.clone().finalize_proposal(|_| Ok(other_psbt.clone())).unwrap_err();
assert_eq!(
err.to_string(),
format!(
"Implementation error: Ntxid mismatch: expected {}, got {}",
provisional.psbt_context.payjoin_psbt.unsigned_tx.compute_txid(),
other_psbt.unsigned_tx.compute_txid()
)
);
}
#[test]
fn test_getting_psbt_to_sign() {
let provisional_proposal = ProvisionalProposal {
psbt_context: PsbtContext {
payjoin_psbt: PARSED_PAYJOIN_PROPOSAL.clone(),
original_psbt: PARSED_ORIGINAL_PSBT.clone(),
},
};
let psbt = provisional_proposal.psbt_to_sign();
assert_eq!(psbt, PARSED_PAYJOIN_PROPOSAL.clone());
}
}