miden-gpu 0.6.0

GPU acceleration for the Miden VM prover
Documentation
#ifndef rpx_shader_h
#define rpx_shader_h

#include <metal_stdlib>
#include "../felt_u64.h.metal"
#include "common.h.metal"
using namespace metal;

namespace p18446744069414584321 {

    inline void apply_fb_round(threadgroup Fp* shared, unsigned round, unsigned local_state_offset) {
        apply_mds_freq(shared, local_state_offset);
        #pragma unroll
        for (unsigned j = 0; j < STATE_WIDTH; j++) {
            // Apply constants (round 1)
            Fp new_val = shared[local_state_offset + j];
            new_val = new_val + ROUND_CONSTANTS_0[round * STATE_WIDTH + j];
            // Apply S-Box
            shared[local_state_offset + j] = new_val.pow7();
        }

        apply_mds_freq(shared, local_state_offset);
        #pragma unroll
        for (unsigned j = 0; j < STATE_WIDTH; j++) {
            // Apply constants (round 2)
            Fp new_val = shared[local_state_offset + j];
            new_val = new_val + ROUND_CONSTANTS_1[round * STATE_WIDTH + j];
            shared[local_state_offset + j] = new_val.pow10540996611094048183();
        }
    }

    inline void apply_e_round(threadgroup Fp* shared, unsigned round, unsigned local_state_offset) {
        #pragma unroll
        for (unsigned j = 0; j < STATE_WIDTH; j++) {
            // Apply constants (round 1)
            Fp new_val = shared[local_state_offset + j];
            new_val = new_val + ROUND_CONSTANTS_0[round * STATE_WIDTH + j];
            shared[local_state_offset + j] = new_val;
        }

        // Decompose the state into 4 elements in the cubic extension field q (Fp3), and raise to the power of 7
        Fq3 ext0 = Fq3(shared[local_state_offset + 0], shared[local_state_offset + 1], shared[local_state_offset + 2]).pow(7);
        Fq3 ext1 = Fq3(shared[local_state_offset + 3], shared[local_state_offset + 4], shared[local_state_offset + 5]).pow(7);
        Fq3 ext2 = Fq3(shared[local_state_offset + 6], shared[local_state_offset + 7], shared[local_state_offset + 8]).pow(7);
        Fq3 ext3 = Fq3(shared[local_state_offset + 9], shared[local_state_offset + 10], shared[local_state_offset + 11]).pow(7);

        // Decompose back to the 12 base field elements
        #pragma unroll
        for (int i = 0; i < 4; i++) {
            // Decompose ext0, ext1, ext2, ext3
            Fq3Decomposed decomposed;
            switch (i) {
                case 0: decomposed = ext0.decompose(); break;
                case 1: decomposed = ext1.decompose(); break;
                case 2: decomposed = ext2.decompose(); break;
                case 3: decomposed = ext3.decompose(); break;
            }

            // Write back into state
            shared[local_state_offset + 3 * i]     = decomposed.c0;
            shared[local_state_offset + 3 * i + 1] = decomposed.c1;
            shared[local_state_offset + 3 * i + 2] = decomposed.c2;
        }
    }

    inline void apply_final_round(threadgroup Fp* shared, unsigned round, unsigned local_state_offset) {
        apply_mds_freq(shared, local_state_offset);
        #pragma unroll
        for (unsigned j = 0; j < STATE_WIDTH; j++) {
            // Apply constants (round 1)
            Fp new_val = shared[local_state_offset + j];
            new_val = new_val + ROUND_CONSTANTS_0[round * STATE_WIDTH + j];
            shared[local_state_offset + j] = new_val;
        }
    }

