#include "ssm_scan.hpp"
#include "common.hpp"
template <int c_factor, int d_state>
static void ssm_scan_f32_group(
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
const int32_t * __restrict__ src6, float * __restrict__ dst,
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok,
const sycl::nd_item<2> & item) {
const int lane = item.get_local_id(1) % WARP_SIZE;
const int warp = item.get_local_id(1) / WARP_SIZE;
const int warp_idx = item.get_group(1) * c_factor + warp;
const int seq_idx = item.get_group(0);
const int head_idx = warp_idx / d_head;
const int head_off = (warp_idx % d_head) * sizeof(float);
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1);
const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;
float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const int stride_x = src1_nb2 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float);
const int stride_B = src4_nb2 / sizeof(float);
const int stride_C = src5_nb2 / sizeof(float);
const int stride_y = n_head * d_head;
float state[c_factor];
float state_sum = 0.0f;
#pragma unroll
for (int j = 0; j < c_factor; j++) {
state[j] = s0_warp[WARP_SIZE * j + lane];
}
for (int64_t i = 0; i < n_tok; i++) {
const float dt_val = dt_warp[i * stride_dt];
const float dt_soft_plus = (dt_val <= 20.0f ? sycl::log1p(sycl::exp(dt_val)) : dt_val);
state_sum = 0.0f;
const float dA = sycl::exp(dt_soft_plus * A_warp[0]);
const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
#pragma unroll
for (int j = 0; j < c_factor; j++) {
const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];
const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
state[j] = (state[j] * dA) + (B_val * x_dt);
state_sum += state[j] * C_val;
}
state_sum = warp_reduce_sum<WARP_SIZE>(state_sum);
if (lane == 0) {
y_warp[i * stride_y] = state_sum;
}
}
#pragma unroll
for (int j = 0; j < c_factor; j++) {
s_warp[WARP_SIZE * j + lane] = state[j];
}
}
static void ssm_scan_f32_sycl(
const float * src0, const float * src1, const float * src2, const float * src3,
const float * src4, const float * src5, const int32_t * src6, float * dst,
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
dpct::queue_ptr stream) {
GGML_ASSERT(src3_nb1 == sizeof(float));
if (d_state == 128) {
constexpr int threads = 128;
constexpr int num_warps = threads / WARP_SIZE;
const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps);
const sycl::range<2> block(1, threads);
stream->parallel_for(
sycl::nd_range<2>(grid * block, block),
[=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
ssm_scan_f32_group<128 / WARP_SIZE, 128>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item);
});
} else if (d_state == 256) {
constexpr int threads = 256;
constexpr int num_warps = threads / WARP_SIZE;
const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps);
const sycl::range<2> block(1, threads);
stream->parallel_for(
sycl::nd_range<2>(grid * block, block),
[=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
ssm_scan_f32_group<256 / WARP_SIZE, 256>(
src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item);
});
} else {
GGML_ABORT("ssm_scan: unsupported d_state (must be 128 or 256)");
}
}
inline void ggml_sycl_op_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const ggml_tensor * src3 = dst->src[3];
const ggml_tensor * src4 = dst->src[4];
const ggml_tensor * src5 = dst->src[5];
const ggml_tensor * src6 = dst->src[6];
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src6->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t nc = src0->ne[0];
const int64_t nr = src0->ne[1];
const int64_t nh = src1->ne[1];
const int64_t ng = src4->ne[1];
const int64_t n_t = src1->ne[2];
const int64_t n_s = src1->ne[3];
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
GGML_ASSERT(ggml_nelements(src1) + nc * nr * nh * n_s == ggml_nelements(dst));
dpct::queue_ptr stream = ctx.stream();
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
ssm_scan_f32_sycl(
static_cast<const float *>(src0->data), static_cast<const float *>(src1->data),
static_cast<const float *>(src2->data), static_cast<const float *>(src3->data),
static_cast<const float *>(src4->data), static_cast<const float *>(src5->data),
static_cast<const int32_t *>(src6->data), static_cast<float *>(dst->data),
src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
s_off, nc, nr, nh, ng, n_t, n_s, stream);
}
void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, 7);
ggml_sycl_op_ssm_scan(ctx, dst);
}