pub struct KthvaluePlan<T: Element> { /* private fields */ }Expand description
kthvalue plan.
Returns the k-th smallest value and its index along the last axis
(PyTorch torch.kthvalue; 0-indexed k here, vs PyTorch’s
1-indexed). Composed at the plan layer as a bottom-(k+1)
TopkPlan, reading cell (k) of the result.
When to use: order-statistic queries (median, quantile pickup
in fixed K range). Pair with
KthvalueBackwardPlan.
Dtypes: {f32, f64}.
Shape limits: input [batch, row_len]; outputs [batch];
row_len ≤ 1024; k < 64 (composes a bottom-(k+1) topk).
Workspace: zero in Workspace; plan internally allocates a
scratch [batch, k+1] topk-result buffer per launch.
Precision guarantee: deterministic, bit-stable (inherits topk’s fixed-network guarantee).
Implementations§
Source§impl<T: Element> KthvaluePlan<T>
impl<T: Element> KthvaluePlan<T>
Sourcepub fn select(
_stream: &Stream,
desc: &KthvalueDescriptor,
_pref: PlanPreference,
) -> Result<Self>
pub fn select( _stream: &Stream, desc: &KthvalueDescriptor, _pref: PlanPreference, ) -> Result<Self>
Pick a kernel for desc.
Sourcepub fn can_implement(&self, args: &KthvalueArgs<'_, T>) -> Result<()>
pub fn can_implement(&self, args: &KthvalueArgs<'_, T>) -> Result<()>
Validate args.
Sourcepub fn workspace_size(&self) -> usize
pub fn workspace_size(&self) -> usize
Workspace size in bytes. Internal device buffers are allocated fresh at run() time.
Sourcepub fn precision_guarantee(&self) -> PrecisionGuarantee
pub fn precision_guarantee(&self) -> PrecisionGuarantee
Numerical guarantees for this plan’s kernel.
Sourcepub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: KthvalueArgs<'_, T>,
) -> Result<()>
pub fn run( &self, stream: &Stream, _workspace: Workspace<'_>, args: KthvalueArgs<'_, T>, ) -> Result<()>
Launch. Composes a bottom-(k+1) topk; reads the last cell as the
k-th smallest. Allocates two intermediate device buffers and
round-trips the bottom-(k+1) cells through host memory to
extract the (k)-th slot per row (the data is small — batch *
(k+1) cells with k+1 ≤ 64).