    // Apply permute
    void rpx_permute(threadgroup Fp* shared, unsigned local_state_offset) {
        // Apply (FB) round
        apply_fb_round(shared, 0, local_state_offset);
        // Apply (E) round
        apply_e_round(shared, 1, local_state_offset);
        // Apply (FB) round
        apply_fb_round(shared, 2, local_state_offset);
        // Apply (E) round
        apply_e_round(shared, 3, local_state_offset);
        // Apply (FB) round
        apply_fb_round(shared, 4, local_state_offset);
        // Apply (E) round
        apply_e_round(shared, 5, local_state_offset);
        // Apply Final round
        apply_final_round(shared, 6, local_state_offset);
    }

    // TODO: This should be moved to a new header file and made generic over hasher
    // Rescue Prime eXtended hash function for 128 bit security: https://eprint.iacr.org/2023/1045.pdf
    // Absorbs 8 columns of equal length. Hashes are generated row-wise.
    [[ host_name("rpx_256_absorb_columns_and_permute_p18446744069414584321_fp") ]] kernel void 
    Rpx256AbsorbColumnsAndPermute(constant Fp *col0 [[ buffer(0) ]],
            constant Fp *col1 [[ buffer(1) ]],
            constant Fp *col2 [[ buffer(2) ]],
            constant Fp *col3 [[ buffer(3) ]],
            constant Fp *col4 [[ buffer(4) ]],
            constant Fp *col5 [[ buffer(5) ]],
            constant Fp *col6 [[ buffer(6) ]],
            constant Fp *col7 [[ buffer(7) ]],
            device Rp256PartialState *states [[ buffer(8) ]],
            device Rp256Digest *digests [[ buffer(9) ]],
            threadgroup Fp *shared [[ threadgroup(0) ]],
            unsigned global_id [[ thread_position_in_grid ]],
            unsigned local_id [[ thread_index_in_threadgroup ]]) {     
        // load hasher state
        unsigned local_state_offset = local_id * STATE_WIDTH * 2;
        *((threadgroup Rp256PartialState*) (shared + local_state_offset)) = states[global_id];
        // absorb the input into the state
        shared[local_state_offset + CAPACITY + 0] = col0[global_id];
        shared[local_state_offset + CAPACITY + 1] = col1[global_id];
        shared[local_state_offset + CAPACITY + 2] = col2[global_id];
        shared[local_state_offset + CAPACITY + 3] = col3[global_id];
        shared[local_state_offset + CAPACITY + 4] = col4[global_id];
        shared[local_state_offset + CAPACITY + 5] = col5[global_id];
        shared[local_state_offset + CAPACITY + 6] = col6[global_id];
        shared[local_state_offset + CAPACITY + 7] = col7[global_id];

        rpx_permute(shared, local_state_offset);

        // TODO: add flag to only write to one of these buffers
        // redundant writes here are neglegable on performance <1%
        digests[global_id] = *((threadgroup Rp256Digest*) (shared + local_state_offset + CAPACITY));
        states[global_id] = *((threadgroup Rp256PartialState*) (shared + local_state_offset));
    }

