Skip to main content

SegmentKind

Enum SegmentKind 

Source
#[non_exhaustive]
#[repr(u16)]
pub enum SegmentKind {
Show 20 variants SegmentSum = 0, SegmentSumBackward = 1, SegmentMean = 2, SegmentMeanBackward = 3, SegmentMax = 4, SegmentMin = 5, SegmentProd = 6, UnsortedSegmentSum = 7, UnsortedSegmentSumBackward = 8, UnsortedSegmentMean = 9, UnsortedSegmentMeanBackward = 10, UnsortedSegmentMax = 11, UnsortedSegmentMin = 12, SegmentMaxBackward = 13, SegmentMinBackward = 14, SegmentProdBackward = 15, UnsortedSegmentMaxBackward = 16, UnsortedSegmentMinBackward = 17, UnsortedSegmentProd = 18, UnsortedSegmentProdBackward = 19,
}
Expand description

Segment / scatter-reduce op discriminant — Category S from the comprehensive plan.

Stored as u16 in crate::KernelSku::op when category == OpCategory::SegmentOps. Each variant maps to a distinct kernel symbol — sorted and unsorted families live in the same enum (different op slots) because the kernel implementation differs (sorted = binary-search single-pass sweep; unsorted = atomic scatter from the input side).

Phase 7 Milestone 7.6 wires:

Phase 25 closes the remaining BW gaps: Max / Min BW (sorted + unsorted) recompute the argmax in the BW kernel (preserves FW API source-compat — no paired-index tensor in the FW signature). Prod BW (sorted + unsorted) computes d_output * prod / x with direct division — caller must avoid zero-valued inputs in the segment or accept NaN/Inf in the gradient. Unsorted Prod FW uses an atomicCAS retry loop (no native FP atomicMul).

Dtype coverage: f32, f64 (atomic-supported FP types). f16 / bf16 deferred — the kernels use atomicAdd / atomicMax / atomicMin which are restricted to native-FP-atomic types.

Variants (Non-exhaustive)§

This enum is marked as non-exhaustive
Non-exhaustive enums could have additional variants added in future. Therefore, when matching against variants of non-exhaustive enums, an extra wildcard arm must be added to account for any future variants.
§

SegmentSum = 0

out[s, d] = Σ_{n : seg[n] == s} input[n, d] — sorted segment IDs (monotonically non-decreasing). TF / JAX segment_sum.

§

SegmentSumBackward = 1

Gradient of Self::SegmentSum: d_input[n, d] = d_output[seg[n], d] (gather along seg ids).

§

SegmentMean = 2

out[s, d] = mean_{n : seg[n] == s} input[n, d] — sorted.

§

SegmentMeanBackward = 3

Gradient of Self::SegmentMean: d_input[n, d] = d_output[seg[n], d] / count[seg[n]].

§

SegmentMax = 4

out[s, d] = max_{n : seg[n] == s} input[n, d] — sorted.

§

SegmentMin = 5

out[s, d] = min_{n : seg[n] == s} input[n, d] — sorted.

§

SegmentProd = 6

out[s, d] = prod_{n : seg[n] == s} input[n, d] — sorted.

§

UnsortedSegmentSum = 7

out[s, d] = Σ_{n : seg[n] == s} input[n, d] — unsorted (seg IDs in any order). TF unsorted_segment_sum.

§

UnsortedSegmentSumBackward = 8

Gradient of Self::UnsortedSegmentSum: d_input[n, d] = d_output[seg[n], d].

§

UnsortedSegmentMean = 9

out[s, d] = mean_{n : seg[n] == s} input[n, d] — unsorted.

§

UnsortedSegmentMeanBackward = 10

Gradient of Self::UnsortedSegmentMean: d_input[n, d] = d_output[seg[n], d] / count[seg[n]].

§

UnsortedSegmentMax = 11

out[s, d] = max_{n : seg[n] == s} input[n, d] — unsorted.

§

UnsortedSegmentMin = 12

out[s, d] = min_{n : seg[n] == s} input[n, d] — unsorted.

§

SegmentMaxBackward = 13

Phase 25. Gradient of Self::SegmentMax: d_input[k, d] = d_output[seg, d] for the (lowest-index) k where input[k, d] == max. Argmax recomputed in BW kernel (re-scans the segment) so the FW signature stays unchanged.

§

SegmentMinBackward = 14

Phase 25. Gradient of Self::SegmentMin — mirror of Self::SegmentMaxBackward.

§

SegmentProdBackward = 15

Phase 25. Gradient of Self::SegmentProd: d_input[k, d] = d_output[seg, d] * (prod[seg, d] / x[k, d]). Direct division — caller must avoid zero-valued inputs in the segment or accept NaN / Inf in the gradient.

§

UnsortedSegmentMaxBackward = 16

Phase 25. Gradient of Self::UnsortedSegmentMax — same recompute-argmax pattern as the sorted variant but scans the full input array per (seg, d) cell. Non-deterministic w.r.t. tie-breaking when the FW was non-deterministic.

§

UnsortedSegmentMinBackward = 17

Phase 25. Gradient of Self::UnsortedSegmentMin — mirror of Self::UnsortedSegmentMaxBackward.

§

UnsortedSegmentProd = 18

Phase 25. out[s, d] = prod_{n : seg[n] == s} input[n, d] — unsorted. Uses an atomicCAS retry loop because no native FP atomicMul exists. Non-deterministic.

§

UnsortedSegmentProdBackward = 19

Phase 25. Gradient of Self::UnsortedSegmentProd — same direct-division pattern as Self::SegmentProdBackward.

Trait Implementations§

Source§

impl Clone for SegmentKind

Source§

fn clone(&self) -> SegmentKind

Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Copy for SegmentKind

Source§

impl Debug for SegmentKind

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error>

Formats the value using the given formatter. Read more
Source§

impl Eq for SegmentKind

Source§

impl Hash for SegmentKind

Source§

fn hash<__H>(&self, state: &mut __H)
where __H: Hasher,

Feeds this value into the given Hasher. Read more
1.3.0 · Source§

fn hash_slice<H>(data: &[Self], state: &mut H)
where H: Hasher, Self: Sized,

Feeds a slice of this type into the given Hasher. Read more
Source§

impl PartialEq for SegmentKind

Source§

fn eq(&self, other: &SegmentKind) -> bool

Tests for self and other values to be equal, and is used by ==.
1.0.0 (const: unstable) · Source§

fn ne(&self, other: &Rhs) -> bool

Tests for !=. The default implementation is almost always sufficient, and should not be overridden without very good reason.
Source§

impl StructuralPartialEq for SegmentKind

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.