#ifndef HTP_CTX_H
#define HTP_CTX_H
#include "hex-dma.h"
#include "hmx-queue.h"
#include "htp-ops.h"
#include "worker-pool.h"
#include <assert.h>
#include <dspqueue.h>
#include <stdatomic.h>
#include <stdint.h>
#define HTP_MAX_NTHREADS 10
#define HTP_MAX_MMAPS 16
struct htp_mmap {
uint64_t size;
uint64_t base;
uint32_t fd;
uint32_t pinned;
};
struct htp_spad {
const struct htp_tensor * src; uint8_t * data; uint32_t stride; uint32_t size; uint32_t size_per_thread; };
struct htp_context;
struct htp_ops_context {
struct htp_context * ctx;
enum htp_op_code op; int32_t op_params[HTP_OP_MAX_PARAMS];
const struct htp_tensor * src[HTP_OP_MAX_INPUTS];
const struct htp_tensor * dst;
struct htp_spad src0_spad;
struct htp_spad src1_spad;
struct htp_spad src2_spad;
struct htp_spad src3_spad;
struct htp_spad dst_spad;
uint32_t n_threads;
uint32_t flags;
};
struct htp_context {
dspqueue_t queue;
dma_queue * dma[HTP_MAX_NTHREADS];
struct htp_mmap mmap[HTP_MAX_MMAPS];
worker_pool_context_t worker_pool;
uint32_t n_threads;
int thread_id;
int thread_prio;
int hmx_enabled;
uint8_t * vtcm_base;
size_t vtcm_size;
uint32_t vtcm_rctx;
atomic_bool vtcm_valid;
atomic_bool vtcm_needs_release;
struct htp_ops_context octx;
#ifdef HTP_HAS_HMX
struct hmx_queue * hmx_queue; #endif
};
int op_matmul(struct htp_ops_context * octx);
int op_matmul_id(struct htp_ops_context * octx);
int op_binary(struct htp_ops_context * octx);
int op_unary(struct htp_ops_context * octx);
int op_sum_rows(struct htp_ops_context * octx);
int op_activations(struct htp_ops_context * octx);
int op_softmax(struct htp_ops_context * octx);
int op_add_id(struct htp_ops_context * octx);
int op_rope(struct htp_ops_context * octx);
int op_flash_attn_ext(struct htp_ops_context * octx);
int op_set_rows(struct htp_ops_context * octx);
int op_get_rows(struct htp_ops_context * octx);
int op_cpy(struct htp_ops_context * octx);
int op_repeat(struct htp_ops_context * octx);
int op_argsort(struct htp_ops_context * octx);
int op_ssm_conv(struct htp_ops_context * octx);
int op_cumsum(struct htp_ops_context * octx);
#endif