#include <crypto/common.h>
#include <crypto/sha256.h>
#include <crypto/siphash.h>
#include <primitives/transaction.h>
#include <test/fuzz/fuzz.h>
#include <txrequest.h>
#include <bitset>
#include <cstdint>
#include <queue>
#include <vector>
namespace {
constexpr int MAX_TXHASHES = 16;
constexpr int MAX_PEERS = 16;
uint256 TXHASHES[MAX_TXHASHES];
std::chrono::microseconds DELAYS[256];
struct Initializer
{
Initializer()
{
for (uint8_t txhash = 0; txhash < MAX_TXHASHES; txhash += 1) {
CSHA256().Write(&txhash, 1).Finalize(TXHASHES[txhash].begin());
}
int i = 0;
for (; i < 16; ++i) {
DELAYS[i] = std::chrono::microseconds{i};
}
for (; i < 128; ++i) {
int diff_bits = ((i - 10) * 2) / 9;
uint64_t diff = 1 + (CSipHasher(0, 0).Write(i).Finalize() >> (64 - diff_bits));
DELAYS[i] = DELAYS[i - 1] + std::chrono::microseconds{diff};
}
for (; i < 256; ++i) {
DELAYS[i] = -DELAYS[255 - i];
}
}
} g_initializer;
class Tester
{
TxRequestTracker m_tracker;
enum class State {
NOTHING,
CANDIDATE,
REQUESTED,
COMPLETED,
};
uint64_t m_current_sequence{0};
std::priority_queue<std::chrono::microseconds, std::vector<std::chrono::microseconds>,
std::greater<std::chrono::microseconds>> m_events;
struct Announcement
{
std::chrono::microseconds m_time;
uint64_t m_sequence;
State m_state{State::NOTHING};
bool m_preferred;
bool m_is_wtxid;
uint64_t m_priority; };
Announcement m_announcements[MAX_TXHASHES][MAX_PEERS];
std::chrono::microseconds m_now{244466666};
void Cleanup(int txhash)
{
bool all_nothing = true;
for (int peer = 0; peer < MAX_PEERS; ++peer) {
const Announcement& ann = m_announcements[txhash][peer];
if (ann.m_state != State::NOTHING) {
if (ann.m_state != State::COMPLETED) return;
all_nothing = false;
}
}
if (all_nothing) return;
for (int peer = 0; peer < MAX_PEERS; ++peer) {
m_announcements[txhash][peer].m_state = State::NOTHING;
}
}
int GetSelected(int txhash) const
{
int ret = -1;
uint64_t ret_priority = 0;
for (int peer = 0; peer < MAX_PEERS; ++peer) {
const Announcement& ann = m_announcements[txhash][peer];
if (ann.m_state == State::REQUESTED) return -1;
if (ann.m_state == State::CANDIDATE && ann.m_time <= m_now) {
if (ret == -1 || ann.m_priority > ret_priority) {
std::tie(ret, ret_priority) = std::tie(peer, ann.m_priority);
}
}
}
return ret;
}
public:
Tester() : m_tracker(true) {}
std::chrono::microseconds Now() const { return m_now; }
void AdvanceTime(std::chrono::microseconds offset)
{
m_now += offset;
while (!m_events.empty() && m_events.top() <= m_now) m_events.pop();
}
void AdvanceToEvent()
{
while (!m_events.empty() && m_events.top() <= m_now) m_events.pop();
if (!m_events.empty()) {
m_now = m_events.top();
m_events.pop();
}
}
void DisconnectedPeer(int peer)
{
for (int txhash = 0; txhash < MAX_TXHASHES; ++txhash) {
if (m_announcements[txhash][peer].m_state != State::NOTHING) {
m_announcements[txhash][peer].m_state = State::NOTHING;
Cleanup(txhash);
}
}
m_tracker.DisconnectedPeer(peer);
}
void ForgetTxHash(int txhash)
{
for (int peer = 0; peer < MAX_PEERS; ++peer) {
m_announcements[txhash][peer].m_state = State::NOTHING;
}
Cleanup(txhash);
m_tracker.ForgetTxHash(TXHASHES[txhash]);
}
void ReceivedInv(int peer, int txhash, bool is_wtxid, bool preferred, std::chrono::microseconds reqtime)
{
Announcement& ann = m_announcements[txhash][peer];
if (ann.m_state == State::NOTHING) {
ann.m_preferred = preferred;
ann.m_state = State::CANDIDATE;
ann.m_time = reqtime;
ann.m_is_wtxid = is_wtxid;
ann.m_sequence = m_current_sequence++;
ann.m_priority = m_tracker.ComputePriority(TXHASHES[txhash], peer, ann.m_preferred);
if (reqtime > m_now) m_events.push(reqtime);
}
auto gtxid = is_wtxid ? GenTxid{Wtxid::FromUint256(TXHASHES[txhash])} : GenTxid{Txid::FromUint256(TXHASHES[txhash])};
m_tracker.ReceivedInv(peer, gtxid, preferred, reqtime);
}
void RequestedTx(int peer, int txhash, std::chrono::microseconds exptime)
{
if (m_announcements[txhash][peer].m_state == State::CANDIDATE) {
for (int peer2 = 0; peer2 < MAX_PEERS; ++peer2) {
if (m_announcements[txhash][peer2].m_state == State::REQUESTED) {
m_announcements[txhash][peer2].m_state = State::COMPLETED;
}
}
m_announcements[txhash][peer].m_state = State::REQUESTED;
m_announcements[txhash][peer].m_time = exptime;
}
if (exptime > m_now) m_events.push(exptime);
m_tracker.RequestedTx(peer, TXHASHES[txhash], exptime);
}
void ReceivedResponse(int peer, int txhash)
{
if (m_announcements[txhash][peer].m_state != State::NOTHING) {
m_announcements[txhash][peer].m_state = State::COMPLETED;
Cleanup(txhash);
}
m_tracker.ReceivedResponse(peer, TXHASHES[txhash]);
}
void GetRequestable(int peer)
{
std::vector<std::tuple<uint64_t, int, bool>> result;
std::vector<std::pair<NodeId, GenTxid>> expected_expired;
for (int txhash = 0; txhash < MAX_TXHASHES; ++txhash) {
for (int peer2 = 0; peer2 < MAX_PEERS; ++peer2) {
Announcement& ann2 = m_announcements[txhash][peer2];
if (ann2.m_state == State::REQUESTED && ann2.m_time <= m_now) {
auto gtxid = ann2.m_is_wtxid ? GenTxid{Wtxid::FromUint256(TXHASHES[txhash])} : GenTxid{Txid::FromUint256(TXHASHES[txhash])};
expected_expired.emplace_back(peer2, gtxid);
ann2.m_state = State::COMPLETED;
break;
}
}
Cleanup(txhash);
const Announcement& ann = m_announcements[txhash][peer];
if (ann.m_state == State::CANDIDATE && GetSelected(txhash) == peer) {
result.emplace_back(ann.m_sequence, txhash, ann.m_is_wtxid);
}
}
std::sort(result.begin(), result.end());
std::sort(expected_expired.begin(), expected_expired.end());
std::vector<std::pair<NodeId, GenTxid>> expired;
const auto actual = m_tracker.GetRequestable(peer, m_now, &expired);
std::sort(expired.begin(), expired.end());
assert(expired == expected_expired);
m_tracker.PostGetRequestableSanityCheck(m_now);
assert(result.size() == actual.size());
for (size_t pos = 0; pos < actual.size(); ++pos) {
assert(TXHASHES[std::get<1>(result[pos])] == actual[pos].ToUint256());
assert(std::get<2>(result[pos]) == actual[pos].IsWtxid());
}
}
void Check()
{
size_t total = 0;
for (int peer = 0; peer < MAX_PEERS; ++peer) {
size_t tracked = 0;
size_t inflight = 0;
size_t candidates = 0;
for (int txhash = 0; txhash < MAX_TXHASHES; ++txhash) {
tracked += m_announcements[txhash][peer].m_state != State::NOTHING;
inflight += m_announcements[txhash][peer].m_state == State::REQUESTED;
candidates += m_announcements[txhash][peer].m_state == State::CANDIDATE;
std::bitset<MAX_PEERS> expected_announcers;
for (int peer = 0; peer < MAX_PEERS; ++peer) {
if (m_announcements[txhash][peer].m_state == State::CANDIDATE || m_announcements[txhash][peer].m_state == State::REQUESTED) {
expected_announcers[peer] = true;
}
}
std::vector<NodeId> candidate_peers;
m_tracker.GetCandidatePeers(TXHASHES[txhash], candidate_peers);
assert(expected_announcers.count() == candidate_peers.size());
for (const auto& peer : candidate_peers) {
assert(expected_announcers[peer]);
}
}
assert(m_tracker.Count(peer) == tracked);
assert(m_tracker.CountInFlight(peer) == inflight);
assert(m_tracker.CountCandidates(peer) == candidates);
total += tracked;
}
assert(m_tracker.Size() == total);
m_tracker.SanityCheck();
}
};
}
FUZZ_TARGET(txrequest)
{
Tester tester;
auto it = buffer.begin();
while (it != buffer.end()) {
int cmd = *(it++) % 11;
int peer, txidnum, delaynum;
switch (cmd) {
case 0: tester.AdvanceToEvent();
break;
case 1: delaynum = it == buffer.end() ? 0 : *(it++);
tester.AdvanceTime(DELAYS[delaynum]);
break;
case 2: peer = it == buffer.end() ? 0 : *(it++) % MAX_PEERS;
tester.GetRequestable(peer);
break;
case 3: peer = it == buffer.end() ? 0 : *(it++) % MAX_PEERS;
tester.DisconnectedPeer(peer);
break;
case 4: txidnum = it == buffer.end() ? 0 : *(it++);
tester.ForgetTxHash(txidnum % MAX_TXHASHES);
break;
case 5: case 6: peer = it == buffer.end() ? 0 : *(it++) % MAX_PEERS;
txidnum = it == buffer.end() ? 0 : *(it++);
tester.ReceivedInv(peer, txidnum % MAX_TXHASHES, (txidnum / MAX_TXHASHES) & 1, cmd & 1,
std::chrono::microseconds::min());
break;
case 7: case 8: peer = it == buffer.end() ? 0 : *(it++) % MAX_PEERS;
txidnum = it == buffer.end() ? 0 : *(it++);
delaynum = it == buffer.end() ? 0 : *(it++);
tester.ReceivedInv(peer, txidnum % MAX_TXHASHES, (txidnum / MAX_TXHASHES) & 1, cmd & 1,
tester.Now() + DELAYS[delaynum]);
break;
case 9: peer = it == buffer.end() ? 0 : *(it++) % MAX_PEERS;
txidnum = it == buffer.end() ? 0 : *(it++);
delaynum = it == buffer.end() ? 0 : *(it++);
tester.RequestedTx(peer, txidnum % MAX_TXHASHES, tester.Now() + DELAYS[delaynum]);
break;
case 10: peer = it == buffer.end() ? 0 : *(it++) % MAX_PEERS;
txidnum = it == buffer.end() ? 0 : *(it++);
tester.ReceivedResponse(peer, txidnum % MAX_TXHASHES);
break;
default:
assert(false);
}
}
tester.Check();
}