1use crate::crypto;
7use crate::encoding;
8use crate::error::SignerError;
9use std::collections::BTreeMap;
10
11const PSBT_MAGIC: [u8; 4] = [0x70, 0x73, 0x62, 0x74];
13
14const PSBT_SEPARATOR: u8 = 0xff;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
21#[repr(u8)]
22pub enum GlobalKey {
23 UnsignedTx = 0x00,
25 Xpub = 0x01,
27 Version = 0xFB,
29}
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33#[repr(u8)]
34pub enum InputKey {
35 NonWitnessUtxo = 0x00,
37 WitnessUtxo = 0x01,
39 PartialSig = 0x02,
41 SighashType = 0x03,
43 RedeemScript = 0x04,
45 WitnessScript = 0x05,
47 Bip32Derivation = 0x06,
49 FinalScriptSig = 0x07,
51 FinalScriptWitness = 0x08,
53 TapKeySig = 0x13,
55 TapScriptSig = 0x14,
57 TapLeafScript = 0x15,
59 TapBip32Derivation = 0x16,
61 TapInternalKey = 0x17,
63 TapMerkleRoot = 0x18,
65}
66
67#[derive(Clone, Copy, Debug, PartialEq, Eq)]
69#[repr(u8)]
70pub enum OutputKey {
71 RedeemScript = 0x00,
73 WitnessScript = 0x01,
75 Bip32Derivation = 0x02,
77 TapInternalKey = 0x05,
79 TapTree = 0x06,
81 TapBip32Derivation = 0x07,
83}
84
85#[derive(Clone, Debug, PartialEq, Eq)]
89pub struct KeyValuePair {
90 pub key: Vec<u8>,
92 pub value: Vec<u8>,
94}
95
96#[derive(Clone, Debug)]
100pub struct Psbt {
101 pub global: BTreeMap<Vec<u8>, Vec<u8>>,
103 pub inputs: Vec<BTreeMap<Vec<u8>, Vec<u8>>>,
105 pub outputs: Vec<BTreeMap<Vec<u8>, Vec<u8>>>,
107}
108
109impl Psbt {
110 pub fn new() -> Self {
112 Self {
113 global: BTreeMap::new(),
114 inputs: Vec::new(),
115 outputs: Vec::new(),
116 }
117 }
118
119 pub fn set_unsigned_tx(&mut self, raw_tx: &[u8]) {
121 self.global
122 .insert(vec![GlobalKey::UnsignedTx as u8], raw_tx.to_vec());
123 }
124
125 pub fn unsigned_tx(&self) -> Option<&Vec<u8>> {
127 self.global.get(&vec![GlobalKey::UnsignedTx as u8])
128 }
129
130 pub fn add_input(&mut self) -> usize {
132 let idx = self.inputs.len();
133 self.inputs.push(BTreeMap::new());
134 idx
135 }
136
137 pub fn add_output(&mut self) -> usize {
139 let idx = self.outputs.len();
140 self.outputs.push(BTreeMap::new());
141 idx
142 }
143
144 pub fn set_input_kv(&mut self, input_idx: usize, key: Vec<u8>, value: Vec<u8>) {
146 if let Some(map) = self.inputs.get_mut(input_idx) {
147 map.insert(key, value);
148 }
149 }
150
151 pub fn set_output_kv(&mut self, output_idx: usize, key: Vec<u8>, value: Vec<u8>) {
153 if let Some(map) = self.outputs.get_mut(output_idx) {
154 map.insert(key, value);
155 }
156 }
157
158 pub fn set_witness_utxo(&mut self, input_idx: usize, amount: u64, script_pubkey: &[u8]) {
160 let mut value = Vec::new();
161 value.extend_from_slice(&amount.to_le_bytes());
162 encoding::encode_compact_size(&mut value, script_pubkey.len() as u64);
163 value.extend_from_slice(script_pubkey);
164 self.set_input_kv(input_idx, vec![InputKey::WitnessUtxo as u8], value);
165 }
166
167 pub fn set_tap_internal_key(&mut self, input_idx: usize, x_only_key: &[u8; 32]) {
169 self.set_input_kv(
170 input_idx,
171 vec![InputKey::TapInternalKey as u8],
172 x_only_key.to_vec(),
173 );
174 }
175
176 pub fn set_tap_merkle_root(&mut self, input_idx: usize, merkle_root: &[u8; 32]) {
178 self.set_input_kv(
179 input_idx,
180 vec![InputKey::TapMerkleRoot as u8],
181 merkle_root.to_vec(),
182 );
183 }
184
185 pub fn set_tap_key_sig(&mut self, input_idx: usize, signature: &[u8]) {
187 self.set_input_kv(
188 input_idx,
189 vec![InputKey::TapKeySig as u8],
190 signature.to_vec(),
191 );
192 }
193
194 pub fn sign_segwit_input(
204 &mut self,
205 input_idx: usize,
206 signer: &crate::bitcoin::BitcoinSigner,
207 sighash_type: crate::bitcoin::tapscript::SighashType,
208 ) -> Result<(), SignerError> {
209 use crate::bitcoin::sighash;
210 use crate::bitcoin::transaction::*;
211 use crate::traits::Signer;
212
213 let witness_utxo_key = vec![InputKey::WitnessUtxo as u8];
215 let utxo_data = self
216 .inputs
217 .get(input_idx)
218 .and_then(|m| m.get(&witness_utxo_key))
219 .ok_or_else(|| SignerError::SigningFailed("missing witness UTXO for input".into()))?
220 .clone();
221
222 if utxo_data.len() < 9 {
224 return Err(SignerError::SigningFailed("witness UTXO too short".into()));
225 }
226 let mut amount_bytes = [0u8; 8];
227 amount_bytes.copy_from_slice(&utxo_data[..8]);
228 let amount = u64::from_le_bytes(amount_bytes);
229
230 let mut utxo_off = 8usize;
232 let script_len = encoding::read_compact_size(&utxo_data, &mut utxo_off)? as usize;
233 let script_pk = &utxo_data[utxo_off..utxo_off + script_len];
234
235 if script_pk.len() != 22 || script_pk[0] != 0x00 || script_pk[1] != 0x14 {
237 return Err(SignerError::SigningFailed(
238 "input is not P2WPKH (expected OP_0 OP_PUSH20)".into(),
239 ));
240 }
241 let mut pubkey_hash = [0u8; 20];
242 pubkey_hash.copy_from_slice(&script_pk[2..22]);
243
244 let expected_hash = crate::crypto::hash160(&signer.public_key_bytes());
246 if pubkey_hash != expected_hash {
247 return Err(SignerError::SigningFailed(
248 "signer public key does not match the P2WPKH input".into(),
249 ));
250 }
251
252 let tx_bytes = self
254 .unsigned_tx()
255 .ok_or_else(|| SignerError::SigningFailed("missing unsigned tx".into()))?
256 .clone();
257
258 let tx = parse_unsigned_tx(&tx_bytes)?;
260
261 let script_code = sighash::p2wpkh_script_code(&pubkey_hash);
263 let prev_out = sighash::PrevOut {
264 script_code,
265 value: amount,
266 };
267 let sighash_value = sighash::segwit_v0_sighash(&tx, input_idx, &prev_out, sighash_type)?;
268
269 let sig = signer.sign_prehashed(&sighash_value)?;
271 let mut sig_bytes = sig.to_bytes();
272 sig_bytes.push(sighash_type.to_byte());
273
274 let pubkey = signer.public_key_bytes();
276 let mut key = vec![InputKey::PartialSig as u8];
277 key.extend_from_slice(&pubkey);
278 self.set_input_kv(input_idx, key, sig_bytes);
279
280 Ok(())
281 }
282
283 pub fn sign_taproot_input(
288 &mut self,
289 input_idx: usize,
290 signer: &crate::bitcoin::schnorr::SchnorrSigner,
291 sighash_type: crate::bitcoin::tapscript::SighashType,
292 ) -> Result<(), SignerError> {
293 use crate::bitcoin::sighash;
294 use crate::bitcoin::transaction::*;
295 use crate::traits::Signer;
296
297 let mut prevouts = Vec::new();
299 let witness_utxo_key = vec![InputKey::WitnessUtxo as u8];
300 for (i, input_map) in self.inputs.iter().enumerate() {
301 let utxo_data = input_map.get(&witness_utxo_key).ok_or_else(|| {
302 SignerError::SigningFailed(format!("missing witness UTXO for input {i}"))
303 })?;
304 if utxo_data.len() < 9 {
305 return Err(SignerError::SigningFailed(format!(
306 "witness UTXO {i} too short"
307 )));
308 }
309 let mut amount_bytes = [0u8; 8];
310 amount_bytes.copy_from_slice(&utxo_data[..8]);
311 let amount = u64::from_le_bytes(amount_bytes);
312 let mut utxo_off = 8usize;
313 let script_len = encoding::read_compact_size(utxo_data, &mut utxo_off)? as usize;
314 let script_pk = utxo_data[utxo_off..utxo_off + script_len].to_vec();
315 prevouts.push(TxOut {
316 value: amount,
317 script_pubkey: script_pk,
318 });
319 }
320
321 let tx_bytes = self
323 .unsigned_tx()
324 .ok_or_else(|| SignerError::SigningFailed("missing unsigned tx".into()))?
325 .clone();
326 let tx = parse_unsigned_tx(&tx_bytes)?;
327
328 let sighash_value =
330 sighash::taproot_key_path_sighash(&tx, input_idx, &prevouts, sighash_type)?;
331
332 let sig = signer.sign(&sighash_value)?;
334 let mut sig_bytes = sig.bytes.to_vec();
335 if sighash_type.to_byte() != 0x00 {
337 sig_bytes.push(sighash_type.to_byte());
338 }
339
340 self.set_tap_key_sig(input_idx, &sig_bytes);
342 Ok(())
343 }
344
345 pub fn serialize(&self) -> Vec<u8> {
349 let mut data = Vec::new();
350
351 data.extend_from_slice(&PSBT_MAGIC);
353 data.push(PSBT_SEPARATOR);
354
355 for (key, value) in &self.global {
357 encoding::encode_compact_size(&mut data, key.len() as u64);
358 data.extend_from_slice(key);
359 encoding::encode_compact_size(&mut data, value.len() as u64);
360 data.extend_from_slice(value);
361 }
362 data.push(0x00); for input in &self.inputs {
366 for (key, value) in input {
367 encoding::encode_compact_size(&mut data, key.len() as u64);
368 data.extend_from_slice(key);
369 encoding::encode_compact_size(&mut data, value.len() as u64);
370 data.extend_from_slice(value);
371 }
372 data.push(0x00); }
374
375 for output in &self.outputs {
377 for (key, value) in output {
378 encoding::encode_compact_size(&mut data, key.len() as u64);
379 data.extend_from_slice(key);
380 encoding::encode_compact_size(&mut data, value.len() as u64);
381 data.extend_from_slice(value);
382 }
383 data.push(0x00); }
385
386 data
387 }
388
389 pub fn deserialize(data: &[u8]) -> Result<Self, SignerError> {
396 if data.len() < 5 {
397 return Err(SignerError::ParseError("PSBT too short".into()));
398 }
399 if data[..4] != PSBT_MAGIC {
400 return Err(SignerError::ParseError("invalid PSBT magic".into()));
401 }
402 if data[4] != PSBT_SEPARATOR {
403 return Err(SignerError::ParseError("missing PSBT separator".into()));
404 }
405
406 let mut offset = 5;
407 let mut psbt = Psbt::new();
408
409 psbt.global = parse_kv_map(data, &mut offset)?;
411
412 let counts = psbt
414 .global
415 .get(&vec![0x00])
416 .and_then(|raw_tx| extract_tx_io_counts(raw_tx));
417
418 if let Some((num_inputs, num_outputs)) = counts {
419 for _ in 0..num_inputs {
421 if offset >= data.len() {
422 break;
423 }
424 psbt.inputs.push(parse_kv_map(data, &mut offset)?);
425 }
426 for _ in 0..num_outputs {
427 if offset >= data.len() {
428 break;
429 }
430 psbt.outputs.push(parse_kv_map(data, &mut offset)?);
431 }
432 } else {
433 while offset < data.len() {
435 let map = parse_kv_map(data, &mut offset)?;
436 if !map.is_empty() {
437 let has_input_keys = map.keys().any(|k| {
438 matches!(
439 k.first(),
440 Some(&0x00)
441 | Some(&0x01)
442 | Some(&0x02)
443 | Some(&0x03)
444 | Some(&0x06)
445 | Some(&0x07)
446 | Some(&0x08)
447 | Some(&0x13)
448 | Some(&0x14)
449 | Some(&0x17)
450 )
451 });
452 if has_input_keys {
453 psbt.inputs.push(map);
454 } else {
455 psbt.outputs.push(map);
456 }
457 }
458 }
459 }
460
461 Ok(psbt)
462 }
463
464 pub fn psbt_id(&self) -> [u8; 32] {
466 let serialized = self.serialize();
467 crypto::sha256(&serialized)
468 }
469}
470
471impl Default for Psbt {
472 fn default() -> Self {
473 Self::new()
474 }
475}
476
477fn parse_kv_map(
481 data: &[u8],
482 offset: &mut usize,
483) -> Result<BTreeMap<Vec<u8>, Vec<u8>>, SignerError> {
484 let mut map = BTreeMap::new();
485
486 loop {
487 if *offset >= data.len() {
488 return Ok(map);
489 }
490
491 let key_len = encoding::read_compact_size(data, offset)?;
493 if key_len == 0 {
494 return Ok(map);
496 }
497
498 let end = *offset + key_len as usize;
500 if end > data.len() {
501 return Err(SignerError::ParseError("PSBT key truncated".into()));
502 }
503 let key = data[*offset..end].to_vec();
504 *offset = end;
505
506 let val_len = encoding::read_compact_size(data, offset)?;
508
509 let end = *offset + val_len as usize;
511 if end > data.len() {
512 return Err(SignerError::ParseError("PSBT value truncated".into()));
513 }
514 let value = data[*offset..end].to_vec();
515 *offset = end;
516
517 map.insert(key, value);
518 }
519}
520
521fn extract_tx_io_counts(raw_tx: &[u8]) -> Option<(usize, usize)> {
526 if raw_tx.len() < 10 {
527 return None; }
529 let mut offset = 4;
531 let num_inputs = encoding::read_compact_size(raw_tx, &mut offset).ok()? as usize;
533 for _ in 0..num_inputs {
535 if offset + 36 > raw_tx.len() {
537 return None;
538 }
539 offset += 36;
540 let script_len = encoding::read_compact_size(raw_tx, &mut offset).ok()? as usize;
542 if offset + script_len + 4 > raw_tx.len() {
543 return None;
544 }
545 offset += script_len;
546 offset += 4;
548 }
549 let num_outputs = encoding::read_compact_size(raw_tx, &mut offset).ok()? as usize;
551 if num_inputs > 10_000 || num_outputs > 10_000 {
553 return None;
554 }
555 Some((num_inputs, num_outputs))
556}
557
558#[cfg(test)]
561#[allow(clippy::unwrap_used, clippy::expect_used)]
562mod tests {
563 use super::*;
564
565 #[test]
566 fn test_psbt_new() {
567 let psbt = Psbt::new();
568 assert!(psbt.global.is_empty());
569 assert!(psbt.inputs.is_empty());
570 assert!(psbt.outputs.is_empty());
571 }
572
573 #[test]
574 fn test_psbt_set_unsigned_tx() {
575 let mut psbt = Psbt::new();
576 let fake_tx = vec![0x01, 0x02, 0x03, 0x04];
577 psbt.set_unsigned_tx(&fake_tx);
578 assert_eq!(psbt.unsigned_tx(), Some(&fake_tx));
579 }
580
581 #[test]
582 fn test_psbt_add_input_output() {
583 let mut psbt = Psbt::new();
584 let idx_in = psbt.add_input();
585 assert_eq!(idx_in, 0);
586 let idx_out = psbt.add_output();
587 assert_eq!(idx_out, 0);
588 assert_eq!(psbt.inputs.len(), 1);
589 assert_eq!(psbt.outputs.len(), 1);
590 }
591
592 #[test]
593 fn test_psbt_serialize_magic() {
594 let psbt = Psbt::new();
595 let data = psbt.serialize();
596 assert_eq!(&data[..4], &PSBT_MAGIC);
597 assert_eq!(data[4], PSBT_SEPARATOR);
598 }
599
600 #[test]
601 fn test_psbt_serialize_deserialize_roundtrip() {
602 let mut psbt = Psbt::new();
603 psbt.set_unsigned_tx(&[0x01, 0x00, 0x00, 0x00]);
604 let idx = psbt.add_input();
605 let script_pk = [
606 0x00u8, 0x14, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
607 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
608 ];
609 psbt.set_witness_utxo(idx, 50000, &script_pk);
610
611 let serialized = psbt.serialize();
612 let parsed = Psbt::deserialize(&serialized).expect("valid PSBT");
613
614 assert_eq!(parsed.global.len(), psbt.global.len());
616 assert_eq!(parsed.unsigned_tx(), psbt.unsigned_tx());
617 }
618
619 #[test]
620 fn test_psbt_deserialize_invalid() {
621 assert!(Psbt::deserialize(&[]).is_err());
622 assert!(Psbt::deserialize(&[0x00, 0x01, 0x02, 0x03, 0xFF]).is_err());
623 assert!(Psbt::deserialize(&[0x70, 0x73, 0x62, 0x74, 0x00]).is_err()); }
625
626 #[test]
627 fn test_psbt_set_taproot_fields() {
628 let mut psbt = Psbt::new();
629 let idx = psbt.add_input();
630 let key = [0xAA; 32];
631 let root = [0xBB; 32];
632 let sig = [0xCC; 64];
633
634 psbt.set_tap_internal_key(idx, &key);
635 psbt.set_tap_merkle_root(idx, &root);
636 psbt.set_tap_key_sig(idx, &sig);
637
638 let input = &psbt.inputs[0];
639 assert_eq!(
640 input.get(&vec![InputKey::TapInternalKey as u8]),
641 Some(&key.to_vec())
642 );
643 assert_eq!(
644 input.get(&vec![InputKey::TapMerkleRoot as u8]),
645 Some(&root.to_vec())
646 );
647 assert_eq!(
648 input.get(&vec![InputKey::TapKeySig as u8]),
649 Some(&sig.to_vec())
650 );
651 }
652
653 #[test]
654 fn test_psbt_psbt_id_deterministic() {
655 let mut psbt = Psbt::new();
656 psbt.set_unsigned_tx(&[0x01, 0x00]);
657 let id1 = psbt.psbt_id();
658 let id2 = psbt.psbt_id();
659 assert_eq!(id1, id2);
660 }
661
662 #[test]
663 fn test_psbt_empty_roundtrip() {
664 let psbt = Psbt::new();
665 let data = psbt.serialize();
666 let parsed = Psbt::deserialize(&data).expect("valid");
667 assert!(parsed.global.is_empty());
668 }
669
670 #[test]
671 fn test_psbt_multiple_inputs() {
672 let mut psbt = Psbt::new();
673 psbt.add_input();
674 psbt.add_input();
675 psbt.add_input();
676 assert_eq!(psbt.inputs.len(), 3);
677 }
678
679 #[test]
680 fn test_compact_size_roundtrip() {
681 for val in [0u64, 1, 252, 253, 0xFFFF, 0x10000, 0xFFFFFFFF, 0x100000000] {
682 let mut buf = Vec::new();
683 encoding::encode_compact_size(&mut buf, val);
684 let mut offset = 0;
685 let parsed = encoding::read_compact_size(&buf, &mut offset).expect("valid");
686 assert_eq!(parsed, val, "failed for value {val}");
687 }
688 }
689
690 #[test]
691 fn test_psbt_witness_utxo() {
692 let mut psbt = Psbt::new();
693 let idx = psbt.add_input();
694 let script_pk = vec![0x00, 0x14, 0xAA, 0xBB, 0xCC]; psbt.set_witness_utxo(idx, 100000, &script_pk);
696
697 let input = &psbt.inputs[0];
698 let value = input
699 .get(&vec![InputKey::WitnessUtxo as u8])
700 .expect("exists");
701 assert_eq!(&value[..8], &100000u64.to_le_bytes());
703 }
704}