Skip to main content

baracuda_kernels/
lib.rs

1//! # baracuda-kernels
2//!
3//! Unified ML op facade for the baracuda CUDA ecosystem.
4//!
5//! Exposes every primitive an ML framework would expect (union of
6//! PyTorch `torch.*` + `nn.functional` and JAX `lax.*` / `numpy` ops)
7//! through a single Plan-based Rust surface, internally dispatching to:
8//!
9//! 1. An NVIDIA-library wrapper crate when one already covers the op
10//!    (`baracuda-cublas`, `baracuda-cudnn`, `baracuda-cufft`,
11//!    `baracuda-cusparse`, `baracuda-cusolver`, `baracuda-curand`,
12//!    `baracuda-cutensor`, `baracuda-npp`, `baracuda-cvcuda`,
13//!    `baracuda-cutlass`).
14//! 2. A bespoke `.cu` kernel shipped in
15//!    [`baracuda-kernels-sys`](https://docs.rs/baracuda-kernels-sys)
16//!    when no NVIDIA library covers it (or covers it poorly at relevant
17//!    shapes).
18//!
19//! Callers import **one** crate and reach for **one** API style; the
20//! dispatch decision is an internal detail driven by `select`.
21//!
22//! ## Status
23//!
24//! Active. Covers ~2700 FFI launch points across Phase 1–66 work
25//! including: full elementwise unary/binary/ternary matrix (fwd + bwd,
26//! contig + strided), all standard reductions and scans, the
27//! normalizer family (RMS / Layer / Batch / Group / Instance with
28//! in-place SMEM-staged kernels for f32/f16/bf16/f64), softmax /
29//! log-softmax / sparsemax / gumbel-softmax (+ BW), full attention
30//! suite (SDPA contig + strided + BW, Flash SDPA sm_80 + sm_89 +
31//! varlen + Tri Dao FA2 v2.8.3, RoPE / ALiBi / KV-cache, paged-KV
32//! decode/prefill via FlashInfer, ring attention, block-sparse SDPA,
33//! arbitrary-mask SDPA), GEMM (f16/bf16/tf32/f32/f64/s8/u8/s4/u4/bin/
34//! fp8 with optional bias + ReLU/GELU/SiLU epilogues), GGUF MMVQ
35//! (11 block formats × {contig, strided, batched, multi-M}), the
36//! complete loss family (15 losses × FW+BW + CTC), conv + pool
37//! (cuDNN-backed + bit-exact bespoke Adaptive / LpPool /
38//! FractionalMaxPool), image ops (interpolate / upsample / grid
39//! sample / ROI / NMS / pixel shuffle), linalg (cuSOLVER facade +
40//! bespoke batched Ormqr WY + QR materialize, real + complex), FFT /
41//! cuRAND facades, full quantize family + GGUF + NF4 + AWQ +
42//! Marlin + STE backward, segment + embedding + indexing + scatter,
43//! Mamba-2 SSD + causal conv1d, TransformerEngine FP8 cast /
44//! recipe, mHC hyper-connections.
45//!
46//! Every public `_run` FFI symbol has a matching `_can_implement`
47//! pre-launch validator companion (Phase 66 closure, alpha.64).
48//!
49//! Cargo features are documented in the workspace `README.md`. The
50//! default build (`sm80` only) covers Ampere-baseline kernels;
51//! `sm89` adds Ada specializations (FP8 GEMM, sm_89 Flash SDPA);
52//! `sm90a` reserves the Hopper namespace. Feature flags for the
53//! vendored kernel families (`fa2`, `mhc`, `ozimmu`, `flashinfer`,
54//! `mamba`, `bnb_nf4`, `marlin`, `awq`, `xformers_*`,
55//! `tensor_engine`, `optim`, `ring_attention`, `megatron_tp`,
56//! `nvshmem`) are off by default.
57//!
58//! See [`ROADMAP.md`] for the live backlog and [`OP-MATRIX.md`] for
59//! per-op support status.
60//!
61//! [`ROADMAP.md`]: https://github.com/ciresnave/baracuda/blob/main/ROADMAP.md
62//! [`OP-MATRIX.md`]: https://github.com/ciresnave/baracuda/blob/main/OP-MATRIX.md
63
64#![deny(missing_docs)]
65
66// Re-export the shared type vocabulary.
67pub use baracuda_kernels_types::{
68    contiguous_stride, ActivationKind, ArchSku, ArgReduceKind, AttentionKind, BackendKind,
69    BiasElement, BiasElementKind, Bin, BinElement, BinaryCmpKind, BinaryKind, Bool, Complex32,
70    Complex64, CrossEntropyTargetKind, Element, ElementKind, EmbeddingKind, EpilogueKind,
71    F32Strict, FftKind, FillMode, Fp8E4M3, Fp8E5M2, FpElement, GatedActivationKind,
72    GgufBlockFormat, ImageKind, IndexElement, IndexElementKind, IndexOutputElement,
73    IndexOutputKind, IndexingKind,
74    IntElement, KernelDtype, KernelSku, LayoutSku, LinalgKind, LossKind, LossReduction,
75    MathPrecision, MatrixMut, MatrixRef, MoeKind, NormalizationKind, OpCategory, PadMode,
76    PlanPreference, PoolKind, PrecisionGuarantee, QuantizeKind, RandomKind, ReduceKind,
77    ReduceToOp, S4, S8, ScalarType, ScanKind, SegmentKind,
78    ShapeLayoutKind, SoftmaxKind, SortKind, TensorMut, TensorRef, TernaryKind, U4, U8, UnaryKind,
79    VectorRef, Workspace,
80};
81
82// Re-export the float-GEMM plan types from baracuda-cutlass unchanged —
83// no bespoke path exists for float GEMM yet, the CUTLASS surface is
84// the one true entry.
85pub use baracuda_cutlass::{
86    BatchedGemmArgs, BatchedGemmDescriptor, BatchedGemmPlan, Error, GemmArgs, GemmDescriptor,
87    GemmPlan, GemmSku, GroupedGemmPlan, GroupedPlanPreference, GroupedProblem, GroupedScheduleMode,
88    PreparedGroupedGemm, Result,
89};
90
91// Unified GEMM plan dispatchers: the quantized / packed families
92// (int / fp8 / int4 / bin / sparse24) plus, since Phase 74, the plain
93// dense FP family (`DenseGemmPlan` — cuBLAS-backed, RRR / RCR / CRR,
94// strided-batch).
95pub mod gemm;
96
97pub use gemm::{
98    BinGemmArgs, BinGemmDescriptor, BinGemmPlan, DenseGemmArgs, DenseGemmDescriptor,
99    DenseGemmLayout, DenseGemmPlan, Fp8GemmArgs, Fp8GemmDescriptor, Fp8GemmPlan,
100    GemmSparse24Args, GemmSparse24Descriptor, GemmSparse24Plan, Int4GemmArgs, Int4GemmDescriptor,
101    Int4GemmPlan, IntGemmArgs, IntGemmDescriptor, IntGemmPlan,
102};
103
104// Phase 48 — Marlin + AWQ 4-bit GEMM + GPTQ→Marlin repack utility.
105// Plan types always exported; FFI calls inside `run()` are
106// `marlin` / `awq` feature-gated.
107pub use gemm::{
108    gptq_to_marlin_repack, AwqActivation, GptqWeights, Int4AwqGemmArgs, Int4AwqGemmDescriptor,
109    Int4AwqGemmPlan, Int4MarlinGemmArgs, Int4MarlinGemmDescriptor, Int4MarlinGemmPlan,
110    MarlinActivation, MarlinWeights, MARLIN_PERM_LEN, MARLIN_SCALE_PERM_LEN,
111};
112
113// Elementwise op family — Phase 3 trailblazer surface. See module docs
114// for the per-category Plan layout.
115pub mod elementwise;
116
117pub use elementwise::{
118    AffineArgs, AffineDescriptor, AffinePlan, BinaryArgs, BinaryBackwardArgs,
119    BinaryBackwardDescriptor, BinaryBackwardPlan, BinaryCmpArgs,
120    BinaryCmpDescriptor, BinaryCmpPlan, BinaryDescriptor, BinaryParamArgs,
121    BinaryParamBackwardArgs, BinaryParamBackwardDescriptor, BinaryParamBackwardPlan,
122    BinaryParamDescriptor, BinaryParamPlan, BinaryPlan, CastArgs, CastDescriptor, CastPlan,
123    CastSubByteArgs, CastSubByteDescriptor, CastSubBytePlan,
124    GatedActivationArgs,
125    GatedActivationBackwardArgs, GatedActivationBackwardDescriptor, GatedActivationBackwardPlan,
126    GatedActivationDescriptor, GatedActivationPlan, TernaryArgs, TernaryBackwardArgs,
127    TernaryBackwardDescriptor, TernaryBackwardPlan, TernaryDescriptor, TernaryPlan, UnaryArgs,
128    UnaryBackwardArgs, UnaryBackwardDescriptor, UnaryBackwardPlan, UnaryDescriptor,
129    UnaryParamArgs, UnaryParamBackwardArgs, UnaryParamBackwardDescriptor, UnaryParamBackwardPlan,
130    UnaryParamDescriptor, UnaryParamPlan, UnaryPlan, WhereArgs, WhereBackwardArgs,
131    WhereBackwardDescriptor, WhereBackwardPlan, WhereDescriptor, WherePlan,
132};
133
134pub use elementwise::{
135    PReluArgs, PReluBackwardArgs, PReluBackwardDescriptor, PReluBackwardPlan, PReluDescriptor,
136    PReluPlan,
137};
138
139// Shape / layout op family — Category N. Plan-per-op because each
140// op's descriptor / args shape differs.
141pub mod shape_layout;
142
143pub use shape_layout::{
144    ConcatArgs, ConcatBackwardArgs, ConcatBackwardDescriptor, ConcatBackwardPlan,
145    ConcatDescriptor, ConcatPlan, ContiguizeArgs, ContiguizeDescriptor, ContiguizePlan,
146    FillArgs, FillDescriptor, FillPlan, FlipArgs,
147    FlipBackwardArgs, FlipBackwardDescriptor,
148    FlipBackwardPlan, FlipDescriptor, FlipPlan, PadArgs, PadBackwardArgs,
149    PadBackwardDescriptor, PadBackwardPlan, PadDescriptor, PadPlan, PermuteArgs,
150    PermuteBackwardArgs, PermuteBackwardDescriptor, PermuteBackwardPlan, PermuteDescriptor,
151    PermutePlan, RepeatArgs, RepeatBackwardArgs, RepeatBackwardDescriptor,
152    RepeatBackwardPlan, RepeatDescriptor, RepeatPlan, RollArgs, RollBackwardArgs,
153    RollBackwardDescriptor, RollBackwardPlan, RollDescriptor, RollPlan,
154    TrilArgs, TrilBackwardArgs, TrilBackwardDescriptor, TrilBackwardPlan,
155    TrilDescriptor, TrilPlan, TriuArgs, TriuBackwardArgs, TriuBackwardDescriptor,
156    TriuBackwardPlan, TriuDescriptor, TriuPlan,
157    WriteSliceArgs, WriteSliceDescriptor, WriteSlicePlan,
158};
159
160// Reduction op family — Phase 4 (Category E). Output shape differs
161// from input by the reduced axes.
162pub mod reduce;
163
164pub use reduce::{
165    ArgReduceArgs, ArgReduceDescriptor, ArgReducePlan, BoolReduceArgs, BoolReduceDescriptor,
166    BoolReducePlan, CountReduceArgs, CountReduceDescriptor, CountReducePlan, ReduceArgs,
167    ReduceBackwardArgs, ReduceBackwardDescriptor, ReduceBackwardPlan, ReduceDescriptor, ReducePlan,
168    ReduceToArgs, ReduceToDescriptor, ReduceToPlan, TraceArgs, TraceDescriptor, TracePlan,
169};
170
171// Scan (associative prefix) op family — Phase 4 (Category F).
172// Length-preserving along the scan axis.
173pub mod scan;
174
175pub use scan::{
176    ScanArgs, ScanBackwardArgs, ScanBackwardDescriptor, ScanBackwardPlan, ScanDescriptor,
177    ScanPlan,
178};
179
180// Softmax family — Phase 5 (Category H). Length-preserving stable
181// softmax / log-softmax / sparsemax along a single axis.
182pub mod softmax;
183
184pub use softmax::{
185    GumbelSoftmaxArgs, GumbelSoftmaxBackwardArgs, GumbelSoftmaxBackwardDescriptor,
186    GumbelSoftmaxBackwardPlan, GumbelSoftmaxDescriptor, GumbelSoftmaxPlan, SoftmaxArgs,
187    SoftmaxBackwardArgs, SoftmaxBackwardDescriptor, SoftmaxBackwardPlan, SoftmaxDescriptor,
188    SoftmaxPlan, SparsemaxArgs, SparsemaxBackwardArgs, SparsemaxBackwardDescriptor,
189    SparsemaxBackwardPlan, SparsemaxDescriptor, SparsemaxPlan, SPARSEMAX_MAX_EXTENT,
190};
191
192// Normalization family — Phase 5 (Category G). Per-row stable
193// normalization along a single axis with optional per-feature affine
194// (gamma / beta) parameters. Today wired: RMSNorm + LayerNorm × FW + BW.
195pub mod norm;
196
197pub use norm::{
198    BatchNormArgs, BatchNormBackwardArgs, BatchNormBackwardDescriptor, BatchNormBackwardPlan,
199    BatchNormDescriptor, BatchNormPlan, GroupNormArgs, GroupNormBackwardArgs,
200    GroupNormBackwardDescriptor, GroupNormBackwardPlan, GroupNormDescriptor, GroupNormPlan,
201    InstanceNormArgs, InstanceNormBackwardArgs, InstanceNormBackwardDescriptor,
202    InstanceNormBackwardPlan, InstanceNormDescriptor, InstanceNormPlan, LayerNormArgs,
203    LayerNormBackwardArgs, LayerNormBackwardDescriptor, LayerNormBackwardPlan, LayerNormDescriptor,
204    LayerNormPlan, RMSNormArgs, RMSNormBackwardArgs, RMSNormBackwardDescriptor,
205    RMSNormBackwardPlan, RMSNormDescriptor, RMSNormPlan,
206};
207
208// Loss family — Phase 5 (Category R). MSE / NLL / CrossEntropy / BCE
209// / KLDiv (FW + BW × 4 FP dtypes × {None, Mean, Sum} reduction).
210pub mod loss;
211
212pub use loss::{
213    BceLossArgs, BceLossBackwardArgs, BceLossBackwardDescriptor, BceLossBackwardPlan,
214    BceLossDescriptor, BceLossPlan, BceWithLogitsLossArgs, BceWithLogitsLossBackwardArgs,
215    BceWithLogitsLossBackwardDescriptor, BceWithLogitsLossBackwardPlan,
216    BceWithLogitsLossDescriptor, BceWithLogitsLossPlan, CrossEntropyLossArgs,
217    CrossEntropyLossBackwardArgs, CrossEntropyLossBackwardDescriptor,
218    CrossEntropyLossBackwardPlan, CrossEntropyLossDescriptor, CrossEntropyLossPlan,
219    FusedLinearCrossEntropyArgs, FusedLinearCrossEntropyBackwardArgs,
220    FusedLinearCrossEntropyBackwardDescriptor, FusedLinearCrossEntropyBackwardPlan,
221    FusedLinearCrossEntropyDescriptor, FusedLinearCrossEntropyPlan, FLCE_DEFAULT_IGNORE_INDEX,
222    GaussianNllLossArgs, GaussianNllLossBackwardArgs, GaussianNllLossBackwardDescriptor,
223    GaussianNllLossBackwardPlan, GaussianNllLossDescriptor, GaussianNllLossPlan, HuberLossArgs,
224    HuberLossBackwardArgs, HuberLossBackwardDescriptor, HuberLossBackwardPlan,
225    HuberLossDescriptor, HuberLossPlan, KlDivLossArgs, KlDivLossBackwardArgs,
226    KlDivLossBackwardDescriptor, KlDivLossBackwardPlan, KlDivLossDescriptor, KlDivLossPlan,
227    L1LossArgs, L1LossBackwardArgs, L1LossBackwardDescriptor, L1LossBackwardPlan,
228    L1LossDescriptor, L1LossPlan, MseLossArgs, MseLossBackwardArgs, MseLossBackwardDescriptor,
229    MseLossBackwardPlan, MseLossDescriptor, MseLossPlan, NllLossArgs, NllLossBackwardArgs,
230    NllLossBackwardDescriptor, NllLossBackwardPlan, NllLossDescriptor, NllLossPlan,
231    PoissonNllLossArgs, PoissonNllLossBackwardArgs, PoissonNllLossBackwardDescriptor,
232    PoissonNllLossBackwardPlan, PoissonNllLossDescriptor, PoissonNllLossPlan, SmoothL1LossArgs,
233    SmoothL1LossBackwardArgs, SmoothL1LossBackwardDescriptor, SmoothL1LossBackwardPlan,
234    SmoothL1LossDescriptor, SmoothL1LossPlan,
235};
236
237pub use loss::{
238    CosineEmbeddingLossArgs, CosineEmbeddingLossBackwardArgs,
239    CosineEmbeddingLossBackwardDescriptor, CosineEmbeddingLossBackwardPlan,
240    CosineEmbeddingLossDescriptor, CosineEmbeddingLossPlan, HingeEmbeddingLossArgs,
241    HingeEmbeddingLossBackwardArgs, HingeEmbeddingLossBackwardDescriptor,
242    HingeEmbeddingLossBackwardPlan, HingeEmbeddingLossDescriptor, HingeEmbeddingLossPlan,
243    MarginRankingLossArgs, MarginRankingLossBackwardArgs, MarginRankingLossBackwardDescriptor,
244    MarginRankingLossBackwardPlan, MarginRankingLossDescriptor, MarginRankingLossPlan,
245    MultiMarginLossArgs, MultiMarginLossBackwardArgs, MultiMarginLossBackwardDescriptor,
246    MultiMarginLossBackwardPlan, MultiMarginLossDescriptor, MultiMarginLossPlan,
247    MultilabelMarginLossArgs, MultilabelMarginLossBackwardArgs,
248    MultilabelMarginLossBackwardDescriptor, MultilabelMarginLossBackwardPlan,
249    MultilabelMarginLossDescriptor, MultilabelMarginLossPlan, MultilabelSoftMarginLossArgs,
250    MultilabelSoftMarginLossBackwardArgs, MultilabelSoftMarginLossBackwardDescriptor,
251    MultilabelSoftMarginLossBackwardPlan, MultilabelSoftMarginLossDescriptor,
252    MultilabelSoftMarginLossPlan, TripletMarginLossArgs, TripletMarginLossBackwardArgs,
253    TripletMarginLossBackwardDescriptor, TripletMarginLossBackwardPlan,
254    TripletMarginLossDescriptor, TripletMarginLossPlan,
255};
256
257// CTCLoss (Phase 5 Milestone 5.5) — DP-based sequence loss for
258// variable-length inputs/targets.
259pub use loss::{
260    CtcLossArgs, CtcLossBackwardArgs, CtcLossBackwardDescriptor, CtcLossBackwardPlan,
261    CtcLossDescriptor, CtcLossPlan,
262};
263
264// CTCLoss cuDNN sibling (Phase 7 Milestone 7.4) — same op, distinct
265// backend; Fuel's autotuner races this against the bespoke plan.
266// Gated behind the `cudnn` cargo feature.
267#[cfg(feature = "cudnn")]
268pub use loss::{CtcLossCudnnArgs, CtcLossCudnnDescriptor, CtcLossCudnnPlan};
269
270// Random / sampling family — Phase 4.5 (Category Q). Uniform / Normal
271// pass through cuRAND; Bernoulli + Dropout use bespoke kernels on top
272// of cuRAND-uniform.
273pub mod random;
274
275pub use random::{
276    DropoutArgs, DropoutBackwardArgs, DropoutBackwardDescriptor, DropoutBackwardPlan,
277    DropoutDescriptor, DropoutPlan, RandomArgs, RandomBoolArgs, RandomDescriptor, RandomPlan,
278};
279
280// Attention family — Phase 6 (Category K). Milestone 6.1 ships the two
281// positional-encoding ops: RoPE (rotary, Llama / Mistral / Gemma) and
282// ALiBi (linear biases, MPT / BLOOM). FW + BW × 4 FP dtypes.
283pub mod attention;
284
285pub use attention::{
286    AlibiArgs, AlibiBackwardArgs, AlibiBackwardDescriptor, AlibiBackwardPlan, AlibiDescriptor,
287    AlibiPlan,
288    // Phase 73 follow-up — FlashDecoding (split-K parallel decode).
289    FlashDecodingArgs, FlashDecodingDescriptor, FlashDecodingPlan, FLASH_DECODING_MAX_D,
290    FlashSdpaArgs, FlashSdpaBackwardArgs, FlashSdpaBackwardDescriptor,
291    FlashSdpaBackwardPlan, FlashSdpaDescriptor, FlashSdpaPlan,
292    // Phase 59b — packed-batch (varlen) FlashAttention v2 plans.
293    FlashSdpaVarlenArgs, FlashSdpaVarlenBackwardArgs, FlashSdpaVarlenBackwardPlan,
294    FlashSdpaVarlenDescriptor, FlashSdpaVarlenPlan,
295    HyperConnectionArgs, HyperConnectionDescriptor, HyperConnectionPlan, KvCacheAppendArgs,
296    KvCacheAppendDescriptor, KvCacheAppendPlan, RopeArgs, RopeBackwardArgs,
297    RopeBackwardDescriptor, RopeBackwardPlan, RopeDescriptor, RopePlan, SdpaArgs,
298    SdpaBackwardArgs, SdpaBackwardDescriptor, SdpaBackwardPlan, SdpaBlockSparseArgs,
299    SdpaBlockSparseDescriptor, SdpaBlockSparsePlan, SdpaDescriptor, SdpaPlan,
300    FLASH_SDPA_MAX_D, ROPE_DEFAULT_BASE, SDPA_BLOCK_SPARSE_MAX_BLOCK, SDPA_BLOCK_SPARSE_MAX_D,
301};
302
303// Phase 45 — long-context RoPE scaling helpers (pure-Rust host-side
304// cos/sin table builders for YaRN + LongRoPE).
305pub use attention::{RopeScaledTableBuilder, RopeScaling};
306
307// Phase 10 Milestone 10.3 — sm_89 (Ada Lovelace) Flash Attention FW
308// sibling. Same descriptor / args shape as the sm_80 baseline so callers
309// swap plans by changing the type, with `cp.async` double-buffered K/V
310// loads and a wider thread block for Ada's larger per-SM register file.
311// f16 + bf16 only.
312#[cfg(feature = "sm89")]
313pub use attention::{FlashSdpaSm89Args, FlashSdpaSm89Descriptor, FlashSdpaSm89Plan};
314
315// Dense linalg family — Milestone 6.3 (Category Linalg). Wraps
316// cuSOLVER for Cholesky / LU / QR / SVD. f32 + f64 only (cuSOLVER's
317// dense API does not expose f16 / bf16 for these ops).
318pub mod linalg;
319
320pub use linalg::{
321    BatchedOrmqrArgs, BatchedOrmqrDescriptor, BatchedOrmqrOp, BatchedOrmqrPlan, BatchedOrmqrSide,
322    BatchedOrmqrWyArgs, BatchedOrmqrWyDescriptor, BatchedOrmqrWyPlan, BatchedQrArgs,
323    BatchedQrDescriptor, BatchedQrMaterializeArgs, BatchedQrMaterializeDescriptor,
324    BatchedQrMaterializePlan, BatchedQrPlan, BatchedSvdArgs, BatchedSvdDescriptor, BatchedSvdPlan,
325    BatchedSvdaArgs, BatchedSvdaDescriptor, BatchedSvdaPlan, CholeskyArgs, CholeskyDescriptor,
326    CholeskyPlan, EigArgs, EigDescriptor, EigPlan, EighArgs, EighDescriptor, EighPlan, InverseArgs,
327    InverseDescriptor, InversePlan, LstSqArgs, LstSqDescriptor, LstSqPlan, LuArgs, LuDescriptor,
328    LuPlan, QrArgs, QrDescriptor, QrPlan, SolveArgs, SolveDescriptor, SolvePlan, SvdArgs,
329    SvdDescriptor, SvdPlan, WY_NB,
330};
331
332// Convolution family — Phase 7 Milestone 7.1 (Category Convolution).
333// Wraps cuDNN's legacy descriptor-based API. Today wired: NCHW Conv2d
334// FW + BW data + BW filter × {f32, f64, f16, bf16}. 1-D / 3-D /
335// transposed / depthwise variants follow in fanout milestones. Gated
336// behind the `cudnn` cargo feature — cuDNN is a separate NVIDIA
337// download not bundled with the stock CUDA toolkit.
338#[cfg(feature = "cudnn")]
339pub mod conv;
340
341#[cfg(feature = "cudnn")]
342pub use conv::{
343    Col2Im1dArgs, Col2Im1dDescriptor, Col2Im1dPlan, Conv1dArgs, Conv1dBwArgs, Conv1dDescriptor,
344    Conv1dDwArgs, Conv1dPlan, Conv2dArgs, Conv2dBwArgs, Conv2dDescriptor, Conv2dDwArgs,
345    Conv2dPlan, Conv3dArgs, Conv3dBwArgs, Conv3dDescriptor, Conv3dDwArgs, Conv3dPlan,
346    ConvTranspose1dArgs, ConvTranspose1dBwArgs, ConvTranspose1dDescriptor, ConvTranspose1dDwArgs,
347    ConvTranspose1dPlan, ConvTranspose2dArgs, ConvTranspose2dBwArgs, ConvTranspose2dDescriptor,
348    ConvTranspose2dDwArgs, ConvTranspose2dPlan, ConvTranspose3dArgs, ConvTranspose3dBwArgs,
349    ConvTranspose3dDescriptor, ConvTranspose3dDwArgs, ConvTranspose3dPlan, Im2Col1dArgs,
350    Im2Col1dDescriptor, Im2Col1dPlan, Im2ColArgs, Im2ColDescriptor, Im2ColPlan,
351};
352
353// Pooling family — Phase 7 Milestone 7.2 (Category Pooling). Wraps
354// cuDNN's legacy pooling API. Today wired: NCHW MaxPool2d + AvgPool2d
355// (FW + BW) × {f32, f64, f16, bf16}. 1-D / 3-D / adaptive / LP-pool /
356// fractional-max-pool follow in fanout milestones. Gated behind the
357// `cudnn` cargo feature.
358#[cfg(feature = "cudnn")]
359pub mod pool;
360
361#[cfg(feature = "cudnn")]
362pub use pool::{
363    AdaptiveAvgPool1dPlan, AdaptiveAvgPool2dPlan, AdaptiveAvgPool3dPlan, AdaptiveMaxPool1dPlan,
364    AdaptiveMaxPool2dPlan, AdaptiveMaxPool3dPlan, AdaptivePool1dBwArgs, AdaptivePool1dDescriptor,
365    AdaptivePool1dFwArgs, AdaptivePool2dBwArgs, AdaptivePool2dDescriptor, AdaptivePool2dFwArgs,
366    AdaptivePool3dBwArgs, AdaptivePool3dDescriptor, AdaptivePool3dFwArgs, AvgPool1dPlan,
367    AvgPool2dPlan, AvgPool3dPlan, FractionalMaxPool2dBwArgs, FractionalMaxPool2dDescriptor,
368    FractionalMaxPool2dFwArgs, FractionalMaxPool2dPlan, FractionalMaxPool3dBwArgs,
369    FractionalMaxPool3dDescriptor, FractionalMaxPool3dFwArgs, FractionalMaxPool3dPlan,
370    LpPool1dBackwardPlan, LpPool1dBwArgs, LpPool1dDescriptor, LpPool1dFwArgs, LpPool1dPlan,
371    LpPool2dBackwardPlan, LpPool2dBwArgs, LpPool2dDescriptor, LpPool2dFwArgs, LpPool2dPlan,
372    MaxPool1dPlan, MaxPool2dPlan, MaxPool3dPlan, Pool1dBwArgs,
373    Pool1dDescriptor, Pool1dFwArgs, Pool2dBwArgs, Pool2dDescriptor, Pool2dFwArgs, Pool3dBwArgs,
374    Pool3dDescriptor, Pool3dFwArgs, PoolMode,
375};
376
377// FFT family — Milestone 6.4 (Category Fft). Wraps cuFFT for the four
378// canonical 1-D PyTorch / JAX FFTs (FFT / IFFT / RFFT / IRFFT) plus
379// the two bespoke index-permutation helpers (fftshift / ifftshift).
380// f32 + f64 only (cuFFT's main API does not expose f16 / bf16).
381pub mod fft;
382
383pub use fft::{
384    FftArgs, FftDescriptor, FftNdArgs, FftNdDescriptor, FftNdPlan, FftPlan, FftShiftArgs,
385    FftShiftDescriptor, FftShiftNdArgs, FftShiftNdDescriptor, FftShiftNdPlan, FftShiftPlan,
386    IrfftArgs, IrfftDescriptor, IrfftNdArgs, IrfftNdDescriptor, IrfftNdPlan, IrfftPlan, RfftArgs,
387    RfftDescriptor, RfftNdArgs, RfftNdDescriptor, RfftNdPlan, RfftPlan, FFTSHIFT_ND_MAX_RANK,
388    FFTSHIFT_ND_MAX_SHIFT_AXES,
389};
390
391// Indexing / scatter / gather family — Phase 7 Milestone 7.3 (Category L).
392// Bespoke kernels for gather + gather_backward, scatter_add, index_select
393// + index_select_backward, masked_fill + masked_fill_backward, one_hot,
394// nonzero. Index dtype is i32 only (i64 deferred); out-of-bounds + negative
395// indices are skipped (no PyTorch-style wrap-around).
396pub mod indexing;
397
398pub use indexing::{
399    GatherArgs, GatherBackwardArgs, GatherBackwardDescriptor, GatherBackwardPlan,
400    GatherDescriptor, GatherPlan, IndexAddArgs, IndexAddDescriptor, IndexAddPlan, IndexSelectArgs,
401    IndexSelectBackwardArgs, IndexSelectBackwardDescriptor, IndexSelectBackwardPlan,
402    IndexSelectDescriptor, IndexSelectPlan, MaskedFillArgs, MaskedFillBackwardArgs,
403    MaskedFillBackwardDescriptor, MaskedFillBackwardPlan, MaskedFillDescriptor, MaskedFillPlan,
404    NonzeroArgs, NonzeroDescriptor, NonzeroPlan, OneHotArgs, OneHotDescriptor, OneHotPlan,
405    ScatterArgs, ScatterDescriptor, ScatterPlan, ScatterAddArgs, ScatterAddDescriptor,
406    ScatterAddPlan,
407};
408
409// Embedding family — Phase 7 Milestone 7.5 (Category M). Bespoke
410// kernels for `embedding` (FW + BW) with optional `padding_idx` and
411// `embedding_bag` (FW + BW × Sum / Mean modes). FW dtypes: f32 / f64 /
412// f16 / bf16 (pure copy / accumulator-typed reduce); BW dtypes: f32 /
413// f64 only (atomicAdd is native-FP). Max-mode for `embedding_bag` is
414// deferred (needs per-feature argmax tracking).
415pub mod embedding;
416
417pub use embedding::{
418    EmbeddingArgs, EmbeddingBackwardArgs, EmbeddingBackwardDescriptor, EmbeddingBackwardPlan,
419    EmbeddingBagArgs, EmbeddingBagBackwardArgs, EmbeddingBagBackwardDescriptor,
420    EmbeddingBagBackwardPlan, EmbeddingBagDescriptor, EmbeddingBagMaxArgs,
421    EmbeddingBagMaxBackwardArgs, EmbeddingBagMaxBackwardDescriptor, EmbeddingBagMaxBackwardPlan,
422    EmbeddingBagMaxDescriptor, EmbeddingBagMaxPlan, EmbeddingBagMode, EmbeddingBagPlan,
423    EmbeddingDescriptor, EmbeddingPlan,
424};
425
426// Segment / scatter-reduce family — Phase 7 Milestone 7.6 (Category S).
427// Sorted (binary-search single-pass sweep) and unsorted (atomicAdd /
428// atomicMax-via-CAS / atomicMin-via-CAS) variants for sum / mean / max
429// / min / prod. BW shipped for sum + mean (sorted and unsorted share
430// the BW launcher); max / min / prod BW deferred (argmax tracking +
431// stable prod-div). f32 + f64 only.
432pub mod segment;
433
434pub use segment::{
435    SegmentMaxArgs, SegmentMaxBackwardArgs, SegmentMaxBackwardDescriptor, SegmentMaxBackwardPlan,
436    SegmentMaxDescriptor, SegmentMaxPlan, SegmentMeanArgs, SegmentMeanBackwardArgs,
437    SegmentMeanBackwardDescriptor, SegmentMeanBackwardPlan, SegmentMeanDescriptor, SegmentMeanPlan,
438    SegmentMinArgs, SegmentMinBackwardArgs, SegmentMinBackwardDescriptor, SegmentMinBackwardPlan,
439    SegmentMinDescriptor, SegmentMinPlan, SegmentProdArgs, SegmentProdBackwardArgs,
440    SegmentProdBackwardDescriptor, SegmentProdBackwardPlan, SegmentProdDescriptor, SegmentProdPlan,
441    SegmentSumArgs, SegmentSumBackwardArgs, SegmentSumBackwardDescriptor, SegmentSumBackwardPlan,
442    SegmentSumDescriptor, SegmentSumPlan, UnsortedSegmentMaxArgs, UnsortedSegmentMaxBackwardArgs,
443    UnsortedSegmentMaxBackwardDescriptor, UnsortedSegmentMaxBackwardPlan,
444    UnsortedSegmentMaxDescriptor, UnsortedSegmentMaxPlan, UnsortedSegmentMeanArgs,
445    UnsortedSegmentMeanBackwardArgs, UnsortedSegmentMeanBackwardDescriptor,
446    UnsortedSegmentMeanBackwardPlan, UnsortedSegmentMeanDescriptor, UnsortedSegmentMeanPlan,
447    UnsortedSegmentMinArgs, UnsortedSegmentMinBackwardArgs, UnsortedSegmentMinBackwardDescriptor,
448    UnsortedSegmentMinBackwardPlan, UnsortedSegmentMinDescriptor, UnsortedSegmentMinPlan,
449    UnsortedSegmentProdArgs, UnsortedSegmentProdBackwardArgs,
450    UnsortedSegmentProdBackwardDescriptor, UnsortedSegmentProdBackwardPlan,
451    UnsortedSegmentProdDescriptor, UnsortedSegmentProdPlan, UnsortedSegmentSumArgs,
452    UnsortedSegmentSumBackwardArgs, UnsortedSegmentSumBackwardDescriptor,
453    UnsortedSegmentSumBackwardPlan, UnsortedSegmentSumDescriptor, UnsortedSegmentSumPlan,
454};
455
456// Quantization family — Phase 8 (Category P). Split across two parallel
457// milestones: 8.1 ships per-tensor / per-channel / fake_quantize plans;
458// 8.2 ships per-token / per-group plans for LLM-style activation +
459// weight quantization (W8A8 and INT4 GPTQ). Dtype coverage:
460// {f32, f64, f16, bf16} × {s8, u8}. Backwards via STE for `quantize_*`
461// and straight-through scaling for `dequantize_*`.
462pub mod quantize;
463
464pub use quantize::{
465    DequantizePerGroupArgs, DequantizePerGroupBackwardArgs,
466    DequantizePerGroupBackwardDescriptor, DequantizePerGroupBackwardPlan,
467    DequantizePerGroupDescriptor, DequantizePerGroupPlan, DequantizePerTokenArgs,
468    DequantizePerTokenBackwardArgs, DequantizePerTokenBackwardDescriptor,
469    DequantizePerTokenBackwardPlan, DequantizePerTokenDescriptor, DequantizePerTokenPlan,
470    QuantizePerGroupArgs, QuantizePerGroupBackwardArgs, QuantizePerGroupBackwardDescriptor,
471    QuantizePerGroupBackwardPlan, QuantizePerGroupDescriptor, QuantizePerGroupPlan,
472    QuantizePerTokenArgs, QuantizePerTokenBackwardArgs, QuantizePerTokenBackwardDescriptor,
473    QuantizePerTokenBackwardPlan, QuantizePerTokenDescriptor, QuantizePerTokenPlan,
474};
475
476// Milestone 8.1 — per-tensor + per-channel + fake_quantize plan types.
477pub use quantize::{
478    DequantizePerChannelArgs, DequantizePerChannelBackwardArgs,
479    DequantizePerChannelBackwardDescriptor, DequantizePerChannelBackwardPlan,
480    DequantizePerChannelDescriptor, DequantizePerChannelPlan, DequantizePerTensorArgs,
481    DequantizePerTensorBackwardArgs, DequantizePerTensorBackwardDescriptor,
482    DequantizePerTensorBackwardPlan, DequantizePerTensorDescriptor, DequantizePerTensorPlan,
483    FakeQuantizeArgs, FakeQuantizeBackwardArgs, FakeQuantizeBackwardDescriptor,
484    FakeQuantizeBackwardPlan, FakeQuantizeDescriptor, FakeQuantizePlan, QuantizePerChannelArgs,
485    QuantizePerChannelBackwardArgs, QuantizePerChannelBackwardDescriptor,
486    QuantizePerChannelBackwardPlan, QuantizePerChannelDescriptor, QuantizePerChannelPlan,
487    QuantizePerTensorArgs, QuantizePerTensorBackwardArgs, QuantizePerTensorBackwardDescriptor,
488    QuantizePerTensorBackwardPlan, QuantizePerTensorDescriptor, QuantizePerTensorPlan,
489};
490
491// Milestone 8.3 — composing quantization ops (DynamicRangeQuantize +
492// QuantizedLinear).
493pub use quantize::{
494    DynamicRangeMode, DynamicRangeQuantizeArgs, DynamicRangeQuantizeDescriptor,
495    DynamicRangeQuantizePlan, DynamicRangeScope, QuantizedLinearArgs,
496    QuantizedLinearDescriptor, QuantizedLinearPlan,
497};
498
499// Phase 45 — SmoothQuant linear (pure Rust composition over the
500// existing `quantized_linear_w8a8` kernel; zero new CUDA).
501pub use quantize::{
502    SmoothQuantLinearArgs, SmoothQuantLinearDescriptor, SmoothQuantLinearPlan,
503};
504
505// Milestone 8.4 — GGUF block-format dequant + MMVQ (Category P).
506// Vendored from llama.cpp via fuel-cuda-kernels.
507pub use quantize::{
508    BlockQ2K, BlockQ3K, BlockQ4_0, BlockQ4_1, BlockQ4K, BlockQ5_0, BlockQ5_1, BlockQ5K, BlockQ6K,
509    BlockQ8_0, BlockQ8K, GgufDequantizeArgs, GgufDequantizeDescriptor, GgufDequantizePlan,
510    GgufMmvqArgs, GgufMmvqDescriptor, GgufMmvqPlan,
511};
512
513// Phase 20.1 — GGUF batched MMVQ × N-experts (general-purpose routing primitive).
514pub use quantize::{
515    GgufMmvqBatchedActivation, GgufMmvqBatchedArgs, GgufMmvqBatchedDescriptor,
516    GgufMmvqBatchedFormat, GgufMmvqBatchedPlan,
517};
518
519// Phase 33 — multi-M MMVQ via Q8_1 activation staging (prefill speedup).
520pub use quantize::{GgufMmvqMultiMArgs, GgufMmvqMultiMDescriptor, GgufMmvqMultiMPlan};
521
522// Phase 53 — bitsandbytes NF4 (NormalFloat 4-bit) dequant + GEMV
523// (QLoRA inference path). Plan types are always re-exported; the FFI
524// dispatch inside `run()` is gated behind the `bnb_nf4` cargo feature
525// (matches the Phase 46 FlashInfer pattern — public API stable, link
526// surface opt-in). See `quantize/nf4.rs` for the full docs.
527pub use quantize::{
528    Nf4Activation, Nf4DequantizeArgs, Nf4DequantizePlan, Nf4Descriptor, Nf4MmvqArgs,
529    Nf4MmvqMultiMArgs, Nf4MmvqMultiMDescriptor, Nf4MmvqMultiMPlan, Nf4MmvqPlan, NF4_CODEBOOK,
530};
531
532// Milestone 8.5 — Mixture-of-Experts inference forward (Category V).
533// Vendored from attention.rs via fuel-cuda-kernels.
534pub mod moe;
535pub use moe::{MoeArgs, MoeDescriptor, MoePlan, MoeVariant};
536
537// Image / spatial-transform family — Phase 9 Category T. Bespoke
538// kernels for the canonical vision-domain ops: interpolate (bilinear
539// 2D), grid_sample + affine_grid, pixel_shuffle / pixel_unshuffle (each
540// is the other's BW), roi_align, roi_pool, nms. f32 + f64 for math-
541// bearing ops; pixel_shuffle adds f16 + bf16 (memory-bound). NCHW.
542pub mod image;
543
544pub use image::{
545    AffineGridArgs, AffineGridDescriptor, AffineGridPlan, GridSampleArgs,
546    GridSampleBackwardArgs, GridSampleBackwardDescriptor, GridSampleBackwardPlan,
547    GridSampleDescriptor, GridSamplePlan, InterpolateArgs, InterpolateBackwardArgs,
548    InterpolateBackwardDescriptor, InterpolateBackwardPlan, InterpolateDescriptor,
549    InterpolateMode, InterpolatePlan, NmsArgs, NmsDescriptor, NmsPlan, PixelShuffleArgs,
550    PixelShuffleDescriptor, PixelShufflePlan, PixelUnshuffleArgs, PixelUnshuffleDescriptor,
551    PixelUnshufflePlan, RoiAlignArgs, RoiAlignBackwardArgs, RoiAlignBackwardDescriptor,
552    RoiAlignBackwardPlan, RoiAlignDescriptor, RoiAlignPlan, RoiPoolArgs, RoiPoolBackwardArgs,
553    RoiPoolBackwardDescriptor, RoiPoolBackwardPlan, RoiPoolDescriptor, RoiPoolPlan,
554};
555
556// Sorting / order-statistics family — Phase 9 Category O. Block-
557// bitonic sort + topk (one block per row, `row_len ≤ 1024`, `k ≤ 64`),
558// per-query binary search (searchsorted), atomic-bin histograms +
559// bincount, and the unique / unique_consecutive set-valued ops.
560// Sort / topk BW use the saved-indices scatter contract (FW emits
561// indices as a required output; BW reads them verbatim — no
562// recomputation). f32 + f64 + i32 + i64 FW for sort family; f32 + f64
563// for grads.
564pub mod sort;
565
566pub use sort::{
567    ArgsortArgs, ArgsortDescriptor, ArgsortPlan, BincountArgs, BincountDescriptor, BincountPlan,
568    HistogramArgs, HistogramDescriptor, HistogramPlan, HistogramddArgs, HistogramddDescriptor,
569    HistogramddPlan, KthvalueArgs, KthvalueBackwardArgs, KthvalueBackwardDescriptor,
570    KthvalueBackwardPlan, KthvalueDescriptor, KthvaluePlan, MsortArgs, MsortBackwardArgs,
571    MsortBackwardDescriptor, MsortBackwardPlan, MsortDescriptor, MsortPlan, SearchsortedArgs,
572    SearchsortedDescriptor, SearchsortedPlan, SortArgs, SortBackwardArgs, SortBackwardDescriptor,
573    SortBackwardPlan, SortDescriptor, SortPlan, TopkArgs, TopkBackwardArgs,
574    TopkBackwardDescriptor, TopkBackwardPlan, TopkDescriptor, TopkPlan, UniqueArgs,
575    UniqueConsecutiveArgs, UniqueConsecutiveDescriptor, UniqueConsecutivePlan, UniqueDescriptor,
576    UniquePlan, SORT_MAX_ROW, TOPK_MAX_K,
577};
578
579// Phase 50 — Mamba-2 causal-conv1d primitive. Bespoke kernel; lives at
580// the crate root because it isn't part of any existing op family
581// (no cuDNN dep, distinct shape contract from generic conv1d).
582#[cfg(feature = "mamba")]
583pub mod causal_conv1d;
584
585#[cfg(feature = "mamba")]
586pub use causal_conv1d::{
587    CausalConv1dArgs, CausalConv1dBackwardArgs, CausalConv1dBackwardDescriptor,
588    CausalConv1dBackwardPlan, CausalConv1dDescriptor, CausalConv1dPlan,
589};
590
591// Phase 50 — Mamba-2 SSD chunk-scan re-exports from the attention
592// family (SSD = State-Space Duality).
593#[cfg(feature = "mamba")]
594pub use attention::{
595    SsdChunkScanArgs, SsdChunkScanBackwardArgs, SsdChunkScanBackwardDescriptor,
596    SsdChunkScanBackwardPlan, SsdChunkScanDescriptor, SsdChunkScanPlan,
597};
598
599// Phase 50b — Mamba-1 selective_scan re-exports (sibling to SSD,
600// powers Mamba-7B / Falcon-Mamba / Codestral-Mamba).
601#[cfg(feature = "mamba")]
602pub use attention::{
603    SelectiveScanArgs, SelectiveScanBackwardArgs, SelectiveScanBackwardDescriptor,
604    SelectiveScanBackwardPlan, SelectiveScanDescriptor, SelectiveScanPlan,
605};
606
607// Phase 56 — Ring Attention re-exports. Plan types are always
608// exposed (struct definitions compile without the feature); the
609// `run()` method that actually invokes NCCL + the kernel is gated
610// behind the `ring_attention` cargo feature.
611pub use attention::{
612    RingAttentionArgs, RingAttentionDescriptor, RingAttentionPlan, RING_ATTENTION_HEAD_DIM,
613};
614
615// Phase 49 — Apex multi-tensor optimizer subset (Adam / LAMB / SGD).
616// Vendored from NVIDIA Apex (BSD-3-Clause) and exposed under the
617// `optim` cargo feature. Deliberate scope expansion (training-
618// framework-adjacent); inference-only consumers don't pay the FFI
619// surface cost because they don't enable the feature.
620#[cfg(feature = "optim")]
621pub mod optim {
622    //! Re-export of [`baracuda_optim`]'s optimizer plans into the
623    //! unified kernel facade. Gated behind the `optim` cargo feature.
624    pub use baracuda_optim::{
625        AdamConfig, AdamMode, AdamParamDtype, AdamStepPlan, Error as OptimError, LambConfig,
626        LambStepPlan, MultiTensorApplyContext, Result as OptimResult, SgdConfig, SgdParamDtype,
627        SgdStepPlan, TensorList,
628    };
629}
630
631// Phase 55 — TransformerEngine FP8 cast/transpose + delayed-scaling
632// recipe primitives. Vendored from NVIDIA TransformerEngine (Apache-2.0)
633// and exposed under the `tensor_engine` cargo feature. Cast / recipe
634// subset only — `normalization` / `fused_attn` / `fused_rope` /
635// `activation` / `gemm` overlap existing baracuda phases and are
636// deliberately NOT brought in. Sm_89 caveat (RTX 4070): the FP8 wins
637// are bandwidth-saving only on Ada — tensor-core FP8 MMA throughput
638// equals BF16. Forward-compatible with Hopper / Blackwell where the
639// MMA throughput win also materializes.
640#[cfg(feature = "tensor_engine")]
641pub mod transformer_engine {
642    //! Re-export of [`baracuda_transformer_engine`]'s FP8 cast / recipe
643    //! plans into the unified kernel facade. Gated behind the
644    //! `tensor_engine` cargo feature.
645    pub use baracuda_transformer_engine::{
646        Error as TransformerEngineError, Fp8CastPlan, Fp8DequantPlan, Fp8Format, Fp8Recipe,
647        Fp8WideDtype, Result as TransformerEngineResult,
648    };
649}
650
651// Phase 57 — Megatron-LM tensor-parallel primitives. Pure-composition
652// crate over baracuda-cublas + baracuda-nccl (no new CUDA kernels) and
653// exposed under the `megatron_tp` cargo feature. Algorithmic reference
654// is Shoeybi et al. arXiv:1909.08053 (NVIDIA Megatron-LM, Apache-2.0);
655// no source is vendored.
656#[cfg(feature = "megatron_tp")]
657pub mod megatron {
658    //! Re-export of [`baracuda_megatron`]'s tensor-parallel Linear
659    //! plans into the unified kernel facade. Gated behind the
660    //! `megatron_tp` cargo feature.
661    pub use baracuda_megatron::{
662        ColumnParallelLinearPlan, Error as MegatronError, MegatronGemmScalar,
663        Result as MegatronResult, RowParallelLinearPlan, TensorParallelContext,
664    };
665}
666
667// =========================================================================
668// Phase 46 — FlashInfer cherry-pick re-exports (at the crate root).
669// =========================================================================
670//
671// The three new plan families landed in `attention::*` and `random::*`;
672// mirror them at the crate root so callers can use
673// `baracuda_kernels::TopKTopPSamplingPlan` (the documented path).
674pub use attention::{
675    BatchPagedDecodeArgs, BatchPagedDecodeDescriptor, BatchPagedDecodePlan,
676    BatchPagedDecodeFp8Args, BatchPagedDecodeFp8Descriptor, BatchPagedDecodeFp8Plan,
677    BatchPagedPrefillArgs, BatchPagedPrefillDescriptor, BatchPagedPrefillPlan,
678    BatchRaggedPrefillArgs, BatchRaggedPrefillDescriptor, BatchRaggedPrefillPlan,
679    CascadeAttentionArgs, CascadeAttentionDescriptor, CascadeAttentionPlan,
680    CascadeMergeStatesArgs, CascadeMergeStatesDescriptor, CascadeMergeStatesPlan, Fp8KvDtype,
681    PagedKvAppendArgs, PagedKvAppendDescriptor, PagedKvAppendPlan, PagedKvCacheDescriptor,
682};
683pub use random::{
684    PerRowSampler, PerRowSamplingArgs, PerRowSamplingDescriptor, PerRowSamplingPlan, SamplerKind,
685    SpeculativeSamplingArgs, SpeculativeSamplingDescriptor, SpeculativeSamplingPlan,
686    TokenPenaltyArgs, TokenPenaltyDescriptor, TokenPenaltyPlan, TopKTopPSamplingArgs,
687    TopKTopPSamplingDescriptor, TopKTopPSamplingPlan,
688};