Skip to main content

Crate baracuda_cutlass

Crate baracuda_cutlass 

Source
Expand description

§baracuda-cutlass

Safe Rust wrapper for compiled CUTLASS kernels in the baracuda ecosystem. Plan-based GEMM and grouped-GEMM API with caller-supplied workspace, typed device-buffer arguments, and capture-safe launches.

See the crate README.md for the v0 scope and the design rationale. See baracuda-cutlass-kernels-sys for the underlying compiled template instantiations.

§Quick start

use baracuda_cutlass::{
    EpilogueKind, GemmArgs, GemmDescriptor, GemmPlan, LayoutSku,
    MatrixMut, MatrixRef, PlanPreference, Workspace,
};
use baracuda_driver::{Context, Device, DeviceBuffer, Stream};
use half::f16;

let ctx = Context::new(&Device::get(0)?)?;
let stream = Stream::new(&ctx)?;

let m = 128i32; let n = 128i32; let k = 128i32;
let dev_a: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (m * k) as usize)?;
let dev_b: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (k * n) as usize)?;
let mut dev_d: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (m * n) as usize)?;

let desc = GemmDescriptor {
    m, n, k,
    layout: LayoutSku::Rcr,
    epilogue: EpilogueKind::Identity,
};
let plan = GemmPlan::<f16>::select(&stream, &desc, PlanPreference::default())?;
let args = GemmArgs::<f16> {
    a: MatrixRef { data: dev_a.as_slice(), rows: m, cols: k, ld: k as i64 },
    b: MatrixRef { data: dev_b.as_slice(), rows: k, cols: n, ld: k as i64 },
    c: None,
    d: MatrixMut { data: dev_d.as_slice_mut(), rows: m, cols: n, ld: n as i64 },
    bias: None,
    alpha: 1.0,
    beta: 0.0,
};
plan.can_implement(&args)?;
plan.run(&stream, Workspace::None, args)?;

Re-exports§

pub use error::Error;
pub use error::Result;
pub use plan::BatchedGemmPlan;
pub use plan::GemmPlan;
pub use plan::GroupedGemmPlan;
pub use plan::IntGemmPlan;
pub use plan::PreparedGroupedGemm;
pub use types::BatchedGemmArgs;
pub use types::BatchedGemmDescriptor;
pub use types::GemmArgs;
pub use types::GemmDescriptor;
pub use types::GemmSku;
pub use types::GroupedPlanPreference;
pub use types::GroupedProblem;
pub use types::GroupedScheduleMode;
pub use types::IntGemmArgs;
pub use types::IntGemmDescriptor;

Modules§

error
Error types for baracuda-cutlass.
plan
Plan-based GEMM and grouped-GEMM API.
types
Value types for the CUTLASS plan-based API.

Structs§

F32Strict
Strict-precision f32 element marker.
MatrixMut
Mutable view of a device-resident matrix (used for the output D).
MatrixRef
Read-only view of a device-resident matrix.
PlanPreference
Hints that influence kernel selection inside a plan’s select method.
PrecisionGuarantee
Numerical guarantees a kernel provides.
S8
Signed 8-bit integer element marker. #[repr(transparent)] around i8.
U8
Unsigned 8-bit integer element marker. #[repr(transparent)] around u8.
VectorRef
Read-only view of a device-resident vector.

Enums§

ActivationKind
Activation functions implemented by the Bias*Activation EpilogueKind variants. Surfaced for telemetry and selector logic; the kernel selection itself is driven by the enum variant.
ArchSku
Compute capability bucket the selected kernel was compiled for.
BackendKind
Which underlying compute backend served a kernel SKU.
BiasElementKind
Runtime tag for a BiasElement.
ElementKind
Runtime tag for an Element or IntElement.
EpilogueKind
Epilogue applied after the matrix-multiply accumulation.
LayoutSku
Layout SKU. Describes the row/column orientation of A, B, C, and D for matrix-multiply-shaped kernels.
MathPrecision
Math precision used by the FMA / tensor-core instruction.
Workspace
Caller-supplied workspace for a launch.

Traits§

BiasElement
Bias element types accepted by the int-GEMM bias epilogue family.
CutlassElement
Back-compat alias for Element.
IntElement
Integer element types supported by the int-GEMM kernel set.
ScalarType
Sealed marker for the alpha/beta scalar type an Element uses.