llama-cpp-sys-4 0.2.45

Low Level Bindings to llama.cpp
Documentation
#include <string.h>
#include <stdlib.h>
#include <math.h>
#include <HAP_farf.h>
#include <HAP_perf.h>

#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "ggml.h"

#include "hvx-utils.h"
#include "hex-dma.h"

#include "htp-ctx.h"
#include "htp-ops.h"
#include "htp-ops.h"

#ifndef MIN
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif

struct htp_argsort_context {
    struct htp_ops_context * octx;
    uint32_t                 nrows_per_thread;
};

static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y)
{
    const HVX_Vector one  = Q6_V_vsplat_R(1);
    const HVX_Vector zero = Q6_V_vzero();

    HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y);
    HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero);
    HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
    return hvx_vec_get_i32(sum) == 32;
}

// Sorts values and mirrors swaps to indices.
static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {
    if (left >= right) return;

    int pivot_idx = (left + right) / 2;
    float pivot = values[pivot_idx];
    int i = left;
    int j = right;

    HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
    while (i <= j) {
        // Vectorized scan for i
        while (i <= j) {
            // Check if we have at least one full vector
            if (i + 32 <= j) {
                HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
                if (all_greater_f32(pivot_vec, vals_vec)) {
                    // If all elements are < pivot, we can skip this whole block
                    i += 32;
                    continue;
                }
            }

            // Scalar fallback / cleanup
            if (values[i] < pivot) {
                i++;
            } else {
                break;
            }
        }

        // Vectorized scan for j
        while (i <= j) {
            if (j - 32 >= i) {
                // Load 32 elements ending at j.
                // Since we want `values[j] > pivot`, let's load from j-31 to j.
                HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
                if (all_greater_f32(vals_vec, pivot_vec)) {
                    j -= 32;
                    continue;
                }
            }

            if (values[j] > pivot) {
                j--;
            } else {
                break;
            }
        }

        if (i <= j) {
            float tmp_val = values[i];
            values[i] = values[j];
            values[j] = tmp_val;

            int32_t tmp_idx = indices[i];
            indices[i] = indices[j];
            indices[j] = tmp_idx;
            i++;
            j--;
        }
    }

    if (left < j) quicksort_values_indices_asc(values, indices, left, j);
    if (i < right) quicksort_values_indices_asc(values, indices, i, right);
}

static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {
    if (left >= right) return;

    int pivot_idx = (left + right) / 2;
    float pivot = values[pivot_idx];
    int i = left;
    int j = right;

    HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);

    while (i <= j) {
        // Vectorized scan for i (values[i] > pivot)
        while (i <= j) {
            if (i + 32 <= j) {
                HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
                if (all_greater_f32(vals_vec, pivot_vec)) {
                    i += 32;
                    continue;
                }
            }

            if (values[i] > pivot) {
                i++;
            } else {
                break;
            }
        }

        // Vectorized scan for j (values[j] < pivot)
        while (i <= j) {
            if (j - 32 >= i) {
                HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
                if (all_greater_f32(pivot_vec, vals_vec)) {
                    j -= 32;
                    continue;
                }
            }

            if (values[j] < pivot) {
                j--;
            } else {
                break;
            }
        }

        if (i <= j) {
            float tmp_val = values[i];
            values[i] = values[j];
            values[j] = tmp_val;

            int32_t tmp_idx = indices[i];
            indices[i] = indices[j];
            indices[j] = tmp_idx;
            i++;
            j--;
        }
    }

    if (left < j) quicksort_values_indices_desc(values, indices, left, j);
    if (i < right) quicksort_values_indices_desc(values, indices, i, right);
}

// LUT for ramp initialization of argsort output (first 32 members)
int32_t argosrt_ramp_lut[32] __attribute__((aligned(VLEN))) = {
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
    16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
};

