Skip to main content

Thunk

Enum Thunk 

Source
pub enum Thunk {
Show 107 variants Nop, Sgemm { a: usize, b: usize, c: usize, m: u32, k: u32, n: u32, }, DenseSolveF64 { a: usize, b: usize, x: usize, n: u32, nrhs: u32, }, DenseSolveF32 { a: usize, b: usize, x: usize, n: u32, nrhs: u32, }, BatchedDenseSolveF64 { a: usize, b: usize, x: usize, batch: u32, n: u32, nrhs: u32, }, BatchedDenseSolveF32 { a: usize, b: usize, x: usize, batch: u32, n: u32, nrhs: u32, }, BatchedDgemmF64 { a: usize, b: usize, c: usize, batch: u32, m: u32, k: u32, n: u32, }, BatchedSgemm { a: usize, b: usize, c: usize, batch: u32, m: u32, k: u32, n: u32, }, Dgemm { a: usize, b: usize, c: usize, m: u32, k: u32, n: u32, }, TransposeF64 { src: usize, dst: usize, in_total: u32, out_dims: Vec<u32>, in_strides: Vec<u32>, }, ActivationF64 { src: usize, dst: usize, len: u32, kind: Activation, }, ComplexNormSqF32 { src: usize, dst: usize, len: u32, }, ComplexNormSqBackwardF32 { z: usize, g: usize, dz: usize, len: u32, }, ConjugateC64 { src: usize, dst: usize, len: u32, }, ActivationC64 { src: usize, dst: usize, len: u32, kind: Activation, }, ReduceSumF64 { src: usize, dst: usize, outer: u32, reduced: u32, inner: u32, }, CopyF64 { src: usize, dst: usize, len: u32, }, BinaryFullF64 { lhs: usize, rhs: usize, dst: usize, len: u32, lhs_len: u32, rhs_len: u32, op: BinaryOp, out_dims_bcast: Vec<u32>, bcast_lhs_strides: Vec<u32>, bcast_rhs_strides: Vec<u32>, }, ConcatF64 { dst: usize, outer: u32, inner: u32, total_axis: u32, inputs: Vec<(usize, u32)>, }, BinaryFullC64 { lhs: usize, rhs: usize, dst: usize, len: u32, lhs_len: u32, rhs_len: u32, op: BinaryOp, out_dims_bcast: Vec<u32>, bcast_lhs_strides: Vec<u32>, bcast_rhs_strides: Vec<u32>, }, Scan { body: Arc<ThunkSchedule>, body_init: Arc<Vec<u8>>, body_input_off: usize, body_output_off: usize, outer_init_off: usize, outer_final_off: usize, length: u32, carry_bytes: u32, save_trajectory: bool, xs_inputs: Arc<Vec<(usize, usize, u32)>>, bcast_inputs: Arc<Vec<(usize, usize, u32)>>, num_checkpoints: u32, }, ScanBackward {
Show 21 fields body_vjp: Arc<ThunkSchedule>, body_init: Arc<Vec<u8>>, body_carry_in_off: usize, body_x_offs: Arc<Vec<usize>>, body_d_output_off: usize, body_dcarry_out_off: usize, outer_init_off: usize, outer_traj_off: usize, outer_upstream_off: usize, outer_xs_offs: Arc<Vec<(usize, u32)>>, outer_dinit_off: usize, length: u32, carry_bytes: u32, carry_elem_size: u32, save_trajectory: bool, num_checkpoints: u32, forward_body: Option<Arc<ThunkSchedule>>, forward_body_init: Option<Arc<Vec<u8>>>, forward_body_carry_in_off: usize, forward_body_output_off: usize, forward_body_x_offs: Arc<Vec<usize>>,
}, ScanBackwardXs {
Show 23 fields body_vjp: Arc<ThunkSchedule>, body_init: Arc<Vec<u8>>, body_carry_in_off: usize, body_x_offs: Arc<Vec<usize>>, body_d_output_off: usize, body_dcarry_out_off: usize, body_dxs_out_off: usize, outer_init_off: usize, outer_traj_off: usize, outer_upstream_off: usize, outer_xs_offs: Arc<Vec<(usize, u32)>>, outer_dxs_off: usize, length: u32, carry_bytes: u32, carry_elem_size: u32, per_step_bytes: u32, save_trajectory: bool, num_checkpoints: u32, forward_body: Option<Arc<ThunkSchedule>>, forward_body_init: Option<Arc<Vec<u8>>>, forward_body_carry_in_off: usize, forward_body_output_off: usize, forward_body_x_offs: Arc<Vec<usize>>,
}, CustomFn { body: Arc<ThunkSchedule>, body_init: Arc<Vec<u8>>, inputs: Arc<Vec<(usize, usize, u32)>>, body_output_off: usize, outer_output_off: usize, out_bytes: u32, }, FusedMmBiasAct { a: usize, w: usize, bias: usize, c: usize, m: u32, k: u32, n: u32, act: Option<Activation>, }, FusedResidualLN { x: usize, res: usize, bias: usize, g: usize, b: usize, out: usize, rows: u32, h: u32, eps: f32, has_bias: bool, }, FusedResidualRmsNorm { x: usize, res: usize, bias: usize, g: usize, b: usize, out: usize, rows: u32, h: u32, eps: f32, has_bias: bool, }, BiasAdd { src: usize, bias: usize, dst: usize, m: u32, n: u32, }, BinaryFull { lhs: usize, rhs: usize, dst: usize, len: u32, lhs_len: u32, rhs_len: u32, op: BinaryOp, out_dims_bcast: Vec<u32>, bcast_lhs_strides: Vec<u32>, bcast_rhs_strides: Vec<u32>, }, ActivationInPlace { data: usize, len: u32, act: Activation, }, Gather { table: usize, table_len: u32, idx: usize, dst: usize, num_idx: u32, trailing: u32, }, Narrow { src: usize, dst: usize, outer: u32, src_stride: u32, dst_stride: u32, inner: u32, elem_bytes: u8, }, Copy { src: usize, dst: usize, len: u32, }, LayerNorm { src: usize, g: usize, b: usize, dst: usize, rows: u32, h: u32, eps: f32, }, GroupNorm { src: usize, g: usize, b: usize, dst: usize, n: u32, c: u32, h: u32, w: u32, num_groups: u32, eps: f32, }, LayerNorm2d { src: usize, g: usize, b: usize, dst: usize, n: u32, c: u32, h: u32, w: u32, eps: f32, }, ConvTranspose2d {
Show 19 fields src: usize, weight: usize, dst: usize, n: u32, c_in: u32, h: u32, w_in: u32, c_out: u32, h_out: u32, w_out: u32, kh: u32, kw: u32, sh: u32, sw: u32, ph: u32, pw: u32, dh: u32, dw: u32, groups: u32,
}, ResizeNearest2x { src: usize, dst: usize, n: u32, c: u32, h: u32, w: u32, }, AxialRope2d { src: usize, dst: usize, batch: u32, seq: u32, hidden: u32, end_x: u32, end_y: u32, head_dim: u32, num_heads: u32, theta: f32, repeat_factor: u32, }, RmsNorm { src: usize, g: usize, b: usize, dst: usize, rows: u32, h: u32, eps: f32, }, Softmax { data: usize, rows: u32, cols: u32, }, Cumsum { src: usize, dst: usize, rows: u32, cols: u32, exclusive: bool, }, SelectiveScan { x: usize, delta: usize, a: usize, b: usize, c: usize, dst: usize, batch: u32, seq: u32, hidden: u32, state_size: u32, }, GatedDeltaNet { q: usize, k: usize, v: usize, g: usize, beta: usize, state: usize, dst: usize, batch: u32, seq: u32, heads: u32, state_size: u32, }, Conv2D1x1 { src: usize, weight: usize, dst: usize, n: u32, c_in: u32, c_out: u32, hw: u32, }, DequantMatMul { x: usize, w_q: usize, scale: usize, zp: usize, dst: usize, m: u32, k: u32, n: u32, block_size: u32, is_asymmetric: bool, }, DequantMatMulGguf { x: usize, w_q: usize, dst: usize, m: u32, k: u32, n: u32, scheme: QuantScheme, }, DequantMatMulInt4 { x: usize, w_q: usize, scale: usize, zp: usize, dst: usize, m: u32, k: u32, n: u32, block_size: u32, is_asymmetric: bool, }, DequantMatMulFp8 { x: usize, w_q: usize, scale: usize, dst: usize, m: u32, k: u32, n: u32, e5m2: bool, }, DequantMatMulNvfp4 { x: usize, w_q: usize, scale: usize, global_scale: usize, dst: usize, m: u32, k: u32, n: u32, }, LoraMatMul { x: usize, w: usize, a: usize, b: usize, dst: usize, m: u32, k: u32, n: u32, r: u32, scale: f32, }, Sample { logits: usize, dst: usize, batch: u32, vocab: u32, top_k: u32, top_p: f32, temperature: f32, seed: u64, }, Attention {
Show 15 fields q: usize, k: usize, v: usize, mask: usize, out: usize, batch: u32, seq: u32, kv_seq: u32, heads: u32, head_dim: u32, mask_kind: MaskKind, q_row_stride: u32, k_row_stride: u32, v_row_stride: u32, bhsd: bool,
}, AttentionBackward {
Show 14 fields q: usize, k: usize, v: usize, dy: usize, mask: usize, out: usize, batch: u32, seq: u32, kv_seq: u32, heads: u32, head_dim: u32, mask_kind: MaskKind, wrt: AttentionBwdWrt, bhsd: bool,
}, Rope { src: usize, cos: usize, sin: usize, dst: usize, batch: u32, seq: u32, hidden: u32, head_dim: u32, n_rot: u32, cos_len: u32, src_row_stride: u32, }, FusedAttnBlock {
Show 17 fields hidden: usize, qkv_w: usize, out_w: usize, mask: usize, out: usize, qkv_b: usize, out_b: usize, cos: usize, sin: usize, cos_len: u32, batch: u32, seq: u32, hs: u32, nh: u32, dh: u32, has_bias: bool, has_rope: bool,
}, FusedBertLayer {
Show 23 fields hidden: usize, qkv_w: usize, qkv_b: usize, out_w: usize, out_b: usize, mask: usize, ln1_g: usize, ln1_b: usize, eps1: f32, fc1_w: usize, fc1_b: usize, fc2_w: usize, fc2_b: usize, ln2_g: usize, ln2_b: usize, eps2: f32, out: usize, batch: u32, seq: u32, hs: u32, nh: u32, dh: u32, int_dim: u32,
}, FusedNomicLayer {
Show 23 fields hidden: usize, qkv_w: usize, out_w: usize, mask: usize, cos: usize, sin: usize, cos_len: u32, ln1_g: usize, ln1_b: usize, eps1: f32, fc11_w: usize, fc12_w: usize, fc2_w: usize, ln2_g: usize, ln2_b: usize, eps2: f32, out: usize, batch: u32, seq: u32, hs: u32, nh: u32, dh: u32, int_dim: u32,
}, FusedSwiGLU { src: usize, dst: usize, n_half: u32, total: u32, gate_first: bool, }, Concat { dst: usize, outer: u32, inner: u32, total_axis: u32, inputs: Vec<(usize, u32)>, }, Compare { lhs: usize, rhs: usize, dst: usize, len: u32, op: CmpOp, }, Reduce { src: usize, dst: usize, outer: u32, reduced: u32, inner: u32, op: ReduceOp, }, TopK { src: usize, dst: usize, outer: u32, axis_dim: u32, k: u32, }, GroupedMatMul { input: usize, weight: usize, expert_idx: usize, dst: usize, m: u32, k_dim: u32, n: u32, num_experts: u32, }, DequantGroupedMatMulGguf { input: usize, w_q: usize, expert_idx: usize, dst: usize, m: u32, k_dim: u32, n: u32, num_experts: u32, scheme: QuantScheme, }, DequantMoEWeightsGguf { w_q: usize, dst: usize, k_dim: u32, n: u32, num_experts: u32, scheme: QuantScheme, }, ScatterAdd { updates: usize, indices: usize, dst: usize, num_updates: u32, out_dim: u32, trailing: u32, }, Where { cond: usize, on_true: usize, on_false: usize, dst: usize, len: u32, }, Transpose { src: usize, dst: usize, in_total: u32, out_dims: Vec<u32>, in_strides: Vec<u32>, }, GatherAxis { table: usize, idx: usize, dst: usize, outer: u32, axis_dim: u32, num_idx: u32, trailing: u32, }, Pool2D {
Show 15 fields src: usize, dst: usize, n: u32, c: u32, h: u32, w: u32, h_out: u32, w_out: u32, kh: u32, kw: u32, sh: u32, sw: u32, ph: u32, pw: u32, kind: ReduceOp,
}, Conv2D {
Show 19 fields src: usize, weight: usize, dst: usize, n: u32, c_in: u32, h: u32, w: u32, c_out: u32, h_out: u32, w_out: u32, kh: u32, kw: u32, sh: u32, sw: u32, ph: u32, pw: u32, dh: u32, dw: u32, groups: u32,
}, QMatMul { x: usize, w: usize, bias: usize, out: usize, m: u32, k: u32, n: u32, x_zp: i32, w_zp: i32, out_zp: i32, mult: f32, }, QConv2d {
Show 24 fields x: usize, w: usize, bias: usize, out: usize, n: u32, c_in: u32, h: u32, w_in: u32, c_out: u32, h_out: u32, w_out: u32, kh: u32, kw: u32, sh: u32, sw: u32, ph: u32, pw: u32, dh: u32, dw: u32, groups: u32, x_zp: i32, w_zp: i32, out_zp: i32, mult: f32,
}, Quantize { x: usize, q: usize, len: u32, chan_axis: u32, chan_dim: u32, inner: u32, scales: Vec<f32>, zero_points: Vec<i32>, }, Dequantize { q: usize, x: usize, len: u32, chan_axis: u32, chan_dim: u32, inner: u32, scales: Vec<f32>, zero_points: Vec<i32>, }, FakeQuantize { x: usize, out: usize, len: u32, chan_axis: u32, chan_dim: u32, inner: u32, bits: u8, ste: SteKind, scale_mode: ScaleMode, state_off: Option<usize>, }, FakeQuantizeBackward { x: usize, dy: usize, dx: usize, len: u32, chan_axis: u32, chan_dim: u32, inner: u32, bits: u8, ste: SteKind, }, FakeQuantizeLSQ { x: usize, scale_off: usize, out: usize, len: u32, chan_axis: u32, chan_dim: u32, inner: u32, bits: u8, }, FakeQuantizeLSQBackwardX { x: usize, scale_off: usize, dy: usize, dx: usize, len: u32, chan_axis: u32, chan_dim: u32, inner: u32, bits: u8, }, FakeQuantizeLSQBackwardScale { x: usize, scale_off: usize, dy: usize, dscale: usize, len: u32, chan_axis: u32, chan_dim: u32, inner: u32, bits: u8, }, ReluBackward { x: usize, dy: usize, dx: usize, len: u32, }, ReluBackwardF64 { x: usize, dy: usize, dx: usize, len: u32, }, ActivationBackward { x: usize, dy: usize, dx: usize, len: u32, kind: Activation, }, ActivationBackwardF64 { x: usize, dy: usize, dx: usize, len: u32, kind: Activation, }, LayerNormBackwardInput { x: usize, gamma: usize, dy: usize, dx: usize, rows: u32, h: u32, eps: f32, }, LayerNormBackwardGamma { x: usize, dy: usize, dgamma: usize, rows: u32, h: u32, eps: f32, }, RmsNormBackwardInput { x: usize, gamma: usize, beta: usize, dy: usize, dx: usize, rows: u32, h: u32, eps: f32, }, RmsNormBackwardGamma { x: usize, gamma: usize, beta: usize, dy: usize, dgamma: usize, rows: u32, h: u32, eps: f32, }, RmsNormBackwardBeta { x: usize, gamma: usize, beta: usize, dy: usize, dbeta: usize, rows: u32, h: u32, eps: f32, }, RopeBackward { dy: usize, cos: usize, sin: usize, dx: usize, batch: u32, seq: u32, hidden: u32, head_dim: u32, n_rot: u32, cos_len: u32, }, CumsumBackward { dy: usize, dx: usize, rows: u32, cols: u32, exclusive: bool, }, GatherBackward { dy: usize, indices: usize, dst: usize, outer: u32, axis_dim: u32, num_idx: u32, trailing: u32, }, GroupNormBackwardInput { x: usize, gamma: usize, beta: usize, dy: usize, dx: usize, n: u32, c: u32, h: u32, w: u32, num_groups: u32, eps: f32, }, GroupNormBackwardGamma { x: usize, dy: usize, dgamma: usize, n: u32, c: u32, h: u32, w: u32, num_groups: u32, eps: f32, }, GroupNormBackwardBeta { dy: usize, dbeta: usize, n: u32, c: u32, h: u32, w: u32, }, MaxPool2dBackward {
Show 15 fields x: usize, dy: usize, dx: usize, n: u32, c: u32, h: u32, w: u32, h_out: u32, w_out: u32, kh: u32, kw: u32, sh: u32, sw: u32, ph: u32, pw: u32,
}, Conv2dBackwardInput {
Show 19 fields dy: usize, w: usize, dx: usize, n: u32, c_in: u32, h: u32, w_in: u32, c_out: u32, h_out: u32, w_out: u32, kh: u32, kw: u32, sh: u32, sw: u32, ph: u32, pw: u32, dh: u32, dw: u32, groups: u32,
}, Conv2dBackwardWeight {
Show 19 fields x: usize, dy: usize, dw: usize, n: u32, c_in: u32, h: u32, w: u32, c_out: u32, h_out: u32, w_out: u32, kh: u32, kw: u32, sh: u32, sw: u32, ph: u32, pw: u32, dh: u32, dw_dil: u32, groups: u32,
}, SoftmaxCrossEntropy { logits: usize, labels: usize, dst: usize, n: u32, c: u32, }, SoftmaxCrossEntropyBackward { logits: usize, labels: usize, d_loss: usize, dlogits: usize, n: u32, c: u32, }, CustomOp { kernel: Arc<dyn CpuKernel>, inputs: Vec<(usize, u32, Shape)>, output: (usize, u32, Shape), attrs: Vec<u8>, }, GaussianSplatRender {
Show 23 fields positions_off: usize, positions_len: usize, scales_off: usize, scales_len: usize, rotations_off: usize, rotations_len: usize, opacities_off: usize, opacities_len: usize, colors_off: usize, colors_len: usize, sh_coeffs_off: usize, sh_coeffs_len: usize, meta_off: usize, dst_off: usize, dst_len: usize, width: u32, height: u32, tile_size: u32, radius_scale: f32, alpha_cutoff: f32, max_splat_steps: u32, transmittance_threshold: f32, max_list_entries: u32,
}, GaussianSplatRenderBackward {
Show 28 fields positions_off: usize, positions_len: usize, scales_off: usize, scales_len: usize, rotations_off: usize, rotations_len: usize, opacities_off: usize, opacities_len: usize, colors_off: usize, colors_len: usize, sh_coeffs_off: usize, sh_coeffs_len: usize, meta_off: usize, d_loss_off: usize, d_loss_len: usize, packed_off: usize, packed_len: usize, width: u32, height: u32, tile_size: u32, radius_scale: f32, alpha_cutoff: f32, max_splat_steps: u32, transmittance_threshold: f32, max_list_entries: u32, loss_grad_clip: f32, sh_band: u32, max_anisotropy: f32,
}, GaussianSplatPrepare {
Show 24 fields positions_off: usize, positions_len: usize, scales_off: usize, scales_len: usize, rotations_off: usize, rotations_len: usize, opacities_off: usize, opacities_len: usize, colors_off: usize, colors_len: usize, sh_coeffs_off: usize, sh_coeffs_len: usize, meta_off: usize, meta_len: usize, prep_off: usize, prep_len: usize, width: u32, height: u32, tile_size: u32, radius_scale: f32, alpha_cutoff: f32, max_splat_steps: u32, transmittance_threshold: f32, max_list_entries: u32,
}, GaussianSplatRasterize {
Show 14 fields prep_off: usize, prep_len: usize, meta_off: usize, meta_len: usize, dst_off: usize, dst_len: usize, count: usize, width: u32, height: u32, tile_size: u32, alpha_cutoff: f32, max_splat_steps: u32, transmittance_threshold: f32, max_list_entries: u32,
}, Fft1d { src: usize, dst: usize, outer: u32, n_complex: u32, inverse: bool, dtype: DType, },
}
Expand description