    // TODO: This should be moved to a new header file and made generic over hasher
    // Rescue Prime Optimized hash function for 128 bit security: https://eprint.iacr.org/2023/1045.pdf
    // Absorbs 8 column of equal length in row major order. Hashes are generated row-wise.
    [[ host_name("rpx_256_absorb_rows_and_permute_p18446744069414584321_fp") ]] kernel void 
    Rpx256AbsorbRowsAndPermute(constant Fp *rows [[ buffer(0) ]],
            device Rp256PartialState *states [[ buffer(1) ]],
            device Rp256Digest *digests [[ buffer(2) ]],
            threadgroup Fp *shared [[ threadgroup(0) ]],
            unsigned tg_size [[ threads_per_threadgroup ]],
            unsigned tg_id [[ threadgroup_position_in_grid ]],
            unsigned global_id [[ thread_position_in_grid ]],
            unsigned local_id [[ thread_index_in_threadgroup ]]) {     
        // load hasher state
        unsigned local_state_offset = local_id * STATE_WIDTH * 2;
        *((threadgroup Rp256PartialState*) (shared + local_state_offset)) = states[global_id];
        // absorb the input into the state (done like this for coalleced reads)
        unsigned tg_offset = tg_size * tg_id * 8;
        for (unsigned i = 0; i < 8; i++) {
            unsigned local_offset = tg_size * i + local_id;
            unsigned hasher_id = local_offset / 8;
            unsigned hasher_offset = local_offset % 8;
            shared[hasher_id * STATE_WIDTH * 2 + CAPACITY + hasher_offset] = rows[tg_offset + local_offset];
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        rpx_permute(shared, local_state_offset);

        // TODO: add flag to only write to one of these buffers
        // redundant writes here are neglegable on performance <1%
        digests[global_id] = *((threadgroup Rp256Digest*) (shared + local_state_offset + CAPACITY));
        states[global_id] = *((threadgroup Rp256PartialState*) (shared + local_state_offset));
    }

    // Generates the first row of merkle tree after the leaf nodes
    // using Rescue Prime eXtended hash function.
    [[ host_name("rpx_128_gen_merkle_nodes_first_row_p18446744069414584321_fp") ]] kernel void
    Rpx256GenMerkleNodesFirstRow(constant Rp256DigestPair *leaves [[ buffer(0) ]],
            device Rp256Digest *nodes [[ buffer(1) ]],
            threadgroup Fp *shared [[ threadgroup(0) ]],
            unsigned global_id [[ thread_position_in_grid ]],
            unsigned local_id [[ thread_index_in_threadgroup ]]) {
        // fetch state
        // *((threadgroup Rp256PartialState*) (shared + local_state_offset)) = { .s0 = Fp(0); .s1 = Fp(0); .s2 = Fp(0); .s3 = Fp(0) };
        unsigned local_state_offset = local_id * STATE_WIDTH * 2;
        shared[local_state_offset + 0] = Fp(0);
        shared[local_state_offset + 1] = Fp(0);
        shared[local_state_offset + 2] = Fp(0);
        shared[local_state_offset + 3] = Fp(0);
        // absorb children as input
        *((threadgroup Rp256DigestPair*) (shared + local_state_offset + CAPACITY)) = leaves[global_id];

        rpx_permute(shared, local_state_offset);

        // write digest
        nodes[N / 2 + global_id] = *((threadgroup Rp256Digest*) (shared + local_state_offset + CAPACITY));
    }

    // Generates a row of merkle tree nodes using the Rescue Prime Optimized hash function.
    [[ host_name("rpx_128_gen_merkle_nodes_row_p18446744069414584321_fp") ]] kernel void
    Rpx256GenMerkleNodesRow(device Rp256Digest *nodes [[ buffer(0) ]],
            constant unsigned &round [[ buffer(1) ]],
            threadgroup Fp *shared [[ threadgroup(0) ]],
            unsigned global_id [[ thread_position_in_grid ]],
            unsigned local_id [[ thread_index_in_threadgroup ]]) {
        // fetch state
        // *((threadgroup Rp256PartialState*) (shared + local_state_offset)) = { .s0 = Fp(0); .s1 = Fp(0); .s2 = Fp(0); .s3 = Fp(0) };
        unsigned local_state_offset = local_id * STATE_WIDTH * 2;
        shared[local_state_offset + 0] = Fp(0);
        shared[local_state_offset + 1] = Fp(0);
        shared[local_state_offset + 2] = Fp(0);
        shared[local_state_offset + 3] = Fp(0);
        // absorb children as input
        // TODO: c++ cast for readability
        *((threadgroup Rp256DigestPair*) (shared + local_state_offset + CAPACITY)) = ((device Rp256DigestPair*) nodes)[(N >> round) + global_id];

        rpx_permute(shared, local_state_offset);

        // write digest
        nodes[(N >> round) + global_id] = *((threadgroup Rp256Digest*) (shared + local_state_offset + CAPACITY));
    }
}


#endif /* rpx_shader_h */