sp1-gpu-sys 6.2.1

FFI bindings and CUDA build system for SP1-GPU
Documentation
#pragma once

#include "poseidon2/poseidon2_kb31_16.cuh"
#include "poseidon2/poseidon2_bn254_3.cuh"

namespace poseidon2 {

template <typename Params>
struct __align__(16) FDW_t {
    using F_t = typename Params::F_t;
    F_t v[Params::DIGEST_WIDTH];
};

template <typename Params>
struct RoundConstants {
    using F_t = typename Params::F_t;
    using pF_t = typename Params::pF_t;

    pF_t* internalRoundConstants;
    pF_t* externalRoundConstants;
    pF_t* matInternalDiagM1;
    pF_t montyInverse;
};

template <typename Params>
class Hasher {
    using F_t = typename Params::F_t;
    using pF_t = typename Params::pF_t;
    using FDW_t = FDW_t<Params>;
    using RoundConstants_t = RoundConstants<Params>;

  private:
    __device__ static void addExtRc(F_t state[Params::WIDTH], pF_t rc[Params::WIDTH]) {
        for (int i = 0; i < Params::WIDTH; i++) {
            state[i] += rc[i];
        }
    }

    __device__ static void sbox(F_t state[Params::WIDTH]) {
        for (int i = 0; i < Params::WIDTH; i++) {
            state[i] ^= Params::D;
        }
    }

  public:
    __device__ static void
    permute(F_t in[Params::WIDTH], F_t out[Params::WIDTH], RoundConstants_t roundConstants) {
        F_t state[Params::WIDTH];
        for (int i = 0; i < Params::WIDTH; i++) {
            state[i] = in[i];
        }

        Params::externalLinearLayer(state);

        int rounds_f_half = Params::ROUNDS_F >> 1;
        for (int i = 0; i < rounds_f_half; i++) {
            addExtRc(state, roundConstants.externalRoundConstants + i * Params::WIDTH);
            sbox(state);
            Params::externalLinearLayer(state);
        }

        for (int i = 0; i < Params::ROUNDS_P; i++) {
            state[0] += roundConstants.internalRoundConstants[i];
            state[0] ^= Params::D;
            Params::internalLinearLayer(
                state,
                roundConstants.matInternalDiagM1,
                roundConstants.montyInverse);
        }

        for (int i = rounds_f_half; i < Params::ROUNDS_F; i++) {
            addExtRc(state, roundConstants.externalRoundConstants + i * Params::WIDTH);
            sbox(state);
            Params::externalLinearLayer(state);
        }

        for (int i = 0; i < Params::WIDTH; i++) {
            out[i] = state[i];
        }
    }

    __device__ static void compress(
        F_t left[Params::DIGEST_WIDTH],
        F_t right[Params::DIGEST_WIDTH],
        F_t out[Params::DIGEST_WIDTH],
        RoundConstants_t roundConstants) {
        F_t state[Params::WIDTH];
        FDW_t* stateWidth = reinterpret_cast<FDW_t*>(state);
        stateWidth[0] = *reinterpret_cast<FDW_t*>(left);
        stateWidth[1] = *reinterpret_cast<FDW_t*>(right);
        for (int i = 2 * Params::DIGEST_WIDTH; i < Params::WIDTH; i++) {
            state[i].set_to_zero();
        }
        permute(state, state, roundConstants);
        *reinterpret_cast<FDW_t*>(out) = stateWidth[0];
    }

    __device__ static void
    hash(F_t* in, size_t nIn, F_t out[Params::DIGEST_WIDTH], RoundConstants_t roundConstants) {
        F_t state[Params::WIDTH];
        for (int i = 0; i < Params::WIDTH; i++) {
            state[i].set_to_zero();
        }

        for (int i = 0; i < nIn; i += Params::RATE) {
            for (int j = 0; j < Params::RATE; j++) {
                if (i + j < nIn) {
                    state[j] = in[i + j];
                }
            }
            permute(state, state, roundConstants);
        }

        *reinterpret_cast<FDW_t*>(out) = *reinterpret_cast<FDW_t*>(state);
    }
};

template <typename Params>
class DynamicHasher : public Hasher<Params> {
    using F_t = typename Params::F_t;
    using pF_t = typename Params::pF_t;
    using Hasher_t = Hasher<Params>;

  public:
    RoundConstants<Params> roundConstants;

    void setInternalRoundConstants(pF_t* internalRoundConstants) {
        roundConstants.internalRoundConstants = internalRoundConstants;
    }

    void setExternalRoundConstants(pF_t* externalRoundConstants) {
        roundConstants.externalRoundConstants = externalRoundConstants;
    }

    void setMatInternalDiagM1(pF_t* matInternalDiagM1) {
        roundConstants.matInternalDiagM1 = matInternalDiagM1;
    }

    void setMontyInverse(pF_t montyInverse) { roundConstants.montyInverse = montyInverse; }

    __device__ void permute(F_t in[Params::WIDTH], F_t out[Params::WIDTH]) {
        Hasher_t::permute(in, out, roundConstants);
    }

