use crate::pskt::{KeySource, Version};
use crate::utils::combine_if_no_conflicts;
use derive_builder::Builder;
use kaspa_consensus_core::tx::TransactionId;
use serde::{Deserialize, Serialize};
use std::{
collections::{btree_map, BTreeMap},
ops::Add,
};
type Xpub = kaspa_bip32::ExtendedPublicKey<secp256k1::PublicKey>;
#[derive(Debug, Clone, Builder, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
#[builder(default)]
pub struct Global {
pub version: Version,
pub tx_version: u16,
#[builder(setter(strip_option))]
pub fallback_lock_time: Option<u64>,
pub inputs_modifiable: bool,
pub outputs_modifiable: bool,
pub input_count: usize,
pub output_count: usize,
pub xpubs: BTreeMap<Xpub, KeySource>,
pub id: Option<TransactionId>,
pub proprietaries: BTreeMap<String, serde_value::Value>,
#[serde(flatten)]
pub unknowns: BTreeMap<String, serde_value::Value>,
}
impl Add for Global {
type Output = Result<Self, CombineError>;
fn add(mut self, rhs: Self) -> Self::Output {
if self.version != rhs.version {
return Err(CombineError::VersionMismatch { this: self.version, that: rhs.version });
}
if self.tx_version != rhs.tx_version {
return Err(CombineError::TxVersionMismatch { this: self.tx_version, that: rhs.tx_version });
}
self.fallback_lock_time = match (self.fallback_lock_time, rhs.fallback_lock_time) {
(Some(lhs), Some(rhs)) if lhs != rhs => return Err(CombineError::LockTimeMismatch { this: lhs, that: rhs }),
(Some(v), _) | (_, Some(v)) => Some(v),
_ => None,
};
self.inputs_modifiable &= rhs.inputs_modifiable;
self.outputs_modifiable &= rhs.outputs_modifiable;
self.input_count = self.input_count.max(rhs.input_count);
self.output_count = self.output_count.max(rhs.output_count);
for (xpub, KeySource { key_fingerprint: fingerprint1, derivation_path: derivation1 }) in rhs.xpubs {
match self.xpubs.entry(xpub) {
btree_map::Entry::Vacant(entry) => {
entry.insert(KeySource::new(fingerprint1, derivation1));
}
btree_map::Entry::Occupied(mut entry) => {
let KeySource { key_fingerprint: fingerprint2, derivation_path: derivation2 } = entry.get().clone();
if (derivation1 == derivation2 && fingerprint1 == fingerprint2)
|| (derivation1.len() < derivation2.len()
&& derivation1.as_ref() == &derivation2.as_ref()[derivation2.len() - derivation1.len()..])
{
continue;
} else if derivation2.as_ref() == &derivation1.as_ref()[derivation1.len() - derivation2.len()..] {
entry.insert(KeySource::new(fingerprint1, derivation1));
continue;
}
return Err(CombineError::InconsistentKeySources(entry.key().clone()));
}
}
}
self.id = match (self.id, rhs.id) {
(Some(lhs), Some(rhs)) if lhs != rhs => return Err(CombineError::TransactionIdMismatch { this: lhs, that: rhs }),
(Some(v), _) | (_, Some(v)) => Some(v),
_ => None,
};
self.proprietaries =
combine_if_no_conflicts(self.proprietaries, rhs.proprietaries).map_err(CombineError::NotCompatibleProprietary)?;
self.unknowns = combine_if_no_conflicts(self.unknowns, rhs.unknowns).map_err(CombineError::NotCompatibleUnknownField)?;
Ok(self)
}
}
impl Default for Global {
fn default() -> Self {
Global {
version: Version::Zero,
tx_version: kaspa_consensus_core::constants::TX_VERSION,
fallback_lock_time: None,
inputs_modifiable: false,
outputs_modifiable: false,
input_count: 0,
output_count: 0,
xpubs: Default::default(),
id: None,
proprietaries: Default::default(),
unknowns: Default::default(),
}
}
}
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum CombineError {
#[error("The version numbers are not the same")]
VersionMismatch {
this: Version,
that: Version,
},
#[error("The transaction version numbers are not the same")]
TxVersionMismatch {
this: u16,
that: u16,
},
#[error("The transaction lock times are not the same")]
LockTimeMismatch {
this: u64,
that: u64,
},
#[error("The transaction ids are not the same")]
TransactionIdMismatch {
this: TransactionId,
that: TransactionId,
},
#[error("combining PSKT, key-source conflict for xpub {0}")]
InconsistentKeySources(Xpub),
#[error("Two different unknown field values")]
NotCompatibleUnknownField(crate::utils::Error<String, serde_value::Value>),
#[error("Two different proprietary values")]
NotCompatibleProprietary(crate::utils::Error<String, serde_value::Value>),
}