use metaltile::{bench_kernel, kernel};
use metaltile_core::ir::KernelMode;
use crate::{
bench_types::DType,
spec::{BenchDispatch, BenchSpec},
};
static SCAN_SHAPES: &[(usize, usize)] = &[(1_024, 4_096)];
#[bench_kernel(
op="scan",
subop="scan",
class=Scan,
shapes=&SCAN_SHAPES,
tpg=256,
tol=1e-3,
mlx="contig_scan_inclusive_sum_{tn}_{tn}",
metal_file="scan.metal",
)]
#[kernel]
pub fn mt_scan<T>(inp: Tensor<T>, out: Tensor<T>, #[constexpr] n: u32) {
let row = program_id::<1>();
let lid = tid;
let lane = simd_lane;
let sg = simd_id;
let ns = n_simd;
let row_off = row * n;
threadgroup_alloc("sgs", 9);
if lid == 0 {
threadgroup_store("sgs", ns, 0);
}
threadgroup_barrier();
let zero_f = threadgroup_load("sgs", ns);
let chunk = lsize * 4u32;
let n_iters = (n + chunk - 1u32) / chunk;
for _r in range(0, n_iters, 1) {
let base = _r * chunk + lid * 4u32;
let v0 = select(base < n, load(inp[row_off + base]).cast::<f32>(), zero_f);
let v1 = select(base + 1u32 < n, load(inp[row_off + base + 1u32]).cast::<f32>(), zero_f);
let v2 = select(base + 2u32 < n, load(inp[row_off + base + 2u32]).cast::<f32>(), zero_f);
let v3 = select(base + 3u32 < n, load(inp[row_off + base + 3u32]).cast::<f32>(), zero_f);
let s1 = v0 + v1;
let s2 = s1 + v2;
let s3 = s2 + v3;
let thread_excl = simd_scan_exclusive(s3);
if lane == 31 {
threadgroup_store("sgs", sg, thread_excl + s3);
}
threadgroup_barrier();
if sg == 0 {
let wt = select(lane < ns, threadgroup_load("sgs", lane), zero_f);
let wt_excl = simd_scan_exclusive(wt);
if lane < ns {
threadgroup_store("sgs", lane, wt_excl);
}
}
threadgroup_barrier();
let cur_prefix = threadgroup_load("sgs", ns);
let warp_excl = threadgroup_load("sgs", sg);
let base_prefix = cur_prefix + warp_excl + thread_excl;
if base < n {
store(out[row_off + base], (base_prefix + v0).cast::<T>());
}
if base + 1u32 < n {
store(out[row_off + base + 1u32], (base_prefix + s1).cast::<T>());
}
if base + 2u32 < n {
store(out[row_off + base + 2u32], (base_prefix + s2).cast::<T>());
}
if base + 3u32 < n {
store(out[row_off + base + 3u32], (base_prefix + s3).cast::<T>());
}
threadgroup_barrier();
if lid == lsize - 1 {
threadgroup_store("sgs", ns, base_prefix + s3);
}
threadgroup_barrier();
}
}
#[kernel]
pub fn mt_scan_exclusive<T>(inp: Tensor<T>, out: Tensor<T>, #[constexpr] n: u32) {
let row = program_id::<1>();
let lid = tid;
let lane = simd_lane;
let sg = simd_id;
let ns = n_simd;
let row_off = row * n;
threadgroup_alloc("sgs", 9);
if lid == 0 {
threadgroup_store("sgs", ns, 0);
}
threadgroup_barrier();
let zero_f = threadgroup_load("sgs", ns);
let chunk = lsize * 4u32;
let n_iters = (n + chunk - 1u32) / chunk;
for _r in range(0, n_iters, 1) {
let base = _r * chunk + lid * 4u32;
let v0 = select(base < n, load(inp[row_off + base]).cast::<f32>(), zero_f);
let v1 = select(base + 1u32 < n, load(inp[row_off + base + 1u32]).cast::<f32>(), zero_f);
let v2 = select(base + 2u32 < n, load(inp[row_off + base + 2u32]).cast::<f32>(), zero_f);
let v3 = select(base + 3u32 < n, load(inp[row_off + base + 3u32]).cast::<f32>(), zero_f);
let s1 = v0 + v1;
let s2 = s1 + v2;
let s3 = s2 + v3;
let thread_excl = simd_scan_exclusive(s3);
if lane == 31 {
threadgroup_store("sgs", sg, thread_excl + s3);
}
threadgroup_barrier();
if sg == 0 {
let wt = select(lane < ns, threadgroup_load("sgs", lane), zero_f);
let wt_excl = simd_scan_exclusive(wt);
if lane < ns {
threadgroup_store("sgs", lane, wt_excl);
}
}
threadgroup_barrier();
let cur_prefix = threadgroup_load("sgs", ns);
let warp_excl = threadgroup_load("sgs", sg);
let base_prefix = cur_prefix + warp_excl + thread_excl;
if base < n {
store(out[row_off + base], base_prefix.cast::<T>());
}
if base + 1u32 < n {
store(out[row_off + base + 1u32], (base_prefix + v0).cast::<T>());
}
if base + 2u32 < n {
store(out[row_off + base + 2u32], (base_prefix + s1).cast::<T>());
}
if base + 3u32 < n {
store(out[row_off + base + 3u32], (base_prefix + s2).cast::<T>());
}
threadgroup_barrier();
if lid == lsize - 1 {
threadgroup_store("sgs", ns, base_prefix + s3);
}
threadgroup_barrier();
}
}
inventory::submit! {
BenchSpec {
op: "scan",
subop: "scan_exclusive",
kernel_name: "mt_scan_exclusive",
kernel_ir: mt_scan_exclusive::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-3,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}
#[kernel]
pub fn mt_scan_prod<T>(inp: Tensor<T>, out: Tensor<T>, #[constexpr] n: u32) {
let row = program_id::<1>();
let lid = tid;
let ns = n_simd;
let row_off = row * n;
threadgroup_alloc("sgs", 9);
threadgroup_alloc("tgs", 256);
if lid == 0 {
threadgroup_store("sgs", ns, 1.0f32);
}
threadgroup_barrier();
let one_f = threadgroup_load("sgs", ns);
let chunk = lsize * 4u32;
let n_iters = (n + chunk - 1u32) / chunk;
for _r in range(0, n_iters, 1) {
let base = _r * chunk + lid * 4u32;
let v0 = select(base < n, load(inp[row_off + base]).cast::<f32>(), one_f);
let v1 = select(base + 1u32 < n, load(inp[row_off + base + 1u32]).cast::<f32>(), one_f);
let v2 = select(base + 2u32 < n, load(inp[row_off + base + 2u32]).cast::<f32>(), one_f);
let v3 = select(base + 3u32 < n, load(inp[row_off + base + 3u32]).cast::<f32>(), one_f);
let p1 = v0 * v1;
let p2 = p1 * v2;
let p3 = p2 * v3;
threadgroup_store("tgs", lid, p3);
threadgroup_barrier();
let mut t_excl = one_f;
for _i in range(0u32, lid, 1u32) {
t_excl = t_excl * threadgroup_load("tgs", _i);
}
let cur_prefix = threadgroup_load("sgs", ns);
let base_prefix = cur_prefix * t_excl;
if base < n {
store(out[row_off + base], (base_prefix * v0).cast::<T>());
}
if base + 1u32 < n {
store(out[row_off + base + 1u32], (base_prefix * p1).cast::<T>());
}
if base + 2u32 < n {
store(out[row_off + base + 2u32], (base_prefix * p2).cast::<T>());
}
if base + 3u32 < n {
store(out[row_off + base + 3u32], (base_prefix * p3).cast::<T>());
}
threadgroup_barrier();
if lid == lsize - 1 {
threadgroup_store("sgs", ns, base_prefix * p3);
}
threadgroup_barrier();
}
}
inventory::submit! {
BenchSpec {
op: "scan",
subop: "scan_prod",
kernel_name: "mt_scan_prod",
kernel_ir: mt_scan_prod::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-3,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}
#[kernel]
pub fn mt_scan_prod_exclusive<T>(inp: Tensor<T>, out: Tensor<T>, #[constexpr] n: u32) {
let row = program_id::<1>();
let lid = tid;
let ns = n_simd;
let row_off = row * n;
threadgroup_alloc("sgs", 9);
threadgroup_alloc("tgs", 256);
if lid == 0 {
threadgroup_store("sgs", ns, 1.0f32);
}
threadgroup_barrier();
let one_f = threadgroup_load("sgs", ns);
let chunk = lsize * 4u32;
let n_iters = (n + chunk - 1u32) / chunk;
for _r in range(0, n_iters, 1) {
let base = _r * chunk + lid * 4u32;
let v0 = select(base < n, load(inp[row_off + base]).cast::<f32>(), one_f);
let v1 = select(base + 1u32 < n, load(inp[row_off + base + 1u32]).cast::<f32>(), one_f);
let v2 = select(base + 2u32 < n, load(inp[row_off + base + 2u32]).cast::<f32>(), one_f);
let v3 = select(base + 3u32 < n, load(inp[row_off + base + 3u32]).cast::<f32>(), one_f);
let p1 = v0 * v1;
let p2 = p1 * v2;
let p3 = p2 * v3;
threadgroup_store("tgs", lid, p3);
threadgroup_barrier();
let mut t_excl = one_f;
for _i in range(0u32, lid, 1u32) {
t_excl = t_excl * threadgroup_load("tgs", _i);
}
let cur_prefix = threadgroup_load("sgs", ns);
let base_prefix = cur_prefix * t_excl;
if base < n {
store(out[row_off + base], base_prefix.cast::<T>());
}
if base + 1u32 < n {
store(out[row_off + base + 1u32], (base_prefix * v0).cast::<T>());
}
if base + 2u32 < n {
store(out[row_off + base + 2u32], (base_prefix * p1).cast::<T>());
}
if base + 3u32 < n {
store(out[row_off + base + 3u32], (base_prefix * p2).cast::<T>());
}
threadgroup_barrier();
if lid == lsize - 1 {
threadgroup_store("sgs", ns, base_prefix * p3);
}
threadgroup_barrier();
}
}
inventory::submit! {
BenchSpec {
op: "scan",
subop: "scan_prod_exclusive",
kernel_name: "mt_scan_prod_exclusive",
kernel_ir: mt_scan_prod_exclusive::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-3,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}
#[kernel]
pub fn mt_scan_max<T>(inp: Tensor<T>, out: Tensor<T>, #[constexpr] n: u32) {
let row = program_id::<1>();
let lid = tid;
let ns = n_simd;
let row_off = row * n;
threadgroup_alloc("sgs", 9);
threadgroup_alloc("tgs", 256);
if lid == 0 {
threadgroup_store("sgs", ns, neg_infinity());
}
threadgroup_barrier();
let neginf = threadgroup_load("sgs", ns);
let chunk = lsize * 4u32;
let n_iters = (n + chunk - 1u32) / chunk;
for _r in range(0, n_iters, 1) {
let base = _r * chunk + lid * 4u32;
let v0 = select(base < n, load(inp[row_off + base]).cast::<f32>(), neginf);
let v1 = select(base + 1u32 < n, load(inp[row_off + base + 1u32]).cast::<f32>(), neginf);
let v2 = select(base + 2u32 < n, load(inp[row_off + base + 2u32]).cast::<f32>(), neginf);
let v3 = select(base + 3u32 < n, load(inp[row_off + base + 3u32]).cast::<f32>(), neginf);
let m1 = select(v0 > v1, v0, v1);
let m2 = select(m1 > v2, m1, v2);
let m3 = select(m2 > v3, m2, v3);
threadgroup_store("tgs", lid, m3);
threadgroup_barrier();
let mut t_excl = neginf;
for _i in range(0u32, lid, 1u32) {
let v = threadgroup_load("tgs", _i);
t_excl = select(v > t_excl, v, t_excl);
}
let cur_prefix = threadgroup_load("sgs", ns);
let base_prefix = select(cur_prefix > t_excl, cur_prefix, t_excl);
let out0 = select(base_prefix > v0, base_prefix, v0);
let out1 = select(out0 > v1, out0, v1);
let out2 = select(out1 > v2, out1, v2);
let out3 = select(out2 > v3, out2, v3);
if base < n {
store(out[row_off + base], out0.cast::<T>());
}
if base + 1u32 < n {
store(out[row_off + base + 1u32], out1.cast::<T>());
}
if base + 2u32 < n {
store(out[row_off + base + 2u32], out2.cast::<T>());
}
if base + 3u32 < n {
store(out[row_off + base + 3u32], out3.cast::<T>());
}
threadgroup_barrier();
if lid == lsize - 1 {
threadgroup_store("sgs", ns, out3);
}
threadgroup_barrier();
}
}
inventory::submit! {
BenchSpec {
op: "scan",
subop: "scan_max",
kernel_name: "mt_scan_max",
kernel_ir: mt_scan_max::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-4,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}
#[kernel]
pub fn mt_scan_max_exclusive<T>(inp: Tensor<T>, out: Tensor<T>, #[constexpr] n: u32) {
let row = program_id::<1>();
let lid = tid;
let ns = n_simd;
let row_off = row * n;
threadgroup_alloc("sgs", 9);
threadgroup_alloc("tgs", 256);
if lid == 0 {
threadgroup_store("sgs", ns, neg_infinity());
}
threadgroup_barrier();
let neginf = threadgroup_load("sgs", ns);
let chunk = lsize * 4u32;
let n_iters = (n + chunk - 1u32) / chunk;
for _r in range(0, n_iters, 1) {
let base = _r * chunk + lid * 4u32;
let v0 = select(base < n, load(inp[row_off + base]).cast::<f32>(), neginf);
let v1 = select(base + 1u32 < n, load(inp[row_off + base + 1u32]).cast::<f32>(), neginf);
let v2 = select(base + 2u32 < n, load(inp[row_off + base + 2u32]).cast::<f32>(), neginf);
let v3 = select(base + 3u32 < n, load(inp[row_off + base + 3u32]).cast::<f32>(), neginf);
let m1 = select(v0 > v1, v0, v1);
let m2 = select(m1 > v2, m1, v2);
let m3 = select(m2 > v3, m2, v3);
threadgroup_store("tgs", lid, m3);
threadgroup_barrier();
let mut t_excl = neginf;
for _i in range(0u32, lid, 1u32) {
let v = threadgroup_load("tgs", _i);
t_excl = select(v > t_excl, v, t_excl);
}
let cur_prefix = threadgroup_load("sgs", ns);
let base_prefix = select(cur_prefix > t_excl, cur_prefix, t_excl);
let ep1 = select(base_prefix > v0, base_prefix, v0);
let ep2 = select(ep1 > v1, ep1, v1);
let ep3 = select(ep2 > v2, ep2, v2);
if base < n {
store(out[row_off + base], base_prefix.cast::<T>());
}
if base + 1u32 < n {
store(out[row_off + base + 1u32], ep1.cast::<T>());
}
if base + 2u32 < n {
store(out[row_off + base + 2u32], ep2.cast::<T>());
}
if base + 3u32 < n {
store(out[row_off + base + 3u32], ep3.cast::<T>());
}
threadgroup_barrier();
let chunk_max = select(base_prefix > m3, base_prefix, m3);
if lid == lsize - 1 {
threadgroup_store("sgs", ns, chunk_max);
}
threadgroup_barrier();
}
}
inventory::submit! {
BenchSpec {
op: "scan",
subop: "scan_max_exclusive",
kernel_name: "mt_scan_max_exclusive",
kernel_ir: mt_scan_max_exclusive::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-4,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}
#[kernel]
pub fn mt_scan_min<T>(inp: Tensor<T>, out: Tensor<T>, #[constexpr] n: u32) {
let row = program_id::<1>();
let lid = tid;
let ns = n_simd;
let row_off = row * n;
threadgroup_alloc("sgs", 9);
threadgroup_alloc("tgs", 256);
if lid == 0 {
threadgroup_store("sgs", ns, infinity());
}
threadgroup_barrier();
let posinf = threadgroup_load("sgs", ns);
let chunk = lsize * 4u32;
let n_iters = (n + chunk - 1u32) / chunk;
for _r in range(0, n_iters, 1) {
let base = _r * chunk + lid * 4u32;
let v0 = select(base < n, load(inp[row_off + base]).cast::<f32>(), posinf);
let v1 = select(base + 1u32 < n, load(inp[row_off + base + 1u32]).cast::<f32>(), posinf);
let v2 = select(base + 2u32 < n, load(inp[row_off + base + 2u32]).cast::<f32>(), posinf);
let v3 = select(base + 3u32 < n, load(inp[row_off + base + 3u32]).cast::<f32>(), posinf);
let m1 = select(v0 < v1, v0, v1);
let m2 = select(m1 < v2, m1, v2);
let m3 = select(m2 < v3, m2, v3);
threadgroup_store("tgs", lid, m3);
threadgroup_barrier();
let mut t_excl = posinf;
for _i in range(0u32, lid, 1u32) {
let v = threadgroup_load("tgs", _i);
t_excl = select(v < t_excl, v, t_excl);
}
let cur_prefix = threadgroup_load("sgs", ns);
let base_prefix = select(cur_prefix < t_excl, cur_prefix, t_excl);
let out0 = select(base_prefix < v0, base_prefix, v0);
let out1 = select(out0 < v1, out0, v1);
let out2 = select(out1 < v2, out1, v2);
let out3 = select(out2 < v3, out2, v3);
if base < n {
store(out[row_off + base], out0.cast::<T>());
}
if base + 1u32 < n {
store(out[row_off + base + 1u32], out1.cast::<T>());
}
if base + 2u32 < n {
store(out[row_off + base + 2u32], out2.cast::<T>());
}
if base + 3u32 < n {
store(out[row_off + base + 3u32], out3.cast::<T>());
}
threadgroup_barrier();
if lid == lsize - 1 {
threadgroup_store("sgs", ns, out3);
}
threadgroup_barrier();
}
}
inventory::submit! {
BenchSpec {
op: "scan",
subop: "scan_min",
kernel_name: "mt_scan_min",
kernel_ir: mt_scan_min::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-4,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}
#[kernel]
pub fn mt_scan_min_exclusive<T>(inp: Tensor<T>, out: Tensor<T>, #[constexpr] n: u32) {
let row = program_id::<1>();
let lid = tid;
let ns = n_simd;
let row_off = row * n;
threadgroup_alloc("sgs", 9);
threadgroup_alloc("tgs", 256);
if lid == 0 {
threadgroup_store("sgs", ns, infinity());
}
threadgroup_barrier();
let posinf = threadgroup_load("sgs", ns);
let chunk = lsize * 4u32;
let n_iters = (n + chunk - 1u32) / chunk;
for _r in range(0, n_iters, 1) {
let base = _r * chunk + lid * 4u32;
let v0 = select(base < n, load(inp[row_off + base]).cast::<f32>(), posinf);
let v1 = select(base + 1u32 < n, load(inp[row_off + base + 1u32]).cast::<f32>(), posinf);
let v2 = select(base + 2u32 < n, load(inp[row_off + base + 2u32]).cast::<f32>(), posinf);
let v3 = select(base + 3u32 < n, load(inp[row_off + base + 3u32]).cast::<f32>(), posinf);
let m1 = select(v0 < v1, v0, v1);
let m2 = select(m1 < v2, m1, v2);
let m3 = select(m2 < v3, m2, v3);
threadgroup_store("tgs", lid, m3);
threadgroup_barrier();
let mut t_excl = posinf;
for _i in range(0u32, lid, 1u32) {
let v = threadgroup_load("tgs", _i);
t_excl = select(v < t_excl, v, t_excl);
}
let cur_prefix = threadgroup_load("sgs", ns);
let base_prefix = select(cur_prefix < t_excl, cur_prefix, t_excl);
let ep1 = select(base_prefix < v0, base_prefix, v0);
let ep2 = select(ep1 < v1, ep1, v1);
let ep3 = select(ep2 < v2, ep2, v2);
if base < n {
store(out[row_off + base], base_prefix.cast::<T>());
}
if base + 1u32 < n {
store(out[row_off + base + 1u32], ep1.cast::<T>());
}
if base + 2u32 < n {
store(out[row_off + base + 2u32], ep2.cast::<T>());
}
if base + 3u32 < n {
store(out[row_off + base + 3u32], ep3.cast::<T>());
}
threadgroup_barrier();
let chunk_min = select(base_prefix < m3, base_prefix, m3);
if lid == lsize - 1 {
threadgroup_store("sgs", ns, chunk_min);
}
threadgroup_barrier();
}
}
inventory::submit! {
BenchSpec {
op: "scan",
subop: "scan_min_exclusive",
kernel_name: "mt_scan_min_exclusive",
kernel_ir: mt_scan_min_exclusive::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 1e-4,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}