psbt/
constructor.rs

1// Modern, minimalistic & standard-compliant cold wallet library.
2//
3// SPDX-License-Identifier: Apache-2.0
4//
5// Written in 2020-2023 by
6//     Dr Maxim Orlovsky <orlovsky@lnp-bp.org>
7//
8// Copyright (C) 2020-2023 LNP/BP Standards Association. All rights reserved.
9// Copyright (C) 2020-2023 Dr Maxim Orlovsky. All rights reserved.
10//
11// Licensed under the Apache License, Version 2.0 (the "License");
12// you may not use this file except in compliance with the License.
13// You may obtain a copy of the License at
14//
15//     http://www.apache.org/licenses/LICENSE-2.0
16//
17// Unless required by applicable law or agreed to in writing, software
18// distributed under the License is distributed on an "AS IS" BASIS,
19// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20// See the License for the specific language governing permissions and
21// limitations under the License.
22
23use std::num::ParseIntError;
24use std::str::FromStr;
25
26use derive::{
27    Address, AddressParseError, Keychain, LockTime, Network, NormalIndex, Outpoint, Sats,
28    ScriptPubkey, SeqNo, Terminal, Vout,
29};
30use descriptors::Descriptor;
31
32use crate::{Prevout, Psbt, PsbtError, PsbtVer};
33
34#[derive(Clone, Debug, Display, Error, From)]
35#[display(doc_comments)]
36pub enum ConstructionError {
37    #[display(inner)]
38    Psbt(PsbtError),
39
40    /// the input spending {0} is not known for the current wallet.
41    UnknownInput(Outpoint),
42
43    /// impossible to construct transaction having no inputs.
44    NoInputs,
45
46    /// the total payment amount ({0} sats) exceeds number of sats in existence.
47    Overflow(Sats),
48
49    /// attempt to spend more than present in transaction inputs. Total transaction inputs are
50    /// {input_value} sats, but output is {output_value} sats.
51    OutputExceedsInputs {
52        input_value: Sats,
53        output_value: Sats,
54    },
55
56    /// not enough funds to pay fee of {fee} sats; the sum of inputs is {input_value} sats, and
57    /// outputs spends {output_value} sats out of them.
58    NoFundsForFee {
59        input_value: Sats,
60        output_value: Sats,
61        fee: Sats,
62    },
63
64    /// network for address {0} mismatch the one used by the wallet.
65    NetworkMismatch(Address),
66}
67
68#[derive(Clone, Debug, Display, Error, From)]
69#[display(doc_comments)]
70pub enum BeneficiaryParseError {
71    #[display("invalid format of the invoice")]
72    InvalidFormat,
73
74    #[from]
75    Int(ParseIntError),
76
77    #[from]
78    Address(AddressParseError),
79}
80
81#[derive(Copy, Clone, Eq, PartialEq, Debug, Display, From)]
82pub enum Payment {
83    #[from]
84    #[display(inner)]
85    Fixed(Sats),
86    #[display("MAX")]
87    Max,
88}
89
90impl Payment {
91    #[inline]
92    pub fn sats(&self) -> Option<Sats> {
93        match self {
94            Payment::Fixed(sats) => Some(*sats),
95            Payment::Max => None,
96        }
97    }
98
99    #[inline]
100    pub fn unwrap_or(&self, default: impl Into<Sats>) -> Sats {
101        self.sats().unwrap_or(default.into())
102    }
103
104    #[inline]
105    pub fn is_max(&self) -> bool { *self == Payment::Max }
106}
107
108impl FromStr for Payment {
109    type Err = ParseIntError;
110
111    fn from_str(s: &str) -> Result<Self, Self::Err> {
112        if s == "MAX" {
113            return Ok(Payment::Max);
114        }
115        Sats::from_str(s).map(Payment::Fixed)
116    }
117}
118
119#[derive(Copy, Clone, Eq, PartialEq, Debug, Display)]
120#[display("{amount}@{address}", alt = "bitcoin:{address}?amount={amount}")]
121pub struct Beneficiary {
122    pub address: Address,
123    pub amount: Payment,
124}
125
126impl Beneficiary {
127    #[inline]
128    pub fn new(address: Address, amount: impl Into<Payment>) -> Self {
129        Beneficiary {
130            address,
131            amount: amount.into(),
132        }
133    }
134    #[inline]
135    pub fn with_max(address: Address) -> Self {
136        Beneficiary {
137            address,
138            amount: Payment::Max,
139        }
140    }
141    #[inline]
142    pub fn is_max(&self) -> bool { self.amount.is_max() }
143    #[inline]
144    pub fn script_pubkey(&self) -> ScriptPubkey { self.address.script_pubkey() }
145}
146
147impl FromStr for Beneficiary {
148    type Err = BeneficiaryParseError;
149
150    fn from_str(s: &str) -> Result<Self, Self::Err> {
151        let (amount, beneficiary) =
152            s.split_once('@').ok_or(BeneficiaryParseError::InvalidFormat)?;
153        Ok(Beneficiary::new(Address::from_str(beneficiary)?, Payment::from_str(amount)?))
154    }
155}
156
157#[derive(Copy, Clone, PartialEq, Debug)]
158pub struct TxParams {
159    pub fee: Sats,
160    pub lock_time: Option<LockTime>,
161    pub seq_no: SeqNo,
162    pub change_shift: bool,
163    pub change_keychain: Keychain,
164}
165
166impl TxParams {
167    pub fn with(fee: Sats) -> Self {
168        TxParams {
169            fee,
170            lock_time: None,
171            seq_no: SeqNo::from_consensus_u32(0),
172            change_shift: true,
173            change_keychain: Keychain::INNER,
174        }
175    }
176}
177
178#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
179pub struct PsbtMeta {
180    pub change_vout: Option<Vout>,
181    pub change_terminal: Option<Terminal>,
182}
183
184#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
185pub struct Utxo {
186    pub outpoint: Outpoint,
187    pub value: Sats,
188    pub terminal: Terminal,
189}
190
191impl Utxo {
192    #[inline]
193    pub fn to_prevout(&self) -> Prevout { Prevout::new(self.outpoint, self.value) }
194}
195
196pub trait PsbtConstructor {
197    type Key;
198    type Descr: Descriptor<Self::Key>;
199
200    fn descriptor(&self) -> &Self::Descr;
201    fn utxo(&self, outpoint: Outpoint) -> Option<(Utxo, ScriptPubkey)>;
202    fn network(&self) -> Network;
203    fn next_derivation_index(&mut self, keychain: impl Into<Keychain>, shift: bool) -> NormalIndex;
204
205    fn construct_psbt(
206        &mut self,
207        coins: impl IntoIterator<Item = Outpoint>,
208        beneficiaries: impl IntoIterator<Item = Beneficiary>,
209        params: TxParams,
210    ) -> Result<(Psbt, PsbtMeta), ConstructionError> {
211        let mut psbt = Psbt::create(PsbtVer::V2);
212
213        // Set locktime
214        psbt.fallback_locktime = params.lock_time;
215
216        // Add xpubs
217        for spec in self.descriptor().xpubs() {
218            psbt.xpubs.insert(*spec.xpub(), spec.origin().clone());
219        }
220
221        // 1. Add inputs
222        for coin in coins {
223            let (utxo, spk) = self.utxo(coin).ok_or(ConstructionError::UnknownInput(coin))?;
224            if psbt.inputs().any(|inp| inp.previous_outpoint == utxo.outpoint) {
225                continue;
226            }
227            psbt.append_input_expect(
228                utxo.to_prevout(),
229                self.descriptor(),
230                utxo.terminal,
231                spk,
232                params.seq_no,
233            );
234        }
235        if psbt.inputs().count() == 0 {
236            return Err(ConstructionError::NoInputs);
237        }
238
239        // 2. Add outputs
240        let input_value = psbt.input_sum();
241        let mut max = Vec::new();
242        let mut output_value = Sats::ZERO;
243        for beneficiary in beneficiaries {
244            if beneficiary.address.network != self.network().into() {
245                return Err(ConstructionError::NetworkMismatch(beneficiary.address));
246            }
247            let amount = beneficiary.amount.unwrap_or(Sats::ZERO);
248            output_value
249                .checked_add_assign(amount)
250                .ok_or(ConstructionError::Overflow(output_value))?;
251            let out = psbt.append_output_expect(beneficiary.script_pubkey(), amount);
252            if beneficiary.amount.is_max() {
253                max.push(out.index());
254            }
255        }
256        let mut remaining_value = input_value
257            .checked_sub(output_value)
258            .ok_or(ConstructionError::OutputExceedsInputs {
259                input_value,
260                output_value,
261            })?
262            .checked_sub(params.fee)
263            .ok_or(ConstructionError::NoFundsForFee {
264                input_value,
265                output_value,
266                fee: params.fee,
267            })?;
268        if !max.is_empty() {
269            let portion = remaining_value / max.len();
270            for out in psbt.outputs_mut() {
271                if max.contains(&out.index()) {
272                    out.amount = portion;
273                }
274            }
275            remaining_value = Sats::ZERO;
276        }
277
278        // 3. Add change - only if exceeded the dust limit
279        let (change_vout, change_terminal) =
280            if remaining_value > self.descriptor().class().dust_limit() {
281                let change_index =
282                    self.next_derivation_index(params.change_keychain, params.change_shift);
283                let change_terminal = Terminal::new(params.change_keychain, change_index);
284                let change_vout = psbt
285                    .append_change_expect(self.descriptor(), change_terminal, remaining_value)
286                    .index();
287                (Some(Vout::from_u32(change_vout as u32)), Some(change_terminal))
288            } else {
289                (None, None)
290            };
291
292        Ok((psbt, PsbtMeta {
293            change_vout,
294            change_terminal,
295        }))
296    }
297}