1use std::num::ParseIntError;
24use std::str::FromStr;
25
26use derive::{
27 Address, AddressParseError, Keychain, LockTime, Network, NormalIndex, Outpoint, Sats,
28 ScriptPubkey, SeqNo, Terminal, Vout,
29};
30use descriptors::Descriptor;
31
32use crate::{Prevout, Psbt, PsbtError, PsbtVer};
33
34#[derive(Clone, Debug, Display, Error, From)]
35#[display(doc_comments)]
36pub enum ConstructionError {
37 #[display(inner)]
38 Psbt(PsbtError),
39
40 UnknownInput(Outpoint),
42
43 NoInputs,
45
46 Overflow(Sats),
48
49 OutputExceedsInputs {
52 input_value: Sats,
53 output_value: Sats,
54 },
55
56 NoFundsForFee {
59 input_value: Sats,
60 output_value: Sats,
61 fee: Sats,
62 },
63
64 NetworkMismatch(Address),
66}
67
68#[derive(Clone, Debug, Display, Error, From)]
69#[display(doc_comments)]
70pub enum BeneficiaryParseError {
71 #[display("invalid format of the invoice")]
72 InvalidFormat,
73
74 #[from]
75 Int(ParseIntError),
76
77 #[from]
78 Address(AddressParseError),
79}
80
81#[derive(Copy, Clone, Eq, PartialEq, Debug, Display, From)]
82pub enum Payment {
83 #[from]
84 #[display(inner)]
85 Fixed(Sats),
86 #[display("MAX")]
87 Max,
88}
89
90impl Payment {
91 #[inline]
92 pub fn sats(&self) -> Option<Sats> {
93 match self {
94 Payment::Fixed(sats) => Some(*sats),
95 Payment::Max => None,
96 }
97 }
98
99 #[inline]
100 pub fn unwrap_or(&self, default: impl Into<Sats>) -> Sats {
101 self.sats().unwrap_or(default.into())
102 }
103
104 #[inline]
105 pub fn is_max(&self) -> bool { *self == Payment::Max }
106}
107
108impl FromStr for Payment {
109 type Err = ParseIntError;
110
111 fn from_str(s: &str) -> Result<Self, Self::Err> {
112 if s == "MAX" {
113 return Ok(Payment::Max);
114 }
115 Sats::from_str(s).map(Payment::Fixed)
116 }
117}
118
119#[derive(Copy, Clone, Eq, PartialEq, Debug, Display)]
120#[display("{amount}@{address}", alt = "bitcoin:{address}?amount={amount}")]
121pub struct Beneficiary {
122 pub address: Address,
123 pub amount: Payment,
124}
125
126impl Beneficiary {
127 #[inline]
128 pub fn new(address: Address, amount: impl Into<Payment>) -> Self {
129 Beneficiary {
130 address,
131 amount: amount.into(),
132 }
133 }
134 #[inline]
135 pub fn with_max(address: Address) -> Self {
136 Beneficiary {
137 address,
138 amount: Payment::Max,
139 }
140 }
141 #[inline]
142 pub fn is_max(&self) -> bool { self.amount.is_max() }
143 #[inline]
144 pub fn script_pubkey(&self) -> ScriptPubkey { self.address.script_pubkey() }
145}
146
147impl FromStr for Beneficiary {
148 type Err = BeneficiaryParseError;
149
150 fn from_str(s: &str) -> Result<Self, Self::Err> {
151 let (amount, beneficiary) =
152 s.split_once('@').ok_or(BeneficiaryParseError::InvalidFormat)?;
153 Ok(Beneficiary::new(Address::from_str(beneficiary)?, Payment::from_str(amount)?))
154 }
155}
156
157#[derive(Copy, Clone, PartialEq, Debug)]
158pub struct TxParams {
159 pub fee: Sats,
160 pub lock_time: Option<LockTime>,
161 pub seq_no: SeqNo,
162 pub change_shift: bool,
163 pub change_keychain: Keychain,
164}
165
166impl TxParams {
167 pub fn with(fee: Sats) -> Self {
168 TxParams {
169 fee,
170 lock_time: None,
171 seq_no: SeqNo::from_consensus_u32(0),
172 change_shift: true,
173 change_keychain: Keychain::INNER,
174 }
175 }
176}
177
178#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
179pub struct PsbtMeta {
180 pub change_vout: Option<Vout>,
181 pub change_terminal: Option<Terminal>,
182}
183
184#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
185pub struct Utxo {
186 pub outpoint: Outpoint,
187 pub value: Sats,
188 pub terminal: Terminal,
189}
190
191impl Utxo {
192 #[inline]
193 pub fn to_prevout(&self) -> Prevout { Prevout::new(self.outpoint, self.value) }
194}
195
196pub trait PsbtConstructor {
197 type Key;
198 type Descr: Descriptor<Self::Key>;
199
200 fn descriptor(&self) -> &Self::Descr;
201 fn utxo(&self, outpoint: Outpoint) -> Option<(Utxo, ScriptPubkey)>;
202 fn network(&self) -> Network;
203 fn next_derivation_index(&mut self, keychain: impl Into<Keychain>, shift: bool) -> NormalIndex;
204
205 fn construct_psbt(
206 &mut self,
207 coins: impl IntoIterator<Item = Outpoint>,
208 beneficiaries: impl IntoIterator<Item = Beneficiary>,
209 params: TxParams,
210 ) -> Result<(Psbt, PsbtMeta), ConstructionError> {
211 let mut psbt = Psbt::create(PsbtVer::V2);
212
213 psbt.fallback_locktime = params.lock_time;
215
216 for spec in self.descriptor().xpubs() {
218 psbt.xpubs.insert(*spec.xpub(), spec.origin().clone());
219 }
220
221 for coin in coins {
223 let (utxo, spk) = self.utxo(coin).ok_or(ConstructionError::UnknownInput(coin))?;
224 if psbt.inputs().any(|inp| inp.previous_outpoint == utxo.outpoint) {
225 continue;
226 }
227 psbt.append_input_expect(
228 utxo.to_prevout(),
229 self.descriptor(),
230 utxo.terminal,
231 spk,
232 params.seq_no,
233 );
234 }
235 if psbt.inputs().count() == 0 {
236 return Err(ConstructionError::NoInputs);
237 }
238
239 let input_value = psbt.input_sum();
241 let mut max = Vec::new();
242 let mut output_value = Sats::ZERO;
243 for beneficiary in beneficiaries {
244 if beneficiary.address.network != self.network().into() {
245 return Err(ConstructionError::NetworkMismatch(beneficiary.address));
246 }
247 let amount = beneficiary.amount.unwrap_or(Sats::ZERO);
248 output_value
249 .checked_add_assign(amount)
250 .ok_or(ConstructionError::Overflow(output_value))?;
251 let out = psbt.append_output_expect(beneficiary.script_pubkey(), amount);
252 if beneficiary.amount.is_max() {
253 max.push(out.index());
254 }
255 }
256 let mut remaining_value = input_value
257 .checked_sub(output_value)
258 .ok_or(ConstructionError::OutputExceedsInputs {
259 input_value,
260 output_value,
261 })?
262 .checked_sub(params.fee)
263 .ok_or(ConstructionError::NoFundsForFee {
264 input_value,
265 output_value,
266 fee: params.fee,
267 })?;
268 if !max.is_empty() {
269 let portion = remaining_value / max.len();
270 for out in psbt.outputs_mut() {
271 if max.contains(&out.index()) {
272 out.amount = portion;
273 }
274 }
275 remaining_value = Sats::ZERO;
276 }
277
278 let (change_vout, change_terminal) =
280 if remaining_value > self.descriptor().class().dust_limit() {
281 let change_index =
282 self.next_derivation_index(params.change_keychain, params.change_shift);
283 let change_terminal = Terminal::new(params.change_keychain, change_index);
284 let change_vout = psbt
285 .append_change_expect(self.descriptor(), change_terminal, remaining_value)
286 .index();
287 (Some(Vout::from_u32(change_vout as u32)), Some(change_terminal))
288 } else {
289 (None, None)
290 };
291
292 Ok((psbt, PsbtMeta {
293 change_vout,
294 change_terminal,
295 }))
296 }
297}