A pre-compiled kernel call with all args resolved to arena offsets.

Variants§

§

Nop

Skip (Input/Param already in arena)

§

Sgemm

C = A @ B (BLAS sgemm)

Fields

§

DenseSolveF64

f64 dense solve x = A⁻¹·b via LAPACK dgesv. a, b, x are byte-offsets into the arena. n is the matrix dimension; nrhs is 1 for a vector RHS or >1 for multi-RHS. The kernel materializes scratch copies of A and b internally (LAPACK overwrites both with LU factors and solution).

Fields

§nrhs: u32
§

DenseSolveF32

f32 twin of DenseSolveF64. Calls LAPACK sgesv (or the no-blas Rust fallback). Same arena byte-offset contract.

Fields

§nrhs: u32
§

BatchedDenseSolveF64

Batched f64 dense solve. a, b, x are byte-offsets to the leading slice; batch is the number of independent systems. Per slice the kernel calls dgesv(A_i, b_i, n, nrhs) — LAPACK has no batched dgesv on Accelerate, so we loop.

Fields

§batch: u32
§nrhs: u32
§

BatchedDenseSolveF32

Batched f32 dense solve — loop of sgesv per batch slice.

Fields

§batch: u32
§nrhs: u32
§

BatchedDgemmF64

Batched f64 matmul. Both inputs and output have a leading batch axis of size batch. Per-batch independent dgemm: C[i] = A[i] @ B[i] for i in 0..batch. Used by VJP rules that emit per-batch outer products (e.g., BatchedDenseSolve VJP). The unbatched Dgemm thunk handles the rank-2 case.

