Skip to main content

baracuda_kernels_types/
plan.rs

1//! Plan-layer descriptors shared across kernel families: caller
2//! preferences, workspace handles, and the numerical-guarantee record
3//! every plan exposes.
4
5use baracuda_driver::DeviceSliceMut;
6
7use crate::element::{ElementKind, MathPrecision};
8use crate::sku::BackendKind;
9
10/// Caller-supplied workspace for a launch.
11///
12/// Plans never own device memory in baracuda — pass scratch in at
13/// `run` time. Pass [`Workspace::None`] for plans whose
14/// workspace size is zero.
15///
16/// **Intentionally NOT `#[non_exhaustive]`** — the two-variant
17/// `None` / `Borrowed` split is hot-path-matched by every plan's
18/// `run` method, and the API has been stable through 27 alphas. If
19/// a third variant (pool-backed, per-stream-cached) ever lands it
20/// will be a deliberate breaking change with a major-version bump.
21#[derive(Debug)]
22pub enum Workspace<'a> {
23    /// No workspace (only valid when the plan reports zero bytes needed).
24    None,
25    /// Borrowed device scratch. Length must be at least the plan's
26    /// reported workspace size.
27    Borrowed(DeviceSliceMut<'a, u8>),
28}
29
30/// Hints that influence kernel selection inside a plan's `select`
31/// method.
32///
33/// The fields are intentionally generic across kernel families — each
34/// op category may layer its own `*PlanPreference` wrapper on top
35/// (e.g. `GroupedPlanPreference` adds grouped-specific knobs) that
36/// embeds this struct.
37#[derive(Copy, Clone, Debug)]
38pub struct PlanPreference {
39    /// Maximum workspace the caller is willing to provide. The selector
40    /// only considers kernels whose workspace size for the descriptor
41    /// fits in this budget. Use `usize::MAX` to disable the constraint.
42    pub max_workspace_bytes: usize,
43    /// Allow Hopper-specialized (`sm_90a`) kernels in selection. Has no
44    /// effect when the `sm90a` feature is off in the underlying kernel
45    /// crate (no such kernels exist in the build).
46    pub allow_sm90a: bool,
47    /// Force a particular backend at plan-selection time, bypassing the
48    /// plan's built-in heuristic.
49    ///
50    /// `None` (the default) lets the plan's per-op-category heuristic
51    /// decide. `Some(BackendKind::Cublas)` / `Some(BackendKind::Cutlass)`
52    /// override the heuristic when a caller has profiling-driven
53    /// information the heuristic doesn't have (or wants deterministic
54    /// kernel selection for golden-output testing).
55    ///
56    /// Plans surface their actual choice through their `sku()` accessor —
57    /// inspect `sku.backend` to see what the heuristic picked.
58    ///
59    /// Returns `Error::Unsupported` from `select` if the requested
60    /// backend doesn't have a kernel for the requested
61    /// `(layout, epilogue, element)` triple. For example, the cuBLAS
62    /// backend doesn't support `EpilogueKind::BiasRelu` (cuBLAS has no
63    /// fused-bias-activation GEMM); forcing it on a Bias* epilogue
64    /// returns an error rather than silently falling back to CUTLASS.
65    pub prefer_backend: Option<BackendKind>,
66}
67
68impl Default for PlanPreference {
69    fn default() -> Self {
70        Self {
71            max_workspace_bytes: usize::MAX,
72            allow_sm90a: true,
73            prefer_backend: None,
74        }
75    }
76}
77
78/// Numerical guarantees a kernel provides.
79///
80/// Surfaces the salient numerical properties consumers need to decide
81/// whether a kernel SKU satisfies an op's precision contract — without
82/// having to re-derive them from documentation per kernel.
83///
84/// All fields are intentionally cheap to compare so this struct can be
85/// hashed into selection / autotuner caches.
86#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
87pub struct PrecisionGuarantee {
88    /// Bit-precision used inside the math instruction.
89    pub math_precision: MathPrecision,
90    /// Element type of the multiply-accumulate accumulator.
91    pub accumulator: ElementKind,
92    /// Whether the kernel produces bit-identical results across runs on
93    /// the same hardware with the same inputs.
94    ///
95    /// `false` for tensor-core kernels (F16, BF16, TF32) because the
96    /// warp-level reduction order isn't fixed by the spec — adjacent
97    /// runs can differ in the last bit even with the same inputs.
98    /// `true` for SIMT F32 and for integer kernels.
99    pub bit_stable_on_same_hardware: bool,
100    /// Whether the kernel produces bit-identical results across runs
101    /// from a single thread within a process — i.e. it has no internal
102    /// nondeterminism (no atomic accumulation across blocks, no random
103    /// tile-schedule decisions).
104    pub deterministic: bool,
105}