1#![deny(missing_docs)]
65
66pub use baracuda_kernels_types::{
68 contiguous_stride, ActivationKind, ArchSku, ArgReduceKind, AttentionKind, BackendKind,
69 BiasElement, BiasElementKind, Bin, BinElement, BinaryCmpKind, BinaryKind, Bool, Complex32,
70 Complex64, CrossEntropyTargetKind, Element, ElementKind, EmbeddingKind, EpilogueKind,
71 F32Strict, FftKind, FillMode, Fp8E4M3, Fp8E5M2, FpElement, GatedActivationKind,
72 GgufBlockFormat, ImageKind, IndexElement, IndexElementKind, IndexOutputElement,
73 IndexOutputKind, IndexingKind,
74 IntElement, KernelDtype, KernelSku, LayoutSku, LinalgKind, LossKind, LossReduction,
75 MathPrecision, MatrixMut, MatrixRef, MoeKind, NormalizationKind, OpCategory, PadMode,
76 PlanPreference, PoolKind, PrecisionGuarantee, QuantizeKind, RandomKind, ReduceKind,
77 ReduceToOp, S4, S8, ScalarType, ScanKind, SegmentKind,
78 ShapeLayoutKind, SoftmaxKind, SortKind, TensorMut, TensorRef, TernaryKind, U4, U8, UnaryKind,
79 VectorRef, Workspace,
80};
81
82pub use baracuda_cutlass::{
86 BatchedGemmArgs, BatchedGemmDescriptor, BatchedGemmPlan, Error, GemmArgs, GemmDescriptor,
87 GemmPlan, GemmSku, GroupedGemmPlan, GroupedPlanPreference, GroupedProblem, GroupedScheduleMode,
88 PreparedGroupedGemm, Result,
89};
90
91pub mod gemm;
96
97pub use gemm::{
98 BinGemmArgs, BinGemmDescriptor, BinGemmPlan, DenseGemmArgs, DenseGemmDescriptor,
99 DenseGemmLayout, DenseGemmPlan, Fp8GemmArgs, Fp8GemmDescriptor, Fp8GemmPlan,
100 GemmSparse24Args, GemmSparse24Descriptor, GemmSparse24Plan, Int4GemmArgs, Int4GemmDescriptor,
101 Int4GemmPlan, IntGemmArgs, IntGemmDescriptor, IntGemmPlan,
102};
103
104pub use gemm::{
108 gptq_to_marlin_repack, AwqActivation, GptqWeights, Int4AwqGemmArgs, Int4AwqGemmDescriptor,
109 Int4AwqGemmPlan, Int4MarlinGemmArgs, Int4MarlinGemmDescriptor, Int4MarlinGemmPlan,
110 MarlinActivation, MarlinWeights, MARLIN_PERM_LEN, MARLIN_SCALE_PERM_LEN,
111};
112
113pub mod elementwise;
116
117pub use elementwise::{
118 AffineArgs, AffineDescriptor, AffinePlan, BinaryArgs, BinaryBackwardArgs,
119 BinaryBackwardDescriptor, BinaryBackwardPlan, BinaryCmpArgs,
120 BinaryCmpDescriptor, BinaryCmpPlan, BinaryDescriptor, BinaryParamArgs,
121 BinaryParamBackwardArgs, BinaryParamBackwardDescriptor, BinaryParamBackwardPlan,
122 BinaryParamDescriptor, BinaryParamPlan, BinaryPlan, CastArgs, CastDescriptor, CastPlan,
123 CastSubByteArgs, CastSubByteDescriptor, CastSubBytePlan,
124 GatedActivationArgs,
125 GatedActivationBackwardArgs, GatedActivationBackwardDescriptor, GatedActivationBackwardPlan,
126 GatedActivationDescriptor, GatedActivationPlan, TernaryArgs, TernaryBackwardArgs,
127 TernaryBackwardDescriptor, TernaryBackwardPlan, TernaryDescriptor, TernaryPlan, UnaryArgs,
128 UnaryBackwardArgs, UnaryBackwardDescriptor, UnaryBackwardPlan, UnaryDescriptor,
129 UnaryParamArgs, UnaryParamBackwardArgs, UnaryParamBackwardDescriptor, UnaryParamBackwardPlan,
130 UnaryParamDescriptor, UnaryParamPlan, UnaryPlan, WhereArgs, WhereBackwardArgs,
131 WhereBackwardDescriptor, WhereBackwardPlan, WhereDescriptor, WherePlan,
132};
133
134pub use elementwise::{
135 PReluArgs, PReluBackwardArgs, PReluBackwardDescriptor, PReluBackwardPlan, PReluDescriptor,
136 PReluPlan,
137};
138
139pub mod shape_layout;
142
143pub use shape_layout::{
144 ConcatArgs, ConcatBackwardArgs, ConcatBackwardDescriptor, ConcatBackwardPlan,
145 ConcatDescriptor, ConcatPlan, ContiguizeArgs, ContiguizeDescriptor, ContiguizePlan,
146 FillArgs, FillDescriptor, FillPlan, FlipArgs,
147 FlipBackwardArgs, FlipBackwardDescriptor,
148 FlipBackwardPlan, FlipDescriptor, FlipPlan, PadArgs, PadBackwardArgs,
149 PadBackwardDescriptor, PadBackwardPlan, PadDescriptor, PadPlan, PermuteArgs,
150 PermuteBackwardArgs, PermuteBackwardDescriptor, PermuteBackwardPlan, PermuteDescriptor,
151 PermutePlan, RepeatArgs, RepeatBackwardArgs, RepeatBackwardDescriptor,
152 RepeatBackwardPlan, RepeatDescriptor, RepeatPlan, RollArgs, RollBackwardArgs,
153 RollBackwardDescriptor, RollBackwardPlan, RollDescriptor, RollPlan,
154 TrilArgs, TrilBackwardArgs, TrilBackwardDescriptor, TrilBackwardPlan,
155 TrilDescriptor, TrilPlan, TriuArgs, TriuBackwardArgs, TriuBackwardDescriptor,
156 TriuBackwardPlan, TriuDescriptor, TriuPlan,
157 WriteSliceArgs, WriteSliceDescriptor, WriteSlicePlan,
158};
159
160pub mod reduce;
163
164pub use reduce::{
165 ArgReduceArgs, ArgReduceDescriptor, ArgReducePlan, BoolReduceArgs, BoolReduceDescriptor,
166 BoolReducePlan, CountReduceArgs, CountReduceDescriptor, CountReducePlan, ReduceArgs,
167 ReduceBackwardArgs, ReduceBackwardDescriptor, ReduceBackwardPlan, ReduceDescriptor, ReducePlan,
168 ReduceToArgs, ReduceToDescriptor, ReduceToPlan, TraceArgs, TraceDescriptor, TracePlan,
169};
170
171pub mod scan;
174
175pub use scan::{
176 ScanArgs, ScanBackwardArgs, ScanBackwardDescriptor, ScanBackwardPlan, ScanDescriptor,
177 ScanPlan,
178};
179
180pub mod softmax;
183
184pub use softmax::{
185 GumbelSoftmaxArgs, GumbelSoftmaxBackwardArgs, GumbelSoftmaxBackwardDescriptor,
186 GumbelSoftmaxBackwardPlan, GumbelSoftmaxDescriptor, GumbelSoftmaxPlan, SoftmaxArgs,
187 SoftmaxBackwardArgs, SoftmaxBackwardDescriptor, SoftmaxBackwardPlan, SoftmaxDescriptor,
188 SoftmaxPlan, SparsemaxArgs, SparsemaxBackwardArgs, SparsemaxBackwardDescriptor,
189 SparsemaxBackwardPlan, SparsemaxDescriptor, SparsemaxPlan, SPARSEMAX_MAX_EXTENT,
190};
191
192pub mod norm;
196
197pub use norm::{
198 BatchNormArgs, BatchNormBackwardArgs, BatchNormBackwardDescriptor, BatchNormBackwardPlan,
199 BatchNormDescriptor, BatchNormPlan, GroupNormArgs, GroupNormBackwardArgs,
200 GroupNormBackwardDescriptor, GroupNormBackwardPlan, GroupNormDescriptor, GroupNormPlan,
201 InstanceNormArgs, InstanceNormBackwardArgs, InstanceNormBackwardDescriptor,
202 InstanceNormBackwardPlan, InstanceNormDescriptor, InstanceNormPlan, LayerNormArgs,
203 LayerNormBackwardArgs, LayerNormBackwardDescriptor, LayerNormBackwardPlan, LayerNormDescriptor,
204 LayerNormPlan, RMSNormArgs, RMSNormBackwardArgs, RMSNormBackwardDescriptor,
205 RMSNormBackwardPlan, RMSNormDescriptor, RMSNormPlan,
206};
207
208pub mod loss;
211
212pub use loss::{
213 BceLossArgs, BceLossBackwardArgs, BceLossBackwardDescriptor, BceLossBackwardPlan,
214 BceLossDescriptor, BceLossPlan, BceWithLogitsLossArgs, BceWithLogitsLossBackwardArgs,
215 BceWithLogitsLossBackwardDescriptor, BceWithLogitsLossBackwardPlan,
216 BceWithLogitsLossDescriptor, BceWithLogitsLossPlan, CrossEntropyLossArgs,
217 CrossEntropyLossBackwardArgs, CrossEntropyLossBackwardDescriptor,
218 CrossEntropyLossBackwardPlan, CrossEntropyLossDescriptor, CrossEntropyLossPlan,
219 FusedLinearCrossEntropyArgs, FusedLinearCrossEntropyBackwardArgs,
220 FusedLinearCrossEntropyBackwardDescriptor, FusedLinearCrossEntropyBackwardPlan,
221 FusedLinearCrossEntropyDescriptor, FusedLinearCrossEntropyPlan, FLCE_DEFAULT_IGNORE_INDEX,
222 GaussianNllLossArgs, GaussianNllLossBackwardArgs, GaussianNllLossBackwardDescriptor,
223 GaussianNllLossBackwardPlan, GaussianNllLossDescriptor, GaussianNllLossPlan, HuberLossArgs,
224 HuberLossBackwardArgs, HuberLossBackwardDescriptor, HuberLossBackwardPlan,
225 HuberLossDescriptor, HuberLossPlan, KlDivLossArgs, KlDivLossBackwardArgs,
226 KlDivLossBackwardDescriptor, KlDivLossBackwardPlan, KlDivLossDescriptor, KlDivLossPlan,
227 L1LossArgs, L1LossBackwardArgs, L1LossBackwardDescriptor, L1LossBackwardPlan,
228 L1LossDescriptor, L1LossPlan, MseLossArgs, MseLossBackwardArgs, MseLossBackwardDescriptor,
229 MseLossBackwardPlan, MseLossDescriptor, MseLossPlan, NllLossArgs, NllLossBackwardArgs,
230 NllLossBackwardDescriptor, NllLossBackwardPlan, NllLossDescriptor, NllLossPlan,
231 PoissonNllLossArgs, PoissonNllLossBackwardArgs, PoissonNllLossBackwardDescriptor,
232 PoissonNllLossBackwardPlan, PoissonNllLossDescriptor, PoissonNllLossPlan, SmoothL1LossArgs,
233 SmoothL1LossBackwardArgs, SmoothL1LossBackwardDescriptor, SmoothL1LossBackwardPlan,
234 SmoothL1LossDescriptor, SmoothL1LossPlan,
235};
236
237pub use loss::{
238 CosineEmbeddingLossArgs, CosineEmbeddingLossBackwardArgs,
239 CosineEmbeddingLossBackwardDescriptor, CosineEmbeddingLossBackwardPlan,
240 CosineEmbeddingLossDescriptor, CosineEmbeddingLossPlan, HingeEmbeddingLossArgs,
241 HingeEmbeddingLossBackwardArgs, HingeEmbeddingLossBackwardDescriptor,
242 HingeEmbeddingLossBackwardPlan, HingeEmbeddingLossDescriptor, HingeEmbeddingLossPlan,
243 MarginRankingLossArgs, MarginRankingLossBackwardArgs, MarginRankingLossBackwardDescriptor,
244 MarginRankingLossBackwardPlan, MarginRankingLossDescriptor, MarginRankingLossPlan,
245 MultiMarginLossArgs, MultiMarginLossBackwardArgs, MultiMarginLossBackwardDescriptor,
246 MultiMarginLossBackwardPlan, MultiMarginLossDescriptor, MultiMarginLossPlan,
247 MultilabelMarginLossArgs, MultilabelMarginLossBackwardArgs,
248 MultilabelMarginLossBackwardDescriptor, MultilabelMarginLossBackwardPlan,
249 MultilabelMarginLossDescriptor, MultilabelMarginLossPlan, MultilabelSoftMarginLossArgs,
250 MultilabelSoftMarginLossBackwardArgs, MultilabelSoftMarginLossBackwardDescriptor,
251 MultilabelSoftMarginLossBackwardPlan, MultilabelSoftMarginLossDescriptor,
252 MultilabelSoftMarginLossPlan, TripletMarginLossArgs, TripletMarginLossBackwardArgs,
253 TripletMarginLossBackwardDescriptor, TripletMarginLossBackwardPlan,
254 TripletMarginLossDescriptor, TripletMarginLossPlan,
255};
256
257pub use loss::{
260 CtcLossArgs, CtcLossBackwardArgs, CtcLossBackwardDescriptor, CtcLossBackwardPlan,
261 CtcLossDescriptor, CtcLossPlan,
262};
263
264#[cfg(feature = "cudnn")]
268pub use loss::{CtcLossCudnnArgs, CtcLossCudnnDescriptor, CtcLossCudnnPlan};
269
270pub mod random;
274
275pub use random::{
276 DropoutArgs, DropoutBackwardArgs, DropoutBackwardDescriptor, DropoutBackwardPlan,
277 DropoutDescriptor, DropoutPlan, RandomArgs, RandomBoolArgs, RandomDescriptor, RandomPlan,
278};
279
280pub mod attention;
284
285pub use attention::{
286 AlibiArgs, AlibiBackwardArgs, AlibiBackwardDescriptor, AlibiBackwardPlan, AlibiDescriptor,
287 AlibiPlan,
288 FlashDecodingArgs, FlashDecodingDescriptor, FlashDecodingPlan, FLASH_DECODING_MAX_D,
290 FlashSdpaArgs, FlashSdpaBackwardArgs, FlashSdpaBackwardDescriptor,
291 FlashSdpaBackwardPlan, FlashSdpaDescriptor, FlashSdpaPlan,
292 FlashSdpaVarlenArgs, FlashSdpaVarlenBackwardArgs, FlashSdpaVarlenBackwardPlan,
294 FlashSdpaVarlenDescriptor, FlashSdpaVarlenPlan,
295 HyperConnectionArgs, HyperConnectionDescriptor, HyperConnectionPlan, KvCacheAppendArgs,
296 KvCacheAppendDescriptor, KvCacheAppendPlan, RopeArgs, RopeBackwardArgs,
297 RopeBackwardDescriptor, RopeBackwardPlan, RopeDescriptor, RopePlan, SdpaArgs,
298 SdpaBackwardArgs, SdpaBackwardDescriptor, SdpaBackwardPlan, SdpaBlockSparseArgs,
299 SdpaBlockSparseDescriptor, SdpaBlockSparsePlan, SdpaDescriptor, SdpaPlan,
300 FLASH_SDPA_MAX_D, ROPE_DEFAULT_BASE, SDPA_BLOCK_SPARSE_MAX_BLOCK, SDPA_BLOCK_SPARSE_MAX_D,
301};
302
303pub use attention::{RopeScaledTableBuilder, RopeScaling};
306
307#[cfg(feature = "sm89")]
313pub use attention::{FlashSdpaSm89Args, FlashSdpaSm89Descriptor, FlashSdpaSm89Plan};
314
315pub mod linalg;
319
320pub use linalg::{
321 BatchedOrmqrArgs, BatchedOrmqrDescriptor, BatchedOrmqrOp, BatchedOrmqrPlan, BatchedOrmqrSide,
322 BatchedOrmqrWyArgs, BatchedOrmqrWyDescriptor, BatchedOrmqrWyPlan, BatchedQrArgs,
323 BatchedQrDescriptor, BatchedQrMaterializeArgs, BatchedQrMaterializeDescriptor,
324 BatchedQrMaterializePlan, BatchedQrPlan, BatchedSvdArgs, BatchedSvdDescriptor, BatchedSvdPlan,
325 BatchedSvdaArgs, BatchedSvdaDescriptor, BatchedSvdaPlan, CholeskyArgs, CholeskyDescriptor,
326 CholeskyPlan, EigArgs, EigDescriptor, EigPlan, EighArgs, EighDescriptor, EighPlan, InverseArgs,
327 InverseDescriptor, InversePlan, LstSqArgs, LstSqDescriptor, LstSqPlan, LuArgs, LuDescriptor,
328 LuPlan, QrArgs, QrDescriptor, QrPlan, SolveArgs, SolveDescriptor, SolvePlan, SvdArgs,
329 SvdDescriptor, SvdPlan, WY_NB,
330};
331
332#[cfg(feature = "cudnn")]
339pub mod conv;
340
341#[cfg(feature = "cudnn")]
342pub use conv::{
343 Col2Im1dArgs, Col2Im1dDescriptor, Col2Im1dPlan, Conv1dArgs, Conv1dBwArgs, Conv1dDescriptor,
344 Conv1dDwArgs, Conv1dPlan, Conv2dArgs, Conv2dBwArgs, Conv2dDescriptor, Conv2dDwArgs,
345 Conv2dPlan, Conv3dArgs, Conv3dBwArgs, Conv3dDescriptor, Conv3dDwArgs, Conv3dPlan,
346 ConvTranspose1dArgs, ConvTranspose1dBwArgs, ConvTranspose1dDescriptor, ConvTranspose1dDwArgs,
347 ConvTranspose1dPlan, ConvTranspose2dArgs, ConvTranspose2dBwArgs, ConvTranspose2dDescriptor,
348 ConvTranspose2dDwArgs, ConvTranspose2dPlan, ConvTranspose3dArgs, ConvTranspose3dBwArgs,
349 ConvTranspose3dDescriptor, ConvTranspose3dDwArgs, ConvTranspose3dPlan, Im2Col1dArgs,
350 Im2Col1dDescriptor, Im2Col1dPlan, Im2ColArgs, Im2ColDescriptor, Im2ColPlan,
351};
352
353#[cfg(feature = "cudnn")]
359pub mod pool;
360
361#[cfg(feature = "cudnn")]
362pub use pool::{
363 AdaptiveAvgPool1dPlan, AdaptiveAvgPool2dPlan, AdaptiveAvgPool3dPlan, AdaptiveMaxPool1dPlan,
364 AdaptiveMaxPool2dPlan, AdaptiveMaxPool3dPlan, AdaptivePool1dBwArgs, AdaptivePool1dDescriptor,
365 AdaptivePool1dFwArgs, AdaptivePool2dBwArgs, AdaptivePool2dDescriptor, AdaptivePool2dFwArgs,
366 AdaptivePool3dBwArgs, AdaptivePool3dDescriptor, AdaptivePool3dFwArgs, AvgPool1dPlan,
367 AvgPool2dPlan, AvgPool3dPlan, FractionalMaxPool2dBwArgs, FractionalMaxPool2dDescriptor,
368 FractionalMaxPool2dFwArgs, FractionalMaxPool2dPlan, FractionalMaxPool3dBwArgs,
369 FractionalMaxPool3dDescriptor, FractionalMaxPool3dFwArgs, FractionalMaxPool3dPlan,
370 LpPool1dBackwardPlan, LpPool1dBwArgs, LpPool1dDescriptor, LpPool1dFwArgs, LpPool1dPlan,
371 LpPool2dBackwardPlan, LpPool2dBwArgs, LpPool2dDescriptor, LpPool2dFwArgs, LpPool2dPlan,
372 MaxPool1dPlan, MaxPool2dPlan, MaxPool3dPlan, Pool1dBwArgs,
373 Pool1dDescriptor, Pool1dFwArgs, Pool2dBwArgs, Pool2dDescriptor, Pool2dFwArgs, Pool3dBwArgs,
374 Pool3dDescriptor, Pool3dFwArgs, PoolMode,
375};
376
377pub mod fft;
382
383pub use fft::{
384 FftArgs, FftDescriptor, FftNdArgs, FftNdDescriptor, FftNdPlan, FftPlan, FftShiftArgs,
385 FftShiftDescriptor, FftShiftNdArgs, FftShiftNdDescriptor, FftShiftNdPlan, FftShiftPlan,
386 IrfftArgs, IrfftDescriptor, IrfftNdArgs, IrfftNdDescriptor, IrfftNdPlan, IrfftPlan, RfftArgs,
387 RfftDescriptor, RfftNdArgs, RfftNdDescriptor, RfftNdPlan, RfftPlan, FFTSHIFT_ND_MAX_RANK,
388 FFTSHIFT_ND_MAX_SHIFT_AXES,
389};
390
391pub mod indexing;
397
398pub use indexing::{
399 GatherArgs, GatherBackwardArgs, GatherBackwardDescriptor, GatherBackwardPlan,
400 GatherDescriptor, GatherPlan, IndexAddArgs, IndexAddDescriptor, IndexAddPlan, IndexSelectArgs,
401 IndexSelectBackwardArgs, IndexSelectBackwardDescriptor, IndexSelectBackwardPlan,
402 IndexSelectDescriptor, IndexSelectPlan, MaskedFillArgs, MaskedFillBackwardArgs,
403 MaskedFillBackwardDescriptor, MaskedFillBackwardPlan, MaskedFillDescriptor, MaskedFillPlan,
404 NonzeroArgs, NonzeroDescriptor, NonzeroPlan, OneHotArgs, OneHotDescriptor, OneHotPlan,
405 ScatterArgs, ScatterDescriptor, ScatterPlan, ScatterAddArgs, ScatterAddDescriptor,
406 ScatterAddPlan,
407};
408
409pub mod embedding;
416
417pub use embedding::{
418 EmbeddingArgs, EmbeddingBackwardArgs, EmbeddingBackwardDescriptor, EmbeddingBackwardPlan,
419 EmbeddingBagArgs, EmbeddingBagBackwardArgs, EmbeddingBagBackwardDescriptor,
420 EmbeddingBagBackwardPlan, EmbeddingBagDescriptor, EmbeddingBagMaxArgs,
421 EmbeddingBagMaxBackwardArgs, EmbeddingBagMaxBackwardDescriptor, EmbeddingBagMaxBackwardPlan,
422 EmbeddingBagMaxDescriptor, EmbeddingBagMaxPlan, EmbeddingBagMode, EmbeddingBagPlan,
423 EmbeddingDescriptor, EmbeddingPlan,
424};
425
426pub mod segment;
433
434pub use segment::{
435 SegmentMaxArgs, SegmentMaxBackwardArgs, SegmentMaxBackwardDescriptor, SegmentMaxBackwardPlan,
436 SegmentMaxDescriptor, SegmentMaxPlan, SegmentMeanArgs, SegmentMeanBackwardArgs,
437 SegmentMeanBackwardDescriptor, SegmentMeanBackwardPlan, SegmentMeanDescriptor, SegmentMeanPlan,
438 SegmentMinArgs, SegmentMinBackwardArgs, SegmentMinBackwardDescriptor, SegmentMinBackwardPlan,
439 SegmentMinDescriptor, SegmentMinPlan, SegmentProdArgs, SegmentProdBackwardArgs,
440 SegmentProdBackwardDescriptor, SegmentProdBackwardPlan, SegmentProdDescriptor, SegmentProdPlan,
441 SegmentSumArgs, SegmentSumBackwardArgs, SegmentSumBackwardDescriptor, SegmentSumBackwardPlan,
442 SegmentSumDescriptor, SegmentSumPlan, UnsortedSegmentMaxArgs, UnsortedSegmentMaxBackwardArgs,
443 UnsortedSegmentMaxBackwardDescriptor, UnsortedSegmentMaxBackwardPlan,
444 UnsortedSegmentMaxDescriptor, UnsortedSegmentMaxPlan, UnsortedSegmentMeanArgs,
445 UnsortedSegmentMeanBackwardArgs, UnsortedSegmentMeanBackwardDescriptor,
446 UnsortedSegmentMeanBackwardPlan, UnsortedSegmentMeanDescriptor, UnsortedSegmentMeanPlan,
447 UnsortedSegmentMinArgs, UnsortedSegmentMinBackwardArgs, UnsortedSegmentMinBackwardDescriptor,
448 UnsortedSegmentMinBackwardPlan, UnsortedSegmentMinDescriptor, UnsortedSegmentMinPlan,
449 UnsortedSegmentProdArgs, UnsortedSegmentProdBackwardArgs,
450 UnsortedSegmentProdBackwardDescriptor, UnsortedSegmentProdBackwardPlan,
451 UnsortedSegmentProdDescriptor, UnsortedSegmentProdPlan, UnsortedSegmentSumArgs,
452 UnsortedSegmentSumBackwardArgs, UnsortedSegmentSumBackwardDescriptor,
453 UnsortedSegmentSumBackwardPlan, UnsortedSegmentSumDescriptor, UnsortedSegmentSumPlan,
454};
455
456pub mod quantize;
463
464pub use quantize::{
465 DequantizePerGroupArgs, DequantizePerGroupBackwardArgs,
466 DequantizePerGroupBackwardDescriptor, DequantizePerGroupBackwardPlan,
467 DequantizePerGroupDescriptor, DequantizePerGroupPlan, DequantizePerTokenArgs,
468 DequantizePerTokenBackwardArgs, DequantizePerTokenBackwardDescriptor,
469 DequantizePerTokenBackwardPlan, DequantizePerTokenDescriptor, DequantizePerTokenPlan,
470 QuantizePerGroupArgs, QuantizePerGroupBackwardArgs, QuantizePerGroupBackwardDescriptor,
471 QuantizePerGroupBackwardPlan, QuantizePerGroupDescriptor, QuantizePerGroupPlan,
472 QuantizePerTokenArgs, QuantizePerTokenBackwardArgs, QuantizePerTokenBackwardDescriptor,
473 QuantizePerTokenBackwardPlan, QuantizePerTokenDescriptor, QuantizePerTokenPlan,
474};
475
476pub use quantize::{
478 DequantizePerChannelArgs, DequantizePerChannelBackwardArgs,
479 DequantizePerChannelBackwardDescriptor, DequantizePerChannelBackwardPlan,
480 DequantizePerChannelDescriptor, DequantizePerChannelPlan, DequantizePerTensorArgs,
481 DequantizePerTensorBackwardArgs, DequantizePerTensorBackwardDescriptor,
482 DequantizePerTensorBackwardPlan, DequantizePerTensorDescriptor, DequantizePerTensorPlan,
483 FakeQuantizeArgs, FakeQuantizeBackwardArgs, FakeQuantizeBackwardDescriptor,
484 FakeQuantizeBackwardPlan, FakeQuantizeDescriptor, FakeQuantizePlan, QuantizePerChannelArgs,
485 QuantizePerChannelBackwardArgs, QuantizePerChannelBackwardDescriptor,
486 QuantizePerChannelBackwardPlan, QuantizePerChannelDescriptor, QuantizePerChannelPlan,
487 QuantizePerTensorArgs, QuantizePerTensorBackwardArgs, QuantizePerTensorBackwardDescriptor,
488 QuantizePerTensorBackwardPlan, QuantizePerTensorDescriptor, QuantizePerTensorPlan,
489};
490
491pub use quantize::{
494 DynamicRangeMode, DynamicRangeQuantizeArgs, DynamicRangeQuantizeDescriptor,
495 DynamicRangeQuantizePlan, DynamicRangeScope, QuantizedLinearArgs,
496 QuantizedLinearDescriptor, QuantizedLinearPlan,
497};
498
499pub use quantize::{
502 SmoothQuantLinearArgs, SmoothQuantLinearDescriptor, SmoothQuantLinearPlan,
503};
504
505pub use quantize::{
508 BlockQ2K, BlockQ3K, BlockQ4_0, BlockQ4_1, BlockQ4K, BlockQ5_0, BlockQ5_1, BlockQ5K, BlockQ6K,
509 BlockQ8_0, BlockQ8K, GgufDequantizeArgs, GgufDequantizeDescriptor, GgufDequantizePlan,
510 GgufMmvqArgs, GgufMmvqDescriptor, GgufMmvqPlan,
511};
512
513pub use quantize::{
515 GgufMmvqBatchedActivation, GgufMmvqBatchedArgs, GgufMmvqBatchedDescriptor,
516 GgufMmvqBatchedFormat, GgufMmvqBatchedPlan,
517};
518
519pub use quantize::{GgufMmvqMultiMArgs, GgufMmvqMultiMDescriptor, GgufMmvqMultiMPlan};
521
522pub use quantize::{
528 Nf4Activation, Nf4DequantizeArgs, Nf4DequantizePlan, Nf4Descriptor, Nf4MmvqArgs,
529 Nf4MmvqMultiMArgs, Nf4MmvqMultiMDescriptor, Nf4MmvqMultiMPlan, Nf4MmvqPlan, NF4_CODEBOOK,
530};
531
532pub mod moe;
535pub use moe::{MoeArgs, MoeDescriptor, MoePlan, MoeVariant};
536
537pub mod image;
543
544pub use image::{
545 AffineGridArgs, AffineGridDescriptor, AffineGridPlan, GridSampleArgs,
546 GridSampleBackwardArgs, GridSampleBackwardDescriptor, GridSampleBackwardPlan,
547 GridSampleDescriptor, GridSamplePlan, InterpolateArgs, InterpolateBackwardArgs,
548 InterpolateBackwardDescriptor, InterpolateBackwardPlan, InterpolateDescriptor,
549 InterpolateMode, InterpolatePlan, NmsArgs, NmsDescriptor, NmsPlan, PixelShuffleArgs,
550 PixelShuffleDescriptor, PixelShufflePlan, PixelUnshuffleArgs, PixelUnshuffleDescriptor,
551 PixelUnshufflePlan, RoiAlignArgs, RoiAlignBackwardArgs, RoiAlignBackwardDescriptor,
552 RoiAlignBackwardPlan, RoiAlignDescriptor, RoiAlignPlan, RoiPoolArgs, RoiPoolBackwardArgs,
553 RoiPoolBackwardDescriptor, RoiPoolBackwardPlan, RoiPoolDescriptor, RoiPoolPlan,
554};
555
556pub mod sort;
565
566pub use sort::{
567 ArgsortArgs, ArgsortDescriptor, ArgsortPlan, BincountArgs, BincountDescriptor, BincountPlan,
568 HistogramArgs, HistogramDescriptor, HistogramPlan, HistogramddArgs, HistogramddDescriptor,
569 HistogramddPlan, KthvalueArgs, KthvalueBackwardArgs, KthvalueBackwardDescriptor,
570 KthvalueBackwardPlan, KthvalueDescriptor, KthvaluePlan, MsortArgs, MsortBackwardArgs,
571 MsortBackwardDescriptor, MsortBackwardPlan, MsortDescriptor, MsortPlan, SearchsortedArgs,
572 SearchsortedDescriptor, SearchsortedPlan, SortArgs, SortBackwardArgs, SortBackwardDescriptor,
573 SortBackwardPlan, SortDescriptor, SortPlan, TopkArgs, TopkBackwardArgs,
574 TopkBackwardDescriptor, TopkBackwardPlan, TopkDescriptor, TopkPlan, UniqueArgs,
575 UniqueConsecutiveArgs, UniqueConsecutiveDescriptor, UniqueConsecutivePlan, UniqueDescriptor,
576 UniquePlan, SORT_MAX_ROW, TOPK_MAX_K,
577};
578
579#[cfg(feature = "mamba")]
583pub mod causal_conv1d;
584
585#[cfg(feature = "mamba")]
586pub use causal_conv1d::{
587 CausalConv1dArgs, CausalConv1dBackwardArgs, CausalConv1dBackwardDescriptor,
588 CausalConv1dBackwardPlan, CausalConv1dDescriptor, CausalConv1dPlan,
589};
590
591#[cfg(feature = "mamba")]
594pub use attention::{
595 SsdChunkScanArgs, SsdChunkScanBackwardArgs, SsdChunkScanBackwardDescriptor,
596 SsdChunkScanBackwardPlan, SsdChunkScanDescriptor, SsdChunkScanPlan,
597};
598
599#[cfg(feature = "mamba")]
602pub use attention::{
603 SelectiveScanArgs, SelectiveScanBackwardArgs, SelectiveScanBackwardDescriptor,
604 SelectiveScanBackwardPlan, SelectiveScanDescriptor, SelectiveScanPlan,
605};
606
607pub use attention::{
612 RingAttentionArgs, RingAttentionDescriptor, RingAttentionPlan, RING_ATTENTION_HEAD_DIM,
613};
614
615#[cfg(feature = "optim")]
621pub mod optim {
622 pub use baracuda_optim::{
625 AdamConfig, AdamMode, AdamParamDtype, AdamStepPlan, Error as OptimError, LambConfig,
626 LambStepPlan, MultiTensorApplyContext, Result as OptimResult, SgdConfig, SgdParamDtype,
627 SgdStepPlan, TensorList,
628 };
629}
630
631#[cfg(feature = "tensor_engine")]
641pub mod transformer_engine {
642 pub use baracuda_transformer_engine::{
646 Error as TransformerEngineError, Fp8CastPlan, Fp8DequantPlan, Fp8Format, Fp8Recipe,
647 Fp8WideDtype, Result as TransformerEngineResult,
648 };
649}
650
651#[cfg(feature = "megatron_tp")]
657pub mod megatron {
658 pub use baracuda_megatron::{
662 ColumnParallelLinearPlan, Error as MegatronError, MegatronGemmScalar,
663 Result as MegatronResult, RowParallelLinearPlan, TensorParallelContext,
664 };
665}
666
667pub use attention::{
675 BatchPagedDecodeArgs, BatchPagedDecodeDescriptor, BatchPagedDecodePlan,
676 BatchPagedDecodeFp8Args, BatchPagedDecodeFp8Descriptor, BatchPagedDecodeFp8Plan,
677 BatchPagedPrefillArgs, BatchPagedPrefillDescriptor, BatchPagedPrefillPlan,
678 BatchRaggedPrefillArgs, BatchRaggedPrefillDescriptor, BatchRaggedPrefillPlan,
679 CascadeAttentionArgs, CascadeAttentionDescriptor, CascadeAttentionPlan,
680 CascadeMergeStatesArgs, CascadeMergeStatesDescriptor, CascadeMergeStatesPlan, Fp8KvDtype,
681 PagedKvAppendArgs, PagedKvAppendDescriptor, PagedKvAppendPlan, PagedKvCacheDescriptor,
682};
683pub use random::{
684 PerRowSampler, PerRowSamplingArgs, PerRowSamplingDescriptor, PerRowSamplingPlan, SamplerKind,
685 SpeculativeSamplingArgs, SpeculativeSamplingDescriptor, SpeculativeSamplingPlan,
686 TokenPenaltyArgs, TokenPenaltyDescriptor, TokenPenaltyPlan, TopKTopPSamplingArgs,
687 TopKTopPSamplingDescriptor, TopKTopPSamplingPlan,
688};