Fields

§batch: u32
§

BatchedSgemm

Batched f32 matmul — same loop-per-batch shape as BatchedDgemmF64 but calling sgemm. Needed for attention patterns where both operands carry a batch dim (e.g. q@k^T and attn@v in decomposed self-attention). The 2-D Sgemm flatten trick is wrong in that case because it treats b as a single shared RHS across every batch.

Fields

§batch: u32
§

Dgemm

C = A @ B via Accelerate cblas_dgemm. Mirror of Sgemm at f64.

Fields

§

TransposeF64

f64 N-D index walk used for both Op::Transpose and Op::Expand. in_strides carries 0s on broadcast axes (Expand) or permuted strides (Transpose). Mirror of Thunk::Transpose at f64.

Fields

§src: usize
§dst: usize
§in_total: u32
§out_dims: Vec<u32>
§in_strides: Vec<u32>
§

ActivationF64

f64 element-wise activation. Single-input, single-output. The kernel always reads from src and writes to dst, so it works whether or not the planner aliased the two slots.

Fields

§src: usize
§dst: usize
§len: u32
§

ComplexNormSqF32

Element-wise complex squared-magnitude: |z|² = re² + im². Reads the C64 input at src as 2·len f32 ([re,im] pairs), writes len f32 to dst.

Fields

