#pragma once
#include "flash_fwd_kernel.h"
namespace FLASH_NAMESPACE {
using namespace cute;
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
inline __device__ void sparse_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
extern __shared__ char smem_[];
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;
auto seed_offset = at::cuda::philox::unpack(params.philox_args);
flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
bidb, bidh, tidx, params.h);
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
params.rng_state[0] = std::get<0>(seed_offset);
params.rng_state[1] = std::get<1>(seed_offset);
}
const BlockInfo<!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (Is_causal || Is_local) {
n_block_max = std::min(n_block_max,
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
}
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr)
+ binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.q_row_stride, params.q_head_stride, _1{}));
Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr)
+ binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)),
make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
make_stride(params.k_row_stride, params.k_head_stride, _1{}));
Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_coord(_, 0)); Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr)
+ binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)),
make_shape(binfo.actual_seqlen_k, params.h_k, params.d),
make_stride(params.v_row_stride, params.v_head_stride, _1{}));
Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_coord(_, 0)); const index_t row_offset_k_token =
binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
+ (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v_token =
binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ (bidh / params.h_h_k_ratio) * params.v_head_stride;
Tensor gKToken = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr) + row_offset_k_token),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(_0{}, _1{}));
Tensor gVToken = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr) + row_offset_v_token),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(_0{}, _1{}));
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
Shape<Int<kBlockM>, Int<kBlockN>>{},
make_stride(params.seqlen_k_rounded, _1{}));
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQ{});
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
typename Kernel_traits::SmemLayoutKV{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); Tensor tKgKBlock = tKgK(_, _, _, 0);
auto tKgKBlockData = tKgKBlock.data();
Tensor tKgKToken = gmem_thr_copy_QKV.partition_S(gKToken); auto tKgKTokenData = tKgKToken.data();
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); Tensor tVgVBlock = tVgV(_, _, _, 0);
auto tVgVBlockData = tVgVBlock.data();
Tensor tVgVToken = gmem_thr_copy_QKV.partition_S(gVToken); auto tVgVTokenData = tVgVToken.data();
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); Tensor tSrK = thr_mma.partition_fragment_B(sK); Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle);
Tensor tSgS = thr_mma.partition_C(gP);
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{});
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK)));
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV);
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
#pragma unroll
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
}
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM);
cute::cp_async_fence();
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
if (Kernel_traits::Share_Q_K_smem) {
flash::cp_async_wait<0>();
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads();
}
int n_block = n_block_max - 1;
int num_blks = params.block_count[(bidb * params.h + bidh) * params.NUM_ROWS + m_block];
auto* blks_ptr = params.block_offset + ((bidb * params.h + bidh) * params.NUM_ROWS + m_block) * params.NNZ_S;
if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
flash::cp_async_wait<1>();
__syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
}
clear(acc_o);
flash::Softmax<2 * size<1>(acc_o)> softmax;
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
int num_cols = params.column_count[(bidb * params.h + bidh) * params.NUM_ROWS + m_block];
int num_cols_block = (num_cols + kBlockN - 1)/ kBlockN;
if (num_blks <= 0 && num_cols_block <= 0) {
Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
+ binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0));
Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_tensor<Element>(shape(tOgO));
clear(tOrO);
Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
}
flash::copy<Is_even_MN, Is_even_K, false, false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
#pragma unroll
for (int m = 0; m < size<1>(tOgO); ++m) {
const int row = get<0>(tOcO(0, m, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
}
return;
}
if (num_blks > 0) {
int block_index = num_blks - 1;
tKgKBlock.data() = tKgKBlockData + blks_ptr[block_index] * int64_t(params.k_row_stride);
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgKBlock, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - blks_ptr[block_index]);
cute::cp_async_fence();
for (int n = 0; n < n_masking_steps && block_index >= 0; ++n, --block_index) {
int start_n = blks_ptr[block_index];
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
tVgVBlock.data() = tVgVBlockData + start_n * int64_t(params.v_row_stride);
if (block_index < num_blks - 1) {
flash::copy<true, Is_even_K>(gmem_tiled_copy_QKV, tVgVBlock, tVsV, tKVcKV, tKVpKV);
} else {
flash::copy<Is_even_MN, Is_even_K, true>(
gmem_tiled_copy_QKV, tVgVBlock, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - start_n
);
}
cute::cp_async_fence();
flash::gemm<Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
flash::apply_softcap(acc_s, params.softcap);
}
mask.template apply_mask<Is_causal, Is_even_MN>(
acc_s, start_n, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
);
flash::cp_async_wait<0>();
__syncthreads();
if (block_index > 0) {
tKgKBlock.data() = tKgKBlockData + blks_ptr[block_index - 1] * int64_t(params.k_row_stride);
flash::copy<true, Is_even_K>(gmem_tiled_copy_QKV, tKgKBlock, tKsK, tKVcKV, tKVpKV);
cute::cp_async_fence();
}
n == 0
? softmax.template softmax_rescale_o<true, Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2)
: softmax.template softmax_rescale_o<false, Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(acc_s);
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor rP_drop = make_fragment_like(rP);
cute::copy(rP, rP_drop);
dropout.template apply_dropout<true>(
rP_drop, block_row_idx, block_col_idx, kNWarps
);
cute::copy(rP_drop, tSgS);
tSgS.data() = tSgS.data() + (-kBlockN);
}
if (Is_dropout) {
dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
}
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}
for (; block_index >= 0; --block_index) {
int start_n = blks_ptr[block_index];
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
tVgVBlock.data() = tVgVBlockData + start_n * int64_t(params.v_row_stride);
flash::copy<true, Is_even_K>(gmem_tiled_copy_QKV, tVgVBlock, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence();
flash::gemm<Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
flash::apply_softcap(acc_s, params.softcap);
}
flash::cp_async_wait<0>();
__syncthreads();
if (block_index > 0) {
tKgKBlock.data() = tKgKBlockData + blks_ptr[block_index - 1] * int64_t(params.k_row_stride);
flash::copy<true, Is_even_K>(gmem_tiled_copy_QKV, tKgKBlock, tKsK, tKVcKV, tKVpKV);
cute::cp_async_fence();
}
softmax.template softmax_rescale_o<false, Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(acc_s);
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor rP_drop = make_fragment_like(rP);
cute::copy(rP, rP_drop);
dropout.template apply_dropout<true>(
rP_drop, block_row_idx, block_col_idx, kNWarps
);
cute::copy(rP_drop, tSgS);
tSgS.data() = tSgS.data() + (-kBlockN);
}
if (Is_dropout) {
dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
}
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}
}
if (num_cols > 0) {
auto* cols_ptr = params.column_index + ((bidb * params.h + bidh) * params.NUM_ROWS + m_block) * params.NNZ_V;
#pragma unroll
for (int m = 0; m < size<1>(tKgKToken); ++m) {
if (Is_even_MN || get<0>(tKVcKV(0, m, 0)) < num_cols) { tKgKToken.data() = tKgKTokenData + cols_ptr[get<0>(tKVcKV(0, m, 0))] * int64_t(params.k_row_stride);
#pragma unroll
for (int k = 0; k < size<2>(tKgKToken); ++k) {
if (Is_even_K || tKVpKV(k)) {
cute::copy(gmem_tiled_copy_QKV, tKgKToken(_, m, k), tKsK(_, m, k));
} else { cute::clear(tKsK(_, m, k));
}
}
}
}
cute::cp_async_fence();
for (int n = 0; n < num_cols_block; ++n) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); clear(acc_s);
flash::cp_async_wait<0>();
__syncthreads();
if (n < num_cols_block - 1) {
#pragma unroll
for (int m = 0; m < size<1>(tVgVToken); ++m) {
tVgVToken.data() = tVgVTokenData + cols_ptr[n * kBlockN + get<0>(tKVcKV(0, m, 0))] * int64_t(params.v_row_stride);
#pragma unroll
for (int k = 0; k < size<2>(tVgVToken); ++k) {
if (Is_even_K || tKVpKV(k)) {
cute::copy(gmem_tiled_copy_QKV, tVgVToken(_, m, k), tVsV(_, m, k));
} else { cute::clear(tVsV(_, m, k));
}
}
}
} else {
#pragma unroll
for (int m = 0; m < size<1>(tVgVToken); ++m) {
if (Is_even_MN || n * kBlockN + get<0>(tKVcKV(0, m, 0)) < num_cols) { tVgVToken.data() = tVgVTokenData + cols_ptr[n * kBlockN + get<0>(tKVcKV(0, m, 0))] * int64_t(params.v_row_stride);
#pragma unroll
for (int k = 0; k < size<2>(tVgVToken); ++k) {
if (Is_even_K || tKVpKV(k)) {
cute::copy(gmem_tiled_copy_QKV, tVgVToken(_, m, k), tVsV(_, m, k));
} else { cute::clear(tVsV(_, m, k));
}
}
} else { cute::clear(tVsV(_, m, _));
}
}
}
cute::cp_async_fence();
flash::gemm<Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
flash::apply_softcap(acc_s, params.softcap);
}
if (n >= num_cols_block - n_masking_steps) {
Tensor tensor = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = n * kBlockN + (lane_id % 4) * 2;
const int row_idx_offset = m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4;
const int warp_row_stride = kNWarps * 16;
const int max_seqlen_k = binfo.actual_seqlen_k;
const int max_seqlen_q = binfo.actual_seqlen_q;
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
const int col_idx_limit = Is_causal ? std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q) : max_seqlen_k;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
if (col_idx_base + j >= num_cols || cols_ptr[col_idx_base + j] >= col_idx_limit) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
}
flash::cp_async_wait<0>();
__syncthreads();
if (n < num_cols_block - 2) {
#pragma unroll
for (int m = 0; m < size<1>(tKgKToken); ++m) {
int token_idx = cols_ptr[(n + 1) * kBlockN + get<0>(tKVcKV(0, m, 0))];
tKgKToken.data() = tKgKTokenData + token_idx * int64_t(params.k_row_stride);
#pragma unroll
for (int k = 0; k < size<2>(tKgKToken); ++k) {
if (Is_even_K || tKVpKV(k)) {
cute::copy(gmem_tiled_copy_QKV, tKgKToken(_, m, k), tKsK(_, m, k));
} else { cute::clear(tKsK(_, m, k));
}
}
}
cute::cp_async_fence();
} else if (n == num_cols_block - 2) {
#pragma unroll
for (int m = 0; m < size<1>(tKgKToken); ++m) {
if (Is_even_MN || (n + 1) * kBlockN + get<0>(tKVcKV(0, m, 0)) < num_cols) { int token_idx = cols_ptr[(n + 1) * kBlockN + get<0>(tKVcKV(0, m, 0))];
tKgKToken.data() = tKgKTokenData + token_idx * int64_t(params.k_row_stride);
#pragma unroll
for (int k = 0; k < size<2>(tKgKToken); ++k) {
if (Is_even_K || tKVpKV(k)) {
cute::copy(gmem_tiled_copy_QKV, tKgKToken(_, m, k), tKsK(_, m, k));
} else { cute::clear(tKsK(_, m, k));
}
}
}
}
cute::cp_async_fence();
}
(num_blks <= 0 && n ==0)
? softmax.template softmax_rescale_o<true, Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2)
: softmax.template softmax_rescale_o<false, Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(acc_s);
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor rP_drop = make_fragment_like(rP);
cute::copy(rP, rP_drop);
dropout.template apply_dropout<true>(
rP_drop, block_row_idx, block_col_idx, kNWarps
);
cute::copy(rP_drop, tSgS);
tSgS.data() = tSgS.data() + (-kBlockN);
}
if (Is_dropout) {
dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
}
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}
}
Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout);
Tensor rO = flash::convert_type<Element>(acc_o);
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); Tensor taccOsO = smem_thr_copy_O.partition_D(sO);
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr)
+ binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)),
make_shape(binfo.actual_seqlen_q, params.h, params.d),
make_stride(params.o_row_stride, params.o_head_stride, _1{}));
Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_coord(m_block, 0)); Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo);
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
__syncthreads();
Tensor tOrO = make_tensor<Element>(shape(tOgO));
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); Tensor taccOcO = thr_mma.partition_C(caccO); static_assert(decltype(size<0>(taccOcO))::value == 4);
Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); if (get<1>(taccOcO_row(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi));
if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
}
}
Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); Tensor tOcO = gmem_thr_copy_O.partition_D(cO); Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
}
flash::copy<Is_even_MN, Is_even_K, false, false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params>
inline __device__ void compute_sparse_attn(const Params ¶ms) {
const int m_block = blockIdx.x;
const int bidb = blockIdx.y;
const int bidh = blockIdx.z;
flash::sparse_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block);
}
}