use crate::address::MultisigWallet;
use crate::error::{MultisigError, Result};
use crate::signer::PartialSignature;
use rustywallet_keys::prelude::PrivateKey;
use secp256k1::{Secp256k1, Message, SecretKey};
#[derive(Debug, Clone)]
pub struct PsbtPartialSig {
pub pubkey: [u8; 33],
pub signature: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct MultisigPsbtBuilder {
pub wallet: MultisigWallet,
pub input_signatures: Vec<Vec<PsbtPartialSig>>,
pub witness_utxos: Vec<Option<(u64, Vec<u8>)>>,
pub non_witness_utxos: Vec<Option<Vec<u8>>>,
}
impl MultisigPsbtBuilder {
pub fn new(wallet: MultisigWallet, input_count: usize) -> Self {
Self {
wallet,
input_signatures: vec![Vec::new(); input_count],
witness_utxos: vec![None; input_count],
non_witness_utxos: vec![None; input_count],
}
}
pub fn set_witness_utxo(&mut self, input_index: usize, value: u64, script_pubkey: Vec<u8>) -> Result<()> {
if input_index >= self.witness_utxos.len() {
return Err(MultisigError::InvalidSignature {
index: input_index,
reason: "Input index out of bounds".to_string(),
});
}
self.witness_utxos[input_index] = Some((value, script_pubkey));
Ok(())
}
pub fn set_non_witness_utxo(&mut self, input_index: usize, prev_tx: Vec<u8>) -> Result<()> {
if input_index >= self.non_witness_utxos.len() {
return Err(MultisigError::InvalidSignature {
index: input_index,
reason: "Input index out of bounds".to_string(),
});
}
self.non_witness_utxos[input_index] = Some(prev_tx);
Ok(())
}
pub fn add_signature(&mut self, input_index: usize, sig: PsbtPartialSig) -> Result<()> {
if input_index >= self.input_signatures.len() {
return Err(MultisigError::InvalidSignature {
index: input_index,
reason: "Input index out of bounds".to_string(),
});
}
if !self.wallet.config.contains_key(&sig.pubkey) {
return Err(MultisigError::InvalidPublicKey(
"Public key not in multisig".to_string(),
));
}
if self.input_signatures[input_index]
.iter()
.any(|s| s.pubkey == sig.pubkey)
{
return Err(MultisigError::InvalidSignature {
index: input_index,
reason: "Duplicate signature for this pubkey".to_string(),
});
}
self.input_signatures[input_index].push(sig);
Ok(())
}
pub fn sign_input(
&mut self,
input_index: usize,
sighash: &[u8; 32],
private_key: &PrivateKey,
) -> Result<()> {
let pubkey = private_key.public_key().to_compressed();
if !self.wallet.config.contains_key(&pubkey) {
return Err(MultisigError::InvalidPublicKey(
"Private key not in multisig".to_string(),
));
}
let secp = Secp256k1::new();
let secret_key = SecretKey::from_slice(&private_key.to_bytes())
.map_err(|e| MultisigError::SigningFailed(e.to_string()))?;
let message = Message::from_digest(*sighash);
let signature = secp.sign_ecdsa(&message, &secret_key);
let sig_bytes = signature.serialize_der().to_vec();
self.add_signature(input_index, PsbtPartialSig {
pubkey,
signature: sig_bytes,
})
}
pub fn input_is_complete(&self, input_index: usize) -> bool {
if input_index >= self.input_signatures.len() {
return false;
}
self.input_signatures[input_index].len() >= self.wallet.config.threshold() as usize
}
pub fn is_complete(&self) -> bool {
(0..self.input_signatures.len()).all(|i| self.input_is_complete(i))
}
pub fn signature_count(&self, input_index: usize) -> usize {
self.input_signatures
.get(input_index)
.map(|sigs| sigs.len())
.unwrap_or(0)
}
pub fn redeem_script(&self) -> &[u8] {
&self.wallet.redeem_script
}
pub fn witness_script(&self) -> &[u8] {
&self.wallet.redeem_script
}
pub fn build_witness(&self, input_index: usize) -> Result<Vec<Vec<u8>>> {
if !self.input_is_complete(input_index) {
return Err(MultisigError::NotEnoughSignatures {
need: self.wallet.config.threshold() as usize,
got: self.signature_count(input_index),
});
}
let mut witness = Vec::new();
witness.push(Vec::new());
let mut sigs: Vec<_> = self.input_signatures[input_index]
.iter()
.filter_map(|sig| {
self.wallet.config.key_index(&sig.pubkey).map(|idx| (idx, sig))
})
.collect();
sigs.sort_by_key(|(idx, _)| *idx);
for (_, sig) in sigs.iter().take(self.wallet.config.threshold() as usize) {
let mut sig_with_sighash = sig.signature.clone();
sig_with_sighash.push(0x01); witness.push(sig_with_sighash);
}
witness.push(self.wallet.redeem_script.clone());
Ok(witness)
}
pub fn build_script_sig(&self, input_index: usize) -> Result<Vec<u8>> {
if !self.input_is_complete(input_index) {
return Err(MultisigError::NotEnoughSignatures {
need: self.wallet.config.threshold() as usize,
got: self.signature_count(input_index),
});
}
let mut script_sig = Vec::new();
script_sig.push(0x00);
let mut sigs: Vec<_> = self.input_signatures[input_index]
.iter()
.filter_map(|sig| {
self.wallet.config.key_index(&sig.pubkey).map(|idx| (idx, sig))
})
.collect();
sigs.sort_by_key(|(idx, _)| *idx);
for (_, sig) in sigs.iter().take(self.wallet.config.threshold() as usize) {
let mut sig_with_sighash = sig.signature.clone();
sig_with_sighash.push(0x01); push_data(&mut script_sig, &sig_with_sighash);
}
push_data(&mut script_sig, &self.wallet.redeem_script);
Ok(script_sig)
}
}
impl From<PartialSignature> for PsbtPartialSig {
fn from(sig: PartialSignature) -> Self {
let signature = if sig.signature.last() == Some(&0x01) {
sig.signature[..sig.signature.len() - 1].to_vec()
} else {
sig.signature
};
Self {
pubkey: sig.pubkey,
signature,
}
}
}
fn push_data(script: &mut Vec<u8>, data: &[u8]) {
let len = data.len();
if len < 76 {
script.push(len as u8);
} else if len <= 255 {
script.push(0x4c); script.push(len as u8);
} else if len <= 65535 {
script.push(0x4d); script.extend_from_slice(&(len as u16).to_le_bytes());
} else {
script.push(0x4e); script.extend_from_slice(&(len as u32).to_le_bytes());
}
script.extend_from_slice(data);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::address::Network;
fn create_test_wallet() -> (MultisigWallet, PrivateKey, PrivateKey, PrivateKey) {
let key1 = PrivateKey::random();
let key2 = PrivateKey::random();
let key3 = PrivateKey::random();
let pubkeys = vec![
key1.public_key().to_compressed(),
key2.public_key().to_compressed(),
key3.public_key().to_compressed(),
];
let wallet = MultisigWallet::from_pubkeys(2, pubkeys, Network::Mainnet).unwrap();
(wallet, key1, key2, key3)
}
#[test]
fn test_psbt_builder_creation() {
let (wallet, _, _, _) = create_test_wallet();
let builder = MultisigPsbtBuilder::new(wallet, 2);
assert_eq!(builder.input_signatures.len(), 2);
assert!(!builder.is_complete());
}
#[test]
fn test_sign_input() {
let (wallet, key1, key2, _) = create_test_wallet();
let mut builder = MultisigPsbtBuilder::new(wallet, 1);
let sighash = [0xab; 32];
builder.sign_input(0, &sighash, &key1).unwrap();
assert_eq!(builder.signature_count(0), 1);
assert!(!builder.input_is_complete(0));
builder.sign_input(0, &sighash, &key2).unwrap();
assert_eq!(builder.signature_count(0), 2);
assert!(builder.input_is_complete(0));
}
#[test]
fn test_build_witness() {
let (wallet, key1, key2, _) = create_test_wallet();
let mut builder = MultisigPsbtBuilder::new(wallet, 1);
let sighash = [0xcd; 32];
builder.sign_input(0, &sighash, &key1).unwrap();
builder.sign_input(0, &sighash, &key2).unwrap();
let witness = builder.build_witness(0).unwrap();
assert_eq!(witness.len(), 4);
assert!(witness[0].is_empty());
}
#[test]
fn test_build_script_sig() {
let (wallet, key1, key2, _) = create_test_wallet();
let mut builder = MultisigPsbtBuilder::new(wallet, 1);
let sighash = [0xef; 32];
builder.sign_input(0, &sighash, &key1).unwrap();
builder.sign_input(0, &sighash, &key2).unwrap();
let script_sig = builder.build_script_sig(0).unwrap();
assert_eq!(script_sig[0], 0x00); }
#[test]
fn test_duplicate_signature_rejected() {
let (wallet, key1, _, _) = create_test_wallet();
let mut builder = MultisigPsbtBuilder::new(wallet, 1);
let sighash = [0x11; 32];
builder.sign_input(0, &sighash, &key1).unwrap();
let result = builder.sign_input(0, &sighash, &key1);
assert!(result.is_err());
}
#[test]
fn test_wrong_key_rejected() {
let (wallet, _, _, _) = create_test_wallet();
let mut builder = MultisigPsbtBuilder::new(wallet, 1);
let wrong_key = PrivateKey::random();
let sighash = [0x22; 32];
let result = builder.sign_input(0, &sighash, &wrong_key);
assert!(result.is_err());
}
}