§src: usize
§dst: usize
§len: u32

Logical element count (number of complex values).

§

ComplexNormSqBackwardF32

Wirtinger backward for [ComplexNormSqF32]: dz = g · z as C64. Reads z at 2·len f32 + g at len f32; writes 2·len f32 to dz.

Fields

§len: u32
§

ConjugateC64

Element-wise C64 conjugate: writes [re_i, -im_i] per element. Layout matches the rest of C64 here ([re,im] interleaved f32).

Fields

§src: usize
§dst: usize
§len: u32
§

ActivationC64

C64 element-wise activation. Only kinds with well-defined complex extensions are supported: Neg, Exp, Log, Sqrt. Everything else (Sigmoid, Tanh, Relu, Abs, Sin/Cos/Tan/Atan, Round, GeLU family) is rejected at lowering — those don’t have single natural complex definitions. len is the complex element count (the f32 buffer holds 2·len floats).

Fields

§src: usize
§dst: usize
§len: u32
§

ReduceSumF64

f64 contiguous reduction along a single axis range. Layout [outer, reduced, inner] in memory; output is [outer, inner]. Sum only for now (Mean composes via 1/N multiply post-pass).

Fields

§src: usize
§dst: usize
§outer: u32
§reduced: u32
§inner: u32
§

CopyF64

f64 plain copy (Reshape / Cast at the same dtype). Mirrors Copy but at 8 bytes per element.

Fields

§src: usize
§dst: usize
§len: u32
§

BinaryFullF64

f64 element-wise binary with broadcast. len/lhs_len/rhs_len are element counts; kernel does out[i] = lhs[i % lhs_len] OP rhs[i % rhs_len]. Mirror of BinaryFull at 8 bytes per element.

Fields

§lhs: usize
§rhs: usize
§dst: usize
§len: u32
§lhs_len: u32
§rhs_len: u32
§out_dims_bcast: Vec<u32>

Output shape dims (row-major). Empty in the fast path. See BinaryFull doc for the broadcast convention.

§bcast_lhs_strides: Vec<u32>
§bcast_rhs_strides: Vec<u32>
§

ConcatF64

f64 concat — byte-for-byte mirror of Concat but copies 8 bytes per element. Element-counted offsets/strides match the f32 variant; the executor scales by elem_size internally.

Fields

§dst: usize
§outer: u32
§inner: u32
§total_axis: u32
§inputs: Vec<(usize, u32)>
§

BinaryFullC64

C64 element-wise binary with broadcast. Same len / lhs_len / rhs_len semantics as BinaryFull but each “element” is one complex value (8 bytes = [re, im] as two f32s). The executor reads the underlying f32 buffer at 2·len floats and walks element pairs. Supports Add / Sub / Mul / Div; Max / Min / Pow have no single natural complex definition and panic at lowering.

Fields

§lhs: usize
§rhs: usize
§dst: usize
§len: u32

Complex element count (NOT f32 count). f32 buffer length is 2·len.

§lhs_len: u32
§rhs_len: u32
§out_dims_bcast: Vec<u32>
§bcast_lhs_strides: Vec<u32>
§bcast_rhs_strides: Vec<u32>
§

Scan

Bounded scan. Holds a recursively-compiled body schedule + a pre-initialized body arena snapshot (constants filled). Each outer execution clones the snapshot, copies the carry-in slot from the outer arena, runs the body schedule length times, then writes the final carry to the outer arena.

Single-carry MVP — body has exactly one Input and one output, both same shape and dtype.

Fields

§body_init: Arc<Vec<u8>>
§body_input_off: usize
§body_output_off: usize
§outer_init_off: usize
§outer_final_off: usize
§length: u32
§carry_bytes: u32
§save_trajectory: bool

When true, write each step’s carry to the outer arena at offset outer_final_off + t * carry_bytes, producing a [length, *carry] stacked trajectory. When false, only the final carry lands at outer_final_off.

§xs_inputs: Arc<Vec<(usize, usize, u32)>>

Per-step xs inputs. For each: (body_x_input_off, outer_xs_base_off, per_step_bytes). Per iteration t, the executor copies outer_xs_base_off + t * per_step_bytes into body_x_input_off. Empty when the scan has no xs.

§bcast_inputs: Arc<Vec<(usize, usize, u32)>>

Broadcast inputs — values constant across iterations. For each: (body_bcast_input_off, outer_bcast_off, total_bytes). Filled into body_buf ONCE before the scan loop starts (xs in contrast are re-filled every iteration). Empty when the scan has no bcasts.

§num_checkpoints: u32

Number of trajectory checkpoints (when save_trajectory). 0 or length ⇒ save every iteration. Otherwise save only K rows at indices floor((k+1) * length / K) - 1 for k in 0..K. Last index is always length-1 so the final carry is always cached.

§

ScanBackward

Reverse-mode AD companion to Thunk::Scan. Walks t = length-1 .. 0, threading dcarry through the body’s VJP. Per iteration: writes carry_t (from outer init or trajectory), each xs_i[t] slice, and the current dcarry into the body_vjp’s Input slots, runs body_vjp, reads new dcarry from its single output. f64 carry only — the upstream-accumulation step in trajectory mode does an element-wise f64 add.

Fields

§body_vjp: Arc<ThunkSchedule>
§body_init: Arc<Vec<u8>>
§body_carry_in_off: usize
§body_x_offs: Arc<Vec<usize>>
§body_d_output_off: usize
§body_dcarry_out_off: usize
§outer_init_off: usize
§outer_traj_off: usize
§outer_upstream_off: usize
§outer_xs_offs: Arc<Vec<(usize, u32)>>

Per-xs entries: (outer_xs_base_off, per_step_bytes). Read xs_i[t] from outer_xs_base_off + t * per_step_bytes.

§outer_dinit_off: usize
§length: u32
§carry_bytes: u32
§carry_elem_size: u32

Bytes per element in the carry tensor: 4 for f32, 8 for f64. Used to dispatch the trajectory-mode upstream accumulation kernel (the dcarry += upstream[t] add must use the right floating-point type — a hard-coded f64 add silently does nothing for an f32 carry whose cb isn’t divisible by 8).

§save_trajectory: bool
§num_checkpoints: u32

Recursive checkpointing config. 0 or length ⇒ full trajectory cached, no recompute (existing behavior). 0 < K < length ⇒ trajectory has only K rows; the executor recomputes intermediate carries via forward_body between checkpoints. Memory: O(K · carry_bytes); time: O(length).

§forward_body: Option<Arc<ThunkSchedule>>

Forward body schedule (same compiled body as the forward Op::Scan), used for recompute when num_checkpoints is active. None for the All strategy.

§forward_body_init: Option<Arc<Vec<u8>>>

Pristine forward body arena bytes (constants filled).

§forward_body_carry_in_off: usize

Forward body’s carry-Input and output slot offsets — needed to seed/read the body during recompute.

§forward_body_output_off: usize
§forward_body_x_offs: Arc<Vec<usize>>

Forward body’s per-step xs Input slots (one per outer xs). Same indexing convention as body_x_offs.

§

ScanBackwardXs

Companion to ScanBackward that materializes one stacked dxs_i. Same backward loop; per iteration, after running body_vjp, copies its body_dxs_out_off slot into the outer arena at outer_dxs_off + t * per_step_bytes. dcarry threading is identical — we still need it for the body_vjp recurrence even though we don’t write it back to the outer arena.