    __device__ void compress(
        F_t left[Params::DIGEST_WIDTH],
        F_t right[Params::DIGEST_WIDTH],
        F_t out[Params::DIGEST_WIDTH]) {
        Hasher_t::compress(left, right, out, roundConstants);
    }

    __device__ void hash(F_t* in, size_t nIn, F_t out[Params::DIGEST_WIDTH]) {
        Hasher_t::hash(in, nIn, out, roundConstants);
    }
};

template <typename Params>
class StaticHasher : public Hasher<Params> {
    using F_t = typename Params::F_t;
    using Hasher_t = Hasher<Params>;
    using RoundConstants_t = RoundConstants<Params>;

  public:
    static constexpr const RoundConstants_t roundConstants = {
        Params::INTERNAL_ROUND_CONSTANTS,
        Params::EXTERNAL_ROUND_CONSTANTS,
        Params::MAT_INTERNAL_DIAG_M1,
        Params::MONTY_INVERSE};

    __device__ static void permute(F_t in[Params::WIDTH], F_t out[Params::WIDTH]) {
        Hasher_t::permute(in, out, roundConstants);
    }

    __device__ static void compress(
        F_t left[Params::DIGEST_WIDTH],
        F_t right[Params::DIGEST_WIDTH],
        F_t out[Params::DIGEST_WIDTH]) {
        Hasher_t::compress(left, right, out, roundConstants);
    }

    __device__ static void hash(F_t* in, size_t nIn, F_t out[Params::DIGEST_WIDTH]) {
        Hasher_t::hash(in, nIn, out, roundConstants);
    }
};

template <typename Params, typename Hasher_t>
struct HasherState {
    using F_t = typename Params::F_t;
    using FDW_t = FDW_t<Params>;

    F_t data[Params::WIDTH];
    size_t index;

    __device__ HasherState() : index(0) {
        for (int i = 0; i < Params::WIDTH; ++i) {
            data[i].set_to_zero();
        }
    }

    __device__ void absorb(Hasher_t hasher, F_t* in, size_t nIn) {
        for (int i = 0; i < nIn; i++) {
            data[index] = in[i];
            index++;
            if (index == Params::RATE) {
                hasher.permute(data, data);
                index = 0;
            }
        }
    }

    __device__ void finalize(Hasher_t hasher, F_t out[Params::DIGEST_WIDTH]) {
        if (index != 0) {
            hasher.permute(data, data);
        }
        *reinterpret_cast<FDW_t*>(out) = *reinterpret_cast<FDW_t*>(data);
    }
};

template <typename Params, typename Hasher_t, typename P_t, int R>
struct MultiFieldHasherState : public HasherState<Params, Hasher_t> {
    using F_t = typename Params::F_t;

    static_assert(
        std::is_same<F_t, bn254_t>::value,
        "MultiFieldHasherState only supports bb31 reduction to bn254");
    static_assert(
        std::is_same<P_t, kb31_t>::value,
        "MultiFieldHasherState only supports bb31 reduction to bn254");

    P_t overhang[R];
    size_t overhangSize;

    __device__ MultiFieldHasherState() : HasherState<Params, Hasher_t>(), overhangSize(0) {}

    __device__ void finalize(Hasher_t hasher, F_t out[Params::DIGEST_WIDTH]) {
        if (overhangSize > 0) {
            F_t value =
                poseidon2_bn254_3::reduceKoalaBear(overhang, nullptr, overhangSize, 0, 1, 0);
            absorb(hasher, &value, 1);
        }
        HasherState<Params, Hasher_t>::finalize(hasher, out);
    }
};

using KoalaBearHasher = StaticHasher<poseidon2_kb31_16::KoalaBear>;
using Bn254Hasher = StaticHasher<poseidon2_bn254_3::Bn254>;

class KoalaBearHasherState : public HasherState<poseidon2_kb31_16::KoalaBear, KoalaBearHasher> {
  public:
    __device__ void
    absorbRow(KoalaBearHasher hasher, kb31_t* in, int rowIdx, size_t width, size_t height) {
        poseidon2_kb31_16::
            absorbRow<KoalaBearHasher, HasherState<poseidon2_kb31_16::KoalaBear, KoalaBearHasher>>(
                hasher,
                in,
                rowIdx,
                width,
                height,
                this);
    }
};

class Bn254HasherState
    : public MultiFieldHasherState<poseidon2_bn254_3::Bn254, Bn254Hasher, kb31_t, 8> {
  public:
    __device__ void
    absorbRow(Bn254Hasher hasher, kb31_t* in, int rowIdx, size_t width, size_t height) {
        poseidon2_bn254_3::absorbRow<
            Bn254Hasher,
            MultiFieldHasherState<poseidon2_bn254_3::Bn254, Bn254Hasher, kb31_t, 8>>(
            hasher,
            in,
            rowIdx,
            width,
            height,
            this);
    }
};


} // namespace poseidon2