pub struct SegmentSumPlan<T: Element> { /* private fields */ }Expand description
segment_sum plan (sorted).
out[s, d] = Σ_{n : segment_ids[n] == s} input[n, d] (TF / JAX
segment_sum). Requires segment_ids to be monotonically
non-decreasing.
When to use: forward sorted segment-sum. For unsorted IDs use
UnsortedSegmentSumPlan. Pair
with SegmentSumBackwardPlan
for autograd.
Dtypes: {f32, f64} (matches the family — kernels rely on
FP atomic primitives even in the sorted variant for some paths).
Shape limits: input is [N, D], segment_ids is [N]
with values in [0, num_segments); output is [num_segments, D].
All extents non-negative.
Workspace: none.
Precision guarantee: deterministic, bit-stable — single thread per output cell sweeps the segment’s row range in order.
Index policy: out-of-range segment IDs (< 0 or
≥ num_segments) are silently dropped (TF / JAX semantic).
Output buffer is fully overwritten (no accumulation into prior
state).
Implementations§
Source§impl<T: Element> SegmentSumPlan<T>
impl<T: Element> SegmentSumPlan<T>
Sourcepub fn select(
_stream: &Stream,
desc: &SegmentSumDescriptor,
_pref: PlanPreference,
) -> Result<Self>
pub fn select( _stream: &Stream, desc: &SegmentSumDescriptor, _pref: PlanPreference, ) -> Result<Self>
Pick a kernel for desc.
Sourcepub fn can_implement(&self, args: &SegmentSumArgs<'_, T>) -> Result<()>
pub fn can_implement(&self, args: &SegmentSumArgs<'_, T>) -> Result<()>
Validate args.
Sourcepub fn workspace_size(&self) -> usize
pub fn workspace_size(&self) -> usize
Workspace size in bytes.
Sourcepub fn precision_guarantee(&self) -> PrecisionGuarantee
pub fn precision_guarantee(&self) -> PrecisionGuarantee
Numerical guarantees for this plan’s kernel.