Fields

§body_vjp: Arc<ThunkSchedule>
§body_init: Arc<Vec<u8>>
§body_carry_in_off: usize
§body_x_offs: Arc<Vec<usize>>
§body_d_output_off: usize
§body_dcarry_out_off: usize
§body_dxs_out_off: usize
§outer_init_off: usize
§outer_traj_off: usize
§outer_upstream_off: usize
§outer_xs_offs: Arc<Vec<(usize, u32)>>
§outer_dxs_off: usize
§length: u32
§carry_bytes: u32
§carry_elem_size: u32

Same role as Thunk::ScanBackward::carry_elem_size.

§per_step_bytes: u32
§save_trajectory: bool
§num_checkpoints: u32

Recursive checkpointing config. Same semantics as Thunk::ScanBackward::num_checkpoints0 or length means “save every step’s carry”; 0 < K < length means the trajectory has only K rows and the executor recomputes intermediate carries via forward_body (which must be Some). Implemented via segment-cached recompute, mirroring the ScanBackward path.

§forward_body: Option<Arc<ThunkSchedule>>
§forward_body_init: Option<Arc<Vec<u8>>>
§forward_body_carry_in_off: usize
§forward_body_output_off: usize
§forward_body_x_offs: Arc<Vec<usize>>
§

CustomFn

User-defined sub-graph (Op::CustomFn) — runs fwd_body once. Per execution: clone body_init, copy each primal input from the outer arena into its body Input slot, run the body schedule, copy the body’s single output back to the outer arena.

Fields

§body_init: Arc<Vec<u8>>
§inputs: Arc<Vec<(usize, usize, u32)>>

Per primal input: (body_input_off, outer_input_off, bytes).

§body_output_off: usize
§outer_output_off: usize
§out_bytes: u32
§

FusedMmBiasAct

C = A @ B; C += bias; C = act(C)

§

FusedResidualLN

out = LN(x + residual + bias, gamma, beta)

Fields

§res: usize
§bias: usize
§out: usize
§rows: u32
§eps: f32
§has_bias: bool
§

FusedResidualRmsNorm

out = RmsNorm(x + residual + bias, gamma, beta)

Fields

§res: usize
§bias: usize
§out: usize
§rows: u32
§eps: f32
§has_bias: bool
§

BiasAdd

out = bias_add(data, bias, m, n) for Binary::Add with broadcast

Fields

§src: usize
§bias: usize
§dst: usize
§

BinaryFull

Element-wise binary op with NumPy-style broadcast.

Fast path (lhs_len == rhs_len == len): plain element-wise loop, SIMD-vectorized on aarch64 for Add/Mul. bcast_* fields are unused.

Broadcast path: uses out_dims_bcast + bcast_lhs_strides + bcast_rhs_strides to compute per-cell indices into each operand. The strides are precomputed at thunk-construction time from the operands’ true shapes (with stride 0 on any axis where the operand has size 1). This is the only correct way to handle bidirectional broadcasts like [N, 1] op [1, S] → [N, S], which simple i % lhs_len modulo indexing maps to wrong cells.

Fields

§lhs: usize
§rhs: usize
§dst: usize
§len: u32
§lhs_len: u32
§rhs_len: u32
§out_dims_bcast: Vec<u32>

Output shape dims (row-major). Empty in the fast path.

§bcast_lhs_strides: Vec<u32>

Per-dim stride into lhs (0 where lhs broadcasts).

§bcast_rhs_strides: Vec<u32>

Per-dim stride into rhs.

§

ActivationInPlace

Activation in-place

Fields

§data: usize
§len: u32
§

Gather

Gather axis=0: table[idx] → out

Fields

§table: usize
§table_len: u32
§idx: usize
§dst: usize
§num_idx: u32
§trailing: u32
§

Narrow

Narrow: copy slice (elem_bytes = source element size: 4 for f32, 8 for f64).

Fields

§src: usize
§dst: usize
§outer: u32
§src_stride: u32
§dst_stride: u32
§inner: u32
§elem_bytes: u8
§

Copy

Copy (reshape, expand)

Fields

§src: usize
§dst: usize
§len: u32
§

LayerNorm

LayerNorm standalone

Fields

§src: usize
§dst: usize
§rows: u32
§eps: f32
§

GroupNorm

GroupNorm on NCHW [N,C,H,W].

Fields

§src: usize
§dst: usize
§num_groups: u32
§eps: f32
§

LayerNorm2d

LayerNorm2d on NCHW (SAM / candle semantics).

Fields

§src: usize
§dst: usize
§eps: f32
§

ConvTranspose2d

ConvTranspose2d on NCHW.

Fields

§src: usize
§weight: usize
§dst: usize
§c_in: u32
§w_in: u32
§c_out: u32
§h_out: u32
§w_out: u32
§kh: u32
§kw: u32
§sh: u32
§sw: u32
§ph: u32
§pw: u32
§dh: u32
§dw: u32
§groups: u32
§

ResizeNearest2x

Nearest 2× upsample on NCHW (per-batch slice).

Fields

§src: usize
§dst: usize
§

AxialRope2d

SAM2 axial 2-D RoPE on [batch, seq, num_heads * head_dim].

Fields

§src: usize
§dst: usize
§batch: u32
§seq: u32
§hidden: u32
§end_x: u32
§end_y: u32
§head_dim: u32
§num_heads: u32
§theta: f32
§repeat_factor: u32
§

RmsNorm

RMSNorm: out = (x / sqrt(mean(x^2) + eps)) * gamma + beta. No mean subtraction, hence cheaper than LayerNorm. Used by Llama-class models.

Fields

§src: usize
§dst: usize
§rows: u32
§eps: f32
§

Softmax

Softmax

Fields

§data: usize
§rows: u32
§cols: u32
§

Cumsum

Inclusive (or exclusive) cumulative sum along the last axis (callers pre-flatten higher-dim cumsums via reshape views).

Fields

§src: usize
§dst: usize
§rows: u32
§cols: u32
§exclusive: bool
§

SelectiveScan

