#[non_exhaustive]#[repr(u16)]pub enum SegmentKind {
Show 20 variants
SegmentSum = 0,
SegmentSumBackward = 1,
SegmentMean = 2,
SegmentMeanBackward = 3,
SegmentMax = 4,
SegmentMin = 5,
SegmentProd = 6,
UnsortedSegmentSum = 7,
UnsortedSegmentSumBackward = 8,
UnsortedSegmentMean = 9,
UnsortedSegmentMeanBackward = 10,
UnsortedSegmentMax = 11,
UnsortedSegmentMin = 12,
SegmentMaxBackward = 13,
SegmentMinBackward = 14,
SegmentProdBackward = 15,
UnsortedSegmentMaxBackward = 16,
UnsortedSegmentMinBackward = 17,
UnsortedSegmentProd = 18,
UnsortedSegmentProdBackward = 19,
}Expand description
Segment / scatter-reduce op discriminant — Category S from the comprehensive plan.
Stored as u16 in crate::KernelSku::op when
category == OpCategory::SegmentOps. Each variant maps to a
distinct kernel symbol — sorted and unsorted families live in the
same enum (different op slots) because the kernel implementation
differs (sorted = binary-search single-pass sweep; unsorted = atomic
scatter from the input side).
Phase 7 Milestone 7.6 wires:
- Sorted:
Self::SegmentSum,Self::SegmentMean,Self::SegmentMax,Self::SegmentMin,Self::SegmentProd(FW). Sum / Mean carry a BW variant (Self::SegmentSumBackward,Self::SegmentMeanBackward). - Unsorted:
Self::UnsortedSegmentSum,Self::UnsortedSegmentMean,Self::UnsortedSegmentMax,Self::UnsortedSegmentMin(FW). Sum / Mean carry a BW variant (Self::UnsortedSegmentSumBackward,Self::UnsortedSegmentMeanBackward).
Phase 25 closes the remaining BW gaps: Max / Min BW (sorted +
unsorted) recompute the argmax in the BW kernel (preserves FW API
source-compat — no paired-index tensor in the FW signature). Prod
BW (sorted + unsorted) computes d_output * prod / x with direct
division — caller must avoid zero-valued inputs in the segment or
accept NaN/Inf in the gradient. Unsorted Prod FW uses an
atomicCAS retry loop (no native FP atomicMul).
Dtype coverage: f32, f64 (atomic-supported FP types). f16 / bf16
deferred — the kernels use atomicAdd / atomicMax / atomicMin
which are restricted to native-FP-atomic types.
Variants (Non-exhaustive)§
This enum is marked as non-exhaustive
SegmentSum = 0
out[s, d] = Σ_{n : seg[n] == s} input[n, d] — sorted segment
IDs (monotonically non-decreasing). TF / JAX segment_sum.
SegmentSumBackward = 1
Gradient of Self::SegmentSum:
d_input[n, d] = d_output[seg[n], d] (gather along seg ids).
SegmentMean = 2
out[s, d] = mean_{n : seg[n] == s} input[n, d] — sorted.
SegmentMeanBackward = 3
Gradient of Self::SegmentMean:
d_input[n, d] = d_output[seg[n], d] / count[seg[n]].
SegmentMax = 4
out[s, d] = max_{n : seg[n] == s} input[n, d] — sorted.
SegmentMin = 5
out[s, d] = min_{n : seg[n] == s} input[n, d] — sorted.
SegmentProd = 6
out[s, d] = prod_{n : seg[n] == s} input[n, d] — sorted.
UnsortedSegmentSum = 7
out[s, d] = Σ_{n : seg[n] == s} input[n, d] — unsorted
(seg IDs in any order). TF unsorted_segment_sum.
UnsortedSegmentSumBackward = 8
Gradient of Self::UnsortedSegmentSum:
d_input[n, d] = d_output[seg[n], d].
UnsortedSegmentMean = 9
out[s, d] = mean_{n : seg[n] == s} input[n, d] — unsorted.
UnsortedSegmentMeanBackward = 10
Gradient of Self::UnsortedSegmentMean:
d_input[n, d] = d_output[seg[n], d] / count[seg[n]].
UnsortedSegmentMax = 11
out[s, d] = max_{n : seg[n] == s} input[n, d] — unsorted.
UnsortedSegmentMin = 12
out[s, d] = min_{n : seg[n] == s} input[n, d] — unsorted.
SegmentMaxBackward = 13
Phase 25. Gradient of Self::SegmentMax:
d_input[k, d] = d_output[seg, d] for the (lowest-index) k
where input[k, d] == max. Argmax recomputed in BW kernel
(re-scans the segment) so the FW signature stays unchanged.
SegmentMinBackward = 14
Phase 25. Gradient of Self::SegmentMin — mirror of
Self::SegmentMaxBackward.
SegmentProdBackward = 15
Phase 25. Gradient of Self::SegmentProd:
d_input[k, d] = d_output[seg, d] * (prod[seg, d] / x[k, d]).
Direct division — caller must avoid zero-valued inputs in the
segment or accept NaN / Inf in the gradient.
UnsortedSegmentMaxBackward = 16
Phase 25. Gradient of Self::UnsortedSegmentMax — same
recompute-argmax pattern as the sorted variant but scans the
full input array per (seg, d) cell. Non-deterministic w.r.t.
tie-breaking when the FW was non-deterministic.
UnsortedSegmentMinBackward = 17
Phase 25. Gradient of Self::UnsortedSegmentMin — mirror of
Self::UnsortedSegmentMaxBackward.
UnsortedSegmentProd = 18
Phase 25. out[s, d] = prod_{n : seg[n] == s} input[n, d] —
unsorted. Uses an atomicCAS retry loop because no native FP
atomicMul exists. Non-deterministic.
UnsortedSegmentProdBackward = 19
Phase 25. Gradient of Self::UnsortedSegmentProd — same
direct-division pattern as Self::SegmentProdBackward.
Trait Implementations§
Source§impl Clone for SegmentKind
impl Clone for SegmentKind
Source§fn clone(&self) -> SegmentKind
fn clone(&self) -> SegmentKind
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreimpl Copy for SegmentKind
Source§impl Debug for SegmentKind
impl Debug for SegmentKind
impl Eq for SegmentKind
Source§impl Hash for SegmentKind
impl Hash for SegmentKind
Source§impl PartialEq for SegmentKind
impl PartialEq for SegmentKind
Source§fn eq(&self, other: &SegmentKind) -> bool
fn eq(&self, other: &SegmentKind) -> bool
self and other values to be equal, and is used by ==.