1use crate::crypto;
7use crate::encoding;
8use crate::error::SignerError;
9use std::collections::BTreeMap;
10
11const PSBT_MAGIC: [u8; 4] = [0x70, 0x73, 0x62, 0x74];
13
14const PSBT_SEPARATOR: u8 = 0xff;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
21#[repr(u8)]
22pub enum GlobalKey {
23 UnsignedTx = 0x00,
25 Xpub = 0x01,
27 Version = 0xFB,
29}
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33#[repr(u8)]
34pub enum InputKey {
35 NonWitnessUtxo = 0x00,
37 WitnessUtxo = 0x01,
39 PartialSig = 0x02,
41 SighashType = 0x03,
43 RedeemScript = 0x04,
45 WitnessScript = 0x05,
47 Bip32Derivation = 0x06,
49 FinalScriptSig = 0x07,
51 FinalScriptWitness = 0x08,
53 TapKeySig = 0x13,
55 TapScriptSig = 0x14,
57 TapLeafScript = 0x15,
59 TapBip32Derivation = 0x16,
61 TapInternalKey = 0x17,
63 TapMerkleRoot = 0x18,
65}
66
67#[derive(Clone, Copy, Debug, PartialEq, Eq)]
69#[repr(u8)]
70pub enum OutputKey {
71 RedeemScript = 0x00,
73 WitnessScript = 0x01,
75 Bip32Derivation = 0x02,
77 TapInternalKey = 0x05,
79 TapTree = 0x06,
81 TapBip32Derivation = 0x07,
83}
84
85#[derive(Clone, Debug, PartialEq, Eq)]
89pub struct KeyValuePair {
90 pub key: Vec<u8>,
92 pub value: Vec<u8>,
94}
95
96#[derive(Clone, Debug)]
100pub struct Psbt {
101 pub global: BTreeMap<Vec<u8>, Vec<u8>>,
103 pub inputs: Vec<BTreeMap<Vec<u8>, Vec<u8>>>,
105 pub outputs: Vec<BTreeMap<Vec<u8>, Vec<u8>>>,
107}
108
109impl Psbt {
110 pub fn new() -> Self {
112 Self {
113 global: BTreeMap::new(),
114 inputs: Vec::new(),
115 outputs: Vec::new(),
116 }
117 }
118
119 pub fn set_unsigned_tx(&mut self, raw_tx: &[u8]) {
121 self.global
122 .insert(vec![GlobalKey::UnsignedTx as u8], raw_tx.to_vec());
123 }
124
125 pub fn unsigned_tx(&self) -> Option<&Vec<u8>> {
127 self.global.get(&vec![GlobalKey::UnsignedTx as u8])
128 }
129
130 pub fn add_input(&mut self) -> usize {
132 let idx = self.inputs.len();
133 self.inputs.push(BTreeMap::new());
134 idx
135 }
136
137 pub fn add_output(&mut self) -> usize {
139 let idx = self.outputs.len();
140 self.outputs.push(BTreeMap::new());
141 idx
142 }
143
144 pub fn set_input_kv(&mut self, input_idx: usize, key: Vec<u8>, value: Vec<u8>) {
146 if let Some(map) = self.inputs.get_mut(input_idx) {
147 map.insert(key, value);
148 }
149 }
150
151 pub fn set_output_kv(&mut self, output_idx: usize, key: Vec<u8>, value: Vec<u8>) {
153 if let Some(map) = self.outputs.get_mut(output_idx) {
154 map.insert(key, value);
155 }
156 }
157
158 pub fn set_witness_utxo(&mut self, input_idx: usize, amount: u64, script_pubkey: &[u8]) {
160 let mut value = Vec::new();
161 value.extend_from_slice(&amount.to_le_bytes());
162 encoding::encode_compact_size(&mut value, script_pubkey.len() as u64);
163 value.extend_from_slice(script_pubkey);
164 self.set_input_kv(input_idx, vec![InputKey::WitnessUtxo as u8], value);
165 }
166
167 pub fn set_tap_internal_key(&mut self, input_idx: usize, x_only_key: &[u8; 32]) {
169 self.set_input_kv(
170 input_idx,
171 vec![InputKey::TapInternalKey as u8],
172 x_only_key.to_vec(),
173 );
174 }
175
176 pub fn set_tap_merkle_root(&mut self, input_idx: usize, merkle_root: &[u8; 32]) {
178 self.set_input_kv(
179 input_idx,
180 vec![InputKey::TapMerkleRoot as u8],
181 merkle_root.to_vec(),
182 );
183 }
184
185 pub fn set_tap_key_sig(&mut self, input_idx: usize, signature: &[u8]) {
187 self.set_input_kv(
188 input_idx,
189 vec![InputKey::TapKeySig as u8],
190 signature.to_vec(),
191 );
192 }
193
194 pub fn sign_segwit_input(
204 &mut self,
205 input_idx: usize,
206 signer: &crate::bitcoin::BitcoinSigner,
207 sighash_type: crate::bitcoin::tapscript::SighashType,
208 ) -> Result<(), SignerError> {
209 use crate::bitcoin::sighash;
210 use crate::bitcoin::transaction::*;
211 use crate::traits::Signer;
212
213 let witness_utxo_key = vec![InputKey::WitnessUtxo as u8];
215 let utxo_data = self
216 .inputs
217 .get(input_idx)
218 .and_then(|m| m.get(&witness_utxo_key))
219 .ok_or_else(|| SignerError::SigningFailed("missing witness UTXO for input".into()))?
220 .clone();
221
222 let (amount, script_pk) = parse_witness_utxo_value(&utxo_data, "witness UTXO")?;
223
224 if script_pk.len() != 22 || script_pk[0] != 0x00 || script_pk[1] != 0x14 {
226 return Err(SignerError::SigningFailed(
227 "input is not P2WPKH (expected OP_0 OP_PUSH20)".into(),
228 ));
229 }
230 let mut pubkey_hash = [0u8; 20];
231 pubkey_hash.copy_from_slice(&script_pk[2..22]);
232
233 let expected_hash = crate::crypto::hash160(&signer.public_key_bytes());
235 if pubkey_hash != expected_hash {
236 return Err(SignerError::SigningFailed(
237 "signer public key does not match the P2WPKH input".into(),
238 ));
239 }
240
241 let tx_bytes = self
243 .unsigned_tx()
244 .ok_or_else(|| SignerError::SigningFailed("missing unsigned tx".into()))?
245 .clone();
246
247 let tx = parse_unsigned_tx(&tx_bytes)?;
249
250 let script_code = sighash::p2wpkh_script_code(&pubkey_hash);
252 let prev_out = sighash::PrevOut {
253 script_code,
254 value: amount,
255 };
256 let sighash_value = sighash::segwit_v0_sighash(&tx, input_idx, &prev_out, sighash_type)?;
257
258 let sig = signer.sign_prehashed(&sighash_value)?;
260 let mut sig_bytes = sig.to_bytes();
261 sig_bytes.push(sighash_type.to_byte());
262
263 let pubkey = signer.public_key_bytes();
265 let mut key = vec![InputKey::PartialSig as u8];
266 key.extend_from_slice(&pubkey);
267 self.set_input_kv(input_idx, key, sig_bytes);
268
269 Ok(())
270 }
271
272 pub fn sign_taproot_input(
277 &mut self,
278 input_idx: usize,
279 signer: &crate::bitcoin::schnorr::SchnorrSigner,
280 sighash_type: crate::bitcoin::tapscript::SighashType,
281 ) -> Result<(), SignerError> {
282 use crate::bitcoin::sighash;
283 use crate::bitcoin::transaction::*;
284 use crate::traits::Signer;
285
286 let mut prevouts = Vec::new();
288 let witness_utxo_key = vec![InputKey::WitnessUtxo as u8];
289 for (i, input_map) in self.inputs.iter().enumerate() {
290 let utxo_data = input_map.get(&witness_utxo_key).ok_or_else(|| {
291 SignerError::SigningFailed(format!("missing witness UTXO for input {i}"))
292 })?;
293 let context = format!("witness UTXO {i}");
294 let (amount, script_pk) = parse_witness_utxo_value(utxo_data, &context)?;
295 prevouts.push(TxOut {
296 value: amount,
297 script_pubkey: script_pk.to_vec(),
298 });
299 }
300
301 let tx_bytes = self
303 .unsigned_tx()
304 .ok_or_else(|| SignerError::SigningFailed("missing unsigned tx".into()))?
305 .clone();
306 let tx = parse_unsigned_tx(&tx_bytes)?;
307
308 let sighash_value =
310 sighash::taproot_key_path_sighash(&tx, input_idx, &prevouts, sighash_type)?;
311
312 let sig = signer.sign(&sighash_value)?;
314 let mut sig_bytes = sig.bytes.to_vec();
315 if sighash_type.to_byte() != 0x00 {
317 sig_bytes.push(sighash_type.to_byte());
318 }
319
320 self.set_tap_key_sig(input_idx, &sig_bytes);
322 Ok(())
323 }
324
325 pub fn serialize(&self) -> Vec<u8> {
329 let mut data = Vec::new();
330
331 data.extend_from_slice(&PSBT_MAGIC);
333 data.push(PSBT_SEPARATOR);
334
335 for (key, value) in &self.global {
337 encoding::encode_compact_size(&mut data, key.len() as u64);
338 data.extend_from_slice(key);
339 encoding::encode_compact_size(&mut data, value.len() as u64);
340 data.extend_from_slice(value);
341 }
342 data.push(0x00); for input in &self.inputs {
346 for (key, value) in input {
347 encoding::encode_compact_size(&mut data, key.len() as u64);
348 data.extend_from_slice(key);
349 encoding::encode_compact_size(&mut data, value.len() as u64);
350 data.extend_from_slice(value);
351 }
352 data.push(0x00); }
354
355 for output in &self.outputs {
357 for (key, value) in output {
358 encoding::encode_compact_size(&mut data, key.len() as u64);
359 data.extend_from_slice(key);
360 encoding::encode_compact_size(&mut data, value.len() as u64);
361 data.extend_from_slice(value);
362 }
363 data.push(0x00); }
365
366 data
367 }
368
369 pub fn deserialize(data: &[u8]) -> Result<Self, SignerError> {
378 if data.len() < 5 {
379 return Err(SignerError::ParseError("PSBT too short".into()));
380 }
381 if data[..4] != PSBT_MAGIC {
382 return Err(SignerError::ParseError("invalid PSBT magic".into()));
383 }
384 if data[4] != PSBT_SEPARATOR {
385 return Err(SignerError::ParseError("missing PSBT separator".into()));
386 }
387
388 let mut offset = 5;
389 let mut psbt = Psbt::new();
390
391 psbt.global = parse_kv_map(data, &mut offset)?;
393
394 let counts = psbt
396 .global
397 .get(&vec![0x00])
398 .and_then(|raw_tx| extract_tx_io_counts(raw_tx));
399
400 let (num_inputs, num_outputs) = counts.ok_or_else(|| {
401 SignerError::ParseError(
402 "PSBT: missing or malformed unsigned transaction (key 0x00)".into(),
403 )
404 })?;
405
406 for i in 0..num_inputs {
408 if offset >= data.len() {
409 return Err(SignerError::ParseError(format!(
410 "PSBT truncated: expected {} inputs, got {}",
411 num_inputs, i
412 )));
413 }
414 psbt.inputs.push(parse_kv_map(data, &mut offset)?);
415 }
416 for i in 0..num_outputs {
417 if offset >= data.len() {
418 return Err(SignerError::ParseError(format!(
419 "PSBT truncated: expected {} outputs, got {}",
420 num_outputs, i
421 )));
422 }
423 psbt.outputs.push(parse_kv_map(data, &mut offset)?);
424 }
425
426 if offset != data.len() {
428 return Err(SignerError::ParseError(format!(
429 "PSBT has {} trailing bytes",
430 data.len() - offset
431 )));
432 }
433
434 Ok(psbt)
435 }
436
437 pub fn psbt_id(&self) -> [u8; 32] {
439 let serialized = self.serialize();
440 crypto::sha256(&serialized)
441 }
442}
443
444impl Default for Psbt {
445 fn default() -> Self {
446 Self::new()
447 }
448}
449
450fn parse_kv_map(
454 data: &[u8],
455 offset: &mut usize,
456) -> Result<BTreeMap<Vec<u8>, Vec<u8>>, SignerError> {
457 let mut map = BTreeMap::new();
458
459 loop {
460 if *offset >= data.len() {
461 return Err(SignerError::ParseError(
462 "PSBT map missing terminator".into(),
463 ));
464 }
465
466 let key_len = encoding::read_compact_size(data, offset)?;
468 if key_len == 0 {
469 return Ok(map);
471 }
472
473 let key_len_usize = usize::try_from(key_len).map_err(|_| {
475 SignerError::ParseError("PSBT key length exceeds platform usize".into())
476 })?;
477 let end = offset
478 .checked_add(key_len_usize)
479 .ok_or_else(|| SignerError::ParseError("PSBT key length overflow".into()))?;
480 if end > data.len() {
481 return Err(SignerError::ParseError("PSBT key truncated".into()));
482 }
483 let key = data[*offset..end].to_vec();
484 *offset = end;
485
486 let val_len = encoding::read_compact_size(data, offset)?;
488
489 let val_len_usize = usize::try_from(val_len).map_err(|_| {
491 SignerError::ParseError("PSBT value length exceeds platform usize".into())
492 })?;
493 let end = offset
494 .checked_add(val_len_usize)
495 .ok_or_else(|| SignerError::ParseError("PSBT value length overflow".into()))?;
496 if end > data.len() {
497 return Err(SignerError::ParseError("PSBT value truncated".into()));
498 }
499 let value = data[*offset..end].to_vec();
500 *offset = end;
501
502 if map.contains_key(&key) {
504 return Err(SignerError::ParseError("PSBT: duplicate key in map".into()));
505 }
506 map.insert(key, value);
507 }
508}
509
510fn extract_tx_io_counts(raw_tx: &[u8]) -> Option<(usize, usize)> {
515 if raw_tx.len() < 10 {
516 return None; }
518 let mut offset = 4;
520 let num_inputs =
522 usize::try_from(encoding::read_compact_size(raw_tx, &mut offset).ok()?).ok()?;
523 for _ in 0..num_inputs {
525 offset = offset.checked_add(36)?;
527 if offset > raw_tx.len() {
528 return None;
529 }
530 let script_len =
532 usize::try_from(encoding::read_compact_size(raw_tx, &mut offset).ok()?).ok()?;
533 let end = offset.checked_add(script_len)?.checked_add(4)?;
534 if end > raw_tx.len() {
535 return None;
536 }
537 offset = end;
538 }
539 let num_outputs =
541 usize::try_from(encoding::read_compact_size(raw_tx, &mut offset).ok()?).ok()?;
542 if num_inputs > 10_000 || num_outputs > 10_000 {
544 return None;
545 }
546 Some((num_inputs, num_outputs))
547}
548
549fn parse_witness_utxo_value<'a>(
550 utxo_data: &'a [u8],
551 context: &str,
552) -> Result<(u64, &'a [u8]), SignerError> {
553 if utxo_data.len() < 9 {
554 return Err(SignerError::SigningFailed(format!("{context} too short")));
555 }
556 let mut amount_bytes = [0u8; 8];
557 amount_bytes.copy_from_slice(&utxo_data[..8]);
558 let amount = u64::from_le_bytes(amount_bytes);
559
560 let mut utxo_off = 8usize;
561 let script_len_u64 = encoding::read_compact_size(utxo_data, &mut utxo_off)?;
562 let script_len = usize::try_from(script_len_u64).map_err(|_| {
563 SignerError::SigningFailed(format!("{context} script length exceeds platform usize"))
564 })?;
565 let script_end = utxo_off
566 .checked_add(script_len)
567 .ok_or_else(|| SignerError::SigningFailed(format!("{context} script length overflow")))?;
568 if script_end > utxo_data.len() {
569 return Err(SignerError::SigningFailed(format!(
570 "{context} script truncated"
571 )));
572 }
573 if script_end != utxo_data.len() {
574 return Err(SignerError::SigningFailed(format!(
575 "{context} has trailing bytes"
576 )));
577 }
578 Ok((amount, &utxo_data[utxo_off..script_end]))
579}
580
581#[cfg(test)]
584#[allow(clippy::unwrap_used, clippy::expect_used)]
585mod tests {
586 use super::*;
587
588 #[test]
589 fn test_psbt_new() {
590 let psbt = Psbt::new();
591 assert!(psbt.global.is_empty());
592 assert!(psbt.inputs.is_empty());
593 assert!(psbt.outputs.is_empty());
594 }
595
596 #[test]
597 fn test_psbt_set_unsigned_tx() {
598 let mut psbt = Psbt::new();
599 let fake_tx = vec![0x01, 0x02, 0x03, 0x04];
600 psbt.set_unsigned_tx(&fake_tx);
601 assert_eq!(psbt.unsigned_tx(), Some(&fake_tx));
602 }
603
604 #[test]
605 fn test_psbt_add_input_output() {
606 let mut psbt = Psbt::new();
607 let idx_in = psbt.add_input();
608 assert_eq!(idx_in, 0);
609 let idx_out = psbt.add_output();
610 assert_eq!(idx_out, 0);
611 assert_eq!(psbt.inputs.len(), 1);
612 assert_eq!(psbt.outputs.len(), 1);
613 }
614
615 #[test]
616 fn test_psbt_serialize_magic() {
617 let psbt = Psbt::new();
618 let data = psbt.serialize();
619 assert_eq!(&data[..4], &PSBT_MAGIC);
620 assert_eq!(data[4], PSBT_SEPARATOR);
621 }
622
623 #[test]
624 fn test_psbt_serialize_deserialize_roundtrip() {
625 let mut psbt = Psbt::new();
626 let mut raw_tx = Vec::new();
628 raw_tx.extend_from_slice(&1i32.to_le_bytes()); raw_tx.push(0x01); raw_tx.extend_from_slice(&[0xAA; 32]); raw_tx.extend_from_slice(&0u32.to_le_bytes()); raw_tx.push(0x00); raw_tx.extend_from_slice(&0xFFFFFFFFu32.to_le_bytes()); raw_tx.push(0x01); raw_tx.extend_from_slice(&50000u64.to_le_bytes()); raw_tx.push(0x00); raw_tx.extend_from_slice(&0u32.to_le_bytes()); psbt.set_unsigned_tx(&raw_tx);
639 let idx = psbt.add_input();
640 psbt.add_output();
641 let script_pk = [
642 0x00u8, 0x14, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
643 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA,
644 ];
645 psbt.set_witness_utxo(idx, 50000, &script_pk);
646
647 let serialized = psbt.serialize();
648 let parsed = Psbt::deserialize(&serialized).expect("valid PSBT");
649
650 assert_eq!(parsed.global.len(), psbt.global.len());
652 assert_eq!(parsed.unsigned_tx(), psbt.unsigned_tx());
653 }
654
655 #[test]
656 fn test_psbt_deserialize_invalid() {
657 assert!(Psbt::deserialize(&[]).is_err());
658 assert!(Psbt::deserialize(&[0x00, 0x01, 0x02, 0x03, 0xFF]).is_err());
659 assert!(Psbt::deserialize(&[0x70, 0x73, 0x62, 0x74, 0x00]).is_err()); }
661
662 #[test]
663 fn test_psbt_set_taproot_fields() {
664 let mut psbt = Psbt::new();
665 let idx = psbt.add_input();
666 let key = [0xAA; 32];
667 let root = [0xBB; 32];
668 let sig = [0xCC; 64];
669
670 psbt.set_tap_internal_key(idx, &key);
671 psbt.set_tap_merkle_root(idx, &root);
672 psbt.set_tap_key_sig(idx, &sig);
673
674 let input = &psbt.inputs[0];
675 assert_eq!(
676 input.get(&vec![InputKey::TapInternalKey as u8]),
677 Some(&key.to_vec())
678 );
679 assert_eq!(
680 input.get(&vec![InputKey::TapMerkleRoot as u8]),
681 Some(&root.to_vec())
682 );
683 assert_eq!(
684 input.get(&vec![InputKey::TapKeySig as u8]),
685 Some(&sig.to_vec())
686 );
687 }
688
689 #[test]
690 fn test_psbt_psbt_id_deterministic() {
691 let mut psbt = Psbt::new();
692 psbt.set_unsigned_tx(&[0x01, 0x00]);
693 let id1 = psbt.psbt_id();
694 let id2 = psbt.psbt_id();
695 assert_eq!(id1, id2);
696 }
697
698 #[test]
699 fn test_psbt_empty_roundtrip() {
700 let psbt = Psbt::new();
701 let data = psbt.serialize();
702 assert!(Psbt::deserialize(&data).is_err());
704 }
705
706 #[test]
707 fn test_psbt_multiple_inputs() {
708 let mut psbt = Psbt::new();
709 psbt.add_input();
710 psbt.add_input();
711 psbt.add_input();
712 assert_eq!(psbt.inputs.len(), 3);
713 }
714
715 #[test]
716 fn test_compact_size_roundtrip() {
717 for val in [0u64, 1, 252, 253, 0xFFFF, 0x10000, 0xFFFFFFFF, 0x100000000] {
718 let mut buf = Vec::new();
719 encoding::encode_compact_size(&mut buf, val);
720 let mut offset = 0;
721 let parsed = encoding::read_compact_size(&buf, &mut offset).expect("valid");
722 assert_eq!(parsed, val, "failed for value {val}");
723 }
724 }
725
726 #[test]
727 fn test_parse_kv_map_rejects_huge_key_length() {
728 let mut data = vec![0xFF];
729 data.extend_from_slice(&u64::MAX.to_le_bytes());
730 let mut offset = 0;
731 assert!(parse_kv_map(&data, &mut offset).is_err());
732 }
733
734 #[test]
735 fn test_parse_kv_map_rejects_missing_terminator() {
736 let data = vec![0x01, 0x01, 0x01, 0x02];
738 let mut offset = 0;
739 assert!(parse_kv_map(&data, &mut offset).is_err());
740 }
741
742 #[test]
743 fn test_extract_tx_io_counts_rejects_oversized_script_len() {
744 let mut raw_tx = Vec::new();
745 raw_tx.extend_from_slice(&1u32.to_le_bytes()); raw_tx.push(0x01); raw_tx.extend_from_slice(&[0u8; 32]); raw_tx.extend_from_slice(&0u32.to_le_bytes()); raw_tx.push(0xFF); raw_tx.extend_from_slice(&u64::MAX.to_le_bytes());
751 assert_eq!(extract_tx_io_counts(&raw_tx), None);
752 }
753
754 #[test]
755 fn test_parse_witness_utxo_rejects_trailing_bytes() {
756 let mut utxo = Vec::new();
757 utxo.extend_from_slice(&50_000u64.to_le_bytes());
758 encoding::encode_compact_size(&mut utxo, 22);
759 utxo.extend_from_slice(&[0x00, 0x14]);
760 utxo.extend_from_slice(&[0xAA; 20]);
761 utxo.push(0x99); assert!(parse_witness_utxo_value(&utxo, "witness UTXO").is_err());
764 }
765
766 #[test]
767 fn test_psbt_witness_utxo() {
768 let mut psbt = Psbt::new();
769 let idx = psbt.add_input();
770 let script_pk = vec![0x00, 0x14, 0xAA, 0xBB, 0xCC]; psbt.set_witness_utxo(idx, 100000, &script_pk);
772
773 let input = &psbt.inputs[0];
774 let value = input
775 .get(&vec![InputKey::WitnessUtxo as u8])
776 .expect("exists");
777 assert_eq!(&value[..8], &100000u64.to_le_bytes());
779 }
780}