Mamba-style selective scan (plan #15). Inputs: x, delta [b,s,h], a [h,n], b [b,s,n], c [b,s,n]. Output: y [b,s,h]. State h carries through the seq.

Fields

§delta: usize
§dst: usize
§batch: u32
§seq: u32
§hidden: u32
§state_size: u32
§

GatedDeltaNet

Gated DeltaNet linear-attention scan (Qwen3.5/3.6 trunk). Inputs: q, k, v [b, s, h, n]; g, beta [b, s, h]. Output: [b, s, h, n]. See Op::GatedDeltaNet for math.

Fields

§beta: usize
§state: usize

When non-zero, load initial [b, h, n, n] state and write the final state back in place after the scan.

§dst: usize
§batch: u32
§seq: u32
§heads: u32
§state_size: u32
§

Conv2D1x1

1×1 conv fast path (plan #26). The general Conv2D thunk runs the textbook 7-deep loop; a 1×1 stride-1 padding-0 groups-1 conv is mathematically a per-batch matmul, and dispatching it through BLAS is 3-10× faster than the scalar nest. Common case: ViT patch-projection follow-on, transformer “expert” reductions in some MoE designs.

Per batch: weight [c_out, c_in] × input [c_in, h*w] = output [c_out, h*w].

Fields

§src: usize
§weight: usize
§dst: usize
§c_in: u32
§c_out: u32
§hw: u32
§

DequantMatMul

Fused dequant + matmul (plan #5). Today supports QuantScheme::Int8Block (symmetric); other schemes panic at lowering time with a clear message until kernels are added.

Fields

§w_q: usize
§scale: usize
§dst: usize
§block_size: u32
§is_asymmetric: bool
§

DequantMatMulGguf

GGUF-format dequant + matmul. Weight is a packed byte tensor in one of the K-quant super-block layouts (Q4_K, Q5_K, Q6_K, Q8_K). Scales / mins live inside the packed bytes — no side-channel scale tensor.

Today this is a “dequant-to-scratch then sgemm” kernel — it keeps the arena memory footprint down (weights stay packed) but the dequant itself happens per matmul. A future fully fused tile-streaming kernel would close the compute gap.

Fields

§w_q: usize
§dst: usize
§

DequantMatMulInt4

Int4 block dequant + matmul (packed nibbles, side scale/zp).

Fields

§w_q: usize
§scale: usize
§dst: usize
§block_size: u32
§is_asymmetric: bool
§

DequantMatMulFp8

FP8 dequant + matmul (per-tensor or per-column scale).

Fields

§w_q: usize
§scale: usize
§dst: usize
§e5m2: bool
§

DequantMatMulNvfp4

NVFP4 (E2M1) block dequant + matmul — 16-wide groups, FP8 scales.

Fields

§w_q: usize
§scale: usize
§global_scale: usize
§dst: usize
§

LoraMatMul

Fused LoRA matmul (plan #9): out = x·W + scale * (x·A)·B. r is the LoRA rank (typically 4-64) — the rank-r intermediate x·A lives in scratch, never on the arena.

Fields

§dst: usize
§scale: f32
§

Sample

Fused sample: logits [batch, vocab] → token ids [batch]. See Op::Sample. Output values are f32-encoded usize indices (matches the rest of the IR’s “ids as f32” convention).

Fields

§logits: usize
§dst: usize
§batch: u32
§vocab: u32
§top_k: u32
§top_p: f32
§temperature: f32
§seed: u64
§

Attention

Attention SDPA. mask is the offset of the optional mask tensor (only meaningful when mask_kind == MaskKind::Custom); other kinds synthesize the mask in-kernel.

Q/K/V each carry a _row_stride (elements per source row). Defaults to heads * head_dim — matches the standalone “Q/K/V are their own contiguous buffers” case. The Narrow→ Attention fusion below rewrites these to the parent QKV stride (typically 3 * heads * head_dim) so the kernel reads QKV directly without materializing the per-head buffers (plan #46).

Fields

§mask: usize
§out: usize
§batch: u32
§seq: u32

Query sequence length.

§kv_seq: u32

Key/value sequence length. Differs from seq during cached decode.

§heads: u32
§head_dim: u32
§mask_kind: MaskKind
§q_row_stride: u32
§k_row_stride: u32
§v_row_stride: u32
§bhsd: bool

Memory layout flag. false (the historical default) → [B, S, H, D] row-major: per-head offset is bi*S*H*D + si*H*D + hi*D. true[B, H, S, D] (head-major), matching the convention used by rlx-cuda / rlx-rocm / rlx-tpu: per-head offset is bi*H*S*D + hi*S*D + si*D. Detected at lowering time from the input shape vs num_heads / head_dim.

§

AttentionBackward

Op::AttentionBackward — emits dQ, dK, or dV (see wrt).

Fields

§mask: usize
§out: usize
§batch: u32
§seq: u32
§kv_seq: u32
§heads: u32
§head_dim: u32
§mask_kind: MaskKind
§bhsd: bool
§

Rope

RoPE (rotary position embeddings). src_row_stride is elements per source row (defaults to hidden for the standalone case; set to qkv_axis * inner when the thunk fusion pass below rewires Rope to read directly from the fused QKV buffer — plan #45).

Fields

§src: usize
§cos: usize
§sin: usize
§dst: usize
§batch: u32
§seq: u32
§hidden: u32
§head_dim: u32
§n_rot: u32
§cos_len: u32
§src_row_stride: u32
§

FusedAttnBlock

Fused attention block: QKV proj → split → [RoPE] → SDPA → output proj. All intermediates stay in L1 cache. Zero arena writes between ops.

Fields

§hidden: usize
§qkv_w: usize
§out_w: usize
§mask: usize
§out: usize
§qkv_b: usize
§out_b: usize
§cos: usize
§sin: usize
§cos_len: u32
§batch: u32
§seq: u32
§hs: u32
§nh: u32
§dh: u32
§has_bias: bool
§has_rope: bool
§

FusedBertLayer

Fused ENTIRE transformer layer: attention + residual + LN + FFN + residual + LN. Combines ~10 thunks into 1. All intermediates on stack. Zero arena traffic.

Fields

§hidden: usize
§qkv_w: usize
§qkv_b: usize
§out_w: usize
§out_b: usize
§mask: usize
§ln1_g: usize
§ln1_b: usize
§eps1: f32
§fc1_w: usize
§fc1_b: usize
§fc2_w: usize
§fc2_b: usize
§ln2_g: usize
§ln2_b: usize
§eps2: f32
§out: usize
§batch: u32
§seq: u32
§hs: u32
§nh: u32
§dh: u32
§int_dim: u32
§

FusedNomicLayer

Fused Nomic transformer layer: attention+RoPE + residual + LN + SwiGLU FFN + residual + LN.

Fields

§hidden: usize
§qkv_w: usize
§out_w: usize
§mask: usize
§cos: usize
§sin: usize
§cos_len: u32
§ln1_g: usize
§ln1_b: usize
§eps1: f32
§fc11_w: usize
§fc12_w: usize
§fc2_w: usize
§ln2_g: usize
§ln2_b: usize
§eps2: f32
§out: usize
§batch: u32
§seq: u32
§hs: u32
§nh: u32
§dh: u32
§int_dim: u32
§

FusedSwiGLU

Fused SwiGLU: out[r,i] = x[r,i] * silu(x[r, n_half+i]). Input: [outer, 2*n_half] — concatenated up||gate per row. Output: [outer, n_half].

Fields

§src: usize
§dst: usize
§n_half: u32
§total: u32
§gate_first: bool
§

Concat

Concat along an axis: output[outer, axis, inner] = inputs concatenated. Each entry of inputs is (src_offset, axis_len_for_that_input) in u32 elements. outer, inner, and total_axis_len are pre-computed at compile time to avoid per-run shape work.

Fields

§dst: usize
§outer: u32
§inner: u32
§total_axis: u32
§inputs: Vec<(usize, u32)>
§

Compare

Element-wise comparison: out = (lhs CMP rhs) ? 1.0 : 0.0

Fields

§lhs: usize
§rhs: usize
§dst: usize
§len: u32
§

Reduce

Reduction along a contiguous range of axes. Input layout (after shape decomposition) is [outer, reduced, inner]; output is [outer, inner]. The single-axis cases (axis=0 → outer=1; axis=last → inner=1) and contiguous multi-axis (e.g. reduce over [0, 1] of an [N, C, H, W] tensor → outer=1, reduced=NC, inner=HW) all map onto this triplet. Non-contiguous axes are not supported and bail to Nop in the compile pass.

Fields

§src: usize
§dst: usize
§outer: u32
§reduced: u32
§inner: u32
§

TopK

Top-K indices along the last axis. Input shape [outer, axis_dim], output [outer, k] of f32-encoded i64 indices. Ties broken by smaller index. Used by MoE gating + beam search.

Fields

§src: usize
§dst: usize
§outer: u32
§axis_dim: u32
§

GroupedMatMul

Indexed batched matmul: out[i] = input[i] @ weight[expert_idx[i]]. Naive impl per token; for real MoE workloads, sort-by-expert + run segmented GEMM would amortize. Done when there’s a workload.

Fields

§input: usize
§weight: usize
§expert_idx: usize
§dst: usize
§k_dim: u32
§num_experts: u32
§

DequantGroupedMatMulGguf

GGUF K-quant packed expert stack + grouped matmul (MoE FFN).

Fields

§input: usize
§w_q: usize
§expert_idx: usize
§dst: usize
§k_dim: u32
§num_experts: u32
§

DequantMoEWeightsGguf

Materialize packed MoE weights to F32 [E, K, N] (autodiff helper).

Fields

§w_q: usize
§dst: usize
§k_dim: u32
§num_experts: u32
§

ScatterAdd

Scatter-add: dst[indices[i] * trailing + j] += updates[i * trailing + j]. Output is zeroed first; multiple updates to the same row accumulate.

Fields

§updates: usize
§indices: usize
§dst: usize
§num_updates: u32
§out_dim: u32
§trailing: u32
§

Where

Ternary select: out = cond != 0 ? on_true : on_false

Fields

§cond: usize
§on_true: usize
§on_false: usize
§dst: usize
§len: u32
§

Transpose

General N-D transpose / broadcast. out_dims[i] is the output’s dim i length; in_strides[i] is the input stride (in elements) used to index that dim — 0 for broadcast dims (Expand). in_total is the total element count in the source buffer (≤ output total when broadcasting). Strides are pre-computed at compile time.

Fields

§src: usize
§dst: usize
§in_total: u32
§out_dims: Vec<u32>
§in_strides: Vec<u32>
§

GatherAxis

Gather along an arbitrary axis. outer = product(dims[..axis]), trailing = product(dims[axis+1..]), axis_dim = the dimension being indexed into. Output: outer × num_idx × trailing. (axis=0 still routes to the simpler Thunk::Gather fast path.)

Fields

§table: usize
§idx: usize
§dst: usize
§outer: u32
§axis_dim: u32
§num_idx: u32
§trailing: u32
§

Pool2D

2D pooling (Max or Mean). Input layout [N, C, H, W], output [N, C, H_out, W_out]. Padding is implicit-zero; Mean divides by the full kernel area (matches torch’s count_include_pad=True).

Fields

§src: usize
§dst: usize
§h_out: u32
§w_out: u32
§kh: u32
§kw: u32
§sh: u32
§sw: u32
§ph: u32
§pw: u32
§

Conv2D

2D convolution. Input [N, C_in, H, W], weight [C_out, C_in_per_group, kH, kW], output [N, C_out, H_out, W_out]. Bias is a separate Op::Binary::Add after the conv (matching the IR’s input layout — Op::Conv has 2 inputs). Naive direct convolution; sufficient for correctness, not optimised.

Fields

§src: usize
§weight: usize
§dst: usize
§c_in: u32
§c_out: u32
§h_out: u32
§w_out: u32
§kh: u32
§kw: u32
§sh: u32
§sw: u32
§ph: u32
§pw: u32
§dh: u32
§dw: u32
§groups: u32
§

QMatMul

Real INT8 matmul with i32 accumulation. out[m, n] = requantize(bias[n] + Σₖ (x[m,k]-x_zp)·(w[k,n]-w_zp), mult, out_zp) Reads x and w as i8, bias as i32; writes out as i8. Same kernel shape as rlx_cortexm::dense::dense_i8 — promoted to a desktop thunk so a quantized graph compiled here doesn’t have to round-trip through fake-quant.

Fields

§bias: usize
§out: usize
§x_zp: i32
§w_zp: i32
§out_zp: i32
§mult: f32
§

QConv2d

Real INT8 conv2d, NCHW layout. Same loop shape as Thunk::Conv2D but with i8 reads, i32 accumulation, and per-output requantize to i8. Bias is i32 in the accumulator scale.

Fields

§bias: usize
§out: usize
§c_in: u32
§w_in: u32
§c_out: u32
§h_out: u32
§w_out: u32
§kh: u32
§kw: u32
§sh: u32
§sw: u32
§ph: u32
§pw: u32
§dh: u32
§dw: u32
§groups: u32
§x_zp: i32
§w_zp: i32
§out_zp: i32
§mult: f32
§

Quantize

INT8 quantize. Reads x as f32, writes q as i8. chan = (i / inner) % chan_dim selects the per-channel scale/zp; chan_axis is informational only (the kernel uses chan_dim and inner directly). For per-tensor, chan_dim = 1 and inner = len so chan is always 0.

Fields

§len: u32
§chan_axis: u32
§chan_dim: u32
§inner: u32
§scales: Vec<f32>
§zero_points: Vec<i32>
§

Dequantize

INT8 dequantize — inverse of Thunk::Quantize.

Fields

§len: u32
§chan_axis: u32
§chan_dim: u32
§inner: u32
§scales: Vec<f32>
§zero_points: Vec<i32>
§

FakeQuantize

QAT fake-quantize. Per-channel (or per-tensor) symmetric quantize-then-dequantize on the fly. Computes s[c] = max(|x[..., c, ...]|) / q_max then out[i] = clamp(round(x[i]/s[c]), -q_max, q_max) * s[c] with q_max = {127, 7, 1} for bits = {8, 4, 2}. Same channel-layout convention as Thunk::Quantize: every element’s channel is (i / inner) % chan_dim. The kernel does two passes — one to scan max-abs per channel, one to quant-dequant per element.

Fields

§out: usize
§len: u32
§chan_axis: u32
§chan_dim: u32
§inner: u32
§bits: u8
§ste: SteKind

STE variant — informational on the forward side (output is the same regardless), kernel-relevant in the matching FakeQuantizeBackward thunk.

§scale_mode: ScaleMode

Scale-tracking strategy. PerBatch recomputes max_abs/q_max every call (the original path). EMA{decay} blends per-batch max-abs into the state_off buffer; Fixed reads state_off and never updates it.

§state_off: Option<usize>

Some(off) for EMA and Fixed; None for PerBatch. Points at a [chan_dim] f32 buffer holding the running scale per channel.

§

FakeQuantizeBackward

Backward pass for Op::FakeQuantize under one of four STE variants. Computes dx[i] from the f32 forward input x and the upstream gradient dy, using the same per-channel scale scheme as the forward.

Fields

§len: u32
§chan_axis: u32
§chan_dim: u32
§inner: u32
§bits: u8
§

FakeQuantizeLSQ

LSQ forward — same kernel shape as FakeQuantize Fixed mode. Reads scale from scale_off (a [chan_dim] Param tensor).

Fields

§scale_off: usize
§out: usize
§len: u32
§chan_axis: u32
§chan_dim: u32
§inner: u32
§bits: u8
§

FakeQuantizeLSQBackwardX

LSQ backward, x-gradient. STE-clipped: passes upstream through inside the quantization range, zeros outside.

Fields

§scale_off: usize
§len: u32
§chan_axis: u32
§chan_dim: u32
§inner: u32
§bits: u8
§

FakeQuantizeLSQBackwardScale

LSQ backward, scale-gradient. Per-channel: dscale[c] = sum_i ψ(x[i]/s[c]) · upstream[i] where ψ(z) = -z + round(z) if |z| ≤ q_max else sign(z) · q_max. Output shape: [chan_dim].

Fields

§scale_off: usize
§dscale: usize
§len: u32
§chan_axis: u32
§chan_dim: u32
§inner: u32
§bits: u8
§

ReluBackward

ReLU backward: dx[i] = dy[i] if x[i] > 0 else 0.

Fields

§len: u32
§

ReluBackwardF64

f64 sibling of ReluBackward — same shape as the f32 variant but reads/writes 8 bytes per element. Required because ReluBackward’s &[f32] slot view returns half of every f64 otherwise → backward silently produces 0 gradients on an f64 graph. Mirrors the ActivationBackwardF64 split.

Fields

§len: u32
§

ActivationBackward

Generic element-wise activation backward. dx[i] = (d/dx act(x))[i] · dy[i]. The closure dispatch is per-element; expensive activations (Gelu) recompute internals inline rather than threading an extra “saved y” tensor through.

Fields

§len: u32
§

ActivationBackwardF64

f64 sibling of ActivationBackward — slot offsets, len in elements; kernel reads/writes 8 bytes per element. Required because ActivationBackward’s &[f32] slot view silently returns garbage on an f64 graph (cb % 4 still works but every loaded value is half of an f64 → wrong gradient).

Fields

§len: u32
§

LayerNormBackwardInput

LayerNorm backward — input gradient. Recomputes mean/var/x̂ from x and emits the closed-form d_x per row.

Fields

§gamma: usize
§rows: u32
§eps: f32
§

LayerNormBackwardGamma

LayerNorm backward — gamma gradient. d_gamma[d] = Σ_row dy·x̂.

Fields

§dgamma: usize
§rows: u32
§eps: f32
§

RmsNormBackwardInput

Fields

§gamma: usize
§beta: usize
§rows: u32
§eps: f32
§

RmsNormBackwardGamma

Fields

§gamma: usize
§beta: usize
§dgamma: usize
§rows: u32
§eps: f32
§

RmsNormBackwardBeta

Fields

§gamma: usize
§beta: usize
§dbeta: usize
§rows: u32
§eps: f32
§

RopeBackward

Fields

§cos: usize
§sin: usize
§batch: u32
§seq: u32
§hidden: u32
§head_dim: u32
§n_rot: u32
§cos_len: u32
§

CumsumBackward

Fields

§rows: u32
§cols: u32
§exclusive: bool
§

GatherBackward

Fields

§indices: usize
§dst: usize
§outer: u32
§axis_dim: u32
§num_idx: u32
§trailing: u32
§

GroupNormBackwardInput

Fields

§gamma: usize
§beta: usize
§num_groups: u32
§eps: f32
§

GroupNormBackwardGamma

Fields

§dgamma: usize
§num_groups: u32
§eps: f32
§

GroupNormBackwardBeta

Fields

§dbeta: usize
§

MaxPool2dBackward

2D max-pool backward (NCHW). Recomputes the argmax position inside each window and accumulates dy into dx at that position. Output is zeroed first; ties resolve to the first hit (lowest (kh,kw) index), matching what the forward kernel does with acc.max(v).

Fields

§h_out: u32
§w_out: u32
§kh: u32
§kw: u32
§sh: u32
§sw: u32
§ph: u32
§pw: u32
§

Conv2dBackwardInput

2D conv backward w.r.t. input (dx = conv_transpose(dy, w)). dy [N, C_out, H_out, W_out], w [C_out, C_in_per_group, kH, kW], dx [N, C_in, H, W].

Fields

§c_in: u32
§w_in: u32
§c_out: u32
§h_out: u32
§w_out: u32
§kh: u32
§kw: u32
§sh: u32
§sw: u32
§ph: u32
§pw: u32
§dh: u32
§dw: u32
§groups: u32
§

Conv2dBackwardWeight

2D conv backward w.r.t. weight. x [N, C_in, H, W], dy [N, C_out, H_out, W_out], dw [C_out, C_in_per_group, kH, kW]. dw is zeroed before accumulation.

Fields

§c_in: u32
§c_out: u32
§h_out: u32
§w_out: u32
§kh: u32
§kw: u32
§sh: u32
§sw: u32
§ph: u32
§pw: u32
§dh: u32
§dw_dil: u32
§groups: u32
§

SoftmaxCrossEntropy

Fused softmax + cross-entropy loss with f32-encoded integer labels. logits [N, C], labels [N], output [N] per-row loss. Numerically stable (max-subtract before exp).

Fields

§logits: usize
§labels: usize
§dst: usize
§

SoftmaxCrossEntropyBackward

Backward of the fused loss above. dlogits[n, k] = (softmax(logits[n])[k] - one_hot(labels[n])[k]) * d_loss[n].

Fields

§logits: usize
§labels: usize
§d_loss: usize
§dlogits: usize
§

CustomOp

User-registered custom op (CPU side). Lowered from Op::Custom. kernel is resolved against the global CPU kernel registry at compile time and stored as Arc<dyn CpuKernel> so execution avoids per-call lookups. v1: f32 contiguous only — see op_registry::CpuKernel::execute_f32.

Fields

§kernel: Arc<dyn CpuKernel>
§inputs: Vec<(usize, u32, Shape)>
§output: (usize, u32, Shape)
§attrs: Vec<u8>
§

GaussianSplatRender

1D FFT along the last axis. Input/output are [..., 2N] real-block layout (first N real, second N imag along the transformed axis). outer is the product of all leading axes; n_complex is N (the number of complex points). Both halves of the real-block layout are read together by the kernel. dtype selects the f32 or f64 path; the two share structure but not buffers, so a flag at compile time avoids per-row dispatch. CPU reference 3D Gaussian splat render (rlx_ir::Op::GaussianSplatRender).

Fields

§positions_off: usize
§positions_len: usize
§scales_off: usize
§scales_len: usize
§rotations_off: usize
§rotations_len: usize
§opacities_off: usize
§opacities_len: usize
§colors_off: usize
§colors_len: usize
§sh_coeffs_off: usize
§sh_coeffs_len: usize
§meta_off: usize
§dst_off: usize
§dst_len: usize
§width: u32
§height: u32
§tile_size: u32
§radius_scale: f32
§alpha_cutoff: f32
§max_splat_steps: u32
§transmittance_threshold: f32
§max_list_entries: u32
§

GaussianSplatRenderBackward

Fields

§positions_off: usize
§positions_len: usize
§scales_off: usize
§scales_len: usize
§rotations_off: usize
§rotations_len: usize
§opacities_off: usize
§opacities_len: usize
§colors_off: usize
§colors_len: usize
§sh_coeffs_off: usize
§sh_coeffs_len: usize
§meta_off: usize
§d_loss_off: usize
§d_loss_len: usize
§packed_off: usize
§packed_len: usize
§width: u32
§height: u32
§tile_size: u32
§radius_scale: f32
§alpha_cutoff: f32
§max_splat_steps: u32
§transmittance_threshold: f32
§max_list_entries: u32
§loss_grad_clip: f32
§sh_band: u32
§max_anisotropy: f32
§

GaussianSplatPrepare

Strict IR stage 1 — project + bin + sort + rays (Op::GaussianSplatPrepare).

Fields

§positions_off: usize
§positions_len: usize
§scales_off: usize
§scales_len: usize
§rotations_off: usize
§rotations_len: usize
§opacities_off: usize
§opacities_len: usize
§colors_off: usize
§colors_len: usize
§sh_coeffs_off: usize
§sh_coeffs_len: usize
§meta_off: usize
§meta_len: usize
§prep_off: usize
§prep_len: usize
§width: u32
§height: u32
§tile_size: u32
§radius_scale: f32
§alpha_cutoff: f32
§max_splat_steps: u32
§transmittance_threshold: f32
§max_list_entries: u32
§

GaussianSplatRasterize

Strict IR stage 2 — tile raster from prepare buffer (Op::GaussianSplatRasterize).

Fields

§prep_off: usize
§prep_len: usize
§meta_off: usize
§meta_len: usize
§dst_off: usize
§dst_len: usize
§count: usize
§width: u32
§height: u32
§tile_size: u32
§alpha_cutoff: f32
§max_splat_steps: u32
§transmittance_threshold: f32
§max_list_entries: u32
§

Fft1d

Fields

§src: usize
§dst: usize
§outer: u32
§n_complex: u32
§inverse: bool
§dtype: DType

Trait Implementations§

Source§

impl Clone for Thunk

Source§

fn clone(&self) -> Thunk

Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more

Auto Trait Implementations§

§

impl Freeze for Thunk

§

impl !RefUnwindSafe for Thunk

§

impl Send for Thunk

§

impl Sync for Thunk

§

impl Unpin for Thunk

§

impl UnsafeUnpin for Thunk

§

impl !UnwindSafe for Thunk

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.