llama-cpp-sys-4 0.2.56

Low Level Bindings to llama.cpp
Documentation
#include "cumsum.hpp"
#include "common.hpp"

#include <algorithm>

#define SYCL_CUMSUM_BLOCK_SIZE 256

static __dpct_inline__ float warp_prefix_inclusive_sum_f32(float x, const sycl::nd_item<3> & item) {
    return sycl::inclusive_scan_over_group(item.get_sub_group(), x, sycl::plus<float>());
}

static void cumsum_f32_kernel(
        const float * __restrict__ src, float * __restrict__ dst,
        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
        const int64_t s01, const int64_t s02, const int64_t s03,
        const int64_t  d1, const int64_t  d2, const int64_t  d3,
        const sycl::nd_item<3> & item, float * smem) {

    const int tid = item.get_local_id(2);
    const int block_size = item.get_local_range(2);
    const int lane = tid % WARP_SIZE;
    const int warp = tid / WARP_SIZE;
    const int warps_per_block = block_size / WARP_SIZE;

    float * s_vals      = smem;
    float * s_warp_sums = smem + block_size;
    float * s_carry     = smem + block_size + warps_per_block;

    if (tid == 0) {
        s_carry[0] = 0.0f;
    }
    item.barrier(sycl::access::fence_space::local_space);

    const int64_t i3 = item.get_group(0);
    const int64_t i2 = item.get_group(1);
    const int64_t i1 = item.get_group(2);
    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
        return;
    }

    const float * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
    float       * dst_row = dst + i1 * d1  + i2 * d2  + i3 * d3;

    constexpr int num_unroll = 4;
    float temp[num_unroll];

    for (int64_t i = 0; i < ne00; i += num_unroll * block_size) {
        int64_t idx = i + tid * num_unroll;

        temp[0] = (idx < ne00 ? src_row[idx] : 0.0f);
#pragma unroll
        for (int j = 1; j < num_unroll; j++) {
            temp[j] = temp[j - 1];
            if (idx + j < ne00) {
                temp[j] += src_row[idx + j];
            }
        }

        float val = (idx < ne00) ? temp[num_unroll - 1] : 0.0f;

        val = warp_prefix_inclusive_sum_f32(val, item);
        s_vals[tid] = val;

        if (lane == WARP_SIZE - 1) {
            s_warp_sums[warp] = val;
        }
        item.barrier(sycl::access::fence_space::local_space);

        if (warp == 0) {
            float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
            float inc = warp_prefix_inclusive_sum_f32(w, item);
            if (tid < warps_per_block) {
                s_warp_sums[tid] = inc - w;
            }
            if (tid == warps_per_block - 1) {
                s_carry[1] = inc;
            }
        }
        item.barrier(sycl::access::fence_space::local_space);

        float carry = s_carry[0];
        float final_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];

#pragma unroll
        for (int j = 0; j < num_unroll; j++) {
            if (idx + j < ne00) {
                dst_row[idx + j] = temp[j] + final_offset;
            }
        }

        item.barrier(sycl::access::fence_space::local_space);

        if (tid == 0) {
            s_carry[0] += s_carry[1];
        }
    }
}

inline void ggml_sycl_op_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
    const ggml_tensor * src0 = dst->src[0];

    GGML_ASSERT(src0->type == GGML_TYPE_F32);
    GGML_ASSERT(dst->type == GGML_TYPE_F32);

    dpct::queue_ptr stream = ctx.stream();
    SYCL_CHECK(ggml_sycl_set_device(ctx.device));

    const float * src_d = static_cast<const float *>(src0->data);
    float       * dst_d = static_cast<float *>(dst->data);

    const int64_t ne00 = src0->ne[0];
    const int64_t ne01 = src0->ne[1];
    const int64_t ne02 = src0->ne[2];
    const int64_t ne03 = src0->ne[3];

    const size_t ts = sizeof(float);
    const int64_t s01 = src0->nb[1] / ts;
    const int64_t s02 = src0->nb[2] / ts;
    const int64_t s03 = src0->nb[3] / ts;
    const int64_t d1  = dst->nb[1] / ts;
    const int64_t d2  = dst->nb[2] / ts;
    const int64_t d3  = dst->nb[3] / ts;

    const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE;
    int block_size = num_warps * WARP_SIZE;
    block_size = std::min(block_size, SYCL_CUMSUM_BLOCK_SIZE);
    const int warps_per_block = block_size / WARP_SIZE;
    const int smem_size = block_size + warps_per_block + 2;

    const sycl::range<3> grid(ne03, ne02, ne01);
    const sycl::range<3> block(1, 1, block_size);

    stream->submit([&](sycl::handler & cgh) {
        sycl::local_accessor<float, 1> smem_acc(sycl::range<1>(smem_size), cgh);
        cgh.parallel_for(
            sycl::nd_range<3>(grid * block, block),
            [=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
                cumsum_f32_kernel(src_d, dst_d, ne00, ne01, ne02, ne03,
                                  s01, s02, s03, d1, d2, d3,
                                  item, get_pointer(smem_acc));
            });
    });
}

void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
    ggml_sycl_op_cumsum(ctx, dst);
}