static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
    struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
    struct htp_ops_context * octx = actx->octx;

    // Unpack context
    const struct htp_tensor * src0 = octx->src[0];
    const struct htp_tensor * dst = octx->dst;

    // Scratchpad memory
    uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;

    // Dimensions
    uint32_t ne00 = src0->ne[0];
    uint32_t ne01 = src0->ne[1];
    uint32_t ne02 = src0->ne[2];
    uint32_t ne03 = src0->ne[3];

    uint32_t nb01 = src0->nb[1];
    //uint32_t nb02 = src0->nb[2];
    //uint32_t nb03 = src0->nb[3];

    uint32_t nb1 = dst->nb[1];
    //uint32_t nb2 = dst->nb[2];
    //uint32_t nb3 = dst->nb[3];

    // Sort order
    enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];

    // Rows to process
    uint32_t total_rows = ne01 * ne02 * ne03;
    uint32_t rows_per_thread = actx->nrows_per_thread;
    uint32_t start_row = rows_per_thread * i;
    uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);

    // Scratchpad layout:
    // We need space for one row of float data (values) and one row of int32 indices.
    // values: ne00 * sizeof(float)
    // indices: ne00 * sizeof(int32_t)
    // Padded to 128 bytes.

    size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
    size_t num_vec_ind_values = hmx_ceil_div(ne00, VLEN/(sizeof(int32_t)));
    float * values_buf = (float *) spad;
    int32_t * indices_buf = (int32_t *) (spad + values_size);
    HVX_Vector * indices_buf_vec = (HVX_Vector *) (spad + values_size);
    const HVX_Vector ind_init_vec = *(HVX_Vector *)argosrt_ramp_lut;
    const HVX_Vector ind_diff_vec = Q6_V_vsplat_R(32);

    for (uint32_t r = start_row; r < end_row; r++) {
        uint32_t src_offset = r * nb01;
        uint32_t dst_offset = r * nb1;

        uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
        uint8_t * dst_ptr = (uint8_t *) dst->data  + dst_offset;

        hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
        hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);

        // Initialize indices - Start with values 0..31, add 32 for additional vec iterations
        HVX_Vector curr_ind_vec = ind_init_vec;
        for (uint32_t j_vec = 0; j_vec < num_vec_ind_values; j_vec++) {
            indices_buf_vec[j_vec] = curr_ind_vec;
            curr_ind_vec = Q6_Vw_vadd_VwVw(curr_ind_vec, ind_diff_vec);
        }

        // Sort values and mirror swaps to indices
        if (order == GGML_SORT_ORDER_ASC) {
            quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);
        } else {
            quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);
        }

        // Copy indices back to DDR
        hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00);
    }
}

int op_argsort(struct htp_ops_context * octx) {
    // Check supported types
    if (octx->src[0]->type != HTP_TYPE_F32) {
        return HTP_STATUS_NO_SUPPORT;
    }

    const uint32_t total_rows = octx->src[0]->ne[1] * octx->src[0]->ne[2] * octx->src[0]->ne[3];
    const uint32_t n_threads = MIN(total_rows, octx->n_threads);

    // Allocate scratchpad
    // We need 1 row of float + 1 row of int32 per thread.
    uint32_t ne00 = octx->src[0]->ne[0];
    size_t values_size  = hex_round_up(ne00 * sizeof(float), 128);
    size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
    size_t spad_per_thread = values_size + indices_size;

    // Make sure we round up to 256 for alignment requirements
    spad_per_thread = hex_round_up(spad_per_thread, 256);

    size_t total_spad_size = spad_per_thread * n_threads;

    if (octx->ctx->vtcm_size < total_spad_size) {
        FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
        return HTP_STATUS_VTCM_TOO_SMALL;
    }

    octx->src0_spad.data = octx->ctx->vtcm_base;
    octx->src0_spad.size = total_spad_size;
    octx->src0_spad.size_per_thread = spad_per_thread;

    FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
         octx->src[0]->ne[0], octx->src[0]->ne[1], octx->src[0]->ne[2], octx->src[0]->ne[3],
         octx->dst->ne[0], octx->dst->ne[1], octx->dst->ne[2], octx->dst->ne[3],
         octx->src[0]->data, octx->dst->data);

    struct htp_argsort_context actx;
    actx.octx = octx;
    actx.nrows_per_thread = (total_rows + n_threads - 1) / n_threads;

    // Run jobs
    worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_threads);

    return HTP_STATUS_OK;
}