sp1-gpu-sys 6.2.2

FFI bindings and CUDA build system for SP1-GPU
#pragma once

#include "logup_gkr/round.cuh"
#include "config.cuh"

extern "C" void* logup_gkr_fix_last_variable_first_layer();
extern "C" void* logup_gkr_sum_as_poly_first_layer();
extern "C" void* logup_gkr_first_layer_transition();

struct FirstLayerCircuitValues {
  public:
    felt_t numeratorZero;
    felt_t numeratorOne;
    ext_t denominatorZero;
    ext_t denominatorOne;

  public:
    static __device__ __forceinline__ FirstLayerCircuitValues
    load(felt_t* numeratorValues, ext_t* denominatorValues, size_t i, size_t height) {
        FirstLayerCircuitValues values;

        // Load the numerator and denominator values
        // numeratorValues is the concatenation of numerator evaluations at 0 and then 1
        // likewise for denominatorValues

        values.numeratorZero = felt_t::load(numeratorValues, i);
        values.numeratorOne = felt_t::load(numeratorValues, i + 2 * height);
        values.denominatorZero = ext_t::load(denominatorValues, i);
        values.denominatorOne = ext_t::load(denominatorValues, i + 2 * height);

        return values;
    }

    static __device__ __forceinline__ FirstLayerCircuitValues
    load(const felt_t* numeratorValues, const ext_t* denominatorValues, size_t i, size_t height) {
        FirstLayerCircuitValues values;

        // Load the numerator and denominator values
        // numeratorValues is the concatenation of numerator evaluations at 0 and then 1
        // likewise for denominatorValues

        values.numeratorZero = felt_t::load(numeratorValues, i);
        values.numeratorOne = felt_t::load(numeratorValues, i + 2 * height);
        values.denominatorZero = ext_t::load(denominatorValues, i);
        values.denominatorOne = ext_t::load(denominatorValues, i + 2 * height);

        return values;
    }

    static __device__ __forceinline__ FirstLayerCircuitValues paddingValues() {
        FirstLayerCircuitValues values;
        values.numeratorZero = felt_t::zero();
        values.numeratorOne = felt_t::zero();
        values.denominatorZero = ext_t::one();
        values.denominatorOne = ext_t::one();
        return values;
    }

    static __device__ __forceinline__ CircuitValues fix_last_variable(
        FirstLayerCircuitValues zeroValues,
        FirstLayerCircuitValues oneValues,
        ext_t alpha) {
        CircuitValues values;
        values.numeratorZero =
            alpha.interpolateLinear(oneValues.numeratorZero, zeroValues.numeratorZero);
        values.numeratorOne =
            alpha.interpolateLinear(oneValues.numeratorOne, zeroValues.numeratorOne);
        values.denominatorZero =
            alpha.interpolateLinear(oneValues.denominatorZero, zeroValues.denominatorZero);
        values.denominatorOne =
            alpha.interpolateLinear(oneValues.denominatorOne, zeroValues.denominatorOne);
        return values;
    }

    __device__ __forceinline__ void
    store(felt_t* numeratorValues, ext_t* denominatorValues, size_t i, size_t height) {

        felt_t::store(numeratorValues, i, numeratorZero);
        felt_t::store(numeratorValues, i + 2 * height, numeratorOne);
        ext_t::store(denominatorValues, i, denominatorZero);
        ext_t::store(denominatorValues, i + 2 * height, denominatorOne);
    }

    /// Compute the sumcheck sum values
    __device__ __forceinline__ ext_t sumAsPoly(ext_t lambda, ext_t eqValue) {
        ext_t numerator = numeratorZero * denominatorOne + numeratorOne * denominatorZero;
        ext_t denominator = denominatorZero * denominatorOne;
        return eqValue * (numerator * lambda + denominator);
    }
};


/// A GKR layer.
struct JaggedFirstGkrLayer {
    using OutputDenseData = JaggedGkrLayer;

  public:
    /// numerator_0 || numerator_1
    felt_t* numeratorValues;
    /// denominator_0 || denominator_1
    ext_t* denominatorValues;
    /// Half of the length of each section.
    size_t height;

    __forceinline__ __device__ void fixLastVariable(
        JaggedGkrLayer& other,
        size_t restrictedIdx,
        size_t zeroIdx,
        size_t oneIdx,
        ext_t alpha) const {

        FirstLayerCircuitValues valuesZero =
            FirstLayerCircuitValues::load(numeratorValues, denominatorValues, zeroIdx, height);
        FirstLayerCircuitValues valuesOne =
            FirstLayerCircuitValues::load(numeratorValues, denominatorValues, oneIdx, height);
        CircuitValues values =
            FirstLayerCircuitValues::fix_last_variable(valuesZero, valuesOne, alpha);

        values.store(other.layer, restrictedIdx, other.height);
    }

    __forceinline__ __device__ void pad(JaggedGkrLayer& other, size_t restrictedIdx) const {
        CircuitValues values = CircuitValues::paddingValues();
        values.store(other.layer, restrictedIdx, other.height);
    }

    __forceinline__ __device__ void
    circuitTransition(JaggedGkrLayer& other, size_t restrictedIdx, size_t zeroIdx, size_t oneIdx)
        const {

        CircuitValues values;

        FirstLayerCircuitValues valuesZero =
            FirstLayerCircuitValues::load(numeratorValues, denominatorValues, zeroIdx, height);
        values.numeratorZero = valuesZero.numeratorZero * valuesZero.denominatorOne +
                               valuesZero.numeratorOne * valuesZero.denominatorZero;
        values.denominatorZero = valuesZero.denominatorZero * valuesZero.denominatorOne;

        FirstLayerCircuitValues valuesOne =
            FirstLayerCircuitValues::load(numeratorValues, denominatorValues, oneIdx, height);
        values.numeratorOne = (valuesOne.denominatorOne * valuesOne.numeratorZero) +
                              (valuesOne.denominatorZero * valuesOne.numeratorOne);
        values.denominatorOne = valuesOne.denominatorZero * valuesOne.denominatorOne;

        values.store(other.layer, restrictedIdx, other.height);
    }
};