Skip to main content

Module argmax_rows

Module argmax_rows 

Source
Expand description

argmax_rows_f16 op-diff harness โ€” see crate::op_diff.

Per-row argmax over an [m, n] logits buffer. Metal stores logits as f16, so to keep the argmax unambiguous across the f32 reference and the f16 kernel, each row gets a well-separated spike at a deterministic column โ€” both backends must select it regardless of f16 rounding. The compared output is the m winning indices (as f32).

Structsยง

ArgmaxRowsOp