#[non_exhaustive]#[repr(u16)]pub enum SortKind {
Show 15 variants
Sort = 0,
SortBackward = 1,
Argsort = 2,
Msort = 3,
MsortBackward = 4,
Topk = 5,
TopkBackward = 6,
Kthvalue = 7,
KthvalueBackward = 8,
Unique = 9,
UniqueConsecutive = 10,
Histogram = 11,
Histogramdd = 12,
Bincount = 13,
Searchsorted = 14,
}Expand description
Sorting / order-statistics op discriminant — Category O from the comprehensive plan (Phase 9).
Stored as u16 in crate::KernelSku::op when
category == OpCategory::Sorting. Phase 9 wires the block-bitonic
trailblazer family (row_len ≤ 1024, k ≤ 64):
Self::Sort/Self::SortBackward— full sort with saved indices for BW. PyTorchtorch.sort.Self::Argsort— indices-only variant. PyTorchtorch.argsort.Self::Msort/Self::MsortBackward— stable sort (tie-break on original index preserves input order). PyTorchtorch.msort.Self::Topk/Self::TopkBackward— top-k by value (or bottom-k whenlargest == false). PyTorchtorch.topk.Self::Kthvalue/Self::KthvalueBackward— composed atop topk; returns the k-th value + its index.Self::Unique/Self::UniqueConsecutive— set-valued ops;uniquechains sort + consecutive-dedup,unique_consecutiveassumes the input is already sorted (or only run-equal cells matter). No BW (set-valued).Self::Histogram/Self::Histogramdd/Self::Bincount— atomic-bin accumulation; histogram + bincount FW shipped, histogramdd reserved (rank > 1 trailblazer follow-up).Self::Searchsorted— per-query binary search in a 1-D sorted array. PyTorchtorch.searchsorted. No BW.
Dtype coverage:
- sort / argsort / msort FW:
f32, f64, i32, i64. - sort / msort BW:
f32, f64(FP grads only). - topk FW + BW:
f32, f64. - kthvalue: composes topk; same dtype set.
- unique / unique_consecutive:
f32, f64, i32. - histogram:
f32, f64input →i32counts. - bincount:
i32, i64input →i32counts. - searchsorted:
f32, f64, i32, i64.
Variants (Non-exhaustive)§
This enum is marked as non-exhaustive
Sort = 0
sort(x, dim, descending) — returns sorted values + sorted
indices. PyTorch torch.sort.
SortBackward = 1
Gradient of Self::Sort — scatter dy back to the original
positions via the saved indices.
Argsort = 2
argsort(x, dim, descending) — returns sorted indices only.
PyTorch torch.argsort.
Msort = 3
msort(x) — stable sort along the last dimension. Tie-break on
original index preserves input order. PyTorch torch.msort.
MsortBackward = 4
Gradient of Self::Msort — same scatter as
Self::SortBackward.
Topk = 5
topk(x, k, dim, largest) — top-k (or bottom-k) values + their
indices. PyTorch torch.topk. Trailblazer caps k ≤ 64.
TopkBackward = 6
Gradient of Self::Topk — scatter the k-wide dy back to a
zero-init row_len-wide dx via saved indices.
Kthvalue = 7
kthvalue(x, k, dim) — the k-th smallest value + its index.
Composed at the Rust plan layer atop Self::Topk with the
“bottom-k” order.
KthvalueBackward = 8
Gradient of Self::Kthvalue — scatter the scalar dy back
to the single source position.
Unique = 9
unique(x, sorted=True) — returns the unique values in x. At
the Rust plan layer this chains Self::Sort + the consecutive
dedup. Set-valued — no BW.
UniqueConsecutive = 10
unique_consecutive(x) — emits one cell per run-start (input
must be sorted, or only consecutive-equal cells should be
collapsed). Set-valued — no BW.
Histogram = 11
histogram(x, bins, range) — 1-D uniform-bin histogram.
PyTorch torch.histogram. FW only.
Histogramdd = 12
histogramdd(x, bins, range) — N-D histogram. Reserved
discriminant; rank > 1 trailblazer follow-up.
Bincount = 13
bincount(x, minlength) — count occurrences of each integer
in x. PyTorch torch.bincount. FW only.
Searchsorted = 14
searchsorted(sorted_seq, values, right) — per-query
lower/upper bound binary search. PyTorch torch.searchsorted.
FW only.