#ifndef rpx_shader_h
#define rpx_shader_h
#include <metal_stdlib>
#include "../felt_u64.h.metal"
#include "common.h.metal"
using namespace metal;
namespace p18446744069414584321 {
inline void apply_fb_round(threadgroup Fp* shared, unsigned round, unsigned local_state_offset) {
apply_mds_freq(shared, local_state_offset);
#pragma unroll
for (unsigned j = 0; j < STATE_WIDTH; j++) {
// Apply constants (round 1)
Fp new_val = shared[local_state_offset + j];
new_val = new_val + ROUND_CONSTANTS_0[round * STATE_WIDTH + j];
// Apply S-Box
shared[local_state_offset + j] = new_val.pow7();
}
apply_mds_freq(shared, local_state_offset);
#pragma unroll
for (unsigned j = 0; j < STATE_WIDTH; j++) {
// Apply constants (round 2)
Fp new_val = shared[local_state_offset + j];
new_val = new_val + ROUND_CONSTANTS_1[round * STATE_WIDTH + j];
shared[local_state_offset + j] = new_val.pow10540996611094048183();
}
}
inline void apply_e_round(threadgroup Fp* shared, unsigned round, unsigned local_state_offset) {
#pragma unroll
for (unsigned j = 0; j < STATE_WIDTH; j++) {
// Apply constants (round 1)
Fp new_val = shared[local_state_offset + j];
new_val = new_val + ROUND_CONSTANTS_0[round * STATE_WIDTH + j];
shared[local_state_offset + j] = new_val;
}
// Decompose the state into 4 elements in the cubic extension field q (Fp3), and raise to the power of 7
Fq3 ext0 = Fq3(shared[local_state_offset + 0], shared[local_state_offset + 1], shared[local_state_offset + 2]).pow(7);
Fq3 ext1 = Fq3(shared[local_state_offset + 3], shared[local_state_offset + 4], shared[local_state_offset + 5]).pow(7);
Fq3 ext2 = Fq3(shared[local_state_offset + 6], shared[local_state_offset + 7], shared[local_state_offset + 8]).pow(7);
Fq3 ext3 = Fq3(shared[local_state_offset + 9], shared[local_state_offset + 10], shared[local_state_offset + 11]).pow(7);
// Decompose back to the 12 base field elements
#pragma unroll
for (int i = 0; i < 4; i++) {
// Decompose ext0, ext1, ext2, ext3
Fq3Decomposed decomposed;
switch (i) {
case 0: decomposed = ext0.decompose(); break;
case 1: decomposed = ext1.decompose(); break;
case 2: decomposed = ext2.decompose(); break;
case 3: decomposed = ext3.decompose(); break;
}
// Write back into state
shared[local_state_offset + 3 * i] = decomposed.c0;
shared[local_state_offset + 3 * i + 1] = decomposed.c1;
shared[local_state_offset + 3 * i + 2] = decomposed.c2;
}
}
inline void apply_final_round(threadgroup Fp* shared, unsigned round, unsigned local_state_offset) {
apply_mds_freq(shared, local_state_offset);
#pragma unroll
for (unsigned j = 0; j < STATE_WIDTH; j++) {
// Apply constants (round 1)
Fp new_val = shared[local_state_offset + j];
new_val = new_val + ROUND_CONSTANTS_0[round * STATE_WIDTH + j];
shared[local_state_offset + j] = new_val;
}
}
// Apply permute
void rpx_permute(threadgroup Fp* shared, unsigned local_state_offset) {
// Apply (FB) round
apply_fb_round(shared, 0, local_state_offset);
// Apply (E) round
apply_e_round(shared, 1, local_state_offset);
// Apply (FB) round
apply_fb_round(shared, 2, local_state_offset);
// Apply (E) round
apply_e_round(shared, 3, local_state_offset);
// Apply (FB) round
apply_fb_round(shared, 4, local_state_offset);
// Apply (E) round
apply_e_round(shared, 5, local_state_offset);
// Apply Final round
apply_final_round(shared, 6, local_state_offset);
}
// TODO: This should be moved to a new header file and made generic over hasher
// Rescue Prime eXtended hash function for 128 bit security: https://eprint.iacr.org/2023/1045.pdf
// Absorbs 8 columns of equal length. Hashes are generated row-wise.
[[ host_name("rpx_256_absorb_columns_and_permute_p18446744069414584321_fp") ]] kernel void
Rpx256AbsorbColumnsAndPermute(constant Fp *col0 [[ buffer(0) ]],
constant Fp *col1 [[ buffer(1) ]],
constant Fp *col2 [[ buffer(2) ]],
constant Fp *col3 [[ buffer(3) ]],
constant Fp *col4 [[ buffer(4) ]],
constant Fp *col5 [[ buffer(5) ]],
constant Fp *col6 [[ buffer(6) ]],
constant Fp *col7 [[ buffer(7) ]],
device Rp256PartialState *states [[ buffer(8) ]],
device Rp256Digest *digests [[ buffer(9) ]],
threadgroup Fp *shared [[ threadgroup(0) ]],
unsigned global_id [[ thread_position_in_grid ]],
unsigned local_id [[ thread_index_in_threadgroup ]]) {
// load hasher state
unsigned local_state_offset = local_id * STATE_WIDTH * 2;
*((threadgroup Rp256PartialState*) (shared + local_state_offset)) = states[global_id];
// absorb the input into the state
shared[local_state_offset + CAPACITY + 0] = col0[global_id];
shared[local_state_offset + CAPACITY + 1] = col1[global_id];
shared[local_state_offset + CAPACITY + 2] = col2[global_id];
shared[local_state_offset + CAPACITY + 3] = col3[global_id];
shared[local_state_offset + CAPACITY + 4] = col4[global_id];
shared[local_state_offset + CAPACITY + 5] = col5[global_id];
shared[local_state_offset + CAPACITY + 6] = col6[global_id];
shared[local_state_offset + CAPACITY + 7] = col7[global_id];
rpx_permute(shared, local_state_offset);
// TODO: add flag to only write to one of these buffers
// redundant writes here are neglegable on performance <1%
digests[global_id] = *((threadgroup Rp256Digest*) (shared + local_state_offset + CAPACITY));
states[global_id] = *((threadgroup Rp256PartialState*) (shared + local_state_offset));
}
// TODO: This should be moved to a new header file and made generic over hasher
// Rescue Prime Optimized hash function for 128 bit security: https://eprint.iacr.org/2023/1045.pdf
// Absorbs 8 column of equal length in row major order. Hashes are generated row-wise.
[[ host_name("rpx_256_absorb_rows_and_permute_p18446744069414584321_fp") ]] kernel void
Rpx256AbsorbRowsAndPermute(constant Fp *rows [[ buffer(0) ]],
device Rp256PartialState *states [[ buffer(1) ]],
device Rp256Digest *digests [[ buffer(2) ]],
threadgroup Fp *shared [[ threadgroup(0) ]],
unsigned tg_size [[ threads_per_threadgroup ]],
unsigned tg_id [[ threadgroup_position_in_grid ]],
unsigned global_id [[ thread_position_in_grid ]],
unsigned local_id [[ thread_index_in_threadgroup ]]) {
// load hasher state
unsigned local_state_offset = local_id * STATE_WIDTH * 2;
*((threadgroup Rp256PartialState*) (shared + local_state_offset)) = states[global_id];
// absorb the input into the state (done like this for coalleced reads)
unsigned tg_offset = tg_size * tg_id * 8;
for (unsigned i = 0; i < 8; i++) {
unsigned local_offset = tg_size * i + local_id;
unsigned hasher_id = local_offset / 8;
unsigned hasher_offset = local_offset % 8;
shared[hasher_id * STATE_WIDTH * 2 + CAPACITY + hasher_offset] = rows[tg_offset + local_offset];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
rpx_permute(shared, local_state_offset);
// TODO: add flag to only write to one of these buffers
// redundant writes here are neglegable on performance <1%
digests[global_id] = *((threadgroup Rp256Digest*) (shared + local_state_offset + CAPACITY));
states[global_id] = *((threadgroup Rp256PartialState*) (shared + local_state_offset));
}
// Generates the first row of merkle tree after the leaf nodes
// using Rescue Prime eXtended hash function.
[[ host_name("rpx_128_gen_merkle_nodes_first_row_p18446744069414584321_fp") ]] kernel void
Rpx256GenMerkleNodesFirstRow(constant Rp256DigestPair *leaves [[ buffer(0) ]],
device Rp256Digest *nodes [[ buffer(1) ]],
threadgroup Fp *shared [[ threadgroup(0) ]],
unsigned global_id [[ thread_position_in_grid ]],
unsigned local_id [[ thread_index_in_threadgroup ]]) {
// fetch state
// *((threadgroup Rp256PartialState*) (shared + local_state_offset)) = { .s0 = Fp(0); .s1 = Fp(0); .s2 = Fp(0); .s3 = Fp(0) };
unsigned local_state_offset = local_id * STATE_WIDTH * 2;
shared[local_state_offset + 0] = Fp(0);
shared[local_state_offset + 1] = Fp(0);
shared[local_state_offset + 2] = Fp(0);
shared[local_state_offset + 3] = Fp(0);
// absorb children as input
*((threadgroup Rp256DigestPair*) (shared + local_state_offset + CAPACITY)) = leaves[global_id];
rpx_permute(shared, local_state_offset);
// write digest
nodes[N / 2 + global_id] = *((threadgroup Rp256Digest*) (shared + local_state_offset + CAPACITY));
}
// Generates a row of merkle tree nodes using the Rescue Prime Optimized hash function.
[[ host_name("rpx_128_gen_merkle_nodes_row_p18446744069414584321_fp") ]] kernel void
Rpx256GenMerkleNodesRow(device Rp256Digest *nodes [[ buffer(0) ]],
constant unsigned &round [[ buffer(1) ]],
threadgroup Fp *shared [[ threadgroup(0) ]],
unsigned global_id [[ thread_position_in_grid ]],
unsigned local_id [[ thread_index_in_threadgroup ]]) {
// fetch state
// *((threadgroup Rp256PartialState*) (shared + local_state_offset)) = { .s0 = Fp(0); .s1 = Fp(0); .s2 = Fp(0); .s3 = Fp(0) };
unsigned local_state_offset = local_id * STATE_WIDTH * 2;
shared[local_state_offset + 0] = Fp(0);
shared[local_state_offset + 1] = Fp(0);
shared[local_state_offset + 2] = Fp(0);
shared[local_state_offset + 3] = Fp(0);
// absorb children as input
// TODO: c++ cast for readability
*((threadgroup Rp256DigestPair*) (shared + local_state_offset + CAPACITY)) = ((device Rp256DigestPair*) nodes)[(N >> round) + global_id];
rpx_permute(shared, local_state_offset);
// write digest
nodes[(N >> round) + global_id] = *((threadgroup Rp256Digest*) (shared + local_state_offset + CAPACITY));
}
}
#endif /* rpx_shader_h */