llama-cpp-sys-4 0.2.46

Low Level Bindings to llama.cpp
Documentation
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"

#include <HAP_farf.h>
#include <HAP_perf.h>

#include <math.h>
#include <string.h>

#include "hex-dma.h"
#include "hvx-utils.h"

#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-ops.h"
#include "htp-ops.h"

#define set_rows_preamble                      \
    const uint32_t ne00 = octx->src[0]->ne[0]; \
    const uint32_t ne01 = octx->src[0]->ne[1]; \
    const uint32_t ne02 = octx->src[0]->ne[2]; \
    const uint32_t ne03 = octx->src[0]->ne[3]; \
                                               \
    const uint32_t ne10 = octx->src[1]->ne[0]; \
    const uint32_t ne11 = octx->src[1]->ne[1]; \
    const uint32_t ne12 = octx->src[1]->ne[2]; \
    const uint32_t ne13 = octx->src[1]->ne[3]; \
                                               \
    const uint32_t nb01 = octx->src[0]->nb[1]; \
    const uint32_t nb02 = octx->src[0]->nb[2]; \
    const uint32_t nb03 = octx->src[0]->nb[3]; \
                                               \
    const uint32_t nb10 = octx->src[1]->nb[0]; \
    const uint32_t nb11 = octx->src[1]->nb[1]; \
    const uint32_t nb12 = octx->src[1]->nb[2]; \
                                               \
    const uint32_t nb1 = octx->dst->nb[1];     \
    const uint32_t nb2 = octx->dst->nb[2];     \
    const uint32_t nb3 = octx->dst->nb[3];     \
                                               \
    const uint32_t ne0 = octx->dst->ne[0];     \
    const uint32_t ne1 = octx->dst->ne[1];     \
    const uint32_t ne2 = octx->dst->ne[2];     \
    const uint32_t ne3 = octx->dst->ne[3];     \
                                               \
    const uint32_t nr  = ne01;

struct htp_set_rows_context {
    struct htp_ops_context * octx;
    struct fastdiv_values div_ne12;
    struct fastdiv_values div_ne11;
    uint32_t src0_nrows_per_thread;
};

static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) {
    struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
    struct htp_ops_context * octx = srctx->octx;

    set_rows_preamble;

    uint64_t qt = HAP_perf_get_qtimer_count();

    // parallelize by rows of src0
    const uint32_t dr  = srctx->src0_nrows_per_thread;
    const uint32_t ir0 = dr * ith;
    const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;

    const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32);

    for (uint32_t i03 = 0; i03 < ne03; ++i03) {
        for (uint32_t i02 = 0; i02 < ne02; ++i02) {
            for (uint32_t i = ir0; i < ir1; ++i) {
                const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
                const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
                const uint32_t i10 = i;

                const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12;

                uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
                if (i1 >= ne1) {
                    // ignore invalid indices
                    continue;
                }

                const uintptr_t src0_ptr = octx->src[0]->data + i*nb01 + i02*nb02 + i03*nb03;
                const uintptr_t dst_ptr  = octx->dst->data  + i1*nb1 + i02*nb2  + i03*nb3;

                // copy row
                hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
            }
        }
    }

    qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt);
    FARF(HIGH, "set-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
         ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt);
}

static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) {
    struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data;
    struct htp_ops_context * octx = srctx->octx;

    set_rows_preamble;

    uint64_t qt = HAP_perf_get_qtimer_count();

    // parallelize by rows of src0
    const uint32_t dr  = srctx->src0_nrows_per_thread;
    const uint32_t ir0 = dr * ith;
    const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;

    const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32);

    for (uint32_t i03 = 0; i03 < ne03; ++i03) {
        for (uint32_t i02 = 0; i02 < ne02; ++i02) {
            for (uint32_t i = ir0; i < ir1; ++i) {
                const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12);
                const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11);
                const uint32_t i10 = i;

                const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12;

                uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
                if (i1 >= ne1) {
                    // ignore invalid indices
                    continue;
                }

                const uint8_t* src0_ptr = (const uint8_t *) octx->src[0]->data + i*nb01 + i02*nb02 + i03*nb03;
                uint8_t*       dst_ptr  = (uint8_t *)       octx->dst->data  + i1*nb1 + i02*nb2  + i03*nb3;

                hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00);
            }
        }
    }

    qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt);
    FARF(HIGH, "set-rows-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
         ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt);
}

int op_set_rows(struct htp_ops_context * octx) {
    set_rows_preamble;

    const uint32_t n_threads = MIN(nr, octx->n_threads);

    if (octx->src[0]->type != HTP_TYPE_F32) {
        return HTP_STATUS_NO_SUPPORT;
    }

    if (octx->dst->type != HTP_TYPE_F32 && octx->dst->type != HTP_TYPE_F16) {
        return HTP_STATUS_NO_SUPPORT;
    }

    if (octx->src[1]->type != HTP_TYPE_I32 && octx->src[1]->type != HTP_TYPE_I64) {
        return HTP_STATUS_NO_SUPPORT;
    }

    if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
        return HTP_STATUS_OK;
    }

    struct htp_set_rows_context srctx;
    srctx.octx = octx;
    srctx.div_ne12 = init_fastdiv_values(ne12);
    srctx.div_ne11 = init_fastdiv_values(ne11);

    srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads;

    switch(octx->dst->type) {
    case HTP_TYPE_F32:
        worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads);
        break;
    case HTP_TYPE_F16:
        worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_threads);
        break;
    default:
        return HTP_STATUS_NO_SUPPORT;
    }

    return HTP_STATUS_OK;
}