Skip to main content

Module take_along_axis

Module take_along_axis 

Source
Expand description

ADR-020 iter-11h-e1 — take_along_axis (gather) + scatter-backward for the GpuTape autograd pipeline. Forward gathers values along the last axis using a precomputed (non-differentiable) index buffer; backward scatters gradients back into a zero-initialised dx buffer.

Used by MoE router on GpuTape (iter-11h-e): y = take_along_axis(softmax(gate(x)), top_k_indices, axis=-1)

Statics§

TAKE_ALONG_AXIS_SHADER_SOURCE

Functions§

dispatch_take_along_axis_backward_f32
dispatch_take_along_axis_f32
register