#include <arith_uint256.h>
#include <policy/feerate.h>
#include <util/feefrac.h>
#include <test/fuzz/FuzzedDataProvider.h>
#include <test/fuzz/fuzz.h>
#include <test/fuzz/util.h>
#include <compare>
#include <cmath>
#include <cstdint>
#include <iostream>
namespace {
const auto MAX_ABS_INT64 = arith_uint256{1} << 63;
arith_uint256 Abs256(int64_t x)
{
if (x >= 0) {
return arith_uint256{static_cast<uint64_t>(x)};
} else if (x > std::numeric_limits<int64_t>::min()) {
return arith_uint256{static_cast<uint64_t>(-x)};
} else {
return MAX_ABS_INT64;
}
}
arith_uint256 Abs256(std::pair<int64_t, uint32_t> x)
{
if (x.first >= 0) {
return (Abs256(x.first) << 32) + Abs256(x.second);
} else {
return (Abs256(x.first) << 32) - Abs256(x.second);
}
}
std::strong_ordering MulCompare(int64_t a1, int64_t a2, int64_t b1, int64_t b2)
{
int sign_a = (a1 == 0 ? 0 : a1 < 0 ? -1 : 1) * (a2 == 0 ? 0 : a2 < 0 ? -1 : 1);
int sign_b = (b1 == 0 ? 0 : b1 < 0 ? -1 : 1) * (b2 == 0 ? 0 : b2 < 0 ? -1 : 1);
if (sign_a != sign_b) return sign_a <=> sign_b;
auto mul_abs_a = Abs256(a1) * Abs256(a2), mul_abs_b = Abs256(b1) * Abs256(b2);
if (sign_a < 0) {
return mul_abs_b <=> mul_abs_a;
} else {
return mul_abs_a <=> mul_abs_b;
}
}
}
FUZZ_TARGET(feefrac)
{
FuzzedDataProvider provider(buffer.data(), buffer.size());
int64_t f1 = provider.ConsumeIntegral<int64_t>();
int32_t s1 = provider.ConsumeIntegral<int32_t>();
if (s1 == 0) f1 = 0;
FeeFrac fr1(f1, s1);
assert(fr1.IsEmpty() == (s1 == 0));
int64_t f2 = provider.ConsumeIntegral<int64_t>();
int32_t s2 = provider.ConsumeIntegral<int32_t>();
if (s2 == 0) f2 = 0;
FeeFrac fr2(f2, s2);
assert(fr2.IsEmpty() == (s2 == 0));
auto cmp_feerate = MulCompare(f1, s2, f2, s1);
assert(FeeRateCompare(fr1, fr2) == cmp_feerate);
assert((fr1 << fr2) == std::is_lt(cmp_feerate));
assert((fr1 >> fr2) == std::is_gt(cmp_feerate));
auto cmp_mul = FeeFrac::Mul(f1, s2) <=> FeeFrac::Mul(f2, s1);
assert(cmp_mul == cmp_feerate);
auto cmp_fallback = FeeFrac::MulFallback(f1, s2) <=> FeeFrac::MulFallback(f2, s1);
assert(cmp_fallback == cmp_feerate);
auto cmp_total = std::is_eq(cmp_feerate) ? (s2 <=> s1) : cmp_feerate;
assert((fr1 <=> fr2) == cmp_total);
assert((fr1 < fr2) == std::is_lt(cmp_total));
assert((fr1 > fr2) == std::is_gt(cmp_total));
assert((fr1 <= fr2) == std::is_lteq(cmp_total));
assert((fr1 >= fr2) == std::is_gteq(cmp_total));
assert((fr1 == fr2) == std::is_eq(cmp_total));
assert((fr1 != fr2) == std::is_neq(cmp_total));
}
FUZZ_TARGET(feefrac_div_fallback)
{
FuzzedDataProvider provider(buffer.data(), buffer.size());
auto num_high = provider.ConsumeIntegral<int64_t>();
auto num_low = provider.ConsumeIntegral<uint32_t>();
std::pair<int64_t, uint32_t> num{num_high, num_low};
auto den = provider.ConsumeIntegralInRange<int32_t>(1, std::numeric_limits<int32_t>::max());
auto round_down = provider.ConsumeBool();
bool is_negative = num_high < 0;
auto num_abs = Abs256(num);
auto den_abs = Abs256(den);
auto quot_abs = (is_negative == round_down) ?
(num_abs + den_abs - 1) / den_abs :
num_abs / den_abs;
if ((is_negative && quot_abs > MAX_ABS_INT64) || (!is_negative && quot_abs >= MAX_ABS_INT64)) {
return;
}
auto res = FeeFrac::DivFallback(num, den, round_down);
assert(res == 0 || (res < 0) == is_negative);
assert(Abs256(res) == quot_abs);
long double expect = round_down ? std::floor(num_high * 4294967296.0L + num_low) / den
: std::ceil(num_high * 4294967296.0L + num_low) / den;
if (expect == 0.0L) {
assert(res >= -1 && res <= 1);
} else if (expect > 0.0L) {
assert(res >= expect * 0.999999999999999L - 1.0L);
assert(res <= expect * 1.000000000000001L + 1.0L);
} else {
assert(res >= expect * 1.000000000000001L - 1.0L);
assert(res <= expect * 0.999999999999999L + 1.0L);
}
}
FUZZ_TARGET(feefrac_mul_div)
{
FuzzedDataProvider provider(buffer.data(), buffer.size());
auto mul32 = provider.ConsumeIntegral<int32_t>();
auto mul64 = provider.ConsumeIntegral<int64_t>();
auto div = provider.ConsumeIntegralInRange<int32_t>(1, std::numeric_limits<int32_t>::max());
auto round_down = provider.ConsumeBool();
bool is_negative = ((mul32 < 0) && (mul64 > 0)) || ((mul32 > 0) && (mul64 < 0));
auto prod_abs = Abs256(mul32) * Abs256(mul64);
auto div_abs = Abs256(div);
auto quot_abs = (is_negative == round_down) ?
(prod_abs + div_abs - 1) / div_abs :
prod_abs / div_abs;
if ((is_negative && quot_abs > MAX_ABS_INT64) || (!is_negative && quot_abs >= MAX_ABS_INT64)) {
assert(mul32 < 0 || mul32 > div);
return;
}
auto res = FeeFrac::Div(FeeFrac::Mul(mul64, mul32), div, round_down);
assert(res == 0 || (res < 0) == is_negative);
assert(Abs256(res) == quot_abs);
auto res_fallback = FeeFrac::DivFallback(FeeFrac::MulFallback(mul64, mul32), div, round_down);
assert(res == res_fallback);
long double expect = round_down ? std::floor(static_cast<long double>(mul32) * mul64 / div)
: std::ceil(static_cast<long double>(mul32) * mul64 / div);
if (expect == 0.0L) {
assert(res >= -1 && res <= 1);
} else if (expect > 0.0L) {
assert(res >= expect * 0.999999999999999L - 1.0L);
assert(res <= expect * 1.000000000000001L + 1.0L);
} else {
assert(res >= expect * 1.000000000000001L - 1.0L);
assert(res <= expect * 0.999999999999999L + 1.0L);
}
if (mul32 >= 0) {
auto res_fee = round_down ?
FeeFrac{mul64, div}.EvaluateFeeDown(mul32) :
FeeFrac{mul64, div}.EvaluateFeeUp(mul32);
assert(res == res_fee);
if (mul64 < std::numeric_limits<int64_t>::max() / 1000 &&
mul64 > std::numeric_limits<int64_t>::min() / 1000 &&
quot_abs < arith_uint256{std::numeric_limits<int64_t>::max() / 1000}) {
CFeeRate feerate(mul64, (uint32_t)div);
CAmount feerate_fee{feerate.GetFee(mul32)};
auto allowed_gap = static_cast<int64_t>(mul32 / 1000 + 3 + round_down);
assert(feerate_fee - res_fee >= -allowed_gap);
assert(feerate_fee - res_fee <= allowed_gap);
}
}
}