llama-cpp-sys-4 0.3.1

Low Level Bindings to llama.cpp
Documentation
// HMX operation entry-point declarations.
// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib)

#ifndef HMX_OPS_H
#define HMX_OPS_H

#include <stddef.h>
#include <stdint.h>

#include "htp-ops.h"

#ifdef __cplusplus
extern "C" {
#endif

typedef struct {
    float        *dst;
    const float  *activation;
    const __fp16 *permuted_weight;
    int           m;
    int           k;
    int           n;
    int           act_stride;
    int           weight_stride;
    int           dst_stride;
    int           ne02;
    int           ne03;
    int           ne12;
    int           ne13;
    size_t        src0_nb2;
    size_t        src0_nb3;
    size_t        src1_nb2;
    size_t        src1_nb3;
    size_t        dst_nb2;
    size_t        dst_nb3;
} hmx_matmul_f16_f32_batched_params_t;

// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output
// act_stride: activation row stride in elements (= k for contiguous, or
//             nb[1]/sizeof(float) for permuted tensors like attention Q).
// weight_stride: weight row stride in elements (= k for compact weights, or
//                nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK).
int hmx_matmul_f16_f32(struct htp_context *ctx,
                                float *restrict dst,
                                const float *activation,
                                const __fp16 *permuted_weight,
                                int m, int k, int n,
                                int act_stride,
                                int weight_stride);

// Batched F16 wrapper over hmx_mat_mul_f16_f32.
// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3.
int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params);

// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4)
int hmx_matmul_2d_f32(struct htp_context *ctx,
                                      float *restrict dst,
                                      const float *activation,
                                      const uint8_t *permuted_weight,
                                      int m, int k, int n,
                                      int act_stride,
                                      int weight_stride,
                                      int weight_type);

struct mmid_row_mapping;

int hmx_matmul_id_2d_f32(struct htp_context *ctx,
                                         float *restrict dst,
                                         const float *activation,
                                         const uint8_t *permuted_weight,
                                         int m, int k, int n,
                                         int ne11,
                                         size_t act_nb1, size_t act_nb2,
                                         size_t dst_nb1, size_t dst_nb2,
                                         int weight_stride,
                                         int weight_type,
                                         const struct mmid_row_mapping *matrix_rows,
                                         int cur_a,
                                         int mapping_stride);

// HMX flash attention
int hmx_flash_attn_ext(struct htp_ops_context * octx);

#ifdef __cplusplus
}
#endif

#endif // HMX_OPS_H