1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
//! Copyright 2026 0xClandestine, Ekryski, TheTom, Ambisphaeric
//! SPDX-License-Identifier: Apache-2.0
//! AURA compressed-domain value aggregation.
//!
//! For each (q_head, dim) output element, computes
//! `Σ_t weight[head, t] · norm[kv_head, t] · codebook[unpack(packed[t, d])]`,
//! skipping tokens whose weight is below `sparse_threshold`.
//!
//! Port of `turbo_value` from
//! `ekryski/mlx@alpha:mlx/backend/metal/kernels/turbo_quant.metal`.
//!
//! ## Layout
//!
//! Inputs:
//! - `weights [q_heads, tokens]` f32 — softmax(scores).
//! - `packed [kv_heads, tokens, packed_width]` u32 — codebook indices.
//! - `norms [kv_heads, tokens]` f32 — per-position norm.
//! - `codebook [2**bits]` f32 — centroids.
//!
//! Output:
//! - `output [q_heads, dim]` f32
//!
//! ## Dispatch
//!
//! Grid3D, one thread per (q_head, dim) output element.
//! `gid.x = d`, `gid.y = head_idx`. Each thread runs a single
//! sequential loop over tokens and accumulates its dim slot's
//! contribution. Sparsity check (`w >= sparse_threshold`) skips
//! cheap-to-zero tokens, mirroring the MLX upstream's
//! flash-pass2-style aggregation guard.
use metaltile::{bench_kernel, kernel};
use crate::bench_types::DType;
// Keep `DType` referenced — `FLOAT_DTYPES` is the production shortlist now
// that the kernel is generic over `T` (fp32/fp16/bf16) for its I/O dtype.
const _: DType = DType::F32;
macro_rules! aura_value_kernel {
($name:ident, $bits:literal, $subop:literal) => {
#[bench_kernel(op="aura", subop=$subop, class=GenericEmpty, tol=0.0, kernel_mode=Grid3D,)]
#[kernel]
pub fn $name<T>(
weights: Tensor<T>,
packed: Tensor<u32>,
norms: Tensor<T>,
codebook: Tensor<T>,
mut output: Tensor<T>,
#[constexpr] dim: u32,
#[constexpr] packed_width: u32,
#[constexpr] tokens: u32,
#[constexpr] repeat_count: u32,
#[constexpr] sparse_threshold: f32,
) {
let d = program_id::<0>();
let head_idx = program_id::<1>();
let kv_head = head_idx / repeat_count;
let mask = (1u32 << $bits) - 1u32;
// Pre-compute the bit-stream coordinates for this thread's
// dim slot. Same for every token — only the base packed
// pointer changes per t.
let bit_offset = d * $bits;
let word_idx = bit_offset / 32u32;
let shift = bit_offset & 31u32;
let bits_in_w0 = 32u32 - shift;
let lo_bits = select(bits_in_w0 >= $bits, $bits, bits_in_w0);
let spill = $bits - lo_bits;
let mut acc = 0.0f32;
for t in range(0u32, tokens, 1u32) {
let w = load(weights[head_idx * tokens + t]).cast::<f32>();
if w >= sparse_threshold {
let norm_val = load(norms[kv_head * tokens + t]).cast::<f32>();
let packed_row = (kv_head * tokens + t) * packed_width;
let w0 = load(packed[packed_row + word_idx]);
let w1_idx = select(spill > 0u32, word_idx + 1u32, word_idx);
let w1 = load(packed[packed_row + w1_idx]);
let lo = (w0 >> shift) & ((1u32 << lo_bits) - 1u32);
let hi = (w1 & ((1u32 << spill) - 1u32)) << lo_bits;
let value = (lo | hi) & mask;
let centroid = load(codebook[value]).cast::<f32>();
acc = acc + w * norm_val * centroid;
}
}
store(output[head_idx * dim + d], acc.cast::<T>());
}
};
}
aura_value_kernel!(aura_value_int2, 2u32, "value_int2");
aura_value_kernel!(aura_value_int3, 3u32, "value_int3");
aura_value_kernel!(aura_value_int4, 4u32, "value_int4");
aura_value_kernel!(aura_value_int6, 6u32, "value_int6");
aura_value_kernel!(aura_value_int8, 8u32, "value_int8");