llama-cpp-sys-4 0.2.59

Low Level Bindings to llama.cpp
Documentation
#include "solve_tri.hpp"
#include "common.hpp"
#include <oneapi/mkl/blas.hpp>

template <int n_template, int k_template>
static void solve_tri_f32_fast(const float * __restrict__ A,
                               const float * __restrict__ B,
                               float * __restrict__ X,
                               const int64_t ne02, [[maybe_unused]] const int64_t ne03,
                               const int64_t nb02, const int64_t nb03,
                               const int64_t nb12, const int64_t nb13,
                               const int64_t nb2,  const int64_t nb3,
                               const int n_arg, const int k_arg,
                               const sycl::nd_item<2> & item, float * sA) {

    const int n = n_template == 0 ? n_arg : n_template;
    const int k = k_template == 0 ? k_arg : k_template;

    const int batch_idx = item.get_group(1);
    const int lane      = item.get_local_id(1) % WARP_SIZE;
    const int col_idx   = item.get_local_id(0);

    if (col_idx >= k) {
        return;
    }

    const int64_t i03 = batch_idx / ne02;
    const int64_t i02 = batch_idx % ne02;

    const float * A_batch = (const float *) ((const char *) A + i02 * nb02 + i03 * nb03);
    const float * B_batch = (const float *) ((const char *) B + i02 * nb12 + i03 * nb13);
    float *       X_batch = (float *)       ((char *)       X + i02 * nb2  + i03 * nb3);

    const int offset = item.get_local_id(1) + item.get_local_id(0) * item.get_local_range(1);

#pragma unroll
    for (int i = 0; i < n * n; i += k * WARP_SIZE) {
        const int i0 = i + offset;
        if (i0 < n * n) {
            sA[i0] = A_batch[i0];
        }
    }

    item.barrier(sycl::access::fence_space::local_space);

    float x_low  = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
    float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;

    const int half      = WARP_SIZE;
    const int nrows_low = (n < half) ? n : half;

#pragma unroll
    for (int row = 0; row < nrows_low; ++row) {
        float sum = 0.0f;
        if (lane < row) {
            sum += sA[row * n + lane] * x_low;
        }
        sum = warp_reduce_sum<WARP_SIZE>(sum);
        if (lane == row) {
            x_low = (x_low - sum) / sA[row * n + row];
        }
    }

#pragma unroll
    for (int row = half; row < n; ++row) {
        float     sum = sA[row * n + lane] * x_low;
        const int j   = half + lane;
        if (j < row) {
            sum += sA[row * n + j] * x_high;
        }
        sum = warp_reduce_sum<WARP_SIZE>(sum);
        if (lane == row - half) {
            x_high = (x_high - sum) / sA[row * n + row];
        }
    }

#pragma unroll
    for (int rr = 0; rr < 2; ++rr) {
        const int row = rr * WARP_SIZE + lane;
        if (row < n) {
            const float val            = (row < half) ? x_low : x_high;
            X_batch[row * k + col_idx] = val;
        }
    }
}

static void solve_tri_f32_mkl(dpct::queue_ptr stream,
                               const float * A, float * X,
                               int n, int k,
                               int64_t ne02, [[maybe_unused]] int64_t ne03,
                               int64_t nb02, [[maybe_unused]] int64_t nb03,
                               int64_t nb2,  [[maybe_unused]] int64_t nb3) {
    const float alpha = 1.0f;
    const int64_t total_batches = ne02 * ne03;
    if (total_batches == 0) {
        return;
    }

    const int64_t stride_a = nb02 / sizeof(float);
    const int64_t stride_x = nb2 / sizeof(float);

    oneapi::mkl::blas::trsm_batch(
        *stream,
        oneapi::mkl::side::right,
        oneapi::mkl::uplo::upper,
        oneapi::mkl::transpose::nontrans,
        oneapi::mkl::diag::nonunit,
        k, n, alpha,
        A, n, stride_a,
        X, k, stride_x,
        total_batches);
}

inline void ggml_sycl_op_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
    const ggml_tensor * src0 = dst->src[0];
    const ggml_tensor * src1 = dst->src[1];

    GGML_ASSERT(ggml_is_contiguous(src0));
    GGML_ASSERT(ggml_is_contiguous(src1));
    GGML_ASSERT(src0->type == GGML_TYPE_F32);

    dpct::queue_ptr stream = ctx.stream();
    SYCL_CHECK(ggml_sycl_set_device(ctx.device));

    const int n    = src0->ne[0];
    const int k    = src1->ne[0];
    const int64_t ne02 = src0->ne[2];
    const int64_t ne03 = src0->ne[3];

    GGML_ASSERT(n <= SYCL_SOLVE_TRI_MAX_N && k <= SYCL_SOLVE_TRI_MAX_K);

    const float * A_d = static_cast<const float *>(src0->data);
    const float * B_d = static_cast<const float *>(src1->data);
    float * X_d       = static_cast<float *>(dst->data);

    if (X_d != B_d) {
        const int64_t total_elements = (int64_t)n * k * ne02 * ne03;
        stream->memcpy(X_d, B_d, total_elements * sizeof(float));
    }

    const int64_t nb02 = src0->nb[2];
    const int64_t nb03 = src0->nb[3];
    const int64_t nb12 = src1->nb[2];
    const int64_t nb13 = src1->nb[3];
    const int64_t nb2  = dst->nb[2];
    const int64_t nb3  = dst->nb[3];

    const int64_t total_batches = ne02 * ne03;

    if (n <= 2 * WARP_SIZE && k <= 32) {
        const int smem_size = 2 * WARP_SIZE * 2 * WARP_SIZE;
        const sycl::range<2> grid(1, total_batches);
        const sycl::range<2> block(k, WARP_SIZE);
        stream->submit([&](sycl::handler & cgh) {
            sycl::local_accessor<float, 1> smem_acc(sycl::range<1>(smem_size), cgh);
            cgh.parallel_for(
                sycl::nd_range<2>(grid * block, block),
                [=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
                    solve_tri_f32_fast<0, 0>(A_d, B_d, X_d, ne02, ne03,
                                              nb02, nb03, nb12, nb13, nb2, nb3,
                                              n, k, item, get_pointer(smem_acc));
                });
        });
    } else {
        solve_tri_f32_mkl(stream, A_d, X_d, n, k, ne02, ne03, nb02, nb03, nb2, nb3);
    }
}

void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
    ggml_sycl_op_solve_tri(ctx, dst);
}