#pragma once
#include "fields/kb31_t.cuh"
#include "poseidon2/poseidon2_kb31_16.cuh"
#include "poseidon2/poseidon2_bn254_3.cuh"
#include "poseidon2/poseidon2.cuh"
#include "fields/kb31_extension_t.cuh"
#include "fields/bn254_t.cuh"
extern "C" void* grind_koala_bear();
extern "C" void* grind_multi_field32();
class DuplexChallenger {
static constexpr const int WIDTH = poseidon2_kb31_16::KoalaBear::WIDTH;
static constexpr const int RATE = poseidon2_kb31_16::constants::RATE;
kb31_t* sponge_state;
kb31_t* input_buffer;
size_t* buffer_sizes;
kb31_t* output_buffer;
__device__ void duplexing() {
// Assert input size doesn't exceed RATE
assert(buffer_sizes[0] <= RATE);
// Copy input buffer elements to sponge state
for (size_t i = 0; i < buffer_sizes[0]; i++) {
sponge_state[i] = input_buffer[i];
}
// Clear input buffer.
buffer_sizes[0] = 0;
// Apply the permutation to the sponge state and store the output in the output buffer.
poseidon2::KoalaBearHasher hasher;
hasher.permute(sponge_state, output_buffer);
// Copy the output buffer to the sponge state.
buffer_sizes[1] = RATE;
for (size_t i = 0; i < WIDTH; i++) {
sponge_state[i] = output_buffer[i];
if (i >= RATE) {
output_buffer[i] = kb31_t::zero();
}
}
}
public:
static constexpr const size_t NUM_ELEMENTS = WIDTH + 2 * RATE;
__device__ __forceinline__ kb31_t getVal(size_t idx) { return sponge_state[idx % 16]; }
__device__ __forceinline__ DuplexChallenger load(kb31_t* shared, size_t* buffer_sizes) {
DuplexChallenger challenger;
challenger.sponge_state = shared;
challenger.input_buffer = shared + WIDTH;
challenger.output_buffer = shared + WIDTH + RATE;
challenger.buffer_sizes = buffer_sizes;
return challenger;
}
__device__ __forceinline__ void observe(kb31_t* value) {
// Clear the output buffer.
buffer_sizes[1] = 0;
// Push value to the input buffer.
buffer_sizes[0] += 1;
input_buffer[buffer_sizes[0] - 1] = *value;
if (buffer_sizes[0] == RATE) {
duplexing();
}
}
__device__ __forceinline__ void observe_ext(kb31_extension_t* value) {
#pragma unroll
for (size_t i = 0; i < kb31_extension_t::D; i++) {
observe(&value->value[i]);
}
}
__device__ __forceinline__ kb31_t sample() {
kb31_t result;
if (buffer_sizes[0] != 0 || buffer_sizes[1] == 0) {
duplexing();
}
// Pop the last element of the buffer.
result = output_buffer[buffer_sizes[1] - 1];
buffer_sizes[1] -= 1;
return result;
}
__device__ __forceinline__ kb31_extension_t sample_ext() {
kb31_extension_t result;
for (size_t i = 0; i < kb31_extension_t::D; i++) {
result.value[i] = sample();
}
return result;
}
__device__ __forceinline__ size_t sample_bits(size_t bits) {
kb31_t rand_f = sample();
// Equivalent to "as_canonical_u32" in the Rust implementation.
size_t rand_usize = rand_f.as_canonical_u32();
return rand_usize & ((1 << bits) - 1);
}
__device__ __forceinline__ bool check_witness(size_t bits, kb31_t* witness) {
observe(witness);
return sample_bits(bits) == 0;
}
__device__ __forceinline__ void
grind(size_t bits, kb31_t* result, volatile bool* found_flag, size_t n) {
size_t idx = threadIdx.x + blockIdx.x * blockDim.x;
size_t original_buffer_size = buffer_sizes[0];
size_t original_output_buffer_size = buffer_sizes[1];
__shared__ kb31_t challenger_state[NUM_ELEMENTS];
if (threadIdx.x == 0) {
for (size_t j = 0; j < WIDTH; j++) {
challenger_state[j] = sponge_state[j];
}
for (size_t j = WIDTH; j < WIDTH + RATE; j++) {
challenger_state[j] = input_buffer[j - WIDTH];
}
for (size_t j = WIDTH + RATE; j < NUM_ELEMENTS; j++) {
challenger_state[j] = output_buffer[j - WIDTH - RATE];
}
}
// Ensure all threads see the shared memory initialized
__syncthreads();
// Local copy of challenger state for each thread in each iteration.
kb31_t local_state[NUM_ELEMENTS];
size_t buffer_sizes[2];
for (size_t i = idx; i < n && !*found_flag; i += blockDim.x * gridDim.x) {
buffer_sizes[0] = original_buffer_size;
buffer_sizes[1] = original_output_buffer_size;
// Reset the local state to the shared state.
for (size_t j = 0; j < NUM_ELEMENTS; j++) {
local_state[j] = challenger_state[j];
}
DuplexChallenger temp_challenger = load(local_state, buffer_sizes);
kb31_t witness = kb31_t((int)i);
if (temp_challenger.check_witness(bits, &witness)) {
result[0] = witness;
atomicExch((int*)found_flag, 1);
__threadfence();
return;
}
}
}
};
class MultiField32Challenger {
static constexpr const int WIDTH = poseidon2_bn254_3::Bn254::WIDTH;
static constexpr const int RATE = poseidon2_bn254_3::constants::RATE;
bn254_t* sponge_state;
kb31_t* input_buffer;
size_t* buffer_sizes;
kb31_t* output_buffer;
__device__ void duplexing() {
// Assert input size doesn't exceed RATE
assert(buffer_sizes[3] == 4);
assert(buffer_sizes[0] <= buffer_sizes[2] * RATE);
// Copy input buffer elements to sponge state
for (size_t i = 0; i < buffer_sizes[0]; i += buffer_sizes[2]) {
size_t end = min(buffer_sizes[0], i + buffer_sizes[2]);
bn254_t reduced =
poseidon2_bn254_3::reduceKoalaBear(input_buffer + i, nullptr, end - i, 0);
sponge_state[i / buffer_sizes[2]] = reduced;
}
// Clear input buffer.
buffer_sizes[0] = 0;
// Apply the permutation to the sponge state and store the output in the output buffer.
poseidon2::Bn254Hasher hasher;
bn254_t next_state[WIDTH];
for (size_t i = 0; i < WIDTH; i++) {
next_state[i].set_to_zero();
}
hasher.permute(sponge_state, next_state);
// Copy the output buffer to the sponge state.
buffer_sizes[1] = RATE * buffer_sizes[3];
for (size_t i = 0; i < WIDTH; i++) {
sponge_state[i] = next_state[i];
bn254_t x = next_state[i];
x.from();
if (i < RATE) {
uint32_t v0 =
(uint32_t)(((uint64_t)(x[0]) + (uint64_t(1) << 32) * (uint64_t)(x[1])) %
0x7f000001);
uint32_t v1 =
(uint32_t)(((uint64_t)(x[2]) + (uint64_t(1) << 32) * (uint64_t)(x[3])) %
0x7f000001);
uint32_t v2 =
(uint32_t)(((uint64_t)(x[4]) + (uint64_t(1) << 32) * (uint64_t)(x[5])) %
0x7f000001);
uint32_t v3 =
(uint32_t)(((uint64_t)(x[6]) + (uint64_t(1) << 32) * (uint64_t)(x[7])) %
0x7f000001);
output_buffer[i * 4] = kb31_t::from_canonical_u32(v0);
output_buffer[i * 4 + 1] = kb31_t::from_canonical_u32(v1);
output_buffer[i * 4 + 2] = kb31_t::from_canonical_u32(v2);
output_buffer[i * 4 + 3] = kb31_t::from_canonical_u32(v3);
}
}
}
// Number of KB31 elements for input buffer (WIDTH * num_duplex_elms = 3 * 8 = 24)
static constexpr const int INPUT_BUFFER_SIZE = 24;
// Number of KB31 elements for output buffer (WIDTH * num_f_elms = 3 * 4 = 12)
static constexpr const int OUTPUT_BUFFER_SIZE = 12;
// Number of buffer size entries
static constexpr const int NUM_BUFFER_SIZES = 4;
public:
__device__ __forceinline__ MultiField32Challenger load(
bn254_t* sponge_shared, kb31_t* input_shared, kb31_t* output_shared,
size_t* buffer_sizes) {
MultiField32Challenger challenger;
challenger.sponge_state = sponge_shared;
challenger.input_buffer = input_shared;
challenger.output_buffer = output_shared;
challenger.buffer_sizes = buffer_sizes;
return challenger;
}
__device__ __forceinline__ void observe(kb31_t* value) {
// Clear the output buffer.
buffer_sizes[1] = 0;
// Push value to the input buffer.
buffer_sizes[0] += 1;
input_buffer[buffer_sizes[0] - 1] = *value;
if (buffer_sizes[0] == buffer_sizes[2] * RATE) {
duplexing();
}
}
__device__ __forceinline__ void observe_ext(kb31_extension_t* value) {
#pragma unroll
for (size_t i = 0; i < kb31_extension_t::D; i++) {
observe(&value->value[i]);
}
}
__device__ __forceinline__ kb31_t sample() {
kb31_t result;
if (buffer_sizes[0] != 0 || buffer_sizes[1] == 0) {
duplexing();
}
// Pop the last element of the buffer.
result = output_buffer[buffer_sizes[1] - 1];
buffer_sizes[1] -= 1;
return result;
}
__device__ __forceinline__ kb31_extension_t sample_ext() {
kb31_extension_t result;
for (size_t i = 0; i < kb31_extension_t::D; i++) {
result.value[i] = sample();
}
return result;
}
__device__ __forceinline__ size_t sample_bits(size_t bits) {
kb31_t rand_f = sample();
size_t rand_usize = rand_f.as_canonical_u32();
return rand_usize & ((1 << bits) - 1);
}
__device__ __forceinline__ bool check_witness(size_t bits, kb31_t* witness) {
observe(witness);
return sample_bits(bits) == 0;
}
__device__ __forceinline__ void
grind(size_t bits, kb31_t* result, volatile bool* found_flag, size_t n) {
size_t idx = threadIdx.x + blockIdx.x * blockDim.x;
size_t original_buffer_sizes[NUM_BUFFER_SIZES];
for (size_t j = 0; j < NUM_BUFFER_SIZES; j++) {
original_buffer_sizes[j] = buffer_sizes[j];
}
__shared__ bn254_t shared_sponge_state[WIDTH];
__shared__ kb31_t shared_input_buffer[INPUT_BUFFER_SIZE];
__shared__ kb31_t shared_output_buffer[OUTPUT_BUFFER_SIZE];
if (threadIdx.x == 0) {
for (size_t j = 0; j < WIDTH; j++) {
shared_sponge_state[j] = sponge_state[j];
}
for (size_t j = 0; j < INPUT_BUFFER_SIZE; j++) {
shared_input_buffer[j] = input_buffer[j];
}
for (size_t j = 0; j < OUTPUT_BUFFER_SIZE; j++) {
shared_output_buffer[j] = output_buffer[j];
}
}
// Ensure all threads see the shared memory initialized
__syncthreads();
// Local copy of challenger state for each thread in each iteration.
bn254_t local_sponge_state[WIDTH];
kb31_t local_input_buffer[INPUT_BUFFER_SIZE];
kb31_t local_output_buffer[OUTPUT_BUFFER_SIZE];
size_t local_buffer_sizes[NUM_BUFFER_SIZES];
for (size_t i = idx; i < n && !*found_flag; i += blockDim.x * gridDim.x) {
// Reset local state from shared memory.
for (size_t j = 0; j < NUM_BUFFER_SIZES; j++) {
local_buffer_sizes[j] = original_buffer_sizes[j];
}
for (size_t j = 0; j < WIDTH; j++) {
local_sponge_state[j] = shared_sponge_state[j];
}
for (size_t j = 0; j < INPUT_BUFFER_SIZE; j++) {
local_input_buffer[j] = shared_input_buffer[j];
}
for (size_t j = 0; j < OUTPUT_BUFFER_SIZE; j++) {
local_output_buffer[j] = shared_output_buffer[j];
}
MultiField32Challenger temp_challenger = load(
local_sponge_state, local_input_buffer, local_output_buffer, local_buffer_sizes);
kb31_t witness = kb31_t((int)i);
if (temp_challenger.check_witness(bits, &witness)) {
result[0] = witness;
atomicExch((int*)found_flag, 1);
__threadfence();
return;
}
}
}
};