pub enum Thunk {
Show 107 variants
Nop,
Sgemm {
a: usize,
b: usize,
c: usize,
m: u32,
k: u32,
n: u32,
},
DenseSolveF64 {
a: usize,
b: usize,
x: usize,
n: u32,
nrhs: u32,
},
DenseSolveF32 {
a: usize,
b: usize,
x: usize,
n: u32,
nrhs: u32,
},
BatchedDenseSolveF64 {
a: usize,
b: usize,
x: usize,
batch: u32,
n: u32,
nrhs: u32,
},
BatchedDenseSolveF32 {
a: usize,
b: usize,
x: usize,
batch: u32,
n: u32,
nrhs: u32,
},
BatchedDgemmF64 {
a: usize,
b: usize,
c: usize,
batch: u32,
m: u32,
k: u32,
n: u32,
},
BatchedSgemm {
a: usize,
b: usize,
c: usize,
batch: u32,
m: u32,
k: u32,
n: u32,
},
Dgemm {
a: usize,
b: usize,
c: usize,
m: u32,
k: u32,
n: u32,
},
TransposeF64 {
src: usize,
dst: usize,
in_total: u32,
out_dims: Vec<u32>,
in_strides: Vec<u32>,
},
ActivationF64 {
src: usize,
dst: usize,
len: u32,
kind: Activation,
},
ComplexNormSqF32 {
src: usize,
dst: usize,
len: u32,
},
ComplexNormSqBackwardF32 {
z: usize,
g: usize,
dz: usize,
len: u32,
},
ConjugateC64 {
src: usize,
dst: usize,
len: u32,
},
ActivationC64 {
src: usize,
dst: usize,
len: u32,
kind: Activation,
},
ReduceSumF64 {
src: usize,
dst: usize,
outer: u32,
reduced: u32,
inner: u32,
},
CopyF64 {
src: usize,
dst: usize,
len: u32,
},
BinaryFullF64 {
lhs: usize,
rhs: usize,
dst: usize,
len: u32,
lhs_len: u32,
rhs_len: u32,
op: BinaryOp,
out_dims_bcast: Vec<u32>,
bcast_lhs_strides: Vec<u32>,
bcast_rhs_strides: Vec<u32>,
},
ConcatF64 {
dst: usize,
outer: u32,
inner: u32,
total_axis: u32,
inputs: Vec<(usize, u32)>,
},
BinaryFullC64 {
lhs: usize,
rhs: usize,
dst: usize,
len: u32,
lhs_len: u32,
rhs_len: u32,
op: BinaryOp,
out_dims_bcast: Vec<u32>,
bcast_lhs_strides: Vec<u32>,
bcast_rhs_strides: Vec<u32>,
},
Scan {
body: Arc<ThunkSchedule>,
body_init: Arc<Vec<u8>>,
body_input_off: usize,
body_output_off: usize,
outer_init_off: usize,
outer_final_off: usize,
length: u32,
carry_bytes: u32,
save_trajectory: bool,
xs_inputs: Arc<Vec<(usize, usize, u32)>>,
bcast_inputs: Arc<Vec<(usize, usize, u32)>>,
num_checkpoints: u32,
},
ScanBackward {Show 21 fields
body_vjp: Arc<ThunkSchedule>,
body_init: Arc<Vec<u8>>,
body_carry_in_off: usize,
body_x_offs: Arc<Vec<usize>>,
body_d_output_off: usize,
body_dcarry_out_off: usize,
outer_init_off: usize,
outer_traj_off: usize,
outer_upstream_off: usize,
outer_xs_offs: Arc<Vec<(usize, u32)>>,
outer_dinit_off: usize,
length: u32,
carry_bytes: u32,
carry_elem_size: u32,
save_trajectory: bool,
num_checkpoints: u32,
forward_body: Option<Arc<ThunkSchedule>>,
forward_body_init: Option<Arc<Vec<u8>>>,
forward_body_carry_in_off: usize,
forward_body_output_off: usize,
forward_body_x_offs: Arc<Vec<usize>>,
},
ScanBackwardXs {Show 23 fields
body_vjp: Arc<ThunkSchedule>,
body_init: Arc<Vec<u8>>,
body_carry_in_off: usize,
body_x_offs: Arc<Vec<usize>>,
body_d_output_off: usize,
body_dcarry_out_off: usize,
body_dxs_out_off: usize,
outer_init_off: usize,
outer_traj_off: usize,
outer_upstream_off: usize,
outer_xs_offs: Arc<Vec<(usize, u32)>>,
outer_dxs_off: usize,
length: u32,
carry_bytes: u32,
carry_elem_size: u32,
per_step_bytes: u32,
save_trajectory: bool,
num_checkpoints: u32,
forward_body: Option<Arc<ThunkSchedule>>,
forward_body_init: Option<Arc<Vec<u8>>>,
forward_body_carry_in_off: usize,
forward_body_output_off: usize,
forward_body_x_offs: Arc<Vec<usize>>,
},
CustomFn {
body: Arc<ThunkSchedule>,
body_init: Arc<Vec<u8>>,
inputs: Arc<Vec<(usize, usize, u32)>>,
body_output_off: usize,
outer_output_off: usize,
out_bytes: u32,
},
FusedMmBiasAct {
a: usize,
w: usize,
bias: usize,
c: usize,
m: u32,
k: u32,
n: u32,
act: Option<Activation>,
},
FusedResidualLN {
x: usize,
res: usize,
bias: usize,
g: usize,
b: usize,
out: usize,
rows: u32,
h: u32,
eps: f32,
has_bias: bool,
},
FusedResidualRmsNorm {
x: usize,
res: usize,
bias: usize,
g: usize,
b: usize,
out: usize,
rows: u32,
h: u32,
eps: f32,
has_bias: bool,
},
BiasAdd {
src: usize,
bias: usize,
dst: usize,
m: u32,
n: u32,
},
BinaryFull {
lhs: usize,
rhs: usize,
dst: usize,
len: u32,
lhs_len: u32,
rhs_len: u32,
op: BinaryOp,
out_dims_bcast: Vec<u32>,
bcast_lhs_strides: Vec<u32>,
bcast_rhs_strides: Vec<u32>,
},
ActivationInPlace {
data: usize,
len: u32,
act: Activation,
},
Gather {
table: usize,
table_len: u32,
idx: usize,
dst: usize,
num_idx: u32,
trailing: u32,
},
Narrow {
src: usize,
dst: usize,
outer: u32,
src_stride: u32,
dst_stride: u32,
inner: u32,
elem_bytes: u8,
},
Copy {
src: usize,
dst: usize,
len: u32,
},
LayerNorm {
src: usize,
g: usize,
b: usize,
dst: usize,
rows: u32,
h: u32,
eps: f32,
},
GroupNorm {
src: usize,
g: usize,
b: usize,
dst: usize,
n: u32,
c: u32,
h: u32,
w: u32,
num_groups: u32,
eps: f32,
},
LayerNorm2d {
src: usize,
g: usize,
b: usize,
dst: usize,
n: u32,
c: u32,
h: u32,
w: u32,
eps: f32,
},
ConvTranspose2d {Show 19 fields
src: usize,
weight: usize,
dst: usize,
n: u32,
c_in: u32,
h: u32,
w_in: u32,
c_out: u32,
h_out: u32,
w_out: u32,
kh: u32,
kw: u32,
sh: u32,
sw: u32,
ph: u32,
pw: u32,
dh: u32,
dw: u32,
groups: u32,
},
ResizeNearest2x {
src: usize,
dst: usize,
n: u32,
c: u32,
h: u32,
w: u32,
},
AxialRope2d {
src: usize,
dst: usize,
batch: u32,
seq: u32,
hidden: u32,
end_x: u32,
end_y: u32,
head_dim: u32,
num_heads: u32,
theta: f32,
repeat_factor: u32,
},
RmsNorm {
src: usize,
g: usize,
b: usize,
dst: usize,
rows: u32,
h: u32,
eps: f32,
},
Softmax {
data: usize,
rows: u32,
cols: u32,
},
Cumsum {
src: usize,
dst: usize,
rows: u32,
cols: u32,
exclusive: bool,
},
SelectiveScan {
x: usize,
delta: usize,
a: usize,
b: usize,
c: usize,
dst: usize,
batch: u32,
seq: u32,
hidden: u32,
state_size: u32,
},
GatedDeltaNet {
q: usize,
k: usize,
v: usize,
g: usize,
beta: usize,
state: usize,
dst: usize,
batch: u32,
seq: u32,
heads: u32,
state_size: u32,
},
Conv2D1x1 {
src: usize,
weight: usize,
dst: usize,
n: u32,
c_in: u32,
c_out: u32,
hw: u32,
},
DequantMatMul {
x: usize,
w_q: usize,
scale: usize,
zp: usize,
dst: usize,
m: u32,
k: u32,
n: u32,
block_size: u32,
is_asymmetric: bool,
},
DequantMatMulGguf {
x: usize,
w_q: usize,
dst: usize,
m: u32,
k: u32,
n: u32,
scheme: QuantScheme,
},
DequantMatMulInt4 {
x: usize,
w_q: usize,
scale: usize,
zp: usize,
dst: usize,
m: u32,
k: u32,
n: u32,
block_size: u32,
is_asymmetric: bool,
},
DequantMatMulFp8 {
x: usize,
w_q: usize,
scale: usize,
dst: usize,
m: u32,
k: u32,
n: u32,
e5m2: bool,
},
DequantMatMulNvfp4 {
x: usize,
w_q: usize,
scale: usize,
global_scale: usize,
dst: usize,
m: u32,
k: u32,
n: u32,
},
LoraMatMul {
x: usize,
w: usize,
a: usize,
b: usize,
dst: usize,
m: u32,
k: u32,
n: u32,
r: u32,
scale: f32,
},
Sample {
logits: usize,
dst: usize,
batch: u32,
vocab: u32,
top_k: u32,
top_p: f32,
temperature: f32,
seed: u64,
},
Attention {Show 15 fields
q: usize,
k: usize,
v: usize,
mask: usize,
out: usize,
batch: u32,
seq: u32,
kv_seq: u32,
heads: u32,
head_dim: u32,
mask_kind: MaskKind,
q_row_stride: u32,
k_row_stride: u32,
v_row_stride: u32,
bhsd: bool,
},
AttentionBackward {Show 14 fields
q: usize,
k: usize,
v: usize,
dy: usize,
mask: usize,
out: usize,
batch: u32,
seq: u32,
kv_seq: u32,
heads: u32,
head_dim: u32,
mask_kind: MaskKind,
wrt: AttentionBwdWrt,
bhsd: bool,
},
Rope {
src: usize,
cos: usize,
sin: usize,
dst: usize,
batch: u32,
seq: u32,
hidden: u32,
head_dim: u32,
n_rot: u32,
cos_len: u32,
src_row_stride: u32,
},
FusedAttnBlock {Show 17 fields
hidden: usize,
qkv_w: usize,
out_w: usize,
mask: usize,
out: usize,
qkv_b: usize,
out_b: usize,
cos: usize,
sin: usize,
cos_len: u32,
batch: u32,
seq: u32,
hs: u32,
nh: u32,
dh: u32,
has_bias: bool,
has_rope: bool,
},
FusedBertLayer {Show 23 fields
hidden: usize,
qkv_w: usize,
qkv_b: usize,
out_w: usize,
out_b: usize,
mask: usize,
ln1_g: usize,
ln1_b: usize,
eps1: f32,
fc1_w: usize,
fc1_b: usize,
fc2_w: usize,
fc2_b: usize,
ln2_g: usize,
ln2_b: usize,
eps2: f32,
out: usize,
batch: u32,
seq: u32,
hs: u32,
nh: u32,
dh: u32,
int_dim: u32,
},
FusedNomicLayer {Show 23 fields
hidden: usize,
qkv_w: usize,
out_w: usize,
mask: usize,
cos: usize,
sin: usize,
cos_len: u32,
ln1_g: usize,
ln1_b: usize,
eps1: f32,
fc11_w: usize,
fc12_w: usize,
fc2_w: usize,
ln2_g: usize,
ln2_b: usize,
eps2: f32,
out: usize,
batch: u32,
seq: u32,
hs: u32,
nh: u32,
dh: u32,
int_dim: u32,
},
FusedSwiGLU {
src: usize,
dst: usize,
n_half: u32,
total: u32,
gate_first: bool,
},
Concat {
dst: usize,
outer: u32,
inner: u32,
total_axis: u32,
inputs: Vec<(usize, u32)>,
},
Compare {
lhs: usize,
rhs: usize,
dst: usize,
len: u32,
op: CmpOp,
},
Reduce {
src: usize,
dst: usize,
outer: u32,
reduced: u32,
inner: u32,
op: ReduceOp,
},
TopK {
src: usize,
dst: usize,
outer: u32,
axis_dim: u32,
k: u32,
},
GroupedMatMul {
input: usize,
weight: usize,
expert_idx: usize,
dst: usize,
m: u32,
k_dim: u32,
n: u32,
num_experts: u32,
},
DequantGroupedMatMulGguf {
input: usize,
w_q: usize,
expert_idx: usize,
dst: usize,
m: u32,
k_dim: u32,
n: u32,
num_experts: u32,
scheme: QuantScheme,
},
DequantMoEWeightsGguf {
w_q: usize,
dst: usize,
k_dim: u32,
n: u32,
num_experts: u32,
scheme: QuantScheme,
},
ScatterAdd {
updates: usize,
indices: usize,
dst: usize,
num_updates: u32,
out_dim: u32,
trailing: u32,
},
Where {
cond: usize,
on_true: usize,
on_false: usize,
dst: usize,
len: u32,
},
Transpose {
src: usize,
dst: usize,
in_total: u32,
out_dims: Vec<u32>,
in_strides: Vec<u32>,
},
GatherAxis {
table: usize,
idx: usize,
dst: usize,
outer: u32,
axis_dim: u32,
num_idx: u32,
trailing: u32,
},
Pool2D {Show 15 fields
src: usize,
dst: usize,
n: u32,
c: u32,
h: u32,
w: u32,
h_out: u32,
w_out: u32,
kh: u32,
kw: u32,
sh: u32,
sw: u32,
ph: u32,
pw: u32,
kind: ReduceOp,
},
Conv2D {Show 19 fields
src: usize,
weight: usize,
dst: usize,
n: u32,
c_in: u32,
h: u32,
w: u32,
c_out: u32,
h_out: u32,
w_out: u32,
kh: u32,
kw: u32,
sh: u32,
sw: u32,
ph: u32,
pw: u32,
dh: u32,
dw: u32,
groups: u32,
},
QMatMul {
x: usize,
w: usize,
bias: usize,
out: usize,
m: u32,
k: u32,
n: u32,
x_zp: i32,
w_zp: i32,
out_zp: i32,
mult: f32,
},
QConv2d {Show 24 fields
x: usize,
w: usize,
bias: usize,
out: usize,
n: u32,
c_in: u32,
h: u32,
w_in: u32,
c_out: u32,
h_out: u32,
w_out: u32,
kh: u32,
kw: u32,
sh: u32,
sw: u32,
ph: u32,
pw: u32,
dh: u32,
dw: u32,
groups: u32,
x_zp: i32,
w_zp: i32,
out_zp: i32,
mult: f32,
},
Quantize {
x: usize,
q: usize,
len: u32,
chan_axis: u32,
chan_dim: u32,
inner: u32,
scales: Vec<f32>,
zero_points: Vec<i32>,
},
Dequantize {
q: usize,
x: usize,
len: u32,
chan_axis: u32,
chan_dim: u32,
inner: u32,
scales: Vec<f32>,
zero_points: Vec<i32>,
},
FakeQuantize {
x: usize,
out: usize,
len: u32,
chan_axis: u32,
chan_dim: u32,
inner: u32,
bits: u8,
ste: SteKind,
scale_mode: ScaleMode,
state_off: Option<usize>,
},
FakeQuantizeBackward {
x: usize,
dy: usize,
dx: usize,
len: u32,
chan_axis: u32,
chan_dim: u32,
inner: u32,
bits: u8,
ste: SteKind,
},
FakeQuantizeLSQ {
x: usize,
scale_off: usize,
out: usize,
len: u32,
chan_axis: u32,
chan_dim: u32,
inner: u32,
bits: u8,
},
FakeQuantizeLSQBackwardX {
x: usize,
scale_off: usize,
dy: usize,
dx: usize,
len: u32,
chan_axis: u32,
chan_dim: u32,
inner: u32,
bits: u8,
},
FakeQuantizeLSQBackwardScale {
x: usize,
scale_off: usize,
dy: usize,
dscale: usize,
len: u32,
chan_axis: u32,
chan_dim: u32,
inner: u32,
bits: u8,
},
ReluBackward {
x: usize,
dy: usize,
dx: usize,
len: u32,
},
ReluBackwardF64 {
x: usize,
dy: usize,
dx: usize,
len: u32,
},
ActivationBackward {
x: usize,
dy: usize,
dx: usize,
len: u32,
kind: Activation,
},
ActivationBackwardF64 {
x: usize,
dy: usize,
dx: usize,
len: u32,
kind: Activation,
},
LayerNormBackwardInput {
x: usize,
gamma: usize,
dy: usize,
dx: usize,
rows: u32,
h: u32,
eps: f32,
},
LayerNormBackwardGamma {
x: usize,
dy: usize,
dgamma: usize,
rows: u32,
h: u32,
eps: f32,
},
RmsNormBackwardInput {
x: usize,
gamma: usize,
beta: usize,
dy: usize,
dx: usize,
rows: u32,
h: u32,
eps: f32,
},
RmsNormBackwardGamma {
x: usize,
gamma: usize,
beta: usize,
dy: usize,
dgamma: usize,
rows: u32,
h: u32,
eps: f32,
},
RmsNormBackwardBeta {
x: usize,
gamma: usize,
beta: usize,
dy: usize,
dbeta: usize,
rows: u32,
h: u32,
eps: f32,
},
RopeBackward {
dy: usize,
cos: usize,
sin: usize,
dx: usize,
batch: u32,
seq: u32,
hidden: u32,
head_dim: u32,
n_rot: u32,
cos_len: u32,
},
CumsumBackward {
dy: usize,
dx: usize,
rows: u32,
cols: u32,
exclusive: bool,
},
GatherBackward {
dy: usize,
indices: usize,
dst: usize,
outer: u32,
axis_dim: u32,
num_idx: u32,
trailing: u32,
},
GroupNormBackwardInput {
x: usize,
gamma: usize,
beta: usize,
dy: usize,
dx: usize,
n: u32,
c: u32,
h: u32,
w: u32,
num_groups: u32,
eps: f32,
},
GroupNormBackwardGamma {
x: usize,
dy: usize,
dgamma: usize,
n: u32,
c: u32,
h: u32,
w: u32,
num_groups: u32,
eps: f32,
},
GroupNormBackwardBeta {
dy: usize,
dbeta: usize,
n: u32,
c: u32,
h: u32,
w: u32,
},
MaxPool2dBackward {Show 15 fields
x: usize,
dy: usize,
dx: usize,
n: u32,
c: u32,
h: u32,
w: u32,
h_out: u32,
w_out: u32,
kh: u32,
kw: u32,
sh: u32,
sw: u32,
ph: u32,
pw: u32,
},
Conv2dBackwardInput {Show 19 fields
dy: usize,
w: usize,
dx: usize,
n: u32,
c_in: u32,
h: u32,
w_in: u32,
c_out: u32,
h_out: u32,
w_out: u32,
kh: u32,
kw: u32,
sh: u32,
sw: u32,
ph: u32,
pw: u32,
dh: u32,
dw: u32,
groups: u32,
},
Conv2dBackwardWeight {Show 19 fields
x: usize,
dy: usize,
dw: usize,
n: u32,
c_in: u32,
h: u32,
w: u32,
c_out: u32,
h_out: u32,
w_out: u32,
kh: u32,
kw: u32,
sh: u32,
sw: u32,
ph: u32,
pw: u32,
dh: u32,
dw_dil: u32,
groups: u32,
},
SoftmaxCrossEntropy {
logits: usize,
labels: usize,
dst: usize,
n: u32,
c: u32,
},
SoftmaxCrossEntropyBackward {
logits: usize,
labels: usize,
d_loss: usize,
dlogits: usize,
n: u32,
c: u32,
},
CustomOp {
kernel: Arc<dyn CpuKernel>,
inputs: Vec<(usize, u32, Shape)>,
output: (usize, u32, Shape),
attrs: Vec<u8>,
},
GaussianSplatRender {Show 23 fields
positions_off: usize,
positions_len: usize,
scales_off: usize,
scales_len: usize,
rotations_off: usize,
rotations_len: usize,
opacities_off: usize,
opacities_len: usize,
colors_off: usize,
colors_len: usize,
sh_coeffs_off: usize,
sh_coeffs_len: usize,
meta_off: usize,
dst_off: usize,
dst_len: usize,
width: u32,
height: u32,
tile_size: u32,
radius_scale: f32,
alpha_cutoff: f32,
max_splat_steps: u32,
transmittance_threshold: f32,
max_list_entries: u32,
},
GaussianSplatRenderBackward {Show 28 fields
positions_off: usize,
positions_len: usize,
scales_off: usize,
scales_len: usize,
rotations_off: usize,
rotations_len: usize,
opacities_off: usize,
opacities_len: usize,
colors_off: usize,
colors_len: usize,
sh_coeffs_off: usize,
sh_coeffs_len: usize,
meta_off: usize,
d_loss_off: usize,
d_loss_len: usize,
packed_off: usize,
packed_len: usize,
width: u32,
height: u32,
tile_size: u32,
radius_scale: f32,
alpha_cutoff: f32,
max_splat_steps: u32,
transmittance_threshold: f32,
max_list_entries: u32,
loss_grad_clip: f32,
sh_band: u32,
max_anisotropy: f32,
},
GaussianSplatPrepare {Show 24 fields
positions_off: usize,
positions_len: usize,
scales_off: usize,
scales_len: usize,
rotations_off: usize,
rotations_len: usize,
opacities_off: usize,
opacities_len: usize,
colors_off: usize,
colors_len: usize,
sh_coeffs_off: usize,
sh_coeffs_len: usize,
meta_off: usize,
meta_len: usize,
prep_off: usize,
prep_len: usize,
width: u32,
height: u32,
tile_size: u32,
radius_scale: f32,
alpha_cutoff: f32,
max_splat_steps: u32,
transmittance_threshold: f32,
max_list_entries: u32,
},
GaussianSplatRasterize {Show 14 fields
prep_off: usize,
prep_len: usize,
meta_off: usize,
meta_len: usize,
dst_off: usize,
dst_len: usize,
count: usize,
width: u32,
height: u32,
tile_size: u32,
alpha_cutoff: f32,
max_splat_steps: u32,
transmittance_threshold: f32,
max_list_entries: u32,
},
Fft1d {
src: usize,
dst: usize,
outer: u32,
n_complex: u32,
inverse: bool,
dtype: DType,
},
}Expand description
A pre-compiled kernel call with all args resolved to arena offsets.
Variants§
Nop
Skip (Input/Param already in arena)
Sgemm
C = A @ B (BLAS sgemm)
DenseSolveF64
f64 dense solve x = A⁻¹·b via LAPACK dgesv.
a, b, x are byte-offsets into the arena. n is the matrix
dimension; nrhs is 1 for a vector RHS or >1 for multi-RHS.
The kernel materializes scratch copies of A and b internally
(LAPACK overwrites both with LU factors and solution).
DenseSolveF32
f32 twin of DenseSolveF64. Calls LAPACK sgesv (or the
no-blas Rust fallback). Same arena byte-offset contract.
BatchedDenseSolveF64
Batched f64 dense solve. a, b, x are byte-offsets to
the leading slice; batch is the number of independent
systems. Per slice the kernel calls dgesv(A_i, b_i, n, nrhs)
— LAPACK has no batched dgesv on Accelerate, so we loop.
BatchedDenseSolveF32
Batched f32 dense solve — loop of sgesv per batch slice.
BatchedDgemmF64
Batched f64 matmul. Both inputs and output have a leading
batch axis of size batch. Per-batch independent dgemm:
C[i] = A[i] @ B[i] for i in 0..batch. Used by VJP rules
that emit per-batch outer products (e.g., BatchedDenseSolve
VJP). The unbatched Dgemm thunk handles the rank-2 case.
BatchedSgemm
Batched f32 matmul — same loop-per-batch shape as
BatchedDgemmF64 but calling sgemm. Needed for attention
patterns where both operands carry a batch dim (e.g. q@k^T
and attn@v in decomposed self-attention). The 2-D Sgemm
flatten trick is wrong in that case because it treats b as
a single shared RHS across every batch.
Dgemm
C = A @ B via Accelerate cblas_dgemm. Mirror of Sgemm at f64.
TransposeF64
f64 N-D index walk used for both Op::Transpose and Op::Expand.
in_strides carries 0s on broadcast axes (Expand) or permuted
strides (Transpose). Mirror of Thunk::Transpose at f64.
ActivationF64
f64 element-wise activation. Single-input, single-output. The
kernel always reads from src and writes to dst, so it works
whether or not the planner aliased the two slots.
ComplexNormSqF32
Element-wise complex squared-magnitude: |z|² = re² + im².
Reads the C64 input at src as 2·len f32 ([re,im] pairs),
writes len f32 to dst.
ComplexNormSqBackwardF32
Wirtinger backward for [ComplexNormSqF32]: dz = g · z as
C64. Reads z at 2·len f32 + g at len f32; writes
2·len f32 to dz.
ConjugateC64
Element-wise C64 conjugate: writes [re_i, -im_i] per element.
Layout matches the rest of C64 here ([re,im] interleaved f32).
ActivationC64
C64 element-wise activation. Only kinds with well-defined
complex extensions are supported: Neg, Exp, Log, Sqrt.
Everything else (Sigmoid, Tanh, Relu, Abs, Sin/Cos/Tan/Atan,
Round, GeLU family) is rejected at lowering — those don’t have
single natural complex definitions. len is the complex
element count (the f32 buffer holds 2·len floats).
ReduceSumF64
f64 contiguous reduction along a single axis range. Layout
[outer, reduced, inner] in memory; output is [outer, inner].
Sum only for now (Mean composes via 1/N multiply post-pass).
CopyF64
f64 plain copy (Reshape / Cast at the same dtype). Mirrors Copy
but at 8 bytes per element.
BinaryFullF64
f64 element-wise binary with broadcast. len/lhs_len/rhs_len
are element counts; kernel does out[i] = lhs[i % lhs_len] OP rhs[i % rhs_len].
Mirror of BinaryFull at 8 bytes per element.
Fields
ConcatF64
f64 concat — byte-for-byte mirror of Concat but copies
8 bytes per element. Element-counted offsets/strides match
the f32 variant; the executor scales by elem_size internally.
BinaryFullC64
C64 element-wise binary with broadcast. Same len /
lhs_len / rhs_len semantics as BinaryFull but each
“element” is one complex value (8 bytes = [re, im] as two
f32s). The executor reads the underlying f32 buffer at
2·len floats and walks element pairs. Supports Add / Sub /
Mul / Div; Max / Min / Pow have no single natural complex
definition and panic at lowering.
Fields
Scan
Bounded scan. Holds a recursively-compiled body schedule + a
pre-initialized body arena snapshot (constants filled). Each
outer execution clones the snapshot, copies the carry-in slot
from the outer arena, runs the body schedule length times,
then writes the final carry to the outer arena.
Single-carry MVP — body has exactly one Input and one output, both same shape and dtype.
Fields
body: Arc<ThunkSchedule>save_trajectory: boolWhen true, write each step’s carry to the outer arena at
offset outer_final_off + t * carry_bytes, producing a
[length, *carry] stacked trajectory. When false, only the
final carry lands at outer_final_off.
xs_inputs: Arc<Vec<(usize, usize, u32)>>Per-step xs inputs. For each: (body_x_input_off,
outer_xs_base_off, per_step_bytes). Per iteration t, the
executor copies outer_xs_base_off + t * per_step_bytes
into body_x_input_off. Empty when the scan has no xs.
ScanBackward
Reverse-mode AD companion to Thunk::Scan. Walks t = length-1 .. 0, threading dcarry through the body’s VJP. Per iteration:
writes carry_t (from outer init or trajectory), each xs_i[t]
slice, and the current dcarry into the body_vjp’s Input
slots, runs body_vjp, reads new dcarry from its single output.
f64 carry only — the upstream-accumulation step in trajectory
mode does an element-wise f64 add.
Fields
body_vjp: Arc<ThunkSchedule>outer_xs_offs: Arc<Vec<(usize, u32)>>Per-xs entries: (outer_xs_base_off, per_step_bytes). Read
xs_i[t] from outer_xs_base_off + t * per_step_bytes.
carry_elem_size: u32Bytes per element in the carry tensor: 4 for f32, 8 for f64.
Used to dispatch the trajectory-mode upstream accumulation
kernel (the dcarry += upstream[t] add must use the right
floating-point type — a hard-coded f64 add silently does
nothing for an f32 carry whose cb isn’t divisible by 8).
num_checkpoints: u32Recursive checkpointing config. 0 or length ⇒ full
trajectory cached, no recompute (existing behavior).
0 < K < length ⇒ trajectory has only K rows; the executor
recomputes intermediate carries via forward_body between
checkpoints. Memory: O(K · carry_bytes); time: O(length).
forward_body: Option<Arc<ThunkSchedule>>Forward body schedule (same compiled body as the forward
Op::Scan), used for recompute when num_checkpoints is
active. None for the All strategy.
ScanBackwardXs
Companion to ScanBackward that materializes one stacked
dxs_i. Same backward loop; per iteration, after running
body_vjp, copies its body_dxs_out_off slot into the outer
arena at outer_dxs_off + t * per_step_bytes. dcarry threading
is identical — we still need it for the body_vjp recurrence
even though we don’t write it back to the outer arena.
Fields
body_vjp: Arc<ThunkSchedule>num_checkpoints: u32Recursive checkpointing config. Same semantics as
Thunk::ScanBackward::num_checkpoints — 0 or length
means “save every step’s carry”; 0 < K < length means
the trajectory has only K rows and the executor recomputes
intermediate carries via forward_body (which must be
Some). Implemented via segment-cached recompute,
mirroring the ScanBackward path.
forward_body: Option<Arc<ThunkSchedule>>CustomFn
User-defined sub-graph (Op::CustomFn) — runs fwd_body once.
Per execution: clone body_init, copy each primal input from the
outer arena into its body Input slot, run the body schedule,
copy the body’s single output back to the outer arena.
Fields
body: Arc<ThunkSchedule>FusedMmBiasAct
C = A @ B; C += bias; C = act(C)
FusedResidualLN
out = LN(x + residual + bias, gamma, beta)
Fields
FusedResidualRmsNorm
out = RmsNorm(x + residual + bias, gamma, beta)
Fields
BiasAdd
out = bias_add(data, bias, m, n) for Binary::Add with broadcast
BinaryFull
Element-wise binary op with NumPy-style broadcast.
Fast path (lhs_len == rhs_len == len): plain element-wise loop,
SIMD-vectorized on aarch64 for Add/Mul. bcast_* fields
are unused.
Broadcast path: uses out_dims_bcast + bcast_lhs_strides +
bcast_rhs_strides to compute per-cell indices into each
operand. The strides are precomputed at thunk-construction
time from the operands’ true shapes (with stride 0 on any axis
where the operand has size 1). This is the only correct way
to handle bidirectional broadcasts like [N, 1] op [1, S] → [N, S], which simple i % lhs_len modulo indexing maps to
wrong cells.
Fields
ActivationInPlace
Activation in-place
Gather
Gather axis=0: table[idx] → out
Narrow
Narrow: copy slice (elem_bytes = source element size: 4 for f32, 8 for f64).
Copy
Copy (reshape, expand)
LayerNorm
LayerNorm standalone
GroupNorm
GroupNorm on NCHW [N,C,H,W].
LayerNorm2d
LayerNorm2d on NCHW (SAM / candle semantics).
ConvTranspose2d
ConvTranspose2d on NCHW.
Fields
ResizeNearest2x
Nearest 2× upsample on NCHW (per-batch slice).
AxialRope2d
SAM2 axial 2-D RoPE on [batch, seq, num_heads * head_dim].
Fields
RmsNorm
RMSNorm: out = (x / sqrt(mean(x^2) + eps)) * gamma + beta. No mean subtraction, hence cheaper than LayerNorm. Used by Llama-class models.
Softmax
Softmax
Cumsum
Inclusive (or exclusive) cumulative sum along the last axis (callers pre-flatten higher-dim cumsums via reshape views).
SelectiveScan
Mamba-style selective scan (plan #15). Inputs: x, delta [b,s,h], a [h,n], b [b,s,n], c [b,s,n]. Output: y [b,s,h]. State h carries through the seq.
Fields
GatedDeltaNet
Gated DeltaNet linear-attention scan (Qwen3.5/3.6 trunk).
Inputs: q, k, v [b, s, h, n]; g, beta [b, s, h]. Output:
[b, s, h, n]. See Op::GatedDeltaNet for math.
Fields
Conv2D1x1
1×1 conv fast path (plan #26). The general Conv2D thunk runs the textbook 7-deep loop; a 1×1 stride-1 padding-0 groups-1 conv is mathematically a per-batch matmul, and dispatching it through BLAS is 3-10× faster than the scalar nest. Common case: ViT patch-projection follow-on, transformer “expert” reductions in some MoE designs.
Per batch: weight [c_out, c_in] × input [c_in, h*w]
= output [c_out, h*w].
DequantMatMul
Fused dequant + matmul (plan #5). Today supports
QuantScheme::Int8Block (symmetric); other schemes panic
at lowering time with a clear message until kernels are added.
Fields
DequantMatMulGguf
GGUF-format dequant + matmul. Weight is a packed byte tensor in one of the K-quant super-block layouts (Q4_K, Q5_K, Q6_K, Q8_K). Scales / mins live inside the packed bytes — no side-channel scale tensor.
Today this is a “dequant-to-scratch then sgemm” kernel — it keeps the arena memory footprint down (weights stay packed) but the dequant itself happens per matmul. A future fully fused tile-streaming kernel would close the compute gap.
DequantMatMulInt4
Int4 block dequant + matmul (packed nibbles, side scale/zp).
Fields
DequantMatMulFp8
FP8 dequant + matmul (per-tensor or per-column scale).
DequantMatMulNvfp4
NVFP4 (E2M1) block dequant + matmul — 16-wide groups, FP8 scales.
LoraMatMul
Fused LoRA matmul (plan #9): out = x·W + scale * (x·A)·B.
r is the LoRA rank (typically 4-64) — the rank-r
intermediate x·A lives in scratch, never on the arena.
Sample
Fused sample: logits [batch, vocab] → token ids [batch]. See Op::Sample. Output values are f32-encoded usize indices (matches the rest of the IR’s “ids as f32” convention).
Fields
Attention
Attention SDPA. mask is the offset of the optional mask tensor
(only meaningful when mask_kind == MaskKind::Custom); other
kinds synthesize the mask in-kernel.
Q/K/V each carry a _row_stride (elements per source row).
Defaults to heads * head_dim — matches the standalone
“Q/K/V are their own contiguous buffers” case. The Narrow→
Attention fusion below rewrites these to the parent QKV stride
(typically 3 * heads * head_dim) so the kernel reads QKV
directly without materializing the per-head buffers (plan #46).
Fields
bhsd: boolMemory layout flag. false (the historical default) →
[B, S, H, D] row-major: per-head offset is
bi*S*H*D + si*H*D + hi*D. true → [B, H, S, D]
(head-major), matching the convention used by rlx-cuda /
rlx-rocm / rlx-tpu: per-head offset is
bi*H*S*D + hi*S*D + si*D. Detected at lowering time
from the input shape vs num_heads / head_dim.
AttentionBackward
Op::AttentionBackward — emits dQ, dK, or dV (see wrt).
Fields
wrt: AttentionBwdWrtRope
RoPE (rotary position embeddings).
src_row_stride is elements per source row (defaults to hidden
for the standalone case; set to qkv_axis * inner when the
thunk fusion pass below rewires Rope to read directly from the
fused QKV buffer — plan #45).
Fields
FusedAttnBlock
Fused attention block: QKV proj → split → [RoPE] → SDPA → output proj. All intermediates stay in L1 cache. Zero arena writes between ops.
Fields
FusedBertLayer
Fused ENTIRE transformer layer: attention + residual + LN + FFN + residual + LN. Combines ~10 thunks into 1. All intermediates on stack. Zero arena traffic.
Fields
FusedNomicLayer
Fused Nomic transformer layer: attention+RoPE + residual + LN + SwiGLU FFN + residual + LN.
Fields
FusedSwiGLU
Fused SwiGLU: out[r,i] = x[r,i] * silu(x[r, n_half+i]). Input: [outer, 2*n_half] — concatenated up||gate per row. Output: [outer, n_half].
Concat
Concat along an axis: output[outer, axis, inner] = inputs concatenated.
Each entry of inputs is (src_offset, axis_len_for_that_input) in u32
elements. outer, inner, and total_axis_len are pre-computed
at compile time to avoid per-run shape work.
Compare
Element-wise comparison: out = (lhs CMP rhs) ? 1.0 : 0.0
Reduce
Reduction along a contiguous range of axes. Input layout (after
shape decomposition) is [outer, reduced, inner]; output is
[outer, inner]. The single-axis cases (axis=0 → outer=1;
axis=last → inner=1) and contiguous multi-axis (e.g. reduce over
[0, 1] of an [N, C, H, W] tensor → outer=1, reduced=NC, inner=HW)
all map onto this triplet. Non-contiguous axes are not supported
and bail to Nop in the compile pass.
TopK
Top-K indices along the last axis. Input shape [outer, axis_dim],
output [outer, k] of f32-encoded i64 indices. Ties broken by
smaller index. Used by MoE gating + beam search.
GroupedMatMul
Indexed batched matmul: out[i] = input[i] @ weight[expert_idx[i]]. Naive impl per token; for real MoE workloads, sort-by-expert + run segmented GEMM would amortize. Done when there’s a workload.
Fields
DequantGroupedMatMulGguf
GGUF K-quant packed expert stack + grouped matmul (MoE FFN).
Fields
scheme: QuantSchemeDequantMoEWeightsGguf
Materialize packed MoE weights to F32 [E, K, N] (autodiff helper).
ScatterAdd
Scatter-add: dst[indices[i] * trailing + j] += updates[i * trailing + j]. Output is zeroed first; multiple updates to the same row accumulate.
Where
Ternary select: out = cond != 0 ? on_true : on_false
Transpose
General N-D transpose / broadcast. out_dims[i] is the output’s dim
i length; in_strides[i] is the input stride (in elements) used to
index that dim — 0 for broadcast dims (Expand). in_total is the
total element count in the source buffer (≤ output total when
broadcasting). Strides are pre-computed at compile time.
GatherAxis
Gather along an arbitrary axis. outer = product(dims[..axis]),
trailing = product(dims[axis+1..]), axis_dim = the dimension
being indexed into. Output: outer × num_idx × trailing.
(axis=0 still routes to the simpler Thunk::Gather fast path.)
Pool2D
2D pooling (Max or Mean). Input layout [N, C, H, W], output
[N, C, H_out, W_out]. Padding is implicit-zero; Mean divides by
the full kernel area (matches torch’s count_include_pad=True).
Fields
Conv2D
2D convolution. Input [N, C_in, H, W], weight [C_out, C_in_per_group, kH, kW], output [N, C_out, H_out, W_out]. Bias is a separate Op::Binary::Add after the conv (matching the IR’s input layout — Op::Conv has 2 inputs). Naive direct convolution; sufficient for correctness, not optimised.
Fields
QMatMul
Real INT8 matmul with i32 accumulation.
out[m, n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp)
Reads x and w as i8, bias as i32; writes out as i8.
Same kernel shape as rlx_cortexm::dense::dense_i8 — promoted
to a desktop thunk so a quantized graph compiled here doesn’t
have to round-trip through fake-quant.
Fields
QConv2d
Real INT8 conv2d, NCHW layout. Same loop shape as Thunk::Conv2D
but with i8 reads, i32 accumulation, and per-output requantize
to i8. Bias is i32 in the accumulator scale.
Fields
Quantize
INT8 quantize. Reads x as f32, writes q as i8.
chan = (i / inner) % chan_dim selects the per-channel
scale/zp; chan_axis is informational only (the kernel uses
chan_dim and inner directly).
For per-tensor, chan_dim = 1 and inner = len so chan is
always 0.
Fields
Dequantize
INT8 dequantize — inverse of Thunk::Quantize.
Fields
FakeQuantize
QAT fake-quantize. Per-channel (or per-tensor) symmetric
quantize-then-dequantize on the fly. Computes
s[c] = max(|x[..., c, ...]|) / q_max
then
out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c]
with q_max = {127, 7, 1} for bits = {8, 4, 2}. Same
channel-layout convention as Thunk::Quantize: every
element’s channel is (i / inner) % chan_dim. The kernel
does two passes — one to scan max-abs per channel, one to
quant-dequant per element.
Fields
ste: SteKindSTE variant — informational on the forward side (output is
the same regardless), kernel-relevant in the matching
FakeQuantizeBackward thunk.
FakeQuantizeBackward
Backward pass for Op::FakeQuantize under one of four STE
variants. Computes dx[i] from the f32 forward input x and
the upstream gradient dy, using the same per-channel scale
scheme as the forward.
Fields
FakeQuantizeLSQ
LSQ forward — same kernel shape as FakeQuantize Fixed mode.
Reads scale from scale_off (a [chan_dim] Param tensor).
Fields
FakeQuantizeLSQBackwardX
LSQ backward, x-gradient. STE-clipped: passes upstream through inside the quantization range, zeros outside.
Fields
FakeQuantizeLSQBackwardScale
LSQ backward, scale-gradient. Per-channel:
dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i]
where ψ(z) = -z + round(z) if |z| ≤ q_max else
sign(z) · q_max. Output shape: [chan_dim].
Fields
ReluBackward
ReLU backward: dx[i] = dy[i] if x[i] > 0 else 0.
ReluBackwardF64
f64 sibling of ReluBackward — same shape as the f32 variant
but reads/writes 8 bytes per element. Required because
ReluBackward’s &[f32] slot view returns half of every f64
otherwise → backward silently produces 0 gradients on an f64
graph. Mirrors the ActivationBackwardF64 split.
ActivationBackward
Generic element-wise activation backward.
dx[i] = (d/dx act(x))[i] · dy[i]. The closure dispatch is
per-element; expensive activations (Gelu) recompute internals
inline rather than threading an extra “saved y” tensor through.
ActivationBackwardF64
f64 sibling of ActivationBackward — slot offsets, len in
elements; kernel reads/writes 8 bytes per element. Required
because ActivationBackward’s &[f32] slot view silently
returns garbage on an f64 graph (cb % 4 still works but every
loaded value is half of an f64 → wrong gradient).
LayerNormBackwardInput
LayerNorm backward — input gradient. Recomputes mean/var/x̂ from
x and emits the closed-form d_x per row.
LayerNormBackwardGamma
LayerNorm backward — gamma gradient. d_gamma[d] = Σ_row dy·x̂.
RmsNormBackwardInput
RmsNormBackwardGamma
RmsNormBackwardBeta
RopeBackward
Fields
CumsumBackward
GatherBackward
GroupNormBackwardInput
Fields
GroupNormBackwardGamma
GroupNormBackwardBeta
MaxPool2dBackward
2D max-pool backward (NCHW). Recomputes the argmax position
inside each window and accumulates dy into dx at that
position. Output is zeroed first; ties resolve to the first
hit (lowest (kh,kw) index), matching what the forward kernel
does with acc.max(v).
Fields
Conv2dBackwardInput
2D conv backward w.r.t. input (dx = conv_transpose(dy, w)).
dy [N, C_out, H_out, W_out], w [C_out, C_in_per_group, kH, kW],
dx [N, C_in, H, W].
Fields
Conv2dBackwardWeight
2D conv backward w.r.t. weight. x [N, C_in, H, W],
dy [N, C_out, H_out, W_out], dw [C_out, C_in_per_group, kH, kW].
dw is zeroed before accumulation.
Fields
SoftmaxCrossEntropy
Fused softmax + cross-entropy loss with f32-encoded integer
labels. logits [N, C], labels [N], output [N] per-row loss.
Numerically stable (max-subtract before exp).
SoftmaxCrossEntropyBackward
Backward of the fused loss above.
dlogits[n, k] = (softmax(logits[n])[k] - one_hot(labels[n])[k]) * d_loss[n].
CustomOp
User-registered custom op (CPU side). Lowered from Op::Custom.
kernel is resolved against the global CPU kernel registry at
compile time and stored as Arc<dyn CpuKernel> so execution
avoids per-call lookups. v1: f32 contiguous only — see
op_registry::CpuKernel::execute_f32.
Fields
GaussianSplatRender
1D FFT along the last axis. Input/output are [..., 2N]
real-block layout (first N real, second N imag along the
transformed axis). outer is the product of all leading axes;
n_complex is N (the number of complex points). Both halves
of the real-block layout are read together by the kernel.
dtype selects the f32 or f64 path; the two share structure
but not buffers, so a flag at compile time avoids per-row
dispatch.
CPU reference 3D Gaussian splat render (rlx_ir::Op::GaussianSplatRender).
Fields
GaussianSplatRenderBackward
Fields
GaussianSplatPrepare
Strict IR stage 1 — project + bin + sort + rays (Op::GaussianSplatPrepare).
Fields
GaussianSplatRasterize
Strict IR stage 2 — tile raster from prepare buffer (Op::GaussianSplatRasterize).
Fields
Fft1d
Trait Implementations§
Auto Trait Implementations§
impl Freeze for Thunk
impl !RefUnwindSafe for Thunk
impl Send for Thunk
impl Sync for Thunk
impl Unpin for Thunk
impl UnsafeUnpin for Thunk
impl !UnwindSafe for Thunk
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more