#include <node/txreconciliation.h>
#include <common/system.h>
#include <logging.h>
#include <util/check.h>
#include <unordered_map>
#include <variant>
namespace {
const std::string RECON_STATIC_SALT = "Tx Relay Salting";
const HashWriter RECON_SALT_HASHER = TaggedHash(RECON_STATIC_SALT);
uint256 ComputeSalt(uint64_t salt1, uint64_t salt2)
{
return (HashWriter(RECON_SALT_HASHER) << std::min(salt1, salt2) << std::max(salt1, salt2)).GetSHA256();
}
class TxReconciliationState
{
public:
bool m_we_initiate;
uint64_t m_k0, m_k1;
TxReconciliationState(bool we_initiate, uint64_t k0, uint64_t k1) : m_we_initiate(we_initiate), m_k0(k0), m_k1(k1) {}
};
}
class TxReconciliationTracker::Impl
{
private:
mutable Mutex m_txreconciliation_mutex;
uint32_t m_recon_version;
std::unordered_map<NodeId, std::variant<uint64_t, TxReconciliationState>> m_states GUARDED_BY(m_txreconciliation_mutex);
public:
explicit Impl(uint32_t recon_version) : m_recon_version(recon_version) {}
uint64_t PreRegisterPeer(NodeId peer_id) EXCLUSIVE_LOCKS_REQUIRED(!m_txreconciliation_mutex)
{
AssertLockNotHeld(m_txreconciliation_mutex);
LOCK(m_txreconciliation_mutex);
LogPrintLevel(BCLog::TXRECONCILIATION, BCLog::Level::Debug, "Pre-register peer=%d\n", peer_id);
const uint64_t local_salt{FastRandomContext().rand64()};
Assume(m_states.emplace(peer_id, local_salt).second);
return local_salt;
}
ReconciliationRegisterResult RegisterPeer(NodeId peer_id, bool is_peer_inbound, uint32_t peer_recon_version,
uint64_t remote_salt) EXCLUSIVE_LOCKS_REQUIRED(!m_txreconciliation_mutex)
{
AssertLockNotHeld(m_txreconciliation_mutex);
LOCK(m_txreconciliation_mutex);
auto recon_state = m_states.find(peer_id);
if (recon_state == m_states.end()) return ReconciliationRegisterResult::NOT_FOUND;
if (std::holds_alternative<TxReconciliationState>(recon_state->second)) {
return ReconciliationRegisterResult::ALREADY_REGISTERED;
}
uint64_t local_salt = *std::get_if<uint64_t>(&recon_state->second);
const uint32_t recon_version{std::min(peer_recon_version, m_recon_version)};
if (recon_version < 1) return ReconciliationRegisterResult::PROTOCOL_VIOLATION;
LogPrintLevel(BCLog::TXRECONCILIATION, BCLog::Level::Debug, "Register peer=%d (inbound=%i)\n",
peer_id, is_peer_inbound);
const uint256 full_salt{ComputeSalt(local_salt, remote_salt)};
recon_state->second = TxReconciliationState(!is_peer_inbound, full_salt.GetUint64(0), full_salt.GetUint64(1));
return ReconciliationRegisterResult::SUCCESS;
}
void ForgetPeer(NodeId peer_id) EXCLUSIVE_LOCKS_REQUIRED(!m_txreconciliation_mutex)
{
AssertLockNotHeld(m_txreconciliation_mutex);
LOCK(m_txreconciliation_mutex);
if (m_states.erase(peer_id)) {
LogPrintLevel(BCLog::TXRECONCILIATION, BCLog::Level::Debug, "Forget txreconciliation state of peer=%d\n", peer_id);
}
}
bool IsPeerRegistered(NodeId peer_id) const EXCLUSIVE_LOCKS_REQUIRED(!m_txreconciliation_mutex)
{
AssertLockNotHeld(m_txreconciliation_mutex);
LOCK(m_txreconciliation_mutex);
auto recon_state = m_states.find(peer_id);
return (recon_state != m_states.end() &&
std::holds_alternative<TxReconciliationState>(recon_state->second));
}
};
TxReconciliationTracker::TxReconciliationTracker(uint32_t recon_version) : m_impl{std::make_unique<TxReconciliationTracker::Impl>(recon_version)} {}
TxReconciliationTracker::~TxReconciliationTracker() = default;
uint64_t TxReconciliationTracker::PreRegisterPeer(NodeId peer_id)
{
return m_impl->PreRegisterPeer(peer_id);
}
ReconciliationRegisterResult TxReconciliationTracker::RegisterPeer(NodeId peer_id, bool is_peer_inbound,
uint32_t peer_recon_version, uint64_t remote_salt)
{
return m_impl->RegisterPeer(peer_id, is_peer_inbound, peer_recon_version, remote_salt);
}
void TxReconciliationTracker::ForgetPeer(NodeId peer_id)
{
m_impl->ForgetPeer(peer_id);
}
bool TxReconciliationTracker::IsPeerRegistered(NodeId peer_id) const
{
return m_impl->IsPeerRegistered(peer_id);
}