use std::collections::HashMap;
use bitcoin::secp256k1::{All, Parity, Secp256k1, SecretKey};
use bitcoin::{OutPoint, ScriptBuf};
use crate::address::SilentPaymentAddress;
use crate::{Aggregate, InputHash, ScanPublicKey, SharedSecret, SpendPublicKey};
pub struct SilentPayment<'a> {
recipients: Vec<SilentPaymentAddress>,
input_secret_key: Aggregate<SecretKey>,
input_hash: InputHash,
secp: &'a Secp256k1<All>,
}
impl<'a> SilentPayment<'a> {
pub fn new(secp: &Secp256k1<All>) -> SilentPayment {
SilentPayment {
secp,
recipients: Default::default(),
input_secret_key: Default::default(),
input_hash: Default::default(),
}
}
pub fn add_recipient(&mut self, address: SilentPaymentAddress) -> &mut SilentPayment<'a> {
self.recipients.push(address);
self
}
pub fn add_taproot_private_key<SK: Into<SecretKey>>(
&mut self,
key: SK,
) -> &mut SilentPayment<'a> {
let key = key.into();
let (_, y_parity) = key.public_key(self.secp).x_only_public_key();
let checked_key = if y_parity == Parity::Odd {
key.negate()
} else {
key
};
self.add_private_key(checked_key)
}
pub fn add_private_key<SK: Into<SecretKey>>(&mut self, key: SK) -> &mut SilentPayment<'a> {
let key = key.into();
self.input_secret_key.add_key(&key);
self.input_hash
.add_input_public_key(&key.public_key(self.secp))
.unwrap();
self
}
pub fn add_outpoint(&mut self, outpoint: OutPoint) -> &mut SilentPayment<'a> {
self.input_hash.add_outpoint(&outpoint);
self
}
#[must_use]
pub fn generate_output_scripts(self) -> Vec<ScriptBuf> {
if let Ok(input_hash) = self.input_hash.hash() {
let input_secret_key = self.input_secret_key.get().unwrap();
let mut groups: HashMap<ScanPublicKey, Vec<SpendPublicKey>> = HashMap::new();
self.recipients.iter().for_each(|r| {
groups.entry(r.scan_key()).or_default().push(r.spend_key());
});
groups
.into_iter()
.flat_map(|(b_scan, b_ms)| {
let shared_secret = SharedSecret::new(
input_hash,
b_scan.into_public_key(),
input_secret_key,
self.secp,
)
.unwrap();
b_ms.into_iter()
.zip(0..)
.map(move |(b_m, k)| shared_secret.destination_output(b_m, k, self.secp))
})
.collect()
} else {
vec![]
}
}
}