Skip to main content

baracuda_cutlass/
lib.rs

1//! # baracuda-cutlass
2//!
3//! Safe Rust wrapper for compiled CUTLASS kernels in the baracuda
4//! ecosystem. Plan-based GEMM and grouped-GEMM API with caller-supplied
5//! workspace, typed device-buffer arguments, and capture-safe launches.
6//!
7//! See the crate `README.md` for the v0 scope and the design rationale.
8//! See [`baracuda-cutlass-kernels-sys`] for the underlying compiled
9//! template instantiations.
10//!
11//! ## Quick start
12//!
13//! ```rust,no_run
14//! use baracuda_cutlass::{
15//!     EpilogueKind, GemmArgs, GemmDescriptor, GemmPlan, LayoutSku,
16//!     MatrixMut, MatrixRef, PlanPreference, Workspace,
17//! };
18//! use baracuda_driver::{Context, Device, DeviceBuffer, Stream};
19//! use half::f16;
20//!
21//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
22//! let ctx = Context::new(&Device::get(0)?)?;
23//! let stream = Stream::new(&ctx)?;
24//!
25//! let m = 128i32; let n = 128i32; let k = 128i32;
26//! let dev_a: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (m * k) as usize)?;
27//! let dev_b: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (k * n) as usize)?;
28//! let mut dev_d: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (m * n) as usize)?;
29//!
30//! let desc = GemmDescriptor {
31//!     m, n, k,
32//!     layout: LayoutSku::Rcr,
33//!     epilogue: EpilogueKind::Identity,
34//! };
35//! let plan = GemmPlan::<f16>::select(&stream, &desc, PlanPreference::default())?;
36//! let args = GemmArgs::<f16> {
37//!     a: MatrixRef { data: dev_a.as_slice(), rows: m, cols: k, ld: k as i64 },
38//!     b: MatrixRef { data: dev_b.as_slice(), rows: k, cols: n, ld: k as i64 },
39//!     c: None,
40//!     d: MatrixMut { data: dev_d.as_slice_mut(), rows: m, cols: n, ld: n as i64 },
41//!     bias: None,
42//!     alpha: 1.0,
43//!     beta: 0.0,
44//! };
45//! plan.can_implement(&args)?;
46//! plan.run(&stream, Workspace::None, args)?;
47//! # Ok(()) }
48//! ```
49//!
50//! [`baracuda-cutlass-kernels-sys`]: https://docs.rs/baracuda-cutlass-kernels-sys
51
52#![deny(missing_docs)]
53
54pub mod error;
55pub mod plan;
56pub mod types;
57
58pub use error::{Error, Result};
59pub use plan::{BatchedGemmPlan, GemmPlan, GroupedGemmPlan, IntGemmPlan, PreparedGroupedGemm};
60pub use types::{
61    ActivationKind, ArchSku, BackendKind, BatchedGemmArgs, BatchedGemmDescriptor, BiasElement,
62    BiasElementKind, CutlassElement, ElementKind, EpilogueKind, F32Strict, GemmArgs,
63    GemmDescriptor, GemmSku, GroupedPlanPreference, GroupedProblem, GroupedScheduleMode,
64    IntElement, IntGemmArgs, IntGemmDescriptor, LayoutSku, MathPrecision, MatrixMut, MatrixRef,
65    PlanPreference, PrecisionGuarantee, S8, ScalarType, U8, VectorRef, Workspace,
66};