Expand description
§baracuda-cutlass
Safe Rust wrapper for compiled CUTLASS kernels in the baracuda ecosystem. Plan-based GEMM and grouped-GEMM API with caller-supplied workspace, typed device-buffer arguments, and capture-safe launches.
See the crate README.md for the v0 scope and the design rationale.
See baracuda-cutlass-kernels-sys for the underlying compiled
template instantiations.
§Quick start
use baracuda_cutlass::{
EpilogueKind, GemmArgs, GemmDescriptor, GemmPlan, LayoutSku,
MatrixMut, MatrixRef, PlanPreference, Workspace,
};
use baracuda_driver::{Context, Device, DeviceBuffer, Stream};
use half::f16;
let ctx = Context::new(&Device::get(0)?)?;
let stream = Stream::new(&ctx)?;
let m = 128i32; let n = 128i32; let k = 128i32;
let dev_a: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (m * k) as usize)?;
let dev_b: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (k * n) as usize)?;
let mut dev_d: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (m * n) as usize)?;
let desc = GemmDescriptor {
m, n, k,
layout: LayoutSku::Rcr,
epilogue: EpilogueKind::Identity,
};
let plan = GemmPlan::<f16>::select(&stream, &desc, PlanPreference::default())?;
let args = GemmArgs::<f16> {
a: MatrixRef { data: dev_a.as_slice(), rows: m, cols: k, ld: k as i64 },
b: MatrixRef { data: dev_b.as_slice(), rows: k, cols: n, ld: k as i64 },
c: None,
d: MatrixMut { data: dev_d.as_slice_mut(), rows: m, cols: n, ld: n as i64 },
bias: None,
alpha: 1.0,
beta: 0.0,
};
plan.can_implement(&args)?;
plan.run(&stream, Workspace::None, args)?;Re-exports§
pub use error::Error;pub use error::Result;pub use plan::BatchedGemmPlan;pub use plan::GemmPlan;pub use plan::GroupedGemmPlan;pub use plan::IntGemmPlan;pub use plan::PreparedGroupedGemm;pub use types::BatchedGemmArgs;pub use types::BatchedGemmDescriptor;pub use types::GemmArgs;pub use types::GemmDescriptor;pub use types::GemmSku;pub use types::GroupedPlanPreference;pub use types::GroupedProblem;pub use types::GroupedScheduleMode;pub use types::IntGemmArgs;pub use types::IntGemmDescriptor;
Modules§
- error
- Error types for
baracuda-cutlass. - plan
- Plan-based GEMM and grouped-GEMM API.
- types
- Value types for the CUTLASS plan-based API.
Structs§
- F32Strict
- Strict-precision f32 element marker.
- Matrix
Mut - Mutable view of a device-resident matrix (used for the output
D). - Matrix
Ref - Read-only view of a device-resident matrix.
- Plan
Preference - Hints that influence kernel selection inside a plan’s
selectmethod. - Precision
Guarantee - Numerical guarantees a kernel provides.
- S8
- Signed 8-bit integer element marker.
#[repr(transparent)]aroundi8. - U8
- Unsigned 8-bit integer element marker.
#[repr(transparent)]aroundu8. - Vector
Ref - Read-only view of a device-resident vector.
Enums§
- Activation
Kind - Activation functions implemented by the
Bias*ActivationEpilogueKindvariants. Surfaced for telemetry and selector logic; the kernel selection itself is driven by the enum variant. - ArchSku
- Compute capability bucket the selected kernel was compiled for.
- Backend
Kind - Which underlying compute backend served a kernel SKU.
- Bias
Element Kind - Runtime tag for a
BiasElement. - Element
Kind - Runtime tag for an
ElementorIntElement. - Epilogue
Kind - Epilogue applied after the matrix-multiply accumulation.
- Layout
Sku - Layout SKU. Describes the row/column orientation of A, B, C, and D for matrix-multiply-shaped kernels.
- Math
Precision - Math precision used by the FMA / tensor-core instruction.
- Workspace
- Caller-supplied workspace for a launch.
Traits§
- Bias
Element - Bias element types accepted by the int-GEMM bias epilogue family.
- Cutlass
Element - Back-compat alias for
Element. - IntElement
- Integer element types supported by the int-GEMM kernel set.
- Scalar
Type - Sealed marker for the alpha/beta scalar type an
Elementuses.