use bsv_primitives::ec::{PrivateKey, PublicKey};
use bsv_primitives::hash::hash160;
use bsv_script::opcodes::*;
use bsv_script::Script;
use crate::sighash::SIGHASH_ALL_FORKID;
use crate::template::UnlockingScriptTemplate;
use crate::transaction::Transaction;
use crate::TransactionError;
pub const MAX_MULTISIG_KEYS: usize = 16;
pub const MIN_MULTISIG_KEYS: usize = 1;
#[derive(Clone, Debug)]
pub struct MultisigScript {
threshold: u8,
public_keys: Vec<PublicKey>,
}
impl MultisigScript {
pub fn new(threshold: u8, public_keys: Vec<PublicKey>) -> Result<Self, TransactionError> {
let n = public_keys.len();
if n < MIN_MULTISIG_KEYS {
return Err(TransactionError::InvalidTransaction(
"multisig requires at least 1 public key".to_string(),
));
}
if n > MAX_MULTISIG_KEYS {
return Err(TransactionError::InvalidTransaction(format!(
"multisig supports at most {} public keys, got {}",
MAX_MULTISIG_KEYS, n
)));
}
if threshold == 0 {
return Err(TransactionError::InvalidTransaction(
"multisig threshold must be at least 1".to_string(),
));
}
if threshold as usize > n {
return Err(TransactionError::InvalidTransaction(format!(
"threshold {} exceeds number of keys {}",
threshold, n
)));
}
Ok(MultisigScript {
threshold,
public_keys,
})
}
pub fn threshold(&self) -> u8 {
self.threshold
}
pub fn n(&self) -> usize {
self.public_keys.len()
}
pub fn public_keys(&self) -> &[PublicKey] {
&self.public_keys
}
pub fn to_script(&self) -> Script {
let n = self.public_keys.len();
let mut bytes = Vec::with_capacity(3 + n * 34);
bytes.push(OP_BASE + self.threshold);
for pk in &self.public_keys {
let compressed = pk.to_compressed();
bytes.push(OP_DATA_33);
bytes.extend_from_slice(&compressed);
}
bytes.push(OP_BASE + n as u8);
bytes.push(OP_CHECKMULTISIG);
Script::from_bytes(&bytes)
}
pub fn to_bytes(&self) -> Vec<u8> {
self.to_script().to_bytes().to_vec()
}
pub fn mpkh(&self) -> [u8; 20] {
hash160(&self.to_bytes())
}
pub fn from_script_bytes(bytes: &[u8]) -> Result<Self, TransactionError> {
if bytes.len() < 37 {
return Err(TransactionError::InvalidTransaction(
"script too short for multisig".to_string(),
));
}
if *bytes.last().unwrap() != OP_CHECKMULTISIG {
return Err(TransactionError::InvalidTransaction(
"script does not end with OP_CHECKMULTISIG".to_string(),
));
}
let m_op = bytes[0];
if m_op < OP_1 || m_op > OP_16 {
return Err(TransactionError::InvalidTransaction(format!(
"invalid threshold opcode: 0x{:02x}",
m_op
)));
}
let m = (m_op - OP_BASE) as u8;
let n_op = bytes[bytes.len() - 2];
if n_op < OP_1 || n_op > OP_16 {
return Err(TransactionError::InvalidTransaction(format!(
"invalid key count opcode: 0x{:02x}",
n_op
)));
}
let n = (n_op - OP_BASE) as usize;
let key_section = &bytes[1..bytes.len() - 2]; if key_section.len() != n * 34 {
return Err(TransactionError::InvalidTransaction(format!(
"expected {} key slots ({} bytes), got {} bytes",
n,
n * 34,
key_section.len()
)));
}
let mut public_keys = Vec::with_capacity(n);
for i in 0..n {
let offset = i * 34;
if key_section[offset] != OP_DATA_33 {
return Err(TransactionError::InvalidTransaction(format!(
"expected OP_DATA_33 at key {}, got 0x{:02x}",
i,
key_section[offset]
)));
}
let pk_bytes = &key_section[offset + 1..offset + 34];
let pk = PublicKey::from_bytes(pk_bytes).map_err(|e| {
TransactionError::InvalidTransaction(format!(
"invalid public key at index {}: {}",
i, e
))
})?;
public_keys.push(pk);
}
MultisigScript::new(m, public_keys)
}
}
pub fn lock(multisig: &MultisigScript) -> Result<Script, TransactionError> {
Ok(multisig.to_script())
}
pub fn unlock(
private_keys: Vec<PrivateKey>,
multisig: MultisigScript,
sighash_flag: Option<u32>,
) -> Result<P2MPKH, TransactionError> {
if private_keys.len() != multisig.threshold() as usize {
return Err(TransactionError::SigningError(format!(
"expected {} private keys for threshold, got {}",
multisig.threshold(),
private_keys.len()
)));
}
Ok(P2MPKH {
private_keys,
multisig,
sighash_flag: sighash_flag.unwrap_or(SIGHASH_ALL_FORKID),
})
}
#[derive(Debug)]
pub struct P2MPKH {
private_keys: Vec<PrivateKey>,
multisig: MultisigScript,
sighash_flag: u32,
}
impl P2MPKH {
pub fn multisig(&self) -> &MultisigScript {
&self.multisig
}
}
impl UnlockingScriptTemplate for P2MPKH {
fn sign(&self, tx: &Transaction, input_index: u32) -> Result<Script, TransactionError> {
let idx = input_index as usize;
if idx >= tx.inputs.len() {
return Err(TransactionError::SigningError(format!(
"input index {} out of range (tx has {} inputs)",
idx,
tx.inputs.len()
)));
}
let input = &tx.inputs[idx];
if input.source_tx_output().is_none() {
return Err(TransactionError::SigningError(
"missing source output on input (no previous tx info)".to_string(),
));
}
let sig_hash = tx.calc_input_signature_hash(idx, self.sighash_flag)?;
let mut script = Script::new();
script.append_push_data(&[])?;
for pk in &self.private_keys {
let signature = pk.sign(&sig_hash)?;
let der_sig = signature.to_der();
let mut sig_buf = Vec::with_capacity(der_sig.len() + 1);
sig_buf.extend_from_slice(&der_sig);
sig_buf.push(self.sighash_flag as u8);
script.append_push_data(&sig_buf)?;
}
Ok(script)
}
fn estimate_length(&self, _tx: &Transaction, _input_index: u32) -> u32 {
1 + (self.multisig.threshold() as u32) * 73
}
}
#[cfg(test)]
mod tests {
use super::*;
fn gen_keys(n: usize) -> (Vec<PrivateKey>, Vec<PublicKey>) {
let privs: Vec<PrivateKey> = (0..n).map(|_| PrivateKey::new()).collect();
let pubs: Vec<PublicKey> = privs.iter().map(|k| k.pub_key()).collect();
(privs, pubs)
}
#[test]
fn multisig_script_2_of_3_roundtrip() {
let (_privs, pubs) = gen_keys(3);
let ms = MultisigScript::new(2, pubs).unwrap();
assert_eq!(ms.threshold(), 2);
assert_eq!(ms.n(), 3);
let bytes = ms.to_bytes();
let ms2 = MultisigScript::from_script_bytes(&bytes).unwrap();
assert_eq!(ms2.threshold(), 2);
assert_eq!(ms2.n(), 3);
assert_eq!(ms.mpkh(), ms2.mpkh());
}
#[test]
fn multisig_script_1_of_1() {
let (_privs, pubs) = gen_keys(1);
let ms = MultisigScript::new(1, pubs).unwrap();
assert_eq!(ms.threshold(), 1);
assert_eq!(ms.n(), 1);
let bytes = ms.to_bytes();
assert_eq!(bytes.len(), 37);
}
#[test]
fn multisig_script_max_keys() {
let (_privs, pubs) = gen_keys(MAX_MULTISIG_KEYS);
let ms = MultisigScript::new(1, pubs).unwrap();
assert_eq!(ms.n(), MAX_MULTISIG_KEYS);
}
#[test]
fn multisig_script_rejects_zero_threshold() {
let (_privs, pubs) = gen_keys(3);
let err = MultisigScript::new(0, pubs).unwrap_err();
assert!(err.to_string().contains("threshold must be at least 1"));
}
#[test]
fn multisig_script_rejects_threshold_exceeding_keys() {
let (_privs, pubs) = gen_keys(2);
let err = MultisigScript::new(3, pubs).unwrap_err();
assert!(err.to_string().contains("threshold 3 exceeds"));
}
#[test]
fn multisig_script_rejects_too_many_keys() {
let (_privs, pubs) = gen_keys(MAX_MULTISIG_KEYS + 1);
let err = MultisigScript::new(1, pubs).unwrap_err();
assert!(err.to_string().contains("at most"));
}
#[test]
fn multisig_script_rejects_empty_keys() {
let err = MultisigScript::new(1, vec![]).unwrap_err();
assert!(err.to_string().contains("at least 1"));
}
#[test]
fn lock_produces_bare_multisig() {
let (_privs, pubs) = gen_keys(3);
let ms = MultisigScript::new(2, pubs).unwrap();
let script = lock(&ms).unwrap();
let bytes = script.to_bytes();
assert_eq!(bytes[0], OP_2);
assert_eq!(*bytes.last().unwrap(), OP_CHECKMULTISIG);
assert_eq!(bytes[bytes.len() - 2], OP_3);
}
#[test]
fn mpkh_is_20_bytes() {
let (_privs, pubs) = gen_keys(3);
let ms = MultisigScript::new(2, pubs).unwrap();
assert_eq!(ms.mpkh().len(), 20);
}
#[test]
fn mpkh_differs_for_different_key_sets() {
let (_privs1, pubs1) = gen_keys(3);
let (_privs2, pubs2) = gen_keys(3);
let ms1 = MultisigScript::new(2, pubs1).unwrap();
let ms2 = MultisigScript::new(2, pubs2).unwrap();
assert_ne!(ms1.mpkh(), ms2.mpkh());
}
#[test]
fn mpkh_differs_for_different_thresholds() {
let (_privs, pubs) = gen_keys(3);
let ms1 = MultisigScript::new(1, pubs.clone()).unwrap();
let ms2 = MultisigScript::new(2, pubs).unwrap();
assert_ne!(ms1.mpkh(), ms2.mpkh());
}
#[test]
fn unlock_rejects_wrong_key_count() {
let (privs, pubs) = gen_keys(3);
let ms = MultisigScript::new(2, pubs).unwrap();
let err = unlock(privs, ms, None).unwrap_err();
assert!(err.to_string().contains("expected 2 private keys"));
}
#[test]
fn estimate_length_2_of_3() {
let (privs, pubs) = gen_keys(3);
let ms = MultisigScript::new(2, pubs).unwrap();
let unlocker = unlock(vec![privs[0].clone(), privs[1].clone()], ms, None).unwrap();
let tx = Transaction::default();
let est = unlocker.estimate_length(&tx, 0);
assert_eq!(est, 147);
}
#[test]
fn from_script_bytes_rejects_garbage() {
let err = MultisigScript::from_script_bytes(&[0x00, 0x01, 0x02]).unwrap_err();
assert!(err.to_string().contains("too short"));
}
use crate::input::TransactionInput;
use crate::output::TransactionOutput;
fn mock_tx_with_source(satoshis: u64) -> Transaction {
use bsv_script::Script;
let locking_script = Script::from_asm(
"OP_DUP OP_HASH160 aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa OP_EQUALVERIFY OP_CHECKSIG",
)
.unwrap();
let source_output = TransactionOutput {
satoshis,
locking_script: locking_script.clone(),
change: false,
};
let mut input = TransactionInput::new();
input.source_txid = [0u8; 32];
input.source_tx_out_index = 0;
input.set_source_output(Some(source_output));
let mut tx = Transaction::new();
tx.add_input(input);
tx.add_output(TransactionOutput {
satoshis: satoshis.saturating_sub(1000),
locking_script,
change: false,
});
tx
}
#[test]
fn p2mpkh_sign_2_of_3_script_structure() {
let (privs, pubs) = gen_keys(3);
let ms = MultisigScript::new(2, pubs).unwrap();
let unlocker = unlock(vec![privs[0].clone(), privs[1].clone()], ms, None).unwrap();
let tx = mock_tx_with_source(10_000);
let script = unlocker.sign(&tx, 0).unwrap();
let chunks = script.chunks().unwrap();
assert_eq!(chunks.len(), 3, "expected OP_0 + 2 signature pushes");
let dummy = &chunks[0];
assert!(
dummy.data.is_none() || dummy.data.as_ref().map_or(false, |d| d.is_empty()),
"first chunk should be OP_0 (empty push)"
);
for i in 1..3 {
let sig_data = chunks[i].data.as_ref().expect("signature should be push data");
assert!(
sig_data.len() >= 71 && sig_data.len() <= 73,
"signature {} length {} not in 71..=73",
i - 1,
sig_data.len()
);
assert_eq!(
*sig_data.last().unwrap(),
0x41,
"signature should end with SIGHASH_ALL_FORKID"
);
}
}
#[test]
fn p2mpkh_sign_1_of_1() {
let (privs, pubs) = gen_keys(1);
let ms = MultisigScript::new(1, pubs).unwrap();
let unlocker = unlock(vec![privs[0].clone()], ms, None).unwrap();
let tx = mock_tx_with_source(5_000);
let script = unlocker.sign(&tx, 0).unwrap();
let chunks = script.chunks().unwrap();
assert_eq!(chunks.len(), 2);
let sig = chunks[1].data.as_ref().unwrap();
assert!(sig.len() >= 71 && sig.len() <= 73);
assert_eq!(*sig.last().unwrap(), 0x41);
}
#[test]
fn p2mpkh_sign_missing_source_output_returns_error() {
let (privs, pubs) = gen_keys(3);
let ms = MultisigScript::new(2, pubs).unwrap();
let unlocker = unlock(vec![privs[0].clone(), privs[1].clone()], ms, None).unwrap();
let mut tx = Transaction::new();
tx.add_input(TransactionInput::new());
tx.add_output(TransactionOutput::new());
let result = unlocker.sign(&tx, 0);
assert!(result.is_err());
assert!(
result.unwrap_err().to_string().contains("missing source output"),
"error should mention missing source output"
);
}
#[test]
fn p2mpkh_sign_custom_sighash_flag() {
use crate::sighash::{SIGHASH_NONE, SIGHASH_FORKID};
let sighash_none_forkid = SIGHASH_NONE | SIGHASH_FORKID;
let (privs, pubs) = gen_keys(2);
let ms = MultisigScript::new(1, pubs).unwrap();
let unlocker = unlock(
vec![privs[0].clone()],
ms,
Some(sighash_none_forkid),
)
.unwrap();
let tx = mock_tx_with_source(8_000);
let script = unlocker.sign(&tx, 0).unwrap();
let chunks = script.chunks().unwrap();
assert_eq!(chunks.len(), 2);
let sig = chunks[1].data.as_ref().unwrap();
assert_eq!(
*sig.last().unwrap(),
sighash_none_forkid as u8,
"signature should use the custom sighash flag"
);
}
#[test]
fn p2mpkh_estimate_length_1_of_1() {
let (privs, pubs) = gen_keys(1);
let ms = MultisigScript::new(1, pubs).unwrap();
let unlocker = unlock(vec![privs[0].clone()], ms, None).unwrap();
let tx = Transaction::default();
assert_eq!(unlocker.estimate_length(&tx, 0), 74);
}
#[test]
fn p2mpkh_estimate_length_3_of_5() {
let (privs, pubs) = gen_keys(5);
let ms = MultisigScript::new(3, pubs).unwrap();
let unlocker = unlock(
vec![privs[0].clone(), privs[1].clone(), privs[2].clone()],
ms,
None,
)
.unwrap();
let tx = Transaction::default();
assert_eq!(unlocker.estimate_length(&tx, 0), 220);
}
}