#include "../include/minisketch.h"
#include <string.h>
#include <memory>
#include <vector>
#include <algorithm>
#include <random>
#include <iostream>
#include <thread>
#include "util.h"
uint64_t Combination(uint64_t n, uint64_t k) {
if (n - k < k) k = n - k;
uint64_t ret = 1;
for (uint64_t i = 1; i <= k; ++i) {
ret = (ret * n) / i;
--n;
}
return ret;
}
void TestAll(int bits, int impl, int count, uint32_t threadid, uint32_t threads, std::vector<uint64_t>& ret) {
minisketch* state = minisketch_create(bits, impl, count);
if (!state) return;
for (uint64_t x = threadid; (x >> (bits * count)) == 0; x += threads) {
unsigned char ser[8];
ser[0] = x;
ser[1] = x >> 8;
ser[2] = x >> 16;
ser[3] = x >> 24;
ser[4] = x >> 32;
ser[5] = x >> 40;
ser[6] = x >> 48;
ser[7] = x >> 56;
minisketch_deserialize(state, ser);
uint64_t roots[64];
int num_roots = minisketch_decode(state, 64, roots);
if (num_roots >= 0) {
CHECK(num_roots < 1 || minisketch_decode(state, num_roots - 1, roots) == -1);
minisketch* state2 = minisketch_create(bits, 0, count);
for (int i = 0; i < num_roots; ++i) {
minisketch_add_uint64(state2, roots[i]);
}
unsigned char nser[8] = {0};
minisketch_serialize(state2, nser);
CHECK(memcmp(ser, nser, 8) == 0);
if (num_roots +1 >= (int)ret.size()) ret.resize(num_roots + 2);
ret[num_roots + 1]++;
minisketch_destroy(state2);
} else {
if (ret.size() == 0) ret.resize(1);
ret[0]++;
}
}
minisketch_destroy(state);
}
std::vector<uint64_t> TestAll(int bits, int impl, int count, uint32_t threads) {
std::vector<std::vector<uint64_t>> outputs;
std::vector<std::thread> thread_list;
thread_list.reserve(threads);
outputs.resize(threads);
for (uint32_t i = 0; i < threads; ++i) {
thread_list.emplace_back([=,&outputs](){ TestAll(bits, impl, count, i, threads, outputs[i]); });
}
std::vector<uint64_t> ret;
for (uint32_t i = 0; i < threads; ++i) {
thread_list[i].join();
if (ret.size() < outputs[i].size()) ret.resize(outputs[i].size());
for (size_t j = 0; j < outputs[i].size(); ++j) {
ret[j] += outputs[i][j];
}
}
if (ret.size()) {
for (int i = 1; i <= count + 1; ++i) {
CHECK(ret[i] == Combination((uint64_t(1) << bits) - 1, i - 1));
}
}
return ret;
}
void TestRand(int bits, int impl, int count, int iter) {
std::vector<uint64_t> elems(count);
std::vector<uint64_t> roots(count + 1);
std::random_device rnd;
std::uniform_int_distribution<uint64_t> dist(1, bits == 64 ? -1 : ((uint64_t(1) << bits) - 1));
for (int i = 0; i < iter; ++i) {
bool overfill = iter & 1; minisketch* state = minisketch_create(bits, impl, count);
if (!state) return;
minisketch* basestate = minisketch_create(bits, 0, count);
for (int j = 0; j < count + 3*overfill; ++j) {
uint64_t r = dist(rnd);
if (!overfill) elems[j] = r;
minisketch_add_uint64(state, r);
minisketch_add_uint64(basestate, r);
}
roots.assign(count + 1, 0);
std::vector<unsigned char> data, basedata;
basedata.resize(((count + 1) * bits + 7) / 8);
data.resize(((count + 1) * bits + 7) / 8);
minisketch_serialize(basestate, basedata.data());
minisketch_serialize(state, data.data());
CHECK(data == basedata);
minisketch_deserialize(state, basedata.data());
int num_roots = minisketch_decode(state, count + 1, roots.data());
CHECK(overfill || num_roots >= 0);
CHECK(num_roots < 1 || minisketch_decode(state, num_roots - 1, roots.data()) == -1); if (!overfill) {
std::sort(roots.begin(), roots.begin() + num_roots);
std::sort(elems.begin(), elems.end());
int expected = elems.size();
for (size_t pos = 0; pos < elems.size(); ++pos) {
if (pos + 1 < elems.size() && elems[pos] == elems[pos + 1]) {
expected -= 2;
elems[pos] = 0;
elems[pos + 1] = 0;
++pos;
}
}
CHECK(num_roots == expected);
std::sort(elems.begin(), elems.end());
CHECK(std::equal(roots.begin(), roots.begin() + num_roots, elems.end() - expected));
}
minisketch_destroy(state);
minisketch_destroy(basestate);
}
}
int main(void) {
for (int j = 2; j <= 64; j += 1) {
fprintf(stderr, "%i random tests with %i bits:\n", 500 / j, j);
TestRand(j, 0, 150, 500 / j);
TestRand(j, 1, 150, 500 / j);
TestRand(j, 2, 150, 500 / j);
fprintf(stderr, "%i random tests with %i bits: done\n", 500 / j, j);
}
int counts[65] = {0};
for (int bits = 0; bits < 65; ++bits) {
counts[bits] = 1;
}
for (int weight = 0; weight <= 40; weight += 1) {
for (int bits = 2; bits <= 32; ++bits) {
int count = counts[bits];
while (count < (1 << bits) && count * bits <= weight) {
auto ret = TestAll(bits, 0, count, 4);
auto ret2 = TestAll(bits, 1, count, 4);
auto ret3 = TestAll(bits, 2, count, 4);
CHECK(ret2.empty() || ret == ret2);
CHECK(ret3.empty() || ret == ret3);
fprintf(stderr, "bits=%i count=%i below_bound=[", bits, count);
for (int i = 0; i <= count; ++i) {
if (i) fprintf(stderr, ",");
fprintf(stderr, "%llu", (unsigned long long)ret[i + 1]);
}
fprintf(stderr, "] above_bound=[");
for (int i = count + 1; i + 1 < (int)ret.size(); ++i) {
if (i > count + 1) fprintf(stderr, ",");
fprintf(stderr, "%llu/%llu", (unsigned long long)ret[i + 1], (unsigned long long)Combination((uint64_t(1) << bits) - 1, i));
}
fprintf(stderr, "] nodecode=[%g]\n", (double)ret[0] * pow(0.5, bits * count));
++count;
}
counts[bits] = count;
}
}
return 0;
}