use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="sdpa",
subop="sdpa_bidirectional_d64",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Reduction,
)]
#[kernel]
pub fn ffai_sdpa_bidirectional_d64<T>(
q: Tensor<T>,
k: Tensor<T>,
v: Tensor<T>,
out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] n_q_heads: u32,
#[constexpr] base_kv: u32,
#[constexpr] n_query: u32,
#[constexpr] kv_stride: u32,
#[constexpr] heads_per_group: u32,
#[constexpr] scale: f32,
) {
let tg = tgid_x;
let query_idx = tg / n_q_heads;
let q_head = tg % n_q_heads;
let kv_head = q_head / heads_per_group;
let sg = simd_id;
let lane = simd_lane;
let ns = n_simd;
let n_kv = base_kv + n_query;
threadgroup_alloc("tg_max", 32);
threadgroup_alloc("tg_sum", 32);
threadgroup_alloc("tg_out0", 1056);
threadgroup_alloc("tg_out1", 1056);
let q_off = (query_idx * n_q_heads + q_head) * head_dim;
let kv_head_base = kv_head * kv_stride * head_dim;
let d0 = lane * 2u32;
let q0 = load(q[q_off + d0]).cast::<f32>() * scale;
let q1 = load(q[q_off + d0 + 1u32]).cast::<f32>() * scale;
let mut run_max = neg_infinity();
let mut run_sum = 0.0f32;
let mut o0 = 0.0f32;
let mut o1 = 0.0f32;
for _t in range(sg, n_kv, ns) {
let base = kv_head_base + _t * head_dim;
let kv_idx = base + d0;
let kv0 = kv_idx;
let kv1 = kv_idx + 1u32;
let k0_raw = load(k[kv0]);
let k1_raw = load(k[kv1]);
let k0 = k0_raw.cast::<f32>();
let k1 = k1_raw.cast::<f32>();
let partial = q0 * k0 + q1 * k1;
let score = simd_sum(partial);
let new_max = select(score > run_max, score, run_max);
let factor = exp(run_max - new_max);
let weight = exp(score - new_max);
run_sum = run_sum * factor + weight;
run_max = new_max;
let v0_raw = load(v[kv0]);
let v1_raw = load(v[kv1]);
let v0 = v0_raw.cast::<f32>();
let v1 = v1_raw.cast::<f32>();
o0 = o0 * factor + weight * v0;
o1 = o1 * factor + weight * v1;
}
if lane == 0 {
threadgroup_store("tg_max", sg, run_max);
threadgroup_store("tg_sum", sg, run_sum);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max = threadgroup_load("tg_max", 0);
let g_sum = threadgroup_load("tg_sum", 0);
let rescale = select(g_sum > 0.0f32, exp(run_max - g_max) / g_sum, 0.0f32);
let stride = ns + 1u32;
let idx = lane * stride + sg;
threadgroup_store("tg_out0", idx, o0 * rescale);
threadgroup_store("tg_out1", idx, o1 * rescale);
threadgroup_barrier();
if sg == 0 {
let mut so0 = 0.0f32;
let mut so1 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so0 = so0 + threadgroup_load("tg_out0", ri);
so1 = so1 + threadgroup_load("tg_out1", ri);
}
let out_off = q_off + d0;
store(out[out_off], so0.cast::<T>());
store(out[out_off + 1u32], so1.cast::<T>());
}
}
#[bench_kernel(
op="sdpa",
subop="sdpa_bidirectional_d32",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Reduction,
)]
#[kernel]
pub fn ffai_sdpa_bidirectional_d32<T>(
q: Tensor<T>,
k: Tensor<T>,
v: Tensor<T>,
out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] n_q_heads: u32,
#[constexpr] base_kv: u32,
#[constexpr] n_query: u32,
#[constexpr] kv_stride: u32,
#[constexpr] heads_per_group: u32,
#[constexpr] scale: f32,
) {
let tg = tgid_x;
let query_idx = tg / n_q_heads;
let q_head = tg % n_q_heads;
let kv_head = q_head / heads_per_group;
let sg = simd_id;
let lane = simd_lane;
let ns = n_simd;
let n_kv = base_kv + n_query;
threadgroup_alloc("tg_max", 32);
threadgroup_alloc("tg_sum", 32);
threadgroup_alloc("tg_out0", 1056);
let q_off = (query_idx * n_q_heads + q_head) * head_dim;
let kv_head_base = kv_head * kv_stride * head_dim;
let d0 = lane;
let q0 = load(q[q_off + d0]).cast::<f32>() * scale;
let mut run_max = neg_infinity();
let mut run_sum = 0.0f32;
let mut o0 = 0.0f32;
for _t in range(sg, n_kv, ns) {
let base = kv_head_base + _t * head_dim;
let kv0 = base + d0;
let k0_raw = load(k[kv0]);
let k0 = k0_raw.cast::<f32>();
let partial = q0 * k0;
let score = simd_sum(partial);
let new_max = select(score > run_max, score, run_max);
let factor = exp(run_max - new_max);
let weight = exp(score - new_max);
run_sum = run_sum * factor + weight;
run_max = new_max;
let v0_raw = load(v[kv0]);
let v0 = v0_raw.cast::<f32>();
o0 = o0 * factor + weight * v0;
}
if lane == 0 {
threadgroup_store("tg_max", sg, run_max);
threadgroup_store("tg_sum", sg, run_sum);
}
threadgroup_barrier();
if sg == 0 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max = threadgroup_load("tg_max", 0);
let g_sum = threadgroup_load("tg_sum", 0);
let rescale = select(g_sum > 0.0f32, exp(run_max - g_max) / g_sum, 0.0f32);
let stride = ns + 1u32;
let idx = lane * stride + sg;
threadgroup_store("tg_out0", idx, o0 * rescale);
threadgroup_barrier();
if sg == 0 {
let mut so0 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so0 = so0 + threadgroup_load("tg_out0", ri);
}
let out_off = q_off + d0;
store(out[out_off], so0.cast::<T>());
}
}
#[bench_kernel(
op="sdpa",
subop="sdpa_bidirectional_d72",
class=GenericEmpty,
tol=1e-3,
kernel_mode=Reduction,
)]
#[kernel]
pub fn ffai_sdpa_bidirectional_d72<T>(
q: Tensor<T>,
k: Tensor<T>,
v: Tensor<T>,
out: Tensor<T>,
#[constexpr] head_dim: u32,
#[constexpr] n_q_heads: u32,
#[constexpr] base_kv: u32,
#[constexpr] n_query: u32,
#[constexpr] kv_stride: u32,
#[constexpr] heads_per_group: u32,
#[constexpr] scale: f32,
) {
let tg = tgid_x;
let query_idx = tg / n_q_heads;
let q_head = tg % n_q_heads;
let kv_head = q_head / heads_per_group;
let sg = simd_id;
let lane = simd_lane;
let ns = n_simd;
let n_kv = base_kv + n_query;
threadgroup_alloc("tg_max", 32);
threadgroup_alloc("tg_sum", 32);
threadgroup_alloc("tg_out0", 1056);
threadgroup_alloc("tg_out1", 1056);
threadgroup_alloc("tg_out2", 1056);
let q_off = (query_idx * n_q_heads + q_head) * head_dim;
let kv_head_base = kv_head * kv_stride * head_dim;
let d0 = lane * 3u32;
let d1 = d0 + 1u32;
let d2 = d0 + 2u32;
let d0s = select(d0 < head_dim, d0, 0u32);
let d1s = select(d1 < head_dim, d1, 0u32);
let d2s = select(d2 < head_dim, d2, 0u32);
let q0 = select(d0 < head_dim, load(q[q_off + d0s]).cast::<f32>() * scale, 0.0f32);
let q1 = select(d1 < head_dim, load(q[q_off + d1s]).cast::<f32>() * scale, 0.0f32);
let q2 = select(d2 < head_dim, load(q[q_off + d2s]).cast::<f32>() * scale, 0.0f32);
let mut run_max = neg_infinity();
let mut run_sum = 0.0f32;
let mut o0 = 0.0f32;
let mut o1 = 0.0f32;
let mut o2 = 0.0f32;
for _t in range(sg, n_kv, ns) {
let base = kv_head_base + _t * head_dim;
let k0 = select(d0 < head_dim, load(k[base + d0s]).cast::<f32>(), 0.0f32);
let k1 = select(d1 < head_dim, load(k[base + d1s]).cast::<f32>(), 0.0f32);
let k2 = select(d2 < head_dim, load(k[base + d2s]).cast::<f32>(), 0.0f32);
let partial = q0 * k0 + q1 * k1 + q2 * k2;
let score = simd_sum(partial);
let new_max = select(score > run_max, score, run_max);
let factor = exp(run_max - new_max);
let weight = exp(score - new_max);
run_sum = run_sum * factor + weight;
run_max = new_max;
let v0 = select(d0 < head_dim, load(v[base + d0s]).cast::<f32>(), 0.0f32);
let v1 = select(d1 < head_dim, load(v[base + d1s]).cast::<f32>(), 0.0f32);
let v2 = select(d2 < head_dim, load(v[base + d2s]).cast::<f32>(), 0.0f32);
o0 = o0 * factor + weight * v0;
o1 = o1 * factor + weight * v1;
o2 = o2 * factor + weight * v2;
}
if lane == 0u32 {
threadgroup_store("tg_max", sg, run_max);
threadgroup_store("tg_sum", sg, run_sum);
}
threadgroup_barrier();
if sg == 0u32 {
let g_max_in = select(lane < ns, threadgroup_load("tg_max", lane), neg_infinity());
let g_max = simd_max(g_max_in);
let g_sum_in =
select(lane < ns, threadgroup_load("tg_sum", lane) * exp(g_max_in - g_max), 0.0f32);
let g_sum = simd_sum(g_sum_in);
if lane == 0u32 {
threadgroup_store("tg_max", 0, g_max);
threadgroup_store("tg_sum", 0, g_sum);
}
}
threadgroup_barrier();
let g_max = threadgroup_load("tg_max", 0);
let g_sum = threadgroup_load("tg_sum", 0);
let rescale = select(g_sum > 0.0f32, exp(run_max - g_max) / g_sum, 0.0f32);
let stride = ns + 1u32;
let idx = lane * stride + sg;
threadgroup_store("tg_out0", idx, o0 * rescale);
threadgroup_store("tg_out1", idx, o1 * rescale);
threadgroup_store("tg_out2", idx, o2 * rescale);
threadgroup_barrier();
if sg == 0u32 {
let mut so0 = 0.0f32;
let mut so1 = 0.0f32;
let mut so2 = 0.0f32;
for _g in range(0u32, ns, 1u32) {
let ri = lane * stride + _g;
so0 = so0 + threadgroup_load("tg_out0", ri);
so1 = so1 + threadgroup_load("tg_out1", ri);
so2 = so2 + threadgroup_load("tg_out2", ri);
}
let out_off = q_off + d0;
if d0 < head_dim {
store(out[out_off], so0.cast::<T>());
}
if d1 < head_dim {
store(out[out_off + 1u32], so1.cast::<T>());
}
if d2 < head_dim {
store(out[out_off + 2u32], so2.cast::<T>());
}
}
}